Spaces:
Sleeping
Sleeping
| """ | |
| Model handler for WAN-VACE video generation | |
| """ | |
| import torch | |
| # ----------------------------------------------------------------------------- | |
| # XPU shim for CPU‑only environments | |
| # | |
| # Newer versions of `diffusers` attempt to call `torch.xpu.empty_cache()` for | |
| # Intel GPU support. If the installed PyTorch build does not include XPU | |
| # support (as is the case on CPU‑only environments), accessing `torch.xpu` | |
| # results in an AttributeError. To avoid this, we define a dummy `xpu` | |
| # namespace on the `torch` module when it is missing. This namespace | |
| # implements the minimal methods used by `diffusers` (`empty_cache`, | |
| # `is_available`, and `device_count`). | |
| # | |
| # Intel’s `intel-extension-for-pytorch` provides XPU support, but even when | |
| # installed, some CPU builds of PyTorch may not expose `torch.xpu`. This | |
| # shim ensures that the application runs regardless of whether XPU support is | |
| # present. | |
| # ----------------------------------------------------------------------------- | |
| if not hasattr(torch, "xpu"): | |
| class _DummyXPU: | |
| def empty_cache(): | |
| return None | |
| def manual_seed(_seed: int): | |
| return None | |
| def is_available(): | |
| return False | |
| def device_count(): | |
| return 0 | |
| def current_device(): | |
| return 0 | |
| def set_device(_idx: int): | |
| return None | |
| torch.xpu = _DummyXPU() # type: ignore | |
| import time | |
| from typing import Optional, Tuple, Any | |
| from transformers import UMT5EncoderModel | |
| from diffusers import AutoencoderKLWan, WanVACEPipeline, WanVACETransformer3DModel, GGUFQuantizationConfig | |
| from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler | |
| from diffusers.utils import export_to_video | |
| from huggingface_hub import login | |
| import gradio as gr | |
| from config import MODEL_CONFIG, DEFAULT_PARAMS, HF_TOKEN | |
| import os | |
| from utils import create_temp_video_path, validate_generation_params, validate_prompt, format_generation_info | |
| class WanVACEModelHandler: | |
| """Handler for WAN-VACE model loading and video generation""" | |
| def __init__(self): | |
| self.pipe = None | |
| self.is_loaded = False | |
| self.loading_progress = 0 | |
| def login_hf(self) -> bool: | |
| """Login to Hugging Face""" | |
| try: | |
| login(token=HF_TOKEN) | |
| return True | |
| except Exception as e: | |
| print(f"Warning: Could not login to Hugging Face: {e}") | |
| return False | |
| def load_model(self, progress_callback=None) -> Tuple[bool, str]: | |
| """Load the WAN-VACE model components""" | |
| try: | |
| # Login to HF | |
| self.login_hf() | |
| if progress_callback: | |
| progress_callback(0.1, "Loading transformer model...") | |
| # Determine desired dtype for CPU/GPU execution. | |
| # Hugging Face Spaces often run on CPU, where bfloat16 may not be supported. | |
| # Allow the dtype to be configured via the WAN_DTYPE environment variable. | |
| # Supported values: "bfloat16" (default) or "float32". | |
| dtype_str = os.getenv("WAN_DTYPE", "bfloat16").lower() | |
| # Select compute dtype: use bfloat16 only if requested and available. | |
| # Fall back to float32 otherwise. | |
| compute_dtype = torch.bfloat16 if dtype_str == "bfloat16" else torch.float32 | |
| # Likewise for the torch dtype used when loading weights. | |
| torch_dtype = compute_dtype | |
| # Load transformer | |
| transformer = WanVACETransformer3DModel.from_single_file( | |
| MODEL_CONFIG["transformer_path"], | |
| quantization_config=GGUFQuantizationConfig(compute_dtype=compute_dtype), | |
| torch_dtype=torch_dtype, | |
| ) | |
| if progress_callback: | |
| progress_callback(0.4, "Loading text encoder...") | |
| # Load text encoder | |
| text_encoder = UMT5EncoderModel.from_pretrained( | |
| MODEL_CONFIG["text_encoder_path"], | |
| gguf_file=MODEL_CONFIG["text_encoder_file"], | |
| torch_dtype=torch_dtype, | |
| ) | |
| if progress_callback: | |
| progress_callback(0.7, "Loading VAE...") | |
| # Load VAE | |
| vae = AutoencoderKLWan.from_pretrained( | |
| MODEL_CONFIG["vae_path"], | |
| subfolder="vae", | |
| torch_dtype=torch.float32 | |
| ) | |
| if progress_callback: | |
| progress_callback(0.9, "Assembling pipeline...") | |
| # Create pipeline | |
| self.pipe = WanVACEPipeline.from_pretrained( | |
| MODEL_CONFIG["pipeline_path"], | |
| transformer=transformer, | |
| text_encoder=text_encoder, | |
| vae=vae, | |
| torch_dtype=torch_dtype | |
| ) | |
| # Configure scheduler | |
| flow_shift = DEFAULT_PARAMS["flow_shift"] | |
| self.pipe.scheduler = UniPCMultistepScheduler.from_config( | |
| self.pipe.scheduler.config, | |
| flow_shift=flow_shift | |
| ) | |
| # Enable optimizations | |
| self.pipe.enable_model_cpu_offload() | |
| self.pipe.vae.enable_tiling() | |
| self.is_loaded = True | |
| if progress_callback: | |
| progress_callback(1.0, "Model loaded successfully!") | |
| return True, "Model loaded successfully!" | |
| except Exception as e: | |
| error_msg = f"Error loading model: {str(e)}" | |
| if progress_callback: | |
| progress_callback(0, error_msg) | |
| return False, error_msg | |
| def generate_video( | |
| self, | |
| prompt: str, | |
| negative_prompt: str = "", | |
| width: int = DEFAULT_PARAMS["width"], | |
| height: int = DEFAULT_PARAMS["height"], | |
| num_frames: int = DEFAULT_PARAMS["num_frames"], | |
| num_inference_steps: int = DEFAULT_PARAMS["num_inference_steps"], | |
| guidance_scale: float = DEFAULT_PARAMS["guidance_scale"], | |
| seed: Optional[int] = None, | |
| progress_callback=None | |
| ) -> Tuple[bool, str, str, str]: | |
| """ | |
| Generate video from text prompt | |
| Returns: (success, video_path, error_message, generation_info) | |
| """ | |
| if not self.is_loaded: | |
| return False, "", "Model not loaded. Please load the model first.", "" | |
| # Validate inputs | |
| prompt_valid, prompt_error = validate_prompt(prompt) | |
| if not prompt_valid: | |
| return False, "", prompt_error or "Invalid prompt", "" | |
| params_valid, params_error = validate_generation_params( | |
| width, height, num_frames, num_inference_steps, guidance_scale | |
| ) | |
| if not params_valid: | |
| return False, "", params_error or "Invalid parameters", "" | |
| try: | |
| if progress_callback: | |
| progress_callback(0.1, "Preparing generation...") | |
| # Check if pipeline is loaded | |
| if self.pipe is None: | |
| return False, "", "Pipeline not initialized. Please load the model first.", "" | |
| # Set up generator with seed | |
| generator = torch.Generator() | |
| if seed is not None: | |
| generator.manual_seed(seed) | |
| else: | |
| generator.manual_seed(0) # Default seed | |
| if progress_callback: | |
| progress_callback(0.2, "Starting video generation...") | |
| start_time = time.time() | |
| # Generate video | |
| output = self.pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt if negative_prompt else None, | |
| width=width, | |
| height=height, | |
| num_frames=num_frames, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| conditioning_scale=DEFAULT_PARAMS["conditioning_scale"], | |
| generator=generator, | |
| ).frames[0] | |
| if progress_callback: | |
| progress_callback(0.8, "Exporting video...") | |
| # Export to video file | |
| output_path = create_temp_video_path() | |
| export_to_video(output, output_path, fps=DEFAULT_PARAMS["fps"]) | |
| generation_time = time.time() - start_time | |
| if progress_callback: | |
| progress_callback(1.0, "Video generation complete!") | |
| # Format generation info | |
| gen_info = format_generation_info( | |
| prompt, negative_prompt, width, height, num_frames, | |
| num_inference_steps, guidance_scale, generation_time | |
| ) | |
| return True, output_path, "", gen_info | |
| except Exception as e: | |
| error_msg = f"Error during video generation: {str(e)}" | |
| if progress_callback: | |
| progress_callback(0, error_msg) | |
| return False, "", error_msg, "" | |
| # Global model handler instance | |
| model_handler = WanVACEModelHandler() | |