| | import torch |
| | from diffusers import DiffusionPipeline, DDPMScheduler, StableDiffusionPipeline |
| | from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput |
| | from diffusers.image_processor import VaeImageProcessor |
| | from huggingface_hub import PyTorchModelHubMixin |
| | from transformers import CLIPTextModel, CLIPTextModelWithProjection |
| | from diffusers.models.attention_processor import ( |
| | AttnProcessor2_0, |
| | FusedAttnProcessor2_0, |
| | XFormersAttnProcessor, |
| | ) |
| |
|
| |
|
| | class CombinedStableDiffusionXL( |
| | DiffusionPipeline, |
| | PyTorchModelHubMixin |
| | ): |
| | """ |
| | A Stable Diffusion model wrapper that provides functionality for text-to-image synthesis, |
| | noise scheduling, latent space manipulation, and image decoding. |
| | """ |
| | def __init__( |
| | self, |
| | original_unet: torch.nn.Module, |
| | fine_tuned_unet: torch.nn.Module, |
| | scheduler: DDPMScheduler, |
| | vae: torch.nn.Module, |
| | tokenizer: CLIPTextModel, |
| | tokenizer_2: CLIPTextModel, |
| | text_encoder: CLIPTextModelWithProjection, |
| | text_encoder_2: CLIPTextModelWithProjection, |
| | ) -> None: |
| |
|
| | super().__init__() |
| |
|
| | self.register_modules( |
| | tokenizer=tokenizer, |
| | tokenizer_2=tokenizer_2, |
| | text_encoder=text_encoder, |
| | text_encoder_2=text_encoder_2, |
| | original_unet=original_unet, |
| | fine_tuned_unet=fine_tuned_unet, |
| | scheduler=scheduler, |
| | vae=vae, |
| | ) |
| |
|
| | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
| | self.image_processor = VaeImageProcessor( |
| | vae_scale_factor=self.vae_scale_factor |
| | ) |
| | self.resolution = 1024 |
| |
|
| | def _get_negative_prompts( |
| | self, batch_size: int |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | inputs_ids_1 = self.tokenizer( |
| | [""] * batch_size, |
| | max_length=self.tokenizer.model_max_length, |
| | padding="max_length", |
| | truncation=True, |
| | return_tensors="pt", |
| | ).input_ids |
| |
|
| | input_ids_2 = self.tokenizer_2( |
| | [""] * batch_size, |
| | max_length=self.tokenizer.model_max_length, |
| | padding="max_length", |
| | truncation=True, |
| | return_tensors="pt", |
| | ).input_ids |
| | return inputs_ids_1, input_ids_2 |
| |
|
| | def _get_encoder_hidden_states( |
| | self, |
| | tokenized_prompts_1: torch.Tensor, |
| | tokenized_prompts_2: torch.Tensor, |
| | do_classifier_free_guidance: bool = False |
| | ) -> torch.Tensor: |
| | text_input_ids_list = [ |
| | tokenized_prompts_1, |
| | tokenized_prompts_2 |
| | ] |
| | batch_size = text_input_ids_list[0].size(0) |
| |
|
| | if do_classifier_free_guidance: |
| | negative_prompts = [ |
| | embed.to(text_input_ids_list[0].device) |
| | for embed in self._get_negative_prompts(batch_size) |
| | ] |
| |
|
| | text_input_ids_list = [ |
| | torch.cat( |
| | [ |
| | negative_prompt, |
| | text_input, |
| | ] |
| | ) |
| | for text_input, negative_prompt in zip( |
| | text_input_ids_list, negative_prompts |
| | ) |
| | ] |
| | prompt_embeds_list = [] |
| |
|
| | text_encoders = [self.text_encoder, self.text_encoder_2] |
| | for text_encoder, text_input_ids in zip(text_encoders, text_input_ids_list): |
| | prompt_embeds = text_encoder( |
| | text_input_ids.to(text_encoder.device), |
| | output_hidden_states=True, |
| | return_dict=False, |
| | ) |
| | pooled_prompt_embeds = prompt_embeds[0] |
| | prompt_embeds = prompt_embeds[-1][-2] |
| | bs_embed, seq_len, _ = prompt_embeds.shape |
| | prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) |
| | prompt_embeds_list.append(prompt_embeds) |
| |
|
| | prompt_embeds = torch.cat(prompt_embeds_list, dim=-1) |
| | pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) |
| | return prompt_embeds, pooled_prompt_embeds |
| |
|
| | def _get_unet_prediction( |
| | self, |
| | latent_model_input: torch.Tensor, |
| | timestep: int, |
| | encoder_hidden_states: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """ |
| | Return unet noise prediction |
| | |
| | Args: |
| | latent_model_input (torch.Tensor): Unet latents input |
| | timestep (int): noise scheduler timestep |
| | encoder_hidden_states (tuple[torch.Tensor, torch.Tensor]): Text encoder hidden states |
| | |
| | Returns: |
| | torch.Tensor: noise prediction |
| | """ |
| | unet = self.original_unet if self._use_original_unet else self.fine_tuned_unet |
| |
|
| | prompt_embeds, pooled_prompt_embeds = encoder_hidden_states |
| | target_size = torch.tensor( |
| | [ |
| | [self.resolution, self.resolution] |
| | for _ in range(latent_model_input.size(0)) |
| | ], |
| | device=latent_model_input.device, |
| | dtype=torch.float32, |
| | ) |
| | add_time_ids = torch.cat( |
| | [target_size, torch.zeros_like(target_size), target_size], dim=1 |
| | ) |
| |
|
| | unet_added_conditions = { |
| | "time_ids": add_time_ids, |
| | "text_embeds": pooled_prompt_embeds, |
| | } |
| |
|
| | return unet( |
| | latent_model_input, |
| | timestep, |
| | encoder_hidden_states=prompt_embeds, |
| | added_cond_kwargs=unet_added_conditions, |
| | ).sample |
| |
|
| | def get_noise_prediction( |
| | self, |
| | latents: torch.Tensor, |
| | timestep_index: int, |
| | encoder_hidden_states: torch.Tensor, |
| | do_classifier_free_guidance: bool = False, |
| | detach_main_path: bool = False, |
| | ): |
| | """ |
| | Return noise prediction |
| | |
| | Args: |
| | latents (torch.Tensor): Image latents |
| | timestep_index (int): noise scheduler timestep index |
| | encoder_hidden_states (torch.Tensor): Text encoder hidden states |
| | do_classifier_free_guidance (bool) Whether to do classifier free guidance |
| | detach_main_path (bool): Detach gradient |
| | |
| | Returns: |
| | torch.Tensor: noise prediction |
| | """ |
| | timestep = self.scheduler.timesteps[timestep_index] |
| |
|
| | latent_model_input = self.scheduler.scale_model_input( |
| | sample=torch.cat([latents] * 2) if do_classifier_free_guidance else latents, |
| | timestep=timestep, |
| | ) |
| |
|
| | noise_pred = self._get_unet_prediction( |
| | latent_model_input=latent_model_input, |
| | timestep=timestep, |
| | encoder_hidden_states=encoder_hidden_states, |
| | ) |
| |
|
| | if do_classifier_free_guidance: |
| | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| | if detach_main_path: |
| | noise_pred_text = noise_pred_text.detach() |
| |
|
| | noise_pred = noise_pred_uncond + self.guidance_scale * ( |
| | noise_pred_text - noise_pred_uncond |
| | ) |
| | return noise_pred |
| |
|
| | def sample_next_latents( |
| | self, |
| | latents: torch.Tensor, |
| | timestep_index: int, |
| | noise_pred: torch.Tensor, |
| | return_pred_original: bool = False, |
| | ) -> torch.Tensor: |
| | """ |
| | Return next latents prediction |
| | |
| | Args: |
| | latents (torch.Tensor): Image latents |
| | timestep_index (int): noise scheduler timestep index |
| | noise_pred (torch.Tensor): noise prediction |
| | return_pred_original (bool) Whether to sample original sample |
| | |
| | Returns: |
| | torch.Tensor: latent prediction |
| | """ |
| | timestep = self.scheduler.timesteps[timestep_index] |
| | sample = self.scheduler.step( |
| | model_output=noise_pred, timestep=timestep, sample=latents |
| | ) |
| | return ( |
| | sample.pred_original_sample if return_pred_original else sample.prev_sample |
| | ) |
| |
|
| | def predict_next_latents( |
| | self, |
| | latents: torch.Tensor, |
| | timestep_index: int, |
| | encoder_hidden_states: torch.Tensor, |
| | return_pred_original: bool = False, |
| | do_classifier_free_guidance: bool = False, |
| | detach_main_path: bool = False, |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Predicts the next latent states during the diffusion process. |
| | |
| | Args: |
| | latents (torch.Tensor): Current latent states. |
| | timestep_index (int): Index of the current timestep. |
| | encoder_hidden_states (torch.Tensor): Encoder hidden states from the text encoder. |
| | return_pred_original (bool): Whether to return the predicted original sample. |
| | do_classifier_free_guidance (bool) Whether to do classifier free guidance |
| | detach_main_path (bool): Detach gradient |
| | |
| | Returns: |
| | tuple: Next latents and predicted noise tensor. |
| | """ |
| |
|
| | noise_pred = self.get_noise_prediction( |
| | latents=latents, |
| | timestep_index=timestep_index, |
| | encoder_hidden_states=encoder_hidden_states, |
| | do_classifier_free_guidance=do_classifier_free_guidance, |
| | detach_main_path=detach_main_path, |
| | ) |
| |
|
| | latents = self.sample_next_latents( |
| | latents=latents, |
| | noise_pred=noise_pred, |
| | timestep_index=timestep_index, |
| | return_pred_original=return_pred_original, |
| | ) |
| |
|
| | return latents, noise_pred |
| |
|
| | def get_latents(self, batch_size: int, device: torch.device) -> torch.Tensor: |
| | latent_resolution = int(self.resolution) // self.vae_scale_factor |
| | return torch.randn( |
| | ( |
| | batch_size, |
| | self.original_unet.config.in_channels, |
| | latent_resolution, |
| | latent_resolution, |
| | ), |
| | device=device, |
| | ) |
| |
|
| | def do_k_diffusion_steps( |
| | self, |
| | start_timestep_index: int, |
| | end_timestep_index: int, |
| | latents: torch.Tensor, |
| | encoder_hidden_states: torch.Tensor, |
| | return_pred_original: bool = False, |
| | do_classifier_free_guidance: bool = False, |
| | detach_main_path: bool = False, |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Performs multiple diffusion steps between specified timesteps. |
| | |
| | Args: |
| | start_timestep_index (int): Starting timestep index. |
| | end_timestep_index (int): Ending timestep index. |
| | latents (torch.Tensor): Initial latents. |
| | encoder_hidden_states (torch.Tensor): Encoder hidden states. |
| | return_pred_original (bool): Whether to return the predicted original sample. |
| | do_classifier_free_guidance (bool) Whether to do classifier free guidance |
| | detach_main_path (bool): Detach gradient |
| | |
| | Returns: |
| | tuple: Resulting latents and encoder hidden states. |
| | """ |
| | assert start_timestep_index <= end_timestep_index |
| |
|
| | for timestep_index in range(start_timestep_index, end_timestep_index - 1): |
| | latents, _ = self.predict_next_latents( |
| | latents=latents, |
| | timestep_index=timestep_index, |
| | encoder_hidden_states=encoder_hidden_states, |
| | return_pred_original=False, |
| | do_classifier_free_guidance=do_classifier_free_guidance, |
| | detach_main_path=detach_main_path, |
| | ) |
| | res, _ = self.predict_next_latents( |
| | latents=latents, |
| | timestep_index=end_timestep_index - 1, |
| | encoder_hidden_states=encoder_hidden_states, |
| | return_pred_original=return_pred_original, |
| | do_classifier_free_guidance=do_classifier_free_guidance, |
| | ) |
| | return res, encoder_hidden_states |
| |
|
| | def upcast_vae(self): |
| | dtype = self.vae.dtype |
| | self.vae.to(dtype=torch.float32) |
| | use_torch_2_0_or_xformers = isinstance( |
| | self.vae.decoder.mid_block.attentions[0].processor, |
| | ( |
| | AttnProcessor2_0, |
| | XFormersAttnProcessor, |
| | FusedAttnProcessor2_0, |
| | ), |
| | ) |
| | if use_torch_2_0_or_xformers: |
| | self.vae.post_quant_conv.to(dtype) |
| | self.vae.decoder.conv_in.to(dtype) |
| | self.vae.decoder.mid_block.to(dtype) |
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, |
| | prompt: str | list[str], |
| | num_inference_steps=40, |
| | original_unet_steps=35, |
| | resolution=1024, |
| | guidance_scale=5, |
| | output_type: str = "pil", |
| | return_dict: bool = True, |
| | ): |
| | self.guidance_scale = guidance_scale |
| | self.resolution = resolution |
| | batch_size = 1 if isinstance(prompt, str) else len(prompt) |
| |
|
| | tokenized_prompts_1 = self.tokenizer( |
| | prompt, |
| | max_length=self.tokenizer.model_max_length, |
| | padding="max_length", |
| | truncation=True, |
| | return_tensors="pt", |
| | ).input_ids |
| |
|
| | tokenized_prompts_2 = self.tokenizer_2( |
| | prompt, |
| | max_length=self.tokenizer_2.model_max_length, |
| | padding="max_length", |
| | truncation=True, |
| | return_tensors="pt", |
| | ).input_ids |
| |
|
| | original_encoder_hidden_states = self._get_encoder_hidden_states( |
| | tokenized_prompts_1=tokenized_prompts_1, |
| | tokenized_prompts_2=tokenized_prompts_2, |
| | do_classifier_free_guidance=True |
| | ) |
| | fine_tuned_encoder_hidden_states = self._get_encoder_hidden_states( |
| | tokenized_prompts_1=tokenized_prompts_1, |
| | tokenized_prompts_2=tokenized_prompts_2, |
| | do_classifier_free_guidance=False |
| | ) |
| |
|
| | latent_resolution = int(resolution) // self.vae_scale_factor |
| | latents = torch.randn( |
| | ( |
| | batch_size, |
| | self.original_unet.config.in_channels, |
| | latent_resolution, |
| | latent_resolution, |
| | ), |
| | device=self.device, |
| | ) |
| |
|
| | self.scheduler.set_timesteps( |
| | num_inference_steps, |
| | device=self.device |
| | ) |
| |
|
| | self._use_original_unet = True |
| | latents, _ = self.do_k_diffusion_steps( |
| | start_timestep_index=0, |
| | end_timestep_index=original_unet_steps, |
| | latents=latents, |
| | encoder_hidden_states=original_encoder_hidden_states, |
| | return_pred_original=False, |
| | do_classifier_free_guidance=True, |
| | ) |
| |
|
| | self._use_original_unet = False |
| | latents, _ = self.do_k_diffusion_steps( |
| | start_timestep_index=original_unet_steps, |
| | end_timestep_index=num_inference_steps, |
| | latents=latents, |
| | encoder_hidden_states=fine_tuned_encoder_hidden_states, |
| | return_pred_original=False, |
| | do_classifier_free_guidance=False, |
| | ) |
| |
|
| |
|
| | if not output_type == "latent": |
| | |
| | needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast |
| |
|
| | if needs_upcasting: |
| | self.upcast_vae() |
| | latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) |
| | elif latents.dtype != self.vae.dtype: |
| | if torch.backends.mps.is_available(): |
| | |
| | self.vae = self.vae.to(latents.dtype) |
| |
|
| | latents = latents / self.vae.config.scaling_factor |
| |
|
| | image = self.vae.decode(latents).sample |
| |
|
| | |
| | if needs_upcasting: |
| | self.vae.to(dtype=torch.float16) |
| | else: |
| | image = latents |
| |
|
| | if not output_type == "latent": |
| | image = self.image_processor.postprocess( |
| | image, |
| | output_type=output_type, |
| | do_denormalize=[True] * image.shape[0] |
| | ) |
| |
|
| | |
| | self.maybe_free_model_hooks() |
| |
|
| | if not return_dict: |
| | return (image,) |
| |
|
| | return StableDiffusionXLPipelineOutput(images=image) |
| |
|