| | |
| | import math |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | def my_scaled_dot_product_attention( |
| | query, |
| | key, |
| | value, |
| | attn_mask=None, |
| | dropout_p=0.0, |
| | is_causal=False, |
| | scale=None, |
| | special_token_weight=1.0, |
| | special_token_indices=None, |
| | ) -> torch.Tensor: |
| | """ |
| | Computes the scaled dot-product attention with additional control over specific tokens. |
| | |
| | This function is a re-implementation of the scaled dot-product attention mechanism, |
| | designed to return both the attention map and the output of the attention operation. |
| | It also provides additional control via a scalar that modifies the attention map |
| | for specific tokens. |
| | """ |
| | L, S = query.size(-2), key.size(-2) |
| | scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale |
| | attn_bias = torch.zeros(L, S, dtype=query.dtype).cuda() |
| | if is_causal: |
| | assert attn_mask is None |
| | temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) |
| | attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) |
| | attn_bias.to(query.dtype) |
| |
|
| | if attn_mask is not None: |
| | if attn_mask.dtype == torch.bool: |
| | attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) |
| | else: |
| | attn_bias += attn_mask |
| | attn_weight = query @ key.transpose(-2, -1) * scale_factor |
| | attn_weight += attn_bias |
| | if special_token_indices is not None and special_token_weight != 1.0: |
| | bs = attn_weight.shape[0] |
| | attn_weight[torch.arange(bs), :, :, special_token_indices] = torch.max( |
| | attn_weight[torch.arange(bs), :, :, special_token_indices], |
| | attn_weight[torch.arange(bs), :, :, special_token_indices] |
| | * special_token_weight, |
| | ) |
| |
|
| | attn_weight = torch.softmax(attn_weight, dim=-1) |
| | attn_weight = torch.dropout(attn_weight, dropout_p, train=True) |
| | return attn_weight @ value, attn_weight |
| |
|
| |
|
| | class AttnProcessor(torch.nn.Module): |
| | r""" |
| | Processor for implementing scaled dot-product attention. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | hidden_size=None, |
| | cross_attention_dim=None, |
| | ): |
| | super().__init__() |
| | if not hasattr(F, "scaled_dot_product_attention"): |
| | raise ImportError( |
| | "AttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." |
| | ) |
| |
|
| | def __call__( |
| | self, |
| | attn, |
| | hidden_states, |
| | qformer_tokens_out=None, |
| | special_token_indices=None, |
| | inference_mode=None, |
| | encoder_hidden_states=None, |
| | attention_mask=None, |
| | temb=None, |
| | special_token_weight=None, |
| | ): |
| | residual = hidden_states |
| |
|
| | if attn.spatial_norm is not None: |
| | hidden_states = attn.spatial_norm(hidden_states, temb) |
| |
|
| | input_ndim = hidden_states.ndim |
| |
|
| | if input_ndim == 4: |
| | batch_size, channel, height, width = hidden_states.shape |
| | hidden_states = hidden_states.view( |
| | batch_size, channel, height * width |
| | ).transpose(1, 2) |
| |
|
| | batch_size, sequence_length, _ = ( |
| | hidden_states.shape |
| | if encoder_hidden_states is None |
| | else encoder_hidden_states.shape |
| | ) |
| |
|
| | if attention_mask is not None: |
| | attention_mask = attn.prepare_attention_mask( |
| | attention_mask, sequence_length, batch_size |
| | ) |
| | |
| | |
| | attention_mask = attention_mask.view( |
| | batch_size, attn.heads, -1, attention_mask.shape[-1] |
| | ) |
| |
|
| | if attn.group_norm is not None: |
| | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( |
| | 1, 2 |
| | ) |
| |
|
| | query = attn.to_q(hidden_states) |
| |
|
| | if encoder_hidden_states is None: |
| | encoder_hidden_states = hidden_states |
| | elif attn.norm_cross: |
| | encoder_hidden_states = attn.norm_encoder_hidden_states( |
| | encoder_hidden_states |
| | ) |
| |
|
| | key = attn.to_k(encoder_hidden_states) |
| | value = attn.to_v(encoder_hidden_states) |
| |
|
| | inner_dim = key.shape[-1] |
| | head_dim = inner_dim // attn.heads |
| |
|
| | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| |
|
| | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| |
|
| | |
| | hidden_states = F.scaled_dot_product_attention( |
| | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
| | ) |
| |
|
| | hidden_states = hidden_states.transpose(1, 2).reshape( |
| | batch_size, -1, attn.heads * head_dim |
| | ) |
| | hidden_states = hidden_states.to(query.dtype) |
| |
|
| | |
| | hidden_states = attn.to_out[0](hidden_states) |
| | |
| | hidden_states = attn.to_out[1](hidden_states) |
| |
|
| | if input_ndim == 4: |
| | hidden_states = hidden_states.transpose(-1, -2).reshape( |
| | batch_size, channel, height, width |
| | ) |
| |
|
| | if attn.residual_connection: |
| | hidden_states = hidden_states + residual |
| |
|
| | hidden_states = hidden_states / attn.rescale_output_factor |
| |
|
| | return hidden_states |
| |
|
| |
|
| | class NestedAttnProcessor(torch.nn.Module): |
| | r""" |
| | Nested Attention processor for IP-Adapater for PyTorch 2.0. |
| | """ |
| |
|
| | def __init__(self, hidden_size, cross_attention_dim=None, normalize_factor=1.0): |
| | super().__init__() |
| |
|
| | if not hasattr(F, "scaled_dot_product_attention"): |
| | raise ImportError( |
| | "NestedAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." |
| | ) |
| |
|
| | self.hidden_size = hidden_size |
| | self.cross_attention_dim = cross_attention_dim |
| |
|
| | self.normalize_factor = normalize_factor |
| |
|
| | self.nested_to_k = nn.Linear( |
| | cross_attention_dim or hidden_size, hidden_size, bias=False |
| | ) |
| | self.nested_to_v = nn.Linear( |
| | cross_attention_dim or hidden_size, hidden_size, bias=False |
| | ) |
| |
|
| | def __call__( |
| | self, |
| | attn, |
| | hidden_states, |
| | qformer_tokens_out, |
| | special_token_indices, |
| | inference_mode=False, |
| | encoder_hidden_states=None, |
| | attention_mask=None, |
| | temb=None, |
| | special_token_weight=1.0, |
| | ): |
| | assert ( |
| | special_token_indices.shape[0] > 0 |
| | ), "special_token_indices should not be empty" |
| |
|
| | |
| | |
| |
|
| | residual = hidden_states |
| |
|
| | if attn.spatial_norm is not None: |
| | hidden_states = attn.spatial_norm(hidden_states, temb) |
| |
|
| | input_ndim = hidden_states.ndim |
| | bs = hidden_states.shape[0] |
| |
|
| | if input_ndim == 4: |
| | bs, channel, height, width = hidden_states.shape |
| | hidden_states = hidden_states.view(bs, channel, height * width).transpose( |
| | 1, 2 |
| | ) |
| |
|
| | bs, sequence_length, _ = ( |
| | hidden_states.shape |
| | if encoder_hidden_states is None |
| | else encoder_hidden_states.shape |
| | ) |
| |
|
| | if attention_mask is not None: |
| | attention_mask = attn.prepare_attention_mask( |
| | attention_mask, sequence_length, bs |
| | ) |
| | |
| | |
| | attention_mask = attention_mask.view( |
| | bs, attn.heads, -1, attention_mask.shape[-1] |
| | ) |
| |
|
| | if attn.group_norm is not None: |
| | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( |
| | 1, 2 |
| | ) |
| |
|
| | query = attn.to_q(hidden_states) |
| |
|
| | if encoder_hidden_states is None: |
| | encoder_hidden_states = hidden_states |
| | else: |
| | if attn.norm_cross: |
| | encoder_hidden_states = attn.norm_encoder_hidden_states( |
| | encoder_hidden_states |
| | ) |
| |
|
| | key = attn.to_k(encoder_hidden_states) |
| | value = attn.to_v(encoder_hidden_states) |
| |
|
| | inner_dim = key.shape[-1] |
| | head_dim = inner_dim // attn.heads |
| |
|
| | query = query.view(bs, -1, attn.heads, head_dim).transpose(1, 2) |
| |
|
| | key = key.view(bs, -1, attn.heads, head_dim).transpose(1, 2) |
| | value = value.view(bs, -1, attn.heads, head_dim).transpose(1, 2) |
| |
|
| | |
| | nested_key = self.nested_to_k(qformer_tokens_out) |
| | nested_value = self.nested_to_v(qformer_tokens_out) |
| |
|
| | nested_key = nested_key.view(bs, -1, attn.heads, head_dim).transpose(1, 2) |
| | nested_value = nested_value.view(bs, -1, attn.heads, head_dim).transpose(1, 2) |
| |
|
| | nested_hidden_states = F.scaled_dot_product_attention( |
| | query, |
| | nested_key, |
| | nested_value, |
| | attn_mask=None, |
| | dropout_p=0.0, |
| | is_causal=False, |
| | ) |
| |
|
| | |
| | textual_values_norms = torch.norm( |
| | value[torch.arange(bs), :, special_token_indices], dim=-1 |
| | ) |
| | nested_hidden_states = ( |
| | torch.nn.functional.normalize(nested_hidden_states, p=2, dim=-1) |
| | * self.normalize_factor |
| | ) |
| | nested_hidden_states = ( |
| | textual_values_norms.view(bs, -1, 1, 1) * nested_hidden_states |
| | ) |
| |
|
| | |
| | value_without_special_tokens = value.clone() |
| | if inference_mode: |
| | value_without_special_tokens[bs // 2 : bs, :, special_token_indices, :] = ( |
| | 0.0 |
| | ) |
| | else: |
| | value_without_special_tokens[ |
| | torch.arange(bs), :, special_token_indices, : |
| | ] = 0.0 |
| | hidden_states_without_special_tokens, attn_weight = ( |
| | my_scaled_dot_product_attention( |
| | query, |
| | key, |
| | value_without_special_tokens, |
| | attn_mask=None, |
| | dropout_p=0.0, |
| | is_causal=False, |
| | special_token_weight=special_token_weight, |
| | special_token_indices=special_token_indices, |
| | ) |
| | ) |
| |
|
| | |
| | if inference_mode: |
| | special_token_attn_weight = attn_weight[ |
| | bs // 2 : bs, :, :, special_token_indices |
| | ] |
| | else: |
| | special_token_attn_weight = attn_weight[ |
| | torch.arange(bs), :, :, special_token_indices |
| | ] |
| | if inference_mode: |
| | special_token_weighted_values = ( |
| | special_token_attn_weight * nested_hidden_states[bs // 2 : bs] |
| | ) |
| | else: |
| | special_token_weighted_values = ( |
| | special_token_attn_weight.unsqueeze(-1) * nested_hidden_states |
| | ) |
| | if inference_mode: |
| | hidden_states = hidden_states_without_special_tokens |
| | hidden_states[bs // 2 : bs] += special_token_weighted_values |
| | else: |
| | hidden_states = ( |
| | hidden_states_without_special_tokens + special_token_weighted_values |
| | ) |
| |
|
| | |
| | hidden_states = hidden_states.transpose(1, 2).reshape( |
| | bs, -1, attn.heads * head_dim |
| | ) |
| | hidden_states = hidden_states.to(query.dtype) |
| |
|
| | |
| | hidden_states = attn.to_out[0](hidden_states) |
| | |
| | hidden_states = attn.to_out[1](hidden_states) |
| |
|
| | if input_ndim == 4: |
| | hidden_states = hidden_states.transpose(-1, -2).reshape( |
| | bs, channel, height, width |
| | ) |
| |
|
| | if attn.residual_connection: |
| | hidden_states = hidden_states + residual |
| |
|
| | hidden_states = hidden_states / attn.rescale_output_factor |
| |
|
| | return hidden_states |
| |
|