Video-agent / model_handler.py
Ani14's picture
Update model_handler.py
76e1f82 verified
"""
Model handler for WAN-VACE video generation
"""
import torch
# -----------------------------------------------------------------------------
# XPU shim for CPU‑only environments
#
# Newer versions of `diffusers` attempt to call `torch.xpu.empty_cache()` for
# Intel GPU support. If the installed PyTorch build does not include XPU
# support (as is the case on CPU‑only environments), accessing `torch.xpu`
# results in an AttributeError. To avoid this, we define a dummy `xpu`
# namespace on the `torch` module when it is missing. This namespace
# implements the minimal methods used by `diffusers` (`empty_cache`,
# `is_available`, and `device_count`).
#
# Intel’s `intel-extension-for-pytorch` provides XPU support, but even when
# installed, some CPU builds of PyTorch may not expose `torch.xpu`. This
# shim ensures that the application runs regardless of whether XPU support is
# present.
# -----------------------------------------------------------------------------
if not hasattr(torch, "xpu"):
class _DummyXPU:
@staticmethod
def empty_cache():
return None
@staticmethod
def manual_seed(_seed: int):
return None
@staticmethod
def is_available():
return False
@staticmethod
def device_count():
return 0
@staticmethod
def current_device():
return 0
@staticmethod
def set_device(_idx: int):
return None
torch.xpu = _DummyXPU() # type: ignore
import time
from typing import Optional, Tuple, Any
from transformers import UMT5EncoderModel
from diffusers import AutoencoderKLWan, WanVACEPipeline, WanVACETransformer3DModel, GGUFQuantizationConfig
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from diffusers.utils import export_to_video
from huggingface_hub import login
import gradio as gr
from config import MODEL_CONFIG, DEFAULT_PARAMS, HF_TOKEN
import os
from utils import create_temp_video_path, validate_generation_params, validate_prompt, format_generation_info
class WanVACEModelHandler:
"""Handler for WAN-VACE model loading and video generation"""
def __init__(self):
self.pipe = None
self.is_loaded = False
self.loading_progress = 0
def login_hf(self) -> bool:
"""Login to Hugging Face"""
try:
login(token=HF_TOKEN)
return True
except Exception as e:
print(f"Warning: Could not login to Hugging Face: {e}")
return False
def load_model(self, progress_callback=None) -> Tuple[bool, str]:
"""Load the WAN-VACE model components"""
try:
# Login to HF
self.login_hf()
if progress_callback:
progress_callback(0.1, "Loading transformer model...")
# Determine desired dtype for CPU/GPU execution.
# Hugging Face Spaces often run on CPU, where bfloat16 may not be supported.
# Allow the dtype to be configured via the WAN_DTYPE environment variable.
# Supported values: "bfloat16" (default) or "float32".
dtype_str = os.getenv("WAN_DTYPE", "bfloat16").lower()
# Select compute dtype: use bfloat16 only if requested and available.
# Fall back to float32 otherwise.
compute_dtype = torch.bfloat16 if dtype_str == "bfloat16" else torch.float32
# Likewise for the torch dtype used when loading weights.
torch_dtype = compute_dtype
# Load transformer
transformer = WanVACETransformer3DModel.from_single_file(
MODEL_CONFIG["transformer_path"],
quantization_config=GGUFQuantizationConfig(compute_dtype=compute_dtype),
torch_dtype=torch_dtype,
)
if progress_callback:
progress_callback(0.4, "Loading text encoder...")
# Load text encoder
text_encoder = UMT5EncoderModel.from_pretrained(
MODEL_CONFIG["text_encoder_path"],
gguf_file=MODEL_CONFIG["text_encoder_file"],
torch_dtype=torch_dtype,
)
if progress_callback:
progress_callback(0.7, "Loading VAE...")
# Load VAE
vae = AutoencoderKLWan.from_pretrained(
MODEL_CONFIG["vae_path"],
subfolder="vae",
torch_dtype=torch.float32
)
if progress_callback:
progress_callback(0.9, "Assembling pipeline...")
# Create pipeline
self.pipe = WanVACEPipeline.from_pretrained(
MODEL_CONFIG["pipeline_path"],
transformer=transformer,
text_encoder=text_encoder,
vae=vae,
torch_dtype=torch_dtype
)
# Configure scheduler
flow_shift = DEFAULT_PARAMS["flow_shift"]
self.pipe.scheduler = UniPCMultistepScheduler.from_config(
self.pipe.scheduler.config,
flow_shift=flow_shift
)
# Enable optimizations
self.pipe.enable_model_cpu_offload()
self.pipe.vae.enable_tiling()
self.is_loaded = True
if progress_callback:
progress_callback(1.0, "Model loaded successfully!")
return True, "Model loaded successfully!"
except Exception as e:
error_msg = f"Error loading model: {str(e)}"
if progress_callback:
progress_callback(0, error_msg)
return False, error_msg
def generate_video(
self,
prompt: str,
negative_prompt: str = "",
width: int = DEFAULT_PARAMS["width"],
height: int = DEFAULT_PARAMS["height"],
num_frames: int = DEFAULT_PARAMS["num_frames"],
num_inference_steps: int = DEFAULT_PARAMS["num_inference_steps"],
guidance_scale: float = DEFAULT_PARAMS["guidance_scale"],
seed: Optional[int] = None,
progress_callback=None
) -> Tuple[bool, str, str, str]:
"""
Generate video from text prompt
Returns: (success, video_path, error_message, generation_info)
"""
if not self.is_loaded:
return False, "", "Model not loaded. Please load the model first.", ""
# Validate inputs
prompt_valid, prompt_error = validate_prompt(prompt)
if not prompt_valid:
return False, "", prompt_error or "Invalid prompt", ""
params_valid, params_error = validate_generation_params(
width, height, num_frames, num_inference_steps, guidance_scale
)
if not params_valid:
return False, "", params_error or "Invalid parameters", ""
try:
if progress_callback:
progress_callback(0.1, "Preparing generation...")
# Check if pipeline is loaded
if self.pipe is None:
return False, "", "Pipeline not initialized. Please load the model first.", ""
# Set up generator with seed
generator = torch.Generator()
if seed is not None:
generator.manual_seed(seed)
else:
generator.manual_seed(0) # Default seed
if progress_callback:
progress_callback(0.2, "Starting video generation...")
start_time = time.time()
# Generate video
output = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt if negative_prompt else None,
width=width,
height=height,
num_frames=num_frames,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
conditioning_scale=DEFAULT_PARAMS["conditioning_scale"],
generator=generator,
).frames[0]
if progress_callback:
progress_callback(0.8, "Exporting video...")
# Export to video file
output_path = create_temp_video_path()
export_to_video(output, output_path, fps=DEFAULT_PARAMS["fps"])
generation_time = time.time() - start_time
if progress_callback:
progress_callback(1.0, "Video generation complete!")
# Format generation info
gen_info = format_generation_info(
prompt, negative_prompt, width, height, num_frames,
num_inference_steps, guidance_scale, generation_time
)
return True, output_path, "", gen_info
except Exception as e:
error_msg = f"Error during video generation: {str(e)}"
if progress_callback:
progress_callback(0, error_msg)
return False, "", error_msg, ""
# Global model handler instance
model_handler = WanVACEModelHandler()