| | import torch |
| | from torch import nn |
| | from torch.nn import functional as F |
| | from functools import partial |
| | import numpy as np |
| | import typing as tp |
| |
|
| | from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes |
| | from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config |
| | from .dit import DiffusionTransformer |
| | from .factory import create_pretransform_from_config |
| | from .pretransforms import Pretransform |
| | from ..inference.generation import generate_diffusion_cond |
| |
|
| | from .adp import UNetCFG1d, UNet1d |
| |
|
| | from time import time |
| |
|
| | class Profiler: |
| |
|
| | def __init__(self): |
| | self.ticks = [[time(), None]] |
| |
|
| | def tick(self, msg): |
| | self.ticks.append([time(), msg]) |
| |
|
| | def __repr__(self): |
| | rep = 80 * "=" + "\n" |
| | for i in range(1, len(self.ticks)): |
| | msg = self.ticks[i][1] |
| | ellapsed = self.ticks[i][0] - self.ticks[i - 1][0] |
| | rep += msg + f": {ellapsed*1000:.2f}ms\n" |
| | rep += 80 * "=" + "\n\n\n" |
| | return rep |
| |
|
| | class DiffusionModel(nn.Module): |
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| |
|
| | def forward(self, x, t, **kwargs): |
| | raise NotImplementedError() |
| |
|
| | class DiffusionModelWrapper(nn.Module): |
| | def __init__( |
| | self, |
| | model: DiffusionModel, |
| | io_channels, |
| | sample_size, |
| | sample_rate, |
| | min_input_length, |
| | pretransform: tp.Optional[Pretransform] = None, |
| | ): |
| | super().__init__() |
| | self.io_channels = io_channels |
| | self.sample_size = sample_size |
| | self.sample_rate = sample_rate |
| | self.min_input_length = min_input_length |
| |
|
| | self.model = model |
| |
|
| | if pretransform is not None: |
| | self.pretransform = pretransform |
| | else: |
| | self.pretransform = None |
| |
|
| | def forward(self, x, t, **kwargs): |
| | return self.model(x, t, **kwargs) |
| |
|
| | class ConditionedDiffusionModel(nn.Module): |
| | def __init__(self, |
| | *args, |
| | supports_cross_attention: bool = False, |
| | supports_input_concat: bool = False, |
| | supports_global_cond: bool = False, |
| | supports_prepend_cond: bool = False, |
| | **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self.supports_cross_attention = supports_cross_attention |
| | self.supports_input_concat = supports_input_concat |
| | self.supports_global_cond = supports_global_cond |
| | self.supports_prepend_cond = supports_prepend_cond |
| |
|
| | def forward(self, |
| | x: torch.Tensor, |
| | t: torch.Tensor, |
| | cross_attn_cond: torch.Tensor = None, |
| | cross_attn_mask: torch.Tensor = None, |
| | input_concat_cond: torch.Tensor = None, |
| | global_embed: torch.Tensor = None, |
| | prepend_cond: torch.Tensor = None, |
| | prepend_cond_mask: torch.Tensor = None, |
| | cfg_scale: float = 1.0, |
| | cfg_dropout_prob: float = 0.0, |
| | batch_cfg: bool = False, |
| | rescale_cfg: bool = False, |
| | **kwargs): |
| | raise NotImplementedError() |
| |
|
| | class ConditionedDiffusionModelWrapper(nn.Module): |
| | """ |
| | A diffusion model that takes in conditioning |
| | """ |
| | def __init__( |
| | self, |
| | model: ConditionedDiffusionModel, |
| | conditioner: MultiConditioner, |
| | io_channels, |
| | sample_rate, |
| | min_input_length: int, |
| | diffusion_objective: tp.Literal["v", "rectified_flow"] = "v", |
| | pretransform: tp.Optional[Pretransform] = None, |
| | cross_attn_cond_ids: tp.List[str] = [], |
| | global_cond_ids: tp.List[str] = [], |
| | input_concat_ids: tp.List[str] = [], |
| | prepend_cond_ids: tp.List[str] = [], |
| | ): |
| | super().__init__() |
| |
|
| | self.model = model |
| | self.conditioner = conditioner |
| | self.io_channels = io_channels |
| | self.sample_rate = sample_rate |
| | self.diffusion_objective = diffusion_objective |
| | self.pretransform = pretransform |
| | self.cross_attn_cond_ids = cross_attn_cond_ids |
| | self.global_cond_ids = global_cond_ids |
| | self.input_concat_ids = input_concat_ids |
| | self.prepend_cond_ids = prepend_cond_ids |
| | self.min_input_length = min_input_length |
| |
|
| | def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[torch.Tensor, tp.Any], negative=False): |
| | cross_attention_input = None |
| | cross_attention_masks = None |
| | global_cond = None |
| | input_concat_cond = None |
| | prepend_cond = None |
| | prepend_cond_mask = None |
| |
|
| | if len(self.cross_attn_cond_ids) > 0: |
| | |
| | |
| | cross_attention_input = [] |
| | cross_attention_masks = [] |
| |
|
| | for key in self.cross_attn_cond_ids: |
| | cross_attn_in, cross_attn_mask = conditioning_tensors[key] |
| |
|
| | |
| | if len(cross_attn_in.shape) == 2: |
| | cross_attn_in = cross_attn_in.unsqueeze(1) |
| | cross_attn_mask = cross_attn_mask.unsqueeze(1) |
| |
|
| | cross_attention_input.append(cross_attn_in) |
| | cross_attention_masks.append(cross_attn_mask) |
| |
|
| | cross_attention_input = torch.cat(cross_attention_input, dim=1) |
| | cross_attention_masks = torch.cat(cross_attention_masks, dim=1) |
| |
|
| | if len(self.global_cond_ids) > 0: |
| | |
| | |
| | global_conds = [] |
| | for key in self.global_cond_ids: |
| | |
| | global_cond_input = conditioning_tensors[key][0] |
| |
|
| | global_conds.append(global_cond_input) |
| |
|
| | |
| | global_cond = torch.cat(global_conds, dim=-1) |
| |
|
| | if len(global_cond.shape) == 3: |
| | global_cond = global_cond.squeeze(1) |
| |
|
| | if len(self.input_concat_ids) > 0: |
| | |
| | |
| | input_concat_cond = torch.cat([conditioning_tensors[key][0] for key in self.input_concat_ids], dim=1) |
| |
|
| | if len(self.prepend_cond_ids) > 0: |
| | |
| | |
| | prepend_conds = [] |
| | prepend_cond_masks = [] |
| |
|
| | for key in self.prepend_cond_ids: |
| | prepend_cond_input, prepend_cond_mask = conditioning_tensors[key] |
| | prepend_conds.append(prepend_cond_input) |
| | prepend_cond_masks.append(prepend_cond_mask) |
| |
|
| | prepend_cond = torch.cat(prepend_conds, dim=1) |
| | prepend_cond_mask = torch.cat(prepend_cond_masks, dim=1) |
| |
|
| | if negative: |
| | return { |
| | "negative_cross_attn_cond": cross_attention_input, |
| | "negative_cross_attn_mask": cross_attention_masks, |
| | "negative_global_cond": global_cond, |
| | "negative_input_concat_cond": input_concat_cond |
| | } |
| | else: |
| | return { |
| | "cross_attn_cond": cross_attention_input, |
| | "cross_attn_mask": cross_attention_masks, |
| | "global_cond": global_cond, |
| | "input_concat_cond": input_concat_cond, |
| | "prepend_cond": prepend_cond, |
| | "prepend_cond_mask": prepend_cond_mask |
| | } |
| |
|
| | def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs): |
| | return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs) |
| |
|
| | def generate(self, *args, **kwargs): |
| | return generate_diffusion_cond(self, *args, **kwargs) |
| |
|
| | class UNetCFG1DWrapper(ConditionedDiffusionModel): |
| | def __init__( |
| | self, |
| | *args, |
| | **kwargs |
| | ): |
| | super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True) |
| |
|
| | self.model = UNetCFG1d(*args, **kwargs) |
| |
|
| | with torch.no_grad(): |
| | for param in self.model.parameters(): |
| | param *= 0.5 |
| |
|
| | def forward(self, |
| | x, |
| | t, |
| | cross_attn_cond=None, |
| | cross_attn_mask=None, |
| | input_concat_cond=None, |
| | global_cond=None, |
| | cfg_scale=1.0, |
| | cfg_dropout_prob: float = 0.0, |
| | batch_cfg: bool = False, |
| | rescale_cfg: bool = False, |
| | negative_cross_attn_cond=None, |
| | negative_cross_attn_mask=None, |
| | negative_global_cond=None, |
| | negative_input_concat_cond=None, |
| | prepend_cond=None, |
| | prepend_cond_mask=None, |
| | **kwargs): |
| | p = Profiler() |
| |
|
| | p.tick("start") |
| |
|
| | channels_list = None |
| | if input_concat_cond is not None: |
| | channels_list = [input_concat_cond] |
| |
|
| | outputs = self.model( |
| | x, |
| | t, |
| | embedding=cross_attn_cond, |
| | embedding_mask=cross_attn_mask, |
| | features=global_cond, |
| | channels_list=channels_list, |
| | embedding_scale=cfg_scale, |
| | embedding_mask_proba=cfg_dropout_prob, |
| | batch_cfg=batch_cfg, |
| | rescale_cfg=rescale_cfg, |
| | negative_embedding=negative_cross_attn_cond, |
| | negative_embedding_mask=negative_cross_attn_mask, |
| | **kwargs) |
| |
|
| | p.tick("UNetCFG1D forward") |
| |
|
| | |
| | return outputs |
| |
|
| | class UNet1DCondWrapper(ConditionedDiffusionModel): |
| | def __init__( |
| | self, |
| | *args, |
| | **kwargs |
| | ): |
| | super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True) |
| |
|
| | self.model = UNet1d(*args, **kwargs) |
| |
|
| | with torch.no_grad(): |
| | for param in self.model.parameters(): |
| | param *= 0.5 |
| |
|
| | def forward(self, |
| | x, |
| | t, |
| | input_concat_cond=None, |
| | global_cond=None, |
| | cross_attn_cond=None, |
| | cross_attn_mask=None, |
| | prepend_cond=None, |
| | prepend_cond_mask=None, |
| | cfg_scale=1.0, |
| | cfg_dropout_prob: float = 0.0, |
| | batch_cfg: bool = False, |
| | rescale_cfg: bool = False, |
| | negative_cross_attn_cond=None, |
| | negative_cross_attn_mask=None, |
| | negative_global_cond=None, |
| | negative_input_concat_cond=None, |
| | **kwargs): |
| |
|
| | channels_list = None |
| | if input_concat_cond is not None: |
| |
|
| | |
| | if input_concat_cond.shape[2] != x.shape[2]: |
| | input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest') |
| |
|
| | channels_list = [input_concat_cond] |
| |
|
| | outputs = self.model( |
| | x, |
| | t, |
| | features=global_cond, |
| | channels_list=channels_list, |
| | **kwargs) |
| |
|
| | return outputs |
| |
|
| | class UNet1DUncondWrapper(DiffusionModel): |
| | def __init__( |
| | self, |
| | in_channels, |
| | *args, |
| | **kwargs |
| | ): |
| | super().__init__() |
| |
|
| | self.model = UNet1d(in_channels=in_channels, *args, **kwargs) |
| |
|
| | self.io_channels = in_channels |
| |
|
| | with torch.no_grad(): |
| | for param in self.model.parameters(): |
| | param *= 0.5 |
| |
|
| | def forward(self, x, t, **kwargs): |
| | return self.model(x, t, **kwargs) |
| |
|
| | class DAU1DCondWrapper(ConditionedDiffusionModel): |
| | def __init__( |
| | self, |
| | *args, |
| | **kwargs |
| | ): |
| | super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True) |
| |
|
| | self.model = DiffusionAttnUnet1D(*args, **kwargs) |
| |
|
| | with torch.no_grad(): |
| | for param in self.model.parameters(): |
| | param *= 0.5 |
| |
|
| | def forward(self, |
| | x, |
| | t, |
| | input_concat_cond=None, |
| | cross_attn_cond=None, |
| | cross_attn_mask=None, |
| | global_cond=None, |
| | cfg_scale=1.0, |
| | cfg_dropout_prob: float = 0.0, |
| | batch_cfg: bool = False, |
| | rescale_cfg: bool = False, |
| | negative_cross_attn_cond=None, |
| | negative_cross_attn_mask=None, |
| | negative_global_cond=None, |
| | negative_input_concat_cond=None, |
| | prepend_cond=None, |
| | **kwargs): |
| |
|
| | return self.model(x, t, cond = input_concat_cond) |
| |
|
| | class DiffusionAttnUnet1D(nn.Module): |
| | def __init__( |
| | self, |
| | io_channels = 2, |
| | depth=14, |
| | n_attn_layers = 6, |
| | channels = [128, 128, 256, 256] + [512] * 10, |
| | cond_dim = 0, |
| | cond_noise_aug = False, |
| | kernel_size = 5, |
| | learned_resample = False, |
| | strides = [2] * 13, |
| | conv_bias = True, |
| | use_snake = False |
| | ): |
| | super().__init__() |
| |
|
| | self.cond_noise_aug = cond_noise_aug |
| |
|
| | self.io_channels = io_channels |
| |
|
| | if self.cond_noise_aug: |
| | self.rng = torch.quasirandom.SobolEngine(1, scramble=True) |
| |
|
| | self.timestep_embed = FourierFeatures(1, 16) |
| |
|
| | attn_layer = depth - n_attn_layers |
| |
|
| | strides = [1] + strides |
| |
|
| | block = nn.Identity() |
| |
|
| | conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias = conv_bias, use_snake=use_snake) |
| |
|
| | for i in range(depth, 0, -1): |
| | c = channels[i - 1] |
| | stride = strides[i-1] |
| | if stride > 2 and not learned_resample: |
| | raise ValueError("Must have stride 2 without learned resampling") |
| |
|
| | if i > 1: |
| | c_prev = channels[i - 2] |
| | add_attn = i >= attn_layer and n_attn_layers > 0 |
| | block = SkipBlock( |
| | Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"), |
| | conv_block(c_prev, c, c), |
| | SelfAttention1d( |
| | c, c // 32) if add_attn else nn.Identity(), |
| | conv_block(c, c, c), |
| | SelfAttention1d( |
| | c, c // 32) if add_attn else nn.Identity(), |
| | conv_block(c, c, c), |
| | SelfAttention1d( |
| | c, c // 32) if add_attn else nn.Identity(), |
| | block, |
| | conv_block(c * 2 if i != depth else c, c, c), |
| | SelfAttention1d( |
| | c, c // 32) if add_attn else nn.Identity(), |
| | conv_block(c, c, c), |
| | SelfAttention1d( |
| | c, c // 32) if add_attn else nn.Identity(), |
| | conv_block(c, c, c_prev), |
| | SelfAttention1d(c_prev, c_prev // |
| | 32) if add_attn else nn.Identity(), |
| | Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic") |
| | ) |
| | else: |
| | cond_embed_dim = 16 if not self.cond_noise_aug else 32 |
| | block = nn.Sequential( |
| | conv_block((io_channels + cond_dim) + cond_embed_dim, c, c), |
| | conv_block(c, c, c), |
| | conv_block(c, c, c), |
| | block, |
| | conv_block(c * 2, c, c), |
| | conv_block(c, c, c), |
| | conv_block(c, c, io_channels, is_last=True), |
| | ) |
| | self.net = block |
| |
|
| | with torch.no_grad(): |
| | for param in self.net.parameters(): |
| | param *= 0.5 |
| |
|
| | def forward(self, x, t, cond=None, cond_aug_scale=None): |
| |
|
| | timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape) |
| |
|
| | inputs = [x, timestep_embed] |
| |
|
| | if cond is not None: |
| | if cond.shape[2] != x.shape[2]: |
| | cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False) |
| |
|
| | if self.cond_noise_aug: |
| | |
| | if cond_aug_scale is None: |
| | aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond) |
| | else: |
| | aug_level = torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond) |
| |
|
| | |
| | cond = cond + torch.randn_like(cond) * aug_level[:, None, None] |
| |
|
| | |
| | aug_level_embed = expand_to_planes(self.timestep_embed(aug_level[:, None]), x.shape) |
| |
|
| | inputs.append(aug_level_embed) |
| |
|
| | inputs.append(cond) |
| |
|
| | outputs = self.net(torch.cat(inputs, dim=1)) |
| |
|
| | return outputs |
| |
|
| | class DiTWrapper(ConditionedDiffusionModel): |
| | def __init__( |
| | self, |
| | *args, |
| | **kwargs |
| | ): |
| | super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False) |
| |
|
| | self.model = DiffusionTransformer(*args, **kwargs) |
| |
|
| | with torch.no_grad(): |
| | for param in self.model.parameters(): |
| | param *= 0.5 |
| |
|
| | def forward(self, |
| | x, |
| | t, |
| | cross_attn_cond=None, |
| | cross_attn_mask=None, |
| | negative_cross_attn_cond=None, |
| | negative_cross_attn_mask=None, |
| | input_concat_cond=None, |
| | negative_input_concat_cond=None, |
| | global_cond=None, |
| | negative_global_cond=None, |
| | prepend_cond=None, |
| | prepend_cond_mask=None, |
| | cfg_scale=1.0, |
| | cfg_dropout_prob: float = 0.0, |
| | batch_cfg: bool = True, |
| | rescale_cfg: bool = False, |
| | scale_phi: float = 0.0, |
| | **kwargs): |
| |
|
| | assert batch_cfg, "batch_cfg must be True for DiTWrapper" |
| | |
| |
|
| | return self.model( |
| | x, |
| | t, |
| | cross_attn_cond=cross_attn_cond, |
| | cross_attn_cond_mask=cross_attn_mask, |
| | negative_cross_attn_cond=negative_cross_attn_cond, |
| | negative_cross_attn_mask=negative_cross_attn_mask, |
| | input_concat_cond=input_concat_cond, |
| | prepend_cond=prepend_cond, |
| | prepend_cond_mask=prepend_cond_mask, |
| | cfg_scale=cfg_scale, |
| | cfg_dropout_prob=cfg_dropout_prob, |
| | scale_phi=scale_phi, |
| | global_embed=global_cond, |
| | **kwargs) |
| |
|
| | class DiTUncondWrapper(DiffusionModel): |
| | def __init__( |
| | self, |
| | in_channels, |
| | *args, |
| | **kwargs |
| | ): |
| | super().__init__() |
| |
|
| | self.model = DiffusionTransformer(io_channels=in_channels, *args, **kwargs) |
| |
|
| | self.io_channels = in_channels |
| |
|
| | with torch.no_grad(): |
| | for param in self.model.parameters(): |
| | param *= 0.5 |
| |
|
| | def forward(self, x, t, **kwargs): |
| | return self.model(x, t, **kwargs) |
| |
|
| | def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]): |
| | diffusion_uncond_config = config["model"] |
| |
|
| | model_type = diffusion_uncond_config.get('type', None) |
| |
|
| | diffusion_config = diffusion_uncond_config.get('config', {}) |
| |
|
| | assert model_type is not None, "Must specify model type in config" |
| |
|
| | pretransform = diffusion_uncond_config.get("pretransform", None) |
| |
|
| | sample_size = config.get("sample_size", None) |
| | assert sample_size is not None, "Must specify sample size in config" |
| |
|
| | sample_rate = config.get("sample_rate", None) |
| | assert sample_rate is not None, "Must specify sample rate in config" |
| |
|
| | if pretransform is not None: |
| | pretransform = create_pretransform_from_config(pretransform, sample_rate) |
| | min_input_length = pretransform.downsampling_ratio |
| | else: |
| | min_input_length = 1 |
| |
|
| | if model_type == 'DAU1d': |
| |
|
| | model = DiffusionAttnUnet1D( |
| | **diffusion_config |
| | ) |
| | |
| | elif model_type == "adp_uncond_1d": |
| |
|
| | model = UNet1DUncondWrapper( |
| | **diffusion_config |
| | ) |
| |
|
| | elif model_type == "dit": |
| | model = DiTUncondWrapper( |
| | **diffusion_config |
| | ) |
| |
|
| | else: |
| | raise NotImplementedError(f'Unknown model type: {model_type}') |
| |
|
| | return DiffusionModelWrapper(model, |
| | io_channels=model.io_channels, |
| | sample_size=sample_size, |
| | sample_rate=sample_rate, |
| | pretransform=pretransform, |
| | min_input_length=min_input_length) |
| |
|
| | def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]): |
| |
|
| | model_config = config["model"] |
| |
|
| | model_type = config["model_type"] |
| |
|
| | diffusion_config = model_config.get('diffusion', None) |
| | assert diffusion_config is not None, "Must specify diffusion config" |
| |
|
| | diffusion_model_type = diffusion_config.get('type', None) |
| | assert diffusion_model_type is not None, "Must specify diffusion model type" |
| |
|
| | diffusion_model_config = diffusion_config.get('config', None) |
| | if diffusion_model_config.get('video_fps', None) is not None: |
| | diffusion_model_config.pop('video_fps') |
| | assert diffusion_model_config is not None, "Must specify diffusion model config" |
| |
|
| | if diffusion_model_type == 'adp_cfg_1d': |
| | diffusion_model = UNetCFG1DWrapper(**diffusion_model_config) |
| | elif diffusion_model_type == 'adp_1d': |
| | diffusion_model = UNet1DCondWrapper(**diffusion_model_config) |
| | elif diffusion_model_type == 'dit': |
| | diffusion_model = DiTWrapper(**diffusion_model_config) |
| |
|
| | io_channels = model_config.get('io_channels', None) |
| | assert io_channels is not None, "Must specify io_channels in model config" |
| |
|
| | sample_rate = config.get('sample_rate', None) |
| | assert sample_rate is not None, "Must specify sample_rate in config" |
| |
|
| | diffusion_objective = diffusion_config.get('diffusion_objective', 'v') |
| |
|
| | conditioning_config = model_config.get('conditioning', None) |
| |
|
| | conditioner = None |
| | if conditioning_config is not None: |
| | conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config) |
| |
|
| | cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', []) |
| | global_cond_ids = diffusion_config.get('global_cond_ids', []) |
| | input_concat_ids = diffusion_config.get('input_concat_ids', []) |
| | prepend_cond_ids = diffusion_config.get('prepend_cond_ids', []) |
| |
|
| | pretransform = model_config.get("pretransform", None) |
| |
|
| | if pretransform is not None: |
| | pretransform = create_pretransform_from_config(pretransform, sample_rate) |
| | min_input_length = pretransform.downsampling_ratio |
| | else: |
| | min_input_length = 1 |
| |
|
| | if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d": |
| | min_input_length *= np.prod(diffusion_model_config["factors"]) |
| | elif diffusion_model_type == "dit": |
| | min_input_length *= diffusion_model.model.patch_size |
| |
|
| | |
| |
|
| | extra_kwargs = {} |
| |
|
| | if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint": |
| | wrapper_fn = ConditionedDiffusionModelWrapper |
| |
|
| | extra_kwargs["diffusion_objective"] = diffusion_objective |
| |
|
| | elif model_type == "diffusion_prior": |
| | prior_type = model_config.get("prior_type", None) |
| | assert prior_type is not None, "Must specify prior_type in diffusion prior model config" |
| |
|
| | if prior_type == "mono_stereo": |
| | from .diffusion_prior import MonoToStereoDiffusionPrior |
| | wrapper_fn = MonoToStereoDiffusionPrior |
| | |
| | return wrapper_fn( |
| | diffusion_model, |
| | conditioner, |
| | min_input_length=min_input_length, |
| | sample_rate=sample_rate, |
| | cross_attn_cond_ids=cross_attention_ids, |
| | global_cond_ids=global_cond_ids, |
| | input_concat_ids=input_concat_ids, |
| | prepend_cond_ids=prepend_cond_ids, |
| | pretransform=pretransform, |
| | io_channels=io_channels, |
| | **extra_kwargs |
| | ) |