import os if 'PYTORCH_ALLOC_CONF' not in os.environ: os.environ['PYTORCH_ALLOC_CONF'] = 'expandable_segments:True,max_split_size_mb=128' if 'TRITON_CACHE_DIR' not in os.environ: triton_cache = '/tmp/.triton' os.environ['TRITON_CACHE_DIR'] = triton_cache try: os.makedirs(triton_cache, exist_ok=True) except: pass import tempfile import base64 import torch import nibabel as nib from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi.exceptions import RequestValidationError from starlette.exceptions import HTTPException as StarletteHTTPException from typing import Optional # Import gradio first (needed for patching) import gradio as gr # Fix for Gradio 4.44.x schema bug: convert boolean additionalProperties to dict def fix_gradio_schema_bug(): """ Monkeypatch to fix Gradio 4.44.x crash when additionalProperties is boolean instead of dict. The error occurs in gradio_client/utils.py -> get_type() when it tries: if "const" in schema: # crashes if schema is bool (True/False) This happens when Gradio infers schemas with additionalProperties: true and then tries to process them as dicts. """ try: # Patch gradio_client.utils.get_type - this is where the crash happens try: import gradio_client.utils as gradio_utils if hasattr(gradio_utils, 'get_type'): original_get_type = gradio_utils.get_type def patched_get_type(schema): """Fix boolean additionalProperties and handle bool schemas.""" # Handle case where schema itself is a boolean (the actual bug) if isinstance(schema, bool): return "Any" # Normalize additionalProperties: true -> {} if isinstance(schema, dict): # Recursively fix nested schemas in properties if "properties" in schema and isinstance(schema["properties"], dict): for prop_name, prop_schema in schema["properties"].items(): if isinstance(prop_schema, dict): patched_get_type(prop_schema) # Fix additionalProperties: true -> additionalProperties: {} if "additionalProperties" in schema: if schema["additionalProperties"] is True: schema["additionalProperties"] = {} elif schema["additionalProperties"] is False: # False means no additional properties allowed schema.pop("additionalProperties", None) elif isinstance(schema["additionalProperties"], dict): # Recursively fix nested additionalProperties patched_get_type(schema["additionalProperties"]) return original_get_type(schema) gradio_utils.get_type = patched_get_type print("✓ Applied Gradio schema bug fix (gradio_client.utils.get_type)") except (ImportError, AttributeError) as e: print(f"⚠ Could not patch gradio_client.utils: {e}") # Also patch gradio's API info generation to normalize schemas try: if hasattr(gr, 'Blocks'): # Patch the _get_api_info method if it exists if hasattr(gr.Blocks, '_get_api_info'): original_get_api_info = gr.Blocks._get_api_info def patched_get_api_info(self): """Normalize schemas before API info generation.""" api_info = original_get_api_info(self) if api_info and isinstance(api_info, dict): def normalize_schema(schema): if isinstance(schema, dict): if "additionalProperties" in schema and schema["additionalProperties"] is True: schema["additionalProperties"] = {} if "properties" in schema: for prop in schema["properties"].values(): normalize_schema(prop) # Normalize all schemas in api_info if "named_endpoints" in api_info: for endpoint_info in api_info["named_endpoints"].values(): if "parameters" in endpoint_info: for param in endpoint_info["parameters"]: if "component" in param and "serializer" in param["component"]: if "schema" in param["component"]["serializer"]: normalize_schema(param["component"]["serializer"]["schema"]) return api_info gr.Blocks._get_api_info = patched_get_api_info print("✓ Applied Gradio schema bug fix (Blocks._get_api_info)") except Exception as e: print(f"⚠ Could not patch Blocks._get_api_info: {e}") print("✓ Gradio schema bug fix applied successfully") except Exception as e: print(f"⚠ Could not apply Gradio schema fix: {e}") fix_gradio_schema_bug() import config import model_loader from inference import predict_volume_api, safe_predict_volume, PROCESSING_LOCK def log_startup_health(): print("=" * 60) print("STARTUP HEALTH CHECK") print("=" * 60) import torch import os try: is_docker = os.path.exists('/.dockerenv') or (os.path.exists('/proc/self/cgroup') and 'docker' in open('/proc/self/cgroup').read()) except: is_docker = False is_python_space = '/home/user/.pyenv' in os.environ.get('PATH', '') or os.path.exists('/home/user/.pyenv') if is_python_space and not is_docker: print("WARNING: Running on Python Space (managed environment)") print("Python Spaces do NOT have CUDA toolchain (nvcc) - CUDA extensions CANNOT build") print("ACTION REQUIRED: Switch to Docker Space in HF Settings -> Runtime -> Docker") print("Without Docker, CUDA extensions will remain NOT INSTALLED (slow fallback)") print("=" * 60) elif is_docker: print("Running in Docker container - CUDA extensions should be available") if torch.cuda.is_available(): import config from packaging import version torch_version = version.parse(torch.__version__) if torch_version >= version.parse("2.9.0"): torch.backends.cuda.matmul.fp32_precision = 'tf32' torch.backends.cudnn.conv.fp32_precision = 'tf32' else: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = True print(f"PyTorch: {torch.__version__}") print(f"CUDA Available: {torch.cuda.is_available()}") if torch.cuda.is_available(): print(f"CUDA Version: {torch.version.cuda}") print(f"cuDNN Version: {torch.backends.cudnn.version()}") gpu_name = torch.cuda.get_device_name(0) gpu_props = torch.cuda.get_device_properties(0) total_memory_gb = gpu_props.total_memory / (1024**3) print(f"GPU: {gpu_name}") print(f"GPU Memory: {total_memory_gb:.1f} GB total") allocated = torch.cuda.memory_allocated(0) / (1024**3) reserved = torch.cuda.memory_reserved(0) / (1024**3) free = total_memory_gb - allocated print(f"GPU Memory Status: {allocated:.2f} GB allocated, {reserved:.2f} GB reserved, {free:.2f} GB free") from packaging import version torch_version = version.parse(torch.__version__) if torch_version >= version.parse("2.9.0"): tf32_matmul = getattr(torch.backends.cuda.matmul, 'fp32_precision', 'unknown') tf32_conv = getattr(torch.backends.cudnn.conv, 'fp32_precision', 'unknown') else: tf32_matmul = 'tf32' if torch.backends.cuda.matmul.allow_tf32 else 'ieee' tf32_conv = 'tf32' if torch.backends.cudnn.allow_tf32 else 'ieee' cudnn_benchmark = torch.backends.cudnn.benchmark print(f"TF32 Matmul: {tf32_matmul}") print(f"TF32 Conv: {tf32_conv}") print(f"cuDNN Benchmark: {cudnn_benchmark}") compile_enabled = config.ENABLE_TORCH_COMPILE compile_mode = os.environ.get('TORCH_COMPILE_MODE', 'reduce-overhead') if compile_enabled else 'disabled' print(f"torch.compile: {compile_enabled} (mode: {compile_mode})") try: import monai print(f"MONAI: {monai.__version__}") except: print("MONAI: NOT FOUND") try: import gradio print(f"Gradio: {gradio.__version__}") except: print("Gradio: NOT FOUND") try: import nibabel print(f"NiBabel: {nibabel.__version__}") except: print("NiBabel: NOT FOUND") print("\nCUDA Extensions Status:") try: import mamba_ssm try: version = mamba_ssm.__version__ print(f" mamba_ssm: INSTALLED (version: {version})") except: print(" mamba_ssm: INSTALLED") except ImportError: print(" mamba_ssm: NOT INSTALLED (will use fallback - CRITICAL for speed)") try: import selective_scan_cuda_oflex print(" selective_scan_cuda_oflex: INSTALLED") except ImportError: print(" selective_scan_cuda_oflex: NOT INSTALLED (will use fallback - CRITICAL for speed)") print("\nEnvironment Variables:") alloc_conf = os.environ.get('PYTORCH_ALLOC_CONF', 'not set') print(f" PYTORCH_ALLOC_CONF: {alloc_conf}") print(f" ENABLE_CUDNN_BENCHMARK: {config.ENABLE_CUDNN_BENCHMARK}") print(f" ENABLE_TORCH_COMPILE: {config.ENABLE_TORCH_COMPILE}") print(f" INFERENCE_TIMEOUT: {config.INFERENCE_TIMEOUT}s") print(f" MAX_GRADIO_CONCURRENCY: {config.MAX_GRADIO_CONCURRENCY}") print("=" * 60) log_startup_health() api_app = FastAPI(title="SRMA-Mamba API", version="1.0.0") @api_app.middleware("http") async def log_requests(request, call_next): path = str(request.url.path) print(f"[REQUEST] {request.method} {path}") content_type = request.headers.get('content-type', 'N/A') print(f"[REQUEST] Content-Type: {content_type}") if '/segment' in path: is_multipart = 'multipart' in content_type.lower() print(f"[REQUEST] Is multipart: {is_multipart}") if not is_multipart: print(f"[REQUEST] WARNING: Expected multipart/form-data but got: {content_type}") print(f"[REQUEST] Headers: {dict(request.headers)}") response = await call_next(request) print(f"[RESPONSE] {request.method} {path} -> {response.status_code}") return response @api_app.exception_handler(RequestValidationError) async def validation_exception_handler(request: Request, exc: RequestValidationError): print(f"[ERROR] RequestValidationError on {request.method} {request.url.path}") content_type = request.headers.get('content-type', 'N/A') print(f"[ERROR] Content-Type: {content_type}") print(f"[ERROR] First error: {exc.errors()[0] if exc.errors() else 'No errors'}") error_detail = [] for error in exc.errors(): error_detail.append({ "type": error.get("type"), "loc": error.get("loc"), "msg": error.get("msg") }) return JSONResponse( status_code=422, content={ "detail": error_detail, "message": "Request validation failed. Ensure Content-Type is multipart/form-data for file uploads.", "content_type_received": content_type } ) @api_app.exception_handler(StarletteHTTPException) async def http_exception_handler(request: Request, exc: StarletteHTTPException): return JSONResponse( status_code=exc.status_code, content={"detail": exc.detail} ) allowed_origins = [ "https://harshithreddy01.github.io", "https://harshithreddy01.github.io/frontend-SRMA-Liver", "https://app.paninsight.org", "http://localhost:5173", "http://localhost:3000", "http://localhost:8080", ] api_app.add_middleware( CORSMiddleware, allow_origins=allowed_origins, allow_credentials=True, allow_methods=["GET", "POST", "OPTIONS", "HEAD"], allow_headers=["*"], expose_headers=["*"], ) print(f"✓ CORS configured for origins: {allowed_origins}") @api_app.get("/") @api_app.get("/api") async def root(): return { "message": "SRMA-Mamba Liver Segmentation API", "version": "1.0.0", "endpoints": { "/api/segment": "POST - Upload NIfTI file for segmentation", "/api/health": "GET - Health check", "/api/download/{token}": "GET - Download segmentation mask", "/docs": "API documentation" } } @api_app.get("/health") @api_app.get("/api/health") async def health_check(): gpu_name = "unknown" gpu_memory_gb = 0.0 if torch.cuda.is_available() and model_loader.DEVICE and model_loader.DEVICE.type == 'cuda': try: gpu_name = torch.cuda.get_device_name(0) gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3) except: pass return { "status": "healthy", "device": str(model_loader.DEVICE) if model_loader.DEVICE else "not initialized", "model_t1_loaded": model_loader.MODEL_T1 is not None, "model_t2_loaded": model_loader.MODEL_T2 is not None, "gpu_name": gpu_name, "gpu_memory_gb": round(gpu_memory_gb, 1) } @api_app.get("/api/download/{token}") async def download_mask(token: str): import secrets token_dir = "/tmp/seg_tokens" token_path = os.path.join(token_dir, f"{token}.nii.gz") if not os.path.exists(token_path): raise HTTPException(status_code=404, detail="Token not found or expired") from fastapi.responses import FileResponse return FileResponse( token_path, media_type="application/gzip", filename=f"liver_segmentation_{token[:8]}.nii.gz", headers={"X-Token": token} ) @api_app.post("/segment", response_class=JSONResponse, include_in_schema=True) @api_app.post("/api/segment", response_class=JSONResponse, include_in_schema=True) async def segment_liver( request: Request, file: UploadFile = File(..., description="NIfTI file to segment"), modality: str = Form("T1", description="MRI modality: T1 or T2"), slice_idx: Optional[int] = Form(None, description="Optional slice index") ): print("=" * 60) print(f"API REQUEST RECEIVED: /api/segment") print(f" Origin: {request.headers.get('origin', 'N/A')}") print(f" Content-Type: {request.headers.get('content-type', 'N/A')}") print(f" File: {file.filename if file else 'None'}") print(f" File Content-Type: {file.content_type if file else 'None'}") print(f" Modality: {modality}") print(f" Slice Index: {slice_idx}") print("=" * 60) if not file or not file.filename: raise HTTPException(status_code=400, detail="File is required") if not file.filename.endswith(('.nii', '.nii.gz', '.gz')): raise HTTPException(status_code=400, detail="File must be a NIfTI file (.nii or .nii.gz)") if modality not in ['T1', 'T2']: raise HTTPException(status_code=400, detail="Modality must be 'T1' or 'T2'") content = await file.read() file_size_mb = len(content) / (1024**2) print(f"API: File size: {file_size_mb:.2f} MB") if file_size_mb > 2000: raise HTTPException( status_code=413, detail=f"File too large: {file_size_mb:.1f} MB. Maximum upload size is 2 GB. Please compress or resample your NIfTI file." ) with tempfile.NamedTemporaryFile(delete=False, suffix='.nii.gz') as tmp_file: tmp_file.write(content) tmp_path = tmp_file.name import secrets import shutil try: print(f"API: Starting inference for {modality} modality...") result = predict_volume_api(tmp_path, modality, slice_idx) print(f"API: Inference completed. Success: {result.get('success', False)}") if not result["success"]: raise HTTPException(status_code=500, detail=result.get("error", "Unknown error")) seg_path = result["segmentation_path"] seg_file_size = os.path.getsize(seg_path) / (1024**2) if seg_file_size > 2000: token = secrets.token_urlsafe(16) token_dir = "/tmp/seg_tokens" os.makedirs(token_dir, exist_ok=True) token_path = os.path.join(token_dir, f"{token}.nii.gz") shutil.copy2(seg_path, token_path) result["mask_path_token"] = token result["mask_download_url"] = f"/api/download/{token}" result["segmentation_file"] = None print(f"API: Large mask file ({seg_file_size:.1f} MB). Using token-based download: {token}") else: with open(seg_path, "rb") as seg_file: seg_data = seg_file.read() seg_base64 = base64.b64encode(seg_data).decode('utf-8') result["segmentation_file"] = f"data:application/octet-stream;base64,{seg_base64}" result["mask_path_token"] = None result["mask_download_url"] = None os.unlink(tmp_path) if seg_file_size <= 2000: os.unlink(seg_path) return JSONResponse(content=result) except Exception as e: if os.path.exists(tmp_path): os.unlink(tmp_path) raise HTTPException(status_code=500, detail=str(e)) def create_interface(): with gr.Blocks(title="SRMA-Mamba: Liver Segmentation", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # SRMA-Mamba: Liver Segmentation in MRI Volumes Upload a 3D NIfTI MRI volume (.nii.gz) to perform automatic liver segmentation. **Model Performance:** - Pixel Accuracy: 99.09% - IoU: 75% - PSNR: 29.64 dB **Supported Modalities:** T1-weighted and T2-weighted MRI **API Available:** This Space also provides a REST API. See `/docs` for documentation. """) with gr.Row(): with gr.Column(): nifti_input = gr.File( file_count="single", file_types=[".nii.gz", ".nii"], label="Upload NIfTI File (max 2 GB)" ) modality = gr.Radio( choices=["T1", "T2"], value="T1", label="MRI Modality" ) slice_slider = gr.Slider( minimum=0, maximum=100, value=50, step=1, label="Slice Index (will auto-update based on volume)", interactive=True ) predict_btn = gr.Button("Segment Liver", variant="primary") reset_btn = gr.Button("Reset", variant="secondary") with gr.Column(): output_image = gr.Image( label="Segmentation Overlay", type="pil" ) output_info = gr.Markdown( label="Statistics" ) output_report = gr.Markdown( label="Medical Report" ) output_file = gr.File( label="Download 3D Segmentation (.nii.gz)" ) gr.Markdown(""" ## Instructions 1. Upload a 3D NIfTI MRI volume (.nii.gz format) 2. Select the MRI modality (T1 or T2) 3. Click "Segment Liver" to run inference 4. View the segmentation overlay and download the 3D mask **Note:** First inference may take longer as the model loads. ## API Usage This Space provides a REST API for programmatic access: - **Endpoint**: `POST /api/segment` - **Documentation**: Visit `/docs` for interactive API docs - **Example**: Use from your frontend with `fetch()` or `axios()` """) use_gpu = False predict_fn = safe_predict_volume if config.HAS_SPACES: import spaces if torch.cuda.is_available(): try: test_tensor = torch.zeros(1).cuda() del test_tensor torch.cuda.empty_cache() use_gpu = True gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "Unknown" total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3) if torch.cuda.is_available() else 0 predict_fn = spaces.GPU(safe_predict_volume) print(f"✓ Using GPU acceleration: {gpu_name} ({total_vram:.1f} GB VRAM)") except ImportError: gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "Unknown" total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3) if torch.cuda.is_available() else 0 predict_fn = safe_predict_volume print(f"✓ Using GPU acceleration: {gpu_name} ({total_vram:.1f} GB VRAM)") except Exception as e: print(f"⚠ GPU available but failed to initialize: {e}") print("Falling back to CPU mode") use_gpu = False predict_fn = safe_predict_volume else: print("ℹ No GPU available. Running on CPU (slower but free)") predict_fn = safe_predict_volume else: if torch.cuda.is_available(): gpu_name = torch.cuda.get_device_name(0) total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3) print(f"✓ CUDA available, using GPU: {gpu_name} ({total_vram:.1f} GB VRAM)") use_gpu = True else: print("ℹ Running on CPU") predict_fn = safe_predict_volume predict_btn.click( fn=predict_fn, inputs=[nifti_input, modality, slice_slider], outputs=[output_image, output_info, output_report, output_file] ) def update_slice_slider(file): if file is None: return gr.update(maximum=100, value=50) try: file_path = file.name if hasattr(file, 'name') else str(file) if not os.path.exists(file_path): return gr.update(maximum=100, value=50) volume = nib.load(file_path).get_fdata() max_slices = volume.shape[-1] if len(volume.shape) == 3 else volume.shape[2] return gr.update(maximum=max_slices - 1, value=max_slices // 2) except Exception as e: print(f"Error updating slice slider: {e}") return gr.update(maximum=100, value=50) def reset_interface(): import inference inference.PROCESSING_LOCK = False if model_loader.DEVICE and model_loader.DEVICE.type == 'cuda': torch.cuda.empty_cache() return None, None, "", None nifti_input.change( fn=update_slice_slider, inputs=[nifti_input], outputs=[slice_slider], show_progress=False ) reset_btn.click( fn=reset_interface, inputs=[], outputs=[output_image, output_info, output_report, output_file] ) return demo demo = create_interface() import config demo.queue(max_size=config.MAX_GRADIO_CONCURRENCY) app = api_app if config.HAS_MOUNT_GRADIO_APP: from gradio.routes import mount_gradio_app print("Mounting Gradio interface on FastAPI app...") print("FastAPI routes registered before mounting:") for route in api_app.routes: if hasattr(route, 'path') and hasattr(route, 'methods'): print(f" - {list(route.methods)} {route.path}") print("\nMounting Gradio at path '/gradio' to avoid interfering with API routes") mount_gradio_app(app=api_app, blocks=demo, path="/gradio") print("\nFastAPI routes after mounting:") for route in api_app.routes: if hasattr(route, 'path') and hasattr(route, 'methods'): methods = list(route.methods) if hasattr(route, 'methods') else [] if 'POST' in methods and '/segment' in route.path: print(f" ✓ {methods} {route.path} - File upload endpoint") print("\n✓ FastAPI app with Gradio mounted at /gradio") print("✓ API endpoints available at:") print(" - GET /api/health") print(" - POST /api/segment (multipart/form-data)") print(" - GET /api/download/{token}") print(" - GET /docs (FastAPI docs)") print(" - GET /gradio (Gradio UI - optional)") else: print("⚠ Gradio mount not available. FastAPI will run without Gradio UI.") print("✓ API endpoints available at:") print(" - GET /api/health") print(" - POST /api/segment (multipart/form-data)") print(" - GET /api/download/{token}") print(" - GET /docs (FastAPI docs)") print(f"\n✓ Production server: FastAPI (app = api_app)") print(f"✓ ASGI app exported as 'app' for Uvicorn") if __name__ == "__main__": import uvicorn import os port = int(os.getenv("PORT", 7860)) print(f"\nStarting FastAPI server with Uvicorn on port {port}") print("This is for local development only.") print("In production, Hugging Face Spaces will use the 'app' variable directly.") uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")