DA-2 / da2 /model /vit_w_esphere.py
haodongli's picture
update
d82e7f9
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .base import (
fourier_dimension_expansion,
flatten,
DimensionAligner,
AttentionSeq,
ResidualUpsampler
)
class _ViT_w_Esphere(nn.Module):
def __init__(
self,
hidden_dim: int,
num_heads: int = 8,
expansion: int = 4,
num_layers_head: int | list[int] = 4,
dropout: float = 0.0,
kernel_size: int = 7,
layer_scale: float = 1.0,
out_dim: int = 1,
num_prompt_blocks: int = 1,
use_norm: bool = False,
**kwargs,
) -> None:
super().__init__()
self.out_dim = out_dim
self.hidden_dim = hidden_dim
self.up_sampler = nn.ModuleList([])
self.pred_head = nn.ModuleList([])
self.process_features = nn.ModuleList([])
self.prompt_camera = nn.ModuleList([])
mult = 2
self.to_latents = nn.Linear(hidden_dim, hidden_dim)
for _ in range(4):
self.prompt_camera.append(
AttentionSeq(
num_blocks=num_prompt_blocks,
dim=hidden_dim,
num_heads=num_heads,
expansion=expansion,
dropout=dropout,
layer_scale=-1.0,
context_dim=hidden_dim,
)
)
for i, depth in enumerate(num_layers_head):
current_dim = min(hidden_dim, mult * hidden_dim // int(2**i))
next_dim = mult * hidden_dim // int(2 ** (i + 1))
output_dim = max(next_dim, out_dim)
self.process_features.append(
nn.ConvTranspose2d(
hidden_dim,
current_dim,
kernel_size=max(1, 2 * i),
stride=max(1, 2 * i),
padding=0,
)
)
self.up_sampler.append(
ResidualUpsampler(
current_dim,
output_dim=output_dim,
expansion=expansion,
layer_scale=layer_scale,
kernel_size=kernel_size,
num_layers=depth,
use_norm=use_norm,
)
)
pred_head = (
nn.Sequential(nn.LayerNorm(next_dim), nn.Linear(next_dim, output_dim))
if i == len(num_layers_head) - 1
else nn.Identity()
)
self.pred_head.append(pred_head)
self.to_depth_lr = nn.Conv2d(
output_dim,
output_dim // 2,
kernel_size=3,
padding=1,
padding_mode='reflect',
)
self.to_confidence_lr = nn.Conv2d(
output_dim,
output_dim // 2,
kernel_size=3,
padding=1,
padding_mode='reflect',
)
self.to_depth_hr = nn.Sequential(
nn.Conv2d(
output_dim // 2, 32, kernel_size=3, padding=1, padding_mode='reflect'
),
nn.LeakyReLU(),
nn.Conv2d(32, 1, kernel_size=1),
)
self.to_confidence_hr = nn.Sequential(
nn.Conv2d(
output_dim // 2, 32, kernel_size=3, padding=1, padding_mode='reflect'
),
nn.LeakyReLU(),
nn.Conv2d(32, 1, kernel_size=1),
)
def set_original_shapes(self, shapes: tuple[int, int]):
self.original_shapes = shapes
def set_shapes(self, shapes: tuple[int, int]):
self.shapes = shapes
def embed_sphere_dirs(self, sphere_dirs):
sphere_embedding = flatten(
sphere_dirs, old=self.original_shapes, new=self.shapes
)
# index 0 -> Y
# index 1 -> Z
# index 2 -> X
r1, r2, r3 = sphere_embedding[..., 0], sphere_embedding[..., 1], sphere_embedding[..., 2]
polar = torch.asin(r2)
r3_clipped = r3.abs().clip(min=1e-5) * (2 * (r3 >= 0).int() - 1)
azimuth = torch.atan2(r1, r3_clipped)
# [polar, azimuth] is the angle field
sphere_embedding = torch.stack([polar, azimuth], dim=-1)
# expand the dimension of the angle field to image feature dimensions, via sine-cosine basis embedding
sphere_embedding = fourier_dimension_expansion(
sphere_embedding,
dim=self.hidden_dim,
max_freq=max(self.shapes) // 2,
use_cos=False,
)
return sphere_embedding
def condition(self, feat, sphere_embeddings):
conditioned_features = [
prompter(rearrange(feature, 'b h w c -> b (h w) c'), sphere_embeddings)
for prompter, feature in zip(self.prompt_camera, feat)
]
return conditioned_features
def process(self, features_list, sphere_embeddings):
conditioned_features = self.condition(features_list, sphere_embeddings)
init_latents = self.to_latents(conditioned_features[0])
init_latents = rearrange(
init_latents, 'b (h w) c -> b c h w', h=self.shapes[0], w=self.shapes[1]
).contiguous()
conditioned_features = [
rearrange(
x, 'b (h w) c -> b c h w', h=self.shapes[0], w=self.shapes[1]
).contiguous()
for x in conditioned_features
]
latents = init_latents
out_features = []
# Pyramid-like multi-layer convolutional feature extraction
for i, up in enumerate(self.up_sampler):
latents = latents + self.process_features[i](conditioned_features[i + 1])
latents = up(latents)
out_features.append(latents)
return out_features
def prediction_head(self, out_features):
depths = []
h_out, w_out = out_features[-1].shape[-2:]
for i, (layer, features) in enumerate(zip(self.pred_head, out_features)):
out_depth_features = layer(features.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
if i < len(self.pred_head) - 1:
continue
depths.append(out_depth_features)
out_depth_features = F.interpolate(
out_depth_features, size=(h_out, w_out), mode='bilinear', align_corners=True
)
distance = self.to_depth_lr(out_depth_features)
distance = F.interpolate(
distance, size=self.original_shapes, mode='bilinear', align_corners=True
)
distance = self.to_depth_hr(distance)
return distance
def forward(
self,
features: list[torch.Tensor],
sphere_dirs: torch.Tensor
) -> torch.Tensor:
sphere_embeddings = self.embed_sphere_dirs(sphere_dirs)
features = self.process(features, sphere_embeddings)
distance = self.prediction_head(features)
return distance
class ViT_w_Esphere(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.dim_aligner = DimensionAligner(
input_dims=config['input_dims'],
hidden_dim=config['hidden_dim'],
)
self._vit_w_esphere = _ViT_w_Esphere(**config)
def forward(self, images, features, sphere_dirs) -> torch.Tensor:
_, _, H, W = images.shape
sphere_dirs = sphere_dirs
common_shape = features[0].shape[1:3]
features = self.dim_aligner(features)
sphere_dirs = rearrange(sphere_dirs, 'b c h w -> b (h w) c')
self._vit_w_esphere.set_shapes(common_shape)
self._vit_w_esphere.set_original_shapes((H, W))
logdistance = self._vit_w_esphere(
features=features,
sphere_dirs=sphere_dirs,
)
distance = torch.exp(logdistance.clip(min=-8.0, max=8.0) + 2.0)
distance = distance / torch.quantile(distance, 0.98)
return distance