Harshith Reddy commited on
Commit
f9a3113
·
1 Parent(s): 5b260e5

Fix UnboundLocalError and Triton cache permission error

Browse files
Files changed (2) hide show
  1. app.py +8 -0
  2. inference.py +9 -2
app.py CHANGED
@@ -3,6 +3,14 @@ import os
3
  if 'PYTORCH_ALLOC_CONF' not in os.environ:
4
  os.environ['PYTORCH_ALLOC_CONF'] = 'expandable_segments:True,max_split_size_mb=128'
5
 
 
 
 
 
 
 
 
 
6
  import tempfile
7
  import base64
8
  import torch
 
3
  if 'PYTORCH_ALLOC_CONF' not in os.environ:
4
  os.environ['PYTORCH_ALLOC_CONF'] = 'expandable_segments:True,max_split_size_mb=128'
5
 
6
+ if 'TRITON_CACHE_DIR' not in os.environ:
7
+ triton_cache = '/tmp/.triton'
8
+ os.environ['TRITON_CACHE_DIR'] = triton_cache
9
+ try:
10
+ os.makedirs(triton_cache, exist_ok=True)
11
+ except:
12
+ pass
13
+
14
  import tempfile
15
  import base64
16
  import torch
inference.py CHANGED
@@ -148,6 +148,8 @@ def predict_volume(nifti_file, modality, slice_idx=None):
148
  print("Running inference...")
149
  inference_start = time.time()
150
 
 
 
151
  gpu_handle = None
152
  if model_loader.DEVICE.type == 'cuda':
153
  allocated = torch.cuda.memory_allocated(0) / (1024**3)
@@ -370,8 +372,13 @@ def predict_volume(nifti_file, modality, slice_idx=None):
370
  finally:
371
  timeout_timer.cancel()
372
 
373
- print(f" → Raw model output shapes: y1={y1.shape if hasattr(y1, 'shape') else 'N/A'}, y2={y2.shape if hasattr(y2, 'shape') else 'N/A'}")
374
- sys.stdout.flush()
 
 
 
 
 
375
 
376
  pred = y1
377
  if isinstance(pred, torch.Tensor):
 
148
  print("Running inference...")
149
  inference_start = time.time()
150
 
151
+ y1, y2, y3, y4 = None, None, None, None
152
+
153
  gpu_handle = None
154
  if model_loader.DEVICE.type == 'cuda':
155
  allocated = torch.cuda.memory_allocated(0) / (1024**3)
 
372
  finally:
373
  timeout_timer.cancel()
374
 
375
+ if y1 is not None:
376
+ print(f" → Raw model output shapes: y1={y1.shape if hasattr(y1, 'shape') else 'N/A'}, y2={y2.shape if hasattr(y2, 'shape') else 'N/A'}")
377
+ sys.stdout.flush()
378
+ else:
379
+ print(" → Inference failed before model output was generated")
380
+ sys.stdout.flush()
381
+ raise RuntimeError("Inference failed: model did not produce output")
382
 
383
  pred = y1
384
  if isinstance(pred, torch.Tensor):