File size: 14,474 Bytes
931ed7c 493e2f1 931ed7c 11a9b37 931ed7c 11a9b37 931ed7c 59e6707 05125f7 59e6707 197d4e2 05125f7 8aecceb 59e6707 8aecceb 197d4e2 59e6707 8aecceb e1b9ddc 8aecceb 197d4e2 8aecceb 11a9b37 931ed7c 8aecceb 11a9b37 931ed7c e1b9ddc 11a9b37 e1b9ddc 931ed7c ef22c5d 700e0b1 d8b8562 ff6b142 d8b8562 6bf725e d8b8562 6bf725e c8db90d 700e0b1 b89934f c8db90d b89934f c8db90d 59e6707 b89934f 59e6707 c8db90d b89934f 59e6707 700e0b1 59e6707 ef22c5d 11a9b37 59e6707 05125f7 59e6707 05125f7 59e6707 197d4e2 59e6707 8aecceb 197d4e2 59e6707 8aecceb 197d4e2 8aecceb 11a9b37 8aecceb 11a9b37 59e6707 8aecceb 59e6707 8aecceb 59e6707 931ed7c c8db90d b89934f 05125f7 b89934f 05125f7 c8db90d b89934f c8db90d b89934f c8db90d b89934f 05125f7 b89934f 05125f7 c8db90d 05125f7 c8db90d 931ed7c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 |
import os
import time
import torch
from monai.inferers import SlidingWindowInferer
from config import BUILD_SRMAMAMBA_AVAILABLE, build_SRMAMamba, SRMA_MAMBA_DIR
MODEL_T1 = None
MODEL_T2 = None
DEVICE = torch.device('cpu')
WINDOW_INFER = None
def clear_gpu_memory():
global MODEL_T1, MODEL_T2, WINDOW_INFER
if torch.cuda.is_available():
if MODEL_T1 is not None:
del MODEL_T1
MODEL_T1 = None
if MODEL_T2 is not None:
del MODEL_T2
MODEL_T2 = None
if WINDOW_INFER is not None:
del WINDOW_INFER
WINDOW_INFER = None
torch.cuda.empty_cache()
torch.cuda.synchronize()
print(" β GPU memory cleared (models unloaded)")
def load_model(modality='T1'):
global MODEL_T1, MODEL_T2, DEVICE, WINDOW_INFER, BUILD_SRMAMAMBA_AVAILABLE
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
if not BUILD_SRMAMAMBA_AVAILABLE or build_SRMAMamba is None:
error_msg = "Model builder (build_SRMAMamba) is not available. Please check the logs for import errors."
print(f"β {error_msg}")
raise ImportError(error_msg)
print(f"Loading {modality} model...")
if torch.cuda.is_available():
try:
max_retries = 3
retry_delay = 2
for attempt in range(max_retries):
try:
torch.cuda.empty_cache()
test_tensor = torch.zeros(1).cuda()
del test_tensor
torch.cuda.synchronize()
DEVICE = torch.device('cuda')
print(f"β Using device: {DEVICE}")
break
except RuntimeError as e:
if "CUDA" in str(e) and attempt < max_retries - 1:
print(f"β GPU wake-up attempt {attempt + 1}/{max_retries}: {e}")
print(f"β Waiting {retry_delay}s for GPU to wake up...")
time.sleep(retry_delay)
retry_delay *= 2
else:
raise
except Exception as e:
print(f"β CUDA available but failed to initialize: {e}. Falling back to CPU.")
DEVICE = torch.device('cpu')
else:
DEVICE = torch.device('cpu')
print(f"βΉ CUDA not available. Using device: {DEVICE}")
if DEVICE.type == 'cuda':
torch.cuda.empty_cache()
torch.cuda.synchronize()
allocated = torch.cuda.memory_allocated(0) / (1024**3)
reserved = torch.cuda.memory_reserved(0) / (1024**3)
total = torch.cuda.get_device_properties(0).total_memory / (1024**3)
free_memory_gb = total - allocated
print(f" β GPU memory: {allocated:.2f} GB allocated, {reserved:.2f} GB reserved, {free_memory_gb:.2f} GB free (total: {total:.2f} GB)")
if free_memory_gb < 1.0:
print(f" β CRITICAL: Very low free memory ({free_memory_gb:.2f} GB). Using ultra-minimal settings.")
size = [192, 192, 32]
batch_size = 1
overlap = 0.25
elif free_memory_gb < 2.0:
print(f" β WARNING: Very low free memory ({free_memory_gb:.2f} GB). Using minimal settings.")
size = [192, 192, 32]
batch_size = 1
overlap = 0.25
elif free_memory_gb < 5.0:
size = [224, 224, 48]
batch_size = 1
overlap = 0.2
elif free_memory_gb > 40:
print(f" Very high VRAM GPU detected ({free_memory_gb:.2f} GB free). Using optimal settings for maximum speed.")
size = [256, 256, 80]
batch_size = 2
overlap = 0.1
elif free_memory_gb > 30:
print(f" High VRAM GPU detected ({free_memory_gb:.2f} GB free). Using optimal settings for speed.")
size = [256, 256, 64]
batch_size = 2
overlap = 0.1
elif free_memory_gb > 25:
print(f" β Large VRAM GPU detected ({free_memory_gb:.2f} GB free). Using optimal settings.")
size = [256, 256, 64]
batch_size = 1
overlap = 0.1
elif free_memory_gb > 20:
size = [256, 256, 64]
batch_size = 1
overlap = 0.1
elif free_memory_gb > 15:
size = [256, 256, 64]
batch_size = 1
overlap = 0.1
elif free_memory_gb > 10:
size = [224, 224, 64]
batch_size = 1
overlap = 0.1
elif free_memory_gb > 8:
size = [224, 224, 48]
batch_size = 1
overlap = 0.2
else:
size = [192, 192, 48]
batch_size = 1
overlap = 0.2
else:
size = [224, 224, 64]
batch_size = 1
overlap = 0.15
print(f" β Sliding window config: roi_size={size}, sw_batch_size={batch_size}, overlap={overlap}")
print("Building model architecture...")
if SRMA_MAMBA_DIR:
original_cwd = os.getcwd()
try:
os.chdir(SRMA_MAMBA_DIR)
print(f"Changed working directory to: {SRMA_MAMBA_DIR}")
model = build_SRMAMamba()
print("β Model architecture built")
finally:
os.chdir(original_cwd)
else:
model = build_SRMAMamba()
print("β Model architecture built")
model = model.to(DEVICE)
print(f"β Model moved to {DEVICE}")
checkpoint_path = f"checkpoint_{modality}.pth"
possible_paths = [
checkpoint_path,
os.path.join(os.path.dirname(__file__), checkpoint_path),
f"../../Chkpoints/checkpoint_{modality}.pth",
f"Chkpoints/checkpoint_{modality}.pth",
f"../Chkpoints/checkpoint_{modality}.pth",
f"Model/Chkpoints/checkpoint_{modality}.pth",
os.path.join(os.path.dirname(__file__), f"Chkpoints/checkpoint_{modality}.pth"),
]
found = False
for path in possible_paths:
abs_path = os.path.abspath(path)
if os.path.exists(path) or os.path.exists(abs_path):
checkpoint_path = path if os.path.exists(path) else abs_path
found = True
print(f"β Found checkpoint at: {checkpoint_path}")
break
if not found:
try:
from huggingface_hub import hf_hub_download
repo_id = os.environ.get("HF_MODEL_REPO", "HarshithReddy01/srmamamba-liver-segmentation")
print(f"Attempting to download checkpoint from Hugging Face: {repo_id}")
checkpoint_path = hf_hub_download(
repo_id=repo_id,
filename=f"checkpoint_{modality}.pth",
cache_dir="."
)
found = True
print(f"β Downloaded checkpoint to: {checkpoint_path}")
except Exception as e:
error_msg = f"Checkpoint not found. Searched: {possible_paths}. Hugging Face download failed: {str(e)}"
print(f"β {error_msg}")
raise FileNotFoundError(error_msg)
print(f"Loading checkpoint weights from: {checkpoint_path}")
try:
checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint)
print("β Checkpoint loaded successfully")
except Exception as e:
print(f"β Failed to load checkpoint: {e}")
raise
model.eval()
print("β Model set to evaluation mode")
if DEVICE.type == 'cuda':
import config
from packaging import version
torch_version = version.parse(torch.__version__)
if torch_version >= version.parse("2.9.0"):
torch.backends.cuda.matmul.fp32_precision = 'tf32'
torch.backends.cudnn.conv.fp32_precision = 'tf32'
tf32_matmul = torch.backends.cuda.matmul.fp32_precision
tf32_conv = torch.backends.cudnn.conv.fp32_precision
else:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
tf32_matmul = 'tf32' if torch.backends.cuda.matmul.allow_tf32 else 'ieee'
tf32_conv = 'tf32' if torch.backends.cudnn.allow_tf32 else 'ieee'
torch.backends.cudnn.benchmark = True
print(f"TF32 enabled: matmul={tf32_matmul}, conv={tf32_conv}")
print("cuDNN benchmarking enabled")
if config.ENABLE_TORCH_COMPILE:
try:
compile_mode = os.environ.get('TORCH_COMPILE_MODE', 'reduce-overhead')
if compile_mode == 'max-autotune':
print(f" β Compiling with max-autotune (may take 2-5 min on first run)...")
model = torch.compile(model, mode='max-autotune', fullgraph=False)
print(f"β Model compiled with torch.compile (mode=max-autotune, fullgraph=False)")
elif compile_mode == 'default':
print(f" β Compiling with default mode (may take 1-3 min on first run)...")
model = torch.compile(model, fullgraph=False)
print(f"β Model compiled with torch.compile (mode=default, fullgraph=False)")
else:
print(f" β Compiling with reduce-overhead (faster first run, ~30-60s)...")
model = torch.compile(model, mode='reduce-overhead', fullgraph=False)
print(f"β Model compiled with torch.compile (mode=reduce-overhead, fullgraph=False)")
except Exception as e:
print(f" β torch.compile failed: {e}. Continuing without compilation.")
else:
print(" βΉ torch.compile disabled (set ENABLE_TORCH_COMPILE=true to enable)")
torch.cuda.empty_cache()
torch.cuda.synchronize()
allocated_after_load = torch.cuda.memory_allocated(0) / (1024**3)
free_after_load = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / (1024**3)
print(f" β GPU memory after model load: {allocated_after_load:.2f} GB allocated, {free_after_load:.2f} GB free")
if free_after_load < 1.0:
print(f" β CRITICAL: Only {free_after_load:.2f} GB free after model load. Using ultra-minimal settings.")
size = [192, 192, 32]
batch_size = 1
overlap = 0.25
elif free_after_load < 2.0:
print(f" β WARNING: Low free memory ({free_after_load:.2f} GB) after model load. Adjusting to minimal settings.")
size = [192, 192, 32]
batch_size = 1
overlap = 0.25
elif free_after_load > 40:
print(f" Excellent free memory ({free_after_load:.2f} GB) after model load. Using optimal settings for maximum speed.")
size = [256, 256, 80]
batch_size = 2
overlap = 0.1
elif free_after_load > 30:
print(f" Excellent free memory ({free_after_load:.2f} GB) after model load. Using optimal settings for speed.")
size = [256, 256, 64]
batch_size = 2
overlap = 0.1
elif free_after_load > 25:
print(f" β Good free memory ({free_after_load:.2f} GB) after model load. Using optimal settings.")
size = [256, 256, 64]
batch_size = 1
overlap = 0.1
elif free_after_load > 20:
print(f" β Good free memory ({free_after_load:.2f} GB) after model load. Using optimal settings.")
size = [256, 256, 64]
batch_size = 1
overlap = 0.1
elif free_after_load > 15:
size = [256, 256, 64]
batch_size = 1
overlap = 0.1
elif free_after_load < 5.0 and (size[0] > 224 or batch_size > 1):
print(f" β WARNING: Limited free memory ({free_after_load:.2f} GB). Reducing window size and batch size.")
size = [224, 224, 48]
batch_size = 1
overlap = 0.1
aggregation_device = 'cuda'
if free_after_load < 2.0:
aggregation_device = 'cpu'
print(f" β Very low VRAM ({free_after_load:.2f} GB), using CPU aggregation to prevent OOM")
else:
print(f" β Using GPU aggregation for maximum speed (VRAM: {free_after_load:.2f} GB free)")
WINDOW_INFER = SlidingWindowInferer(
roi_size=size,
sw_batch_size=batch_size,
overlap=overlap,
sw_device='cuda',
device=aggregation_device
)
print(f"β Sliding window inferer created (GPU compute, {aggregation_device.upper()} aggregation)")
if DEVICE.type == 'cuda':
if config.ENABLE_TORCH_COMPILE:
print(" Running warm-up inference to trigger compilation and kernel autotuning...")
print(" This may take 30-60s (reduce-overhead) or 2-5min (max-autotune) on first run...")
else:
print(" Running warm-up inference to trigger kernel autotuning...")
try:
dummy_input = torch.randn(1, 1, size[0], size[1], size[2], device=DEVICE, dtype=torch.float32)
dummy_input = dummy_input.contiguous(memory_format=torch.channels_last_3d)
warmup_start = time.time()
with torch.no_grad():
from torch.amp import autocast
with autocast(device_type='cuda'):
_ = model(dummy_input)
torch.cuda.synchronize()
warmup_time = time.time() - warmup_start
del dummy_input, _
torch.cuda.empty_cache()
if config.ENABLE_TORCH_COMPILE:
print(f" Warm-up completed in {warmup_time:.1f}s (compilation + kernel autotuning)")
else:
print(f" Warm-up completed in {warmup_time:.1f}s (kernels autotuned)")
except RuntimeError as e:
if "out of memory" in str(e):
print(f" Warm-up OOM (non-critical): {e}")
print(f" Will use progressive fallback during inference")
else:
print(f" Warm-up failed (non-critical): {e}")
except Exception as e:
print(f" Warm-up failed (non-critical): {e}")
if modality == 'T1':
MODEL_T1 = model
else:
MODEL_T2 = model
print(f"β {modality} model loaded and ready")
return model
|