Harshith Reddy commited on
Commit
360f3b6
Β·
1 Parent(s): 8cd2cb8

Fix app mounting for Hugging Face Spaces: Export app at module level, add GPU wake-up retry logic, and fix indentation

Browse files
Files changed (1) hide show
  1. app.py +163 -134
app.py CHANGED
@@ -166,9 +166,27 @@ def load_model(modality='T1'):
166
 
167
  if torch.cuda.is_available():
168
  try:
169
- torch.cuda.empty_cache()
170
- DEVICE = torch.device('cuda')
171
- print(f"βœ“ Using device: {DEVICE}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  except Exception as e:
173
  print(f"⚠ CUDA available but failed to initialize: {e}. Falling back to CPU.")
174
  DEVICE = torch.device('cpu')
@@ -968,15 +986,14 @@ def create_interface():
968
 
969
  return demo
970
 
971
- if __name__ == "__main__":
972
- print("=" * 60)
973
- print("Initializing SRMA-Mamba Liver Segmentation Space...")
974
- print("=" * 60)
975
-
976
- try:
977
- demo = create_interface()
978
- print("βœ“ Gradio interface created successfully")
979
- except Exception as e:
980
  print(f"βœ— Failed to create interface: {e}")
981
  import traceback
982
  tb = traceback.format_exc()
@@ -1007,145 +1024,157 @@ if __name__ == "__main__":
1007
  print(f"βœ— Even error interface failed: {e2}")
1008
  raise
1009
 
1010
- print("β„Ή Models will be loaded on first inference")
1011
- print("=" * 60)
1012
-
1013
- from starlette.responses import JSONResponse
1014
- from starlette.requests import Request
 
 
 
 
 
 
 
 
1015
 
1016
- async def handle_segment(request: Request):
1017
- if request.method == "OPTIONS":
1018
- return JSONResponse({}, headers={
1019
- "Access-Control-Allow-Origin": "*",
1020
- "Access-Control-Allow-Methods": "POST, OPTIONS",
1021
- "Access-Control-Allow-Headers": "*"
1022
- })
 
 
 
 
 
 
 
 
 
1023
 
1024
  try:
1025
- form = await request.form()
1026
- file = form.get("file")
1027
- modality = form.get("modality", "T1")
1028
- slice_idx = form.get("slice_idx")
1029
-
1030
- if not file:
1031
- return JSONResponse({"error": "No file provided"}, status_code=400, headers={"Access-Control-Allow-Origin": "*"})
1032
-
1033
- if modality not in ['T1', 'T2']:
1034
- return JSONResponse({"error": "Modality must be 'T1' or 'T2'"}, status_code=400, headers={"Access-Control-Allow-Origin": "*"})
1035
-
1036
- with tempfile.NamedTemporaryFile(delete=False, suffix='.nii.gz') as tmp_file:
1037
- content = await file.read()
1038
- tmp_file.write(content)
1039
- tmp_path = tmp_file.name
1040
-
1041
- try:
1042
- result = predict_volume_api(tmp_path, modality, int(slice_idx) if slice_idx else None)
1043
- if not result["success"]:
1044
- error_detail = result.get("error", "Unknown error")
1045
- print(f"API prediction failed: {error_detail}")
1046
- return JSONResponse({
1047
- "success": False,
1048
- "error": error_detail,
1049
- "error_type": "prediction_failed"
1050
- }, status_code=500, headers={"Access-Control-Allow-Origin": "*"})
1051
-
1052
- with open(result["segmentation_path"], "rb") as seg_file:
1053
- seg_data = seg_file.read()
1054
- seg_base64 = base64.b64encode(seg_data).decode('utf-8')
1055
-
1056
- result["segmentation_file"] = f"data:application/octet-stream;base64,{seg_base64}"
1057
- os.unlink(tmp_path)
1058
- os.unlink(result["segmentation_path"])
1059
-
1060
- return JSONResponse(result, headers={"Access-Control-Allow-Origin": "*"})
1061
- except Exception as e:
1062
- error_msg = str(e)
1063
- import traceback
1064
- tb = traceback.format_exc()
1065
- print(f"Exception in handle_segment: {error_msg}")
1066
- print(f"Traceback: {tb}")
1067
- if os.path.exists(tmp_path):
1068
- os.unlink(tmp_path)
1069
  return JSONResponse({
1070
  "success": False,
1071
- "error": error_msg,
1072
- "error_type": "exception",
1073
- "traceback": tb if "traceback" in str(type(e)) else None
1074
  }, status_code=500, headers={"Access-Control-Allow-Origin": "*"})
 
 
 
 
 
 
 
 
 
 
1075
  except Exception as e:
1076
  error_msg = str(e)
1077
  import traceback
1078
  tb = traceback.format_exc()
1079
- print(f"Exception in handle_segment (outer): {error_msg}")
1080
  print(f"Traceback: {tb}")
 
 
1081
  return JSONResponse({
1082
  "success": False,
1083
  "error": error_msg,
1084
- "error_type": "outer_exception"
 
1085
  }, status_code=500, headers={"Access-Control-Allow-Origin": "*"})
1086
-
1087
- async def handle_health(request: Request):
 
 
 
 
1088
  return JSONResponse({
1089
- "status": "healthy",
1090
- "device": str(DEVICE) if DEVICE else "not initialized",
1091
- "model_t1_loaded": MODEL_T1 is not None,
1092
- "model_t2_loaded": MODEL_T2 is not None
1093
- }, headers={"Access-Control-Allow-Origin": "*"})
1094
-
1095
- import starlette.routing
1096
-
1097
- def add_api_routes():
 
 
 
 
 
 
 
 
 
 
 
 
1098
  try:
1099
- demo.app.mount("/api", api_app)
1100
- print("βœ“ FastAPI app mounted at /api")
1101
- except Exception as mount_error:
1102
- print(f"Mounting failed: {mount_error}, trying direct routes...")
 
 
 
 
 
 
 
 
 
 
 
 
1103
  try:
1104
- existing_routes = [str(r) for r in demo.app.routes]
1105
- print(f"Existing routes: {existing_routes[:5]}...")
1106
-
1107
- routes_to_add = [
1108
  starlette.routing.Route("/api/segment", handle_segment, methods=["POST", "OPTIONS"]),
1109
  starlette.routing.Route("/api/health", handle_health, methods=["GET", "OPTIONS"])
1110
- ]
1111
-
1112
- if hasattr(demo.app, 'router') and hasattr(demo.app.router, 'routes'):
1113
- demo.app.router.routes.extend(routes_to_add)
1114
- print("βœ“ API routes added to router.routes")
1115
- else:
1116
- demo.app.routes.extend(routes_to_add)
1117
- print("βœ“ API routes added to app.routes")
1118
- except Exception as route_error:
1119
- print(f"Router routes failed: {route_error}, trying app.routes...")
1120
- try:
1121
- demo.app.routes.extend([
1122
- starlette.routing.Route("/api/segment", handle_segment, methods=["POST", "OPTIONS"]),
1123
- starlette.routing.Route("/api/health", handle_health, methods=["GET", "OPTIONS"])
1124
- ])
1125
- print("βœ“ API routes added to app.routes")
1126
- except Exception as e2:
1127
- print(f"All route addition methods failed: {e2}")
1128
- import traceback
1129
- traceback.print_exc()
1130
-
1131
- try:
1132
- add_api_routes()
1133
- print("βœ“ API routes configured")
1134
- except Exception as route_err:
1135
- print(f"⚠ API route configuration failed (non-critical): {route_err}")
1136
- import traceback
1137
- traceback.print_exc()
1138
-
1139
- try:
1140
- print("πŸš€ Launching Gradio interface...")
1141
- demo.launch(
1142
- server_name="0.0.0.0",
1143
- server_port=7860,
1144
- share=False,
1145
- show_error=True
1146
- )
1147
- except Exception as launch_err:
1148
- print(f"βœ— Failed to launch interface: {launch_err}")
1149
- import traceback
1150
- traceback.print_exc()
1151
- raise
 
166
 
167
  if torch.cuda.is_available():
168
  try:
169
+ import time
170
+ max_retries = 3
171
+ retry_delay = 2
172
+
173
+ for attempt in range(max_retries):
174
+ try:
175
+ torch.cuda.empty_cache()
176
+ test_tensor = torch.zeros(1).cuda()
177
+ del test_tensor
178
+ torch.cuda.synchronize()
179
+ DEVICE = torch.device('cuda')
180
+ print(f"βœ“ Using device: {DEVICE}")
181
+ break
182
+ except RuntimeError as e:
183
+ if "CUDA" in str(e) and attempt < max_retries - 1:
184
+ print(f"⚠ GPU wake-up attempt {attempt + 1}/{max_retries}: {e}")
185
+ print(f"⚠ Waiting {retry_delay}s for GPU to wake up...")
186
+ time.sleep(retry_delay)
187
+ retry_delay *= 2
188
+ else:
189
+ raise
190
  except Exception as e:
191
  print(f"⚠ CUDA available but failed to initialize: {e}. Falling back to CPU.")
192
  DEVICE = torch.device('cpu')
 
986
 
987
  return demo
988
 
989
+ print("=" * 60)
990
+ print("Initializing SRMA-Mamba Liver Segmentation Space...")
991
+ print("=" * 60)
992
+
993
+ try:
994
+ demo = create_interface()
995
+ print("βœ“ Gradio interface created successfully")
996
+ except Exception as e:
 
997
  print(f"βœ— Failed to create interface: {e}")
998
  import traceback
999
  tb = traceback.format_exc()
 
1024
  print(f"βœ— Even error interface failed: {e2}")
1025
  raise
1026
 
1027
+ print("β„Ή Models will be loaded on first inference")
1028
+ print("=" * 60)
1029
+
1030
+ from starlette.responses import JSONResponse
1031
+ from starlette.requests import Request
1032
+
1033
+ async def handle_segment(request: Request):
1034
+ if request.method == "OPTIONS":
1035
+ return JSONResponse({}, headers={
1036
+ "Access-Control-Allow-Origin": "*",
1037
+ "Access-Control-Allow-Methods": "POST, OPTIONS",
1038
+ "Access-Control-Allow-Headers": "*"
1039
+ })
1040
 
1041
+ try:
1042
+ form = await request.form()
1043
+ file = form.get("file")
1044
+ modality = form.get("modality", "T1")
1045
+ slice_idx = form.get("slice_idx")
1046
+
1047
+ if not file:
1048
+ return JSONResponse({"error": "No file provided"}, status_code=400, headers={"Access-Control-Allow-Origin": "*"})
1049
+
1050
+ if modality not in ['T1', 'T2']:
1051
+ return JSONResponse({"error": "Modality must be 'T1' or 'T2'"}, status_code=400, headers={"Access-Control-Allow-Origin": "*"})
1052
+
1053
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.nii.gz') as tmp_file:
1054
+ content = await file.read()
1055
+ tmp_file.write(content)
1056
+ tmp_path = tmp_file.name
1057
 
1058
  try:
1059
+ result = predict_volume_api(tmp_path, modality, int(slice_idx) if slice_idx else None)
1060
+ if not result["success"]:
1061
+ error_detail = result.get("error", "Unknown error")
1062
+ print(f"API prediction failed: {error_detail}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1063
  return JSONResponse({
1064
  "success": False,
1065
+ "error": error_detail,
1066
+ "error_type": "prediction_failed"
 
1067
  }, status_code=500, headers={"Access-Control-Allow-Origin": "*"})
1068
+
1069
+ with open(result["segmentation_path"], "rb") as seg_file:
1070
+ seg_data = seg_file.read()
1071
+ seg_base64 = base64.b64encode(seg_data).decode('utf-8')
1072
+
1073
+ result["segmentation_file"] = f"data:application/octet-stream;base64,{seg_base64}"
1074
+ os.unlink(tmp_path)
1075
+ os.unlink(result["segmentation_path"])
1076
+
1077
+ return JSONResponse(result, headers={"Access-Control-Allow-Origin": "*"})
1078
  except Exception as e:
1079
  error_msg = str(e)
1080
  import traceback
1081
  tb = traceback.format_exc()
1082
+ print(f"Exception in handle_segment: {error_msg}")
1083
  print(f"Traceback: {tb}")
1084
+ if os.path.exists(tmp_path):
1085
+ os.unlink(tmp_path)
1086
  return JSONResponse({
1087
  "success": False,
1088
  "error": error_msg,
1089
+ "error_type": "exception",
1090
+ "traceback": tb if "traceback" in str(type(e)) else None
1091
  }, status_code=500, headers={"Access-Control-Allow-Origin": "*"})
1092
+ except Exception as e:
1093
+ error_msg = str(e)
1094
+ import traceback
1095
+ tb = traceback.format_exc()
1096
+ print(f"Exception in handle_segment (outer): {error_msg}")
1097
+ print(f"Traceback: {tb}")
1098
  return JSONResponse({
1099
+ "success": False,
1100
+ "error": error_msg,
1101
+ "error_type": "outer_exception"
1102
+ }, status_code=500, headers={"Access-Control-Allow-Origin": "*"})
1103
+
1104
+ async def handle_health(request: Request):
1105
+ return JSONResponse({
1106
+ "status": "healthy",
1107
+ "device": str(DEVICE) if DEVICE else "not initialized",
1108
+ "model_t1_loaded": MODEL_T1 is not None,
1109
+ "model_t2_loaded": MODEL_T2 is not None
1110
+ }, headers={"Access-Control-Allow-Origin": "*"})
1111
+
1112
+ import starlette.routing
1113
+
1114
+ def add_api_routes():
1115
+ try:
1116
+ demo.app.mount("/api", api_app)
1117
+ print("βœ“ FastAPI app mounted at /api")
1118
+ except Exception as mount_error:
1119
+ print(f"Mounting failed: {mount_error}, trying direct routes...")
1120
  try:
1121
+ existing_routes = [str(r) for r in demo.app.routes]
1122
+ print(f"Existing routes: {existing_routes[:5]}...")
1123
+
1124
+ routes_to_add = [
1125
+ starlette.routing.Route("/api/segment", handle_segment, methods=["POST", "OPTIONS"]),
1126
+ starlette.routing.Route("/api/health", handle_health, methods=["GET", "OPTIONS"])
1127
+ ]
1128
+
1129
+ if hasattr(demo.app, 'router') and hasattr(demo.app.router, 'routes'):
1130
+ demo.app.router.routes.extend(routes_to_add)
1131
+ print("βœ“ API routes added to router.routes")
1132
+ else:
1133
+ demo.app.routes.extend(routes_to_add)
1134
+ print("βœ“ API routes added to app.routes")
1135
+ except Exception as route_error:
1136
+ print(f"Router routes failed: {route_error}, trying app.routes...")
1137
  try:
1138
+ demo.app.routes.extend([
 
 
 
1139
  starlette.routing.Route("/api/segment", handle_segment, methods=["POST", "OPTIONS"]),
1140
  starlette.routing.Route("/api/health", handle_health, methods=["GET", "OPTIONS"])
1141
+ ])
1142
+ print("βœ“ API routes added to app.routes")
1143
+ except Exception as e2:
1144
+ print(f"All route addition methods failed: {e2}")
1145
+ import traceback
1146
+ traceback.print_exc()
1147
+
1148
+ try:
1149
+ add_api_routes()
1150
+ print("βœ“ API routes configured")
1151
+ except Exception as route_err:
1152
+ print(f"⚠ API route configuration failed (non-critical): {route_err}")
1153
+ import traceback
1154
+ traceback.print_exc()
1155
+
1156
+ print("βœ“ App initialization complete")
1157
+ print("=" * 60)
1158
+
1159
+ is_spaces = bool(os.getenv("SPACE_ID") or os.getenv("SYSTEM") == "spaces" or os.getenv("HF_SPACE"))
1160
+
1161
+ app = demo
1162
+
1163
+ if __name__ == "__main__":
1164
+ if not is_spaces:
1165
+ print("πŸš€ Running locally - launching with demo.launch()")
1166
+ try:
1167
+ demo.launch(
1168
+ server_name="0.0.0.0",
1169
+ server_port=7860,
1170
+ share=False,
1171
+ show_error=True
1172
+ )
1173
+ except Exception as launch_err:
1174
+ print(f"βœ— Failed to launch interface: {launch_err}")
1175
+ import traceback
1176
+ traceback.print_exc()
1177
+ raise
1178
+ else:
1179
+ print("🌐 Running on Hugging Face Spaces - app exported for Spaces to serve")
1180
+ print("βœ“ App ready for Spaces to serve")