Harshith Reddy
Fix: Always use FastAPI as server, remove Gradio-only fallback, remove demo.launch()
1bb4c2b
| import os | |
| if 'PYTORCH_ALLOC_CONF' not in os.environ: | |
| os.environ['PYTORCH_ALLOC_CONF'] = 'expandable_segments:True,max_split_size_mb=128' | |
| if 'TRITON_CACHE_DIR' not in os.environ: | |
| triton_cache = '/tmp/.triton' | |
| os.environ['TRITON_CACHE_DIR'] = triton_cache | |
| try: | |
| os.makedirs(triton_cache, exist_ok=True) | |
| except: | |
| pass | |
| import tempfile | |
| import base64 | |
| import torch | |
| import nibabel as nib | |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from fastapi.exceptions import RequestValidationError | |
| from starlette.exceptions import HTTPException as StarletteHTTPException | |
| from typing import Optional | |
| # Import gradio first (needed for patching) | |
| import gradio as gr | |
| # Fix for Gradio 4.44.x schema bug: convert boolean additionalProperties to dict | |
| def fix_gradio_schema_bug(): | |
| """ | |
| Monkeypatch to fix Gradio 4.44.x crash when additionalProperties is boolean instead of dict. | |
| The error occurs in gradio_client/utils.py -> get_type() when it tries: | |
| if "const" in schema: # crashes if schema is bool (True/False) | |
| This happens when Gradio infers schemas with additionalProperties: true | |
| and then tries to process them as dicts. | |
| """ | |
| try: | |
| # Patch gradio_client.utils.get_type - this is where the crash happens | |
| try: | |
| import gradio_client.utils as gradio_utils | |
| if hasattr(gradio_utils, 'get_type'): | |
| original_get_type = gradio_utils.get_type | |
| def patched_get_type(schema): | |
| """Fix boolean additionalProperties and handle bool schemas.""" | |
| # Handle case where schema itself is a boolean (the actual bug) | |
| if isinstance(schema, bool): | |
| return "Any" | |
| # Normalize additionalProperties: true -> {} | |
| if isinstance(schema, dict): | |
| # Recursively fix nested schemas in properties | |
| if "properties" in schema and isinstance(schema["properties"], dict): | |
| for prop_name, prop_schema in schema["properties"].items(): | |
| if isinstance(prop_schema, dict): | |
| patched_get_type(prop_schema) | |
| # Fix additionalProperties: true -> additionalProperties: {} | |
| if "additionalProperties" in schema: | |
| if schema["additionalProperties"] is True: | |
| schema["additionalProperties"] = {} | |
| elif schema["additionalProperties"] is False: | |
| # False means no additional properties allowed | |
| schema.pop("additionalProperties", None) | |
| elif isinstance(schema["additionalProperties"], dict): | |
| # Recursively fix nested additionalProperties | |
| patched_get_type(schema["additionalProperties"]) | |
| return original_get_type(schema) | |
| gradio_utils.get_type = patched_get_type | |
| print("β Applied Gradio schema bug fix (gradio_client.utils.get_type)") | |
| except (ImportError, AttributeError) as e: | |
| print(f"β Could not patch gradio_client.utils: {e}") | |
| # Also patch gradio's API info generation to normalize schemas | |
| try: | |
| if hasattr(gr, 'Blocks'): | |
| # Patch the _get_api_info method if it exists | |
| if hasattr(gr.Blocks, '_get_api_info'): | |
| original_get_api_info = gr.Blocks._get_api_info | |
| def patched_get_api_info(self): | |
| """Normalize schemas before API info generation.""" | |
| api_info = original_get_api_info(self) | |
| if api_info and isinstance(api_info, dict): | |
| def normalize_schema(schema): | |
| if isinstance(schema, dict): | |
| if "additionalProperties" in schema and schema["additionalProperties"] is True: | |
| schema["additionalProperties"] = {} | |
| if "properties" in schema: | |
| for prop in schema["properties"].values(): | |
| normalize_schema(prop) | |
| # Normalize all schemas in api_info | |
| if "named_endpoints" in api_info: | |
| for endpoint_info in api_info["named_endpoints"].values(): | |
| if "parameters" in endpoint_info: | |
| for param in endpoint_info["parameters"]: | |
| if "component" in param and "serializer" in param["component"]: | |
| if "schema" in param["component"]["serializer"]: | |
| normalize_schema(param["component"]["serializer"]["schema"]) | |
| return api_info | |
| gr.Blocks._get_api_info = patched_get_api_info | |
| print("β Applied Gradio schema bug fix (Blocks._get_api_info)") | |
| except Exception as e: | |
| print(f"β Could not patch Blocks._get_api_info: {e}") | |
| print("β Gradio schema bug fix applied successfully") | |
| except Exception as e: | |
| print(f"β Could not apply Gradio schema fix: {e}") | |
| fix_gradio_schema_bug() | |
| import config | |
| import model_loader | |
| from inference import predict_volume_api, safe_predict_volume, PROCESSING_LOCK | |
| def log_startup_health(): | |
| print("=" * 60) | |
| print("STARTUP HEALTH CHECK") | |
| print("=" * 60) | |
| import torch | |
| import os | |
| try: | |
| is_docker = os.path.exists('/.dockerenv') or (os.path.exists('/proc/self/cgroup') and 'docker' in open('/proc/self/cgroup').read()) | |
| except: | |
| is_docker = False | |
| is_python_space = '/home/user/.pyenv' in os.environ.get('PATH', '') or os.path.exists('/home/user/.pyenv') | |
| if is_python_space and not is_docker: | |
| print("WARNING: Running on Python Space (managed environment)") | |
| print("Python Spaces do NOT have CUDA toolchain (nvcc) - CUDA extensions CANNOT build") | |
| print("ACTION REQUIRED: Switch to Docker Space in HF Settings -> Runtime -> Docker") | |
| print("Without Docker, CUDA extensions will remain NOT INSTALLED (slow fallback)") | |
| print("=" * 60) | |
| elif is_docker: | |
| print("Running in Docker container - CUDA extensions should be available") | |
| if torch.cuda.is_available(): | |
| import config | |
| from packaging import version | |
| torch_version = version.parse(torch.__version__) | |
| if torch_version >= version.parse("2.9.0"): | |
| torch.backends.cuda.matmul.fp32_precision = 'tf32' | |
| torch.backends.cudnn.conv.fp32_precision = 'tf32' | |
| else: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.backends.cudnn.benchmark = True | |
| print(f"PyTorch: {torch.__version__}") | |
| print(f"CUDA Available: {torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| print(f"CUDA Version: {torch.version.cuda}") | |
| print(f"cuDNN Version: {torch.backends.cudnn.version()}") | |
| gpu_name = torch.cuda.get_device_name(0) | |
| gpu_props = torch.cuda.get_device_properties(0) | |
| total_memory_gb = gpu_props.total_memory / (1024**3) | |
| print(f"GPU: {gpu_name}") | |
| print(f"GPU Memory: {total_memory_gb:.1f} GB total") | |
| allocated = torch.cuda.memory_allocated(0) / (1024**3) | |
| reserved = torch.cuda.memory_reserved(0) / (1024**3) | |
| free = total_memory_gb - allocated | |
| print(f"GPU Memory Status: {allocated:.2f} GB allocated, {reserved:.2f} GB reserved, {free:.2f} GB free") | |
| from packaging import version | |
| torch_version = version.parse(torch.__version__) | |
| if torch_version >= version.parse("2.9.0"): | |
| tf32_matmul = getattr(torch.backends.cuda.matmul, 'fp32_precision', 'unknown') | |
| tf32_conv = getattr(torch.backends.cudnn.conv, 'fp32_precision', 'unknown') | |
| else: | |
| tf32_matmul = 'tf32' if torch.backends.cuda.matmul.allow_tf32 else 'ieee' | |
| tf32_conv = 'tf32' if torch.backends.cudnn.allow_tf32 else 'ieee' | |
| cudnn_benchmark = torch.backends.cudnn.benchmark | |
| print(f"TF32 Matmul: {tf32_matmul}") | |
| print(f"TF32 Conv: {tf32_conv}") | |
| print(f"cuDNN Benchmark: {cudnn_benchmark}") | |
| compile_enabled = config.ENABLE_TORCH_COMPILE | |
| compile_mode = os.environ.get('TORCH_COMPILE_MODE', 'reduce-overhead') if compile_enabled else 'disabled' | |
| print(f"torch.compile: {compile_enabled} (mode: {compile_mode})") | |
| try: | |
| import monai | |
| print(f"MONAI: {monai.__version__}") | |
| except: | |
| print("MONAI: NOT FOUND") | |
| try: | |
| import gradio | |
| print(f"Gradio: {gradio.__version__}") | |
| except: | |
| print("Gradio: NOT FOUND") | |
| try: | |
| import nibabel | |
| print(f"NiBabel: {nibabel.__version__}") | |
| except: | |
| print("NiBabel: NOT FOUND") | |
| print("\nCUDA Extensions Status:") | |
| try: | |
| import mamba_ssm | |
| try: | |
| version = mamba_ssm.__version__ | |
| print(f" mamba_ssm: INSTALLED (version: {version})") | |
| except: | |
| print(" mamba_ssm: INSTALLED") | |
| except ImportError: | |
| print(" mamba_ssm: NOT INSTALLED (will use fallback - CRITICAL for speed)") | |
| try: | |
| import selective_scan_cuda_oflex | |
| print(" selective_scan_cuda_oflex: INSTALLED") | |
| except ImportError: | |
| print(" selective_scan_cuda_oflex: NOT INSTALLED (will use fallback - CRITICAL for speed)") | |
| print("\nEnvironment Variables:") | |
| alloc_conf = os.environ.get('PYTORCH_ALLOC_CONF', 'not set') | |
| print(f" PYTORCH_ALLOC_CONF: {alloc_conf}") | |
| print(f" ENABLE_CUDNN_BENCHMARK: {config.ENABLE_CUDNN_BENCHMARK}") | |
| print(f" ENABLE_TORCH_COMPILE: {config.ENABLE_TORCH_COMPILE}") | |
| print(f" INFERENCE_TIMEOUT: {config.INFERENCE_TIMEOUT}s") | |
| print(f" MAX_GRADIO_CONCURRENCY: {config.MAX_GRADIO_CONCURRENCY}") | |
| print("=" * 60) | |
| log_startup_health() | |
| api_app = FastAPI(title="SRMA-Mamba API", version="1.0.0") | |
| async def log_requests(request, call_next): | |
| path = str(request.url.path) | |
| print(f"[REQUEST] {request.method} {path}") | |
| content_type = request.headers.get('content-type', 'N/A') | |
| print(f"[REQUEST] Content-Type: {content_type}") | |
| if '/segment' in path: | |
| is_multipart = 'multipart' in content_type.lower() | |
| print(f"[REQUEST] Is multipart: {is_multipart}") | |
| if not is_multipart: | |
| print(f"[REQUEST] WARNING: Expected multipart/form-data but got: {content_type}") | |
| print(f"[REQUEST] Headers: {dict(request.headers)}") | |
| response = await call_next(request) | |
| print(f"[RESPONSE] {request.method} {path} -> {response.status_code}") | |
| return response | |
| async def validation_exception_handler(request: Request, exc: RequestValidationError): | |
| print(f"[ERROR] RequestValidationError on {request.method} {request.url.path}") | |
| content_type = request.headers.get('content-type', 'N/A') | |
| print(f"[ERROR] Content-Type: {content_type}") | |
| print(f"[ERROR] First error: {exc.errors()[0] if exc.errors() else 'No errors'}") | |
| error_detail = [] | |
| for error in exc.errors(): | |
| error_detail.append({ | |
| "type": error.get("type"), | |
| "loc": error.get("loc"), | |
| "msg": error.get("msg") | |
| }) | |
| return JSONResponse( | |
| status_code=422, | |
| content={ | |
| "detail": error_detail, | |
| "message": "Request validation failed. Ensure Content-Type is multipart/form-data for file uploads.", | |
| "content_type_received": content_type | |
| } | |
| ) | |
| async def http_exception_handler(request: Request, exc: StarletteHTTPException): | |
| return JSONResponse( | |
| status_code=exc.status_code, | |
| content={"detail": exc.detail} | |
| ) | |
| allowed_origins = [ | |
| "https://harshithreddy01.github.io", | |
| "https://harshithreddy01.github.io/frontend-SRMA-Liver", | |
| "https://app.paninsight.org", | |
| "http://localhost:5173", | |
| "http://localhost:3000", | |
| "http://localhost:8080", | |
| ] | |
| api_app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=allowed_origins, | |
| allow_credentials=True, | |
| allow_methods=["GET", "POST", "OPTIONS", "HEAD"], | |
| allow_headers=["*"], | |
| expose_headers=["*"], | |
| ) | |
| print(f"β CORS configured for origins: {allowed_origins}") | |
| async def root(): | |
| return { | |
| "message": "SRMA-Mamba Liver Segmentation API", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "/api/segment": "POST - Upload NIfTI file for segmentation", | |
| "/api/health": "GET - Health check", | |
| "/api/download/{token}": "GET - Download segmentation mask", | |
| "/docs": "API documentation" | |
| } | |
| } | |
| async def health_check(): | |
| gpu_name = "unknown" | |
| gpu_memory_gb = 0.0 | |
| if torch.cuda.is_available() and model_loader.DEVICE and model_loader.DEVICE.type == 'cuda': | |
| try: | |
| gpu_name = torch.cuda.get_device_name(0) | |
| gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3) | |
| except: | |
| pass | |
| return { | |
| "status": "healthy", | |
| "device": str(model_loader.DEVICE) if model_loader.DEVICE else "not initialized", | |
| "model_t1_loaded": model_loader.MODEL_T1 is not None, | |
| "model_t2_loaded": model_loader.MODEL_T2 is not None, | |
| "gpu_name": gpu_name, | |
| "gpu_memory_gb": round(gpu_memory_gb, 1) | |
| } | |
| async def download_mask(token: str): | |
| import secrets | |
| token_dir = "/tmp/seg_tokens" | |
| token_path = os.path.join(token_dir, f"{token}.nii.gz") | |
| if not os.path.exists(token_path): | |
| raise HTTPException(status_code=404, detail="Token not found or expired") | |
| from fastapi.responses import FileResponse | |
| return FileResponse( | |
| token_path, | |
| media_type="application/gzip", | |
| filename=f"liver_segmentation_{token[:8]}.nii.gz", | |
| headers={"X-Token": token} | |
| ) | |
| async def segment_liver( | |
| request: Request, | |
| file: UploadFile = File(..., description="NIfTI file to segment"), | |
| modality: str = Form("T1", description="MRI modality: T1 or T2"), | |
| slice_idx: Optional[int] = Form(None, description="Optional slice index") | |
| ): | |
| print("=" * 60) | |
| print(f"API REQUEST RECEIVED: /api/segment") | |
| print(f" Origin: {request.headers.get('origin', 'N/A')}") | |
| print(f" Content-Type: {request.headers.get('content-type', 'N/A')}") | |
| print(f" File: {file.filename if file else 'None'}") | |
| print(f" File Content-Type: {file.content_type if file else 'None'}") | |
| print(f" Modality: {modality}") | |
| print(f" Slice Index: {slice_idx}") | |
| print("=" * 60) | |
| if not file or not file.filename: | |
| raise HTTPException(status_code=400, detail="File is required") | |
| if not file.filename.endswith(('.nii', '.nii.gz', '.gz')): | |
| raise HTTPException(status_code=400, detail="File must be a NIfTI file (.nii or .nii.gz)") | |
| if modality not in ['T1', 'T2']: | |
| raise HTTPException(status_code=400, detail="Modality must be 'T1' or 'T2'") | |
| content = await file.read() | |
| file_size_mb = len(content) / (1024**2) | |
| print(f"API: File size: {file_size_mb:.2f} MB") | |
| if file_size_mb > 2000: | |
| raise HTTPException( | |
| status_code=413, | |
| detail=f"File too large: {file_size_mb:.1f} MB. Maximum upload size is 2 GB. Please compress or resample your NIfTI file." | |
| ) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.nii.gz') as tmp_file: | |
| tmp_file.write(content) | |
| tmp_path = tmp_file.name | |
| import secrets | |
| import shutil | |
| try: | |
| print(f"API: Starting inference for {modality} modality...") | |
| result = predict_volume_api(tmp_path, modality, slice_idx) | |
| print(f"API: Inference completed. Success: {result.get('success', False)}") | |
| if not result["success"]: | |
| raise HTTPException(status_code=500, detail=result.get("error", "Unknown error")) | |
| seg_path = result["segmentation_path"] | |
| seg_file_size = os.path.getsize(seg_path) / (1024**2) | |
| if seg_file_size > 2000: | |
| token = secrets.token_urlsafe(16) | |
| token_dir = "/tmp/seg_tokens" | |
| os.makedirs(token_dir, exist_ok=True) | |
| token_path = os.path.join(token_dir, f"{token}.nii.gz") | |
| shutil.copy2(seg_path, token_path) | |
| result["mask_path_token"] = token | |
| result["mask_download_url"] = f"/api/download/{token}" | |
| result["segmentation_file"] = None | |
| print(f"API: Large mask file ({seg_file_size:.1f} MB). Using token-based download: {token}") | |
| else: | |
| with open(seg_path, "rb") as seg_file: | |
| seg_data = seg_file.read() | |
| seg_base64 = base64.b64encode(seg_data).decode('utf-8') | |
| result["segmentation_file"] = f"data:application/octet-stream;base64,{seg_base64}" | |
| result["mask_path_token"] = None | |
| result["mask_download_url"] = None | |
| os.unlink(tmp_path) | |
| if seg_file_size <= 2000: | |
| os.unlink(seg_path) | |
| return JSONResponse(content=result) | |
| except Exception as e: | |
| if os.path.exists(tmp_path): | |
| os.unlink(tmp_path) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def create_interface(): | |
| with gr.Blocks(title="SRMA-Mamba: Liver Segmentation", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # SRMA-Mamba: Liver Segmentation in MRI Volumes | |
| Upload a 3D NIfTI MRI volume (.nii.gz) to perform automatic liver segmentation. | |
| **Model Performance:** | |
| - Pixel Accuracy: 99.09% | |
| - IoU: 75% | |
| - PSNR: 29.64 dB | |
| **Supported Modalities:** T1-weighted and T2-weighted MRI | |
| **API Available:** This Space also provides a REST API. See `/docs` for documentation. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| nifti_input = gr.File( | |
| file_count="single", | |
| file_types=[".nii.gz", ".nii"], | |
| label="Upload NIfTI File (max 2 GB)" | |
| ) | |
| modality = gr.Radio( | |
| choices=["T1", "T2"], | |
| value="T1", | |
| label="MRI Modality" | |
| ) | |
| slice_slider = gr.Slider( | |
| minimum=0, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| label="Slice Index (will auto-update based on volume)", | |
| interactive=True | |
| ) | |
| predict_btn = gr.Button("Segment Liver", variant="primary") | |
| reset_btn = gr.Button("Reset", variant="secondary") | |
| with gr.Column(): | |
| output_image = gr.Image( | |
| label="Segmentation Overlay", | |
| type="pil" | |
| ) | |
| output_info = gr.Markdown( | |
| label="Statistics" | |
| ) | |
| output_report = gr.Markdown( | |
| label="Medical Report" | |
| ) | |
| output_file = gr.File( | |
| label="Download 3D Segmentation (.nii.gz)" | |
| ) | |
| gr.Markdown(""" | |
| ## Instructions | |
| 1. Upload a 3D NIfTI MRI volume (.nii.gz format) | |
| 2. Select the MRI modality (T1 or T2) | |
| 3. Click "Segment Liver" to run inference | |
| 4. View the segmentation overlay and download the 3D mask | |
| **Note:** First inference may take longer as the model loads. | |
| ## API Usage | |
| This Space provides a REST API for programmatic access: | |
| - **Endpoint**: `POST /api/segment` | |
| - **Documentation**: Visit `/docs` for interactive API docs | |
| - **Example**: Use from your frontend with `fetch()` or `axios()` | |
| """) | |
| use_gpu = False | |
| predict_fn = safe_predict_volume | |
| if config.HAS_SPACES: | |
| import spaces | |
| if torch.cuda.is_available(): | |
| try: | |
| test_tensor = torch.zeros(1).cuda() | |
| del test_tensor | |
| torch.cuda.empty_cache() | |
| use_gpu = True | |
| gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "Unknown" | |
| total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3) if torch.cuda.is_available() else 0 | |
| predict_fn = spaces.GPU(safe_predict_volume) | |
| print(f"β Using GPU acceleration: {gpu_name} ({total_vram:.1f} GB VRAM)") | |
| except ImportError: | |
| gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "Unknown" | |
| total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3) if torch.cuda.is_available() else 0 | |
| predict_fn = safe_predict_volume | |
| print(f"β Using GPU acceleration: {gpu_name} ({total_vram:.1f} GB VRAM)") | |
| except Exception as e: | |
| print(f"β GPU available but failed to initialize: {e}") | |
| print("Falling back to CPU mode") | |
| use_gpu = False | |
| predict_fn = safe_predict_volume | |
| else: | |
| print("βΉ No GPU available. Running on CPU (slower but free)") | |
| predict_fn = safe_predict_volume | |
| else: | |
| if torch.cuda.is_available(): | |
| gpu_name = torch.cuda.get_device_name(0) | |
| total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3) | |
| print(f"β CUDA available, using GPU: {gpu_name} ({total_vram:.1f} GB VRAM)") | |
| use_gpu = True | |
| else: | |
| print("βΉ Running on CPU") | |
| predict_fn = safe_predict_volume | |
| predict_btn.click( | |
| fn=predict_fn, | |
| inputs=[nifti_input, modality, slice_slider], | |
| outputs=[output_image, output_info, output_report, output_file] | |
| ) | |
| def update_slice_slider(file): | |
| if file is None: | |
| return gr.update(maximum=100, value=50) | |
| try: | |
| file_path = file.name if hasattr(file, 'name') else str(file) | |
| if not os.path.exists(file_path): | |
| return gr.update(maximum=100, value=50) | |
| volume = nib.load(file_path).get_fdata() | |
| max_slices = volume.shape[-1] if len(volume.shape) == 3 else volume.shape[2] | |
| return gr.update(maximum=max_slices - 1, value=max_slices // 2) | |
| except Exception as e: | |
| print(f"Error updating slice slider: {e}") | |
| return gr.update(maximum=100, value=50) | |
| def reset_interface(): | |
| import inference | |
| inference.PROCESSING_LOCK = False | |
| if model_loader.DEVICE and model_loader.DEVICE.type == 'cuda': | |
| torch.cuda.empty_cache() | |
| return None, None, "", None | |
| nifti_input.change( | |
| fn=update_slice_slider, | |
| inputs=[nifti_input], | |
| outputs=[slice_slider], | |
| show_progress=False | |
| ) | |
| reset_btn.click( | |
| fn=reset_interface, | |
| inputs=[], | |
| outputs=[output_image, output_info, output_report, output_file] | |
| ) | |
| return demo | |
| demo = create_interface() | |
| import config | |
| demo.queue(max_size=config.MAX_GRADIO_CONCURRENCY) | |
| app = api_app | |
| if config.HAS_MOUNT_GRADIO_APP: | |
| from gradio.routes import mount_gradio_app | |
| print("Mounting Gradio interface on FastAPI app...") | |
| print("FastAPI routes registered before mounting:") | |
| for route in api_app.routes: | |
| if hasattr(route, 'path') and hasattr(route, 'methods'): | |
| print(f" - {list(route.methods)} {route.path}") | |
| print("\nMounting Gradio at path '/gradio' to avoid interfering with API routes") | |
| mount_gradio_app(app=api_app, blocks=demo, path="/gradio") | |
| print("\nFastAPI routes after mounting:") | |
| for route in api_app.routes: | |
| if hasattr(route, 'path') and hasattr(route, 'methods'): | |
| methods = list(route.methods) if hasattr(route, 'methods') else [] | |
| if 'POST' in methods and '/segment' in route.path: | |
| print(f" β {methods} {route.path} - File upload endpoint") | |
| print("\nβ FastAPI app with Gradio mounted at /gradio") | |
| print("β API endpoints available at:") | |
| print(" - GET /api/health") | |
| print(" - POST /api/segment (multipart/form-data)") | |
| print(" - GET /api/download/{token}") | |
| print(" - GET /docs (FastAPI docs)") | |
| print(" - GET /gradio (Gradio UI - optional)") | |
| else: | |
| print("β Gradio mount not available. FastAPI will run without Gradio UI.") | |
| print("β API endpoints available at:") | |
| print(" - GET /api/health") | |
| print(" - POST /api/segment (multipart/form-data)") | |
| print(" - GET /api/download/{token}") | |
| print(" - GET /docs (FastAPI docs)") | |
| print(f"\nβ Production server: FastAPI (app = api_app)") | |
| print(f"β ASGI app exported as 'app' for Uvicorn") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| import os | |
| port = int(os.getenv("PORT", 7860)) | |
| print(f"\nStarting FastAPI server with Uvicorn on port {port}") | |
| print("This is for local development only.") | |
| print("In production, Hugging Face Spaces will use the 'app' variable directly.") | |
| uvicorn.run(app, host="0.0.0.0", port=port, log_level="info") | |