Harshith Reddy
Add detailed logging for API requests to show in Hugging Face logs
a1bd718
import os
import sys
import time
import tempfile
import base64
import io
import warnings
import threading
import numpy as np
import torch
import nibabel as nib
from PIL import Image
from typing import Optional
from torch.amp import autocast
from scipy import ndimage
import model_loader
from processing import preprocess_nifti, calculate_liver_volume, analyze_liver_morphology, generate_medical_report, refine_liver_mask, refine_liver_mask_enhanced
warnings.filterwarnings('ignore', category=UserWarning, message='.*non-tuple sequence for multidimensional indexing.*')
warnings.filterwarnings('ignore', category=FutureWarning, message='.*non-tuple sequence.*')
PROCESSING_LOCK = threading.Lock()
def adjust_roi_for_volume(volume_shape):
if model_loader.WINDOW_INFER is None:
return
if len(volume_shape) < 4:
return
if len(volume_shape) == 5:
_, _, depth, height, width = volume_shape
elif len(volume_shape) == 4:
_, depth, height, width = volume_shape
else:
depth, height, width = volume_shape[-3:]
current_roi = list(model_loader.WINDOW_INFER.roi_size) if isinstance(model_loader.WINDOW_INFER.roi_size, (list, tuple)) else [model_loader.WINDOW_INFER.roi_size[0], model_loader.WINDOW_INFER.roi_size[1], model_loader.WINDOW_INFER.roi_size[2]]
original_roi = current_roi.copy()
roi_d, roi_h, roi_w = current_roi
adjusted = False
if roi_d > depth:
current_roi[0] = min(depth, 64)
adjusted = True
print(f" ⚠ ROI depth ({roi_d}) > volume depth ({depth}). Adjusting to {current_roi[0]}")
if roi_h > height:
current_roi[1] = min(height, current_roi[1])
adjusted = True
print(f" ⚠ ROI height ({roi_h}) > volume height ({height}). Adjusting to {current_roi[1]}")
if roi_w > width:
current_roi[2] = min(width, current_roi[2])
adjusted = True
print(f" ⚠ ROI width ({roi_w}) > volume width ({width}). Adjusting to {current_roi[2]}")
total_volume = depth * height * width
roi_volume = current_roi[0] * current_roi[1] * current_roi[2]
if total_volume > 20_000_000:
overlap = 0.1
if model_loader.WINDOW_INFER.overlap > 0.1:
model_loader.WINDOW_INFER.overlap = overlap
adjusted = True
print(f" β†’ Large volume detected ({total_volume:,} voxels). Reducing overlap to {overlap} for faster processing")
elif total_volume > 10_000_000:
overlap = 0.12
if model_loader.WINDOW_INFER.overlap > 0.12:
model_loader.WINDOW_INFER.overlap = overlap
adjusted = True
print(f" β†’ Medium-large volume detected ({total_volume:,} voxels). Reducing overlap to {overlap}")
if depth < 64 and current_roi[0] > depth * 0.8:
current_roi[0] = max(32, int(depth * 0.8))
adjusted = True
print(f" β†’ Small depth dimension ({depth}). Optimizing ROI depth to {current_roi[0]}")
if adjusted:
model_loader.WINDOW_INFER.roi_size = tuple(current_roi)
num_windows_approx = ((depth / current_roi[0]) * (height / current_roi[1]) * (width / current_roi[2])) / (1 - model_loader.WINDOW_INFER.overlap)
print(f" βœ“ ROI adjusted: {original_roi} β†’ {current_roi}")
print(f" β†’ Estimated windows: ~{int(num_windows_approx)} (overlap={model_loader.WINDOW_INFER.overlap:.2f})")
def predict_volume(nifti_file, modality, slice_idx=None):
global PROCESSING_LOCK
if not PROCESSING_LOCK.acquire(blocking=False):
return None, "**Already processing a file. Please wait for the current inference to complete.**", "", None
if nifti_file is None:
PROCESSING_LOCK.release()
return None, "Please upload a NIfTI file (.nii.gz)", "", None
image_tensor = None
pred = None
output_path = None
try:
print(f"Starting prediction for modality: {modality}")
if modality == 'T1':
if model_loader.MODEL_T1 is None:
print("Loading T1 model_loader...")
model_loader.MODEL_T1 = model_loader.load_model('T1')
model_instance = model_loader.MODEL_T1
else:
if model_loader.MODEL_T2 is None:
print("Loading T2 model_loader...")
model_loader.MODEL_T2 = model_loader.load_model('T2')
model_instance = model_loader.MODEL_T2
print("Preprocessing NIfTI file...")
from pathlib import Path
fp = getattr(nifti_file, "name", nifti_file)
file_path = str(Path(fp).resolve())
if not os.path.exists(file_path):
raise FileNotFoundError(f"Uploaded file not found: {file_path}")
nifti_img = nib.load(file_path)
nifti_img = nib.as_closest_canonical(nifti_img)
voxel_spacing = nifti_img.header.get_zooms()[:3] if len(nifti_img.header.get_zooms()) >= 3 else (1.0, 1.0, 1.0)
aff = nifti_img.affine
hdr = nifti_img.header.copy()
original_data = nifti_img.get_fdata(dtype=np.float32)
if model_loader.DEVICE.type == 'cuda':
if not torch.cuda.is_available():
raise RuntimeError("CUDA not available but device is set to CUDA")
torch.cuda.empty_cache()
free_mem = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / (1024**3)
if free_mem < 0.5:
raise RuntimeError(f"Insufficient GPU memory: {free_mem:.2f} GB free. Please restart.")
image_data = preprocess_nifti(file_path, device=model_loader.DEVICE)
if len(image_data.shape) == 3:
image_tensor = image_data.unsqueeze(0).unsqueeze(0)
elif len(image_data.shape) == 4:
image_tensor = image_data.unsqueeze(0)
else:
image_tensor = image_data
if model_loader.DEVICE.type == 'cuda' and len(image_tensor.shape) >= 4:
try:
image_tensor = image_tensor.contiguous(memory_format=torch.channels_last_3d)
if image_tensor.is_contiguous(memory_format=torch.channels_last_3d):
print(f" β†’ Using channels-last 3D memory layout (optimized for GPU)")
except:
pass
print(f"Input tensor shape: {image_tensor.shape}")
if model_loader.WINDOW_INFER is not None:
print("Adjusting ROI size based on volume dimensions...")
adjust_roi_for_volume(image_tensor.shape)
print("Running inference...")
inference_start = time.time()
y1, y2, y3, y4 = None, None, None, None
gpu_handle = None
if model_loader.DEVICE.type == 'cuda':
allocated = torch.cuda.memory_allocated(0) / (1024**3)
free_memory_gb = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / (1024**3)
print(f" β†’ Pre-inference GPU memory: {allocated:.2f} GB allocated, {free_memory_gb:.2f} GB free")
try:
import pynvml
pynvml.nvmlInit()
gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(0)
util = pynvml.nvmlDeviceGetUtilizationRates(gpu_handle)
print(f" β†’ GPU utilization: {util.gpu}%")
except:
pass
if free_memory_gb < 1.0:
print(f" ⚠ WARNING: Very low free memory ({free_memory_gb:.2f} GB). Adjusting settings...")
if model_loader.WINDOW_INFER is not None:
model_loader.WINDOW_INFER.sw_batch_size = 1
if free_memory_gb < 0.5:
model_loader.WINDOW_INFER.roi_size = [192, 192, 32]
model_loader.WINDOW_INFER.overlap = 0.25
if free_memory_gb < 0.5:
torch.cuda.empty_cache()
free_memory_gb = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / (1024**3)
print(f" β†’ After cache clear: {free_memory_gb:.2f} GB free")
if free_memory_gb < 0.5:
error_msg = f"Insufficient GPU memory ({free_memory_gb:.2f} GB free). The model is using {allocated:.2f} GB. Please use a smaller input volume."
print(f" βœ— {error_msg}")
raise RuntimeError(error_msg)
import config
import threading
timeout_occurred = threading.Event()
def timeout_handler():
timeout_occurred.set()
timeout_timer = threading.Timer(config.INFERENCE_TIMEOUT, timeout_handler)
timeout_timer.start()
try:
with torch.no_grad():
if model_loader.DEVICE.type == 'cuda':
torch.cuda.synchronize()
print(" β†’ AMP enabled: YES")
sys.stdout.flush()
last_heartbeat = time.time()
heartbeat_interval = 30
try:
with autocast(device_type='cuda'):
print(f" β†’ Starting inference (timeout: {config.INFERENCE_TIMEOUT}s)...")
sys.stdout.flush()
start_time = time.time()
if timeout_occurred.is_set():
raise TimeoutError(f"Inference timeout: {config.INFERENCE_TIMEOUT}s")
gpu_utils = []
if gpu_handle is not None:
import threading
monitor_active = threading.Event()
monitor_active.set()
def monitor_gpu():
while monitor_active.is_set():
try:
import pynvml
util = pynvml.nvmlDeviceGetUtilizationRates(gpu_handle)
gpu_utils.append(util.gpu)
time.sleep(1.0)
except:
break
monitor_thread = threading.Thread(target=monitor_gpu, daemon=True)
monitor_thread.start()
torch.cuda.synchronize()
inference_start_sync = time.time()
y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance)
if gpu_handle is not None:
monitor_active.clear()
if monitor_thread.is_alive():
monitor_thread.join(timeout=2.0)
if timeout_occurred.is_set():
raise TimeoutError(f"Inference timeout: {config.INFERENCE_TIMEOUT}s")
torch.cuda.synchronize()
inference_time = time.time() - inference_start_sync
elapsed = time.time() - start_time
if gpu_handle is not None and len(gpu_utils) > 0:
try:
import pynvml
final_util = pynvml.nvmlDeviceGetUtilizationRates(gpu_handle)
avg_util = sum(gpu_utils) / len(gpu_utils) if gpu_utils else final_util.gpu
max_util = max(gpu_utils) if gpu_utils else final_util.gpu
print(f" βœ“ Inference completed in {inference_time:.1f}s (total: {elapsed:.1f}s)")
print(f" β†’ GPU utilization: avg={avg_util:.1f}%, max={max_util:.1f}%, final={final_util.gpu}%")
if avg_util < 50:
print(f" ⚠ WARNING: Low GPU utilization ({avg_util:.1f}%). May be CPU/transfer-bound.")
print(f" β†’ Consider: 1) Install CUDA extensions (mamba_ssm, selective_scan_cuda_oflex)")
print(f" β†’ 2) Check data transfer efficiency (channels-last enabled)")
print(f" β†’ 3) Verify GPU aggregation is enabled (check logs above)")
except:
print(f" βœ“ Inference completed in {inference_time:.1f}s (total: {elapsed:.1f}s)")
else:
print(f" βœ“ Inference completed in {inference_time:.1f}s (total: {elapsed:.1f}s)")
except TimeoutError:
timeout_timer.cancel()
raise
except RuntimeError as e:
if "out of memory" in str(e):
print(f" GPU OOM error: {e}")
print(" Applying progressive OOM fallback...")
torch.cuda.empty_cache()
original_batch = model_loader.WINDOW_INFER.sw_batch_size
original_roi = list(model_loader.WINDOW_INFER.roi_size)
original_overlap = model_loader.WINDOW_INFER.overlap
original_device = model_loader.WINDOW_INFER.device
original_device_str = str(original_device)
fallback_applied = None
try:
if original_batch > 1:
print(f" Fallback 1: Reducing sw_batch_size {original_batch} -> 1")
model_loader.WINDOW_INFER.sw_batch_size = 1
with autocast(device_type='cuda'):
y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance)
torch.cuda.synchronize()
fallback_applied = "sw_batch_size=1"
print(f" Retry successful with {fallback_applied}")
else:
raise RuntimeError("Already at batch_size=1")
except RuntimeError as e2:
if "out of memory" not in str(e2):
raise
try:
if original_roi[0] > 48:
new_depth = 48
print(f" Fallback 2: Reducing ROI depth {original_roi[0]} -> {new_depth}")
model_loader.WINDOW_INFER.roi_size = (new_depth, original_roi[1], original_roi[2])
with autocast(device_type='cuda'):
y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance)
torch.cuda.synchronize()
fallback_applied = f"roi_depth={new_depth}"
print(f" Retry successful with {fallback_applied}")
else:
raise RuntimeError("Already at depth=48")
except RuntimeError as e3:
if "out of memory" not in str(e3):
raise
try:
if original_roi[0] > 32:
new_depth = 32
print(f" Fallback 3: Reducing ROI depth {original_roi[0]} -> {new_depth}")
model_loader.WINDOW_INFER.roi_size = (new_depth, original_roi[1], original_roi[2])
with autocast(device_type='cuda'):
y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance)
torch.cuda.synchronize()
fallback_applied = f"roi_depth={new_depth}"
print(f" Retry successful with {fallback_applied}")
else:
raise RuntimeError("Already at depth=32")
except RuntimeError as e4:
if "out of memory" not in str(e4):
raise
try:
if original_device == 'cuda':
print(f" Fallback 4: Switching aggregation to CPU")
model_loader.WINDOW_INFER.device = torch.device('cpu')
with autocast(device_type='cuda'):
y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance)
torch.cuda.synchronize()
fallback_applied = "cpu_aggregation"
print(f" Retry successful with {fallback_applied}")
else:
raise RuntimeError("Already using CPU aggregation")
except RuntimeError as e5:
model_loader.WINDOW_INFER.sw_batch_size = original_batch
model_loader.WINDOW_INFER.roi_size = tuple(original_roi)
model_loader.WINDOW_INFER.overlap = original_overlap
model_loader.WINDOW_INFER.device = original_device if isinstance(original_device, torch.device) else torch.device(original_device_str)
allocated = torch.cuda.memory_allocated(0) / (1024**3)
total_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
free_memory_gb = (total_gb - allocated)
error_msg = f"GPU out of memory after all fallbacks. GPU: {allocated:.2f} GB / {total_gb:.2f} GB ({100*allocated/total_gb:.1f}% full), {free_memory_gb:.2f} GB free. Please restart Space or use smaller input. Error: {e5}"
print(f" {error_msg}")
raise RuntimeError(error_msg)
finally:
if fallback_applied:
print(f" Note: Applied fallback ({fallback_applied}). Restoring original settings after inference.")
model_loader.WINDOW_INFER.sw_batch_size = original_batch
model_loader.WINDOW_INFER.roi_size = tuple(original_roi)
model_loader.WINDOW_INFER.overlap = original_overlap
model_loader.WINDOW_INFER.device = original_device
else:
raise
else:
print(" β†’ AMP: Not available on CPU")
print(" β†’ Starting sliding window inference (this may take 5-10 minutes on CPU)...")
start_time = time.time()
y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance)
elapsed = time.time() - start_time
print(f" βœ“ Inference completed in {elapsed:.1f}s")
finally:
timeout_timer.cancel()
if y1 is not None:
print(f" β†’ Raw model output shapes: y1={y1.shape if hasattr(y1, 'shape') else 'N/A'}, y2={y2.shape if hasattr(y2, 'shape') else 'N/A'}")
sys.stdout.flush()
else:
print(" β†’ Inference failed before model output was generated")
sys.stdout.flush()
raise RuntimeError("Inference failed: model did not produce output")
pred = y1
if isinstance(pred, torch.Tensor):
raw_min = pred.min().item()
raw_max = pred.max().item()
raw_mean = pred.mean().item()
print(f" β†’ Raw logits (before sigmoid): min={raw_min:.6f}, max={raw_max:.6f}, mean={raw_mean:.6f}")
sys.stdout.flush()
print(" β†’ Applying sigmoid activation (confirming: using sigmoid, not softmax)")
sys.stdout.flush()
pred = torch.sigmoid(pred)
if len(pred.shape) == 5:
if pred.shape[1] > 1:
print(f" ⚠ WARNING: Multi-channel output detected (C={pred.shape[1]}). Using channel 0 for binary segmentation.")
pred = pred[0, 0]
else:
pred = pred[0, 0]
elif len(pred.shape) == 4:
pred = pred[0] if pred.shape[0] == 1 else pred
print(f" β†’ Final prediction shape after sigmoid: {pred.shape}")
sys.stdout.flush()
inference_time = time.time() - inference_start
print(f"βœ“ Inference completed in {inference_time:.1f} seconds")
del y1, y2, y3, y4
if model_loader.DEVICE.type == 'cuda' and (torch.cuda.memory_allocated(0) / (1024**3)) > 40:
torch.cuda.empty_cache()
print("Post-processing predictions...")
sys.stdout.flush()
pred = pred.detach().float().cpu()
if pred.ndim == 5:
pred_np = pred[0, 0].numpy()
elif pred.ndim == 4 and pred.shape[0] == 1:
pred_np = pred[0].numpy()
else:
pred_np = pred.numpy()
pred_max = pred_np.max()
pred_min = pred_np.min()
pred_mean = pred_np.mean()
if pred_max < 0.3:
print(f" ⚠ WARNING: Maximum prediction value is very low ({pred_max:.4f}). Model confidence may be low.")
print(f" This can occur with:")
print(f" - Low resolution/compressed input (lost texture cues)")
print(f" - Normalization mismatch (input not properly preprocessed)")
print(f" - Missing metadata (wrong voxel spacing/scaling)")
if pred_mean < 0.1:
print(f" ⚠ WARNING: Average prediction is very low ({pred_mean:.4f}). Most voxels have low confidence.")
print(f" Consider checking input quality and preprocessing.")
total_voxels = pred_np.size
if modality.upper() == "T1":
candidates = np.linspace(0.60, 0.80, 10)
frac_min, frac_max = 0.015, 0.040
else:
candidates = np.linspace(0.30, 0.70, 9)
frac_min, frac_max = 0.008, 0.040
best = None
for t in candidates:
mb = (pred_np > t).astype(np.uint8)
if mb.sum() < 1000:
continue
frac = mb.mean()
if not (frac_min <= frac <= frac_max):
continue
lbl, n = ndimage.label(mb)
if n == 0:
continue
sizes = ndimage.sum(mb, lbl, range(1, n + 1))
score = sizes.max() - (n - 1) * 1e5
if best is None or score > best[0]:
best = (score, t, mb)
if best:
_, threshold, pred_binary = best
print(f" Grid search selected threshold: {threshold:.3f}")
else:
default_threshold = float(os.environ.get("T1_THRESHOLD", "0.65")) if modality.upper() == "T1" else float(os.environ.get("SEGMENTATION_THRESHOLD", "0.5"))
threshold = default_threshold
pred_binary = (pred_np > threshold).astype(np.uint8)
print(f" Grid search failed, using default threshold: {threshold:.3f}")
count_default = pred_binary.sum()
count_035 = (pred_np > 0.35).astype(np.uint8).sum()
count_03 = (pred_np > 0.3).astype(np.uint8).sum()
if count_default == 0 and default_threshold >= 0.5:
if count_035 > 0:
print(f" ⚠ WARNING: No voxels > {default_threshold}, but {count_035:,} voxels > 0.35. Trying threshold 0.35...")
threshold = 0.35
elif count_03 > 0:
print(f" ⚠ WARNING: No voxels > {default_threshold} or 0.35, but {count_03:,} voxels > 0.3. Trying threshold 0.3...")
threshold = 0.3
else:
percentile_thresh = float(np.quantile(pred_np.flatten(), 0.995))
threshold = max(0.3, percentile_thresh)
print(f" ⚠ WARNING: No voxels > {default_threshold}. Using percentile threshold: {threshold:.4f}")
pred_binary = (pred_np > threshold).astype(np.uint8)
elif count_default == 0 and default_threshold < 0.5:
if count_03 > 0:
print(f" ⚠ WARNING: No voxels > {default_threshold}, but {count_03:,} voxels > 0.3. Trying threshold 0.3...")
threshold = 0.3
else:
percentile_thresh = float(np.quantile(pred_np.flatten(), 0.995))
threshold = max(0.3, percentile_thresh)
print(f" ⚠ WARNING: No voxels > {default_threshold}. Using percentile threshold: {threshold:.4f}")
pred_binary = (pred_np > threshold).astype(np.uint8)
if modality.upper() == "T1":
try:
h, w, d = original_data.shape
is_prenormalized = (original_data.max() <= 1.0 and original_data.min() >= 0.0)
if not is_prenormalized:
right_upper_quadrant = original_data[h//4:3*h//4, w//2:, d//3:2*d//3]
if right_upper_quadrant.size > 0 and (right_upper_quadrant > 0).sum() > 0:
median_intensity = np.median(right_upper_quadrant[right_upper_quadrant > 0])
mad = np.median(np.abs(right_upper_quadrant[right_upper_quadrant > 0] - median_intensity))
k = 1.5
intensity_lower = median_intensity - k * mad
intensity_upper = median_intensity + k * mad
if intensity_upper - intensity_lower > 0.5:
if len(pred_np.shape) == 4:
pred_3d = pred_np[0, 0]
elif len(pred_np.shape) == 5:
pred_3d = pred_np[0, 0, 0]
else:
pred_3d = pred_np
if pred_3d.shape == original_data.shape:
intensity_mask = (original_data >= intensity_lower) & (original_data <= intensity_upper)
pred_3d_clamped = pred_3d.copy()
pred_3d_clamped[~intensity_mask] = np.minimum(pred_3d_clamped[~intensity_mask], threshold * 0.6)
if pred_3d_clamped.std() > 1e-6:
if len(pred_np.shape) == 4:
pred_np[0, 0] = pred_3d_clamped
elif len(pred_np.shape) == 5:
pred_np[0, 0, 0] = pred_3d_clamped
else:
pred_np = pred_3d_clamped
print(f" Intensity gate: Clamped predictions outside liver-like intensity range [{intensity_lower:.2f}, {intensity_upper:.2f}]")
else:
print(f" Intensity gate: Shape mismatch (pred: {pred_3d.shape}, original: {original_data.shape}). Skipping intensity gate.")
else:
left_upper_region = original_data[h//4:3*h//4, :w//2, d//3:2*d//3]
if left_upper_region.size > 0 and (left_upper_region > 0).sum() > 100:
stomach_like_intensity = np.percentile(left_upper_region[left_upper_region > 0], 75)
if len(pred_np.shape) == 4:
pred_3d = pred_np[0, 0]
elif len(pred_np.shape) == 5:
pred_3d = pred_np[0, 0, 0]
else:
pred_3d = pred_np
if pred_3d.shape == original_data.shape:
stomach_mask = (original_data > stomach_like_intensity * 0.9) & (original_data < stomach_like_intensity * 1.1)
pred_3d_clamped = pred_3d.copy()
pred_3d_clamped[stomach_mask] = np.minimum(pred_3d_clamped[stomach_mask], threshold * 0.5)
if pred_3d_clamped.std() > 1e-6:
if len(pred_np.shape) == 4:
pred_np[0, 0] = pred_3d_clamped
elif len(pred_np.shape) == 5:
pred_np[0, 0, 0] = pred_3d_clamped
else:
pred_np = pred_3d_clamped
print(f" Intensity gate: Suppressed stomach-like intensities (threshold: {stomach_like_intensity:.3f})")
else:
print(f" Intensity gate: Shape mismatch (pred: {pred_3d.shape}, original: {original_data.shape}). Skipping intensity gate.")
except Exception as e:
print(f" Intensity gate failed (non-critical): {e}")
pred_binary = (pred_np > threshold).astype(np.uint8)
count_after_thresh = pred_binary.sum()
fraction = count_after_thresh / total_voxels if total_voxels > 0 else 0.0
volume_ml = calculate_liver_volume(pred_binary, voxel_spacing)
print(f" After threshold ({threshold:.3f}): {count_after_thresh:,} voxels ({100*fraction:.2f}%), volume: {volume_ml:.1f} ml")
if fraction > 0.040 or volume_ml > 2200:
threshold_max = float(os.environ.get("THRESHOLD_MAX", "0.90"))
new_threshold = min(threshold_max, threshold + 0.05)
if new_threshold > threshold:
print(f" Size-aware auto-tune: Mask too large (fraction={100*fraction:.2f}%, volume={volume_ml:.1f}ml). Increasing threshold {threshold:.3f} -> {new_threshold:.3f}")
threshold = new_threshold
pred_binary = (pred_np > threshold).astype(np.uint8)
count_after_thresh = pred_binary.sum()
fraction = count_after_thresh / total_voxels if total_voxels > 0 else 0.0
volume_ml = calculate_liver_volume(pred_binary, voxel_spacing)
print(f" After auto-tune threshold ({threshold:.3f}): {count_after_thresh:,} voxels ({100*fraction:.2f}%), volume: {volume_ml:.1f} ml")
if fraction < 0.005 or fraction > 0.040:
print(f" WARNING: Mask fraction {100*fraction:.3f}% outside typical range (0.5%-4.0%). May indicate segmentation issues.")
sys.stdout.flush()
print("Refining liver segmentation mask...")
sys.stdout.flush()
pred_np_copy = pred_np.copy()
try:
pred_binary_refined, refinement_metrics, confidence_score = refine_liver_mask_enhanced(
pred_binary, voxel_spacing, pred_np_copy, threshold, modality
)
pred_binary = pred_binary_refined
guards_ok = refinement_metrics.get('guards_ok', True)
print(f" Refinement applied: {refinement_metrics['connected_components_before']} -> {refinement_metrics['connected_components_after']} components")
print(f" Volume change: {refinement_metrics['volume_change_percent']:.2f}% ({refinement_metrics['volume_change_ml']:.2f} ml)")
print(f" Confidence score: {confidence_score:.1f}% {'(guards OK)' if guards_ok else '(guards triggered)'}")
except Exception as e:
print(f" Refinement failed (using original mask): {e}")
import traceback
traceback.print_exc()
confidence_score = 50.0
guards_ok = False
del pred_np, pred_np_copy
print("Loading original volume for visualization...")
original_volume = original_data.copy()
if original_volume.max() > original_volume.min():
original_volume = (original_volume - original_volume.min()) / (original_volume.max() - original_volume.min())
else:
original_volume = np.zeros_like(original_volume)
slice_idx = max(0, min(slice_idx if slice_idx is not None else (pred_binary.shape[-1] // 2), pred_binary.shape[-1] - 1))
if len(original_volume.shape) == 3:
original_slice = original_volume[:, :, slice_idx]
else:
original_slice = original_volume[:, :, slice_idx] if original_volume.shape[2] > slice_idx else original_volume[:, :, 0]
if len(pred_binary.shape) == 4:
pred_slice = pred_binary[0, :, :, slice_idx] if pred_binary.shape[3] > slice_idx else pred_binary[0, :, :, 0]
elif len(pred_binary.shape) == 5:
pred_slice = pred_binary[0, 0, :, :, slice_idx] if pred_binary.shape[4] > slice_idx else pred_binary[0, 0, :, :, 0]
else:
pred_slice = pred_binary[:, :, slice_idx] if pred_binary.shape[2] > slice_idx else pred_binary[:, :, 0]
if pred_slice.shape != original_slice.shape:
print(f" Warning: Shape mismatch in overlay (pred_slice: {pred_slice.shape}, original_slice: {original_slice.shape}). Resizing pred_slice to match.")
from scipy.ndimage import zoom
zoom_factors = (original_slice.shape[0] / pred_slice.shape[0], original_slice.shape[1] / pred_slice.shape[1])
pred_slice = zoom(pred_slice, zoom_factors, order=0)
overlay = np.zeros((*original_slice.shape, 3), dtype=np.uint8)
overlay[:, :, 0] = (original_slice * 255).astype(np.uint8)
overlay[:, :, 1] = (original_slice * 255).astype(np.uint8)
overlay[:, :, 2] = (original_slice * 255).astype(np.uint8)
mask = pred_slice > 0
overlay[mask, 0] = 0
overlay[mask, 1] = 255
overlay[mask, 2] = 0
overlay_img = Image.fromarray(overlay)
print("Saving segmentation mask...")
aff = nifti_img.affine
hdr = nifti_img.header.copy()
hdr.set_data_dtype(np.uint8)
pred_save = pred_binary[0] if pred_binary.ndim == 4 else pred_binary
if pred_binary.ndim == 5:
pred_save = pred_binary[0, 0]
original_shape = original_data.shape
pred_shape = pred_save.shape
if pred_shape != original_shape:
print(f" Resampling prediction from {pred_shape} to original shape {original_shape}...")
from scipy.ndimage import zoom
zoom_factors = tuple(orig / pred for orig, pred in zip(original_shape, pred_shape))
pred_save = zoom(pred_save.astype(np.float32), zoom_factors, order=0).astype(np.uint8)
print(f" βœ“ Resampled to match original dimensions")
with tempfile.NamedTemporaryFile(suffix='.nii.gz', delete=False) as tmp_file:
nifti_pred = nib.Nifti1Image(pred_save.astype(np.uint8), affine=aff, header=hdr)
nib.save(nifti_pred, tmp_file.name)
output_path = tmp_file.name
total_voxels = pred_binary.size
liver_voxels = pred_binary.sum()
liver_percentage = (liver_voxels / total_voxels) * 100
print(f"βœ“ Segmentation complete. Liver: {liver_voxels:,} voxels ({liver_percentage:.2f}%)")
volume_ml = calculate_liver_volume(pred_binary, voxel_spacing)
if len(pred_binary.shape) == 4:
mask_3d = pred_binary[0]
elif len(pred_binary.shape) == 5:
mask_3d = pred_binary[0, 0]
else:
mask_3d = pred_binary
labels_final, num_components_final = ndimage.label(mask_3d)
morphology = analyze_liver_morphology(pred_binary)
from processing import check_volume_sanity
volume_sanity_status, volume_sanity_msg = check_volume_sanity(volume_ml)
if volume_sanity_status != "OK":
print(f" ⚠ {volume_sanity_status}: {volume_sanity_msg}")
if num_components_final != 1:
print(f" ⚠ WARNING: Expected 1 connected component, found {num_components_final}. Quality check recommended.")
medical_report = generate_medical_report(
{
"volume_shape": list(pred_binary.shape),
"liver_voxels": int(liver_voxels),
"total_voxels": int(total_voxels),
"liver_percentage": float(liver_percentage),
"slice_index": int(slice_idx),
"total_slices": int(pred_binary.shape[-1]),
"modality": modality,
"confidence_score": float(confidence_score)
},
volume_ml,
morphology,
modality,
confidence_score
)
info_text = f"""
## Segmentation Results
- **Volume Shape:** {pred_binary.shape}
- **Liver Voxels:** {liver_voxels:,} ({liver_percentage:.2f}% of volume)
- **Liver Volume:** {volume_ml:.2f} ml
- **Slice:** {slice_idx + 1} of {pred_binary.shape[-1]} shown
- **Modality:** {modality}
"""
severity_text = medical_report['severity'].upper().replace('_', ' ')
confidence = medical_report['measurements'].get('confidence_score', 0.0)
report_text = f"""
## Automated Liver Segmentation Report
---
### Study Information
- **Study Date & Time:** {medical_report['study_date']}
- **Imaging Modality:** {medical_report['modality']}
- **Report Status:** **{severity_text}**
- **Confidence Score:** **{confidence:.1f}%**
---
### Key Findings
{chr(10).join(f"{finding}" for finding in medical_report['findings'])}
---
### Quantitative Measurements
**Volume Analysis:**
- **Liver Volume:** **{medical_report['measurements']['liver_volume_ml']} ml** ({medical_report['measurements']['liver_volume_liters']} L)
- **Liver Percentage of Scan:** {medical_report['measurements']['liver_percentage']}%
- **Segmented Voxels:** {medical_report['measurements']['liver_voxels']:,} / {medical_report['measurements']['total_voxels']:,} total voxels
**Morphological Analysis:**
- **Connected Components:** {medical_report['measurements']['morphology']['connected_components']}
- **Fragmentation Level:** {medical_report['measurements']['morphology']['fragmentation'].upper()}
- **Largest Component Ratio:** {medical_report['measurements']['morphology']['largest_component_ratio']*100:.1f}%
**Image Characteristics:**
- **Volume Dimensions:** {medical_report['measurements']['volume_shape']}
**Confidence Metrics:**
- **Overall Confidence:** {confidence:.1f}%
- **Confidence Interpretation:** {'High' if confidence >= 80 else 'Moderate' if confidence >= 60 else 'Low'} confidence segmentation
---
### Quality Assessment
{chr(10).join(f"- {note}" for note in medical_report.get('quality_assessment', []))}
---
### Clinical Context
{chr(10).join(f"- {note}" for note in medical_report.get('clinical_notes', [])) if medical_report.get('clinical_notes') else "- No additional clinical notes at this time."}
---
### Impression
{medical_report['impression']}
---
{f"### Recommendations{chr(10)}{chr(10)}{chr(10).join(f'- {rec}' for rec in medical_report['recommendations'])}{chr(10)}" if medical_report['recommendations'] else ""}
### Methodology
{medical_report.get('methodology', 'SRMA-Mamba deep learning model for automated liver segmentation')}
---
### Important Notice
{medical_report['disclaimer']}
---
*Report generated automatically by SRMA-Mamba Liver Segmentation System*
"""
del original_volume, pred_binary
if model_loader.DEVICE.type == 'cuda':
torch.cuda.empty_cache()
return overlay_img, info_text, report_text, output_path
except FileNotFoundError as e:
error_msg = f"**Checkpoint Error:** {str(e)}\n\nPlease ensure checkpoint files are uploaded to the Space."
print(f"βœ— {error_msg}")
import traceback
traceback.print_exc()
return None, f"## ❌ Error\n\n{error_msg}", "", None
except RuntimeError as e:
error_msg = f"**Runtime Error:** {str(e)}\n\nThis might be a GPU/CUDA issue. The model will try to use CPU instead."
print(f"βœ— {error_msg}")
import traceback
traceback.print_exc()
return None, f"## ❌ Error\n\n{error_msg}", "", None
except Exception as e:
error_msg = f"**Error during inference:** {str(e)}\n\nPlease check the logs for more details."
print(f"βœ— {error_msg}")
import traceback
tb = traceback.format_exc()
print(f"Full traceback:\n{tb}")
return None, f"## ❌ Error\n\n{error_msg}\n\n**Error Type:** `{type(e).__name__}`", "", None
finally:
PROCESSING_LOCK.release()
try:
if image_tensor is not None:
del image_tensor
except (NameError, UnboundLocalError):
pass
try:
if pred is not None:
del pred
except (NameError, UnboundLocalError):
pass
if model_loader.DEVICE.type == 'cuda':
torch.cuda.empty_cache()
print("βœ“ Memory cleaned up, lock released")
def predict_volume_api(file_path: str, modality: str = 'T1', slice_idx: Optional[int] = None):
confidence_score = 50.0
try:
print("=" * 60)
print(f"API: Starting prediction for modality: {modality}, file: {file_path}")
print("=" * 60)
if modality == 'T1':
if model_loader.MODEL_T1 is None:
print("API: Loading T1 model_loader...")
model_loader.MODEL_T1 = model_loader.load_model('T1')
model_instance = model_loader.MODEL_T1
else:
if model_loader.MODEL_T2 is None:
print("API: Loading T2 model_loader...")
model_loader.MODEL_T2 = model_loader.load_model('T2')
model_instance = model_loader.MODEL_T2
print("API: Preprocessing NIfTI file...")
from pathlib import Path
file_path = str(Path(file_path).resolve())
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
nifti_img = nib.load(file_path)
nifti_img = nib.as_closest_canonical(nifti_img)
voxel_spacing = nifti_img.header.get_zooms()[:3] if len(nifti_img.header.get_zooms()) >= 3 else (1.0, 1.0, 1.0)
aff = nifti_img.affine
hdr = nifti_img.header.copy()
original_data = nifti_img.get_fdata(dtype=np.float32)
if model_loader.DEVICE.type == 'cuda':
if not torch.cuda.is_available():
raise RuntimeError("CUDA not available but device is set to CUDA")
torch.cuda.empty_cache()
free_mem = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / (1024**3)
if free_mem < 0.5:
raise RuntimeError(f"Insufficient GPU memory: {free_mem:.2f} GB free. Please restart.")
image_data = preprocess_nifti(file_path, device=model_loader.DEVICE)
if len(image_data.shape) == 3:
image_tensor = image_data.unsqueeze(0).unsqueeze(0)
elif len(image_data.shape) == 4:
image_tensor = image_data.unsqueeze(0)
else:
image_tensor = image_data
print(f"API: Input tensor shape: {image_tensor.shape}")
if model_loader.WINDOW_INFER is not None:
print("API: Adjusting ROI size based on volume dimensions...")
adjust_roi_for_volume(image_tensor.shape)
print("API: Running inference...")
if model_loader.DEVICE.type == 'cuda':
allocated = torch.cuda.memory_allocated(0) / (1024**3)
free_memory_gb = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / (1024**3)
print(f"API: Pre-inference GPU memory: {allocated:.2f} GB allocated, {free_memory_gb:.2f} GB free")
if free_memory_gb < 0.5:
print(f"API: ⚠ WARNING: Very low free memory ({free_memory_gb:.2f} GB). Adjusting settings...")
if model_loader.WINDOW_INFER is not None:
model_loader.WINDOW_INFER.sw_batch_size = 1
model_loader.WINDOW_INFER.roi_size = [192, 192, 32]
model_loader.WINDOW_INFER.overlap = 0.25
torch.cuda.empty_cache()
free_memory_gb = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / (1024**3)
print(f"API: β†’ After cache clear: {free_memory_gb:.2f} GB free")
if free_memory_gb < 0.3:
error_msg = f"API: Insufficient GPU memory ({free_memory_gb:.2f} GB free). Please use a smaller input volume."
print(f"API: βœ— {error_msg}")
raise RuntimeError(error_msg)
with torch.no_grad():
if model_loader.DEVICE.type == 'cuda':
print("API: AMP (Automatic Mixed Precision) enabled: YES")
sys.stdout.flush()
try:
with autocast(device_type='cuda'):
y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance)
except RuntimeError as e:
if "out of memory" in str(e):
print(f"API: ⚠ GPU OOM error: {e}")
print("API: β†’ Clearing cache and retrying with minimal settings...")
torch.cuda.empty_cache()
torch.cuda.synchronize()
original_batch = model_loader.WINDOW_INFER.sw_batch_size
original_roi = model_loader.WINDOW_INFER.roi_size
original_overlap = model_loader.WINDOW_INFER.overlap
model_loader.WINDOW_INFER.sw_batch_size = 1
model_loader.WINDOW_INFER.roi_size = [192, 192, 32]
model_loader.WINDOW_INFER.overlap = 0.25
try:
with autocast(device_type='cuda'):
y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance)
print("API: βœ“ Retry successful with minimal settings")
except RuntimeError as e2:
model_loader.WINDOW_INFER.sw_batch_size = original_batch
model_loader.WINDOW_INFER.roi_size = original_roi
model_loader.WINDOW_INFER.overlap = original_overlap
allocated = torch.cuda.memory_allocated(0) / (1024**3)
total_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
free_memory_gb = (total_gb - allocated)
error_msg = f"GPU out of memory even with minimal settings. GPU is {allocated:.2f} GB / {total_gb:.2f} GB ({100*allocated/total_gb:.1f}% full), only {free_memory_gb:.2f} GB free. The model requires CUDA. Please restart the Space or use a smaller input volume. Error: {e2}"
print(f"API: βœ— {error_msg}")
raise RuntimeError(error_msg)
finally:
model_loader.WINDOW_INFER.sw_batch_size = original_batch
model_loader.WINDOW_INFER.roi_size = original_roi
model_loader.WINDOW_INFER.overlap = original_overlap
else:
raise
else:
print("API: AMP: Not available on CPU")
y1, y2, y3, y4 = model_loader.WINDOW_INFER(image_tensor, model_instance)
print(f"API: Raw model output shapes: y1={y1.shape if hasattr(y1, 'shape') else 'N/A'}")
sys.stdout.flush()
pred = y1
if isinstance(pred, torch.Tensor):
raw_min = pred.min().item()
raw_max = pred.max().item()
raw_mean = pred.mean().item()
print(f"API: Raw logits (before sigmoid): min={raw_min:.6f}, max={raw_max:.6f}, mean={raw_mean:.6f}")
sys.stdout.flush()
print("API: Applying sigmoid activation (confirming: using sigmoid, not softmax)")
sys.stdout.flush()
pred = torch.sigmoid(pred)
if len(pred.shape) == 5:
if pred.shape[1] > 1:
print(f"API: ⚠ WARNING: Multi-channel output detected (C={pred.shape[1]}). Using channel 0 for binary segmentation.")
pred = pred[0, 0]
else:
pred = pred[0, 0]
elif len(pred.shape) == 4:
pred = pred[0] if pred.shape[0] == 1 else pred
print(f"API: Final prediction shape after sigmoid: {pred.shape}")
sys.stdout.flush()
print("API: Post-processing predictions...")
sys.stdout.flush()
pred_np = pred.detach().cpu().numpy()
print("=" * 60)
print("API: DIAGNOSTIC: SIGMOID OUTPUT STATS (before thresholding)")
print("=" * 60)
print(f" Max: {pred_np.max():.6f}")
print(f" Mean: {pred_np.mean():.6f}")
print(f" Min: {pred_np.min():.6f}")
print(f" Std: {pred_np.std():.6f}")
print(f" Shape: {pred_np.shape}")
sys.stdout.flush()
default_threshold = float(os.environ.get("SEGMENTATION_THRESHOLD", "0.5"))
count_035 = np.sum(pred_np > 0.35)
count_03 = np.sum(pred_np > 0.3)
count_05 = np.sum(pred_np > 0.5)
total_voxels = pred_np.size
threshold = default_threshold
if count_05 == 0 and default_threshold >= 0.5:
if count_035 > 0:
print(f"API: ⚠ WARNING: No voxels > 0.5, but {count_035:,} voxels > 0.35. Trying threshold 0.35...")
threshold = 0.35
elif count_03 > 0:
print(f"API: ⚠ WARNING: No voxels > 0.5 or 0.35, but {count_03:,} voxels > 0.3. Trying threshold 0.3...")
threshold = 0.3
else:
percentile_thresh = float(np.quantile(pred_np.flatten(), 0.995))
threshold = max(0.3, percentile_thresh)
print(f"API: ⚠ WARNING: No voxels > 0.5. Using percentile threshold: {threshold:.4f}")
elif count_035 == 0 and default_threshold < 0.5:
if count_03 > 0:
print(f"API: ⚠ WARNING: No voxels > {default_threshold}, but {count_03:,} voxels > 0.3. Trying threshold 0.3...")
threshold = 0.3
elif count_05 == 0:
percentile_thresh = float(np.quantile(pred_np.flatten(), 0.995))
threshold = max(0.3, percentile_thresh)
print(f"API: ⚠ WARNING: No voxels > {default_threshold}. Using percentile threshold: {threshold:.4f}")
pred_binary = (pred_np > threshold).astype(np.uint8)
fraction = pred_binary.sum() / total_voxels if total_voxels > 0 else 0.0
if fraction < 0.005 or fraction > 0.040:
print(f"API: ⚠ WARNING: Mask fraction {100*fraction:.3f}% outside typical range (0.5%-4.0%). May indicate segmentation issues.")
count_after_thresh = pred_binary.sum()
print(f"API: After threshold ({threshold}): {count_after_thresh:,} voxels ({100*count_after_thresh/total_voxels:.2f}%)")
if count_05 > 0 and count_after_thresh == 0 and threshold == 0.5:
print(f"API: ⚠ WARNING: Had {count_05:,} voxels > 0.5 but 0 after thresholding - check threshold logic!")
sys.stdout.flush()
print("API: Refining liver segmentation mask...")
try:
pred_np_copy = pred_np.copy()
pred_binary_refined, refinement_metrics, confidence_score = refine_liver_mask_enhanced(
pred_binary, voxel_spacing, pred_np_copy, threshold, modality
)
pred_binary = pred_binary_refined
print(f"API: βœ“ Refinement applied: {refinement_metrics['connected_components_before']} β†’ {refinement_metrics['connected_components_after']} components")
print(f"API: β†’ Volume change: {refinement_metrics['volume_change_percent']:.2f}% ({refinement_metrics['volume_change_ml']:.2f} ml)")
print(f"API: β†’ Confidence score: {confidence_score:.1f}%")
except Exception as e:
print(f"API: ⚠ Refinement failed (using original mask): {e}")
confidence_score = 50.0
print("API: Loading original volume for visualization...")
original_volume = original_data.copy()
if original_volume.max() > original_volume.min():
original_volume = (original_volume - original_volume.min()) / (original_volume.max() - original_volume.min())
else:
original_volume = np.zeros_like(original_volume)
slice_idx = max(0, min(slice_idx if slice_idx is not None else (pred_binary.shape[-1] // 2), pred_binary.shape[-1] - 1))
if len(original_volume.shape) == 3:
original_slice = original_volume[:, :, slice_idx]
else:
original_slice = original_volume[:, :, slice_idx] if original_volume.shape[2] > slice_idx else original_volume[:, :, 0]
if len(pred_binary.shape) == 4:
pred_slice = pred_binary[0, :, :, slice_idx] if pred_binary.shape[3] > slice_idx else pred_binary[0, :, :, 0]
elif len(pred_binary.shape) == 5:
pred_slice = pred_binary[0, 0, :, :, slice_idx] if pred_binary.shape[4] > slice_idx else pred_binary[0, 0, :, :, 0]
else:
pred_slice = pred_binary[:, :, slice_idx] if pred_binary.shape[2] > slice_idx else pred_binary[:, :, 0]
if pred_slice.shape != original_slice.shape:
print(f"API: Warning: Shape mismatch in overlay (pred_slice: {pred_slice.shape}, original_slice: {original_slice.shape}). Resizing pred_slice to match.")
from scipy.ndimage import zoom
zoom_factors = (original_slice.shape[0] / pred_slice.shape[0], original_slice.shape[1] / pred_slice.shape[1])
pred_slice = zoom(pred_slice, zoom_factors, order=0)
overlay = np.zeros((*original_slice.shape, 3), dtype=np.uint8)
overlay[:, :, 0] = (original_slice * 255).astype(np.uint8)
overlay[:, :, 1] = (original_slice * 255).astype(np.uint8)
overlay[:, :, 2] = (original_slice * 255).astype(np.uint8)
mask = pred_slice > 0
overlay[mask, 0] = 0
overlay[mask, 1] = 255
overlay[mask, 2] = 0
overlay_img = Image.fromarray(overlay)
img_buffer = io.BytesIO()
overlay_img.save(img_buffer, format='PNG')
img_base64 = base64.b64encode(img_buffer.getvalue()).decode('utf-8')
print("API: Saving segmentation mask...")
aff = nifti_img.affine
hdr = nifti_img.header.copy()
hdr.set_data_dtype(np.uint8)
pred_save = pred_binary[0] if pred_binary.ndim == 4 else pred_binary
if pred_binary.ndim == 5:
pred_save = pred_binary[0, 0]
original_shape = original_data.shape
pred_shape = pred_save.shape
if pred_shape != original_shape:
print(f"API: Resampling prediction from {pred_shape} to original shape {original_shape}...")
from scipy.ndimage import zoom
zoom_factors = tuple(orig / pred for orig, pred in zip(original_shape, pred_shape))
pred_save = zoom(pred_save.astype(np.float32), zoom_factors, order=0).astype(np.uint8)
print(f"API: βœ“ Resampled to match original dimensions")
with tempfile.NamedTemporaryFile(suffix='.nii.gz', delete=False) as tmp_file:
nifti_pred = nib.Nifti1Image(pred_save.astype(np.uint8), affine=aff, header=hdr)
nib.save(nifti_pred, tmp_file.name)
output_path = tmp_file.name
total_voxels = pred_binary.size
liver_voxels = pred_binary.sum()
liver_percentage = (liver_voxels / total_voxels) * 100
print(f"API: Calculating volume and morphology...")
volume_ml = calculate_liver_volume(pred_binary, voxel_spacing)
if len(pred_binary.shape) == 4:
mask_3d = pred_binary[0]
elif len(pred_binary.shape) == 5:
mask_3d = pred_binary[0, 0]
else:
mask_3d = pred_binary
labels_final, num_components_final = ndimage.label(mask_3d)
morphology = analyze_liver_morphology(pred_binary)
from processing import check_volume_sanity
volume_sanity_status, volume_sanity_msg = check_volume_sanity(volume_ml)
if volume_sanity_status != "OK":
print(f"API: ⚠ {volume_sanity_status}: {volume_sanity_msg}")
if num_components_final != 1:
print(f"API: ⚠ WARNING: Expected 1 connected component, found {num_components_final}. Quality check recommended.")
medical_report = generate_medical_report(
{
"volume_shape": list(pred_binary.shape),
"liver_voxels": int(liver_voxels),
"total_voxels": int(total_voxels),
"liver_percentage": float(liver_percentage),
"slice_index": int(slice_idx),
"total_slices": int(pred_binary.shape[-1]),
"modality": modality,
"confidence_score": float(confidence_score)
},
volume_ml,
morphology,
modality,
confidence_score
)
print(f"API: βœ“ Segmentation complete. Liver: {liver_voxels:,} voxels ({liver_percentage:.2f}%), Volume: {volume_ml:.2f} ml")
return {
"success": True,
"overlay_image": f"data:image/png;base64,{img_base64}",
"segmentation_path": output_path,
"statistics": {
"volume_shape": list(pred_binary.shape),
"liver_voxels": int(liver_voxels),
"total_voxels": int(total_voxels),
"liver_percentage": float(liver_percentage),
"slice_index": int(slice_idx),
"total_slices": int(pred_binary.shape[-1]),
"modality": modality,
"liver_volume_ml": round(volume_ml, 2)
},
"medical_report": medical_report
}
except FileNotFoundError as e:
error_msg = f"Checkpoint file not found: {str(e)}"
print(error_msg)
import traceback
traceback.print_exc()
return {
"success": False,
"error": error_msg,
"error_type": "checkpoint_not_found"
}
except RuntimeError as e:
error_msg = f"Runtime error (possibly CUDA/GPU): {str(e)}"
print(error_msg)
import traceback
traceback.print_exc()
return {
"success": False,
"error": error_msg,
"error_type": "runtime_error"
}
except Exception as e:
error_msg = f"Error during inference: {str(e)}"
print(error_msg)
import traceback
tb = traceback.format_exc()
print(f"Full traceback: {tb}")
return {
"success": False,
"error": error_msg,
"error_type": type(e).__name__
}
def safe_predict_volume(nifti_file, modality, slice_idx=None):
try:
result = predict_volume(nifti_file, modality, slice_idx)
if result is None or len(result) != 4:
return None, "## ❌ Error\n\nUnexpected return value from prediction function.", "", None
overlay_img, info_text, report_text, output_path = result
if overlay_img is None and info_text and "Error" in info_text:
return None, info_text, report_text or "", output_path or None
return overlay_img, info_text, report_text, output_path
except Exception as e:
error_msg = f"**Unexpected error:** {str(e)}\n\nPlease check the logs for details."
print(f"βœ— Safe wrapper caught error: {error_msg}")
import traceback
traceback.print_exc()
return None, f"## ❌ Error\n\n{error_msg}", "", None