from typing import Tuple, Union from diffusers.models.autoencoders.autoencoder_dc import EfficientViTBlock from diffusers.models.normalization import RMSNorm, get_normalization from diffusers.models.activations import get_activation import torch import torch.nn as nn import torch.nn.functional as F class ResBlock(nn.Module): def __init__( self, in_channels: int, out_channels: int, norm_type: str = "batch_norm", act_fn: str = "relu6", rank_ratio: int = 16, ) -> None: super().__init__() self.norm_type = norm_type self.nonlinearity = get_activation(act_fn) if act_fn is not None else nn.Identity() if rank_ratio is None: self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1) self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False) else: assert isinstance(rank_ratio, int) self.conv1 = nn.Sequential( nn.Conv2d(in_channels, in_channels // rank_ratio, 3, 1, 1, bias=False), nn.Conv2d(in_channels // rank_ratio, in_channels, 3, 1, 1), ) self.conv2 = nn.Sequential( nn.Conv2d(in_channels, out_channels // rank_ratio, 3, 1, 1, bias=False), nn.Conv2d(out_channels // rank_ratio, out_channels, 3, 1, 1, bias=False), ) self.norm = get_normalization(norm_type, out_channels) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residual = hidden_states hidden_states = self.conv1(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.conv2(hidden_states) if self.norm_type == "rms_norm": # move channel to the last dimension so we apply RMSnorm across channel dimension hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1) else: hidden_states = self.norm(hidden_states) return hidden_states + residual def get_block( block_type: str, in_channels: int, out_channels: int, attention_head_dim: int, norm_type: str, act_fn: str, qkv_mutliscales: Tuple[int] = (), ): if block_type == "ResBlock": block = ResBlock(in_channels, out_channels, norm_type, act_fn) elif block_type == "EfficientViTBlock": block = EfficientViTBlock( in_channels, attention_head_dim=attention_head_dim, norm_type=norm_type, qkv_multiscales=qkv_mutliscales ) else: raise ValueError(f"Block with {block_type=} is not supported.") return block class DCDownBlock2d(nn.Module): def __init__(self, in_channels: int, out_channels: int, downsample: bool = False, shortcut: bool = True) -> None: super().__init__() self.downsample = downsample self.factor = 2 self.stride = 1 if downsample else 2 self.group_size = in_channels * self.factor**2 // out_channels self.shortcut = shortcut out_ratio = self.factor**2 if downsample: assert out_channels % out_ratio == 0 out_channels = out_channels // out_ratio self.conv = nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=self.stride, padding=1, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: x = self.conv(hidden_states) if self.downsample: x = F.pixel_unshuffle(x, self.factor) if self.shortcut: y = F.pixel_unshuffle(hidden_states, self.factor) y = y.unflatten(1, (-1, self.group_size)) y = y.mean(dim=2) hidden_states = x + y else: hidden_states = x return hidden_states class DCUpBlock2d(nn.Module): def __init__( self, in_channels: int, out_channels: int, interpolate: bool = False, shortcut: bool = True, interpolation_mode: str = "nearest", ) -> None: super().__init__() self.interpolate = interpolate self.interpolation_mode = interpolation_mode self.shortcut = shortcut self.factor = 2 self.repeats = out_channels * self.factor**2 // in_channels out_ratio = self.factor**2 if not interpolate: out_channels = out_channels * out_ratio self.conv = nn.Conv2d(in_channels, out_channels, 3, 1, 1) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.interpolate: x = F.interpolate(hidden_states, scale_factor=self.factor, mode=self.interpolation_mode) x = self.conv(x) else: x = self.conv(hidden_states) x = F.pixel_shuffle(x, self.factor) if self.shortcut: y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats) y = F.pixel_shuffle(y, self.factor) hidden_states = x + y else: hidden_states = x return hidden_states class DCAEEncoder(nn.Module): def __init__( self, in_channels: int, latent_channels: int, attention_head_dim: int = 32, block_type: Union[str, Tuple[str]] = "ResBlock", block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024), layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2), qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)), downsample_block_type: str = "pixel_unshuffle", out_shortcut: bool = True, ): super().__init__() num_blocks = len(block_out_channels) if isinstance(block_type, str): block_type = (block_type,) * num_blocks if layers_per_block[0] > 0: self.conv_in = nn.Conv2d( in_channels, block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1], kernel_size=3, stride=1, padding=1, ) else: self.conv_in = DCDownBlock2d( in_channels=in_channels, out_channels=block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1], downsample=downsample_block_type == "pixel_unshuffle", shortcut=False, ) down_blocks = [] for i, (out_channel, num_layers) in enumerate(zip(block_out_channels, layers_per_block)): down_block_list = [] for _ in range(num_layers): block = get_block( block_type[i], out_channel, out_channel, attention_head_dim=attention_head_dim, norm_type="rms_norm", act_fn="silu", qkv_mutliscales=qkv_multiscales[i], ) down_block_list.append(block) if i < num_blocks - 1 and num_layers > 0: downsample_block = DCDownBlock2d( in_channels=out_channel, out_channels=block_out_channels[i + 1], downsample=downsample_block_type == "pixel_unshuffle", shortcut=True, ) down_block_list.append(downsample_block) down_blocks.append(nn.Sequential(*down_block_list)) self.down_blocks = nn.ModuleList(down_blocks) self.conv_out = nn.Conv2d(block_out_channels[-1], latent_channels, 3, 1, 1) self.out_shortcut = out_shortcut if out_shortcut: self.out_shortcut_average_group_size = block_out_channels[-1] // latent_channels def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.conv_in(hidden_states) for down_block in self.down_blocks: hidden_states = down_block(hidden_states) if self.out_shortcut: x = hidden_states.unflatten(1, (-1, self.out_shortcut_average_group_size)) x = x.mean(dim=2) hidden_states = self.conv_out(hidden_states) + x else: hidden_states = self.conv_out(hidden_states) return hidden_states class DCAEDecoder(nn.Module): def __init__( self, TYPE: str, in_channels: int, latent_channels: int, attention_head_dim: int = 32, block_type: Union[str, Tuple[str]] = "ResBlock", block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024), layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2), qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)), norm_type: Union[str, Tuple[str]] = "rms_norm", act_fn: Union[str, Tuple[str]] = "silu", upsample_block_type: str = "pixel_shuffle", in_shortcut: bool = True, ): super().__init__() num_blocks = len(block_out_channels) if isinstance(block_type, str): block_type = (block_type,) * num_blocks if isinstance(norm_type, str): norm_type = (norm_type,) * num_blocks if isinstance(act_fn, str): act_fn = (act_fn,) * num_blocks self.conv_in = nn.Conv2d(latent_channels, block_out_channels[-1], 3, 1, 1) self.in_shortcut = in_shortcut if in_shortcut: self.in_shortcut_repeats = block_out_channels[-1] // latent_channels up_blocks = [] for i, (out_channel, num_layers) in reversed(list(enumerate(zip(block_out_channels, layers_per_block)))): up_block_list = [] if i < num_blocks - 1 and num_layers > 0: upsample_block = DCUpBlock2d( block_out_channels[i + 1], out_channel, interpolate=upsample_block_type == "interpolate", shortcut=True, ) up_block_list.append(upsample_block) for _ in range(num_layers): block = get_block( block_type[i], out_channel, out_channel, attention_head_dim=attention_head_dim, norm_type=norm_type[i], act_fn=act_fn[i], qkv_mutliscales=qkv_multiscales[i], ) up_block_list.append(block) up_blocks.insert(0, nn.Sequential(*up_block_list)) self.up_blocks = nn.ModuleList(up_blocks) channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1] self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True) self.conv_act = nn.ReLU() self.conv_out = None if layers_per_block[0] > 0: self.conv_out = nn.Conv2d(channels, in_channels, 3, 1, 1) else: self.conv_out = DCUpBlock2d( channels, in_channels, interpolate=upsample_block_type == "interpolate", shortcut=False ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.in_shortcut: x = hidden_states.repeat_interleave( self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats ) hidden_states = self.conv_in(hidden_states) + x else: hidden_states = self.conv_in(hidden_states) for up_block in reversed(self.up_blocks): hidden_states = up_block(hidden_states) hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) hidden_states = self.conv_act(hidden_states) hidden_states = self.conv_out(hidden_states) return hidden_states if __name__ == "__main__": device = torch.device("cuda:0") # case 1: from 32 x 8 x 8 to 3 x 256 x 256 data = torch.randn(1, 32, 8, 8).to(device) model = DCAEDecoder( in_channels=3, latent_channels=32, attention_head_dim=32, block_type=["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock", "EfficientViTBlock", "EfficientViTBlock"], block_out_channels=[128, 256, 512, 512, 1024, 1024], layers_per_block=[3, 3, 3, 3, 3, 3], qkv_multiscales=[[], [], [], [5], [5], [5]], norm_type="rms_norm", act_fn="silu", upsample_block_type="interpolate", ).to(device) output = model(data) num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Number of parameters: {num_params / 1e6:.2f}M") print(f"Output shape: {output.shape}") # case 2: from 2 x 2 x 2 to 32 x 8 x 8 data = torch.randn(1, 2, 2, 2).to(device) model = DCAEDecoder( in_channels=32, latent_channels=2, attention_head_dim=32, block_type=["EfficientViTBlock", "EfficientViTBlock", "EfficientViTBlock"], block_out_channels=[256, 512, 512], layers_per_block=[6, 6, 6], qkv_multiscales=[[5], [5], [5]], norm_type="rms_norm", act_fn="silu", upsample_block_type="interpolate", ).to(device) output = model(data) num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Number of parameters: {num_params / 1e6:.2f}M") print(f"Output shape: {output.shape}")