import os import sys import time import tempfile import base64 import io import warnings import threading import numpy as np import torch import nibabel as nib from PIL import Image from typing import Optional from torch.amp import autocast from scipy import ndimage import model_loader from processing import preprocess_nifti, calculate_liver_volume, analyze_liver_morphology, generate_medical_report, refine_liver_mask, refine_liver_mask_enhanced warnings.filterwarnings('ignore', category=UserWarning, message='.*non-tuple sequence for multidimensional indexing.*') warnings.filterwarnings('ignore', category=FutureWarning, message='.*non-tuple sequence.*') PROCESSING_LOCK = threading.Lock() def adjust_roi_for_volume(volume_shape): if model_loader.WINDOW_INFER is None: return if len(volume_shape) < 4: return if len(volume_shape) == 5: _, _, depth, height, width = volume_shape elif len(volume_shape) == 4: _, depth, height, width = volume_shape else: depth, height, width = volume_shape[-3:] current_roi = list(model_loader.WINDOW_INFER.roi_size) if isinstance(model_loader.WINDOW_INFER.roi_size, (list, tuple)) else [model_loader.WINDOW_INFER.roi_size[0], model_loader.WINDOW_INFER.roi_size[1], model_loader.WINDOW_INFER.roi_size[2]] original_roi = current_roi.copy() roi_d, roi_h, roi_w = current_roi adjusted = False if roi_d > depth: current_roi[0] = min(depth, 64) adjusted = True print(f" ⚠ ROI depth ({roi_d}) > volume depth ({depth}). Adjusting to {current_roi[0]}") if roi_h > height: current_roi[1] = min(height, current_roi[1]) adjusted = True print(f" ⚠ ROI height ({roi_h}) > volume height ({height}). Adjusting to {current_roi[1]}") if roi_w > width: current_roi[2] = min(width, current_roi[2]) adjusted = True print(f" ⚠ ROI width ({roi_w}) > volume width ({width}). Adjusting to {current_roi[2]}") total_volume = depth * height * width roi_volume = current_roi[0] * current_roi[1] * current_roi[2] if total_volume > 20_000_000: overlap = 0.1 if model_loader.WINDOW_INFER.overlap > 0.1: model_loader.WINDOW_INFER.overlap = overlap adjusted = True print(f" → Large volume detected ({total_volume:,} voxels). Reducing overlap to {overlap} for faster processing") elif total_volume > 10_000_000: overlap = 0.12 if model_loader.WINDOW_INFER.overlap > 0.12: model_loader.WINDOW_INFER.overlap = overlap adjusted = True print(f" → Medium-large volume detected ({total_volume:,} voxels). Reducing overlap to {overlap}") if depth < 64 and current_roi[0] > depth * 0.8: current_roi[0] = max(32, int(depth * 0.8)) adjusted = True print(f" → Small depth dimension ({depth}). Optimizing ROI depth to {current_roi[0]}") if adjusted: model_loader.WINDOW_INFER.roi_size = tuple(current_roi) num_windows_approx = ((depth / current_roi[0]) * (height / current_roi[1]) * (width / current_roi[2])) / (1 - model_loader.WINDOW_INFER.overlap) print(f" ✓ ROI adjusted: {original_roi} → {current_roi}") print(f" → Estimated windows: ~{int(num_windows_approx)} (overlap={model_loader.WINDOW_INFER.overlap:.2f})") def predict_volume(nifti_file, modality, slice_idx=None): global PROCESSING_LOCK if not PROCESSING_LOCK.acquire(blocking=False): return None, "**Already processing a file. Please wait for the current inference to complete.**", "", None if nifti_file is None: PROCESSING_LOCK.release() return None, "Please upload a NIfTI file (.nii.gz)", "", None image_tensor = None pred = None output_path = None try: print(f"Starting prediction for modality: {modality}") if modality == 'T1': if model_loader.MODEL_T1 is None: print("Loading T1 model_loader...") model_loader.MODEL_T1 = model_loader.load_model('T1') model_instance = model_loader.MODEL_T1 else: if model_loader.MODEL_T2 is None: print("Loading T2 model_loader...") model_loader.MODEL_T2 = model_loader.load_model('T2') model_instance = model_loader.MODEL_T2 print("Preprocessing NIfTI file...") from pathlib import Path fp = getattr(nifti_file, "name", nifti_file) file_path = str(Path(fp).resolve()) if not os.path.exists(file_path): raise FileNotFoundError(f"Uploaded file not found: {file_path}") nifti_img = nib.load(file_path) nifti_img = nib.as_closest_canonical(nifti_img) voxel_spacing = nifti_img.header.get_zooms()[:3] if len(nifti_img.header.get_zooms()) >= 3 else (1.0, 1.0, 1.0) aff = nifti_img.affine hdr = nifti_img.header.copy() original_data = nifti_img.get_fdata(dtype=np.float32) if model_loader.DEVICE.type == 'cuda': if not torch.cuda.is_available(): raise RuntimeError("CUDA not available but device is set to CUDA") torch.cuda.empty_cache() free_mem = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / (1024**3) if free_mem < 0.5: raise RuntimeError(f"Insufficient GPU memory: {free_mem:.2f} GB free. Please restart.") image_data = preprocess_nifti(file_path, device=model_loader.DEVICE) if len(image_data.shape) == 3: image_tensor = image_data.unsqueeze(0).unsqueeze(0) elif len(image_data.shape) == 4: image_tensor = image_data.unsqueeze(0) else: image_tensor = image_data if model_loader.DEVICE.type == 'cuda' and len(image_tensor.shape) >= 4: try: image_tensor = image_tensor.contiguous(memory_format=torch.channels_last_3d) if image_tensor.is_contiguous(memory_format=torch.channels_last_3d): print(f" → Using channels-last 3D memory layout (optimized for GPU)") except: pass print(f"Input tensor shape: {image_tensor.shape}") if model_loader.WINDOW_INFER is not None: print("Adjusting ROI size based on volume dimensions...") adjust_roi_for_volume(image_tensor.shape) print("Running inference...") inference_start = time.time() y1, y2, y3, y4 = None, None, None, None gpu_handle = None if model_loader.DEVICE.type == 'cuda': allocated = torch.cuda.memory_allocated(0) / (1024**3) free_memory_gb = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / (1024**3) print(f" → Pre-inference GPU memory: {allocated:.2f} GB allocated, {free_memory_gb:.2f} GB free") try: import pynvml pynvml.nvmlInit() gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(0) util = pynvml.nvmlDeviceGetUtilizationRates(gpu_handle) print(f" → GPU utilization: {util.gpu}%") except: pass if free_memory_gb < 1.0: print(f" ⚠ WARNING: Very low free memory ({free_memory_gb:.2f} GB). Adjusting settings...") if model_loader.WINDOW_INFER is not None: model_loader.WINDOW_INFER.sw_batch_size = 1 if free_memory_gb < 0.5: model_loader.WINDOW_INFER.roi_size = [192, 192, 32] model_loader.WINDOW_INFER.overlap = 0.25 if free_memory_gb < 0.5: torch.cuda.empty_cache() free_memory_gb = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / (1024**3) print(f" → After cache clear: {free_memory_gb:.2f} GB free") if free_memory_gb < 0.5: error_msg = f"Insufficient GPU memory ({free_memory_gb:.2f} GB free). The model is using {allocated:.2f} GB. Please use a smaller input volume." print(f" ✗ {error_msg}") raise RuntimeError(error_msg) import config import threading timeout_occurred = threading.Event() def timeout_handler(): timeout_occurred.set() timeout_timer = threading.Timer(config.INFERENCE_TIMEOUT, timeout_handler) timeout_timer.start() try: with torch.no_grad(): if model_loader.DEVICE.type == 'cuda': torch.cuda.synchronize() print(" → AMP enabled: YES") sys.stdout.flush() last_heartbeat = time.time() heartbeat_interval = 30 try: with autocast(device_type='cuda'): print(f" → Starting inference (timeout: {config.INFERENCE_TIMEOUT}s)...") sys.stdout.flush() start_time = time.time() if timeout_occurred.is_set(): raise TimeoutError(f"Inference timeout: {config.INFERENCE_TIMEOUT}s") gpu_utils = [] if gpu_handle is not None: import threading monitor_active = threading.Event() monitor_active.set() def monitor_gpu(): while monitor_active.is_set(): try: import pynvml util = pynvml.nvmlDeviceGetUtilizationRates(gpu_handle) gpu_utils.append(util.gpu) time.sleep(1.0) except: break monitor_thread = threading.Thread(target=monitor_gpu, daemon=True) monitor_thread.start() torch.cuda.synchronize() inference_start_sync = time.time() y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance) if gpu_handle is not None: monitor_active.clear() if monitor_thread.is_alive(): monitor_thread.join(timeout=2.0) if timeout_occurred.is_set(): raise TimeoutError(f"Inference timeout: {config.INFERENCE_TIMEOUT}s") torch.cuda.synchronize() inference_time = time.time() - inference_start_sync elapsed = time.time() - start_time if gpu_handle is not None and len(gpu_utils) > 0: try: import pynvml final_util = pynvml.nvmlDeviceGetUtilizationRates(gpu_handle) avg_util = sum(gpu_utils) / len(gpu_utils) if gpu_utils else final_util.gpu max_util = max(gpu_utils) if gpu_utils else final_util.gpu print(f" ✓ Inference completed in {inference_time:.1f}s (total: {elapsed:.1f}s)") print(f" → GPU utilization: avg={avg_util:.1f}%, max={max_util:.1f}%, final={final_util.gpu}%") if avg_util < 50: print(f" ⚠ WARNING: Low GPU utilization ({avg_util:.1f}%). May be CPU/transfer-bound.") print(f" → Consider: 1) Install CUDA extensions (mamba_ssm, selective_scan_cuda_oflex)") print(f" → 2) Check data transfer efficiency (channels-last enabled)") print(f" → 3) Verify GPU aggregation is enabled (check logs above)") except: print(f" ✓ Inference completed in {inference_time:.1f}s (total: {elapsed:.1f}s)") else: print(f" ✓ Inference completed in {inference_time:.1f}s (total: {elapsed:.1f}s)") except TimeoutError: timeout_timer.cancel() raise except RuntimeError as e: if "out of memory" in str(e): print(f" GPU OOM error: {e}") print(" Applying progressive OOM fallback...") torch.cuda.empty_cache() original_batch = model_loader.WINDOW_INFER.sw_batch_size original_roi = list(model_loader.WINDOW_INFER.roi_size) original_overlap = model_loader.WINDOW_INFER.overlap original_device = model_loader.WINDOW_INFER.device original_device_str = str(original_device) fallback_applied = None try: if original_batch > 1: print(f" Fallback 1: Reducing sw_batch_size {original_batch} -> 1") model_loader.WINDOW_INFER.sw_batch_size = 1 with autocast(device_type='cuda'): y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance) torch.cuda.synchronize() fallback_applied = "sw_batch_size=1" print(f" Retry successful with {fallback_applied}") else: raise RuntimeError("Already at batch_size=1") except RuntimeError as e2: if "out of memory" not in str(e2): raise try: if original_roi[0] > 48: new_depth = 48 print(f" Fallback 2: Reducing ROI depth {original_roi[0]} -> {new_depth}") model_loader.WINDOW_INFER.roi_size = (new_depth, original_roi[1], original_roi[2]) with autocast(device_type='cuda'): y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance) torch.cuda.synchronize() fallback_applied = f"roi_depth={new_depth}" print(f" Retry successful with {fallback_applied}") else: raise RuntimeError("Already at depth=48") except RuntimeError as e3: if "out of memory" not in str(e3): raise try: if original_roi[0] > 32: new_depth = 32 print(f" Fallback 3: Reducing ROI depth {original_roi[0]} -> {new_depth}") model_loader.WINDOW_INFER.roi_size = (new_depth, original_roi[1], original_roi[2]) with autocast(device_type='cuda'): y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance) torch.cuda.synchronize() fallback_applied = f"roi_depth={new_depth}" print(f" Retry successful with {fallback_applied}") else: raise RuntimeError("Already at depth=32") except RuntimeError as e4: if "out of memory" not in str(e4): raise try: if original_device == 'cuda': print(f" Fallback 4: Switching aggregation to CPU") model_loader.WINDOW_INFER.device = torch.device('cpu') with autocast(device_type='cuda'): y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance) torch.cuda.synchronize() fallback_applied = "cpu_aggregation" print(f" Retry successful with {fallback_applied}") else: raise RuntimeError("Already using CPU aggregation") except RuntimeError as e5: model_loader.WINDOW_INFER.sw_batch_size = original_batch model_loader.WINDOW_INFER.roi_size = tuple(original_roi) model_loader.WINDOW_INFER.overlap = original_overlap model_loader.WINDOW_INFER.device = original_device if isinstance(original_device, torch.device) else torch.device(original_device_str) allocated = torch.cuda.memory_allocated(0) / (1024**3) total_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3) free_memory_gb = (total_gb - allocated) error_msg = f"GPU out of memory after all fallbacks. GPU: {allocated:.2f} GB / {total_gb:.2f} GB ({100*allocated/total_gb:.1f}% full), {free_memory_gb:.2f} GB free. Please restart Space or use smaller input. Error: {e5}" print(f" {error_msg}") raise RuntimeError(error_msg) finally: if fallback_applied: print(f" Note: Applied fallback ({fallback_applied}). Restoring original settings after inference.") model_loader.WINDOW_INFER.sw_batch_size = original_batch model_loader.WINDOW_INFER.roi_size = tuple(original_roi) model_loader.WINDOW_INFER.overlap = original_overlap model_loader.WINDOW_INFER.device = original_device else: raise else: print(" → AMP: Not available on CPU") print(" → Starting sliding window inference (this may take 5-10 minutes on CPU)...") start_time = time.time() y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance) elapsed = time.time() - start_time print(f" ✓ Inference completed in {elapsed:.1f}s") finally: timeout_timer.cancel() if y1 is not None: print(f" → Raw model output shapes: y1={y1.shape if hasattr(y1, 'shape') else 'N/A'}, y2={y2.shape if hasattr(y2, 'shape') else 'N/A'}") sys.stdout.flush() else: print(" → Inference failed before model output was generated") sys.stdout.flush() raise RuntimeError("Inference failed: model did not produce output") pred = y1 if isinstance(pred, torch.Tensor): raw_min = pred.min().item() raw_max = pred.max().item() raw_mean = pred.mean().item() print(f" → Raw logits (before sigmoid): min={raw_min:.6f}, max={raw_max:.6f}, mean={raw_mean:.6f}") sys.stdout.flush() print(" → Applying sigmoid activation (confirming: using sigmoid, not softmax)") sys.stdout.flush() pred = torch.sigmoid(pred) if len(pred.shape) == 5: if pred.shape[1] > 1: print(f" ⚠ WARNING: Multi-channel output detected (C={pred.shape[1]}). Using channel 0 for binary segmentation.") pred = pred[0, 0] else: pred = pred[0, 0] elif len(pred.shape) == 4: pred = pred[0] if pred.shape[0] == 1 else pred print(f" → Final prediction shape after sigmoid: {pred.shape}") sys.stdout.flush() inference_time = time.time() - inference_start print(f"✓ Inference completed in {inference_time:.1f} seconds") del y1, y2, y3, y4 if model_loader.DEVICE.type == 'cuda' and (torch.cuda.memory_allocated(0) / (1024**3)) > 40: torch.cuda.empty_cache() print("Post-processing predictions...") sys.stdout.flush() pred = pred.detach().float().cpu() if pred.ndim == 5: pred_np = pred[0, 0].numpy() elif pred.ndim == 4 and pred.shape[0] == 1: pred_np = pred[0].numpy() else: pred_np = pred.numpy() pred_max = pred_np.max() pred_min = pred_np.min() pred_mean = pred_np.mean() if pred_max < 0.3: print(f" ⚠ WARNING: Maximum prediction value is very low ({pred_max:.4f}). Model confidence may be low.") print(f" This can occur with:") print(f" - Low resolution/compressed input (lost texture cues)") print(f" - Normalization mismatch (input not properly preprocessed)") print(f" - Missing metadata (wrong voxel spacing/scaling)") if pred_mean < 0.1: print(f" ⚠ WARNING: Average prediction is very low ({pred_mean:.4f}). Most voxels have low confidence.") print(f" Consider checking input quality and preprocessing.") total_voxels = pred_np.size if modality.upper() == "T1": candidates = np.linspace(0.60, 0.80, 10) frac_min, frac_max = 0.015, 0.040 else: candidates = np.linspace(0.30, 0.70, 9) frac_min, frac_max = 0.008, 0.040 best = None for t in candidates: mb = (pred_np > t).astype(np.uint8) if mb.sum() < 1000: continue frac = mb.mean() if not (frac_min <= frac <= frac_max): continue lbl, n = ndimage.label(mb) if n == 0: continue sizes = ndimage.sum(mb, lbl, range(1, n + 1)) score = sizes.max() - (n - 1) * 1e5 if best is None or score > best[0]: best = (score, t, mb) if best: _, threshold, pred_binary = best print(f" Grid search selected threshold: {threshold:.3f}") else: default_threshold = float(os.environ.get("T1_THRESHOLD", "0.65")) if modality.upper() == "T1" else float(os.environ.get("SEGMENTATION_THRESHOLD", "0.5")) threshold = default_threshold pred_binary = (pred_np > threshold).astype(np.uint8) print(f" Grid search failed, using default threshold: {threshold:.3f}") count_default = pred_binary.sum() count_035 = (pred_np > 0.35).astype(np.uint8).sum() count_03 = (pred_np > 0.3).astype(np.uint8).sum() if count_default == 0 and default_threshold >= 0.5: if count_035 > 0: print(f" ⚠ WARNING: No voxels > {default_threshold}, but {count_035:,} voxels > 0.35. Trying threshold 0.35...") threshold = 0.35 elif count_03 > 0: print(f" ⚠ WARNING: No voxels > {default_threshold} or 0.35, but {count_03:,} voxels > 0.3. Trying threshold 0.3...") threshold = 0.3 else: percentile_thresh = float(np.quantile(pred_np.flatten(), 0.995)) threshold = max(0.3, percentile_thresh) print(f" ⚠ WARNING: No voxels > {default_threshold}. Using percentile threshold: {threshold:.4f}") pred_binary = (pred_np > threshold).astype(np.uint8) elif count_default == 0 and default_threshold < 0.5: if count_03 > 0: print(f" ⚠ WARNING: No voxels > {default_threshold}, but {count_03:,} voxels > 0.3. Trying threshold 0.3...") threshold = 0.3 else: percentile_thresh = float(np.quantile(pred_np.flatten(), 0.995)) threshold = max(0.3, percentile_thresh) print(f" ⚠ WARNING: No voxels > {default_threshold}. Using percentile threshold: {threshold:.4f}") pred_binary = (pred_np > threshold).astype(np.uint8) if modality.upper() == "T1": try: h, w, d = original_data.shape is_prenormalized = (original_data.max() <= 1.0 and original_data.min() >= 0.0) if not is_prenormalized: right_upper_quadrant = original_data[h//4:3*h//4, w//2:, d//3:2*d//3] if right_upper_quadrant.size > 0 and (right_upper_quadrant > 0).sum() > 0: median_intensity = np.median(right_upper_quadrant[right_upper_quadrant > 0]) mad = np.median(np.abs(right_upper_quadrant[right_upper_quadrant > 0] - median_intensity)) k = 1.5 intensity_lower = median_intensity - k * mad intensity_upper = median_intensity + k * mad if intensity_upper - intensity_lower > 0.5: if len(pred_np.shape) == 4: pred_3d = pred_np[0, 0] elif len(pred_np.shape) == 5: pred_3d = pred_np[0, 0, 0] else: pred_3d = pred_np if pred_3d.shape == original_data.shape: intensity_mask = (original_data >= intensity_lower) & (original_data <= intensity_upper) pred_3d_clamped = pred_3d.copy() pred_3d_clamped[~intensity_mask] = np.minimum(pred_3d_clamped[~intensity_mask], threshold * 0.6) if pred_3d_clamped.std() > 1e-6: if len(pred_np.shape) == 4: pred_np[0, 0] = pred_3d_clamped elif len(pred_np.shape) == 5: pred_np[0, 0, 0] = pred_3d_clamped else: pred_np = pred_3d_clamped print(f" Intensity gate: Clamped predictions outside liver-like intensity range [{intensity_lower:.2f}, {intensity_upper:.2f}]") else: print(f" Intensity gate: Shape mismatch (pred: {pred_3d.shape}, original: {original_data.shape}). Skipping intensity gate.") else: left_upper_region = original_data[h//4:3*h//4, :w//2, d//3:2*d//3] if left_upper_region.size > 0 and (left_upper_region > 0).sum() > 100: stomach_like_intensity = np.percentile(left_upper_region[left_upper_region > 0], 75) if len(pred_np.shape) == 4: pred_3d = pred_np[0, 0] elif len(pred_np.shape) == 5: pred_3d = pred_np[0, 0, 0] else: pred_3d = pred_np if pred_3d.shape == original_data.shape: stomach_mask = (original_data > stomach_like_intensity * 0.9) & (original_data < stomach_like_intensity * 1.1) pred_3d_clamped = pred_3d.copy() pred_3d_clamped[stomach_mask] = np.minimum(pred_3d_clamped[stomach_mask], threshold * 0.5) if pred_3d_clamped.std() > 1e-6: if len(pred_np.shape) == 4: pred_np[0, 0] = pred_3d_clamped elif len(pred_np.shape) == 5: pred_np[0, 0, 0] = pred_3d_clamped else: pred_np = pred_3d_clamped print(f" Intensity gate: Suppressed stomach-like intensities (threshold: {stomach_like_intensity:.3f})") else: print(f" Intensity gate: Shape mismatch (pred: {pred_3d.shape}, original: {original_data.shape}). Skipping intensity gate.") except Exception as e: print(f" Intensity gate failed (non-critical): {e}") pred_binary = (pred_np > threshold).astype(np.uint8) count_after_thresh = pred_binary.sum() fraction = count_after_thresh / total_voxels if total_voxels > 0 else 0.0 volume_ml = calculate_liver_volume(pred_binary, voxel_spacing) print(f" After threshold ({threshold:.3f}): {count_after_thresh:,} voxels ({100*fraction:.2f}%), volume: {volume_ml:.1f} ml") if fraction > 0.040 or volume_ml > 2200: threshold_max = float(os.environ.get("THRESHOLD_MAX", "0.90")) new_threshold = min(threshold_max, threshold + 0.05) if new_threshold > threshold: print(f" Size-aware auto-tune: Mask too large (fraction={100*fraction:.2f}%, volume={volume_ml:.1f}ml). Increasing threshold {threshold:.3f} -> {new_threshold:.3f}") threshold = new_threshold pred_binary = (pred_np > threshold).astype(np.uint8) count_after_thresh = pred_binary.sum() fraction = count_after_thresh / total_voxels if total_voxels > 0 else 0.0 volume_ml = calculate_liver_volume(pred_binary, voxel_spacing) print(f" After auto-tune threshold ({threshold:.3f}): {count_after_thresh:,} voxels ({100*fraction:.2f}%), volume: {volume_ml:.1f} ml") if fraction < 0.005 or fraction > 0.040: print(f" WARNING: Mask fraction {100*fraction:.3f}% outside typical range (0.5%-4.0%). May indicate segmentation issues.") sys.stdout.flush() print("Refining liver segmentation mask...") sys.stdout.flush() pred_np_copy = pred_np.copy() try: pred_binary_refined, refinement_metrics, confidence_score = refine_liver_mask_enhanced( pred_binary, voxel_spacing, pred_np_copy, threshold, modality ) pred_binary = pred_binary_refined guards_ok = refinement_metrics.get('guards_ok', True) print(f" Refinement applied: {refinement_metrics['connected_components_before']} -> {refinement_metrics['connected_components_after']} components") print(f" Volume change: {refinement_metrics['volume_change_percent']:.2f}% ({refinement_metrics['volume_change_ml']:.2f} ml)") print(f" Confidence score: {confidence_score:.1f}% {'(guards OK)' if guards_ok else '(guards triggered)'}") except Exception as e: print(f" Refinement failed (using original mask): {e}") import traceback traceback.print_exc() confidence_score = 50.0 guards_ok = False del pred_np, pred_np_copy print("Loading original volume for visualization...") original_volume = original_data.copy() if original_volume.max() > original_volume.min(): original_volume = (original_volume - original_volume.min()) / (original_volume.max() - original_volume.min()) else: original_volume = np.zeros_like(original_volume) slice_idx = max(0, min(slice_idx if slice_idx is not None else (pred_binary.shape[-1] // 2), pred_binary.shape[-1] - 1)) if len(original_volume.shape) == 3: original_slice = original_volume[:, :, slice_idx] else: original_slice = original_volume[:, :, slice_idx] if original_volume.shape[2] > slice_idx else original_volume[:, :, 0] if len(pred_binary.shape) == 4: pred_slice = pred_binary[0, :, :, slice_idx] if pred_binary.shape[3] > slice_idx else pred_binary[0, :, :, 0] elif len(pred_binary.shape) == 5: pred_slice = pred_binary[0, 0, :, :, slice_idx] if pred_binary.shape[4] > slice_idx else pred_binary[0, 0, :, :, 0] else: pred_slice = pred_binary[:, :, slice_idx] if pred_binary.shape[2] > slice_idx else pred_binary[:, :, 0] if pred_slice.shape != original_slice.shape: print(f" Warning: Shape mismatch in overlay (pred_slice: {pred_slice.shape}, original_slice: {original_slice.shape}). Resizing pred_slice to match.") from scipy.ndimage import zoom zoom_factors = (original_slice.shape[0] / pred_slice.shape[0], original_slice.shape[1] / pred_slice.shape[1]) pred_slice = zoom(pred_slice, zoom_factors, order=0) overlay = np.zeros((*original_slice.shape, 3), dtype=np.uint8) overlay[:, :, 0] = (original_slice * 255).astype(np.uint8) overlay[:, :, 1] = (original_slice * 255).astype(np.uint8) overlay[:, :, 2] = (original_slice * 255).astype(np.uint8) mask = pred_slice > 0 overlay[mask, 0] = 0 overlay[mask, 1] = 255 overlay[mask, 2] = 0 overlay_img = Image.fromarray(overlay) print("Saving segmentation mask...") aff = nifti_img.affine hdr = nifti_img.header.copy() hdr.set_data_dtype(np.uint8) pred_save = pred_binary[0] if pred_binary.ndim == 4 else pred_binary if pred_binary.ndim == 5: pred_save = pred_binary[0, 0] original_shape = original_data.shape pred_shape = pred_save.shape if pred_shape != original_shape: print(f" Resampling prediction from {pred_shape} to original shape {original_shape}...") from scipy.ndimage import zoom zoom_factors = tuple(orig / pred for orig, pred in zip(original_shape, pred_shape)) pred_save = zoom(pred_save.astype(np.float32), zoom_factors, order=0).astype(np.uint8) print(f" ✓ Resampled to match original dimensions") with tempfile.NamedTemporaryFile(suffix='.nii.gz', delete=False) as tmp_file: nifti_pred = nib.Nifti1Image(pred_save.astype(np.uint8), affine=aff, header=hdr) nib.save(nifti_pred, tmp_file.name) output_path = tmp_file.name total_voxels = pred_binary.size liver_voxels = pred_binary.sum() liver_percentage = (liver_voxels / total_voxels) * 100 print(f"✓ Segmentation complete. Liver: {liver_voxels:,} voxels ({liver_percentage:.2f}%)") volume_ml = calculate_liver_volume(pred_binary, voxel_spacing) if len(pred_binary.shape) == 4: mask_3d = pred_binary[0] elif len(pred_binary.shape) == 5: mask_3d = pred_binary[0, 0] else: mask_3d = pred_binary labels_final, num_components_final = ndimage.label(mask_3d) morphology = analyze_liver_morphology(pred_binary) from processing import check_volume_sanity volume_sanity_status, volume_sanity_msg = check_volume_sanity(volume_ml) if volume_sanity_status != "OK": print(f" ⚠ {volume_sanity_status}: {volume_sanity_msg}") if num_components_final != 1: print(f" ⚠ WARNING: Expected 1 connected component, found {num_components_final}. Quality check recommended.") medical_report = generate_medical_report( { "volume_shape": list(pred_binary.shape), "liver_voxels": int(liver_voxels), "total_voxels": int(total_voxels), "liver_percentage": float(liver_percentage), "slice_index": int(slice_idx), "total_slices": int(pred_binary.shape[-1]), "modality": modality, "confidence_score": float(confidence_score) }, volume_ml, morphology, modality, confidence_score ) info_text = f""" ## Segmentation Results - **Volume Shape:** {pred_binary.shape} - **Liver Voxels:** {liver_voxels:,} ({liver_percentage:.2f}% of volume) - **Liver Volume:** {volume_ml:.2f} ml - **Slice:** {slice_idx + 1} of {pred_binary.shape[-1]} shown - **Modality:** {modality} """ severity_text = medical_report['severity'].upper().replace('_', ' ') confidence = medical_report['measurements'].get('confidence_score', 0.0) report_text = f""" ## Automated Liver Segmentation Report --- ### Study Information - **Study Date & Time:** {medical_report['study_date']} - **Imaging Modality:** {medical_report['modality']} - **Report Status:** **{severity_text}** - **Confidence Score:** **{confidence:.1f}%** --- ### Key Findings {chr(10).join(f"{finding}" for finding in medical_report['findings'])} --- ### Quantitative Measurements **Volume Analysis:** - **Liver Volume:** **{medical_report['measurements']['liver_volume_ml']} ml** ({medical_report['measurements']['liver_volume_liters']} L) - **Liver Percentage of Scan:** {medical_report['measurements']['liver_percentage']}% - **Segmented Voxels:** {medical_report['measurements']['liver_voxels']:,} / {medical_report['measurements']['total_voxels']:,} total voxels **Morphological Analysis:** - **Connected Components:** {medical_report['measurements']['morphology']['connected_components']} - **Fragmentation Level:** {medical_report['measurements']['morphology']['fragmentation'].upper()} - **Largest Component Ratio:** {medical_report['measurements']['morphology']['largest_component_ratio']*100:.1f}% **Image Characteristics:** - **Volume Dimensions:** {medical_report['measurements']['volume_shape']} **Confidence Metrics:** - **Overall Confidence:** {confidence:.1f}% - **Confidence Interpretation:** {'High' if confidence >= 80 else 'Moderate' if confidence >= 60 else 'Low'} confidence segmentation --- ### Quality Assessment {chr(10).join(f"- {note}" for note in medical_report.get('quality_assessment', []))} --- ### Clinical Context {chr(10).join(f"- {note}" for note in medical_report.get('clinical_notes', [])) if medical_report.get('clinical_notes') else "- No additional clinical notes at this time."} --- ### Impression {medical_report['impression']} --- {f"### Recommendations{chr(10)}{chr(10)}{chr(10).join(f'- {rec}' for rec in medical_report['recommendations'])}{chr(10)}" if medical_report['recommendations'] else ""} ### Methodology {medical_report.get('methodology', 'SRMA-Mamba deep learning model for automated liver segmentation')} --- ### Important Notice {medical_report['disclaimer']} --- *Report generated automatically by SRMA-Mamba Liver Segmentation System* """ del original_volume, pred_binary if model_loader.DEVICE.type == 'cuda': torch.cuda.empty_cache() return overlay_img, info_text, report_text, output_path except FileNotFoundError as e: error_msg = f"**Checkpoint Error:** {str(e)}\n\nPlease ensure checkpoint files are uploaded to the Space." print(f"✗ {error_msg}") import traceback traceback.print_exc() return None, f"## ❌ Error\n\n{error_msg}", "", None except RuntimeError as e: error_msg = f"**Runtime Error:** {str(e)}\n\nThis might be a GPU/CUDA issue. The model will try to use CPU instead." print(f"✗ {error_msg}") import traceback traceback.print_exc() return None, f"## ❌ Error\n\n{error_msg}", "", None except Exception as e: error_msg = f"**Error during inference:** {str(e)}\n\nPlease check the logs for more details." print(f"✗ {error_msg}") import traceback tb = traceback.format_exc() print(f"Full traceback:\n{tb}") return None, f"## ❌ Error\n\n{error_msg}\n\n**Error Type:** `{type(e).__name__}`", "", None finally: PROCESSING_LOCK.release() try: if image_tensor is not None: del image_tensor except (NameError, UnboundLocalError): pass try: if pred is not None: del pred except (NameError, UnboundLocalError): pass if model_loader.DEVICE.type == 'cuda': torch.cuda.empty_cache() print("✓ Memory cleaned up, lock released") def predict_volume_api(file_path: str, modality: str = 'T1', slice_idx: Optional[int] = None): confidence_score = 50.0 try: print("=" * 60) print(f"API: Starting prediction for modality: {modality}, file: {file_path}") print("=" * 60) if modality == 'T1': if model_loader.MODEL_T1 is None: print("API: Loading T1 model_loader...") model_loader.MODEL_T1 = model_loader.load_model('T1') model_instance = model_loader.MODEL_T1 else: if model_loader.MODEL_T2 is None: print("API: Loading T2 model_loader...") model_loader.MODEL_T2 = model_loader.load_model('T2') model_instance = model_loader.MODEL_T2 print("API: Preprocessing NIfTI file...") from pathlib import Path file_path = str(Path(file_path).resolve()) if not os.path.exists(file_path): raise FileNotFoundError(f"File not found: {file_path}") nifti_img = nib.load(file_path) nifti_img = nib.as_closest_canonical(nifti_img) voxel_spacing = nifti_img.header.get_zooms()[:3] if len(nifti_img.header.get_zooms()) >= 3 else (1.0, 1.0, 1.0) aff = nifti_img.affine hdr = nifti_img.header.copy() original_data = nifti_img.get_fdata(dtype=np.float32) if model_loader.DEVICE.type == 'cuda': if not torch.cuda.is_available(): raise RuntimeError("CUDA not available but device is set to CUDA") torch.cuda.empty_cache() free_mem = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / (1024**3) if free_mem < 0.5: raise RuntimeError(f"Insufficient GPU memory: {free_mem:.2f} GB free. Please restart.") image_data = preprocess_nifti(file_path, device=model_loader.DEVICE) if len(image_data.shape) == 3: image_tensor = image_data.unsqueeze(0).unsqueeze(0) elif len(image_data.shape) == 4: image_tensor = image_data.unsqueeze(0) else: image_tensor = image_data print(f"API: Input tensor shape: {image_tensor.shape}") if model_loader.WINDOW_INFER is not None: print("API: Adjusting ROI size based on volume dimensions...") adjust_roi_for_volume(image_tensor.shape) print("API: Running inference...") if model_loader.DEVICE.type == 'cuda': allocated = torch.cuda.memory_allocated(0) / (1024**3) free_memory_gb = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / (1024**3) print(f"API: Pre-inference GPU memory: {allocated:.2f} GB allocated, {free_memory_gb:.2f} GB free") if free_memory_gb < 0.5: print(f"API: ⚠ WARNING: Very low free memory ({free_memory_gb:.2f} GB). Adjusting settings...") if model_loader.WINDOW_INFER is not None: model_loader.WINDOW_INFER.sw_batch_size = 1 model_loader.WINDOW_INFER.roi_size = [192, 192, 32] model_loader.WINDOW_INFER.overlap = 0.25 torch.cuda.empty_cache() free_memory_gb = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / (1024**3) print(f"API: → After cache clear: {free_memory_gb:.2f} GB free") if free_memory_gb < 0.3: error_msg = f"API: Insufficient GPU memory ({free_memory_gb:.2f} GB free). Please use a smaller input volume." print(f"API: ✗ {error_msg}") raise RuntimeError(error_msg) with torch.no_grad(): if model_loader.DEVICE.type == 'cuda': print("API: AMP (Automatic Mixed Precision) enabled: YES") sys.stdout.flush() try: with autocast(device_type='cuda'): y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance) except RuntimeError as e: if "out of memory" in str(e): print(f"API: ⚠ GPU OOM error: {e}") print("API: → Clearing cache and retrying with minimal settings...") torch.cuda.empty_cache() torch.cuda.synchronize() original_batch = model_loader.WINDOW_INFER.sw_batch_size original_roi = model_loader.WINDOW_INFER.roi_size original_overlap = model_loader.WINDOW_INFER.overlap model_loader.WINDOW_INFER.sw_batch_size = 1 model_loader.WINDOW_INFER.roi_size = [192, 192, 32] model_loader.WINDOW_INFER.overlap = 0.25 try: with autocast(device_type='cuda'): y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance) print("API: ✓ Retry successful with minimal settings") except RuntimeError as e2: model_loader.WINDOW_INFER.sw_batch_size = original_batch model_loader.WINDOW_INFER.roi_size = original_roi model_loader.WINDOW_INFER.overlap = original_overlap allocated = torch.cuda.memory_allocated(0) / (1024**3) total_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3) free_memory_gb = (total_gb - allocated) error_msg = f"GPU out of memory even with minimal settings. GPU is {allocated:.2f} GB / {total_gb:.2f} GB ({100*allocated/total_gb:.1f}% full), only {free_memory_gb:.2f} GB free. The model requires CUDA. Please restart the Space or use a smaller input volume. Error: {e2}" print(f"API: ✗ {error_msg}") raise RuntimeError(error_msg) finally: model_loader.WINDOW_INFER.sw_batch_size = original_batch model_loader.WINDOW_INFER.roi_size = original_roi model_loader.WINDOW_INFER.overlap = original_overlap else: raise else: print("API: AMP: Not available on CPU") y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance) print(f"API: Raw model output shapes: y1={y1.shape if hasattr(y1, 'shape') else 'N/A'}") sys.stdout.flush() pred = y1 if isinstance(pred, torch.Tensor): raw_min = pred.min().item() raw_max = pred.max().item() raw_mean = pred.mean().item() print(f"API: Raw logits (before sigmoid): min={raw_min:.6f}, max={raw_max:.6f}, mean={raw_mean:.6f}") sys.stdout.flush() print("API: Applying sigmoid activation (confirming: using sigmoid, not softmax)") sys.stdout.flush() pred = torch.sigmoid(pred) if len(pred.shape) == 5: if pred.shape[1] > 1: print(f"API: ⚠ WARNING: Multi-channel output detected (C={pred.shape[1]}). Using channel 0 for binary segmentation.") pred = pred[0, 0] else: pred = pred[0, 0] elif len(pred.shape) == 4: pred = pred[0] if pred.shape[0] == 1 else pred print(f"API: Final prediction shape after sigmoid: {pred.shape}") sys.stdout.flush() print("API: Post-processing predictions...") sys.stdout.flush() pred_np = pred.detach().cpu().numpy() print("=" * 60) print("API: DIAGNOSTIC: SIGMOID OUTPUT STATS (before thresholding)") print("=" * 60) print(f" Max: {pred_np.max():.6f}") print(f" Mean: {pred_np.mean():.6f}") print(f" Min: {pred_np.min():.6f}") print(f" Std: {pred_np.std():.6f}") print(f" Shape: {pred_np.shape}") sys.stdout.flush() default_threshold = float(os.environ.get("SEGMENTATION_THRESHOLD", "0.5")) count_035 = np.sum(pred_np > 0.35) count_03 = np.sum(pred_np > 0.3) count_05 = np.sum(pred_np > 0.5) total_voxels = pred_np.size threshold = default_threshold if count_05 == 0 and default_threshold >= 0.5: if count_035 > 0: print(f"API: ⚠ WARNING: No voxels > 0.5, but {count_035:,} voxels > 0.35. Trying threshold 0.35...") threshold = 0.35 elif count_03 > 0: print(f"API: ⚠ WARNING: No voxels > 0.5 or 0.35, but {count_03:,} voxels > 0.3. Trying threshold 0.3...") threshold = 0.3 else: percentile_thresh = float(np.quantile(pred_np.flatten(), 0.995)) threshold = max(0.3, percentile_thresh) print(f"API: ⚠ WARNING: No voxels > 0.5. Using percentile threshold: {threshold:.4f}") elif count_035 == 0 and default_threshold < 0.5: if count_03 > 0: print(f"API: ⚠ WARNING: No voxels > {default_threshold}, but {count_03:,} voxels > 0.3. Trying threshold 0.3...") threshold = 0.3 elif count_05 == 0: percentile_thresh = float(np.quantile(pred_np.flatten(), 0.995)) threshold = max(0.3, percentile_thresh) print(f"API: ⚠ WARNING: No voxels > {default_threshold}. Using percentile threshold: {threshold:.4f}") pred_binary = (pred_np > threshold).astype(np.uint8) fraction = pred_binary.sum() / total_voxels if total_voxels > 0 else 0.0 if fraction < 0.005 or fraction > 0.040: print(f"API: ⚠ WARNING: Mask fraction {100*fraction:.3f}% outside typical range (0.5%-4.0%). May indicate segmentation issues.") count_after_thresh = pred_binary.sum() print(f"API: After threshold ({threshold}): {count_after_thresh:,} voxels ({100*count_after_thresh/total_voxels:.2f}%)") if count_05 > 0 and count_after_thresh == 0 and threshold == 0.5: print(f"API: ⚠ WARNING: Had {count_05:,} voxels > 0.5 but 0 after thresholding - check threshold logic!") sys.stdout.flush() print("API: Refining liver segmentation mask...") try: pred_np_copy = pred_np.copy() pred_binary_refined, refinement_metrics, confidence_score = refine_liver_mask_enhanced( pred_binary, voxel_spacing, pred_np_copy, threshold, modality ) pred_binary = pred_binary_refined print(f"API: ✓ Refinement applied: {refinement_metrics['connected_components_before']} → {refinement_metrics['connected_components_after']} components") print(f"API: → Volume change: {refinement_metrics['volume_change_percent']:.2f}% ({refinement_metrics['volume_change_ml']:.2f} ml)") print(f"API: → Confidence score: {confidence_score:.1f}%") except Exception as e: print(f"API: ⚠ Refinement failed (using original mask): {e}") confidence_score = 50.0 print("API: Loading original volume for visualization...") original_volume = original_data.copy() if original_volume.max() > original_volume.min(): original_volume = (original_volume - original_volume.min()) / (original_volume.max() - original_volume.min()) else: original_volume = np.zeros_like(original_volume) slice_idx = max(0, min(slice_idx if slice_idx is not None else (pred_binary.shape[-1] // 2), pred_binary.shape[-1] - 1)) if len(original_volume.shape) == 3: original_slice = original_volume[:, :, slice_idx] else: original_slice = original_volume[:, :, slice_idx] if original_volume.shape[2] > slice_idx else original_volume[:, :, 0] if len(pred_binary.shape) == 4: pred_slice = pred_binary[0, :, :, slice_idx] if pred_binary.shape[3] > slice_idx else pred_binary[0, :, :, 0] elif len(pred_binary.shape) == 5: pred_slice = pred_binary[0, 0, :, :, slice_idx] if pred_binary.shape[4] > slice_idx else pred_binary[0, 0, :, :, 0] else: pred_slice = pred_binary[:, :, slice_idx] if pred_binary.shape[2] > slice_idx else pred_binary[:, :, 0] if pred_slice.shape != original_slice.shape: print(f"API: Warning: Shape mismatch in overlay (pred_slice: {pred_slice.shape}, original_slice: {original_slice.shape}). Resizing pred_slice to match.") from scipy.ndimage import zoom zoom_factors = (original_slice.shape[0] / pred_slice.shape[0], original_slice.shape[1] / pred_slice.shape[1]) pred_slice = zoom(pred_slice, zoom_factors, order=0) overlay = np.zeros((*original_slice.shape, 3), dtype=np.uint8) overlay[:, :, 0] = (original_slice * 255).astype(np.uint8) overlay[:, :, 1] = (original_slice * 255).astype(np.uint8) overlay[:, :, 2] = (original_slice * 255).astype(np.uint8) mask = pred_slice > 0 overlay[mask, 0] = 0 overlay[mask, 1] = 255 overlay[mask, 2] = 0 overlay_img = Image.fromarray(overlay) img_buffer = io.BytesIO() overlay_img.save(img_buffer, format='PNG') img_base64 = base64.b64encode(img_buffer.getvalue()).decode('utf-8') print("API: Saving segmentation mask...") aff = nifti_img.affine hdr = nifti_img.header.copy() hdr.set_data_dtype(np.uint8) pred_save = pred_binary[0] if pred_binary.ndim == 4 else pred_binary if pred_binary.ndim == 5: pred_save = pred_binary[0, 0] original_shape = original_data.shape pred_shape = pred_save.shape if pred_shape != original_shape: print(f"API: Resampling prediction from {pred_shape} to original shape {original_shape}...") from scipy.ndimage import zoom zoom_factors = tuple(orig / pred for orig, pred in zip(original_shape, pred_shape)) pred_save = zoom(pred_save.astype(np.float32), zoom_factors, order=0).astype(np.uint8) print(f"API: ✓ Resampled to match original dimensions") with tempfile.NamedTemporaryFile(suffix='.nii.gz', delete=False) as tmp_file: nifti_pred = nib.Nifti1Image(pred_save.astype(np.uint8), affine=aff, header=hdr) nib.save(nifti_pred, tmp_file.name) output_path = tmp_file.name total_voxels = pred_binary.size liver_voxels = pred_binary.sum() liver_percentage = (liver_voxels / total_voxels) * 100 print(f"API: Calculating volume and morphology...") volume_ml = calculate_liver_volume(pred_binary, voxel_spacing) if len(pred_binary.shape) == 4: mask_3d = pred_binary[0] elif len(pred_binary.shape) == 5: mask_3d = pred_binary[0, 0] else: mask_3d = pred_binary labels_final, num_components_final = ndimage.label(mask_3d) morphology = analyze_liver_morphology(pred_binary) from processing import check_volume_sanity volume_sanity_status, volume_sanity_msg = check_volume_sanity(volume_ml) if volume_sanity_status != "OK": print(f"API: ⚠ {volume_sanity_status}: {volume_sanity_msg}") if num_components_final != 1: print(f"API: ⚠ WARNING: Expected 1 connected component, found {num_components_final}. Quality check recommended.") medical_report = generate_medical_report( { "volume_shape": list(pred_binary.shape), "liver_voxels": int(liver_voxels), "total_voxels": int(total_voxels), "liver_percentage": float(liver_percentage), "slice_index": int(slice_idx), "total_slices": int(pred_binary.shape[-1]), "modality": modality, "confidence_score": float(confidence_score) }, volume_ml, morphology, modality, confidence_score ) print(f"API: ✓ Segmentation complete. Liver: {liver_voxels:,} voxels ({liver_percentage:.2f}%), Volume: {volume_ml:.2f} ml") return { "success": True, "overlay_image": f"data:image/png;base64,{img_base64}", "segmentation_path": output_path, "statistics": { "volume_shape": list(pred_binary.shape), "liver_voxels": int(liver_voxels), "total_voxels": int(total_voxels), "liver_percentage": float(liver_percentage), "slice_index": int(slice_idx), "total_slices": int(pred_binary.shape[-1]), "modality": modality, "liver_volume_ml": round(volume_ml, 2) }, "medical_report": medical_report } except FileNotFoundError as e: error_msg = f"Checkpoint file not found: {str(e)}" print(error_msg) import traceback traceback.print_exc() return { "success": False, "error": error_msg, "error_type": "checkpoint_not_found" } except RuntimeError as e: error_msg = f"Runtime error (possibly CUDA/GPU): {str(e)}" print(error_msg) import traceback traceback.print_exc() return { "success": False, "error": error_msg, "error_type": "runtime_error" } except Exception as e: error_msg = f"Error during inference: {str(e)}" print(error_msg) import traceback tb = traceback.format_exc() print(f"Full traceback: {tb}") return { "success": False, "error": error_msg, "error_type": type(e).__name__ } def safe_predict_volume(nifti_file, modality, slice_idx=None): try: result = predict_volume(nifti_file, modality, slice_idx) if result is None or len(result) != 4: return None, "## ❌ Error\n\nUnexpected return value from prediction function.", "", None overlay_img, info_text, report_text, output_path = result if overlay_img is None and info_text and "Error" in info_text: return None, info_text, report_text or "", output_path or None return overlay_img, info_text, report_text, output_path except Exception as e: error_msg = f"**Unexpected error:** {str(e)}\n\nPlease check the logs for details." print(f"✗ Safe wrapper caught error: {error_msg}") import traceback traceback.print_exc() return None, f"## ❌ Error\n\n{error_msg}", "", None