import os import time import torch from monai.inferers import SlidingWindowInferer from config import BUILD_SRMAMAMBA_AVAILABLE, build_SRMAMamba, SRMA_MAMBA_DIR MODEL_T1 = None MODEL_T2 = None DEVICE = torch.device('cpu') WINDOW_INFER = None def clear_gpu_memory(): global MODEL_T1, MODEL_T2, WINDOW_INFER if torch.cuda.is_available(): if MODEL_T1 is not None: del MODEL_T1 MODEL_T1 = None if MODEL_T2 is not None: del MODEL_T2 MODEL_T2 = None if WINDOW_INFER is not None: del WINDOW_INFER WINDOW_INFER = None torch.cuda.empty_cache() torch.cuda.synchronize() print(" → GPU memory cleared (models unloaded)") def load_model(modality='T1'): global MODEL_T1, MODEL_T2, DEVICE, WINDOW_INFER, BUILD_SRMAMAMBA_AVAILABLE if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() if not BUILD_SRMAMAMBA_AVAILABLE or build_SRMAMamba is None: error_msg = "Model builder (build_SRMAMamba) is not available. Please check the logs for import errors." print(f"✗ {error_msg}") raise ImportError(error_msg) print(f"Loading {modality} model...") if torch.cuda.is_available(): try: max_retries = 3 retry_delay = 2 for attempt in range(max_retries): try: torch.cuda.empty_cache() test_tensor = torch.zeros(1).cuda() del test_tensor torch.cuda.synchronize() DEVICE = torch.device('cuda') print(f"✓ Using device: {DEVICE}") break except RuntimeError as e: if "CUDA" in str(e) and attempt < max_retries - 1: print(f"⚠ GPU wake-up attempt {attempt + 1}/{max_retries}: {e}") print(f"⚠ Waiting {retry_delay}s for GPU to wake up...") time.sleep(retry_delay) retry_delay *= 2 else: raise except Exception as e: print(f"⚠ CUDA available but failed to initialize: {e}. Falling back to CPU.") DEVICE = torch.device('cpu') else: DEVICE = torch.device('cpu') print(f"ℹ CUDA not available. Using device: {DEVICE}") if DEVICE.type == 'cuda': torch.cuda.empty_cache() torch.cuda.synchronize() allocated = torch.cuda.memory_allocated(0) / (1024**3) reserved = torch.cuda.memory_reserved(0) / (1024**3) total = torch.cuda.get_device_properties(0).total_memory / (1024**3) free_memory_gb = total - allocated print(f" → GPU memory: {allocated:.2f} GB allocated, {reserved:.2f} GB reserved, {free_memory_gb:.2f} GB free (total: {total:.2f} GB)") if free_memory_gb < 1.0: print(f" ⚠ CRITICAL: Very low free memory ({free_memory_gb:.2f} GB). Using ultra-minimal settings.") size = [192, 192, 32] batch_size = 1 overlap = 0.25 elif free_memory_gb < 2.0: print(f" ⚠ WARNING: Very low free memory ({free_memory_gb:.2f} GB). Using minimal settings.") size = [192, 192, 32] batch_size = 1 overlap = 0.25 elif free_memory_gb < 5.0: size = [224, 224, 48] batch_size = 1 overlap = 0.2 elif free_memory_gb > 40: print(f" Very high VRAM GPU detected ({free_memory_gb:.2f} GB free). Using optimal settings for maximum speed.") size = [256, 256, 80] batch_size = 2 overlap = 0.1 elif free_memory_gb > 30: print(f" High VRAM GPU detected ({free_memory_gb:.2f} GB free). Using optimal settings for speed.") size = [256, 256, 64] batch_size = 2 overlap = 0.1 elif free_memory_gb > 25: print(f" ✓ Large VRAM GPU detected ({free_memory_gb:.2f} GB free). Using optimal settings.") size = [256, 256, 64] batch_size = 1 overlap = 0.1 elif free_memory_gb > 20: size = [256, 256, 64] batch_size = 1 overlap = 0.1 elif free_memory_gb > 15: size = [256, 256, 64] batch_size = 1 overlap = 0.1 elif free_memory_gb > 10: size = [224, 224, 64] batch_size = 1 overlap = 0.1 elif free_memory_gb > 8: size = [224, 224, 48] batch_size = 1 overlap = 0.2 else: size = [192, 192, 48] batch_size = 1 overlap = 0.2 else: size = [224, 224, 64] batch_size = 1 overlap = 0.15 print(f" → Sliding window config: roi_size={size}, sw_batch_size={batch_size}, overlap={overlap}") print("Building model architecture...") if SRMA_MAMBA_DIR: original_cwd = os.getcwd() try: os.chdir(SRMA_MAMBA_DIR) print(f"Changed working directory to: {SRMA_MAMBA_DIR}") model = build_SRMAMamba() print("✓ Model architecture built") finally: os.chdir(original_cwd) else: model = build_SRMAMamba() print("✓ Model architecture built") model = model.to(DEVICE) print(f"✓ Model moved to {DEVICE}") checkpoint_path = f"checkpoint_{modality}.pth" possible_paths = [ checkpoint_path, os.path.join(os.path.dirname(__file__), checkpoint_path), f"../../Chkpoints/checkpoint_{modality}.pth", f"Chkpoints/checkpoint_{modality}.pth", f"../Chkpoints/checkpoint_{modality}.pth", f"Model/Chkpoints/checkpoint_{modality}.pth", os.path.join(os.path.dirname(__file__), f"Chkpoints/checkpoint_{modality}.pth"), ] found = False for path in possible_paths: abs_path = os.path.abspath(path) if os.path.exists(path) or os.path.exists(abs_path): checkpoint_path = path if os.path.exists(path) else abs_path found = True print(f"✓ Found checkpoint at: {checkpoint_path}") break if not found: try: from huggingface_hub import hf_hub_download repo_id = os.environ.get("HF_MODEL_REPO", "HarshithReddy01/srmamamba-liver-segmentation") print(f"Attempting to download checkpoint from Hugging Face: {repo_id}") checkpoint_path = hf_hub_download( repo_id=repo_id, filename=f"checkpoint_{modality}.pth", cache_dir="." ) found = True print(f"✓ Downloaded checkpoint to: {checkpoint_path}") except Exception as e: error_msg = f"Checkpoint not found. Searched: {possible_paths}. Hugging Face download failed: {str(e)}" print(f"✗ {error_msg}") raise FileNotFoundError(error_msg) print(f"Loading checkpoint weights from: {checkpoint_path}") try: checkpoint = torch.load(checkpoint_path, map_location=DEVICE) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: model.load_state_dict(checkpoint['state_dict']) else: model.load_state_dict(checkpoint) print("✓ Checkpoint loaded successfully") except Exception as e: print(f"✗ Failed to load checkpoint: {e}") raise model.eval() print("✓ Model set to evaluation mode") if DEVICE.type == 'cuda': import config from packaging import version torch_version = version.parse(torch.__version__) if torch_version >= version.parse("2.9.0"): torch.backends.cuda.matmul.fp32_precision = 'tf32' torch.backends.cudnn.conv.fp32_precision = 'tf32' tf32_matmul = torch.backends.cuda.matmul.fp32_precision tf32_conv = torch.backends.cudnn.conv.fp32_precision else: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True tf32_matmul = 'tf32' if torch.backends.cuda.matmul.allow_tf32 else 'ieee' tf32_conv = 'tf32' if torch.backends.cudnn.allow_tf32 else 'ieee' torch.backends.cudnn.benchmark = True print(f"TF32 enabled: matmul={tf32_matmul}, conv={tf32_conv}") print("cuDNN benchmarking enabled") if config.ENABLE_TORCH_COMPILE: try: compile_mode = os.environ.get('TORCH_COMPILE_MODE', 'reduce-overhead') if compile_mode == 'max-autotune': print(f" → Compiling with max-autotune (may take 2-5 min on first run)...") model = torch.compile(model, mode='max-autotune', fullgraph=False) print(f"✓ Model compiled with torch.compile (mode=max-autotune, fullgraph=False)") elif compile_mode == 'default': print(f" → Compiling with default mode (may take 1-3 min on first run)...") model = torch.compile(model, fullgraph=False) print(f"✓ Model compiled with torch.compile (mode=default, fullgraph=False)") else: print(f" → Compiling with reduce-overhead (faster first run, ~30-60s)...") model = torch.compile(model, mode='reduce-overhead', fullgraph=False) print(f"✓ Model compiled with torch.compile (mode=reduce-overhead, fullgraph=False)") except Exception as e: print(f" ⚠ torch.compile failed: {e}. Continuing without compilation.") else: print(" ℹ torch.compile disabled (set ENABLE_TORCH_COMPILE=true to enable)") torch.cuda.empty_cache() torch.cuda.synchronize() allocated_after_load = torch.cuda.memory_allocated(0) / (1024**3) free_after_load = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / (1024**3) print(f" → GPU memory after model load: {allocated_after_load:.2f} GB allocated, {free_after_load:.2f} GB free") if free_after_load < 1.0: print(f" ⚠ CRITICAL: Only {free_after_load:.2f} GB free after model load. Using ultra-minimal settings.") size = [192, 192, 32] batch_size = 1 overlap = 0.25 elif free_after_load < 2.0: print(f" ⚠ WARNING: Low free memory ({free_after_load:.2f} GB) after model load. Adjusting to minimal settings.") size = [192, 192, 32] batch_size = 1 overlap = 0.25 elif free_after_load > 40: print(f" Excellent free memory ({free_after_load:.2f} GB) after model load. Using optimal settings for maximum speed.") size = [256, 256, 80] batch_size = 2 overlap = 0.1 elif free_after_load > 30: print(f" Excellent free memory ({free_after_load:.2f} GB) after model load. Using optimal settings for speed.") size = [256, 256, 64] batch_size = 2 overlap = 0.1 elif free_after_load > 25: print(f" ✓ Good free memory ({free_after_load:.2f} GB) after model load. Using optimal settings.") size = [256, 256, 64] batch_size = 1 overlap = 0.1 elif free_after_load > 20: print(f" ✓ Good free memory ({free_after_load:.2f} GB) after model load. Using optimal settings.") size = [256, 256, 64] batch_size = 1 overlap = 0.1 elif free_after_load > 15: size = [256, 256, 64] batch_size = 1 overlap = 0.1 elif free_after_load < 5.0 and (size[0] > 224 or batch_size > 1): print(f" ⚠ WARNING: Limited free memory ({free_after_load:.2f} GB). Reducing window size and batch size.") size = [224, 224, 48] batch_size = 1 overlap = 0.1 aggregation_device = 'cuda' if free_after_load < 2.0: aggregation_device = 'cpu' print(f" → Very low VRAM ({free_after_load:.2f} GB), using CPU aggregation to prevent OOM") else: print(f" → Using GPU aggregation for maximum speed (VRAM: {free_after_load:.2f} GB free)") WINDOW_INFER = SlidingWindowInferer( roi_size=size, sw_batch_size=batch_size, overlap=overlap, sw_device='cuda', device=aggregation_device ) print(f"✓ Sliding window inferer created (GPU compute, {aggregation_device.upper()} aggregation)") if DEVICE.type == 'cuda': if config.ENABLE_TORCH_COMPILE: print(" Running warm-up inference to trigger compilation and kernel autotuning...") print(" This may take 30-60s (reduce-overhead) or 2-5min (max-autotune) on first run...") else: print(" Running warm-up inference to trigger kernel autotuning...") try: dummy_input = torch.randn(1, 1, size[0], size[1], size[2], device=DEVICE, dtype=torch.float32) dummy_input = dummy_input.contiguous(memory_format=torch.channels_last_3d) warmup_start = time.time() with torch.no_grad(): from torch.amp import autocast with autocast(device_type='cuda'): _ = model(dummy_input) torch.cuda.synchronize() warmup_time = time.time() - warmup_start del dummy_input, _ torch.cuda.empty_cache() if config.ENABLE_TORCH_COMPILE: print(f" Warm-up completed in {warmup_time:.1f}s (compilation + kernel autotuning)") else: print(f" Warm-up completed in {warmup_time:.1f}s (kernels autotuned)") except RuntimeError as e: if "out of memory" in str(e): print(f" Warm-up OOM (non-critical): {e}") print(f" Will use progressive fallback during inference") else: print(f" Warm-up failed (non-critical): {e}") except Exception as e: print(f" Warm-up failed (non-critical): {e}") if modality == 'T1': MODEL_T1 = model else: MODEL_T2 = model print(f"✓ {modality} model loaded and ready") return model