import torch from torch import nn from typing import Optional from diffusers.models.embeddings import Timesteps, TimestepEmbedding, LabelEmbedding class FinalLayer(nn.Module): """ Final layer of the diffusion model that outputs the final logits. """ def __init__(self, in_ch, out_ch=None, dropout=0.0): super().__init__() out_ch = in_ch if out_ch is None else out_ch self.linear = nn.Linear(in_ch, out_ch) self.norm = AdaLayerNormTC(in_ch, 2 * in_ch, dropout) def forward(self, x, t, cond=None): assert cond is not None x = self.norm(x, t, cond) x = self.linear(x) return x class AdaLayerNormTC(nn.Module): """ Norm layer modified to incorporate timestep and condition embeddings. """ def __init__(self, embedding_dim, num_embeddings, dropout): super().__init__() self.emb = CombinedTimestepLabelEmbeddings( num_embeddings, embedding_dim, dropout ) self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, embedding_dim * 2) self.norm = nn.LayerNorm( embedding_dim, elementwise_affine=False, eps=torch.finfo(torch.float16).eps ) def forward(self, x, timestep, cond): emb = self.linear(self.silu(self.emb(timestep, cond, hidden_dtype=None))) scale, shift = torch.chunk(emb, 2, dim=1) x = self.norm(x) * (1 + scale[:, None]) + shift[:, None] return x class PEmbeder(nn.Module): """ Positional embedding layer. """ def __init__(self, vocab_size, d_model): super().__init__() self.embed = nn.Embedding(vocab_size, d_model) self._init_embeddings() def _init_embeddings(self): nn.init.kaiming_normal_(self.embed.weight, mode="fan_in") def forward(self, x, idx=None): if idx is None: idx = torch.arange(x.shape[1], device=x.device, dtype=torch.long) return x + self.embed(idx) class CombinedTimestepLabelEmbeddings(nn.Module): '''Modified from diffusers.models.embeddings.CombinedTimestepLabelEmbeddings''' def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) def forward(self, timestep, class_labels, hidden_dtype=None, label_free=False): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) force_drop_ids = None # training mode if label_free: # inference mode, force_drop_ids is set to all ones to be dropped in class_embedder force_drop_ids = torch.ones_like(class_labels, dtype=torch.bool, device=class_labels.device) class_labels = self.class_embedder(class_labels, force_drop_ids) # (N, D) conditioning = timesteps_emb + class_labels # (N, D) return conditioning class MyAdaLayerNormZero(nn.Module): """ Adaptive layer norm zero (adaLN-Zero), borrowed from diffusers.models.attention.AdaLayerNormZero. Extended to incorporate scale parameters (gate_2, gate_3) for intermidate attention layers. """ def __init__(self, embedding_dim, num_embeddings, class_dropout_prob): super().__init__() self.emb = CombinedTimestepLabelEmbeddings( num_embeddings, embedding_dim, class_dropout_prob ) self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, 8 * embedding_dim, bias=True) self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) def forward(self, x, timestep, class_labels, hidden_dtype=None, label_free=False): emb_t_cls = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype, label_free=label_free) emb = self.linear(self.silu(emb_t_cls)) ( shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_2, gate_3, ) = emb.chunk(8, dim=1) x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_2, gate_3 class VisAttnProcessor: r""" This code is adapted from diffusers.models.attention_processor.AttnProcessor. Used for visualizing the attention maps when testing, NOT for training. """ def __call__( self, attn, hidden_states, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, *args, **kwargs, ) -> torch.Tensor: # Removed # if len(args) > 0 or kwargs.get("scale", None) is not None: # deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." # deprecate("scale", "1.0.0", deprecation_message) residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) # (40, 160, 16) key = attn.head_to_batch_dim(key) # (40, 256, 16) value = attn.head_to_batch_dim(value) # (40, 256, 16) if attention_mask is not None: if attention_mask.dtype == torch.bool: attn_mask = torch.zeros_like(attention_mask, dtype=query.dtype, device=query.device) attn_mask = attn_mask.masked_fill_(attention_mask.logical_not(), float("-inf")) else: attn_mask = attention_mask assert attn_mask.dtype == query.dtype, f"query and attention_mask must have the same dtype, but got {query.dtype} and {attention_mask.dtype}." else: attn_mask = None attention_probs = attn.get_attention_scores(query, key, attn_mask) # (40, 160, 256) hidden_states = torch.bmm(attention_probs, value) # (40, 160, 16) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor attention_probs = attention_probs.reshape(batch_size, attn.heads, query.shape[1], sequence_length) return hidden_states, attention_probs