Spaces:
Runtime error
Runtime error
Update models/depth_normal_pipeline_clip.py
Browse files
models/depth_normal_pipeline_clip.py
CHANGED
|
@@ -79,7 +79,6 @@ class DepthNormalEstimationPipeline(DiffusionPipeline):
|
|
| 79 |
match_input_res:bool =True,
|
| 80 |
batch_size:int = 0,
|
| 81 |
domain: str = "indoor",
|
| 82 |
-
#seed: int = 0,
|
| 83 |
color_map: str="Spectral",
|
| 84 |
show_progress_bar:bool = True,
|
| 85 |
ensemble_kwargs: Dict = None,
|
|
@@ -148,7 +147,6 @@ class DepthNormalEstimationPipeline(DiffusionPipeline):
|
|
| 148 |
input_rgb=batched_image,
|
| 149 |
num_inference_steps=denoising_steps,
|
| 150 |
domain=domain,
|
| 151 |
-
#seed=seed,
|
| 152 |
show_pbar=show_progress_bar,
|
| 153 |
)
|
| 154 |
depth_pred_ls.append(depth_pred_raw.detach().clone())
|
|
@@ -232,7 +230,6 @@ class DepthNormalEstimationPipeline(DiffusionPipeline):
|
|
| 232 |
def single_infer(self,input_rgb:torch.Tensor,
|
| 233 |
num_inference_steps:int,
|
| 234 |
domain:str,
|
| 235 |
-
#seed: int,
|
| 236 |
show_pbar:bool,):
|
| 237 |
|
| 238 |
device = input_rgb.device
|
|
@@ -244,9 +241,7 @@ class DepthNormalEstimationPipeline(DiffusionPipeline):
|
|
| 244 |
# encode image
|
| 245 |
rgb_latent = self.encode_RGB(input_rgb)
|
| 246 |
|
| 247 |
-
# Initial
|
| 248 |
-
#if seed >= 0:
|
| 249 |
-
#torch.manual_seed(0)
|
| 250 |
geo_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype).repeat(2,1,1,1)
|
| 251 |
rgb_latent = rgb_latent.repeat(2,1,1,1)
|
| 252 |
|
|
@@ -258,7 +253,7 @@ class DepthNormalEstimationPipeline(DiffusionPipeline):
|
|
| 258 |
(rgb_latent.shape[0], 1, 1)
|
| 259 |
) # [B, 1, 768]
|
| 260 |
|
| 261 |
-
# hybrid
|
| 262 |
geo_class = torch.tensor([[0., 1.], [1, 0]], device=device, dtype=self.dtype)
|
| 263 |
geo_embedding = torch.cat([torch.sin(geo_class), torch.cos(geo_class)], dim=-1)
|
| 264 |
|
|
|
|
| 79 |
match_input_res:bool =True,
|
| 80 |
batch_size:int = 0,
|
| 81 |
domain: str = "indoor",
|
|
|
|
| 82 |
color_map: str="Spectral",
|
| 83 |
show_progress_bar:bool = True,
|
| 84 |
ensemble_kwargs: Dict = None,
|
|
|
|
| 147 |
input_rgb=batched_image,
|
| 148 |
num_inference_steps=denoising_steps,
|
| 149 |
domain=domain,
|
|
|
|
| 150 |
show_pbar=show_progress_bar,
|
| 151 |
)
|
| 152 |
depth_pred_ls.append(depth_pred_raw.detach().clone())
|
|
|
|
| 230 |
def single_infer(self,input_rgb:torch.Tensor,
|
| 231 |
num_inference_steps:int,
|
| 232 |
domain:str,
|
|
|
|
| 233 |
show_pbar:bool,):
|
| 234 |
|
| 235 |
device = input_rgb.device
|
|
|
|
| 241 |
# encode image
|
| 242 |
rgb_latent = self.encode_RGB(input_rgb)
|
| 243 |
|
| 244 |
+
# Initial geometric maps (Guassian noise)
|
|
|
|
|
|
|
| 245 |
geo_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype).repeat(2,1,1,1)
|
| 246 |
rgb_latent = rgb_latent.repeat(2,1,1,1)
|
| 247 |
|
|
|
|
| 253 |
(rgb_latent.shape[0], 1, 1)
|
| 254 |
) # [B, 1, 768]
|
| 255 |
|
| 256 |
+
# hybrid switcher
|
| 257 |
geo_class = torch.tensor([[0., 1.], [1, 0]], device=device, dtype=self.dtype)
|
| 258 |
geo_embedding = torch.cat([torch.sin(geo_class), torch.cos(geo_class)], dim=-1)
|
| 259 |
|