Spaces:
Runtime error
Runtime error
| # -------------------------------------------------------- | |
| # What Matters When Repurposing Diffusion Models for General Dense Perception Tasks? (https://arxiv.org/abs/2403.06090) | |
| # Github source: https://github.com/aim-uofa/GenPercept | |
| # Copyright (c) 2024, Advanced Intelligent Machines (AIM) | |
| # Licensed under The BSD 2-Clause License [see LICENSE for details] | |
| # By Guangkai Xu | |
| # Based on Marigold, diffusers codebases | |
| # https://github.com/prs-eth/marigold | |
| # https://github.com/huggingface/diffusers | |
| # -------------------------------------------------------- | |
| import torch | |
| from typing import List, Optional, Tuple, Union | |
| import numpy as np | |
| from diffusers import DDIMScheduler, DDPMScheduler | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| def rescale_zero_terminal_snr(betas): | |
| """ | |
| Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) | |
| Args: | |
| betas (`torch.FloatTensor`): | |
| the betas that the scheduler is being initialized with. | |
| Returns: | |
| `torch.FloatTensor`: rescaled betas with zero terminal SNR | |
| """ | |
| # Convert betas to alphas_bar_sqrt | |
| alphas = 1.0 - betas | |
| alphas_cumprod = torch.cumprod(alphas, dim=0) | |
| alphas_bar_sqrt = alphas_cumprod.sqrt() | |
| # Store old values. | |
| alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() | |
| alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() | |
| # Shift so the last timestep is zero. | |
| alphas_bar_sqrt -= alphas_bar_sqrt_T | |
| # Scale so the first timestep is back to the old value. | |
| alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) | |
| # Convert alphas_bar_sqrt to betas | |
| alphas_bar = alphas_bar_sqrt**2 # Revert sqrt | |
| alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod | |
| alphas = torch.cat([alphas_bar[0:1], alphas]) | |
| betas = 1 - alphas | |
| return betas | |
| class DDPMSchedulerCustomized(DDPMScheduler): | |
| def __init__( | |
| self, | |
| num_train_timesteps: int = 1000, | |
| beta_start: float = 0.0001, | |
| beta_end: float = 0.02, | |
| beta_schedule: str = "linear", | |
| trained_betas: Optional[Union[np.ndarray, List[float]]] = None, | |
| variance_type: str = "fixed_small", | |
| clip_sample: bool = True, | |
| prediction_type: str = "epsilon", | |
| thresholding: bool = False, | |
| dynamic_thresholding_ratio: float = 0.995, | |
| clip_sample_range: float = 1.0, | |
| sample_max_value: float = 1.0, | |
| timestep_spacing: str = "leading", | |
| steps_offset: int = 0, | |
| rescale_betas_zero_snr: int = False, | |
| power_beta_curve = 1.0, | |
| ): | |
| if trained_betas is not None: | |
| self.betas = torch.tensor(trained_betas, dtype=torch.float32) | |
| elif beta_schedule == "linear": | |
| self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) | |
| elif beta_schedule == "scaled_linear": | |
| # this schedule is very specific to the latent diffusion model. | |
| self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 | |
| elif beta_schedule == "scaled_linear_power": | |
| self.betas = torch.linspace(beta_start**(1/power_beta_curve), beta_end**(1/power_beta_curve), num_train_timesteps, dtype=torch.float32) ** power_beta_curve | |
| elif beta_schedule == "squaredcos_cap_v2": | |
| # Glide cosine schedule | |
| self.betas = betas_for_alpha_bar(num_train_timesteps) | |
| elif beta_schedule == "sigmoid": | |
| # GeoDiff sigmoid schedule | |
| betas = torch.linspace(-6, 6, num_train_timesteps) | |
| self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start | |
| else: | |
| raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") | |
| # Rescale for zero SNR | |
| if rescale_betas_zero_snr: | |
| self.betas = rescale_zero_terminal_snr(self.betas) | |
| self.alphas = 1.0 - self.betas | |
| self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | |
| self.one = torch.tensor(1.0) | |
| # standard deviation of the initial noise distribution | |
| self.init_noise_sigma = 1.0 | |
| # setable values | |
| self.custom_timesteps = False | |
| self.num_inference_steps = None | |
| self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) | |
| self.variance_type = variance_type | |
| def get_velocity( | |
| self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor | |
| ) -> torch.FloatTensor: | |
| # Make sure alphas_cumprod and timestep have same device and dtype as sample | |
| self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) | |
| alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) | |
| timesteps = timesteps.to(sample.device) | |
| sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 | |
| sqrt_alpha_prod = sqrt_alpha_prod.flatten() | |
| while len(sqrt_alpha_prod.shape) < len(sample.shape): | |
| sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) | |
| sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 | |
| sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() | |
| while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): | |
| sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) | |
| # import pdb | |
| # pdb.set_trace() | |
| velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample | |
| return velocity | |
| class DDIMSchedulerCustomized(DDIMScheduler): | |
| def __init__( | |
| self, | |
| num_train_timesteps: int = 1000, | |
| beta_start: float = 0.0001, | |
| beta_end: float = 0.02, | |
| beta_schedule: str = "linear", | |
| trained_betas: Optional[Union[np.ndarray, List[float]]] = None, | |
| clip_sample: bool = True, | |
| set_alpha_to_one: bool = True, | |
| steps_offset: int = 0, | |
| prediction_type: str = "epsilon", | |
| thresholding: bool = False, | |
| dynamic_thresholding_ratio: float = 0.995, | |
| clip_sample_range: float = 1.0, | |
| sample_max_value: float = 1.0, | |
| timestep_spacing: str = "leading", | |
| rescale_betas_zero_snr: bool = False, | |
| power_beta_curve = 1.0, | |
| ): | |
| if trained_betas is not None: | |
| self.betas = torch.tensor(trained_betas, dtype=torch.float32) | |
| elif beta_schedule == "linear": | |
| self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) | |
| elif beta_schedule == "scaled_linear": | |
| # this schedule is very specific to the latent diffusion model. | |
| self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 | |
| elif beta_schedule == "scaled_linear_power": | |
| self.betas = torch.linspace(beta_start**(1/power_beta_curve), beta_end**(1/power_beta_curve), num_train_timesteps, dtype=torch.float32) ** power_beta_curve | |
| self.power_beta_curve = power_beta_curve | |
| elif beta_schedule == "squaredcos_cap_v2": | |
| # Glide cosine schedule | |
| self.betas = betas_for_alpha_bar(num_train_timesteps) | |
| else: | |
| raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") | |
| # Rescale for zero SNR | |
| if rescale_betas_zero_snr: | |
| self.betas = rescale_zero_terminal_snr(self.betas) | |
| # self.betas = self.betas.double() | |
| self.alphas = 1.0 - self.betas | |
| self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | |
| # At every step in ddim, we are looking into the previous alphas_cumprod | |
| # For the final step, there is no previous alphas_cumprod because we are already at 0 | |
| # `set_alpha_to_one` decides whether we set this parameter simply to one or | |
| # whether we use the final alpha of the "non-previous" one. | |
| self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] | |
| # standard deviation of the initial noise distribution | |
| self.init_noise_sigma = 1.0 | |
| # setable values | |
| self.num_inference_steps = None | |
| self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) | |
| self.beta_schedule = beta_schedule | |
| def _get_variance(self, timestep, prev_timestep): | |
| alpha_prod_t = self.alphas_cumprod[timestep] | |
| alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod | |
| beta_prod_t = 1 - alpha_prod_t | |
| beta_prod_t_prev = 1 - alpha_prod_t_prev | |
| alpha_t_prev_to_t = self.alphas[(prev_timestep+1):(timestep+1)] | |
| alpha_t_prev_to_t = torch.prod(alpha_t_prev_to_t) | |
| variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_t_prev_to_t) | |
| return variance | |