Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from .base import ( | |
| fourier_dimension_expansion, | |
| flatten, | |
| DimensionAligner, | |
| AttentionSeq, | |
| ResidualUpsampler | |
| ) | |
| class _ViT_w_Esphere(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_dim: int, | |
| num_heads: int = 8, | |
| expansion: int = 4, | |
| num_layers_head: int | list[int] = 4, | |
| dropout: float = 0.0, | |
| kernel_size: int = 7, | |
| layer_scale: float = 1.0, | |
| out_dim: int = 1, | |
| num_prompt_blocks: int = 1, | |
| use_norm: bool = False, | |
| **kwargs, | |
| ) -> None: | |
| super().__init__() | |
| self.out_dim = out_dim | |
| self.hidden_dim = hidden_dim | |
| self.up_sampler = nn.ModuleList([]) | |
| self.pred_head = nn.ModuleList([]) | |
| self.process_features = nn.ModuleList([]) | |
| self.prompt_camera = nn.ModuleList([]) | |
| mult = 2 | |
| self.to_latents = nn.Linear(hidden_dim, hidden_dim) | |
| for _ in range(4): | |
| self.prompt_camera.append( | |
| AttentionSeq( | |
| num_blocks=num_prompt_blocks, | |
| dim=hidden_dim, | |
| num_heads=num_heads, | |
| expansion=expansion, | |
| dropout=dropout, | |
| layer_scale=-1.0, | |
| context_dim=hidden_dim, | |
| ) | |
| ) | |
| for i, depth in enumerate(num_layers_head): | |
| current_dim = min(hidden_dim, mult * hidden_dim // int(2**i)) | |
| next_dim = mult * hidden_dim // int(2 ** (i + 1)) | |
| output_dim = max(next_dim, out_dim) | |
| self.process_features.append( | |
| nn.ConvTranspose2d( | |
| hidden_dim, | |
| current_dim, | |
| kernel_size=max(1, 2 * i), | |
| stride=max(1, 2 * i), | |
| padding=0, | |
| ) | |
| ) | |
| self.up_sampler.append( | |
| ResidualUpsampler( | |
| current_dim, | |
| output_dim=output_dim, | |
| expansion=expansion, | |
| layer_scale=layer_scale, | |
| kernel_size=kernel_size, | |
| num_layers=depth, | |
| use_norm=use_norm, | |
| ) | |
| ) | |
| pred_head = ( | |
| nn.Sequential(nn.LayerNorm(next_dim), nn.Linear(next_dim, output_dim)) | |
| if i == len(num_layers_head) - 1 | |
| else nn.Identity() | |
| ) | |
| self.pred_head.append(pred_head) | |
| self.to_depth_lr = nn.Conv2d( | |
| output_dim, | |
| output_dim // 2, | |
| kernel_size=3, | |
| padding=1, | |
| padding_mode='reflect', | |
| ) | |
| self.to_confidence_lr = nn.Conv2d( | |
| output_dim, | |
| output_dim // 2, | |
| kernel_size=3, | |
| padding=1, | |
| padding_mode='reflect', | |
| ) | |
| self.to_depth_hr = nn.Sequential( | |
| nn.Conv2d( | |
| output_dim // 2, 32, kernel_size=3, padding=1, padding_mode='reflect' | |
| ), | |
| nn.LeakyReLU(), | |
| nn.Conv2d(32, 1, kernel_size=1), | |
| ) | |
| self.to_confidence_hr = nn.Sequential( | |
| nn.Conv2d( | |
| output_dim // 2, 32, kernel_size=3, padding=1, padding_mode='reflect' | |
| ), | |
| nn.LeakyReLU(), | |
| nn.Conv2d(32, 1, kernel_size=1), | |
| ) | |
| def set_original_shapes(self, shapes: tuple[int, int]): | |
| self.original_shapes = shapes | |
| def set_shapes(self, shapes: tuple[int, int]): | |
| self.shapes = shapes | |
| def embed_sphere_dirs(self, sphere_dirs): | |
| sphere_embedding = flatten( | |
| sphere_dirs, old=self.original_shapes, new=self.shapes | |
| ) | |
| # index 0 -> Y | |
| # index 1 -> Z | |
| # index 2 -> X | |
| r1, r2, r3 = sphere_embedding[..., 0], sphere_embedding[..., 1], sphere_embedding[..., 2] | |
| polar = torch.asin(r2) | |
| r3_clipped = r3.abs().clip(min=1e-5) * (2 * (r3 >= 0).int() - 1) | |
| azimuth = torch.atan2(r1, r3_clipped) | |
| # [polar, azimuth] is the angle field | |
| sphere_embedding = torch.stack([polar, azimuth], dim=-1) | |
| # expand the dimension of the angle field to image feature dimensions, via sine-cosine basis embedding | |
| sphere_embedding = fourier_dimension_expansion( | |
| sphere_embedding, | |
| dim=self.hidden_dim, | |
| max_freq=max(self.shapes) // 2, | |
| use_cos=False, | |
| ) | |
| return sphere_embedding | |
| def condition(self, feat, sphere_embeddings): | |
| conditioned_features = [ | |
| prompter(rearrange(feature, 'b h w c -> b (h w) c'), sphere_embeddings) | |
| for prompter, feature in zip(self.prompt_camera, feat) | |
| ] | |
| return conditioned_features | |
| def process(self, features_list, sphere_embeddings): | |
| conditioned_features = self.condition(features_list, sphere_embeddings) | |
| init_latents = self.to_latents(conditioned_features[0]) | |
| init_latents = rearrange( | |
| init_latents, 'b (h w) c -> b c h w', h=self.shapes[0], w=self.shapes[1] | |
| ).contiguous() | |
| conditioned_features = [ | |
| rearrange( | |
| x, 'b (h w) c -> b c h w', h=self.shapes[0], w=self.shapes[1] | |
| ).contiguous() | |
| for x in conditioned_features | |
| ] | |
| latents = init_latents | |
| out_features = [] | |
| # Pyramid-like multi-layer convolutional feature extraction | |
| for i, up in enumerate(self.up_sampler): | |
| latents = latents + self.process_features[i](conditioned_features[i + 1]) | |
| latents = up(latents) | |
| out_features.append(latents) | |
| return out_features | |
| def prediction_head(self, out_features): | |
| depths = [] | |
| h_out, w_out = out_features[-1].shape[-2:] | |
| for i, (layer, features) in enumerate(zip(self.pred_head, out_features)): | |
| out_depth_features = layer(features.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | |
| if i < len(self.pred_head) - 1: | |
| continue | |
| depths.append(out_depth_features) | |
| out_depth_features = F.interpolate( | |
| out_depth_features, size=(h_out, w_out), mode='bilinear', align_corners=True | |
| ) | |
| distance = self.to_depth_lr(out_depth_features) | |
| distance = F.interpolate( | |
| distance, size=self.original_shapes, mode='bilinear', align_corners=True | |
| ) | |
| distance = self.to_depth_hr(distance) | |
| return distance | |
| def forward( | |
| self, | |
| features: list[torch.Tensor], | |
| sphere_dirs: torch.Tensor | |
| ) -> torch.Tensor: | |
| sphere_embeddings = self.embed_sphere_dirs(sphere_dirs) | |
| features = self.process(features, sphere_embeddings) | |
| distance = self.prediction_head(features) | |
| return distance | |
| class ViT_w_Esphere(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.dim_aligner = DimensionAligner( | |
| input_dims=config['input_dims'], | |
| hidden_dim=config['hidden_dim'], | |
| ) | |
| self._vit_w_esphere = _ViT_w_Esphere(**config) | |
| def forward(self, images, features, sphere_dirs) -> torch.Tensor: | |
| _, _, H, W = images.shape | |
| sphere_dirs = sphere_dirs | |
| common_shape = features[0].shape[1:3] | |
| features = self.dim_aligner(features) | |
| sphere_dirs = rearrange(sphere_dirs, 'b c h w -> b (h w) c') | |
| self._vit_w_esphere.set_shapes(common_shape) | |
| self._vit_w_esphere.set_original_shapes((H, W)) | |
| logdistance = self._vit_w_esphere( | |
| features=features, | |
| sphere_dirs=sphere_dirs, | |
| ) | |
| distance = torch.exp(logdistance.clip(min=-8.0, max=8.0) + 2.0) | |
| distance = distance / torch.quantile(distance, 0.98) | |
| return distance | |