File size: 12,851 Bytes
95315db
 
 
 
 
7406fce
 
 
 
 
 
 
 
 
 
 
 
 
1738f9f
7406fce
 
1738f9f
 
 
 
7406fce
 
1738f9f
 
 
 
7406fce
1738f9f
 
 
 
95315db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0aa024f
95315db
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
"""
Gradio interface for WAN-VACE video generation
"""
import gradio as gr
import torch

# -----------------------------------------------------------------------------
# XPU shim for CPU‑only environments
#
# `diffusers` attempts to access `torch.xpu.empty_cache()` when cleaning up
# device memory. On CPU‑only builds of PyTorch (or builds without Intel
# extension support), the `xpu` attribute does not exist on the `torch`
# module. Defining a dummy `torch.xpu` prevents AttributeError during
# import.
# -----------------------------------------------------------------------------
if not hasattr(torch, "xpu"):
    class _DummyXPU:
        @staticmethod
        def empty_cache():
            return None
        @staticmethod
        def manual_seed(_seed: int):
            return None
        @staticmethod
        def is_available():
            return False
        @staticmethod
        def device_count():
            return 0
        @staticmethod
        def current_device():
            return 0
        @staticmethod
        def set_device(_idx: int):
            return None
    torch.xpu = _DummyXPU()  # type: ignore
import time
from typing import Optional

# Import the simple planner
from planning import plan_from_topic

from config import UI_CONFIG, DEFAULT_PARAMS, SERVER_CONFIG
from model_handler import model_handler
from utils import cleanup_temp_files

def load_model_interface(progress=gr.Progress()):
    """Interface function for loading the model"""
    def progress_callback(value, message):
        progress(value, desc=message)
    
    success, message = model_handler.load_model(progress_callback)
    
    if success:
        return (
            gr.update(visible=False),  # Hide load button
            gr.update(visible=True),   # Show generation interface
            gr.update(value=message, visible=True),  # Show success message
            gr.update(visible=False)   # Hide error message
        )
    else:
        return (
            gr.update(visible=True),   # Keep load button visible
            gr.update(visible=False),  # Keep generation interface hidden
            gr.update(visible=False),  # Hide success message
            gr.update(value=message, visible=True)   # Show error message
        )

def generate_video_interface(
    prompt: str,
    negative_prompt: str,
    width: int,
    height: int,
    num_frames: int,
    num_inference_steps: int,
    guidance_scale: float,
    seed: Optional[int],
    progress=gr.Progress()
):
    """Interface function for video generation"""
    
    def progress_callback(value, message):
        progress(value, desc=message)
    
    # Plan the prompt: treat the user input as a high‑level concept and let the
    # agent craft a refined prompt and recommended negative prompt.  If the user
    # supplies a negative prompt, it overrides the recommended negative prompt.
    plan = plan_from_topic(prompt)
    # Use the refined prompt from the plan
    effective_prompt = plan.prompt
    # If the user provided a negative prompt, use it; otherwise use the recommended one
    effective_negative = negative_prompt.strip() if negative_prompt and negative_prompt.strip() else plan.negative_prompt

    success, video_path, error_msg, gen_info = model_handler.generate_video(
        prompt=effective_prompt,
        negative_prompt=effective_negative,
        width=width,
        height=height,
        num_frames=num_frames,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        seed=seed,
        progress_callback=progress_callback
    )
    
    if success:
        return (
            gr.update(value=video_path, visible=True),  # Video output
            gr.update(value=gen_info, visible=True),    # Generation info
            gr.update(visible=False)                    # Hide error message
        )
    else:
        return (
            gr.update(value=None, visible=False),       # Hide video output
            gr.update(visible=False),                   # Hide generation info
            gr.update(value=error_msg, visible=True)    # Show error message
        )

def create_interface():
    """Create the Gradio interface"""
    
    with gr.Blocks(
        title=UI_CONFIG["title"],
        theme=UI_CONFIG["theme"]
    ) as demo:
        
        # Header
        gr.Markdown(f"# {UI_CONFIG['title']}")
        gr.Markdown(UI_CONFIG["description"])
        
        # Model loading section
        with gr.Row():
            with gr.Column():
                load_btn = gr.Button(
                    "🚀 Load Video Generation Model", 
                    variant="primary", 
                    size="lg"
                )
                load_success_msg = gr.Markdown(visible=False)
                load_error_msg = gr.Markdown(visible=False)
        
        # Main generation interface (initially hidden)
        with gr.Column(visible=False) as generation_interface:
            
            # Input section
            with gr.Row():
                with gr.Column(scale=2):
                    with gr.Group():
                        gr.Markdown("### 📝 Concept & Prompts")
                        # The user supplies a high‑level concept or topic.  The agent will
                        # refine this into a detailed prompt automatically.
                        prompt_input = gr.Textbox(
                            label="Video Concept",
                            placeholder="Describe the concept you want to generate, e.g. 'a pig in a winter forest'...",
                            lines=3,
                            value="a pig moving quickly in a beautiful winter scenery nature trees sunset tracking camera"
                        )
                        # Optional negative prompt: overrides the agent's recommended negative prompt.
                        negative_prompt_input = gr.Textbox(
                            label="Negative Prompt (Optional)",
                            placeholder="Things you don't want in the video; leave empty to use the agent's recommendation...",
                            lines=2,
                            value=""
                        )
                
                with gr.Column(scale=1):
                    with gr.Group():
                        gr.Markdown("### ⚙️ Generation Parameters")
                        
                        with gr.Row():
                            width_slider = gr.Slider(
                                label="Width",
                                minimum=64,
                                maximum=1920,
                                step=8,
                                value=DEFAULT_PARAMS["width"]
                            )
                            height_slider = gr.Slider(
                                label="Height",
                                minimum=64,
                                maximum=1080,
                                step=8,
                                value=DEFAULT_PARAMS["height"]
                            )
                        
                        num_frames_slider = gr.Slider(
                            label="Number of Frames",
                            minimum=1,
                            maximum=200,
                            step=1,
                            value=DEFAULT_PARAMS["num_frames"]
                        )
                        
                        inference_steps_slider = gr.Slider(
                            label="Inference Steps",
                            minimum=1,
                            maximum=100,
                            step=1,
                            value=DEFAULT_PARAMS["num_inference_steps"]
                        )
                        
                        guidance_scale_slider = gr.Slider(
                            label="Guidance Scale",
                            minimum=0.0,
                            maximum=20.0,
                            step=0.1,
                            value=DEFAULT_PARAMS["guidance_scale"]
                        )
                        
                        seed_input = gr.Number(
                            label="Seed (Optional)",
                            value=0,
                            precision=0
                        )
            
            # Generation button
            with gr.Row():
                generate_btn = gr.Button(
                    "🎬 Generate Video",
                    variant="primary",
                    size="lg"
                )
            
            # Output section
            with gr.Row():
                with gr.Column():
                    video_output = gr.Video(
                        label="Generated Video",
                        visible=False
                    )
                    
                    generation_info = gr.Markdown(
                        label="Generation Information",
                        visible=False
                    )
                    
                    generation_error = gr.Markdown(
                        visible=False
                    )
            
            # Additional controls
            with gr.Row():
                with gr.Column():
                    gr.Markdown("""
                    ### 💡 Tips:
                    - Enter a short **concept** (e.g. “a busy city street at dawn”). The agent will expand it into a detailed prompt.
                    - Adjust the **guidance scale**: higher values make the video adhere more closely to the refined prompt.
                    - Increasing **inference steps** improves quality at the cost of generation time.
                    - Use the optional **Negative Prompt** field only if you want to override the agent's recommended terms.
                    - Keep width and height multiples of 8 for optimal performance.
                    """)
                
                with gr.Column():
                    if torch.cuda.is_available():
                        gpu_info = f"🎮 GPU: {torch.cuda.get_device_name()}"
                    else:
                        gpu_info = "💻 Running on CPU"
                    
                    gr.Markdown(f"""
                    ### 🖥️ System Information:
                    {gpu_info}
                    
                    ### 📊 Model Information:
                    - **Model:** WAN‑VACE 1.3B (Q4_0 Quantized)
                    - **Text Encoder:** UMT5‑XXL
                    - **Scheduler:** UniPC Multistep
                    
                    ### 🤖 Agent Details:
                    - **Planning:** The agent automatically crafts a detailed prompt and a recommended negative prompt based on your concept.
                    - **Override:** Supply your own negative prompt to override the recommendation if desired.
                    """)
        
        # Event handlers
        load_btn.click(
            fn=load_model_interface,
            outputs=[
                load_btn,
                generation_interface,
                load_success_msg,
                load_error_msg
            ]
        )
        
        generate_btn.click(
            fn=generate_video_interface,
            inputs=[
                prompt_input,
                negative_prompt_input,
                width_slider,
                height_slider,
                num_frames_slider,
                inference_steps_slider,
                guidance_scale_slider,
                seed_input
            ],
            outputs=[
                video_output,
                generation_info,
                generation_error
            ]
        )
    
    return demo

def main():
    """Main function to launch the application"""
    print(f"🚀 Starting {UI_CONFIG['title']}...")
    print(f"🔧 Server configuration: {SERVER_CONFIG['host']}:{SERVER_CONFIG['port']}")
    
    # Check GPU availability
    if torch.cuda.is_available():
        print(f"🎮 GPU detected: {torch.cuda.get_device_name()}")
        print(f"💾 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB")
    else:
        print("💻 Running on CPU (GPU recommended for better performance)")
    
    # Create interface and enable the event queue to support multiple users.  
    demo = create_interface()
    # Hugging Face Spaces expect `.queue()` to be called for handling request concurrency.  
    # Limiting concurrency_count to 1 helps prevent excessive memory usage on CPU-only hardware.
    demo = demo.queue()

    # Launch the interface.  
    demo.launch(
        server_name=SERVER_CONFIG["host"],
        server_port=SERVER_CONFIG["port"],
        share=SERVER_CONFIG["share"],
        show_error=True,
    )

if __name__ == "__main__":
    main()