File size: 26,995 Bytes
5cb5bf3 8aecceb b00f43e e0e8f87 8aecceb f9a3113 5cb5bf3 985fc5a 5cb5bf3 88dc9a3 985fc5a 951621e 88dc9a3 985fc5a 5cb5bf3 47c72b7 931ed7c 493e2f1 931ed7c 951621e 05125f7 700e0b1 05125f7 700e0b1 802e81d 239e74c 29b978c 239e74c ff6b142 d8b8562 6bf725e ff6b142 700e0b1 05125f7 700e0b1 05125f7 d8b8562 05125f7 5f56790 05125f7 700e0b1 05125f7 700e0b1 05125f7 700e0b1 05125f7 ff6b142 05125f7 ff6b142 05125f7 ff6b142 05125f7 b00f43e 05125f7 fedba42 700e0b1 05125f7 700e0b1 985fc5a 3a531e5 dc7149f 88312c3 dc7149f 88312c3 3a531e5 dc7149f 3a531e5 88dc9a3 c478bee 71b7140 c478bee 985fc5a c478bee 985fc5a c478bee 985fc5a c478bee 985fc5a c478bee 985fc5a 4b1ce2f 985fc5a 4b1ce2f 985fc5a ae91ef9 3a24f54 985fc5a e3ce1b9 985fc5a 493e2f1 e3ce1b9 985fc5a e3ce1b9 88312c3 985fc5a 0b85642 88312c3 0b85642 985fc5a a1bd718 0b85642 9eb761e c478bee a1bd718 71b7140 985fc5a a1bd718 e3ce1b9 a1bd718 e3ce1b9 89c2460 e3ce1b9 89c2460 e3ce1b9 985fc5a e3ce1b9 985fc5a a1bd718 985fc5a a1bd718 985fc5a e3ce1b9 985fc5a 89c2460 e3ce1b9 985fc5a 89c2460 e3ce1b9 985fc5a 5cb5bf3 985fc5a 5cb5bf3 e3ce1b9 89c2460 5cb5bf3 fc50c7b 5cb5bf3 02481ad 5cb5bf3 985fc5a 5cb5bf3 9e2df9f 9663401 931ed7c 9e2df9f 700e0b1 fc50c7b 700e0b1 e1b9ddc 700e0b1 e1b9ddc 700e0b1 9e2df9f fc50c7b 9e2df9f fc50c7b a357ddd 9e2df9f 700e0b1 9e2df9f fc50c7b 5cb5bf3 23f6376 5cb5bf3 3aaab6c 5cb5bf3 3aaab6c fc50c7b 3aaab6c 5cb5bf3 3aaab6c 5cb5bf3 fc50c7b 931ed7c 493e2f1 fc50c7b 5cb5bf3 05bdb8b 5cb5bf3 fc50c7b 5cb5bf3 963eb0b 700e0b1 360f3b6 1bb4c2b 5b260e5 579632e 8c09633 88312c3 dc7149f 88312c3 dc7149f 8c09633 579632e 88312c3 579632e 1bb4c2b 5b260e5 1bb4c2b 6c02567 360f3b6 1bb4c2b 0ca6c99 1bb4c2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 |
import os
if 'PYTORCH_ALLOC_CONF' not in os.environ:
os.environ['PYTORCH_ALLOC_CONF'] = 'expandable_segments:True,max_split_size_mb=128'
if 'TRITON_CACHE_DIR' not in os.environ:
triton_cache = '/tmp/.triton'
os.environ['TRITON_CACHE_DIR'] = triton_cache
try:
os.makedirs(triton_cache, exist_ok=True)
except:
pass
import tempfile
import base64
import torch
import nibabel as nib
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
from typing import Optional
# Import gradio first (needed for patching)
import gradio as gr
# Fix for Gradio 4.44.x schema bug: convert boolean additionalProperties to dict
def fix_gradio_schema_bug():
"""
Monkeypatch to fix Gradio 4.44.x crash when additionalProperties is boolean instead of dict.
The error occurs in gradio_client/utils.py -> get_type() when it tries:
if "const" in schema: # crashes if schema is bool (True/False)
This happens when Gradio infers schemas with additionalProperties: true
and then tries to process them as dicts.
"""
try:
# Patch gradio_client.utils.get_type - this is where the crash happens
try:
import gradio_client.utils as gradio_utils
if hasattr(gradio_utils, 'get_type'):
original_get_type = gradio_utils.get_type
def patched_get_type(schema):
"""Fix boolean additionalProperties and handle bool schemas."""
# Handle case where schema itself is a boolean (the actual bug)
if isinstance(schema, bool):
return "Any"
# Normalize additionalProperties: true -> {}
if isinstance(schema, dict):
# Recursively fix nested schemas in properties
if "properties" in schema and isinstance(schema["properties"], dict):
for prop_name, prop_schema in schema["properties"].items():
if isinstance(prop_schema, dict):
patched_get_type(prop_schema)
# Fix additionalProperties: true -> additionalProperties: {}
if "additionalProperties" in schema:
if schema["additionalProperties"] is True:
schema["additionalProperties"] = {}
elif schema["additionalProperties"] is False:
# False means no additional properties allowed
schema.pop("additionalProperties", None)
elif isinstance(schema["additionalProperties"], dict):
# Recursively fix nested additionalProperties
patched_get_type(schema["additionalProperties"])
return original_get_type(schema)
gradio_utils.get_type = patched_get_type
print("β Applied Gradio schema bug fix (gradio_client.utils.get_type)")
except (ImportError, AttributeError) as e:
print(f"β Could not patch gradio_client.utils: {e}")
# Also patch gradio's API info generation to normalize schemas
try:
if hasattr(gr, 'Blocks'):
# Patch the _get_api_info method if it exists
if hasattr(gr.Blocks, '_get_api_info'):
original_get_api_info = gr.Blocks._get_api_info
def patched_get_api_info(self):
"""Normalize schemas before API info generation."""
api_info = original_get_api_info(self)
if api_info and isinstance(api_info, dict):
def normalize_schema(schema):
if isinstance(schema, dict):
if "additionalProperties" in schema and schema["additionalProperties"] is True:
schema["additionalProperties"] = {}
if "properties" in schema:
for prop in schema["properties"].values():
normalize_schema(prop)
# Normalize all schemas in api_info
if "named_endpoints" in api_info:
for endpoint_info in api_info["named_endpoints"].values():
if "parameters" in endpoint_info:
for param in endpoint_info["parameters"]:
if "component" in param and "serializer" in param["component"]:
if "schema" in param["component"]["serializer"]:
normalize_schema(param["component"]["serializer"]["schema"])
return api_info
gr.Blocks._get_api_info = patched_get_api_info
print("β Applied Gradio schema bug fix (Blocks._get_api_info)")
except Exception as e:
print(f"β Could not patch Blocks._get_api_info: {e}")
print("β Gradio schema bug fix applied successfully")
except Exception as e:
print(f"β Could not apply Gradio schema fix: {e}")
fix_gradio_schema_bug()
import config
import model_loader
from inference import predict_volume_api, safe_predict_volume, PROCESSING_LOCK
def log_startup_health():
print("=" * 60)
print("STARTUP HEALTH CHECK")
print("=" * 60)
import torch
import os
try:
is_docker = os.path.exists('/.dockerenv') or (os.path.exists('/proc/self/cgroup') and 'docker' in open('/proc/self/cgroup').read())
except:
is_docker = False
is_python_space = '/home/user/.pyenv' in os.environ.get('PATH', '') or os.path.exists('/home/user/.pyenv')
if is_python_space and not is_docker:
print("WARNING: Running on Python Space (managed environment)")
print("Python Spaces do NOT have CUDA toolchain (nvcc) - CUDA extensions CANNOT build")
print("ACTION REQUIRED: Switch to Docker Space in HF Settings -> Runtime -> Docker")
print("Without Docker, CUDA extensions will remain NOT INSTALLED (slow fallback)")
print("=" * 60)
elif is_docker:
print("Running in Docker container - CUDA extensions should be available")
if torch.cuda.is_available():
import config
from packaging import version
torch_version = version.parse(torch.__version__)
if torch_version >= version.parse("2.9.0"):
torch.backends.cuda.matmul.fp32_precision = 'tf32'
torch.backends.cudnn.conv.fp32_precision = 'tf32'
else:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
print(f"PyTorch: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA Version: {torch.version.cuda}")
print(f"cuDNN Version: {torch.backends.cudnn.version()}")
gpu_name = torch.cuda.get_device_name(0)
gpu_props = torch.cuda.get_device_properties(0)
total_memory_gb = gpu_props.total_memory / (1024**3)
print(f"GPU: {gpu_name}")
print(f"GPU Memory: {total_memory_gb:.1f} GB total")
allocated = torch.cuda.memory_allocated(0) / (1024**3)
reserved = torch.cuda.memory_reserved(0) / (1024**3)
free = total_memory_gb - allocated
print(f"GPU Memory Status: {allocated:.2f} GB allocated, {reserved:.2f} GB reserved, {free:.2f} GB free")
from packaging import version
torch_version = version.parse(torch.__version__)
if torch_version >= version.parse("2.9.0"):
tf32_matmul = getattr(torch.backends.cuda.matmul, 'fp32_precision', 'unknown')
tf32_conv = getattr(torch.backends.cudnn.conv, 'fp32_precision', 'unknown')
else:
tf32_matmul = 'tf32' if torch.backends.cuda.matmul.allow_tf32 else 'ieee'
tf32_conv = 'tf32' if torch.backends.cudnn.allow_tf32 else 'ieee'
cudnn_benchmark = torch.backends.cudnn.benchmark
print(f"TF32 Matmul: {tf32_matmul}")
print(f"TF32 Conv: {tf32_conv}")
print(f"cuDNN Benchmark: {cudnn_benchmark}")
compile_enabled = config.ENABLE_TORCH_COMPILE
compile_mode = os.environ.get('TORCH_COMPILE_MODE', 'reduce-overhead') if compile_enabled else 'disabled'
print(f"torch.compile: {compile_enabled} (mode: {compile_mode})")
try:
import monai
print(f"MONAI: {monai.__version__}")
except:
print("MONAI: NOT FOUND")
try:
import gradio
print(f"Gradio: {gradio.__version__}")
except:
print("Gradio: NOT FOUND")
try:
import nibabel
print(f"NiBabel: {nibabel.__version__}")
except:
print("NiBabel: NOT FOUND")
print("\nCUDA Extensions Status:")
try:
import mamba_ssm
try:
version = mamba_ssm.__version__
print(f" mamba_ssm: INSTALLED (version: {version})")
except:
print(" mamba_ssm: INSTALLED")
except ImportError:
print(" mamba_ssm: NOT INSTALLED (will use fallback - CRITICAL for speed)")
try:
import selective_scan_cuda_oflex
print(" selective_scan_cuda_oflex: INSTALLED")
except ImportError:
print(" selective_scan_cuda_oflex: NOT INSTALLED (will use fallback - CRITICAL for speed)")
print("\nEnvironment Variables:")
alloc_conf = os.environ.get('PYTORCH_ALLOC_CONF', 'not set')
print(f" PYTORCH_ALLOC_CONF: {alloc_conf}")
print(f" ENABLE_CUDNN_BENCHMARK: {config.ENABLE_CUDNN_BENCHMARK}")
print(f" ENABLE_TORCH_COMPILE: {config.ENABLE_TORCH_COMPILE}")
print(f" INFERENCE_TIMEOUT: {config.INFERENCE_TIMEOUT}s")
print(f" MAX_GRADIO_CONCURRENCY: {config.MAX_GRADIO_CONCURRENCY}")
print("=" * 60)
log_startup_health()
api_app = FastAPI(title="SRMA-Mamba API", version="1.0.0")
@api_app.middleware("http")
async def log_requests(request, call_next):
path = str(request.url.path)
print(f"[REQUEST] {request.method} {path}")
content_type = request.headers.get('content-type', 'N/A')
print(f"[REQUEST] Content-Type: {content_type}")
if '/segment' in path:
is_multipart = 'multipart' in content_type.lower()
print(f"[REQUEST] Is multipart: {is_multipart}")
if not is_multipart:
print(f"[REQUEST] WARNING: Expected multipart/form-data but got: {content_type}")
print(f"[REQUEST] Headers: {dict(request.headers)}")
response = await call_next(request)
print(f"[RESPONSE] {request.method} {path} -> {response.status_code}")
return response
@api_app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
print(f"[ERROR] RequestValidationError on {request.method} {request.url.path}")
content_type = request.headers.get('content-type', 'N/A')
print(f"[ERROR] Content-Type: {content_type}")
print(f"[ERROR] First error: {exc.errors()[0] if exc.errors() else 'No errors'}")
error_detail = []
for error in exc.errors():
error_detail.append({
"type": error.get("type"),
"loc": error.get("loc"),
"msg": error.get("msg")
})
return JSONResponse(
status_code=422,
content={
"detail": error_detail,
"message": "Request validation failed. Ensure Content-Type is multipart/form-data for file uploads.",
"content_type_received": content_type
}
)
@api_app.exception_handler(StarletteHTTPException)
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
return JSONResponse(
status_code=exc.status_code,
content={"detail": exc.detail}
)
allowed_origins = [
"https://harshithreddy01.github.io",
"https://harshithreddy01.github.io/frontend-SRMA-Liver",
"https://app.paninsight.org",
"http://localhost:5173",
"http://localhost:3000",
"http://localhost:8080",
]
api_app.add_middleware(
CORSMiddleware,
allow_origins=allowed_origins,
allow_credentials=True,
allow_methods=["GET", "POST", "OPTIONS", "HEAD"],
allow_headers=["*"],
expose_headers=["*"],
)
print(f"β CORS configured for origins: {allowed_origins}")
@api_app.get("/")
@api_app.get("/api")
async def root():
return {
"message": "SRMA-Mamba Liver Segmentation API",
"version": "1.0.0",
"endpoints": {
"/api/segment": "POST - Upload NIfTI file for segmentation",
"/api/health": "GET - Health check",
"/api/download/{token}": "GET - Download segmentation mask",
"/docs": "API documentation"
}
}
@api_app.get("/health")
@api_app.get("/api/health")
async def health_check():
gpu_name = "unknown"
gpu_memory_gb = 0.0
if torch.cuda.is_available() and model_loader.DEVICE and model_loader.DEVICE.type == 'cuda':
try:
gpu_name = torch.cuda.get_device_name(0)
gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
except:
pass
return {
"status": "healthy",
"device": str(model_loader.DEVICE) if model_loader.DEVICE else "not initialized",
"model_t1_loaded": model_loader.MODEL_T1 is not None,
"model_t2_loaded": model_loader.MODEL_T2 is not None,
"gpu_name": gpu_name,
"gpu_memory_gb": round(gpu_memory_gb, 1)
}
@api_app.get("/api/download/{token}")
async def download_mask(token: str):
import secrets
token_dir = "/tmp/seg_tokens"
token_path = os.path.join(token_dir, f"{token}.nii.gz")
if not os.path.exists(token_path):
raise HTTPException(status_code=404, detail="Token not found or expired")
from fastapi.responses import FileResponse
return FileResponse(
token_path,
media_type="application/gzip",
filename=f"liver_segmentation_{token[:8]}.nii.gz",
headers={"X-Token": token}
)
@api_app.post("/segment", response_class=JSONResponse, include_in_schema=True)
@api_app.post("/api/segment", response_class=JSONResponse, include_in_schema=True)
async def segment_liver(
request: Request,
file: UploadFile = File(..., description="NIfTI file to segment"),
modality: str = Form("T1", description="MRI modality: T1 or T2"),
slice_idx: Optional[int] = Form(None, description="Optional slice index")
):
print("=" * 60)
print(f"API REQUEST RECEIVED: /api/segment")
print(f" Origin: {request.headers.get('origin', 'N/A')}")
print(f" Content-Type: {request.headers.get('content-type', 'N/A')}")
print(f" File: {file.filename if file else 'None'}")
print(f" File Content-Type: {file.content_type if file else 'None'}")
print(f" Modality: {modality}")
print(f" Slice Index: {slice_idx}")
print("=" * 60)
if not file or not file.filename:
raise HTTPException(status_code=400, detail="File is required")
if not file.filename.endswith(('.nii', '.nii.gz', '.gz')):
raise HTTPException(status_code=400, detail="File must be a NIfTI file (.nii or .nii.gz)")
if modality not in ['T1', 'T2']:
raise HTTPException(status_code=400, detail="Modality must be 'T1' or 'T2'")
content = await file.read()
file_size_mb = len(content) / (1024**2)
print(f"API: File size: {file_size_mb:.2f} MB")
if file_size_mb > 2000:
raise HTTPException(
status_code=413,
detail=f"File too large: {file_size_mb:.1f} MB. Maximum upload size is 2 GB. Please compress or resample your NIfTI file."
)
with tempfile.NamedTemporaryFile(delete=False, suffix='.nii.gz') as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name
import secrets
import shutil
try:
print(f"API: Starting inference for {modality} modality...")
result = predict_volume_api(tmp_path, modality, slice_idx)
print(f"API: Inference completed. Success: {result.get('success', False)}")
if not result["success"]:
raise HTTPException(status_code=500, detail=result.get("error", "Unknown error"))
seg_path = result["segmentation_path"]
seg_file_size = os.path.getsize(seg_path) / (1024**2)
if seg_file_size > 2000:
token = secrets.token_urlsafe(16)
token_dir = "/tmp/seg_tokens"
os.makedirs(token_dir, exist_ok=True)
token_path = os.path.join(token_dir, f"{token}.nii.gz")
shutil.copy2(seg_path, token_path)
result["mask_path_token"] = token
result["mask_download_url"] = f"/api/download/{token}"
result["segmentation_file"] = None
print(f"API: Large mask file ({seg_file_size:.1f} MB). Using token-based download: {token}")
else:
with open(seg_path, "rb") as seg_file:
seg_data = seg_file.read()
seg_base64 = base64.b64encode(seg_data).decode('utf-8')
result["segmentation_file"] = f"data:application/octet-stream;base64,{seg_base64}"
result["mask_path_token"] = None
result["mask_download_url"] = None
os.unlink(tmp_path)
if seg_file_size <= 2000:
os.unlink(seg_path)
return JSONResponse(content=result)
except Exception as e:
if os.path.exists(tmp_path):
os.unlink(tmp_path)
raise HTTPException(status_code=500, detail=str(e))
def create_interface():
with gr.Blocks(title="SRMA-Mamba: Liver Segmentation", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# SRMA-Mamba: Liver Segmentation in MRI Volumes
Upload a 3D NIfTI MRI volume (.nii.gz) to perform automatic liver segmentation.
**Model Performance:**
- Pixel Accuracy: 99.09%
- IoU: 75%
- PSNR: 29.64 dB
**Supported Modalities:** T1-weighted and T2-weighted MRI
**API Available:** This Space also provides a REST API. See `/docs` for documentation.
""")
with gr.Row():
with gr.Column():
nifti_input = gr.File(
file_count="single",
file_types=[".nii.gz", ".nii"],
label="Upload NIfTI File (max 2 GB)"
)
modality = gr.Radio(
choices=["T1", "T2"],
value="T1",
label="MRI Modality"
)
slice_slider = gr.Slider(
minimum=0,
maximum=100,
value=50,
step=1,
label="Slice Index (will auto-update based on volume)",
interactive=True
)
predict_btn = gr.Button("Segment Liver", variant="primary")
reset_btn = gr.Button("Reset", variant="secondary")
with gr.Column():
output_image = gr.Image(
label="Segmentation Overlay",
type="pil"
)
output_info = gr.Markdown(
label="Statistics"
)
output_report = gr.Markdown(
label="Medical Report"
)
output_file = gr.File(
label="Download 3D Segmentation (.nii.gz)"
)
gr.Markdown("""
## Instructions
1. Upload a 3D NIfTI MRI volume (.nii.gz format)
2. Select the MRI modality (T1 or T2)
3. Click "Segment Liver" to run inference
4. View the segmentation overlay and download the 3D mask
**Note:** First inference may take longer as the model loads.
## API Usage
This Space provides a REST API for programmatic access:
- **Endpoint**: `POST /api/segment`
- **Documentation**: Visit `/docs` for interactive API docs
- **Example**: Use from your frontend with `fetch()` or `axios()`
""")
use_gpu = False
predict_fn = safe_predict_volume
if config.HAS_SPACES:
import spaces
if torch.cuda.is_available():
try:
test_tensor = torch.zeros(1).cuda()
del test_tensor
torch.cuda.empty_cache()
use_gpu = True
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "Unknown"
total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3) if torch.cuda.is_available() else 0
predict_fn = spaces.GPU(safe_predict_volume)
print(f"β Using GPU acceleration: {gpu_name} ({total_vram:.1f} GB VRAM)")
except ImportError:
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "Unknown"
total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3) if torch.cuda.is_available() else 0
predict_fn = safe_predict_volume
print(f"β Using GPU acceleration: {gpu_name} ({total_vram:.1f} GB VRAM)")
except Exception as e:
print(f"β GPU available but failed to initialize: {e}")
print("Falling back to CPU mode")
use_gpu = False
predict_fn = safe_predict_volume
else:
print("βΉ No GPU available. Running on CPU (slower but free)")
predict_fn = safe_predict_volume
else:
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3)
print(f"β CUDA available, using GPU: {gpu_name} ({total_vram:.1f} GB VRAM)")
use_gpu = True
else:
print("βΉ Running on CPU")
predict_fn = safe_predict_volume
predict_btn.click(
fn=predict_fn,
inputs=[nifti_input, modality, slice_slider],
outputs=[output_image, output_info, output_report, output_file]
)
def update_slice_slider(file):
if file is None:
return gr.update(maximum=100, value=50)
try:
file_path = file.name if hasattr(file, 'name') else str(file)
if not os.path.exists(file_path):
return gr.update(maximum=100, value=50)
volume = nib.load(file_path).get_fdata()
max_slices = volume.shape[-1] if len(volume.shape) == 3 else volume.shape[2]
return gr.update(maximum=max_slices - 1, value=max_slices // 2)
except Exception as e:
print(f"Error updating slice slider: {e}")
return gr.update(maximum=100, value=50)
def reset_interface():
import inference
inference.PROCESSING_LOCK = False
if model_loader.DEVICE and model_loader.DEVICE.type == 'cuda':
torch.cuda.empty_cache()
return None, None, "", None
nifti_input.change(
fn=update_slice_slider,
inputs=[nifti_input],
outputs=[slice_slider],
show_progress=False
)
reset_btn.click(
fn=reset_interface,
inputs=[],
outputs=[output_image, output_info, output_report, output_file]
)
return demo
demo = create_interface()
import config
demo.queue(max_size=config.MAX_GRADIO_CONCURRENCY)
app = api_app
if config.HAS_MOUNT_GRADIO_APP:
from gradio.routes import mount_gradio_app
print("Mounting Gradio interface on FastAPI app...")
print("FastAPI routes registered before mounting:")
for route in api_app.routes:
if hasattr(route, 'path') and hasattr(route, 'methods'):
print(f" - {list(route.methods)} {route.path}")
print("\nMounting Gradio at path '/gradio' to avoid interfering with API routes")
mount_gradio_app(app=api_app, blocks=demo, path="/gradio")
print("\nFastAPI routes after mounting:")
for route in api_app.routes:
if hasattr(route, 'path') and hasattr(route, 'methods'):
methods = list(route.methods) if hasattr(route, 'methods') else []
if 'POST' in methods and '/segment' in route.path:
print(f" β {methods} {route.path} - File upload endpoint")
print("\nβ FastAPI app with Gradio mounted at /gradio")
print("β API endpoints available at:")
print(" - GET /api/health")
print(" - POST /api/segment (multipart/form-data)")
print(" - GET /api/download/{token}")
print(" - GET /docs (FastAPI docs)")
print(" - GET /gradio (Gradio UI - optional)")
else:
print("β Gradio mount not available. FastAPI will run without Gradio UI.")
print("β API endpoints available at:")
print(" - GET /api/health")
print(" - POST /api/segment (multipart/form-data)")
print(" - GET /api/download/{token}")
print(" - GET /docs (FastAPI docs)")
print(f"\nβ Production server: FastAPI (app = api_app)")
print(f"β ASGI app exported as 'app' for Uvicorn")
if __name__ == "__main__":
import uvicorn
import os
port = int(os.getenv("PORT", 7860))
print(f"\nStarting FastAPI server with Uvicorn on port {port}")
print("This is for local development only.")
print("In production, Hugging Face Spaces will use the 'app' variable directly.")
uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")
|