Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| import torch.nn as nn | |
| from easydict import EasyDict as edict | |
| from einops import rearrange | |
| from sklearn.cluster import SpectralClustering | |
| from spatracker.blocks import Lie | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| import torch.nn.functional as F | |
| from spatracker.blocks import ( | |
| BasicEncoder, | |
| CorrBlock, | |
| EUpdateFormer, | |
| FusionFormer, | |
| pix2cam, | |
| cam2pix, | |
| edgeMat, | |
| VitEncoder, | |
| DPTEnc, | |
| DPT_DINOv2, | |
| Dinov2 | |
| ) | |
| from spatracker.feature_net import ( | |
| LocalSoftSplat | |
| ) | |
| from spatracker.model_utils import ( | |
| meshgrid2d, bilinear_sample2d, smart_cat, sample_features5d, vis_PCA | |
| ) | |
| from spatracker.embeddings import ( | |
| get_2d_embedding, | |
| get_3d_embedding, | |
| get_1d_sincos_pos_embed_from_grid, | |
| get_2d_sincos_pos_embed, | |
| get_3d_sincos_pos_embed_from_grid, | |
| Embedder_Fourier, | |
| ) | |
| import numpy as np | |
| from spatracker.softsplat import softsplat | |
| torch.manual_seed(0) | |
| def get_points_on_a_grid(grid_size, interp_shape, | |
| grid_center=(0, 0), device="cuda"): | |
| if grid_size == 1: | |
| return torch.tensor([interp_shape[1] / 2, | |
| interp_shape[0] / 2], device=device)[ | |
| None, None | |
| ] | |
| grid_y, grid_x = meshgrid2d( | |
| 1, grid_size, grid_size, stack=False, norm=False, device=device | |
| ) | |
| step = interp_shape[1] // 64 | |
| if grid_center[0] != 0 or grid_center[1] != 0: | |
| grid_y = grid_y - grid_size / 2.0 | |
| grid_x = grid_x - grid_size / 2.0 | |
| grid_y = step + grid_y.reshape(1, -1) / float(grid_size - 1) * ( | |
| interp_shape[0] - step * 2 | |
| ) | |
| grid_x = step + grid_x.reshape(1, -1) / float(grid_size - 1) * ( | |
| interp_shape[1] - step * 2 | |
| ) | |
| grid_y = grid_y + grid_center[0] | |
| grid_x = grid_x + grid_center[1] | |
| xy = torch.stack([grid_x, grid_y], dim=-1).to(device) | |
| return xy | |
| def sample_pos_embed(grid_size, embed_dim, coords): | |
| if coords.shape[-1] == 2: | |
| pos_embed = get_2d_sincos_pos_embed(embed_dim=embed_dim, | |
| grid_size=grid_size) | |
| pos_embed = ( | |
| torch.from_numpy(pos_embed) | |
| .reshape(grid_size[0], grid_size[1], embed_dim) | |
| .float() | |
| .unsqueeze(0) | |
| .to(coords.device) | |
| ) | |
| sampled_pos_embed = bilinear_sample2d( | |
| pos_embed.permute(0, 3, 1, 2), | |
| coords[:, 0, :, 0], coords[:, 0, :, 1] | |
| ) | |
| elif coords.shape[-1] == 3: | |
| sampled_pos_embed = get_3d_sincos_pos_embed_from_grid( | |
| embed_dim, coords[:, :1, ...] | |
| ).float()[:,0,...].permute(0, 2, 1) | |
| return sampled_pos_embed | |
| class FeatureExtractor(nn.Module): | |
| def __init__( | |
| self, | |
| S=8, | |
| stride=8, | |
| add_space_attn=True, | |
| num_heads=8, | |
| hidden_size=384, | |
| space_depth=12, | |
| time_depth=12, | |
| depth_extend_margin = 0.2, | |
| args=edict({}) | |
| ): | |
| super(FeatureExtractor, self).__init__() | |
| # step1: config the arch of the model | |
| self.args=args | |
| # step1.1: config the default value of the model | |
| if getattr(args, "depth_color", None) == None: | |
| self.args.depth_color = False | |
| if getattr(args, "if_ARAP", None) == None: | |
| self.args.if_ARAP = True | |
| if getattr(args, "flash_attn", None) == None: | |
| self.args.flash_attn = True | |
| if getattr(args, "backbone", None) == None: | |
| self.args.backbone = "CNN" | |
| if getattr(args, "Nblock", None) == None: | |
| self.args.Nblock = 0 | |
| if getattr(args, "Embed3D", None) == None: | |
| self.args.Embed3D = True | |
| # step1.2: config the model parameters | |
| self.S = S | |
| self.stride = stride | |
| self.hidden_dim = 256 | |
| self.latent_dim = latent_dim = 128 | |
| self.b_latent_dim = self.latent_dim//3 | |
| self.corr_levels = 4 | |
| self.corr_radius = 3 | |
| self.add_space_attn = add_space_attn | |
| self.lie = Lie() | |
| self.depth_extend_margin = depth_extend_margin | |
| # step2: config the model components | |
| # @Encoder | |
| self.fnet = BasicEncoder(input_dim=3, | |
| output_dim=self.latent_dim, norm_fn="instance", dropout=0, | |
| stride=stride, Embed3D=False | |
| ) | |
| # conv head for the tri-plane features | |
| self.headyz = nn.Sequential( | |
| nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1)) | |
| self.headxz = nn.Sequential( | |
| nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(self.latent_dim, self.latent_dim, 3, padding=1)) | |
| # @UpdateFormer | |
| self.updateformer = EUpdateFormer( | |
| space_depth=space_depth, | |
| time_depth=time_depth, | |
| input_dim=456, | |
| hidden_size=hidden_size, | |
| num_heads=num_heads, | |
| output_dim=latent_dim + 3, | |
| mlp_ratio=4.0, | |
| add_space_attn=add_space_attn, | |
| flash=getattr(self.args, "flash_attn", True) | |
| ) | |
| self.support_features = torch.zeros(100, 384).to("cuda") + 0.1 | |
| self.norm = nn.GroupNorm(1, self.latent_dim) | |
| self.ffeat_updater = nn.Sequential( | |
| nn.Linear(self.latent_dim, self.latent_dim), | |
| nn.GELU(), | |
| ) | |
| self.ffeatyz_updater = nn.Sequential( | |
| nn.Linear(self.latent_dim, self.latent_dim), | |
| nn.GELU(), | |
| ) | |
| self.ffeatxz_updater = nn.Sequential( | |
| nn.Linear(self.latent_dim, self.latent_dim), | |
| nn.GELU(), | |
| ) | |
| #TODO @NeuralArap: optimize the arap | |
| self.embed_traj = Embedder_Fourier( | |
| input_dim=5, max_freq_log2=5.0, N_freqs=3, include_input=True | |
| ) | |
| self.embed3d = Embedder_Fourier( | |
| input_dim=3, max_freq_log2=10.0, N_freqs=10, include_input=True | |
| ) | |
| self.embedConv = nn.Conv2d(self.latent_dim+63, | |
| self.latent_dim, 3, padding=1) | |
| # @Vis_predictor | |
| self.vis_predictor = nn.Sequential( | |
| nn.Linear(128, 1), | |
| ) | |
| self.embedProj = nn.Linear(63, 456) | |
| self.zeroMLPflow = nn.Linear(195, 130) | |
| def prepare_track(self, rgbds, queries): | |
| """ | |
| NOTE: | |
| Normalized the rgbs and sorted the queries via their first appeared time | |
| Args: | |
| rgbds: the input rgbd images (B T 4 H W) | |
| queries: the input queries (B N 4) | |
| Return: | |
| rgbds: the normalized rgbds (B T 4 H W) | |
| queries: the sorted queries (B N 4) | |
| track_mask: | |
| """ | |
| assert (rgbds.shape[2]==4) and (queries.shape[2]==4) | |
| #Step1: normalize the rgbs input | |
| device = rgbds.device | |
| rgbds[:, :, :3, ...] = 2 * (rgbds[:, :, :3, ...] / 255.0) - 1.0 | |
| B, T, C, H, W = rgbds.shape | |
| B, N, __ = queries.shape | |
| self.traj_e = torch.zeros((B, T, N, 3), device=device) | |
| self.vis_e = torch.zeros((B, T, N), device=device) | |
| #Step2: sort the points via their first appeared time | |
| first_positive_inds = queries[0, :, 0].long() | |
| __, sort_inds = torch.sort(first_positive_inds, dim=0, descending=False) | |
| inv_sort_inds = torch.argsort(sort_inds, dim=0) | |
| first_positive_sorted_inds = first_positive_inds[sort_inds] | |
| # check if can be inverse | |
| assert torch.allclose( | |
| first_positive_inds, first_positive_inds[sort_inds][inv_sort_inds] | |
| ) | |
| # filter those points never appear points during 1 - T | |
| ind_array = torch.arange(T, device=device) | |
| ind_array = ind_array[None, :, None].repeat(B, 1, N) | |
| track_mask = (ind_array >= | |
| first_positive_inds[None, None, :]).unsqueeze(-1) | |
| # scale the coords_init | |
| coords_init = queries[:, :, 1:].reshape(B, 1, N, 3).repeat( | |
| 1, self.S, 1, 1 | |
| ) | |
| coords_init[..., :2] /= float(self.stride) | |
| #Step3: initial the regular grid | |
| gridx = torch.linspace(0, W//self.stride - 1, W//self.stride) | |
| gridy = torch.linspace(0, H//self.stride - 1, H//self.stride) | |
| gridx, gridy = torch.meshgrid(gridx, gridy) | |
| gridxy = torch.stack([gridx, gridy], dim=-1).to(rgbds.device).permute( | |
| 2, 1, 0 | |
| ) | |
| vis_init = torch.ones((B, self.S, N, 1), device=device).float() * 10 | |
| # Step4: initial traj for neural arap | |
| T_series = torch.linspace(0, 5, T).reshape(1, T, 1 , 1).cuda() # 1 T 1 1 | |
| T_series = T_series.repeat(B, 1, N, 1) | |
| # get the 3d traj in the camera coordinates | |
| intr_init = self.intrs[:,queries[0,:,0].long()] | |
| Traj_series = pix2cam(queries[:,:,None,1:].double(), intr_init.double()) # [B S N 3] | |
| #torch.inverse(intr_init.double())@queries[:,:,1:,None].double() # B N 3 1 | |
| Traj_series = Traj_series.repeat(1, 1, T, 1).permute(0, 2, 1, 3).float() | |
| Traj_series = torch.cat([T_series, Traj_series], dim=-1) | |
| # get the indicator for the neural arap | |
| Traj_mask = -1e2*torch.ones_like(T_series) | |
| Traj_series = torch.cat([Traj_series, Traj_mask], dim=-1) | |
| return ( | |
| rgbds, | |
| first_positive_inds, | |
| first_positive_sorted_inds, | |
| sort_inds, inv_sort_inds, | |
| track_mask, gridxy, coords_init[..., sort_inds, :].clone(), | |
| vis_init, Traj_series[..., sort_inds, :].clone() | |
| ) | |
| def sample_trifeat(self, t, | |
| coords, | |
| featMapxy, | |
| featMapyz, | |
| featMapxz): | |
| """ | |
| Sample the features from the 5D triplane feature map 3*(B S C H W) | |
| Args: | |
| t: the time index | |
| coords: the coordinates of the points B S N 3 | |
| featMapxy: the feature map B S C Hx Wy | |
| featMapyz: the feature map B S C Hy Wz | |
| featMapxz: the feature map B S C Hx Wz | |
| """ | |
| # get xy_t yz_t xz_t | |
| queried_t = t.reshape(1, 1, -1, 1) | |
| xy_t = torch.cat( | |
| [queried_t, coords[..., [0,1]]], | |
| dim=-1 | |
| ) | |
| yz_t = torch.cat( | |
| [queried_t, coords[..., [1, 2]]], | |
| dim=-1 | |
| ) | |
| xz_t = torch.cat( | |
| [queried_t, coords[..., [0, 2]]], | |
| dim=-1 | |
| ) | |
| featxy_init = sample_features5d(featMapxy, xy_t) | |
| featyz_init = sample_features5d(featMapyz, yz_t) | |
| featxz_init = sample_features5d(featMapxz, xz_t) | |
| featxy_init = featxy_init.repeat(1, self.S, 1, 1) | |
| featyz_init = featyz_init.repeat(1, self.S, 1, 1) | |
| featxz_init = featxz_init.repeat(1, self.S, 1, 1) | |
| return featxy_init, featyz_init, featxz_init | |
| def forward(self, rgbds, queries, num_levels=4, feat_init=None, | |
| is_train=False, intrs=None, wind_S=None): | |
| ''' | |
| queries: given trajs (B, f, N, 3) [x, y, z], x, y in camera coordinate, z in depth (need to be normalized) | |
| vis_init: visibility of the points (B, f, N) , 0 for invisible, 1 for visible | |
| ''' | |
| B, T, C, H, W = rgbds.shape | |
| Dz = W//self.stride | |
| rgbs_ = rgbds[:, :, :3,...] | |
| depth_all = rgbds[:, :, 3,...] | |
| d_near = self.d_near = depth_all[depth_all>0.01].min().item() | |
| d_far = self.d_far = depth_all[depth_all>0.01].max().item() | |
| d_near_z = queries.reshape(B, -1, 3)[:, :, 2].min().item() | |
| d_far_z = queries.reshape(B, -1, 3)[:, :, 2].max().item() | |
| d_near = min(d_near, d_near_z) | |
| d_far = max(d_far, d_far_z) | |
| d_near = min(d_near - self.depth_extend_margin, 0.01) | |
| d_far = d_far + self.depth_extend_margin | |
| depths = (depth_all - d_near)/(d_far-d_near) | |
| depths_dn = nn.functional.interpolate( | |
| depths, scale_factor=1.0 / self.stride, mode="nearest") | |
| depths_dnG = depths_dn*Dz | |
| #Step3: initial the regular grid | |
| gridx = torch.linspace(0, W//self.stride - 1, W//self.stride) | |
| gridy = torch.linspace(0, H//self.stride - 1, H//self.stride) | |
| gridx, gridy = torch.meshgrid(gridx, gridy) | |
| gridxy = torch.stack([gridx, gridy], dim=-1).to(rgbds.device).permute( | |
| 2, 1, 0 | |
| ) # 2 H W | |
| gridxyz = torch.cat([gridxy[None,...].repeat( | |
| depths_dn.shape[0],1,1,1), depths_dnG], dim=1) | |
| Fxy2yz = gridxyz[:,[1, 2], ...] - gridxyz[:,:2] | |
| Fxy2xz = gridxyz[:,[0, 2], ...] - gridxyz[:,:2] | |
| if getattr(self.args, "Embed3D", None) == True: | |
| gridxyz_nm = gridxyz.clone() | |
| gridxyz_nm[:,0,...] = (gridxyz_nm[:,0,...]-gridxyz_nm[:,0,...].min())/(gridxyz_nm[:,0,...].max()-gridxyz_nm[:,0,...].min()) | |
| gridxyz_nm[:,1,...] = (gridxyz_nm[:,1,...]-gridxyz_nm[:,1,...].min())/(gridxyz_nm[:,1,...].max()-gridxyz_nm[:,1,...].min()) | |
| gridxyz_nm[:,2,...] = (gridxyz_nm[:,2,...]-gridxyz_nm[:,2,...].min())/(gridxyz_nm[:,2,...].max()-gridxyz_nm[:,2,...].min()) | |
| gridxyz_nm = 2*(gridxyz_nm-0.5) | |
| _,_,h4,w4 = gridxyz_nm.shape | |
| gridxyz_nm = gridxyz_nm.permute(0,2,3,1).reshape(S*h4*w4, 3) | |
| featPE = self.embed3d(gridxyz_nm).view(S, h4, w4, -1).permute(0,3,1,2) | |
| if fmaps_ is None: | |
| fmaps_ = torch.cat([self.fnet(rgbs_),featPE], dim=1) | |
| fmaps_ = self.embedConv(fmaps_) | |
| else: | |
| fmaps_new = torch.cat([self.fnet(rgbs_[self.S // 2 :]),featPE[self.S // 2 :]], dim=1) | |
| fmaps_new = self.embedConv(fmaps_new) | |
| fmaps_ = torch.cat( | |
| [fmaps_[self.S // 2 :], fmaps_new], dim=0 | |
| ) | |
| else: | |
| if fmaps_ is None: | |
| fmaps_ = self.fnet(rgbs_) | |
| else: | |
| fmaps_ = torch.cat( | |
| [fmaps_[self.S // 2 :], self.fnet(rgbs_[self.S // 2 :])], dim=0 | |
| ) | |
| fmapXY = fmaps_[:, :self.latent_dim].reshape( | |
| B, T, self.latent_dim, H // self.stride, W // self.stride | |
| ) | |
| fmapYZ = softsplat(fmapXY[0], Fxy2yz, None, | |
| strMode="avg", tenoutH=self.Dz, tenoutW=H//self.stride) | |
| fmapXZ = softsplat(fmapXY[0], Fxy2xz, None, | |
| strMode="avg", tenoutH=self.Dz, tenoutW=W//self.stride) | |
| fmapYZ = self.headyz(fmapYZ)[None, ...] | |
| fmapXZ = self.headxz(fmapXZ)[None, ...] | |
| # scale the coords_init | |
| coords_init = queries[:, :1] # B 1 N 3, the first frame | |
| coords_init[..., :2] /= float(self.stride) | |
| (featxy_init, | |
| featyz_init, | |
| featxz_init) = self.sample_trifeat( | |
| t=torch.zeros(B*queries.shape[2]),featMapxy=fmapXY, | |
| featMapyz=fmapYZ,featMapxz=fmapXZ, | |
| coords = coords_init # B 1 N 3 | |
| ) | |
| return torch.stack([featxy_init, featyz_init, featxz_init], dim=-1) | |