Video-agent / app.py
Ani14's picture
Update app.py
d2ce5d3 verified
"""
Gradio interface for WAN-VACE video generation
"""
import gradio as gr
import torch
# -----------------------------------------------------------------------------
# XPU shim for CPU‑only environments
#
# `diffusers` attempts to access `torch.xpu.empty_cache()` when cleaning up
# device memory. On CPU‑only builds of PyTorch (or builds without Intel
# extension support), the `xpu` attribute does not exist on the `torch`
# module. Defining a dummy `torch.xpu` prevents AttributeError during
# import.
# -----------------------------------------------------------------------------
if not hasattr(torch, "xpu"):
class _DummyXPU:
@staticmethod
def empty_cache():
return None
@staticmethod
def manual_seed(_seed: int):
return None
@staticmethod
def is_available():
return False
@staticmethod
def device_count():
return 0
@staticmethod
def current_device():
return 0
@staticmethod
def set_device(_idx: int):
return None
torch.xpu = _DummyXPU() # type: ignore
import time
from typing import Optional
# Import the simple planner
from planning import plan_from_topic
from config import UI_CONFIG, DEFAULT_PARAMS, SERVER_CONFIG
from model_handler import model_handler
from utils import cleanup_temp_files
def load_model_interface(progress=gr.Progress()):
"""Interface function for loading the model"""
def progress_callback(value, message):
progress(value, desc=message)
success, message = model_handler.load_model(progress_callback)
if success:
return (
gr.update(visible=False), # Hide load button
gr.update(visible=True), # Show generation interface
gr.update(value=message, visible=True), # Show success message
gr.update(visible=False) # Hide error message
)
else:
return (
gr.update(visible=True), # Keep load button visible
gr.update(visible=False), # Keep generation interface hidden
gr.update(visible=False), # Hide success message
gr.update(value=message, visible=True) # Show error message
)
def generate_video_interface(
prompt: str,
negative_prompt: str,
width: int,
height: int,
num_frames: int,
num_inference_steps: int,
guidance_scale: float,
seed: Optional[int],
progress=gr.Progress()
):
"""Interface function for video generation"""
def progress_callback(value, message):
progress(value, desc=message)
# Plan the prompt: treat the user input as a high‑level concept and let the
# agent craft a refined prompt and recommended negative prompt. If the user
# supplies a negative prompt, it overrides the recommended negative prompt.
plan = plan_from_topic(prompt)
# Use the refined prompt from the plan
effective_prompt = plan.prompt
# If the user provided a negative prompt, use it; otherwise use the recommended one
effective_negative = negative_prompt.strip() if negative_prompt and negative_prompt.strip() else plan.negative_prompt
success, video_path, error_msg, gen_info = model_handler.generate_video(
prompt=effective_prompt,
negative_prompt=effective_negative,
width=width,
height=height,
num_frames=num_frames,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
seed=seed,
progress_callback=progress_callback
)
if success:
return (
gr.update(value=video_path, visible=True), # Video output
gr.update(value=gen_info, visible=True), # Generation info
gr.update(visible=False) # Hide error message
)
else:
return (
gr.update(value=None, visible=False), # Hide video output
gr.update(visible=False), # Hide generation info
gr.update(value=error_msg, visible=True) # Show error message
)
def create_interface():
"""Create the Gradio interface"""
with gr.Blocks(
title=UI_CONFIG["title"],
theme=UI_CONFIG["theme"]
) as demo:
# Header
gr.Markdown(f"# {UI_CONFIG['title']}")
gr.Markdown(UI_CONFIG["description"])
# Model loading section
with gr.Row():
with gr.Column():
load_btn = gr.Button(
"🚀 Load Video Generation Model",
variant="primary",
size="lg"
)
load_success_msg = gr.Markdown(visible=False)
load_error_msg = gr.Markdown(visible=False)
# Main generation interface (initially hidden)
with gr.Column(visible=False) as generation_interface:
# Input section
with gr.Row():
with gr.Column(scale=2):
with gr.Group():
gr.Markdown("### 📝 Concept & Prompts")
# The user supplies a high‑level concept or topic. The agent will
# refine this into a detailed prompt automatically.
prompt_input = gr.Textbox(
label="Video Concept",
placeholder="Describe the concept you want to generate, e.g. 'a pig in a winter forest'...",
lines=3,
value="a pig moving quickly in a beautiful winter scenery nature trees sunset tracking camera"
)
# Optional negative prompt: overrides the agent's recommended negative prompt.
negative_prompt_input = gr.Textbox(
label="Negative Prompt (Optional)",
placeholder="Things you don't want in the video; leave empty to use the agent's recommendation...",
lines=2,
value=""
)
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### ⚙️ Generation Parameters")
with gr.Row():
width_slider = gr.Slider(
label="Width",
minimum=64,
maximum=1920,
step=8,
value=DEFAULT_PARAMS["width"]
)
height_slider = gr.Slider(
label="Height",
minimum=64,
maximum=1080,
step=8,
value=DEFAULT_PARAMS["height"]
)
num_frames_slider = gr.Slider(
label="Number of Frames",
minimum=1,
maximum=200,
step=1,
value=DEFAULT_PARAMS["num_frames"]
)
inference_steps_slider = gr.Slider(
label="Inference Steps",
minimum=1,
maximum=100,
step=1,
value=DEFAULT_PARAMS["num_inference_steps"]
)
guidance_scale_slider = gr.Slider(
label="Guidance Scale",
minimum=0.0,
maximum=20.0,
step=0.1,
value=DEFAULT_PARAMS["guidance_scale"]
)
seed_input = gr.Number(
label="Seed (Optional)",
value=0,
precision=0
)
# Generation button
with gr.Row():
generate_btn = gr.Button(
"🎬 Generate Video",
variant="primary",
size="lg"
)
# Output section
with gr.Row():
with gr.Column():
video_output = gr.Video(
label="Generated Video",
visible=False
)
generation_info = gr.Markdown(
label="Generation Information",
visible=False
)
generation_error = gr.Markdown(
visible=False
)
# Additional controls
with gr.Row():
with gr.Column():
gr.Markdown("""
### 💡 Tips:
- Enter a short **concept** (e.g. “a busy city street at dawn”). The agent will expand it into a detailed prompt.
- Adjust the **guidance scale**: higher values make the video adhere more closely to the refined prompt.
- Increasing **inference steps** improves quality at the cost of generation time.
- Use the optional **Negative Prompt** field only if you want to override the agent's recommended terms.
- Keep width and height multiples of 8 for optimal performance.
""")
with gr.Column():
if torch.cuda.is_available():
gpu_info = f"🎮 GPU: {torch.cuda.get_device_name()}"
else:
gpu_info = "💻 Running on CPU"
gr.Markdown(f"""
### 🖥️ System Information:
{gpu_info}
### 📊 Model Information:
- **Model:** WAN‑VACE 1.3B (Q4_0 Quantized)
- **Text Encoder:** UMT5‑XXL
- **Scheduler:** UniPC Multistep
### 🤖 Agent Details:
- **Planning:** The agent automatically crafts a detailed prompt and a recommended negative prompt based on your concept.
- **Override:** Supply your own negative prompt to override the recommendation if desired.
""")
# Event handlers
load_btn.click(
fn=load_model_interface,
outputs=[
load_btn,
generation_interface,
load_success_msg,
load_error_msg
]
)
generate_btn.click(
fn=generate_video_interface,
inputs=[
prompt_input,
negative_prompt_input,
width_slider,
height_slider,
num_frames_slider,
inference_steps_slider,
guidance_scale_slider,
seed_input
],
outputs=[
video_output,
generation_info,
generation_error
]
)
return demo
def main():
"""Main function to launch the application"""
print(f"🚀 Starting {UI_CONFIG['title']}...")
print(f"🔧 Server configuration: {SERVER_CONFIG['host']}:{SERVER_CONFIG['port']}")
# Check GPU availability
if torch.cuda.is_available():
print(f"🎮 GPU detected: {torch.cuda.get_device_name()}")
print(f"💾 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB")
else:
print("💻 Running on CPU (GPU recommended for better performance)")
# Create interface and enable the event queue to support multiple users.
demo = create_interface()
# Hugging Face Spaces expect `.queue()` to be called for handling request concurrency.
# Limiting concurrency_count to 1 helps prevent excessive memory usage on CPU-only hardware.
demo = demo.queue()
# Launch the interface.
demo.launch(
server_name=SERVER_CONFIG["host"],
server_port=SERVER_CONFIG["port"],
share=SERVER_CONFIG["share"],
show_error=True,
)
if __name__ == "__main__":
main()