tpu-optimized-llm / model /attention.py
Threatthriver's picture
Upload folder using huggingface_hub
f24563f verified
"""
Attention mechanisms for the LLM model.
"""
import jax
import jax.numpy as jnp
import flax.linen as nn
from typing import Optional, Tuple, Dict, Any, Callable, Union
import math
import functools
from einops import rearrange, repeat
from model.embedding import RotaryPositionalEmbedding
class MultiHeadAttention(nn.Module):
"""
Multi-Head Attention mechanism.
Attributes:
dim: Hidden dimension
num_heads: Number of attention heads
head_dim: Dimension of each attention head
dropout_rate: Dropout probability
dtype: Data type for computations
"""
dim: int
num_heads: int
head_dim: Optional[int] = None
dropout_rate: float = 0.0
dtype: jnp.dtype = jnp.float32
def setup(self):
# Determine head dimension if not provided
self.actual_head_dim = self.head_dim or self.dim // self.num_heads
# Projection matrices
self.q_proj = nn.Dense(
features=self.num_heads * self.actual_head_dim,
dtype=self.dtype,
kernel_init=nn.initializers.normal(stddev=0.02),
name="q_proj"
)
self.k_proj = nn.Dense(
features=self.num_heads * self.actual_head_dim,
dtype=self.dtype,
kernel_init=nn.initializers.normal(stddev=0.02),
name="k_proj"
)
self.v_proj = nn.Dense(
features=self.num_heads * self.actual_head_dim,
dtype=self.dtype,
kernel_init=nn.initializers.normal(stddev=0.02),
name="v_proj"
)
self.out_proj = nn.Dense(
features=self.dim,
dtype=self.dtype,
kernel_init=nn.initializers.normal(stddev=0.02),
name="out_proj"
)
self.dropout = nn.Dropout(rate=self.dropout_rate)
def __call__(
self,
hidden_states: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
past_key_value: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None,
output_attentions: bool = False,
use_cache: bool = False,
deterministic: bool = True,
) -> Tuple[jnp.ndarray, ...]:
"""
Apply multi-head attention.
Args:
hidden_states: Input tensor [batch_size, seq_len, dim]
attention_mask: Attention mask [batch_size, 1, seq_len, seq_len]
position_ids: Position indices [batch_size, seq_len]
past_key_value: Cached key and value tensors for incremental decoding
output_attentions: Whether to return attention weights
use_cache: Whether to use cached key and values
deterministic: Whether to use deterministic operations (no dropout)
Returns:
Tuple of (output, attention_weights, present_key_value)
"""
batch_size, seq_len, _ = hidden_states.shape
# Project inputs to queries, keys, and values
q = self.q_proj(hidden_states) # [batch_size, seq_len, num_heads * head_dim]
k = self.k_proj(hidden_states) # [batch_size, seq_len, num_heads * head_dim]
v = self.v_proj(hidden_states) # [batch_size, seq_len, num_heads * head_dim]
# Reshape to [batch_size, seq_len, num_heads, head_dim]
q = q.reshape(batch_size, seq_len, self.num_heads, self.actual_head_dim)
k = k.reshape(batch_size, seq_len, self.num_heads, self.actual_head_dim)
v = v.reshape(batch_size, seq_len, self.num_heads, self.actual_head_dim)
# Handle cached key and values for incremental decoding
if past_key_value is not None and use_cache:
past_k, past_v = past_key_value
k = jnp.concatenate([past_k, k], axis=1)
v = jnp.concatenate([past_v, v], axis=1)
# Save key and value for future use if caching
present_key_value = (k, v) if use_cache else None
# Transpose to [batch_size, num_heads, seq_len, head_dim]
q = jnp.transpose(q, (0, 2, 1, 3))
k = jnp.transpose(k, (0, 2, 1, 3))
v = jnp.transpose(v, (0, 2, 1, 3))
# Compute attention scores
# [batch_size, num_heads, seq_len, seq_len]
attention_scores = jnp.matmul(q, jnp.transpose(k, (0, 1, 3, 2))) / math.sqrt(self.actual_head_dim)
# Apply attention mask if provided
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
# Apply softmax to get attention weights
attention_weights = jax.nn.softmax(attention_scores, axis=-1)
# Apply dropout to attention weights
attention_weights = self.dropout(attention_weights, deterministic=deterministic)
# Compute attention output
# [batch_size, num_heads, seq_len, head_dim]
attention_output = jnp.matmul(attention_weights, v)
# Transpose and reshape to [batch_size, seq_len, dim]
attention_output = jnp.transpose(attention_output, (0, 2, 1, 3))
attention_output = attention_output.reshape(batch_size, seq_len, self.num_heads * self.actual_head_dim)
# Project to output dimension
output = self.out_proj(attention_output)
outputs = (output, attention_weights, present_key_value) if output_attentions else (output, None, present_key_value)
return outputs
class MultiQueryAttention(nn.Module):
"""
Multi-Query Attention mechanism.
Uses a single key and value head for multiple query heads.
Attributes:
dim: Hidden dimension
num_query_heads: Number of query heads
num_kv_heads: Number of key-value heads (usually 1 or a small number)
head_dim: Dimension of each attention head
dropout_rate: Dropout probability
dtype: Data type for computations
"""
dim: int
num_query_heads: int
num_kv_heads: int = 1
head_dim: Optional[int] = None
dropout_rate: float = 0.0
dtype: jnp.dtype = jnp.float32
def setup(self):
# Determine head dimension if not provided
self.actual_head_dim = self.head_dim or self.dim // self.num_query_heads
# Projection matrices
self.q_proj = nn.Dense(
features=self.num_query_heads * self.actual_head_dim,
dtype=self.dtype,
kernel_init=nn.initializers.normal(stddev=0.02),
name="q_proj"
)
self.k_proj = nn.Dense(
features=self.num_kv_heads * self.actual_head_dim,
dtype=self.dtype,
kernel_init=nn.initializers.normal(stddev=0.02),
name="k_proj"
)
self.v_proj = nn.Dense(
features=self.num_kv_heads * self.actual_head_dim,
dtype=self.dtype,
kernel_init=nn.initializers.normal(stddev=0.02),
name="v_proj"
)
self.out_proj = nn.Dense(
features=self.dim,
dtype=self.dtype,
kernel_init=nn.initializers.normal(stddev=0.02),
name="out_proj"
)
self.dropout = nn.Dropout(rate=self.dropout_rate)
def __call__(
self,
hidden_states: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
past_key_value: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None,
output_attentions: bool = False,
use_cache: bool = False,
deterministic: bool = True,
) -> Tuple[jnp.ndarray, ...]:
"""
Apply multi-query attention.
Args:
hidden_states: Input tensor [batch_size, seq_len, dim]
attention_mask: Attention mask [batch_size, 1, seq_len, seq_len]
position_ids: Position indices [batch_size, seq_len]
past_key_value: Cached key and value tensors for incremental decoding
output_attentions: Whether to return attention weights
use_cache: Whether to use cached key and values
deterministic: Whether to use deterministic operations (no dropout)
Returns:
Tuple of (output, attention_weights, present_key_value)
"""
batch_size, seq_len, _ = hidden_states.shape
# Project inputs to queries, keys, and values
q = self.q_proj(hidden_states) # [batch_size, seq_len, num_query_heads * head_dim]
k = self.k_proj(hidden_states) # [batch_size, seq_len, num_kv_heads * head_dim]
v = self.v_proj(hidden_states) # [batch_size, seq_len, num_kv_heads * head_dim]
# Reshape
q = q.reshape(batch_size, seq_len, self.num_query_heads, self.actual_head_dim)
k = k.reshape(batch_size, seq_len, self.num_kv_heads, self.actual_head_dim)
v = v.reshape(batch_size, seq_len, self.num_kv_heads, self.actual_head_dim)
# Handle cached key and values for incremental decoding
if past_key_value is not None and use_cache:
past_k, past_v = past_key_value
k = jnp.concatenate([past_k, k], axis=1)
v = jnp.concatenate([past_v, v], axis=1)
# Save key and value for future use if caching
present_key_value = (k, v) if use_cache else None
# Transpose to [batch_size, num_*_heads, seq_len, head_dim]
q = jnp.transpose(q, (0, 2, 1, 3))
k = jnp.transpose(k, (0, 2, 1, 3))
v = jnp.transpose(v, (0, 2, 1, 3))
# Repeat k and v for each query head
if self.num_kv_heads < self.num_query_heads:
# Calculate how many times to repeat
repeats = self.num_query_heads // self.num_kv_heads
# Repeat k and v along the head dimension
k = jnp.repeat(k, repeats, axis=1)
v = jnp.repeat(v, repeats, axis=1)
# Compute attention scores
# [batch_size, num_query_heads, seq_len, seq_len]
attention_scores = jnp.matmul(q, jnp.transpose(k, (0, 1, 3, 2))) / math.sqrt(self.actual_head_dim)
# Apply attention mask if provided
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
# Apply softmax to get attention weights
attention_weights = jax.nn.softmax(attention_scores, axis=-1)
# Apply dropout to attention weights
attention_weights = self.dropout(attention_weights, deterministic=deterministic)
# Compute attention output
# [batch_size, num_query_heads, seq_len, head_dim]
attention_output = jnp.matmul(attention_weights, v)
# Transpose and reshape to [batch_size, seq_len, dim]
attention_output = jnp.transpose(attention_output, (0, 2, 1, 3))
attention_output = attention_output.reshape(batch_size, seq_len, self.num_query_heads * self.actual_head_dim)
# Project to output dimension
output = self.out_proj(attention_output)
outputs = (output, attention_weights, present_key_value) if output_attentions else (output, None, present_key_value)
return outputs
def flash_attention(q, k, v, mask=None, dropout_rate=0.0, deterministic=True, causal=True, block_size=128, max_context_length=131072):
"""
Implements optimized Flash Attention algorithm for TPU v4-32 with blocked computation.
Args:
q: Query tensor [batch_size, num_heads, seq_len, head_dim]
k: Key tensor [batch_size, num_heads, seq_len, head_dim]
v: Value tensor [batch_size, num_heads, seq_len, head_dim]
mask: Attention mask [batch_size, 1, seq_len, seq_len]
dropout_rate: Dropout probability
deterministic: Whether to use deterministic operations (no dropout)
causal: Whether to use causal masking
block_size: Block size for chunked attention computation
Returns:
Output tensor [batch_size, num_heads, seq_len, head_dim]
"""
batch_size, num_heads, seq_len, head_dim = q.shape
scale = 1.0 / math.sqrt(head_dim)
# Scaled dot-product
q = q * scale
# Dynamically adjust block size for very long sequences
if seq_len > 32768:
# For extremely long sequences (128K), use larger blocks
adjusted_block_size = min(2048, ((seq_len + 2047) // 2048) * 2048)
if adjusted_block_size > block_size:
print(f"Adjusting block size to {adjusted_block_size} for sequence length {seq_len}")
block_size = adjusted_block_size
# For short sequences, use standard attention
if seq_len <= block_size:
# Compute attention scores
scores = jnp.matmul(q, jnp.swapaxes(k, -2, -1))
# Apply causal mask if needed
if causal:
causal_mask = jnp.triu(jnp.ones((seq_len, seq_len), dtype=jnp.bool_), k=1)
causal_mask = jnp.expand_dims(jnp.expand_dims(causal_mask, 0), 0) # [1, 1, seq_len, seq_len]
scores = jnp.where(causal_mask, jnp.finfo(scores.dtype).min, scores)
# Apply attention mask if provided
if mask is not None:
scores = scores + mask
# Apply softmax
attention_weights = jax.nn.softmax(scores, axis=-1)
# Apply dropout
if dropout_rate > 0.0 and not deterministic:
attention_weights = nn.dropout(attention_weights, rate=dropout_rate, deterministic=deterministic)
# Compute attention output
output = jnp.matmul(attention_weights, v)
return output
# For long sequences, use blocked attention computation
# This implementation is optimized for TPU by using blocks that fit in HBM
# Pad sequence length to multiple of block_size for efficient blocking
padded_seq_len = ((seq_len + block_size - 1) // block_size) * block_size
pad_len = padded_seq_len - seq_len
if pad_len > 0:
# Pad inputs
q_padded = jnp.pad(q, ((0, 0), (0, 0), (0, pad_len), (0, 0)))
k_padded = jnp.pad(k, ((0, 0), (0, 0), (0, pad_len), (0, 0)))
v_padded = jnp.pad(v, ((0, 0), (0, 0), (0, pad_len), (0, 0)))
else:
q_padded, k_padded, v_padded = q, k, v
# Initialize output
output_padded = jnp.zeros((batch_size, num_heads, padded_seq_len, head_dim), dtype=q.dtype)
# Define a scan function for processing blocks
def block_scan_fn(carry, idx):
block_start = idx * block_size
block_end = block_start + block_size
q_block = jax.lax.dynamic_slice(
q_padded, (0, 0, block_start, 0),
(batch_size, num_heads, block_size, head_dim)
)
# Compute attention for this block
attn_weights = jnp.matmul(q_block, jnp.swapaxes(k_padded, -2, -1))
# Apply causal mask if needed
if causal:
# Create causal mask for this block
row_idx = jnp.arange(block_size) + block_start
col_idx = jnp.arange(padded_seq_len)
causal_mask = jnp.less(row_idx[:, None], col_idx[None, :])
causal_mask = jnp.logical_not(causal_mask)
causal_mask = jnp.expand_dims(jnp.expand_dims(causal_mask, 0), 0)
attn_weights = jnp.where(causal_mask, jnp.finfo(attn_weights.dtype).min, attn_weights)
# Apply attention mask if provided
if mask is not None:
# Slice the mask for this block
if mask.shape[-2] == 1: # Broadcast mask
mask_block = mask
else:
mask_block = jax.lax.dynamic_slice(
mask, (0, 0, block_start, 0),
(batch_size, 1, block_size, mask.shape[-1])
)
attn_weights = attn_weights + mask_block
# Apply softmax
attn_weights = jax.nn.softmax(attn_weights, axis=-1)
# Apply dropout
if dropout_rate > 0.0 and not deterministic:
attn_weights = nn.dropout(attn_weights, rate=dropout_rate, deterministic=deterministic)
# Compute output for this block
block_output = jnp.matmul(attn_weights, v_padded)
# Update output
output_padded_updated = jax.lax.dynamic_update_slice(
carry, block_output, (0, 0, block_start, 0)
)
return output_padded_updated, None
# Process blocks
num_blocks = padded_seq_len // block_size
output_padded, _ = jax.lax.scan(
block_scan_fn, output_padded, jnp.arange(num_blocks)
)
# Slice to get original sequence length
output = jax.lax.dynamic_slice(
output_padded, (0, 0, 0, 0),
(batch_size, num_heads, seq_len, head_dim)
)
return output
class FlashAttention(nn.Module):
"""
Optimized Flash Attention implementation for TPU v4-32 with support for very long sequences.
Attributes:
dim: Hidden dimension
num_heads: Number of attention heads
head_dim: Dimension of each attention head
dropout_rate: Dropout probability
dtype: Data type for computations
use_causal_mask: Whether to use causal masking
block_size: Block size for chunked attention computation
use_fused_attention: Whether to use fused attention operations
"""
dim: int
num_heads: int
head_dim: Optional[int] = None
dropout_rate: float = 0.0
dtype: jnp.dtype = jnp.float32
use_causal_mask: bool = True
block_size: int = 128 # Optimal block size for TPU v4-32
use_fused_attention: bool = True # Use fused attention operations when available
def setup(self):
# Determine head dimension if not provided
self.actual_head_dim = self.head_dim or self.dim // self.num_heads
# Round head dimension to multiple of 8 for TPU efficiency
if self.actual_head_dim % 8 != 0:
print(f"Warning: Head dimension {self.actual_head_dim} is not a multiple of 8. "
f"This may reduce TPU efficiency.")
# Projection matrices with optimized initialization for stability
self.q_proj = nn.Dense(
features=self.num_heads * self.actual_head_dim,
dtype=self.dtype,
kernel_init=nn.initializers.variance_scaling(
scale=1.0, mode='fan_in', distribution='normal'
),
name="q_proj"
)
self.k_proj = nn.Dense(
features=self.num_heads * self.actual_head_dim,
dtype=self.dtype,
kernel_init=nn.initializers.variance_scaling(
scale=1.0, mode='fan_in', distribution='normal'
),
name="k_proj"
)
self.v_proj = nn.Dense(
features=self.num_heads * self.actual_head_dim,
dtype=self.dtype,
kernel_init=nn.initializers.variance_scaling(
scale=1.0, mode='fan_in', distribution='normal'
),
name="v_proj"
)
self.out_proj = nn.Dense(
features=self.dim,
dtype=self.dtype,
kernel_init=nn.initializers.variance_scaling(
scale=1.0, mode='fan_out', distribution='normal'
),
name="out_proj"
)
self.dropout = nn.Dropout(rate=self.dropout_rate)
def __call__(
self,
hidden_states: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None, # Unused but kept for API compatibility
past_key_value: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None,
output_attentions: bool = False,
use_cache: bool = False,
deterministic: bool = True,
) -> Tuple[jnp.ndarray, ...]:
"""
Apply optimized flash attention for TPU v4-32.
Args:
hidden_states: Input tensor [batch_size, seq_len, dim]
attention_mask: Attention mask [batch_size, 1, seq_len, seq_len]
position_ids: Position indices [batch_size, seq_len] (unused but kept for API compatibility)
past_key_value: Cached key and value tensors for incremental decoding
output_attentions: Whether to return attention weights
use_cache: Whether to use cached key and values
deterministic: Whether to use deterministic operations (no dropout)
Returns:
Tuple of (output, attention_weights, present_key_value)
"""
batch_size, seq_len, _ = hidden_states.shape
# Check if sequence length is compatible with block size
if seq_len > 32768 and self.block_size < 256:
# Adjust block size for very long sequences
adjusted_block_size = min(512, ((seq_len + 511) // 512) * 512)
print(f"Adjusting block size to {adjusted_block_size} for sequence length {seq_len}")
block_size = adjusted_block_size
else:
block_size = self.block_size
# Project inputs to queries, keys, and values with optimized memory layout
# Use jit to optimize the projection operations
@jax.jit
def project_qkv(states):
q = self.q_proj(states)
k = self.k_proj(states)
v = self.v_proj(states)
return q, k, v
q, k, v = project_qkv(hidden_states)
# Reshape to [batch_size, seq_len, num_heads, head_dim] with optimized memory layout
# This reshaping is optimized for TPU memory access patterns
q = q.reshape(batch_size, seq_len, self.num_heads, self.actual_head_dim)
k = k.reshape(batch_size, seq_len, self.num_heads, self.actual_head_dim)
v = v.reshape(batch_size, seq_len, self.num_heads, self.actual_head_dim)
# Handle cached key and values for incremental decoding
key_seq_len = seq_len
if past_key_value is not None and use_cache:
past_k, past_v = past_key_value
k = jnp.concatenate([past_k, k], axis=1)
v = jnp.concatenate([past_v, v], axis=1)
key_seq_len = k.shape[1] # Update key sequence length
# Save key and value for future use if caching
present_key_value = (k, v) if use_cache else None
# Transpose to [batch_size, num_heads, seq_len, head_dim]
# This transpose is optimized for TPU memory access patterns
q = jnp.transpose(q, (0, 2, 1, 3))
k = jnp.transpose(k, (0, 2, 1, 3))
v = jnp.transpose(v, (0, 2, 1, 3))
# Try to use JAX's built-in optimized attention if available and enabled
use_jax_attention = self.use_fused_attention and hasattr(jax.lax, 'dot_general_attention')
if use_jax_attention and not output_attentions and seq_len <= 4096:
# Use JAX's built-in optimized attention for shorter sequences
try:
# Prepare mask for JAX attention
if attention_mask is not None:
# Convert mask to the format expected by dot_general_attention
bias = attention_mask
else:
bias = None
# Use JAX's optimized attention
attention_output = jax.lax.dot_general_attention(
q, k, v, bias=bias, precision=jax.lax.Precision.DEFAULT
)
except (AttributeError, TypeError) as e:
# Fall back to custom implementation if JAX's optimized attention fails
print(f"Warning: JAX optimized attention failed, falling back to custom implementation: {e}")
use_jax_attention = False
else:
use_jax_attention = False
if not use_jax_attention:
# Apply our optimized flash attention implementation
attention_output = flash_attention(
q=q,
k=k,
v=v,
mask=attention_mask,
dropout_rate=self.dropout_rate,
deterministic=deterministic,
causal=self.use_causal_mask,
block_size=block_size
)
# Transpose and reshape to [batch_size, seq_len, dim] with optimized memory layout
attention_output = jnp.transpose(attention_output, (0, 2, 1, 3))
attention_output = attention_output.reshape(batch_size, seq_len, self.num_heads * self.actual_head_dim)
# Project to output dimension with optimized memory layout
output = self.out_proj(attention_output)
# For compatibility with other attention implementations
attention_weights = None
if output_attentions:
# Compute attention weights for visualization
# Note: This is not efficient and should only be used for debugging
attention_scores = jnp.matmul(q, jnp.transpose(k, (0, 1, 3, 2))) / math.sqrt(self.actual_head_dim)
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
attention_weights = jax.nn.softmax(attention_scores, axis=-1)
outputs = (output, attention_weights, present_key_value) if output_attentions else (output, None, present_key_value)
return outputs
class RotaryMultiQueryAttention(nn.Module):
"""
Multi-Query Attention with Rotary Position Embeddings (RoPE).
Attributes:
dim: Hidden dimension
num_query_heads: Number of query heads
num_kv_heads: Number of key-value heads
head_dim: Dimension of each attention head
max_seq_len: Maximum sequence length for RoPE
rope_base: Base for RoPE frequency computation
dropout_rate: Dropout probability
dtype: Data type for computations
"""
dim: int
num_query_heads: int
num_kv_heads: int = 1
head_dim: Optional[int] = None
max_seq_len: int = 4096
rope_base: int = 10000
dropout_rate: float = 0.0
dtype: jnp.dtype = jnp.float32
def setup(self):
# Determine head dimension if not provided
self.actual_head_dim = self.head_dim or self.dim // self.num_query_heads
# Projection matrices
self.q_proj = nn.Dense(
features=self.num_query_heads * self.actual_head_dim,
dtype=self.dtype,
kernel_init=nn.initializers.normal(stddev=0.02),
name="q_proj"
)
self.k_proj = nn.Dense(
features=self.num_kv_heads * self.actual_head_dim,
dtype=self.dtype,
kernel_init=nn.initializers.normal(stddev=0.02),
name="k_proj"
)
self.v_proj = nn.Dense(
features=self.num_kv_heads * self.actual_head_dim,
dtype=self.dtype,
kernel_init=nn.initializers.normal(stddev=0.02),
name="v_proj"
)
self.out_proj = nn.Dense(
features=self.dim,
dtype=self.dtype,
kernel_init=nn.initializers.normal(stddev=0.02),
name="out_proj"
)
self.dropout = nn.Dropout(rate=self.dropout_rate)
# Rotary position embeddings
self.rotary_emb = RotaryPositionalEmbedding(
dim=self.actual_head_dim,
max_seq_len=self.max_seq_len,
base=self.rope_base,
dtype=self.dtype
)
def __call__(
self,
hidden_states: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
past_key_value: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None,
output_attentions: bool = False,
use_cache: bool = False,
deterministic: bool = True,
) -> Tuple[jnp.ndarray, ...]:
"""
Apply rotary multi-query attention.
Args:
hidden_states: Input tensor [batch_size, seq_len, dim]
attention_mask: Attention mask [batch_size, 1, seq_len, seq_len]
position_ids: Position indices [batch_size, seq_len]
past_key_value: Cached key and value tensors for incremental decoding
output_attentions: Whether to return attention weights
use_cache: Whether to use cached key and values
deterministic: Whether to use deterministic operations (no dropout)
Returns:
Tuple of (output, attention_weights, present_key_value)
"""
batch_size, seq_len, _ = hidden_states.shape
# Project inputs to queries, keys, and values
q = self.q_proj(hidden_states) # [batch_size, seq_len, num_query_heads * head_dim]
k = self.k_proj(hidden_states) # [batch_size, seq_len, num_kv_heads * head_dim]
v = self.v_proj(hidden_states) # [batch_size, seq_len, num_kv_heads * head_dim]
# Reshape
q = q.reshape(batch_size, seq_len, self.num_query_heads, self.actual_head_dim)
k = k.reshape(batch_size, seq_len, self.num_kv_heads, self.actual_head_dim)
v = v.reshape(batch_size, seq_len, self.num_kv_heads, self.actual_head_dim)
# Apply rotary position embeddings
if position_ids is None:
position_ids = jnp.arange(seq_len)[None, :]
# Apply rotary embeddings to q and k
q = self.rotary_emb(q, position_ids)
k = self.rotary_emb(k, position_ids)
# Handle cached key and values for incremental decoding
if past_key_value is not None and use_cache:
past_k, past_v = past_key_value
k = jnp.concatenate([past_k, k], axis=1)
v = jnp.concatenate([past_v, v], axis=1)
# Save key and value for future use if caching
present_key_value = (k, v) if use_cache else None
# Transpose to [batch_size, num_*_heads, seq_len, head_dim]
q = jnp.transpose(q, (0, 2, 1, 3))
k = jnp.transpose(k, (0, 2, 1, 3))
v = jnp.transpose(v, (0, 2, 1, 3))
# Repeat k and v for each query head
if self.num_kv_heads < self.num_query_heads:
# Calculate how many times to repeat
repeats = self.num_query_heads // self.num_kv_heads
# Repeat k and v along the head dimension
k = jnp.repeat(k, repeats, axis=1)
v = jnp.repeat(v, repeats, axis=1)
# Compute attention scores
# [batch_size, num_query_heads, seq_len, seq_len]
attention_scores = jnp.matmul(q, jnp.transpose(k, (0, 1, 3, 2))) / math.sqrt(self.actual_head_dim)
# Apply attention mask if provided
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
# Apply softmax to get attention weights
attention_weights = jax.nn.softmax(attention_scores, axis=-1)
# Apply dropout to attention weights
attention_weights = self.dropout(attention_weights, deterministic=deterministic)
# Compute attention output
# [batch_size, num_query_heads, seq_len, head_dim]
attention_output = jnp.matmul(attention_weights, v)
# Transpose and reshape to [batch_size, seq_len, dim]
attention_output = jnp.transpose(attention_output, (0, 2, 1, 3))
attention_output = attention_output.reshape(batch_size, seq_len, self.num_query_heads * self.actual_head_dim)
# Project to output dimension
output = self.out_proj(attention_output)
outputs = (output, attention_weights, present_key_value) if output_attentions else (output, None, present_key_value)
return outputs