Harshith Reddy
Fix: Always use FastAPI as server, remove Gradio-only fallback, remove demo.launch()
1bb4c2b
import os
if 'PYTORCH_ALLOC_CONF' not in os.environ:
os.environ['PYTORCH_ALLOC_CONF'] = 'expandable_segments:True,max_split_size_mb=128'
if 'TRITON_CACHE_DIR' not in os.environ:
triton_cache = '/tmp/.triton'
os.environ['TRITON_CACHE_DIR'] = triton_cache
try:
os.makedirs(triton_cache, exist_ok=True)
except:
pass
import tempfile
import base64
import torch
import nibabel as nib
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
from typing import Optional
# Import gradio first (needed for patching)
import gradio as gr
# Fix for Gradio 4.44.x schema bug: convert boolean additionalProperties to dict
def fix_gradio_schema_bug():
"""
Monkeypatch to fix Gradio 4.44.x crash when additionalProperties is boolean instead of dict.
The error occurs in gradio_client/utils.py -> get_type() when it tries:
if "const" in schema: # crashes if schema is bool (True/False)
This happens when Gradio infers schemas with additionalProperties: true
and then tries to process them as dicts.
"""
try:
# Patch gradio_client.utils.get_type - this is where the crash happens
try:
import gradio_client.utils as gradio_utils
if hasattr(gradio_utils, 'get_type'):
original_get_type = gradio_utils.get_type
def patched_get_type(schema):
"""Fix boolean additionalProperties and handle bool schemas."""
# Handle case where schema itself is a boolean (the actual bug)
if isinstance(schema, bool):
return "Any"
# Normalize additionalProperties: true -> {}
if isinstance(schema, dict):
# Recursively fix nested schemas in properties
if "properties" in schema and isinstance(schema["properties"], dict):
for prop_name, prop_schema in schema["properties"].items():
if isinstance(prop_schema, dict):
patched_get_type(prop_schema)
# Fix additionalProperties: true -> additionalProperties: {}
if "additionalProperties" in schema:
if schema["additionalProperties"] is True:
schema["additionalProperties"] = {}
elif schema["additionalProperties"] is False:
# False means no additional properties allowed
schema.pop("additionalProperties", None)
elif isinstance(schema["additionalProperties"], dict):
# Recursively fix nested additionalProperties
patched_get_type(schema["additionalProperties"])
return original_get_type(schema)
gradio_utils.get_type = patched_get_type
print("βœ“ Applied Gradio schema bug fix (gradio_client.utils.get_type)")
except (ImportError, AttributeError) as e:
print(f"⚠ Could not patch gradio_client.utils: {e}")
# Also patch gradio's API info generation to normalize schemas
try:
if hasattr(gr, 'Blocks'):
# Patch the _get_api_info method if it exists
if hasattr(gr.Blocks, '_get_api_info'):
original_get_api_info = gr.Blocks._get_api_info
def patched_get_api_info(self):
"""Normalize schemas before API info generation."""
api_info = original_get_api_info(self)
if api_info and isinstance(api_info, dict):
def normalize_schema(schema):
if isinstance(schema, dict):
if "additionalProperties" in schema and schema["additionalProperties"] is True:
schema["additionalProperties"] = {}
if "properties" in schema:
for prop in schema["properties"].values():
normalize_schema(prop)
# Normalize all schemas in api_info
if "named_endpoints" in api_info:
for endpoint_info in api_info["named_endpoints"].values():
if "parameters" in endpoint_info:
for param in endpoint_info["parameters"]:
if "component" in param and "serializer" in param["component"]:
if "schema" in param["component"]["serializer"]:
normalize_schema(param["component"]["serializer"]["schema"])
return api_info
gr.Blocks._get_api_info = patched_get_api_info
print("βœ“ Applied Gradio schema bug fix (Blocks._get_api_info)")
except Exception as e:
print(f"⚠ Could not patch Blocks._get_api_info: {e}")
print("βœ“ Gradio schema bug fix applied successfully")
except Exception as e:
print(f"⚠ Could not apply Gradio schema fix: {e}")
fix_gradio_schema_bug()
import config
import model_loader
from inference import predict_volume_api, safe_predict_volume, PROCESSING_LOCK
def log_startup_health():
print("=" * 60)
print("STARTUP HEALTH CHECK")
print("=" * 60)
import torch
import os
try:
is_docker = os.path.exists('/.dockerenv') or (os.path.exists('/proc/self/cgroup') and 'docker' in open('/proc/self/cgroup').read())
except:
is_docker = False
is_python_space = '/home/user/.pyenv' in os.environ.get('PATH', '') or os.path.exists('/home/user/.pyenv')
if is_python_space and not is_docker:
print("WARNING: Running on Python Space (managed environment)")
print("Python Spaces do NOT have CUDA toolchain (nvcc) - CUDA extensions CANNOT build")
print("ACTION REQUIRED: Switch to Docker Space in HF Settings -> Runtime -> Docker")
print("Without Docker, CUDA extensions will remain NOT INSTALLED (slow fallback)")
print("=" * 60)
elif is_docker:
print("Running in Docker container - CUDA extensions should be available")
if torch.cuda.is_available():
import config
from packaging import version
torch_version = version.parse(torch.__version__)
if torch_version >= version.parse("2.9.0"):
torch.backends.cuda.matmul.fp32_precision = 'tf32'
torch.backends.cudnn.conv.fp32_precision = 'tf32'
else:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
print(f"PyTorch: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA Version: {torch.version.cuda}")
print(f"cuDNN Version: {torch.backends.cudnn.version()}")
gpu_name = torch.cuda.get_device_name(0)
gpu_props = torch.cuda.get_device_properties(0)
total_memory_gb = gpu_props.total_memory / (1024**3)
print(f"GPU: {gpu_name}")
print(f"GPU Memory: {total_memory_gb:.1f} GB total")
allocated = torch.cuda.memory_allocated(0) / (1024**3)
reserved = torch.cuda.memory_reserved(0) / (1024**3)
free = total_memory_gb - allocated
print(f"GPU Memory Status: {allocated:.2f} GB allocated, {reserved:.2f} GB reserved, {free:.2f} GB free")
from packaging import version
torch_version = version.parse(torch.__version__)
if torch_version >= version.parse("2.9.0"):
tf32_matmul = getattr(torch.backends.cuda.matmul, 'fp32_precision', 'unknown')
tf32_conv = getattr(torch.backends.cudnn.conv, 'fp32_precision', 'unknown')
else:
tf32_matmul = 'tf32' if torch.backends.cuda.matmul.allow_tf32 else 'ieee'
tf32_conv = 'tf32' if torch.backends.cudnn.allow_tf32 else 'ieee'
cudnn_benchmark = torch.backends.cudnn.benchmark
print(f"TF32 Matmul: {tf32_matmul}")
print(f"TF32 Conv: {tf32_conv}")
print(f"cuDNN Benchmark: {cudnn_benchmark}")
compile_enabled = config.ENABLE_TORCH_COMPILE
compile_mode = os.environ.get('TORCH_COMPILE_MODE', 'reduce-overhead') if compile_enabled else 'disabled'
print(f"torch.compile: {compile_enabled} (mode: {compile_mode})")
try:
import monai
print(f"MONAI: {monai.__version__}")
except:
print("MONAI: NOT FOUND")
try:
import gradio
print(f"Gradio: {gradio.__version__}")
except:
print("Gradio: NOT FOUND")
try:
import nibabel
print(f"NiBabel: {nibabel.__version__}")
except:
print("NiBabel: NOT FOUND")
print("\nCUDA Extensions Status:")
try:
import mamba_ssm
try:
version = mamba_ssm.__version__
print(f" mamba_ssm: INSTALLED (version: {version})")
except:
print(" mamba_ssm: INSTALLED")
except ImportError:
print(" mamba_ssm: NOT INSTALLED (will use fallback - CRITICAL for speed)")
try:
import selective_scan_cuda_oflex
print(" selective_scan_cuda_oflex: INSTALLED")
except ImportError:
print(" selective_scan_cuda_oflex: NOT INSTALLED (will use fallback - CRITICAL for speed)")
print("\nEnvironment Variables:")
alloc_conf = os.environ.get('PYTORCH_ALLOC_CONF', 'not set')
print(f" PYTORCH_ALLOC_CONF: {alloc_conf}")
print(f" ENABLE_CUDNN_BENCHMARK: {config.ENABLE_CUDNN_BENCHMARK}")
print(f" ENABLE_TORCH_COMPILE: {config.ENABLE_TORCH_COMPILE}")
print(f" INFERENCE_TIMEOUT: {config.INFERENCE_TIMEOUT}s")
print(f" MAX_GRADIO_CONCURRENCY: {config.MAX_GRADIO_CONCURRENCY}")
print("=" * 60)
log_startup_health()
api_app = FastAPI(title="SRMA-Mamba API", version="1.0.0")
@api_app.middleware("http")
async def log_requests(request, call_next):
path = str(request.url.path)
print(f"[REQUEST] {request.method} {path}")
content_type = request.headers.get('content-type', 'N/A')
print(f"[REQUEST] Content-Type: {content_type}")
if '/segment' in path:
is_multipart = 'multipart' in content_type.lower()
print(f"[REQUEST] Is multipart: {is_multipart}")
if not is_multipart:
print(f"[REQUEST] WARNING: Expected multipart/form-data but got: {content_type}")
print(f"[REQUEST] Headers: {dict(request.headers)}")
response = await call_next(request)
print(f"[RESPONSE] {request.method} {path} -> {response.status_code}")
return response
@api_app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
print(f"[ERROR] RequestValidationError on {request.method} {request.url.path}")
content_type = request.headers.get('content-type', 'N/A')
print(f"[ERROR] Content-Type: {content_type}")
print(f"[ERROR] First error: {exc.errors()[0] if exc.errors() else 'No errors'}")
error_detail = []
for error in exc.errors():
error_detail.append({
"type": error.get("type"),
"loc": error.get("loc"),
"msg": error.get("msg")
})
return JSONResponse(
status_code=422,
content={
"detail": error_detail,
"message": "Request validation failed. Ensure Content-Type is multipart/form-data for file uploads.",
"content_type_received": content_type
}
)
@api_app.exception_handler(StarletteHTTPException)
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
return JSONResponse(
status_code=exc.status_code,
content={"detail": exc.detail}
)
allowed_origins = [
"https://harshithreddy01.github.io",
"https://harshithreddy01.github.io/frontend-SRMA-Liver",
"https://app.paninsight.org",
"http://localhost:5173",
"http://localhost:3000",
"http://localhost:8080",
]
api_app.add_middleware(
CORSMiddleware,
allow_origins=allowed_origins,
allow_credentials=True,
allow_methods=["GET", "POST", "OPTIONS", "HEAD"],
allow_headers=["*"],
expose_headers=["*"],
)
print(f"βœ“ CORS configured for origins: {allowed_origins}")
@api_app.get("/")
@api_app.get("/api")
async def root():
return {
"message": "SRMA-Mamba Liver Segmentation API",
"version": "1.0.0",
"endpoints": {
"/api/segment": "POST - Upload NIfTI file for segmentation",
"/api/health": "GET - Health check",
"/api/download/{token}": "GET - Download segmentation mask",
"/docs": "API documentation"
}
}
@api_app.get("/health")
@api_app.get("/api/health")
async def health_check():
gpu_name = "unknown"
gpu_memory_gb = 0.0
if torch.cuda.is_available() and model_loader.DEVICE and model_loader.DEVICE.type == 'cuda':
try:
gpu_name = torch.cuda.get_device_name(0)
gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
except:
pass
return {
"status": "healthy",
"device": str(model_loader.DEVICE) if model_loader.DEVICE else "not initialized",
"model_t1_loaded": model_loader.MODEL_T1 is not None,
"model_t2_loaded": model_loader.MODEL_T2 is not None,
"gpu_name": gpu_name,
"gpu_memory_gb": round(gpu_memory_gb, 1)
}
@api_app.get("/api/download/{token}")
async def download_mask(token: str):
import secrets
token_dir = "/tmp/seg_tokens"
token_path = os.path.join(token_dir, f"{token}.nii.gz")
if not os.path.exists(token_path):
raise HTTPException(status_code=404, detail="Token not found or expired")
from fastapi.responses import FileResponse
return FileResponse(
token_path,
media_type="application/gzip",
filename=f"liver_segmentation_{token[:8]}.nii.gz",
headers={"X-Token": token}
)
@api_app.post("/segment", response_class=JSONResponse, include_in_schema=True)
@api_app.post("/api/segment", response_class=JSONResponse, include_in_schema=True)
async def segment_liver(
request: Request,
file: UploadFile = File(..., description="NIfTI file to segment"),
modality: str = Form("T1", description="MRI modality: T1 or T2"),
slice_idx: Optional[int] = Form(None, description="Optional slice index")
):
print("=" * 60)
print(f"API REQUEST RECEIVED: /api/segment")
print(f" Origin: {request.headers.get('origin', 'N/A')}")
print(f" Content-Type: {request.headers.get('content-type', 'N/A')}")
print(f" File: {file.filename if file else 'None'}")
print(f" File Content-Type: {file.content_type if file else 'None'}")
print(f" Modality: {modality}")
print(f" Slice Index: {slice_idx}")
print("=" * 60)
if not file or not file.filename:
raise HTTPException(status_code=400, detail="File is required")
if not file.filename.endswith(('.nii', '.nii.gz', '.gz')):
raise HTTPException(status_code=400, detail="File must be a NIfTI file (.nii or .nii.gz)")
if modality not in ['T1', 'T2']:
raise HTTPException(status_code=400, detail="Modality must be 'T1' or 'T2'")
content = await file.read()
file_size_mb = len(content) / (1024**2)
print(f"API: File size: {file_size_mb:.2f} MB")
if file_size_mb > 2000:
raise HTTPException(
status_code=413,
detail=f"File too large: {file_size_mb:.1f} MB. Maximum upload size is 2 GB. Please compress or resample your NIfTI file."
)
with tempfile.NamedTemporaryFile(delete=False, suffix='.nii.gz') as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name
import secrets
import shutil
try:
print(f"API: Starting inference for {modality} modality...")
result = predict_volume_api(tmp_path, modality, slice_idx)
print(f"API: Inference completed. Success: {result.get('success', False)}")
if not result["success"]:
raise HTTPException(status_code=500, detail=result.get("error", "Unknown error"))
seg_path = result["segmentation_path"]
seg_file_size = os.path.getsize(seg_path) / (1024**2)
if seg_file_size > 2000:
token = secrets.token_urlsafe(16)
token_dir = "/tmp/seg_tokens"
os.makedirs(token_dir, exist_ok=True)
token_path = os.path.join(token_dir, f"{token}.nii.gz")
shutil.copy2(seg_path, token_path)
result["mask_path_token"] = token
result["mask_download_url"] = f"/api/download/{token}"
result["segmentation_file"] = None
print(f"API: Large mask file ({seg_file_size:.1f} MB). Using token-based download: {token}")
else:
with open(seg_path, "rb") as seg_file:
seg_data = seg_file.read()
seg_base64 = base64.b64encode(seg_data).decode('utf-8')
result["segmentation_file"] = f"data:application/octet-stream;base64,{seg_base64}"
result["mask_path_token"] = None
result["mask_download_url"] = None
os.unlink(tmp_path)
if seg_file_size <= 2000:
os.unlink(seg_path)
return JSONResponse(content=result)
except Exception as e:
if os.path.exists(tmp_path):
os.unlink(tmp_path)
raise HTTPException(status_code=500, detail=str(e))
def create_interface():
with gr.Blocks(title="SRMA-Mamba: Liver Segmentation", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# SRMA-Mamba: Liver Segmentation in MRI Volumes
Upload a 3D NIfTI MRI volume (.nii.gz) to perform automatic liver segmentation.
**Model Performance:**
- Pixel Accuracy: 99.09%
- IoU: 75%
- PSNR: 29.64 dB
**Supported Modalities:** T1-weighted and T2-weighted MRI
**API Available:** This Space also provides a REST API. See `/docs` for documentation.
""")
with gr.Row():
with gr.Column():
nifti_input = gr.File(
file_count="single",
file_types=[".nii.gz", ".nii"],
label="Upload NIfTI File (max 2 GB)"
)
modality = gr.Radio(
choices=["T1", "T2"],
value="T1",
label="MRI Modality"
)
slice_slider = gr.Slider(
minimum=0,
maximum=100,
value=50,
step=1,
label="Slice Index (will auto-update based on volume)",
interactive=True
)
predict_btn = gr.Button("Segment Liver", variant="primary")
reset_btn = gr.Button("Reset", variant="secondary")
with gr.Column():
output_image = gr.Image(
label="Segmentation Overlay",
type="pil"
)
output_info = gr.Markdown(
label="Statistics"
)
output_report = gr.Markdown(
label="Medical Report"
)
output_file = gr.File(
label="Download 3D Segmentation (.nii.gz)"
)
gr.Markdown("""
## Instructions
1. Upload a 3D NIfTI MRI volume (.nii.gz format)
2. Select the MRI modality (T1 or T2)
3. Click "Segment Liver" to run inference
4. View the segmentation overlay and download the 3D mask
**Note:** First inference may take longer as the model loads.
## API Usage
This Space provides a REST API for programmatic access:
- **Endpoint**: `POST /api/segment`
- **Documentation**: Visit `/docs` for interactive API docs
- **Example**: Use from your frontend with `fetch()` or `axios()`
""")
use_gpu = False
predict_fn = safe_predict_volume
if config.HAS_SPACES:
import spaces
if torch.cuda.is_available():
try:
test_tensor = torch.zeros(1).cuda()
del test_tensor
torch.cuda.empty_cache()
use_gpu = True
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "Unknown"
total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3) if torch.cuda.is_available() else 0
predict_fn = spaces.GPU(safe_predict_volume)
print(f"βœ“ Using GPU acceleration: {gpu_name} ({total_vram:.1f} GB VRAM)")
except ImportError:
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "Unknown"
total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3) if torch.cuda.is_available() else 0
predict_fn = safe_predict_volume
print(f"βœ“ Using GPU acceleration: {gpu_name} ({total_vram:.1f} GB VRAM)")
except Exception as e:
print(f"⚠ GPU available but failed to initialize: {e}")
print("Falling back to CPU mode")
use_gpu = False
predict_fn = safe_predict_volume
else:
print("β„Ή No GPU available. Running on CPU (slower but free)")
predict_fn = safe_predict_volume
else:
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3)
print(f"βœ“ CUDA available, using GPU: {gpu_name} ({total_vram:.1f} GB VRAM)")
use_gpu = True
else:
print("β„Ή Running on CPU")
predict_fn = safe_predict_volume
predict_btn.click(
fn=predict_fn,
inputs=[nifti_input, modality, slice_slider],
outputs=[output_image, output_info, output_report, output_file]
)
def update_slice_slider(file):
if file is None:
return gr.update(maximum=100, value=50)
try:
file_path = file.name if hasattr(file, 'name') else str(file)
if not os.path.exists(file_path):
return gr.update(maximum=100, value=50)
volume = nib.load(file_path).get_fdata()
max_slices = volume.shape[-1] if len(volume.shape) == 3 else volume.shape[2]
return gr.update(maximum=max_slices - 1, value=max_slices // 2)
except Exception as e:
print(f"Error updating slice slider: {e}")
return gr.update(maximum=100, value=50)
def reset_interface():
import inference
inference.PROCESSING_LOCK = False
if model_loader.DEVICE and model_loader.DEVICE.type == 'cuda':
torch.cuda.empty_cache()
return None, None, "", None
nifti_input.change(
fn=update_slice_slider,
inputs=[nifti_input],
outputs=[slice_slider],
show_progress=False
)
reset_btn.click(
fn=reset_interface,
inputs=[],
outputs=[output_image, output_info, output_report, output_file]
)
return demo
demo = create_interface()
import config
demo.queue(max_size=config.MAX_GRADIO_CONCURRENCY)
app = api_app
if config.HAS_MOUNT_GRADIO_APP:
from gradio.routes import mount_gradio_app
print("Mounting Gradio interface on FastAPI app...")
print("FastAPI routes registered before mounting:")
for route in api_app.routes:
if hasattr(route, 'path') and hasattr(route, 'methods'):
print(f" - {list(route.methods)} {route.path}")
print("\nMounting Gradio at path '/gradio' to avoid interfering with API routes")
mount_gradio_app(app=api_app, blocks=demo, path="/gradio")
print("\nFastAPI routes after mounting:")
for route in api_app.routes:
if hasattr(route, 'path') and hasattr(route, 'methods'):
methods = list(route.methods) if hasattr(route, 'methods') else []
if 'POST' in methods and '/segment' in route.path:
print(f" βœ“ {methods} {route.path} - File upload endpoint")
print("\nβœ“ FastAPI app with Gradio mounted at /gradio")
print("βœ“ API endpoints available at:")
print(" - GET /api/health")
print(" - POST /api/segment (multipart/form-data)")
print(" - GET /api/download/{token}")
print(" - GET /docs (FastAPI docs)")
print(" - GET /gradio (Gradio UI - optional)")
else:
print("⚠ Gradio mount not available. FastAPI will run without Gradio UI.")
print("βœ“ API endpoints available at:")
print(" - GET /api/health")
print(" - POST /api/segment (multipart/form-data)")
print(" - GET /api/download/{token}")
print(" - GET /docs (FastAPI docs)")
print(f"\nβœ“ Production server: FastAPI (app = api_app)")
print(f"βœ“ ASGI app exported as 'app' for Uvicorn")
if __name__ == "__main__":
import uvicorn
import os
port = int(os.getenv("PORT", 7860))
print(f"\nStarting FastAPI server with Uvicorn on port {port}")
print("This is for local development only.")
print("In production, Hugging Face Spaces will use the 'app' variable directly.")
uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")