| | import random |
| | from tqdm import tqdm |
| | from typing import Callable, Dict, List, Optional |
| |
|
| | import torch |
| | from diffusers import DiffusionPipeline |
| | from diffusers.configuration_utils import ConfigMixin |
| |
|
| |
|
| | class SuperDiffPipeline(DiffusionPipeline, ConfigMixin): |
| | """SuperDiffPipeline.""" |
| |
|
| | def __init__(self, unet: Callable, vae: Callable, text_encoder: Callable, scheduler: Callable, tokenizer: Callable) -> None: |
| | """__init__. |
| | |
| | Parameters |
| | ---------- |
| | unet : Callable |
| | unet |
| | vae : Callable |
| | vae |
| | text_encoder : Callable |
| | text_encoder |
| | scheduler : Callable |
| | scheduler |
| | tokenizer : Callable |
| | tokenizer |
| | kwargs : |
| | kwargs |
| | |
| | Returns |
| | ------- |
| | None |
| | |
| | """ |
| | super().__init__() |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | vae.to(device) |
| | unet.to(device) |
| | text_encoder.to(device) |
| | self.register_modules(unet=unet, |
| | scheduler=scheduler, |
| | vae=vae, |
| | text_encoder=text_encoder, |
| | tokenizer=tokenizer,) |
| |
|
| | @torch.no_grad |
| | def get_batch(self, latents: Callable, nrow: int, ncol: int) -> Callable: |
| | """get_batch. |
| | |
| | Parameters |
| | ---------- |
| | latents : Callable |
| | latents |
| | nrow : int |
| | nrow |
| | ncol : int |
| | ncol |
| | |
| | Returns |
| | ------- |
| | Callable |
| | |
| | """ |
| | image = self.vae.decode( |
| | latents / self.vae.config.scaling_factor, return_dict=False |
| | )[0] |
| | image = (image / 2 + 0.5).clamp(0, 1).squeeze() |
| | if len(image.shape) < 4: |
| | image = image.unsqueeze(0) |
| | image = (image.permute(0, 2, 3, 1) * 255).to(torch.uint8) |
| | return image |
| |
|
| | @torch.no_grad |
| | def get_text_embedding(self, prompt: str) -> Callable: |
| | """get_text_embedding. |
| | |
| | Parameters |
| | ---------- |
| | prompt : str |
| | prompt |
| | |
| | Returns |
| | ------- |
| | Callable |
| | |
| | """ |
| | text_input = self.tokenizer( |
| | prompt, |
| | padding="max_length", |
| | max_length=self.tokenizer.model_max_length, |
| | truncation=True, |
| | return_tensors="pt", |
| | ) |
| | return self.text_encoder(text_input.input_ids.to(self.device))[0] |
| |
|
| | @torch.no_grad |
| | def get_vel(self, t: float, sigma: float, latents: Callable, embeddings: Callable): |
| | """get_vel. |
| | |
| | Parameters |
| | ---------- |
| | t : float |
| | t |
| | sigma : float |
| | sigma |
| | latents : Callable |
| | latents |
| | embeddings : Callable |
| | embeddings |
| | """ |
| | def v(_x, _e): return self.unet( |
| | _x / ((sigma**2 + 1) ** 0.5), t, encoder_hidden_states=_e |
| | ).sample |
| | embeds = torch.cat(embeddings) |
| | latent_input = latents |
| | vel = v(latent_input, embeds) |
| | return vel |
| |
|
| | def preprocess( |
| | self, |
| | prompt_1: str, |
| | prompt_2: str, |
| | seed: int = None, |
| | num_inference_steps: int = 1000, |
| | batch_size: int = 1, |
| | lift: int = 0.0, |
| | height: int = 512, |
| | width: int = 512, |
| | guidance_scale: int = 7.5, |
| | ) -> Callable: |
| | """preprocess. |
| | |
| | Parameters |
| | ---------- |
| | prompt_1 : str |
| | prompt_1 |
| | prompt_2 : str |
| | prompt_2 |
| | seed : int |
| | seed |
| | num_inference_steps : int |
| | num_inference_steps |
| | batch_size : int |
| | batch_size |
| | lift : int |
| | lift |
| | height : int |
| | height |
| | width : int |
| | width |
| | guidance_scale : int |
| | guidance_scale |
| | |
| | Returns |
| | ------- |
| | Callable |
| | |
| | """ |
| | |
| | self.batch_size = batch_size |
| | self.num_inference_steps = num_inference_steps |
| | self.guidance_scale = guidance_scale |
| | self.lift = lift |
| | self.seed = seed |
| | if self.seed is None: |
| | self.seed = random.randint(0, 2**32 - 1) |
| | obj_prompt = [prompt_1] |
| | bg_prompt = [prompt_2] |
| | obj_embeddings = self.get_text_embedding(obj_prompt * batch_size) |
| | bg_embeddings = self.get_text_embedding(bg_prompt * batch_size) |
| |
|
| | uncond_embeddings = self.get_text_embedding([""] * batch_size) |
| |
|
| | generator = torch.cuda.manual_seed( |
| | self.seed |
| | ) |
| | latents = torch.randn( |
| | (batch_size, self.unet.config.in_channels, height // 8, width // 8), |
| | generator=generator, |
| | device=self.device, |
| | ) |
| |
|
| | latents_og = latents.clone().detach() |
| | latents_uncond_og = latents.clone().detach() |
| |
|
| | self.scheduler.set_timesteps(num_inference_steps) |
| | latents = latents * self.scheduler.init_noise_sigma |
| |
|
| | latents_uncond = latents.clone().detach() |
| | return { |
| | "latents": latents, |
| | "obj_embeddings": obj_embeddings, |
| | "uncond_embeddings": uncond_embeddings, |
| | "bg_embeddings": bg_embeddings, |
| | } |
| |
|
| | def _forward(self, model_inputs: Dict) -> Callable: |
| | """_forward. |
| | |
| | Parameters |
| | ---------- |
| | model_inputs : Dict |
| | model_inputs |
| | |
| | Returns |
| | ------- |
| | Callable |
| | |
| | """ |
| | latents = model_inputs["latents"] |
| | obj_embeddings = model_inputs["obj_embeddings"] |
| | uncond_embeddings = model_inputs["uncond_embeddings"] |
| | bg_embeddings = model_inputs["bg_embeddings"] |
| |
|
| | kappa = 0.5 * torch.ones( |
| | (self.num_inference_steps + 1, self.batch_size), device=self.device |
| | ) |
| | ll_obj = torch.ones( |
| | (self.num_inference_steps + 1, self.batch_size), device=self.device |
| | ) |
| | ll_bg = torch.ones( |
| | (self.num_inference_steps + 1, self.batch_size), device=self.device |
| | ) |
| | ll_uncond = torch.ones( |
| | (self.num_inference_steps + 1, self.batch_size), device=self.device |
| | ) |
| | with torch.no_grad(): |
| | for i, t in tqdm(enumerate(self.scheduler.timesteps)): |
| | dsigma = self.scheduler.sigmas[i + |
| | 1] - self.scheduler.sigmas[i] |
| | sigma = self.scheduler.sigmas[i] |
| | vel_obj = self.get_vel(t, sigma, latents, [obj_embeddings]) |
| | vel_uncond = self.get_vel( |
| | t, sigma, latents, [uncond_embeddings]) |
| |
|
| | vel_bg = self.get_vel(t, sigma, latents, [bg_embeddings]) |
| | noise = torch.sqrt(2 * torch.abs(dsigma) * sigma) * torch.randn_like( |
| | latents |
| | ) |
| |
|
| | dx_ind = ( |
| | 2 |
| | * dsigma |
| | * (vel_uncond + self.guidance_scale * (vel_bg - vel_uncond)) |
| | + noise |
| | ) |
| | kappa[i + 1] = ( |
| | (torch.abs(dsigma) * (vel_bg - vel_obj) * (vel_bg + vel_obj)).sum( |
| | (1, 2, 3) |
| | ) |
| | - (dx_ind * ((vel_obj - vel_bg))).sum((1, 2, 3)) |
| | + sigma * self.lift / self.num_inference_steps |
| | ) |
| | kappa[i + 1] /= ( |
| | 2 |
| | * dsigma |
| | * self.guidance_scale |
| | * ((vel_obj - vel_bg) ** 2).sum((1, 2, 3)) |
| | ) |
| |
|
| | vf = vel_uncond + self.guidance_scale * ( |
| | (vel_bg - vel_uncond) |
| | + kappa[i + 1][:, None, None, None] * (vel_obj - vel_bg) |
| | ) |
| | dx = 2 * dsigma * vf + noise |
| | latents += dx |
| |
|
| | ll_obj[i + 1] = ll_obj[i] + ( |
| | -torch.abs(dsigma) / sigma * (vel_obj) ** 2 |
| | - (dx * (vel_obj / sigma)) |
| | ).sum((1, 2, 3)) |
| | ll_bg[i + 1] = ll_bg[i] + ( |
| | -torch.abs(dsigma) / sigma * (vel_bg) ** 2 - |
| | (dx * (vel_bg / sigma)) |
| | ).sum((1, 2, 3)) |
| |
|
| | return latents |
| |
|
| | def postprocess(self, latents: Callable) -> Callable: |
| | """postprocess. |
| | |
| | Parameters |
| | ---------- |
| | latents : Callable |
| | latents |
| | |
| | Returns |
| | ------- |
| | Callable |
| | |
| | """ |
| | image = self.get_batch(latents, 1, self.batch_size) |
| | |
| | assert image.shape[-1] == 3 |
| |
|
| | |
| | image = image.to(torch.uint8) |
| |
|
| | return image |
| |
|
| | def __call__( |
| | self, |
| | prompt_1: str, |
| | prompt_2: str, |
| | seed: int = None, |
| | num_inference_steps: int = 1000, |
| | batch_size: int = 1, |
| | lift: int = 0.0, |
| | height: int = 512, |
| | width: int = 512, |
| | guidance_scale: int = 7.5, |
| | ) -> Callable: |
| | """__call__. |
| | |
| | Parameters |
| | ---------- |
| | prompt_1 : str |
| | prompt_1 |
| | prompt_2 : str |
| | prompt_2 |
| | seed : int |
| | seed |
| | num_inference_steps : int |
| | num_inference_steps |
| | batch_size : int |
| | batch_size |
| | lift : int |
| | lift |
| | height : int |
| | height |
| | width : int |
| | width |
| | guidance_scale : int |
| | guidance_scale |
| | |
| | Returns |
| | ------- |
| | Callable |
| | |
| | """ |
| | |
| | model_inputs = self.preprocess( |
| | prompt_1, |
| | prompt_2, |
| | seed, |
| | num_inference_steps, |
| | batch_size, |
| | lift, |
| | height, |
| | width, |
| | guidance_scale, |
| | ) |
| |
|
| | |
| | latents = self._forward(model_inputs) |
| |
|
| | |
| | images = self.postprocess(latents) |
| | return images |
| |
|
| |
|