import os, sys sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) import torch import models from torch import nn from diffusers.models.attention import Attention, FeedForward from models.utils import ( PEmbeder, FinalLayer, VisAttnProcessor, MyAdaLayerNormZero ) class RAPCrossAttnBlock(nn.Module): def __init__(self, dim, num_layers, num_heads, head_dim, dropout=0.0, img_emb_dims=None): super().__init__() self.layers = nn.ModuleList([ Attention( query_dim=dim, cross_attention_dim=dim, heads=num_heads, dim_head=head_dim, dropout=dropout, bias=True, cross_attention_norm="layer_norm", processor=VisAttnProcessor(), ) for _ in range(num_layers) ]) self.norms = nn.ModuleList([ nn.LayerNorm(dim) for _ in range(num_layers) ]) img_emb_layers = [] for i in range(len(img_emb_dims) - 1): img_emb_layers.append(nn.Linear(img_emb_dims[i], img_emb_dims[i + 1])) img_emb_layers.append(nn.LeakyReLU(inplace=True)) img_emb_layers.pop(-1) self.img_emb = nn.Sequential(*img_emb_layers) self.init_img_emb_weights() def init_img_emb_weights(self): for m in self.img_emb.modules(): if isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight, mode="fan_in") if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, img_first, img_second): """ Inputs: img_first: (B, Np, D) img_second: (B, Np, D) Output: fused_feat: (B, Np, D) """ img_first = self.img_emb(img_first) img_second = self.img_emb(img_second) fused = img_second for norm, attn in zip(self.norms, self.layers): normed = norm(fused) delta, _ = attn(normed, encoder_hidden_states=img_first, attention_mask=None) fused = fused + delta # residual connection return fused class Attn_Block(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, dropout=0.0, activation_fn: str = "geglu", num_embeds_ada_norm: int = None, attention_bias: bool = False, norm_elementwise_affine: bool = True, final_dropout: bool = False, class_dropout_prob: float = 0.0, # for classifier-free img_emb_dims=None, ): super().__init__() self.norm1 = MyAdaLayerNormZero(dim, num_embeds_ada_norm, class_dropout_prob) self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) self.norm4 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) self.norm5 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) self.norm6 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) self.local_attn = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, ) self.global_attn = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, ) self.graph_attn = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, ) self.img_attn = Attention( query_dim=dim, cross_attention_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_norm="layer_norm", processor=VisAttnProcessor(), # to be removed for release model ) self.img_attn_second = Attention( query_dim=dim, cross_attention_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_norm="layer_norm", processor=VisAttnProcessor(), # to be removed for release model ) self.ff = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, ) # image embedding layers layers = [] for i in range(len(img_emb_dims) - 1): layers.append(nn.Linear(img_emb_dims[i], img_emb_dims[i + 1])) layers.append(nn.LeakyReLU(inplace=True)) layers.pop(-1) self.img_emb = nn.Sequential(*layers) self.init_img_emb_weights() def init_img_emb_weights(self): for m in self.img_emb.modules(): if isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight, mode="fan_in") if m.bias is not None: nn.init.constant_(m.bias, 0) def forward( self, hidden_states, img_patches, fuse_feat, pad_mask, attr_mask, graph_mask, timestep, class_labels, label_free=False, ): # image patches embedding img_emb = self.img_emb(img_patches) # adaptive normalization, taken timestep and class_labels as input condition norm_hidden_states, gate_1, shift_mlp, scale_mlp, gate_mlp, gate_2, gate_3 = ( self.norm1( hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype, label_free=label_free ) ) # local attribute self-attention attr_out = self.local_attn(norm_hidden_states, attention_mask=attr_mask) attr_out = gate_1.unsqueeze(1) * attr_out hidden_states = hidden_states + attr_out # global attribute self-attention norm_hidden_states = self.norm2(hidden_states) global_out = self.global_attn(norm_hidden_states, attention_mask=pad_mask) global_out = gate_2.unsqueeze(1) * global_out hidden_states = hidden_states + global_out # graph relation self-attention norm_hidden_states = self.norm3(hidden_states) graph_out = self.graph_attn(norm_hidden_states, attention_mask=graph_mask) graph_out = gate_3.unsqueeze(1) * graph_out hidden_states = hidden_states + graph_out img_first, img_second = img_emb.chunk(2, dim=1) # cross attention with image patches norm_hidden_states = self.norm4(hidden_states) B, Na, D = norm_hidden_states.shape Np = img_first.shape[1] # number of image patches mode_num = Na // 32 reshaped = norm_hidden_states.reshape(B, Na // mode_num, mode_num, D) bboxes = reshaped[:, :, 0, :] # (B, K, D) # cross attention between bbox attributes and image patches bbox_img_out, bbox_cross_attn_map = self.img_attn( bboxes, encoder_hidden_states=img_first, attention_mask=None, ) # cross_attn_map: (B, n_head, K, Np) # to reshape the cross_attn_map back to (B, n_head, Na*5, Np), reduntant for other attributes, fix later # cross_attn_map_reshape = torch.zeros(size=(B, bbox_cross_attn_map.shape[1], Na // mode_num, mode_num, Np), device=bbox_cross_attn_map.device) # cross_attn_map_reshape[:, :, :, 0, :] = bbox_cross_attn_map # cross_attn_map = cross_attn_map_reshape.reshape(B, bbox_cross_attn_map.shape[1], Na, Np) # assemble the output of cross attention with bbox attributes and other attributes img_out = torch.empty(size=(B, Na // mode_num, mode_num, D), device=hidden_states.device, dtype=hidden_states.dtype) img_out[:, :, 0, :] = bbox_img_out img_out[:, :, 1:, :] = reshaped[:, :, 1:, :] img_out = img_out.reshape(B, Na, D) hidden_states = hidden_states + img_out norm_hidden_states = self.norm6(hidden_states) B, Na, D = norm_hidden_states.shape Np = img_second.shape[1] # number of image patches mode_num = Na // 32 reshaped = norm_hidden_states.reshape(B, Na // mode_num, mode_num, D) joints = reshaped # (B, K, 4, D) joints = joints.reshape(B, Na // mode_num * 5, D) # cross attention between bbox attributes and image patches joint_img_out, bbox_cross_attn_map = self.img_attn_second( joints, encoder_hidden_states=fuse_feat, attention_mask=None, ) # cross_attn_map: (B, n_head, K*4, Np) # to reshape the cross_attn_map back to (B, n_head, Na*5, Np), reduntant for other attributes, fix later # cross_attn_map_reshape = torch.zeros(size=(B, bbox_cross_attn_map.shape[1], Na // mode_num, mode_num, Np), device=bbox_cross_attn_map.device) # cross_attn_map_reshape[:, :, :, 1:5, :] = bbox_cross_attn_map.reshape( # B, bbox_cross_attn_map.shape[1], Na // mode_num, 4, Np # ) # cross_attn_map = cross_attn_map_reshape.reshape(B, bbox_cross_attn_map.shape[1], Na, Np) # assemble the output of cross attention with bbox attributes and other attributes img_out = torch.empty(size=(B, Na // mode_num, mode_num, D), device=hidden_states.device, dtype=hidden_states.dtype) img_out = joint_img_out.reshape(B, Na // mode_num, 5, D) img_out = img_out.reshape(B, Na, D) hidden_states = hidden_states + img_out # feed-forward norm_hidden_states = self.norm5(hidden_states) norm_hidden_states = ( norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] ) ff_output = self.ff(norm_hidden_states) ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = ff_output + hidden_states return hidden_states @models.register("denoiser") class Denoiser(nn.Module): """ Denoiser based on CAGE's attribute attention block + our ICA module, with 4 sequential attentions: LA -> GA -> GRA -> ICA Different image adapters for each layer. The image cross attention is with key-padding masks (object mask, part mask) *** The ICA only applies to the bbox attributes, not other attributes such as motion params.*** """ def __init__(self, hparams): super().__init__() self.hparams = hparams self.K = self.hparams.get("K", 32) in_ch = hparams.in_ch attn_dim = hparams.attn_dim mid_dim = attn_dim // 2 n_head = hparams.n_head head_dim = attn_dim // n_head num_embeds_ada_norm = 6 * attn_dim # embedding layers for different node attributes self.aabb_emb = nn.Sequential( nn.Linear(in_ch, mid_dim), nn.ReLU(inplace=True), nn.Linear(mid_dim, attn_dim), ) self.jaxis_emb = nn.Sequential( nn.Linear(in_ch, mid_dim), nn.ReLU(inplace=True), nn.Linear(mid_dim, attn_dim), ) self.range_emb = nn.Sequential( nn.Linear(in_ch, mid_dim), nn.ReLU(inplace=True), nn.Linear(mid_dim, attn_dim), ) self.label_emb = nn.Sequential( nn.Linear(in_ch, mid_dim), nn.ReLU(inplace=True), nn.Linear(mid_dim, attn_dim), ) self.jtype_emb = nn.Sequential( nn.Linear(in_ch, mid_dim), nn.ReLU(inplace=True), nn.Linear(mid_dim, attn_dim), ) # self.node_type_emb = nn.Sequential( # nn.Linear(in_ch, mid_dim), # nn.ReLU(inplace=True), # nn.Linear(mid_dim, attn_dim), # ) # positional encoding for nodes and attributes self.pe_node = PEmbeder(self.K, attn_dim) self.pe_attr = PEmbeder(self.hparams.mode_num, attn_dim) # attention layers self.attn_layers = nn.ModuleList( [ Attn_Block( dim=attn_dim, num_attention_heads=n_head, attention_head_dim=head_dim, class_dropout_prob=hparams.get("cat_drop_prob", 0.0), dropout=hparams.dropout, activation_fn="geglu", num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=False, norm_elementwise_affine=True, final_dropout=False, img_emb_dims=hparams.get("img_emb_dims", None), ) for d in range(hparams.n_layers) ] ) self.image_interaction = RAPCrossAttnBlock( dim=attn_dim, num_layers=6, num_heads=n_head, head_dim=head_dim, dropout=hparams.dropout, img_emb_dims=hparams.get("img_emb_dims", None), ) self.final_layer = FinalLayer(attn_dim, in_ch) def forward( self, x, cat, timesteps, feat, key_pad_mask=None, graph_mask=None, attr_mask=None, label_free=False, ): B = x.shape[0] x = x.view(B, self.K, 5 * 6) # embedding layers for different attributes x_aabb = self.aabb_emb(x[..., :6]) x_jtype = self.jtype_emb(x[..., 6:12]) x_jaxis = self.jaxis_emb(x[..., 12:18]) x_range = self.range_emb(x[..., 18:24]) x_label = self.label_emb(x[..., 24:30]) # x_node_type = self.node_type_emb(x[..., 30:36]) # concatenate all attribute embeddings x_ = torch.cat( [x_aabb, x_jtype, x_jaxis, x_range, x_label], dim=2 ) # (B, K, 6*attn_dim) x_ = x_.view(B, self.K * self.hparams.mode_num, self.hparams.attn_dim) # positional encoding for nodes and attributes idx_attr = torch.tensor( [0, 1, 2, 3, 4], device=x.device, dtype=torch.long ).repeat(self.K) idx_node = torch.arange( self.K, device=x.device, dtype=torch.long ).repeat_interleave(self.hparams.mode_num) x_ = self.pe_attr(self.pe_node(x_, idx=idx_node), idx=idx_attr) # init tensor to store attention maps Np = feat.shape[1] img_first, img_second = feat.chunk(2, dim=1) fused_img_feat = self.image_interaction(img_first, img_second) # (B, Np, D) # attention layers for i, attn_layer in enumerate(self.attn_layers): x_ = attn_layer( hidden_states=x_, img_patches=feat, fuse_feat=fused_img_feat, timestep=timesteps, class_labels=cat, pad_mask=key_pad_mask, graph_mask=graph_mask, attr_mask=attr_mask, label_free=label_free, ) y = self.final_layer(x_, timesteps, cat) return { 'noise_pred': y, 'attn_maps': None, }