srmamamba-liver-segmentation / model_loader.py
Harshith Reddy
Fix TF32 API
d8b8562
import os
import time
import torch
from monai.inferers import SlidingWindowInferer
from config import BUILD_SRMAMAMBA_AVAILABLE, build_SRMAMamba, SRMA_MAMBA_DIR
MODEL_T1 = None
MODEL_T2 = None
DEVICE = torch.device('cpu')
WINDOW_INFER = None
def clear_gpu_memory():
global MODEL_T1, MODEL_T2, WINDOW_INFER
if torch.cuda.is_available():
if MODEL_T1 is not None:
del MODEL_T1
MODEL_T1 = None
if MODEL_T2 is not None:
del MODEL_T2
MODEL_T2 = None
if WINDOW_INFER is not None:
del WINDOW_INFER
WINDOW_INFER = None
torch.cuda.empty_cache()
torch.cuda.synchronize()
print(" β†’ GPU memory cleared (models unloaded)")
def load_model(modality='T1'):
global MODEL_T1, MODEL_T2, DEVICE, WINDOW_INFER, BUILD_SRMAMAMBA_AVAILABLE
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
if not BUILD_SRMAMAMBA_AVAILABLE or build_SRMAMamba is None:
error_msg = "Model builder (build_SRMAMamba) is not available. Please check the logs for import errors."
print(f"βœ— {error_msg}")
raise ImportError(error_msg)
print(f"Loading {modality} model...")
if torch.cuda.is_available():
try:
max_retries = 3
retry_delay = 2
for attempt in range(max_retries):
try:
torch.cuda.empty_cache()
test_tensor = torch.zeros(1).cuda()
del test_tensor
torch.cuda.synchronize()
DEVICE = torch.device('cuda')
print(f"βœ“ Using device: {DEVICE}")
break
except RuntimeError as e:
if "CUDA" in str(e) and attempt < max_retries - 1:
print(f"⚠ GPU wake-up attempt {attempt + 1}/{max_retries}: {e}")
print(f"⚠ Waiting {retry_delay}s for GPU to wake up...")
time.sleep(retry_delay)
retry_delay *= 2
else:
raise
except Exception as e:
print(f"⚠ CUDA available but failed to initialize: {e}. Falling back to CPU.")
DEVICE = torch.device('cpu')
else:
DEVICE = torch.device('cpu')
print(f"β„Ή CUDA not available. Using device: {DEVICE}")
if DEVICE.type == 'cuda':
torch.cuda.empty_cache()
torch.cuda.synchronize()
allocated = torch.cuda.memory_allocated(0) / (1024**3)
reserved = torch.cuda.memory_reserved(0) / (1024**3)
total = torch.cuda.get_device_properties(0).total_memory / (1024**3)
free_memory_gb = total - allocated
print(f" β†’ GPU memory: {allocated:.2f} GB allocated, {reserved:.2f} GB reserved, {free_memory_gb:.2f} GB free (total: {total:.2f} GB)")
if free_memory_gb < 1.0:
print(f" ⚠ CRITICAL: Very low free memory ({free_memory_gb:.2f} GB). Using ultra-minimal settings.")
size = [192, 192, 32]
batch_size = 1
overlap = 0.25
elif free_memory_gb < 2.0:
print(f" ⚠ WARNING: Very low free memory ({free_memory_gb:.2f} GB). Using minimal settings.")
size = [192, 192, 32]
batch_size = 1
overlap = 0.25
elif free_memory_gb < 5.0:
size = [224, 224, 48]
batch_size = 1
overlap = 0.2
elif free_memory_gb > 40:
print(f" Very high VRAM GPU detected ({free_memory_gb:.2f} GB free). Using optimal settings for maximum speed.")
size = [256, 256, 80]
batch_size = 2
overlap = 0.1
elif free_memory_gb > 30:
print(f" High VRAM GPU detected ({free_memory_gb:.2f} GB free). Using optimal settings for speed.")
size = [256, 256, 64]
batch_size = 2
overlap = 0.1
elif free_memory_gb > 25:
print(f" βœ“ Large VRAM GPU detected ({free_memory_gb:.2f} GB free). Using optimal settings.")
size = [256, 256, 64]
batch_size = 1
overlap = 0.1
elif free_memory_gb > 20:
size = [256, 256, 64]
batch_size = 1
overlap = 0.1
elif free_memory_gb > 15:
size = [256, 256, 64]
batch_size = 1
overlap = 0.1
elif free_memory_gb > 10:
size = [224, 224, 64]
batch_size = 1
overlap = 0.1
elif free_memory_gb > 8:
size = [224, 224, 48]
batch_size = 1
overlap = 0.2
else:
size = [192, 192, 48]
batch_size = 1
overlap = 0.2
else:
size = [224, 224, 64]
batch_size = 1
overlap = 0.15
print(f" β†’ Sliding window config: roi_size={size}, sw_batch_size={batch_size}, overlap={overlap}")
print("Building model architecture...")
if SRMA_MAMBA_DIR:
original_cwd = os.getcwd()
try:
os.chdir(SRMA_MAMBA_DIR)
print(f"Changed working directory to: {SRMA_MAMBA_DIR}")
model = build_SRMAMamba()
print("βœ“ Model architecture built")
finally:
os.chdir(original_cwd)
else:
model = build_SRMAMamba()
print("βœ“ Model architecture built")
model = model.to(DEVICE)
print(f"βœ“ Model moved to {DEVICE}")
checkpoint_path = f"checkpoint_{modality}.pth"
possible_paths = [
checkpoint_path,
os.path.join(os.path.dirname(__file__), checkpoint_path),
f"../../Chkpoints/checkpoint_{modality}.pth",
f"Chkpoints/checkpoint_{modality}.pth",
f"../Chkpoints/checkpoint_{modality}.pth",
f"Model/Chkpoints/checkpoint_{modality}.pth",
os.path.join(os.path.dirname(__file__), f"Chkpoints/checkpoint_{modality}.pth"),
]
found = False
for path in possible_paths:
abs_path = os.path.abspath(path)
if os.path.exists(path) or os.path.exists(abs_path):
checkpoint_path = path if os.path.exists(path) else abs_path
found = True
print(f"βœ“ Found checkpoint at: {checkpoint_path}")
break
if not found:
try:
from huggingface_hub import hf_hub_download
repo_id = os.environ.get("HF_MODEL_REPO", "HarshithReddy01/srmamamba-liver-segmentation")
print(f"Attempting to download checkpoint from Hugging Face: {repo_id}")
checkpoint_path = hf_hub_download(
repo_id=repo_id,
filename=f"checkpoint_{modality}.pth",
cache_dir="."
)
found = True
print(f"βœ“ Downloaded checkpoint to: {checkpoint_path}")
except Exception as e:
error_msg = f"Checkpoint not found. Searched: {possible_paths}. Hugging Face download failed: {str(e)}"
print(f"βœ— {error_msg}")
raise FileNotFoundError(error_msg)
print(f"Loading checkpoint weights from: {checkpoint_path}")
try:
checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint)
print("βœ“ Checkpoint loaded successfully")
except Exception as e:
print(f"βœ— Failed to load checkpoint: {e}")
raise
model.eval()
print("βœ“ Model set to evaluation mode")
if DEVICE.type == 'cuda':
import config
from packaging import version
torch_version = version.parse(torch.__version__)
if torch_version >= version.parse("2.9.0"):
torch.backends.cuda.matmul.fp32_precision = 'tf32'
torch.backends.cudnn.conv.fp32_precision = 'tf32'
tf32_matmul = torch.backends.cuda.matmul.fp32_precision
tf32_conv = torch.backends.cudnn.conv.fp32_precision
else:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
tf32_matmul = 'tf32' if torch.backends.cuda.matmul.allow_tf32 else 'ieee'
tf32_conv = 'tf32' if torch.backends.cudnn.allow_tf32 else 'ieee'
torch.backends.cudnn.benchmark = True
print(f"TF32 enabled: matmul={tf32_matmul}, conv={tf32_conv}")
print("cuDNN benchmarking enabled")
if config.ENABLE_TORCH_COMPILE:
try:
compile_mode = os.environ.get('TORCH_COMPILE_MODE', 'reduce-overhead')
if compile_mode == 'max-autotune':
print(f" β†’ Compiling with max-autotune (may take 2-5 min on first run)...")
model = torch.compile(model, mode='max-autotune', fullgraph=False)
print(f"βœ“ Model compiled with torch.compile (mode=max-autotune, fullgraph=False)")
elif compile_mode == 'default':
print(f" β†’ Compiling with default mode (may take 1-3 min on first run)...")
model = torch.compile(model, fullgraph=False)
print(f"βœ“ Model compiled with torch.compile (mode=default, fullgraph=False)")
else:
print(f" β†’ Compiling with reduce-overhead (faster first run, ~30-60s)...")
model = torch.compile(model, mode='reduce-overhead', fullgraph=False)
print(f"βœ“ Model compiled with torch.compile (mode=reduce-overhead, fullgraph=False)")
except Exception as e:
print(f" ⚠ torch.compile failed: {e}. Continuing without compilation.")
else:
print(" β„Ή torch.compile disabled (set ENABLE_TORCH_COMPILE=true to enable)")
torch.cuda.empty_cache()
torch.cuda.synchronize()
allocated_after_load = torch.cuda.memory_allocated(0) / (1024**3)
free_after_load = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / (1024**3)
print(f" β†’ GPU memory after model load: {allocated_after_load:.2f} GB allocated, {free_after_load:.2f} GB free")
if free_after_load < 1.0:
print(f" ⚠ CRITICAL: Only {free_after_load:.2f} GB free after model load. Using ultra-minimal settings.")
size = [192, 192, 32]
batch_size = 1
overlap = 0.25
elif free_after_load < 2.0:
print(f" ⚠ WARNING: Low free memory ({free_after_load:.2f} GB) after model load. Adjusting to minimal settings.")
size = [192, 192, 32]
batch_size = 1
overlap = 0.25
elif free_after_load > 40:
print(f" Excellent free memory ({free_after_load:.2f} GB) after model load. Using optimal settings for maximum speed.")
size = [256, 256, 80]
batch_size = 2
overlap = 0.1
elif free_after_load > 30:
print(f" Excellent free memory ({free_after_load:.2f} GB) after model load. Using optimal settings for speed.")
size = [256, 256, 64]
batch_size = 2
overlap = 0.1
elif free_after_load > 25:
print(f" βœ“ Good free memory ({free_after_load:.2f} GB) after model load. Using optimal settings.")
size = [256, 256, 64]
batch_size = 1
overlap = 0.1
elif free_after_load > 20:
print(f" βœ“ Good free memory ({free_after_load:.2f} GB) after model load. Using optimal settings.")
size = [256, 256, 64]
batch_size = 1
overlap = 0.1
elif free_after_load > 15:
size = [256, 256, 64]
batch_size = 1
overlap = 0.1
elif free_after_load < 5.0 and (size[0] > 224 or batch_size > 1):
print(f" ⚠ WARNING: Limited free memory ({free_after_load:.2f} GB). Reducing window size and batch size.")
size = [224, 224, 48]
batch_size = 1
overlap = 0.1
aggregation_device = 'cuda'
if free_after_load < 2.0:
aggregation_device = 'cpu'
print(f" β†’ Very low VRAM ({free_after_load:.2f} GB), using CPU aggregation to prevent OOM")
else:
print(f" β†’ Using GPU aggregation for maximum speed (VRAM: {free_after_load:.2f} GB free)")
WINDOW_INFER = SlidingWindowInferer(
roi_size=size,
sw_batch_size=batch_size,
overlap=overlap,
sw_device='cuda',
device=aggregation_device
)
print(f"βœ“ Sliding window inferer created (GPU compute, {aggregation_device.upper()} aggregation)")
if DEVICE.type == 'cuda':
if config.ENABLE_TORCH_COMPILE:
print(" Running warm-up inference to trigger compilation and kernel autotuning...")
print(" This may take 30-60s (reduce-overhead) or 2-5min (max-autotune) on first run...")
else:
print(" Running warm-up inference to trigger kernel autotuning...")
try:
dummy_input = torch.randn(1, 1, size[0], size[1], size[2], device=DEVICE, dtype=torch.float32)
dummy_input = dummy_input.contiguous(memory_format=torch.channels_last_3d)
warmup_start = time.time()
with torch.no_grad():
from torch.amp import autocast
with autocast(device_type='cuda'):
_ = model(dummy_input)
torch.cuda.synchronize()
warmup_time = time.time() - warmup_start
del dummy_input, _
torch.cuda.empty_cache()
if config.ENABLE_TORCH_COMPILE:
print(f" Warm-up completed in {warmup_time:.1f}s (compilation + kernel autotuning)")
else:
print(f" Warm-up completed in {warmup_time:.1f}s (kernels autotuned)")
except RuntimeError as e:
if "out of memory" in str(e):
print(f" Warm-up OOM (non-critical): {e}")
print(f" Will use progressive fallback during inference")
else:
print(f" Warm-up failed (non-critical): {e}")
except Exception as e:
print(f" Warm-up failed (non-critical): {e}")
if modality == 'T1':
MODEL_T1 = model
else:
MODEL_T2 = model
print(f"βœ“ {modality} model loaded and ready")
return model