| import os | |
| import sys | |
| import time | |
| import tempfile | |
| import base64 | |
| import io | |
| import warnings | |
| import threading | |
| import numpy as np | |
| import torch | |
| import nibabel as nib | |
| from PIL import Image | |
| from typing import Optional | |
| from torch.amp import autocast | |
| from scipy import ndimage | |
| import model_loader | |
| from processing import preprocess_nifti, calculate_liver_volume, analyze_liver_morphology, generate_medical_report, refine_liver_mask, refine_liver_mask_enhanced | |
| warnings.filterwarnings('ignore', category=UserWarning, message='.*non-tuple sequence for multidimensional indexing.*') | |
| warnings.filterwarnings('ignore', category=FutureWarning, message='.*non-tuple sequence.*') | |
| PROCESSING_LOCK = threading.Lock() | |
| def adjust_roi_for_volume(volume_shape): | |
| if model_loader.WINDOW_INFER is None: | |
| return | |
| if len(volume_shape) < 4: | |
| return | |
| if len(volume_shape) == 5: | |
| _, _, depth, height, width = volume_shape | |
| elif len(volume_shape) == 4: | |
| _, depth, height, width = volume_shape | |
| else: | |
| depth, height, width = volume_shape[-3:] | |
| current_roi = list(model_loader.WINDOW_INFER.roi_size) if isinstance(model_loader.WINDOW_INFER.roi_size, (list, tuple)) else [model_loader.WINDOW_INFER.roi_size[0], model_loader.WINDOW_INFER.roi_size[1], model_loader.WINDOW_INFER.roi_size[2]] | |
| original_roi = current_roi.copy() | |
| roi_d, roi_h, roi_w = current_roi | |
| adjusted = False | |
| if roi_d > depth: | |
| current_roi[0] = min(depth, 64) | |
| adjusted = True | |
| print(f" β ROI depth ({roi_d}) > volume depth ({depth}). Adjusting to {current_roi[0]}") | |
| if roi_h > height: | |
| current_roi[1] = min(height, current_roi[1]) | |
| adjusted = True | |
| print(f" β ROI height ({roi_h}) > volume height ({height}). Adjusting to {current_roi[1]}") | |
| if roi_w > width: | |
| current_roi[2] = min(width, current_roi[2]) | |
| adjusted = True | |
| print(f" β ROI width ({roi_w}) > volume width ({width}). Adjusting to {current_roi[2]}") | |
| total_volume = depth * height * width | |
| roi_volume = current_roi[0] * current_roi[1] * current_roi[2] | |
| if total_volume > 20_000_000: | |
| overlap = 0.1 | |
| if model_loader.WINDOW_INFER.overlap > 0.1: | |
| model_loader.WINDOW_INFER.overlap = overlap | |
| adjusted = True | |
| print(f" β Large volume detected ({total_volume:,} voxels). Reducing overlap to {overlap} for faster processing") | |
| elif total_volume > 10_000_000: | |
| overlap = 0.12 | |
| if model_loader.WINDOW_INFER.overlap > 0.12: | |
| model_loader.WINDOW_INFER.overlap = overlap | |
| adjusted = True | |
| print(f" β Medium-large volume detected ({total_volume:,} voxels). Reducing overlap to {overlap}") | |
| if depth < 64 and current_roi[0] > depth * 0.8: | |
| current_roi[0] = max(32, int(depth * 0.8)) | |
| adjusted = True | |
| print(f" β Small depth dimension ({depth}). Optimizing ROI depth to {current_roi[0]}") | |
| if adjusted: | |
| model_loader.WINDOW_INFER.roi_size = tuple(current_roi) | |
| num_windows_approx = ((depth / current_roi[0]) * (height / current_roi[1]) * (width / current_roi[2])) / (1 - model_loader.WINDOW_INFER.overlap) | |
| print(f" β ROI adjusted: {original_roi} β {current_roi}") | |
| print(f" β Estimated windows: ~{int(num_windows_approx)} (overlap={model_loader.WINDOW_INFER.overlap:.2f})") | |
| def predict_volume(nifti_file, modality, slice_idx=None): | |
| global PROCESSING_LOCK | |
| if not PROCESSING_LOCK.acquire(blocking=False): | |
| return None, "**Already processing a file. Please wait for the current inference to complete.**", "", None | |
| if nifti_file is None: | |
| PROCESSING_LOCK.release() | |
| return None, "Please upload a NIfTI file (.nii.gz)", "", None | |
| image_tensor = None | |
| pred = None | |
| output_path = None | |
| try: | |
| print(f"Starting prediction for modality: {modality}") | |
| if modality == 'T1': | |
| if model_loader.MODEL_T1 is None: | |
| print("Loading T1 model_loader...") | |
| model_loader.MODEL_T1 = model_loader.load_model('T1') | |
| model_instance = model_loader.MODEL_T1 | |
| else: | |
| if model_loader.MODEL_T2 is None: | |
| print("Loading T2 model_loader...") | |
| model_loader.MODEL_T2 = model_loader.load_model('T2') | |
| model_instance = model_loader.MODEL_T2 | |
| print("Preprocessing NIfTI file...") | |
| from pathlib import Path | |
| fp = getattr(nifti_file, "name", nifti_file) | |
| file_path = str(Path(fp).resolve()) | |
| if not os.path.exists(file_path): | |
| raise FileNotFoundError(f"Uploaded file not found: {file_path}") | |
| nifti_img = nib.load(file_path) | |
| nifti_img = nib.as_closest_canonical(nifti_img) | |
| voxel_spacing = nifti_img.header.get_zooms()[:3] if len(nifti_img.header.get_zooms()) >= 3 else (1.0, 1.0, 1.0) | |
| aff = nifti_img.affine | |
| hdr = nifti_img.header.copy() | |
| original_data = nifti_img.get_fdata(dtype=np.float32) | |
| if model_loader.DEVICE.type == 'cuda': | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("CUDA not available but device is set to CUDA") | |
| torch.cuda.empty_cache() | |
| free_mem = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / (1024**3) | |
| if free_mem < 0.5: | |
| raise RuntimeError(f"Insufficient GPU memory: {free_mem:.2f} GB free. Please restart.") | |
| image_data = preprocess_nifti(file_path, device=model_loader.DEVICE) | |
| if len(image_data.shape) == 3: | |
| image_tensor = image_data.unsqueeze(0).unsqueeze(0) | |
| elif len(image_data.shape) == 4: | |
| image_tensor = image_data.unsqueeze(0) | |
| else: | |
| image_tensor = image_data | |
| if model_loader.DEVICE.type == 'cuda' and len(image_tensor.shape) >= 4: | |
| try: | |
| image_tensor = image_tensor.contiguous(memory_format=torch.channels_last_3d) | |
| if image_tensor.is_contiguous(memory_format=torch.channels_last_3d): | |
| print(f" β Using channels-last 3D memory layout (optimized for GPU)") | |
| except: | |
| pass | |
| print(f"Input tensor shape: {image_tensor.shape}") | |
| if model_loader.WINDOW_INFER is not None: | |
| print("Adjusting ROI size based on volume dimensions...") | |
| adjust_roi_for_volume(image_tensor.shape) | |
| print("Running inference...") | |
| inference_start = time.time() | |
| y1, y2, y3, y4 = None, None, None, None | |
| gpu_handle = None | |
| if model_loader.DEVICE.type == 'cuda': | |
| allocated = torch.cuda.memory_allocated(0) / (1024**3) | |
| free_memory_gb = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / (1024**3) | |
| print(f" β Pre-inference GPU memory: {allocated:.2f} GB allocated, {free_memory_gb:.2f} GB free") | |
| try: | |
| import pynvml | |
| pynvml.nvmlInit() | |
| gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(0) | |
| util = pynvml.nvmlDeviceGetUtilizationRates(gpu_handle) | |
| print(f" β GPU utilization: {util.gpu}%") | |
| except: | |
| pass | |
| if free_memory_gb < 1.0: | |
| print(f" β WARNING: Very low free memory ({free_memory_gb:.2f} GB). Adjusting settings...") | |
| if model_loader.WINDOW_INFER is not None: | |
| model_loader.WINDOW_INFER.sw_batch_size = 1 | |
| if free_memory_gb < 0.5: | |
| model_loader.WINDOW_INFER.roi_size = [192, 192, 32] | |
| model_loader.WINDOW_INFER.overlap = 0.25 | |
| if free_memory_gb < 0.5: | |
| torch.cuda.empty_cache() | |
| free_memory_gb = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / (1024**3) | |
| print(f" β After cache clear: {free_memory_gb:.2f} GB free") | |
| if free_memory_gb < 0.5: | |
| error_msg = f"Insufficient GPU memory ({free_memory_gb:.2f} GB free). The model is using {allocated:.2f} GB. Please use a smaller input volume." | |
| print(f" β {error_msg}") | |
| raise RuntimeError(error_msg) | |
| import config | |
| import threading | |
| timeout_occurred = threading.Event() | |
| def timeout_handler(): | |
| timeout_occurred.set() | |
| timeout_timer = threading.Timer(config.INFERENCE_TIMEOUT, timeout_handler) | |
| timeout_timer.start() | |
| try: | |
| with torch.no_grad(): | |
| if model_loader.DEVICE.type == 'cuda': | |
| torch.cuda.synchronize() | |
| print(" β AMP enabled: YES") | |
| sys.stdout.flush() | |
| last_heartbeat = time.time() | |
| heartbeat_interval = 30 | |
| try: | |
| with autocast(device_type='cuda'): | |
| print(f" β Starting inference (timeout: {config.INFERENCE_TIMEOUT}s)...") | |
| sys.stdout.flush() | |
| start_time = time.time() | |
| if timeout_occurred.is_set(): | |
| raise TimeoutError(f"Inference timeout: {config.INFERENCE_TIMEOUT}s") | |
| gpu_utils = [] | |
| if gpu_handle is not None: | |
| import threading | |
| monitor_active = threading.Event() | |
| monitor_active.set() | |
| def monitor_gpu(): | |
| while monitor_active.is_set(): | |
| try: | |
| import pynvml | |
| util = pynvml.nvmlDeviceGetUtilizationRates(gpu_handle) | |
| gpu_utils.append(util.gpu) | |
| time.sleep(1.0) | |
| except: | |
| break | |
| monitor_thread = threading.Thread(target=monitor_gpu, daemon=True) | |
| monitor_thread.start() | |
| torch.cuda.synchronize() | |
| inference_start_sync = time.time() | |
| y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance) | |
| if gpu_handle is not None: | |
| monitor_active.clear() | |
| if monitor_thread.is_alive(): | |
| monitor_thread.join(timeout=2.0) | |
| if timeout_occurred.is_set(): | |
| raise TimeoutError(f"Inference timeout: {config.INFERENCE_TIMEOUT}s") | |
| torch.cuda.synchronize() | |
| inference_time = time.time() - inference_start_sync | |
| elapsed = time.time() - start_time | |
| if gpu_handle is not None and len(gpu_utils) > 0: | |
| try: | |
| import pynvml | |
| final_util = pynvml.nvmlDeviceGetUtilizationRates(gpu_handle) | |
| avg_util = sum(gpu_utils) / len(gpu_utils) if gpu_utils else final_util.gpu | |
| max_util = max(gpu_utils) if gpu_utils else final_util.gpu | |
| print(f" β Inference completed in {inference_time:.1f}s (total: {elapsed:.1f}s)") | |
| print(f" β GPU utilization: avg={avg_util:.1f}%, max={max_util:.1f}%, final={final_util.gpu}%") | |
| if avg_util < 50: | |
| print(f" β WARNING: Low GPU utilization ({avg_util:.1f}%). May be CPU/transfer-bound.") | |
| print(f" β Consider: 1) Install CUDA extensions (mamba_ssm, selective_scan_cuda_oflex)") | |
| print(f" β 2) Check data transfer efficiency (channels-last enabled)") | |
| print(f" β 3) Verify GPU aggregation is enabled (check logs above)") | |
| except: | |
| print(f" β Inference completed in {inference_time:.1f}s (total: {elapsed:.1f}s)") | |
| else: | |
| print(f" β Inference completed in {inference_time:.1f}s (total: {elapsed:.1f}s)") | |
| except TimeoutError: | |
| timeout_timer.cancel() | |
| raise | |
| except RuntimeError as e: | |
| if "out of memory" in str(e): | |
| print(f" GPU OOM error: {e}") | |
| print(" Applying progressive OOM fallback...") | |
| torch.cuda.empty_cache() | |
| original_batch = model_loader.WINDOW_INFER.sw_batch_size | |
| original_roi = list(model_loader.WINDOW_INFER.roi_size) | |
| original_overlap = model_loader.WINDOW_INFER.overlap | |
| original_device = model_loader.WINDOW_INFER.device | |
| original_device_str = str(original_device) | |
| fallback_applied = None | |
| try: | |
| if original_batch > 1: | |
| print(f" Fallback 1: Reducing sw_batch_size {original_batch} -> 1") | |
| model_loader.WINDOW_INFER.sw_batch_size = 1 | |
| with autocast(device_type='cuda'): | |
| y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance) | |
| torch.cuda.synchronize() | |
| fallback_applied = "sw_batch_size=1" | |
| print(f" Retry successful with {fallback_applied}") | |
| else: | |
| raise RuntimeError("Already at batch_size=1") | |
| except RuntimeError as e2: | |
| if "out of memory" not in str(e2): | |
| raise | |
| try: | |
| if original_roi[0] > 48: | |
| new_depth = 48 | |
| print(f" Fallback 2: Reducing ROI depth {original_roi[0]} -> {new_depth}") | |
| model_loader.WINDOW_INFER.roi_size = (new_depth, original_roi[1], original_roi[2]) | |
| with autocast(device_type='cuda'): | |
| y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance) | |
| torch.cuda.synchronize() | |
| fallback_applied = f"roi_depth={new_depth}" | |
| print(f" Retry successful with {fallback_applied}") | |
| else: | |
| raise RuntimeError("Already at depth=48") | |
| except RuntimeError as e3: | |
| if "out of memory" not in str(e3): | |
| raise | |
| try: | |
| if original_roi[0] > 32: | |
| new_depth = 32 | |
| print(f" Fallback 3: Reducing ROI depth {original_roi[0]} -> {new_depth}") | |
| model_loader.WINDOW_INFER.roi_size = (new_depth, original_roi[1], original_roi[2]) | |
| with autocast(device_type='cuda'): | |
| y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance) | |
| torch.cuda.synchronize() | |
| fallback_applied = f"roi_depth={new_depth}" | |
| print(f" Retry successful with {fallback_applied}") | |
| else: | |
| raise RuntimeError("Already at depth=32") | |
| except RuntimeError as e4: | |
| if "out of memory" not in str(e4): | |
| raise | |
| try: | |
| if original_device == 'cuda': | |
| print(f" Fallback 4: Switching aggregation to CPU") | |
| model_loader.WINDOW_INFER.device = torch.device('cpu') | |
| with autocast(device_type='cuda'): | |
| y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance) | |
| torch.cuda.synchronize() | |
| fallback_applied = "cpu_aggregation" | |
| print(f" Retry successful with {fallback_applied}") | |
| else: | |
| raise RuntimeError("Already using CPU aggregation") | |
| except RuntimeError as e5: | |
| model_loader.WINDOW_INFER.sw_batch_size = original_batch | |
| model_loader.WINDOW_INFER.roi_size = tuple(original_roi) | |
| model_loader.WINDOW_INFER.overlap = original_overlap | |
| model_loader.WINDOW_INFER.device = original_device if isinstance(original_device, torch.device) else torch.device(original_device_str) | |
| allocated = torch.cuda.memory_allocated(0) / (1024**3) | |
| total_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3) | |
| free_memory_gb = (total_gb - allocated) | |
| error_msg = f"GPU out of memory after all fallbacks. GPU: {allocated:.2f} GB / {total_gb:.2f} GB ({100*allocated/total_gb:.1f}% full), {free_memory_gb:.2f} GB free. Please restart Space or use smaller input. Error: {e5}" | |
| print(f" {error_msg}") | |
| raise RuntimeError(error_msg) | |
| finally: | |
| if fallback_applied: | |
| print(f" Note: Applied fallback ({fallback_applied}). Restoring original settings after inference.") | |
| model_loader.WINDOW_INFER.sw_batch_size = original_batch | |
| model_loader.WINDOW_INFER.roi_size = tuple(original_roi) | |
| model_loader.WINDOW_INFER.overlap = original_overlap | |
| model_loader.WINDOW_INFER.device = original_device | |
| else: | |
| raise | |
| else: | |
| print(" β AMP: Not available on CPU") | |
| print(" β Starting sliding window inference (this may take 5-10 minutes on CPU)...") | |
| start_time = time.time() | |
| y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance) | |
| elapsed = time.time() - start_time | |
| print(f" β Inference completed in {elapsed:.1f}s") | |
| finally: | |
| timeout_timer.cancel() | |
| if y1 is not None: | |
| print(f" β Raw model output shapes: y1={y1.shape if hasattr(y1, 'shape') else 'N/A'}, y2={y2.shape if hasattr(y2, 'shape') else 'N/A'}") | |
| sys.stdout.flush() | |
| else: | |
| print(" β Inference failed before model output was generated") | |
| sys.stdout.flush() | |
| raise RuntimeError("Inference failed: model did not produce output") | |
| pred = y1 | |
| if isinstance(pred, torch.Tensor): | |
| raw_min = pred.min().item() | |
| raw_max = pred.max().item() | |
| raw_mean = pred.mean().item() | |
| print(f" β Raw logits (before sigmoid): min={raw_min:.6f}, max={raw_max:.6f}, mean={raw_mean:.6f}") | |
| sys.stdout.flush() | |
| print(" β Applying sigmoid activation (confirming: using sigmoid, not softmax)") | |
| sys.stdout.flush() | |
| pred = torch.sigmoid(pred) | |
| if len(pred.shape) == 5: | |
| if pred.shape[1] > 1: | |
| print(f" β WARNING: Multi-channel output detected (C={pred.shape[1]}). Using channel 0 for binary segmentation.") | |
| pred = pred[0, 0] | |
| else: | |
| pred = pred[0, 0] | |
| elif len(pred.shape) == 4: | |
| pred = pred[0] if pred.shape[0] == 1 else pred | |
| print(f" β Final prediction shape after sigmoid: {pred.shape}") | |
| sys.stdout.flush() | |
| inference_time = time.time() - inference_start | |
| print(f"β Inference completed in {inference_time:.1f} seconds") | |
| del y1, y2, y3, y4 | |
| if model_loader.DEVICE.type == 'cuda' and (torch.cuda.memory_allocated(0) / (1024**3)) > 40: | |
| torch.cuda.empty_cache() | |
| print("Post-processing predictions...") | |
| sys.stdout.flush() | |
| pred = pred.detach().float().cpu() | |
| if pred.ndim == 5: | |
| pred_np = pred[0, 0].numpy() | |
| elif pred.ndim == 4 and pred.shape[0] == 1: | |
| pred_np = pred[0].numpy() | |
| else: | |
| pred_np = pred.numpy() | |
| pred_max = pred_np.max() | |
| pred_min = pred_np.min() | |
| pred_mean = pred_np.mean() | |
| if pred_max < 0.3: | |
| print(f" β WARNING: Maximum prediction value is very low ({pred_max:.4f}). Model confidence may be low.") | |
| print(f" This can occur with:") | |
| print(f" - Low resolution/compressed input (lost texture cues)") | |
| print(f" - Normalization mismatch (input not properly preprocessed)") | |
| print(f" - Missing metadata (wrong voxel spacing/scaling)") | |
| if pred_mean < 0.1: | |
| print(f" β WARNING: Average prediction is very low ({pred_mean:.4f}). Most voxels have low confidence.") | |
| print(f" Consider checking input quality and preprocessing.") | |
| total_voxels = pred_np.size | |
| if modality.upper() == "T1": | |
| candidates = np.linspace(0.60, 0.80, 10) | |
| frac_min, frac_max = 0.015, 0.040 | |
| else: | |
| candidates = np.linspace(0.30, 0.70, 9) | |
| frac_min, frac_max = 0.008, 0.040 | |
| best = None | |
| for t in candidates: | |
| mb = (pred_np > t).astype(np.uint8) | |
| if mb.sum() < 1000: | |
| continue | |
| frac = mb.mean() | |
| if not (frac_min <= frac <= frac_max): | |
| continue | |
| lbl, n = ndimage.label(mb) | |
| if n == 0: | |
| continue | |
| sizes = ndimage.sum(mb, lbl, range(1, n + 1)) | |
| score = sizes.max() - (n - 1) * 1e5 | |
| if best is None or score > best[0]: | |
| best = (score, t, mb) | |
| if best: | |
| _, threshold, pred_binary = best | |
| print(f" Grid search selected threshold: {threshold:.3f}") | |
| else: | |
| default_threshold = float(os.environ.get("T1_THRESHOLD", "0.65")) if modality.upper() == "T1" else float(os.environ.get("SEGMENTATION_THRESHOLD", "0.5")) | |
| threshold = default_threshold | |
| pred_binary = (pred_np > threshold).astype(np.uint8) | |
| print(f" Grid search failed, using default threshold: {threshold:.3f}") | |
| count_default = pred_binary.sum() | |
| count_035 = (pred_np > 0.35).astype(np.uint8).sum() | |
| count_03 = (pred_np > 0.3).astype(np.uint8).sum() | |
| if count_default == 0 and default_threshold >= 0.5: | |
| if count_035 > 0: | |
| print(f" β WARNING: No voxels > {default_threshold}, but {count_035:,} voxels > 0.35. Trying threshold 0.35...") | |
| threshold = 0.35 | |
| elif count_03 > 0: | |
| print(f" β WARNING: No voxels > {default_threshold} or 0.35, but {count_03:,} voxels > 0.3. Trying threshold 0.3...") | |
| threshold = 0.3 | |
| else: | |
| percentile_thresh = float(np.quantile(pred_np.flatten(), 0.995)) | |
| threshold = max(0.3, percentile_thresh) | |
| print(f" β WARNING: No voxels > {default_threshold}. Using percentile threshold: {threshold:.4f}") | |
| pred_binary = (pred_np > threshold).astype(np.uint8) | |
| elif count_default == 0 and default_threshold < 0.5: | |
| if count_03 > 0: | |
| print(f" β WARNING: No voxels > {default_threshold}, but {count_03:,} voxels > 0.3. Trying threshold 0.3...") | |
| threshold = 0.3 | |
| else: | |
| percentile_thresh = float(np.quantile(pred_np.flatten(), 0.995)) | |
| threshold = max(0.3, percentile_thresh) | |
| print(f" β WARNING: No voxels > {default_threshold}. Using percentile threshold: {threshold:.4f}") | |
| pred_binary = (pred_np > threshold).astype(np.uint8) | |
| if modality.upper() == "T1": | |
| try: | |
| h, w, d = original_data.shape | |
| is_prenormalized = (original_data.max() <= 1.0 and original_data.min() >= 0.0) | |
| if not is_prenormalized: | |
| right_upper_quadrant = original_data[h//4:3*h//4, w//2:, d//3:2*d//3] | |
| if right_upper_quadrant.size > 0 and (right_upper_quadrant > 0).sum() > 0: | |
| median_intensity = np.median(right_upper_quadrant[right_upper_quadrant > 0]) | |
| mad = np.median(np.abs(right_upper_quadrant[right_upper_quadrant > 0] - median_intensity)) | |
| k = 1.5 | |
| intensity_lower = median_intensity - k * mad | |
| intensity_upper = median_intensity + k * mad | |
| if intensity_upper - intensity_lower > 0.5: | |
| if len(pred_np.shape) == 4: | |
| pred_3d = pred_np[0, 0] | |
| elif len(pred_np.shape) == 5: | |
| pred_3d = pred_np[0, 0, 0] | |
| else: | |
| pred_3d = pred_np | |
| if pred_3d.shape == original_data.shape: | |
| intensity_mask = (original_data >= intensity_lower) & (original_data <= intensity_upper) | |
| pred_3d_clamped = pred_3d.copy() | |
| pred_3d_clamped[~intensity_mask] = np.minimum(pred_3d_clamped[~intensity_mask], threshold * 0.6) | |
| if pred_3d_clamped.std() > 1e-6: | |
| if len(pred_np.shape) == 4: | |
| pred_np[0, 0] = pred_3d_clamped | |
| elif len(pred_np.shape) == 5: | |
| pred_np[0, 0, 0] = pred_3d_clamped | |
| else: | |
| pred_np = pred_3d_clamped | |
| print(f" Intensity gate: Clamped predictions outside liver-like intensity range [{intensity_lower:.2f}, {intensity_upper:.2f}]") | |
| else: | |
| print(f" Intensity gate: Shape mismatch (pred: {pred_3d.shape}, original: {original_data.shape}). Skipping intensity gate.") | |
| else: | |
| left_upper_region = original_data[h//4:3*h//4, :w//2, d//3:2*d//3] | |
| if left_upper_region.size > 0 and (left_upper_region > 0).sum() > 100: | |
| stomach_like_intensity = np.percentile(left_upper_region[left_upper_region > 0], 75) | |
| if len(pred_np.shape) == 4: | |
| pred_3d = pred_np[0, 0] | |
| elif len(pred_np.shape) == 5: | |
| pred_3d = pred_np[0, 0, 0] | |
| else: | |
| pred_3d = pred_np | |
| if pred_3d.shape == original_data.shape: | |
| stomach_mask = (original_data > stomach_like_intensity * 0.9) & (original_data < stomach_like_intensity * 1.1) | |
| pred_3d_clamped = pred_3d.copy() | |
| pred_3d_clamped[stomach_mask] = np.minimum(pred_3d_clamped[stomach_mask], threshold * 0.5) | |
| if pred_3d_clamped.std() > 1e-6: | |
| if len(pred_np.shape) == 4: | |
| pred_np[0, 0] = pred_3d_clamped | |
| elif len(pred_np.shape) == 5: | |
| pred_np[0, 0, 0] = pred_3d_clamped | |
| else: | |
| pred_np = pred_3d_clamped | |
| print(f" Intensity gate: Suppressed stomach-like intensities (threshold: {stomach_like_intensity:.3f})") | |
| else: | |
| print(f" Intensity gate: Shape mismatch (pred: {pred_3d.shape}, original: {original_data.shape}). Skipping intensity gate.") | |
| except Exception as e: | |
| print(f" Intensity gate failed (non-critical): {e}") | |
| pred_binary = (pred_np > threshold).astype(np.uint8) | |
| count_after_thresh = pred_binary.sum() | |
| fraction = count_after_thresh / total_voxels if total_voxels > 0 else 0.0 | |
| volume_ml = calculate_liver_volume(pred_binary, voxel_spacing) | |
| print(f" After threshold ({threshold:.3f}): {count_after_thresh:,} voxels ({100*fraction:.2f}%), volume: {volume_ml:.1f} ml") | |
| if fraction > 0.040 or volume_ml > 2200: | |
| threshold_max = float(os.environ.get("THRESHOLD_MAX", "0.90")) | |
| new_threshold = min(threshold_max, threshold + 0.05) | |
| if new_threshold > threshold: | |
| print(f" Size-aware auto-tune: Mask too large (fraction={100*fraction:.2f}%, volume={volume_ml:.1f}ml). Increasing threshold {threshold:.3f} -> {new_threshold:.3f}") | |
| threshold = new_threshold | |
| pred_binary = (pred_np > threshold).astype(np.uint8) | |
| count_after_thresh = pred_binary.sum() | |
| fraction = count_after_thresh / total_voxels if total_voxels > 0 else 0.0 | |
| volume_ml = calculate_liver_volume(pred_binary, voxel_spacing) | |
| print(f" After auto-tune threshold ({threshold:.3f}): {count_after_thresh:,} voxels ({100*fraction:.2f}%), volume: {volume_ml:.1f} ml") | |
| if fraction < 0.005 or fraction > 0.040: | |
| print(f" WARNING: Mask fraction {100*fraction:.3f}% outside typical range (0.5%-4.0%). May indicate segmentation issues.") | |
| sys.stdout.flush() | |
| print("Refining liver segmentation mask...") | |
| sys.stdout.flush() | |
| pred_np_copy = pred_np.copy() | |
| try: | |
| pred_binary_refined, refinement_metrics, confidence_score = refine_liver_mask_enhanced( | |
| pred_binary, voxel_spacing, pred_np_copy, threshold, modality | |
| ) | |
| pred_binary = pred_binary_refined | |
| guards_ok = refinement_metrics.get('guards_ok', True) | |
| print(f" Refinement applied: {refinement_metrics['connected_components_before']} -> {refinement_metrics['connected_components_after']} components") | |
| print(f" Volume change: {refinement_metrics['volume_change_percent']:.2f}% ({refinement_metrics['volume_change_ml']:.2f} ml)") | |
| print(f" Confidence score: {confidence_score:.1f}% {'(guards OK)' if guards_ok else '(guards triggered)'}") | |
| except Exception as e: | |
| print(f" Refinement failed (using original mask): {e}") | |
| import traceback | |
| traceback.print_exc() | |
| confidence_score = 50.0 | |
| guards_ok = False | |
| del pred_np, pred_np_copy | |
| print("Loading original volume for visualization...") | |
| original_volume = original_data.copy() | |
| if original_volume.max() > original_volume.min(): | |
| original_volume = (original_volume - original_volume.min()) / (original_volume.max() - original_volume.min()) | |
| else: | |
| original_volume = np.zeros_like(original_volume) | |
| slice_idx = max(0, min(slice_idx if slice_idx is not None else (pred_binary.shape[-1] // 2), pred_binary.shape[-1] - 1)) | |
| if len(original_volume.shape) == 3: | |
| original_slice = original_volume[:, :, slice_idx] | |
| else: | |
| original_slice = original_volume[:, :, slice_idx] if original_volume.shape[2] > slice_idx else original_volume[:, :, 0] | |
| if len(pred_binary.shape) == 4: | |
| pred_slice = pred_binary[0, :, :, slice_idx] if pred_binary.shape[3] > slice_idx else pred_binary[0, :, :, 0] | |
| elif len(pred_binary.shape) == 5: | |
| pred_slice = pred_binary[0, 0, :, :, slice_idx] if pred_binary.shape[4] > slice_idx else pred_binary[0, 0, :, :, 0] | |
| else: | |
| pred_slice = pred_binary[:, :, slice_idx] if pred_binary.shape[2] > slice_idx else pred_binary[:, :, 0] | |
| if pred_slice.shape != original_slice.shape: | |
| print(f" Warning: Shape mismatch in overlay (pred_slice: {pred_slice.shape}, original_slice: {original_slice.shape}). Resizing pred_slice to match.") | |
| from scipy.ndimage import zoom | |
| zoom_factors = (original_slice.shape[0] / pred_slice.shape[0], original_slice.shape[1] / pred_slice.shape[1]) | |
| pred_slice = zoom(pred_slice, zoom_factors, order=0) | |
| overlay = np.zeros((*original_slice.shape, 3), dtype=np.uint8) | |
| overlay[:, :, 0] = (original_slice * 255).astype(np.uint8) | |
| overlay[:, :, 1] = (original_slice * 255).astype(np.uint8) | |
| overlay[:, :, 2] = (original_slice * 255).astype(np.uint8) | |
| mask = pred_slice > 0 | |
| overlay[mask, 0] = 0 | |
| overlay[mask, 1] = 255 | |
| overlay[mask, 2] = 0 | |
| overlay_img = Image.fromarray(overlay) | |
| print("Saving segmentation mask...") | |
| aff = nifti_img.affine | |
| hdr = nifti_img.header.copy() | |
| hdr.set_data_dtype(np.uint8) | |
| pred_save = pred_binary[0] if pred_binary.ndim == 4 else pred_binary | |
| if pred_binary.ndim == 5: | |
| pred_save = pred_binary[0, 0] | |
| original_shape = original_data.shape | |
| pred_shape = pred_save.shape | |
| if pred_shape != original_shape: | |
| print(f" Resampling prediction from {pred_shape} to original shape {original_shape}...") | |
| from scipy.ndimage import zoom | |
| zoom_factors = tuple(orig / pred for orig, pred in zip(original_shape, pred_shape)) | |
| pred_save = zoom(pred_save.astype(np.float32), zoom_factors, order=0).astype(np.uint8) | |
| print(f" β Resampled to match original dimensions") | |
| with tempfile.NamedTemporaryFile(suffix='.nii.gz', delete=False) as tmp_file: | |
| nifti_pred = nib.Nifti1Image(pred_save.astype(np.uint8), affine=aff, header=hdr) | |
| nib.save(nifti_pred, tmp_file.name) | |
| output_path = tmp_file.name | |
| total_voxels = pred_binary.size | |
| liver_voxels = pred_binary.sum() | |
| liver_percentage = (liver_voxels / total_voxels) * 100 | |
| print(f"β Segmentation complete. Liver: {liver_voxels:,} voxels ({liver_percentage:.2f}%)") | |
| volume_ml = calculate_liver_volume(pred_binary, voxel_spacing) | |
| if len(pred_binary.shape) == 4: | |
| mask_3d = pred_binary[0] | |
| elif len(pred_binary.shape) == 5: | |
| mask_3d = pred_binary[0, 0] | |
| else: | |
| mask_3d = pred_binary | |
| labels_final, num_components_final = ndimage.label(mask_3d) | |
| morphology = analyze_liver_morphology(pred_binary) | |
| from processing import check_volume_sanity | |
| volume_sanity_status, volume_sanity_msg = check_volume_sanity(volume_ml) | |
| if volume_sanity_status != "OK": | |
| print(f" β {volume_sanity_status}: {volume_sanity_msg}") | |
| if num_components_final != 1: | |
| print(f" β WARNING: Expected 1 connected component, found {num_components_final}. Quality check recommended.") | |
| medical_report = generate_medical_report( | |
| { | |
| "volume_shape": list(pred_binary.shape), | |
| "liver_voxels": int(liver_voxels), | |
| "total_voxels": int(total_voxels), | |
| "liver_percentage": float(liver_percentage), | |
| "slice_index": int(slice_idx), | |
| "total_slices": int(pred_binary.shape[-1]), | |
| "modality": modality, | |
| "confidence_score": float(confidence_score) | |
| }, | |
| volume_ml, | |
| morphology, | |
| modality, | |
| confidence_score | |
| ) | |
| info_text = f""" | |
| ## Segmentation Results | |
| - **Volume Shape:** {pred_binary.shape} | |
| - **Liver Voxels:** {liver_voxels:,} ({liver_percentage:.2f}% of volume) | |
| - **Liver Volume:** {volume_ml:.2f} ml | |
| - **Slice:** {slice_idx + 1} of {pred_binary.shape[-1]} shown | |
| - **Modality:** {modality} | |
| """ | |
| severity_text = medical_report['severity'].upper().replace('_', ' ') | |
| confidence = medical_report['measurements'].get('confidence_score', 0.0) | |
| report_text = f""" | |
| ## Automated Liver Segmentation Report | |
| --- | |
| ### Study Information | |
| - **Study Date & Time:** {medical_report['study_date']} | |
| - **Imaging Modality:** {medical_report['modality']} | |
| - **Report Status:** **{severity_text}** | |
| - **Confidence Score:** **{confidence:.1f}%** | |
| --- | |
| ### Key Findings | |
| {chr(10).join(f"{finding}" for finding in medical_report['findings'])} | |
| --- | |
| ### Quantitative Measurements | |
| **Volume Analysis:** | |
| - **Liver Volume:** **{medical_report['measurements']['liver_volume_ml']} ml** ({medical_report['measurements']['liver_volume_liters']} L) | |
| - **Liver Percentage of Scan:** {medical_report['measurements']['liver_percentage']}% | |
| - **Segmented Voxels:** {medical_report['measurements']['liver_voxels']:,} / {medical_report['measurements']['total_voxels']:,} total voxels | |
| **Morphological Analysis:** | |
| - **Connected Components:** {medical_report['measurements']['morphology']['connected_components']} | |
| - **Fragmentation Level:** {medical_report['measurements']['morphology']['fragmentation'].upper()} | |
| - **Largest Component Ratio:** {medical_report['measurements']['morphology']['largest_component_ratio']*100:.1f}% | |
| **Image Characteristics:** | |
| - **Volume Dimensions:** {medical_report['measurements']['volume_shape']} | |
| **Confidence Metrics:** | |
| - **Overall Confidence:** {confidence:.1f}% | |
| - **Confidence Interpretation:** {'High' if confidence >= 80 else 'Moderate' if confidence >= 60 else 'Low'} confidence segmentation | |
| --- | |
| ### Quality Assessment | |
| {chr(10).join(f"- {note}" for note in medical_report.get('quality_assessment', []))} | |
| --- | |
| ### Clinical Context | |
| {chr(10).join(f"- {note}" for note in medical_report.get('clinical_notes', [])) if medical_report.get('clinical_notes') else "- No additional clinical notes at this time."} | |
| --- | |
| ### Impression | |
| {medical_report['impression']} | |
| --- | |
| {f"### Recommendations{chr(10)}{chr(10)}{chr(10).join(f'- {rec}' for rec in medical_report['recommendations'])}{chr(10)}" if medical_report['recommendations'] else ""} | |
| ### Methodology | |
| {medical_report.get('methodology', 'SRMA-Mamba deep learning model for automated liver segmentation')} | |
| --- | |
| ### Important Notice | |
| {medical_report['disclaimer']} | |
| --- | |
| *Report generated automatically by SRMA-Mamba Liver Segmentation System* | |
| """ | |
| del original_volume, pred_binary | |
| if model_loader.DEVICE.type == 'cuda': | |
| torch.cuda.empty_cache() | |
| return overlay_img, info_text, report_text, output_path | |
| except FileNotFoundError as e: | |
| error_msg = f"**Checkpoint Error:** {str(e)}\n\nPlease ensure checkpoint files are uploaded to the Space." | |
| print(f"β {error_msg}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, f"## β Error\n\n{error_msg}", "", None | |
| except RuntimeError as e: | |
| error_msg = f"**Runtime Error:** {str(e)}\n\nThis might be a GPU/CUDA issue. The model will try to use CPU instead." | |
| print(f"β {error_msg}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, f"## β Error\n\n{error_msg}", "", None | |
| except Exception as e: | |
| error_msg = f"**Error during inference:** {str(e)}\n\nPlease check the logs for more details." | |
| print(f"β {error_msg}") | |
| import traceback | |
| tb = traceback.format_exc() | |
| print(f"Full traceback:\n{tb}") | |
| return None, f"## β Error\n\n{error_msg}\n\n**Error Type:** `{type(e).__name__}`", "", None | |
| finally: | |
| PROCESSING_LOCK.release() | |
| try: | |
| if image_tensor is not None: | |
| del image_tensor | |
| except (NameError, UnboundLocalError): | |
| pass | |
| try: | |
| if pred is not None: | |
| del pred | |
| except (NameError, UnboundLocalError): | |
| pass | |
| if model_loader.DEVICE.type == 'cuda': | |
| torch.cuda.empty_cache() | |
| print("β Memory cleaned up, lock released") | |
| def predict_volume_api(file_path: str, modality: str = 'T1', slice_idx: Optional[int] = None): | |
| confidence_score = 50.0 | |
| try: | |
| print("=" * 60) | |
| print(f"API: Starting prediction for modality: {modality}, file: {file_path}") | |
| print("=" * 60) | |
| if modality == 'T1': | |
| if model_loader.MODEL_T1 is None: | |
| print("API: Loading T1 model_loader...") | |
| model_loader.MODEL_T1 = model_loader.load_model('T1') | |
| model_instance = model_loader.MODEL_T1 | |
| else: | |
| if model_loader.MODEL_T2 is None: | |
| print("API: Loading T2 model_loader...") | |
| model_loader.MODEL_T2 = model_loader.load_model('T2') | |
| model_instance = model_loader.MODEL_T2 | |
| print("API: Preprocessing NIfTI file...") | |
| from pathlib import Path | |
| file_path = str(Path(file_path).resolve()) | |
| if not os.path.exists(file_path): | |
| raise FileNotFoundError(f"File not found: {file_path}") | |
| nifti_img = nib.load(file_path) | |
| nifti_img = nib.as_closest_canonical(nifti_img) | |
| voxel_spacing = nifti_img.header.get_zooms()[:3] if len(nifti_img.header.get_zooms()) >= 3 else (1.0, 1.0, 1.0) | |
| aff = nifti_img.affine | |
| hdr = nifti_img.header.copy() | |
| original_data = nifti_img.get_fdata(dtype=np.float32) | |
| if model_loader.DEVICE.type == 'cuda': | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("CUDA not available but device is set to CUDA") | |
| torch.cuda.empty_cache() | |
| free_mem = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / (1024**3) | |
| if free_mem < 0.5: | |
| raise RuntimeError(f"Insufficient GPU memory: {free_mem:.2f} GB free. Please restart.") | |
| image_data = preprocess_nifti(file_path, device=model_loader.DEVICE) | |
| if len(image_data.shape) == 3: | |
| image_tensor = image_data.unsqueeze(0).unsqueeze(0) | |
| elif len(image_data.shape) == 4: | |
| image_tensor = image_data.unsqueeze(0) | |
| else: | |
| image_tensor = image_data | |
| print(f"API: Input tensor shape: {image_tensor.shape}") | |
| if model_loader.WINDOW_INFER is not None: | |
| print("API: Adjusting ROI size based on volume dimensions...") | |
| adjust_roi_for_volume(image_tensor.shape) | |
| print("API: Running inference...") | |
| if model_loader.DEVICE.type == 'cuda': | |
| allocated = torch.cuda.memory_allocated(0) / (1024**3) | |
| free_memory_gb = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / (1024**3) | |
| print(f"API: Pre-inference GPU memory: {allocated:.2f} GB allocated, {free_memory_gb:.2f} GB free") | |
| if free_memory_gb < 0.5: | |
| print(f"API: β WARNING: Very low free memory ({free_memory_gb:.2f} GB). Adjusting settings...") | |
| if model_loader.WINDOW_INFER is not None: | |
| model_loader.WINDOW_INFER.sw_batch_size = 1 | |
| model_loader.WINDOW_INFER.roi_size = [192, 192, 32] | |
| model_loader.WINDOW_INFER.overlap = 0.25 | |
| torch.cuda.empty_cache() | |
| free_memory_gb = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / (1024**3) | |
| print(f"API: β After cache clear: {free_memory_gb:.2f} GB free") | |
| if free_memory_gb < 0.3: | |
| error_msg = f"API: Insufficient GPU memory ({free_memory_gb:.2f} GB free). Please use a smaller input volume." | |
| print(f"API: β {error_msg}") | |
| raise RuntimeError(error_msg) | |
| with torch.no_grad(): | |
| if model_loader.DEVICE.type == 'cuda': | |
| print("API: AMP (Automatic Mixed Precision) enabled: YES") | |
| sys.stdout.flush() | |
| try: | |
| with autocast(device_type='cuda'): | |
| y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance) | |
| except RuntimeError as e: | |
| if "out of memory" in str(e): | |
| print(f"API: β GPU OOM error: {e}") | |
| print("API: β Clearing cache and retrying with minimal settings...") | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| original_batch = model_loader.WINDOW_INFER.sw_batch_size | |
| original_roi = model_loader.WINDOW_INFER.roi_size | |
| original_overlap = model_loader.WINDOW_INFER.overlap | |
| model_loader.WINDOW_INFER.sw_batch_size = 1 | |
| model_loader.WINDOW_INFER.roi_size = [192, 192, 32] | |
| model_loader.WINDOW_INFER.overlap = 0.25 | |
| try: | |
| with autocast(device_type='cuda'): | |
| y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance) | |
| print("API: β Retry successful with minimal settings") | |
| except RuntimeError as e2: | |
| model_loader.WINDOW_INFER.sw_batch_size = original_batch | |
| model_loader.WINDOW_INFER.roi_size = original_roi | |
| model_loader.WINDOW_INFER.overlap = original_overlap | |
| allocated = torch.cuda.memory_allocated(0) / (1024**3) | |
| total_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3) | |
| free_memory_gb = (total_gb - allocated) | |
| error_msg = f"GPU out of memory even with minimal settings. GPU is {allocated:.2f} GB / {total_gb:.2f} GB ({100*allocated/total_gb:.1f}% full), only {free_memory_gb:.2f} GB free. The model requires CUDA. Please restart the Space or use a smaller input volume. Error: {e2}" | |
| print(f"API: β {error_msg}") | |
| raise RuntimeError(error_msg) | |
| finally: | |
| model_loader.WINDOW_INFER.sw_batch_size = original_batch | |
| model_loader.WINDOW_INFER.roi_size = original_roi | |
| model_loader.WINDOW_INFER.overlap = original_overlap | |
| else: | |
| raise | |
| else: | |
| print("API: AMP: Not available on CPU") | |
| y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance) | |
| print(f"API: Raw model output shapes: y1={y1.shape if hasattr(y1, 'shape') else 'N/A'}") | |
| sys.stdout.flush() | |
| pred = y1 | |
| if isinstance(pred, torch.Tensor): | |
| raw_min = pred.min().item() | |
| raw_max = pred.max().item() | |
| raw_mean = pred.mean().item() | |
| print(f"API: Raw logits (before sigmoid): min={raw_min:.6f}, max={raw_max:.6f}, mean={raw_mean:.6f}") | |
| sys.stdout.flush() | |
| print("API: Applying sigmoid activation (confirming: using sigmoid, not softmax)") | |
| sys.stdout.flush() | |
| pred = torch.sigmoid(pred) | |
| if len(pred.shape) == 5: | |
| if pred.shape[1] > 1: | |
| print(f"API: β WARNING: Multi-channel output detected (C={pred.shape[1]}). Using channel 0 for binary segmentation.") | |
| pred = pred[0, 0] | |
| else: | |
| pred = pred[0, 0] | |
| elif len(pred.shape) == 4: | |
| pred = pred[0] if pred.shape[0] == 1 else pred | |
| print(f"API: Final prediction shape after sigmoid: {pred.shape}") | |
| sys.stdout.flush() | |
| print("API: Post-processing predictions...") | |
| sys.stdout.flush() | |
| pred_np = pred.detach().cpu().numpy() | |
| print("=" * 60) | |
| print("API: DIAGNOSTIC: SIGMOID OUTPUT STATS (before thresholding)") | |
| print("=" * 60) | |
| print(f" Max: {pred_np.max():.6f}") | |
| print(f" Mean: {pred_np.mean():.6f}") | |
| print(f" Min: {pred_np.min():.6f}") | |
| print(f" Std: {pred_np.std():.6f}") | |
| print(f" Shape: {pred_np.shape}") | |
| sys.stdout.flush() | |
| default_threshold = float(os.environ.get("SEGMENTATION_THRESHOLD", "0.5")) | |
| count_035 = np.sum(pred_np > 0.35) | |
| count_03 = np.sum(pred_np > 0.3) | |
| count_05 = np.sum(pred_np > 0.5) | |
| total_voxels = pred_np.size | |
| threshold = default_threshold | |
| if count_05 == 0 and default_threshold >= 0.5: | |
| if count_035 > 0: | |
| print(f"API: β WARNING: No voxels > 0.5, but {count_035:,} voxels > 0.35. Trying threshold 0.35...") | |
| threshold = 0.35 | |
| elif count_03 > 0: | |
| print(f"API: β WARNING: No voxels > 0.5 or 0.35, but {count_03:,} voxels > 0.3. Trying threshold 0.3...") | |
| threshold = 0.3 | |
| else: | |
| percentile_thresh = float(np.quantile(pred_np.flatten(), 0.995)) | |
| threshold = max(0.3, percentile_thresh) | |
| print(f"API: β WARNING: No voxels > 0.5. Using percentile threshold: {threshold:.4f}") | |
| elif count_035 == 0 and default_threshold < 0.5: | |
| if count_03 > 0: | |
| print(f"API: β WARNING: No voxels > {default_threshold}, but {count_03:,} voxels > 0.3. Trying threshold 0.3...") | |
| threshold = 0.3 | |
| elif count_05 == 0: | |
| percentile_thresh = float(np.quantile(pred_np.flatten(), 0.995)) | |
| threshold = max(0.3, percentile_thresh) | |
| print(f"API: β WARNING: No voxels > {default_threshold}. Using percentile threshold: {threshold:.4f}") | |
| pred_binary = (pred_np > threshold).astype(np.uint8) | |
| fraction = pred_binary.sum() / total_voxels if total_voxels > 0 else 0.0 | |
| if fraction < 0.005 or fraction > 0.040: | |
| print(f"API: β WARNING: Mask fraction {100*fraction:.3f}% outside typical range (0.5%-4.0%). May indicate segmentation issues.") | |
| count_after_thresh = pred_binary.sum() | |
| print(f"API: After threshold ({threshold}): {count_after_thresh:,} voxels ({100*count_after_thresh/total_voxels:.2f}%)") | |
| if count_05 > 0 and count_after_thresh == 0 and threshold == 0.5: | |
| print(f"API: β WARNING: Had {count_05:,} voxels > 0.5 but 0 after thresholding - check threshold logic!") | |
| sys.stdout.flush() | |
| print("API: Refining liver segmentation mask...") | |
| try: | |
| pred_np_copy = pred_np.copy() | |
| pred_binary_refined, refinement_metrics, confidence_score = refine_liver_mask_enhanced( | |
| pred_binary, voxel_spacing, pred_np_copy, threshold, modality | |
| ) | |
| pred_binary = pred_binary_refined | |
| print(f"API: β Refinement applied: {refinement_metrics['connected_components_before']} β {refinement_metrics['connected_components_after']} components") | |
| print(f"API: β Volume change: {refinement_metrics['volume_change_percent']:.2f}% ({refinement_metrics['volume_change_ml']:.2f} ml)") | |
| print(f"API: β Confidence score: {confidence_score:.1f}%") | |
| except Exception as e: | |
| print(f"API: β Refinement failed (using original mask): {e}") | |
| confidence_score = 50.0 | |
| print("API: Loading original volume for visualization...") | |
| original_volume = original_data.copy() | |
| if original_volume.max() > original_volume.min(): | |
| original_volume = (original_volume - original_volume.min()) / (original_volume.max() - original_volume.min()) | |
| else: | |
| original_volume = np.zeros_like(original_volume) | |
| slice_idx = max(0, min(slice_idx if slice_idx is not None else (pred_binary.shape[-1] // 2), pred_binary.shape[-1] - 1)) | |
| if len(original_volume.shape) == 3: | |
| original_slice = original_volume[:, :, slice_idx] | |
| else: | |
| original_slice = original_volume[:, :, slice_idx] if original_volume.shape[2] > slice_idx else original_volume[:, :, 0] | |
| if len(pred_binary.shape) == 4: | |
| pred_slice = pred_binary[0, :, :, slice_idx] if pred_binary.shape[3] > slice_idx else pred_binary[0, :, :, 0] | |
| elif len(pred_binary.shape) == 5: | |
| pred_slice = pred_binary[0, 0, :, :, slice_idx] if pred_binary.shape[4] > slice_idx else pred_binary[0, 0, :, :, 0] | |
| else: | |
| pred_slice = pred_binary[:, :, slice_idx] if pred_binary.shape[2] > slice_idx else pred_binary[:, :, 0] | |
| if pred_slice.shape != original_slice.shape: | |
| print(f"API: Warning: Shape mismatch in overlay (pred_slice: {pred_slice.shape}, original_slice: {original_slice.shape}). Resizing pred_slice to match.") | |
| from scipy.ndimage import zoom | |
| zoom_factors = (original_slice.shape[0] / pred_slice.shape[0], original_slice.shape[1] / pred_slice.shape[1]) | |
| pred_slice = zoom(pred_slice, zoom_factors, order=0) | |
| overlay = np.zeros((*original_slice.shape, 3), dtype=np.uint8) | |
| overlay[:, :, 0] = (original_slice * 255).astype(np.uint8) | |
| overlay[:, :, 1] = (original_slice * 255).astype(np.uint8) | |
| overlay[:, :, 2] = (original_slice * 255).astype(np.uint8) | |
| mask = pred_slice > 0 | |
| overlay[mask, 0] = 0 | |
| overlay[mask, 1] = 255 | |
| overlay[mask, 2] = 0 | |
| overlay_img = Image.fromarray(overlay) | |
| img_buffer = io.BytesIO() | |
| overlay_img.save(img_buffer, format='PNG') | |
| img_base64 = base64.b64encode(img_buffer.getvalue()).decode('utf-8') | |
| print("API: Saving segmentation mask...") | |
| aff = nifti_img.affine | |
| hdr = nifti_img.header.copy() | |
| hdr.set_data_dtype(np.uint8) | |
| pred_save = pred_binary[0] if pred_binary.ndim == 4 else pred_binary | |
| if pred_binary.ndim == 5: | |
| pred_save = pred_binary[0, 0] | |
| original_shape = original_data.shape | |
| pred_shape = pred_save.shape | |
| if pred_shape != original_shape: | |
| print(f"API: Resampling prediction from {pred_shape} to original shape {original_shape}...") | |
| from scipy.ndimage import zoom | |
| zoom_factors = tuple(orig / pred for orig, pred in zip(original_shape, pred_shape)) | |
| pred_save = zoom(pred_save.astype(np.float32), zoom_factors, order=0).astype(np.uint8) | |
| print(f"API: β Resampled to match original dimensions") | |
| with tempfile.NamedTemporaryFile(suffix='.nii.gz', delete=False) as tmp_file: | |
| nifti_pred = nib.Nifti1Image(pred_save.astype(np.uint8), affine=aff, header=hdr) | |
| nib.save(nifti_pred, tmp_file.name) | |
| output_path = tmp_file.name | |
| total_voxels = pred_binary.size | |
| liver_voxels = pred_binary.sum() | |
| liver_percentage = (liver_voxels / total_voxels) * 100 | |
| print(f"API: Calculating volume and morphology...") | |
| volume_ml = calculate_liver_volume(pred_binary, voxel_spacing) | |
| if len(pred_binary.shape) == 4: | |
| mask_3d = pred_binary[0] | |
| elif len(pred_binary.shape) == 5: | |
| mask_3d = pred_binary[0, 0] | |
| else: | |
| mask_3d = pred_binary | |
| labels_final, num_components_final = ndimage.label(mask_3d) | |
| morphology = analyze_liver_morphology(pred_binary) | |
| from processing import check_volume_sanity | |
| volume_sanity_status, volume_sanity_msg = check_volume_sanity(volume_ml) | |
| if volume_sanity_status != "OK": | |
| print(f"API: β {volume_sanity_status}: {volume_sanity_msg}") | |
| if num_components_final != 1: | |
| print(f"API: β WARNING: Expected 1 connected component, found {num_components_final}. Quality check recommended.") | |
| medical_report = generate_medical_report( | |
| { | |
| "volume_shape": list(pred_binary.shape), | |
| "liver_voxels": int(liver_voxels), | |
| "total_voxels": int(total_voxels), | |
| "liver_percentage": float(liver_percentage), | |
| "slice_index": int(slice_idx), | |
| "total_slices": int(pred_binary.shape[-1]), | |
| "modality": modality, | |
| "confidence_score": float(confidence_score) | |
| }, | |
| volume_ml, | |
| morphology, | |
| modality, | |
| confidence_score | |
| ) | |
| print(f"API: β Segmentation complete. Liver: {liver_voxels:,} voxels ({liver_percentage:.2f}%), Volume: {volume_ml:.2f} ml") | |
| return { | |
| "success": True, | |
| "overlay_image": f"data:image/png;base64,{img_base64}", | |
| "segmentation_path": output_path, | |
| "statistics": { | |
| "volume_shape": list(pred_binary.shape), | |
| "liver_voxels": int(liver_voxels), | |
| "total_voxels": int(total_voxels), | |
| "liver_percentage": float(liver_percentage), | |
| "slice_index": int(slice_idx), | |
| "total_slices": int(pred_binary.shape[-1]), | |
| "modality": modality, | |
| "liver_volume_ml": round(volume_ml, 2) | |
| }, | |
| "medical_report": medical_report | |
| } | |
| except FileNotFoundError as e: | |
| error_msg = f"Checkpoint file not found: {str(e)}" | |
| print(error_msg) | |
| import traceback | |
| traceback.print_exc() | |
| return { | |
| "success": False, | |
| "error": error_msg, | |
| "error_type": "checkpoint_not_found" | |
| } | |
| except RuntimeError as e: | |
| error_msg = f"Runtime error (possibly CUDA/GPU): {str(e)}" | |
| print(error_msg) | |
| import traceback | |
| traceback.print_exc() | |
| return { | |
| "success": False, | |
| "error": error_msg, | |
| "error_type": "runtime_error" | |
| } | |
| except Exception as e: | |
| error_msg = f"Error during inference: {str(e)}" | |
| print(error_msg) | |
| import traceback | |
| tb = traceback.format_exc() | |
| print(f"Full traceback: {tb}") | |
| return { | |
| "success": False, | |
| "error": error_msg, | |
| "error_type": type(e).__name__ | |
| } | |
| def safe_predict_volume(nifti_file, modality, slice_idx=None): | |
| try: | |
| result = predict_volume(nifti_file, modality, slice_idx) | |
| if result is None or len(result) != 4: | |
| return None, "## β Error\n\nUnexpected return value from prediction function.", "", None | |
| overlay_img, info_text, report_text, output_path = result | |
| if overlay_img is None and info_text and "Error" in info_text: | |
| return None, info_text, report_text or "", output_path or None | |
| return overlay_img, info_text, report_text, output_path | |
| except Exception as e: | |
| error_msg = f"**Unexpected error:** {str(e)}\n\nPlease check the logs for details." | |
| print(f"β Safe wrapper caught error: {error_msg}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, f"## β Error\n\n{error_msg}", "", None | |