from diffusers import DiffusionPipeline import torch from diffusers.utils import BaseOutput from dataclasses import dataclass from typing import List, Union, Optional, Tuple from PIL import Image import numpy as np from tqdm import tqdm @dataclass class SdxsPipelineOutput(BaseOutput): images: Union[List[Image.Image], np.ndarray] class SdxsPipeline(DiffusionPipeline): def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, max_length: int = 192): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler ) self.vae_scale_factor = 16 self.max_length = max_length def encode_prompt(self, prompt=None, negative_prompt=None, device=None, dtype=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: device = device or self.device dtype = dtype or next(self.unet.parameters()).dtype # Преобразуем в списки if isinstance(prompt, str): prompt = [prompt] if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] # Если промпты не заданы, используем пустые эмбеддинги if prompt is None and negative_prompt is None: hidden_dim = 1024 # Размерность эмбеддинга seq_len = self.max_length batch_size = 1 # ИЗМЕНЕНО: Возвращаем три элемента: embeds, mask, pooled empty_embeds = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device) empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device) empty_pooled = torch.zeros((batch_size, hidden_dim), dtype=dtype, device=device) return empty_embeds, empty_mask, empty_pooled # Токенизация с фиксированным max_length и padding="max_length" def encode_texts(texts, max_length=self.max_length): with torch.no_grad(): if isinstance(texts, str): texts = [texts] for i, prompt_item in enumerate(texts): messages = [ {"role": "user", "content": prompt_item}, ] prompt_item = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=True, ) texts[i] = prompt_item toks = self.tokenizer( texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length ).to(device) outs = self.text_encoder(**toks, output_hidden_states=True, return_dict=True) # Токен-эмбеддинги (для Cross-Attention) hidden = outs.hidden_states[-2] # Используем last hidden state -2??? # Маска внимания (для Cross-Attention) attention_mask = toks["attention_mask"] # Пулинг-эмбеддинг (для Class/Time Conditioning). Берем эмбеддинг последнего токена без padding. sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = hidden.shape[0] pooled = hidden[torch.arange(batch_size, device=hidden.device), sequence_lengths] # --- НОВАЯ ЛОГИКА: ОБЪЕДИНЕНИЕ ДЛЯ КРОСС-ВНИМАНИЯ --- # 1. Расширяем пулинг-вектор до последовательности [B, 1, 1024] pooled_expanded = pooled.unsqueeze(1) # 2. Объединяем последовательность токенов и пулинг-вектор # !!! ИЗМЕНЕНИЕ ЗДЕСЬ !!!: Пулинг идет ПЕРВЫМ # Теперь: [B, 1 + L, 1024]. Пулинг стал токеном в НАЧАЛЕ. new_encoder_hidden_states = torch.cat([pooled_expanded, hidden], dim=1) # 3. Обновляем маску внимания для нового токена # Маска внимания: [B, 1 + L]. Добавляем 1 в НАЧАЛО. # torch.ones((batch_size, 1), device=device) создает маску [B, 1] со значениями 1. new_attention_mask = torch.cat([torch.ones((batch_size, 1), device=device), attention_mask], dim=1) return new_encoder_hidden_states, new_attention_mask, pooled # Кодируем позитивные и негативные промпты # ИСПРАВЛЕНИЕ: Теперь возвращаем (None, None, None), чтобы избежать UnboundLocalError pos_result = encode_texts(prompt) if prompt is not None else (None, None, None) neg_result = encode_texts(negative_prompt) if negative_prompt is not None else (None, None, None) pos_embeddings, pos_mask, pos_pooled = pos_result neg_embeddings, neg_mask, neg_pooled = neg_result # Выравниваем размеры batch_size batch_size = max( pos_embeddings.shape[0] if pos_embeddings is not None else 0, neg_embeddings.shape[0] if neg_embeddings is not None else 0 ) # Повторяем эмбеддинги, маски и пулинг по batch_size if pos_embeddings is not None and pos_embeddings.shape[0] < batch_size: pos_embeddings = pos_embeddings.repeat(batch_size, 1, 1) pos_mask = pos_mask.repeat(batch_size, 1) pos_pooled = pos_pooled.repeat(batch_size, 1) # ИСПРАВЛЕНИЕ: Проверяем, существует ли neg_embeddings, прежде чем обращаться к его shape[0] if neg_embeddings is not None and neg_embeddings.shape[0] < batch_size: neg_embeddings = neg_embeddings.repeat(batch_size, 1, 1) neg_mask = neg_mask.repeat(batch_size, 1) neg_pooled = neg_pooled.repeat(batch_size, 1) # Конкатенируем для guidance (эмбеддинги и маски) # Убеждаемся, что все три компонента существуют перед конкатенацией if pos_embeddings is not None and neg_embeddings is not None: text_embeddings = torch.cat([neg_embeddings, pos_embeddings], dim=0) attention_mask = torch.cat([neg_mask, pos_mask], dim=0) pooled_embeddings = torch.cat([neg_pooled, pos_pooled], dim=0) elif pos_embeddings is not None: text_embeddings = pos_embeddings attention_mask = pos_mask pooled_embeddings = pos_pooled else: # Только neg_embeddings text_embeddings = neg_embeddings attention_mask = neg_mask pooled_embeddings = neg_pooled # Возвращаем кортеж return ( text_embeddings.to(device=device, dtype=dtype), attention_mask.to(device=device, dtype=torch.int64), pooled_embeddings.to(device=device, dtype=dtype) ) @torch.no_grad() def generate_latents( self, text_embeddings, attention_mask, pooled_embeddings, height: int = 1280, width: int = 1024, num_inference_steps: int = 40, guidance_scale: float = 4.0, latent_channels: int = 16, batch_size: int = 1, generator=None, ): device = self.device dtype = next(self.unet.parameters()).dtype self.scheduler.set_timesteps(num_inference_steps, device=device) # Разделяем эмбеддинги и маски на условные и безусловные if guidance_scale > 1: neg_embeds, pos_embeds = text_embeddings.chunk(2) neg_mask, pos_mask = attention_mask.chunk(2) neg_pooled, pos_pooled = pooled_embeddings.chunk(2) # Повторяем, если batch_size больше if batch_size > pos_embeds.shape[0]: pos_embeds = pos_embeds.repeat(batch_size, 1, 1) neg_embeds = neg_embeds.repeat(batch_size, 1, 1) pos_mask = pos_mask.repeat(batch_size, 1) neg_mask = neg_mask.repeat(batch_size, 1) pos_pooled = pos_pooled.repeat(batch_size, 1) neg_pooled = neg_pooled.repeat(batch_size, 1) text_embeddings = torch.cat([neg_embeds, pos_embeds], dim=0) unet_attention_mask = torch.cat([neg_mask, pos_mask], dim=0) unet_pooled_embeddings = torch.cat([neg_pooled, pos_pooled], dim=0) else: text_embeddings = text_embeddings.repeat(batch_size, 1, 1) unet_attention_mask = attention_mask.repeat(batch_size, 1) unet_pooled_embeddings = pooled_embeddings.repeat(batch_size, 1) # Инициализация латентов latent_shape = ( batch_size, latent_channels, height // self.vae_scale_factor, width // self.vae_scale_factor ) latents = torch.randn(latent_shape, device=device, dtype=dtype, generator=generator) # Процесс диффузии for t in tqdm(self.scheduler.timesteps, desc="Генерация"): latent_input = torch.cat([latents, latents], dim=0) if guidance_scale > 1 else latents noise_pred = self.unet( latent_input, t, encoder_hidden_states=text_embeddings, encoder_attention_mask=unet_attention_mask, #added_cond_kwargs={'text_embeds': unet_pooled_embeddings} ).sample if guidance_scale > 1: noise_uncond, noise_text = noise_pred.chunk(2) noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) latents = self.scheduler.step(noise_pred, t, latents).prev_sample return latents def decode_latents(self, latents, output_type="pil"): """Декодирование латентов в изображения.""" latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor with torch.no_grad(): images = self.vae.decode(latents).sample images = (images / 2 + 0.5).clamp(0, 1) if output_type == "pil": images = images.cpu().permute(0, 2, 3, 1).float().numpy() images = (images * 255).round().astype("uint8") return [Image.fromarray(image) for image in images] return images.cpu().permute(0, 2, 3, 1).float().numpy() @torch.no_grad() def __call__( self, prompt: Optional[Union[str, List[str]]] = None, height: int = 1280, width: int = 1024, num_inference_steps: int = 40, guidance_scale: float = 4.0, latent_channels: int = 16, output_type: str = "pil", return_dict: bool = True, batch_size: int = 1, seed: Optional[int] = None, negative_prompt: Optional[Union[str, List[str]]] = None, text_embeddings: Optional[torch.FloatTensor] = None, ): device = self.device generator = torch.Generator(device=device).manual_seed(seed) if seed is not None else None if text_embeddings is None: if prompt is None and negative_prompt is None: raise ValueError("Необходимо указать prompt, negative_prompt или text_embeddings") text_embeddings, attention_mask, pooled_embeddings = self.encode_prompt( prompt, negative_prompt, device=device, dtype=next(self.unet.parameters()).dtype ) else: # Требуется, чтобы внешний text_embeddings содержал объединенные cond/uncond, # но мы не можем получить attention_mask и pooled_embeddings. # Для простоты лучше требовать prompt/negative_prompt. raise NotImplementedError("Передача text_embeddings напрямую пока не поддерживает передачу маски и пулинга. Используйте prompt/negative_prompt.") latents = self.generate_latents( text_embeddings=text_embeddings, attention_mask=attention_mask, pooled_embeddings=pooled_embeddings, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, latent_channels=latent_channels, batch_size=batch_size, generator=generator ) images = self.decode_latents(latents, output_type=output_type) if not return_dict: return images return SdxsPipelineOutput(images=images)