Harshith Reddy
commited on
Commit
Β·
360f3b6
1
Parent(s):
8cd2cb8
Fix app mounting for Hugging Face Spaces: Export app at module level, add GPU wake-up retry logic, and fix indentation
Browse files
app.py
CHANGED
|
@@ -166,9 +166,27 @@ def load_model(modality='T1'):
|
|
| 166 |
|
| 167 |
if torch.cuda.is_available():
|
| 168 |
try:
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
except Exception as e:
|
| 173 |
print(f"β CUDA available but failed to initialize: {e}. Falling back to CPU.")
|
| 174 |
DEVICE = torch.device('cpu')
|
|
@@ -968,15 +986,14 @@ def create_interface():
|
|
| 968 |
|
| 969 |
return demo
|
| 970 |
|
| 971 |
-
|
| 972 |
-
|
| 973 |
-
|
| 974 |
-
|
| 975 |
-
|
| 976 |
-
|
| 977 |
-
|
| 978 |
-
|
| 979 |
-
except Exception as e:
|
| 980 |
print(f"β Failed to create interface: {e}")
|
| 981 |
import traceback
|
| 982 |
tb = traceback.format_exc()
|
|
@@ -1007,145 +1024,157 @@ if __name__ == "__main__":
|
|
| 1007 |
print(f"β Even error interface failed: {e2}")
|
| 1008 |
raise
|
| 1009 |
|
| 1010 |
-
|
| 1011 |
-
|
| 1012 |
-
|
| 1013 |
-
|
| 1014 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1015 |
|
| 1016 |
-
|
| 1017 |
-
|
| 1018 |
-
|
| 1019 |
-
|
| 1020 |
-
|
| 1021 |
-
|
| 1022 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1023 |
|
| 1024 |
try:
|
| 1025 |
-
|
| 1026 |
-
|
| 1027 |
-
|
| 1028 |
-
|
| 1029 |
-
|
| 1030 |
-
if not file:
|
| 1031 |
-
return JSONResponse({"error": "No file provided"}, status_code=400, headers={"Access-Control-Allow-Origin": "*"})
|
| 1032 |
-
|
| 1033 |
-
if modality not in ['T1', 'T2']:
|
| 1034 |
-
return JSONResponse({"error": "Modality must be 'T1' or 'T2'"}, status_code=400, headers={"Access-Control-Allow-Origin": "*"})
|
| 1035 |
-
|
| 1036 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix='.nii.gz') as tmp_file:
|
| 1037 |
-
content = await file.read()
|
| 1038 |
-
tmp_file.write(content)
|
| 1039 |
-
tmp_path = tmp_file.name
|
| 1040 |
-
|
| 1041 |
-
try:
|
| 1042 |
-
result = predict_volume_api(tmp_path, modality, int(slice_idx) if slice_idx else None)
|
| 1043 |
-
if not result["success"]:
|
| 1044 |
-
error_detail = result.get("error", "Unknown error")
|
| 1045 |
-
print(f"API prediction failed: {error_detail}")
|
| 1046 |
-
return JSONResponse({
|
| 1047 |
-
"success": False,
|
| 1048 |
-
"error": error_detail,
|
| 1049 |
-
"error_type": "prediction_failed"
|
| 1050 |
-
}, status_code=500, headers={"Access-Control-Allow-Origin": "*"})
|
| 1051 |
-
|
| 1052 |
-
with open(result["segmentation_path"], "rb") as seg_file:
|
| 1053 |
-
seg_data = seg_file.read()
|
| 1054 |
-
seg_base64 = base64.b64encode(seg_data).decode('utf-8')
|
| 1055 |
-
|
| 1056 |
-
result["segmentation_file"] = f"data:application/octet-stream;base64,{seg_base64}"
|
| 1057 |
-
os.unlink(tmp_path)
|
| 1058 |
-
os.unlink(result["segmentation_path"])
|
| 1059 |
-
|
| 1060 |
-
return JSONResponse(result, headers={"Access-Control-Allow-Origin": "*"})
|
| 1061 |
-
except Exception as e:
|
| 1062 |
-
error_msg = str(e)
|
| 1063 |
-
import traceback
|
| 1064 |
-
tb = traceback.format_exc()
|
| 1065 |
-
print(f"Exception in handle_segment: {error_msg}")
|
| 1066 |
-
print(f"Traceback: {tb}")
|
| 1067 |
-
if os.path.exists(tmp_path):
|
| 1068 |
-
os.unlink(tmp_path)
|
| 1069 |
return JSONResponse({
|
| 1070 |
"success": False,
|
| 1071 |
-
"error":
|
| 1072 |
-
"error_type": "
|
| 1073 |
-
"traceback": tb if "traceback" in str(type(e)) else None
|
| 1074 |
}, status_code=500, headers={"Access-Control-Allow-Origin": "*"})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1075 |
except Exception as e:
|
| 1076 |
error_msg = str(e)
|
| 1077 |
import traceback
|
| 1078 |
tb = traceback.format_exc()
|
| 1079 |
-
print(f"Exception in handle_segment
|
| 1080 |
print(f"Traceback: {tb}")
|
|
|
|
|
|
|
| 1081 |
return JSONResponse({
|
| 1082 |
"success": False,
|
| 1083 |
"error": error_msg,
|
| 1084 |
-
"error_type": "
|
|
|
|
| 1085 |
}, status_code=500, headers={"Access-Control-Allow-Origin": "*"})
|
| 1086 |
-
|
| 1087 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1088 |
return JSONResponse({
|
| 1089 |
-
"
|
| 1090 |
-
"
|
| 1091 |
-
"
|
| 1092 |
-
|
| 1093 |
-
|
| 1094 |
-
|
| 1095 |
-
|
| 1096 |
-
|
| 1097 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1098 |
try:
|
| 1099 |
-
demo.app.
|
| 1100 |
-
print("
|
| 1101 |
-
|
| 1102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1103 |
try:
|
| 1104 |
-
|
| 1105 |
-
print(f"Existing routes: {existing_routes[:5]}...")
|
| 1106 |
-
|
| 1107 |
-
routes_to_add = [
|
| 1108 |
starlette.routing.Route("/api/segment", handle_segment, methods=["POST", "OPTIONS"]),
|
| 1109 |
starlette.routing.Route("/api/health", handle_health, methods=["GET", "OPTIONS"])
|
| 1110 |
-
]
|
| 1111 |
-
|
| 1112 |
-
|
| 1113 |
-
|
| 1114 |
-
|
| 1115 |
-
|
| 1116 |
-
|
| 1117 |
-
|
| 1118 |
-
|
| 1119 |
-
|
| 1120 |
-
|
| 1121 |
-
|
| 1122 |
-
|
| 1123 |
-
|
| 1124 |
-
|
| 1125 |
-
|
| 1126 |
-
|
| 1127 |
-
|
| 1128 |
-
|
| 1129 |
-
|
| 1130 |
-
|
| 1131 |
-
|
| 1132 |
-
|
| 1133 |
-
|
| 1134 |
-
|
| 1135 |
-
|
| 1136 |
-
|
| 1137 |
-
|
| 1138 |
-
|
| 1139 |
-
|
| 1140 |
-
|
| 1141 |
-
|
| 1142 |
-
|
| 1143 |
-
|
| 1144 |
-
|
| 1145 |
-
|
| 1146 |
-
|
| 1147 |
-
|
| 1148 |
-
print(
|
| 1149 |
-
|
| 1150 |
-
traceback.print_exc()
|
| 1151 |
-
raise
|
|
|
|
| 166 |
|
| 167 |
if torch.cuda.is_available():
|
| 168 |
try:
|
| 169 |
+
import time
|
| 170 |
+
max_retries = 3
|
| 171 |
+
retry_delay = 2
|
| 172 |
+
|
| 173 |
+
for attempt in range(max_retries):
|
| 174 |
+
try:
|
| 175 |
+
torch.cuda.empty_cache()
|
| 176 |
+
test_tensor = torch.zeros(1).cuda()
|
| 177 |
+
del test_tensor
|
| 178 |
+
torch.cuda.synchronize()
|
| 179 |
+
DEVICE = torch.device('cuda')
|
| 180 |
+
print(f"β Using device: {DEVICE}")
|
| 181 |
+
break
|
| 182 |
+
except RuntimeError as e:
|
| 183 |
+
if "CUDA" in str(e) and attempt < max_retries - 1:
|
| 184 |
+
print(f"β GPU wake-up attempt {attempt + 1}/{max_retries}: {e}")
|
| 185 |
+
print(f"β Waiting {retry_delay}s for GPU to wake up...")
|
| 186 |
+
time.sleep(retry_delay)
|
| 187 |
+
retry_delay *= 2
|
| 188 |
+
else:
|
| 189 |
+
raise
|
| 190 |
except Exception as e:
|
| 191 |
print(f"β CUDA available but failed to initialize: {e}. Falling back to CPU.")
|
| 192 |
DEVICE = torch.device('cpu')
|
|
|
|
| 986 |
|
| 987 |
return demo
|
| 988 |
|
| 989 |
+
print("=" * 60)
|
| 990 |
+
print("Initializing SRMA-Mamba Liver Segmentation Space...")
|
| 991 |
+
print("=" * 60)
|
| 992 |
+
|
| 993 |
+
try:
|
| 994 |
+
demo = create_interface()
|
| 995 |
+
print("β Gradio interface created successfully")
|
| 996 |
+
except Exception as e:
|
|
|
|
| 997 |
print(f"β Failed to create interface: {e}")
|
| 998 |
import traceback
|
| 999 |
tb = traceback.format_exc()
|
|
|
|
| 1024 |
print(f"β Even error interface failed: {e2}")
|
| 1025 |
raise
|
| 1026 |
|
| 1027 |
+
print("βΉ Models will be loaded on first inference")
|
| 1028 |
+
print("=" * 60)
|
| 1029 |
+
|
| 1030 |
+
from starlette.responses import JSONResponse
|
| 1031 |
+
from starlette.requests import Request
|
| 1032 |
+
|
| 1033 |
+
async def handle_segment(request: Request):
|
| 1034 |
+
if request.method == "OPTIONS":
|
| 1035 |
+
return JSONResponse({}, headers={
|
| 1036 |
+
"Access-Control-Allow-Origin": "*",
|
| 1037 |
+
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
| 1038 |
+
"Access-Control-Allow-Headers": "*"
|
| 1039 |
+
})
|
| 1040 |
|
| 1041 |
+
try:
|
| 1042 |
+
form = await request.form()
|
| 1043 |
+
file = form.get("file")
|
| 1044 |
+
modality = form.get("modality", "T1")
|
| 1045 |
+
slice_idx = form.get("slice_idx")
|
| 1046 |
+
|
| 1047 |
+
if not file:
|
| 1048 |
+
return JSONResponse({"error": "No file provided"}, status_code=400, headers={"Access-Control-Allow-Origin": "*"})
|
| 1049 |
+
|
| 1050 |
+
if modality not in ['T1', 'T2']:
|
| 1051 |
+
return JSONResponse({"error": "Modality must be 'T1' or 'T2'"}, status_code=400, headers={"Access-Control-Allow-Origin": "*"})
|
| 1052 |
+
|
| 1053 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.nii.gz') as tmp_file:
|
| 1054 |
+
content = await file.read()
|
| 1055 |
+
tmp_file.write(content)
|
| 1056 |
+
tmp_path = tmp_file.name
|
| 1057 |
|
| 1058 |
try:
|
| 1059 |
+
result = predict_volume_api(tmp_path, modality, int(slice_idx) if slice_idx else None)
|
| 1060 |
+
if not result["success"]:
|
| 1061 |
+
error_detail = result.get("error", "Unknown error")
|
| 1062 |
+
print(f"API prediction failed: {error_detail}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1063 |
return JSONResponse({
|
| 1064 |
"success": False,
|
| 1065 |
+
"error": error_detail,
|
| 1066 |
+
"error_type": "prediction_failed"
|
|
|
|
| 1067 |
}, status_code=500, headers={"Access-Control-Allow-Origin": "*"})
|
| 1068 |
+
|
| 1069 |
+
with open(result["segmentation_path"], "rb") as seg_file:
|
| 1070 |
+
seg_data = seg_file.read()
|
| 1071 |
+
seg_base64 = base64.b64encode(seg_data).decode('utf-8')
|
| 1072 |
+
|
| 1073 |
+
result["segmentation_file"] = f"data:application/octet-stream;base64,{seg_base64}"
|
| 1074 |
+
os.unlink(tmp_path)
|
| 1075 |
+
os.unlink(result["segmentation_path"])
|
| 1076 |
+
|
| 1077 |
+
return JSONResponse(result, headers={"Access-Control-Allow-Origin": "*"})
|
| 1078 |
except Exception as e:
|
| 1079 |
error_msg = str(e)
|
| 1080 |
import traceback
|
| 1081 |
tb = traceback.format_exc()
|
| 1082 |
+
print(f"Exception in handle_segment: {error_msg}")
|
| 1083 |
print(f"Traceback: {tb}")
|
| 1084 |
+
if os.path.exists(tmp_path):
|
| 1085 |
+
os.unlink(tmp_path)
|
| 1086 |
return JSONResponse({
|
| 1087 |
"success": False,
|
| 1088 |
"error": error_msg,
|
| 1089 |
+
"error_type": "exception",
|
| 1090 |
+
"traceback": tb if "traceback" in str(type(e)) else None
|
| 1091 |
}, status_code=500, headers={"Access-Control-Allow-Origin": "*"})
|
| 1092 |
+
except Exception as e:
|
| 1093 |
+
error_msg = str(e)
|
| 1094 |
+
import traceback
|
| 1095 |
+
tb = traceback.format_exc()
|
| 1096 |
+
print(f"Exception in handle_segment (outer): {error_msg}")
|
| 1097 |
+
print(f"Traceback: {tb}")
|
| 1098 |
return JSONResponse({
|
| 1099 |
+
"success": False,
|
| 1100 |
+
"error": error_msg,
|
| 1101 |
+
"error_type": "outer_exception"
|
| 1102 |
+
}, status_code=500, headers={"Access-Control-Allow-Origin": "*"})
|
| 1103 |
+
|
| 1104 |
+
async def handle_health(request: Request):
|
| 1105 |
+
return JSONResponse({
|
| 1106 |
+
"status": "healthy",
|
| 1107 |
+
"device": str(DEVICE) if DEVICE else "not initialized",
|
| 1108 |
+
"model_t1_loaded": MODEL_T1 is not None,
|
| 1109 |
+
"model_t2_loaded": MODEL_T2 is not None
|
| 1110 |
+
}, headers={"Access-Control-Allow-Origin": "*"})
|
| 1111 |
+
|
| 1112 |
+
import starlette.routing
|
| 1113 |
+
|
| 1114 |
+
def add_api_routes():
|
| 1115 |
+
try:
|
| 1116 |
+
demo.app.mount("/api", api_app)
|
| 1117 |
+
print("β FastAPI app mounted at /api")
|
| 1118 |
+
except Exception as mount_error:
|
| 1119 |
+
print(f"Mounting failed: {mount_error}, trying direct routes...")
|
| 1120 |
try:
|
| 1121 |
+
existing_routes = [str(r) for r in demo.app.routes]
|
| 1122 |
+
print(f"Existing routes: {existing_routes[:5]}...")
|
| 1123 |
+
|
| 1124 |
+
routes_to_add = [
|
| 1125 |
+
starlette.routing.Route("/api/segment", handle_segment, methods=["POST", "OPTIONS"]),
|
| 1126 |
+
starlette.routing.Route("/api/health", handle_health, methods=["GET", "OPTIONS"])
|
| 1127 |
+
]
|
| 1128 |
+
|
| 1129 |
+
if hasattr(demo.app, 'router') and hasattr(demo.app.router, 'routes'):
|
| 1130 |
+
demo.app.router.routes.extend(routes_to_add)
|
| 1131 |
+
print("β API routes added to router.routes")
|
| 1132 |
+
else:
|
| 1133 |
+
demo.app.routes.extend(routes_to_add)
|
| 1134 |
+
print("β API routes added to app.routes")
|
| 1135 |
+
except Exception as route_error:
|
| 1136 |
+
print(f"Router routes failed: {route_error}, trying app.routes...")
|
| 1137 |
try:
|
| 1138 |
+
demo.app.routes.extend([
|
|
|
|
|
|
|
|
|
|
| 1139 |
starlette.routing.Route("/api/segment", handle_segment, methods=["POST", "OPTIONS"]),
|
| 1140 |
starlette.routing.Route("/api/health", handle_health, methods=["GET", "OPTIONS"])
|
| 1141 |
+
])
|
| 1142 |
+
print("β API routes added to app.routes")
|
| 1143 |
+
except Exception as e2:
|
| 1144 |
+
print(f"All route addition methods failed: {e2}")
|
| 1145 |
+
import traceback
|
| 1146 |
+
traceback.print_exc()
|
| 1147 |
+
|
| 1148 |
+
try:
|
| 1149 |
+
add_api_routes()
|
| 1150 |
+
print("β API routes configured")
|
| 1151 |
+
except Exception as route_err:
|
| 1152 |
+
print(f"β API route configuration failed (non-critical): {route_err}")
|
| 1153 |
+
import traceback
|
| 1154 |
+
traceback.print_exc()
|
| 1155 |
+
|
| 1156 |
+
print("β App initialization complete")
|
| 1157 |
+
print("=" * 60)
|
| 1158 |
+
|
| 1159 |
+
is_spaces = bool(os.getenv("SPACE_ID") or os.getenv("SYSTEM") == "spaces" or os.getenv("HF_SPACE"))
|
| 1160 |
+
|
| 1161 |
+
app = demo
|
| 1162 |
+
|
| 1163 |
+
if __name__ == "__main__":
|
| 1164 |
+
if not is_spaces:
|
| 1165 |
+
print("π Running locally - launching with demo.launch()")
|
| 1166 |
+
try:
|
| 1167 |
+
demo.launch(
|
| 1168 |
+
server_name="0.0.0.0",
|
| 1169 |
+
server_port=7860,
|
| 1170 |
+
share=False,
|
| 1171 |
+
show_error=True
|
| 1172 |
+
)
|
| 1173 |
+
except Exception as launch_err:
|
| 1174 |
+
print(f"β Failed to launch interface: {launch_err}")
|
| 1175 |
+
import traceback
|
| 1176 |
+
traceback.print_exc()
|
| 1177 |
+
raise
|
| 1178 |
+
else:
|
| 1179 |
+
print("π Running on Hugging Face Spaces - app exported for Spaces to serve")
|
| 1180 |
+
print("β App ready for Spaces to serve")
|
|
|
|
|
|