haodongli commited on
Commit
bc88909
·
verified ·
1 Parent(s): 19e6f4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -9
app.py CHANGED
@@ -27,13 +27,9 @@ from datetime import (
27
  import cv2
28
  import numpy as np
29
 
30
- last_glb_path = None
31
-
32
  def prepare_to_run_demo():
33
  config = load_config('configs/infer.json')
34
  kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=config['accelerator']['timeout']))
35
- output_dir = f'output/infer'
36
- if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True)
37
  accu_steps = config['accelerator']['accumulation_nsteps']
38
  accelerator = Accelerator(
39
  gradient_accumulation_steps=accu_steps,
@@ -70,7 +66,7 @@ def ply2glb(ply_path, glb_path):
70
  os.remove(ply_path)
71
 
72
  def fn(image_path, mask_path):
73
- global last_glb_path
74
  config, accelerator = prepare_to_run_demo()
75
  model = load_model(config, accelerator)
76
  image, cv2_image, mask = load_infer_data_demo(image_path, mask_path,
@@ -81,11 +77,8 @@ def fn(image_path, mask_path):
81
  autocast_ctx = torch.autocast(accelerator.device.type)
82
  with autocast_ctx, torch.no_grad():
83
  distance = model(image).cpu().numpy()[0]
84
- if last_glb_path is not None:
85
- os.remove(last_glb_path)
86
  distance_vis = colorize_distance(distance, mask)
87
- save_path = f'cache/tmp_{datetime.now().strftime("%Y%m%d_%H%M%S")}.glb'
88
- last_glb_path = save_path
89
  normal_image = distance2pointcloud(distance, cv2_image, mask, save_path=save_path.replace('.glb', '.ply'), return_normal=True, save_distance=False)
90
  ply2glb(save_path.replace('.glb', '.ply'), save_path)
91
  return save_path, [distance_vis, normal_image]
 
27
  import cv2
28
  import numpy as np
29
 
 
 
30
  def prepare_to_run_demo():
31
  config = load_config('configs/infer.json')
32
  kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=config['accelerator']['timeout']))
 
 
33
  accu_steps = config['accelerator']['accumulation_nsteps']
34
  accelerator = Accelerator(
35
  gradient_accumulation_steps=accu_steps,
 
66
  os.remove(ply_path)
67
 
68
  def fn(image_path, mask_path):
69
+ name_base, _ = os.path.splitext(os.path.basename(image_path))
70
  config, accelerator = prepare_to_run_demo()
71
  model = load_model(config, accelerator)
72
  image, cv2_image, mask = load_infer_data_demo(image_path, mask_path,
 
77
  autocast_ctx = torch.autocast(accelerator.device.type)
78
  with autocast_ctx, torch.no_grad():
79
  distance = model(image).cpu().numpy()[0]
 
 
80
  distance_vis = colorize_distance(distance, mask)
81
+ save_path = f'files/cache/{name_base}.glb'
 
82
  normal_image = distance2pointcloud(distance, cv2_image, mask, save_path=save_path.replace('.glb', '.ply'), return_normal=True, save_distance=False)
83
  ply2glb(save_path.replace('.glb', '.ply'), save_path)
84
  return save_path, [distance_vis, normal_image]