import argparse import os from datetime import datetime from pathlib import Path from typing import List import av import numpy as np import torch import torchvision from diffusers import AutoencoderKL, DDIMScheduler from omegaconf import OmegaConf from PIL import Image from transformers import CLIPVisionModelWithProjection from src.models.pose_guider import PoseGuider from src.models.unet_2d_condition import UNet2DConditionModel from src.models.unet_3d_edit_bkfill import UNet3DConditionModel from src.pipelines.pipeline_pose2vid_long_edit_bkfill_roiclip import Pose2VideoPipeline from src.utils.util import get_fps, read_frames import cv2 from tools.human_segmenter import human_segmenter import imageio from tools.util import all_file, load_mask_list, crop_img, pad_img, crop_human_clip_auto_context, get_mask, \ refine_img_prepross, init_bk import gradio as gr import json MOTION_TRIGGER_WORD = { 'sports_basketball_gym': [], 'sports_nba_pass': [], 'sports_nba_dunk': [], 'movie_BruceLee1': [], 'shorts_kungfu_match1': [], 'shorts_kungfu_desert1': [], 'parkour_climbing': [], 'dance_indoor_1': [], 'syn_basketball_06_13': [], 'syn_dancing2_00093_irish_dance': [], 'syn_football_10_05': [], } css_style = "#fixed_size_img {height: 500px;}" seg_path = './assets/matting_human.pb' try: if os.path.exists(seg_path): segmenter = human_segmenter(model_path=seg_path) print("✅ Human segmenter loaded successfully") else: segmenter = None print("⚠️ Segmenter model not found, using fallback segmentation") except Exception as e: segmenter = None print(f"⚠️ Failed to load segmenter: {e}, using fallback") def process_seg(img): """Process image segmentation with fallback""" if segmenter is not None: try: rgba = segmenter.run(img) mask = rgba[:, :, 3] color = rgba[:, :, :3] alpha = mask / 255 bk = np.ones_like(color) * 255 color = color * alpha[:, :, np.newaxis] + bk * (1 - alpha[:, :, np.newaxis]) color = color.astype(np.uint8) return color, mask except Exception as e: print(f"⚠️ Segmentation failed: {e}, using simple crop") # Fallback: return original image with simple center crop h, w = img.shape[:2] margin = min(h, w) // 10 mask = np.zeros((h, w), dtype=np.uint8) mask[margin:-margin, margin:-margin] = 255 return img, mask def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default='./configs/prompts/animation_edit.yaml') parser.add_argument("-W", type=int, default=512) parser.add_argument("-H", type=int, default=512) parser.add_argument("-L", type=int, default=64) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--cfg", type=float, default=3.5) parser.add_argument("--steps", type=int, default=10) parser.add_argument("--fps", type=int) parser.add_argument("--assets_dir", type=str, default='./assets') parser.add_argument("--ref_pad", type=int, default=1) parser.add_argument("--use_bk", type=int, default=1) parser.add_argument("--clip_length", type=int, default=16) parser.add_argument("--MAX_FRAME_NUM", type=int, default=150) args = parser.parse_args() return args class MIMO(): def __init__(self, debug_mode=False): try: args = parse_args() config = OmegaConf.load(args.config) # Check if running on CPU or GPU device = "cuda" if torch.cuda.is_available() else "cpu" if device == "cpu": print("⚠️ CUDA not available, running on CPU (will be slow)") weight_dtype = torch.float32 else: if config.weight_dtype == "fp16": weight_dtype = torch.float16 else: weight_dtype = torch.float32 print(f"✅ Using device: {device} with dtype: {weight_dtype}") vae = AutoencoderKL.from_pretrained( config.pretrained_vae_path, ).to(device, dtype=weight_dtype) reference_unet = UNet2DConditionModel.from_pretrained( config.pretrained_base_model_path, subfolder="unet", ).to(dtype=weight_dtype, device=device) inference_config_path = config.inference_config infer_config = OmegaConf.load(inference_config_path) denoising_unet = UNet3DConditionModel.from_pretrained_2d( config.pretrained_base_model_path, config.motion_module_path, subfolder="unet", unet_additional_kwargs=infer_config.unet_additional_kwargs, ).to(dtype=weight_dtype, device=device) pose_guider = PoseGuider(320, conditioning_channels=3, block_out_channels=(16, 32, 96, 256)).to( dtype=weight_dtype, device=device ) image_enc = CLIPVisionModelWithProjection.from_pretrained( config.image_encoder_path ).to(dtype=weight_dtype, device=device) sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) scheduler = DDIMScheduler(**sched_kwargs) self.generator = torch.manual_seed(args.seed) self.width, self.height = args.W, args.H # load pretrained weights with error handling try: if os.path.exists(config.denoising_unet_path): denoising_unet.load_state_dict( torch.load(config.denoising_unet_path, map_location="cpu"), strict=False, ) print("✅ Denoising UNet weights loaded") else: print(f"❌ Denoising UNet weights not found: {config.denoising_unet_path}") if os.path.exists(config.reference_unet_path): reference_unet.load_state_dict( torch.load(config.reference_unet_path, map_location="cpu"), ) print("✅ Reference UNet weights loaded") else: print(f"❌ Reference UNet weights not found: {config.reference_unet_path}") if os.path.exists(config.pose_guider_path): pose_guider.load_state_dict( torch.load(config.pose_guider_path, map_location="cpu"), ) print("✅ Pose guider weights loaded") else: print(f"❌ Pose guider weights not found: {config.pose_guider_path}") except Exception as e: print(f"⚠️ Error loading model weights: {e}") raise self.pipe = Pose2VideoPipeline( vae=vae, image_encoder=image_enc, reference_unet=reference_unet, denoising_unet=denoising_unet, pose_guider=pose_guider, scheduler=scheduler, ) self.pipe = self.pipe.to(device, dtype=weight_dtype) self.args = args # load mask with error handling mask_path = os.path.join(self.args.assets_dir, 'masks', 'alpha2.png') try: if os.path.exists(mask_path): self.mask_list = load_mask_list(mask_path) print("✅ Mask list loaded") else: self.mask_list = None print("⚠️ Mask file not found, using fallback masking") except Exception as e: self.mask_list = None print(f"⚠️ Failed to load mask: {e}") print("✅ MIMO model initialized successfully") except Exception as e: print(f"❌ Failed to initialize MIMO model: {e}") raise def load_template(self, template_path): """Load template with error handling""" if not os.path.exists(template_path): raise FileNotFoundError(f"Template path does not exist: {template_path}") video_path = os.path.join(template_path, 'vid.mp4') pose_video_path = os.path.join(template_path, 'sdc.mp4') bk_video_path = os.path.join(template_path, 'bk.mp4') occ_video_path = os.path.join(template_path, 'occ.mp4') # Check essential files if not os.path.exists(video_path): raise FileNotFoundError(f"Required video file missing: {video_path}") if not os.path.exists(pose_video_path): raise FileNotFoundError(f"Required pose video missing: {pose_video_path}") if not os.path.exists(occ_video_path): occ_video_path = None if not os.path.exists(bk_video_path): print(f"⚠️ Background video not found: {bk_video_path}, will generate white background") bk_video_path = None config_file = os.path.join(template_path, 'config.json') if not os.path.exists(config_file): print(f"⚠️ Config file missing: {config_file}, using default settings") template_data = { 'fps': 30, 'time_crop': {'start_idx': 0, 'end_idx': 1000}, 'frame_crop': {'start_idx': 0, 'end_idx': 1000}, 'layer_recover': True } else: with open(config_file) as f: template_data = json.load(f) template_info = {} template_info['video_path'] = video_path template_info['pose_video_path'] = pose_video_path template_info['bk_video_path'] = bk_video_path template_info['occ_video_path'] = occ_video_path template_info['target_fps'] = template_data.get('fps', 30) template_info['time_crop'] = template_data.get('time_crop', {'start_idx': 0, 'end_idx': 1000}) template_info['frame_crop'] = template_data.get('frame_crop', {'start_idx': 0, 'end_idx': 1000}) template_info['layer_recover'] = template_data.get('layer_recover', True) return template_info def run(self, ref_image_pil, template_name): template_dir = os.path.join(self.args.assets_dir, 'video_template') template_path = os.path.join(template_dir, template_name) template_info = self.load_template(template_path) target_fps = template_info['target_fps'] video_path = template_info['video_path'] pose_video_path = template_info['pose_video_path'] bk_video_path = template_info['bk_video_path'] occ_video_path = template_info['occ_video_path'] # ref_image_pil = Image.open(ref_img_path).convert('RGB') source_image = np.array(ref_image_pil) source_image, mask = process_seg(source_image[..., ::-1]) source_image = source_image[..., ::-1] source_image = crop_img(source_image, mask) source_image, _ = pad_img(source_image, [255, 255, 255]) ref_image_pil = Image.fromarray(source_image) # load tgt vid_images = read_frames(video_path) if bk_video_path is None: n_frame = len(vid_images) tw, th = vid_images[0].size bk_images = init_bk(n_frame, th, tw) # Fixed parameter order: n_frame, height, width else: bk_images = read_frames(bk_video_path) if occ_video_path is not None: occ_mask_images = read_frames(occ_video_path) print('load occ from %s' % occ_video_path) else: occ_mask_images = None print('no occ masks') pose_images = read_frames(pose_video_path) src_fps = get_fps(pose_video_path) start_idx, end_idx = template_info['time_crop']['start_idx'], template_info['time_crop']['end_idx'] start_idx = max(0, start_idx) end_idx = min(len(pose_images), end_idx) pose_images = pose_images[start_idx:end_idx] vid_images = vid_images[start_idx:end_idx] bk_images = bk_images[start_idx:end_idx] if occ_mask_images is not None: occ_mask_images = occ_mask_images[start_idx:end_idx] self.args.L = len(pose_images) max_n_frames = self.args.clip_length # Use clip_length instead of MAX_FRAME_NUM for faster inference if self.args.L > max_n_frames: pose_images = pose_images[:max_n_frames] vid_images = vid_images[:max_n_frames] bk_images = bk_images[:max_n_frames] if occ_mask_images is not None: occ_mask_images = occ_mask_images[:max_n_frames] self.args.L = len(pose_images) bk_images_ori = bk_images.copy() vid_images_ori = vid_images.copy() overlay = 4 pose_images, vid_images, bk_images, bbox_clip, context_list, bbox_clip_list = crop_human_clip_auto_context( pose_images, vid_images, bk_images, overlay) clip_pad_list_context = [] clip_padv_list_context = [] pose_list_context = [] vid_bk_list_context = [] for frame_idx in range(len(pose_images)): pose_image_pil = pose_images[frame_idx] pose_image = np.array(pose_image_pil) pose_image, _ = pad_img(pose_image, color=[0, 0, 0]) pose_image_pil = Image.fromarray(pose_image) pose_list_context.append(pose_image_pil) vid_bk = bk_images[frame_idx] vid_bk = np.array(vid_bk) vid_bk, padding_v = pad_img(vid_bk, color=[255, 255, 255]) pad_h, pad_w, _ = vid_bk.shape clip_pad_list_context.append([pad_h, pad_w]) clip_padv_list_context.append(padding_v) vid_bk_list_context.append(Image.fromarray(vid_bk)) print('start to infer...') print(f'📊 Inference params: frames={len(pose_list_context)}, size={self.width}x{self.height}, steps={self.args.steps}') try: video = self.pipe( ref_image_pil, pose_list_context, vid_bk_list_context, self.width, self.height, len(pose_list_context), self.args.steps, self.args.cfg, generator=self.generator, ).videos[0] print('✅ Inference completed successfully') except Exception as e: print(f'❌ Inference failed: {e}') import traceback traceback.print_exc() return None # post-process video video_idx = 0 res_images = [None for _ in range(self.args.L)] for k, context in enumerate(context_list): start_i = context[0] bbox = bbox_clip_list[k] for i in context: bk_image_pil_ori = bk_images_ori[i] vid_image_pil_ori = vid_images_ori[i] if occ_mask_images is not None: occ_mask = occ_mask_images[i] else: occ_mask = None canvas = Image.new("RGB", bk_image_pil_ori.size, "white") pad_h, pad_w = clip_pad_list_context[video_idx] padding_v = clip_padv_list_context[video_idx] image = video[:, video_idx, :, :].permute(1, 2, 0).cpu().numpy() res_image_pil = Image.fromarray((image * 255).astype(np.uint8)) res_image_pil = res_image_pil.resize((pad_w, pad_h)) top, bottom, left, right = padding_v res_image_pil = res_image_pil.crop((left, top, pad_w - right, pad_h - bottom)) w_min, w_max, h_min, h_max = bbox canvas.paste(res_image_pil, (w_min, h_min)) mask_full = np.zeros((bk_image_pil_ori.size[1], bk_image_pil_ori.size[0]), dtype=np.float32) res_image = np.array(canvas) bk_image = np.array(bk_image_pil_ori) mask = get_mask(self.mask_list, bbox, bk_image_pil_ori) mask = cv2.resize(mask, res_image_pil.size, interpolation=cv2.INTER_AREA) mask_full[h_min:h_min + mask.shape[0], w_min:w_min + mask.shape[1]] = mask res_image = res_image * mask_full[:, :, np.newaxis] + bk_image * (1 - mask_full[:, :, np.newaxis]) if occ_mask is not None: vid_image = np.array(vid_image_pil_ori) occ_mask = np.array(occ_mask)[:, :, 0].astype(np.uint8) # [0,255] occ_mask = occ_mask / 255.0 res_image = res_image * (1 - occ_mask[:, :, np.newaxis]) + vid_image * occ_mask[:, :, np.newaxis] if res_images[i] is None: res_images[i] = res_image else: factor = (i - start_i + 1) / (overlay + 1) res_images[i] = res_images[i] * (1 - factor) + res_image * factor res_images[i] = res_images[i].astype(np.uint8) video_idx = video_idx + 1 return res_images class WebApp(): def __init__(self, debug_mode=False): self.args_base = { "device": "cuda", "output_dir": "output_demo", "img": None, "pos_prompt": '', "motion": "sports_basketball_gym", "motion_dir": "./assets/test_video_trunc", } self.args_input = {} # for gr.components only self.gr_motion = list(MOTION_TRIGGER_WORD.keys()) # fun fact: google analytics doesn't work in this space currently self.gtag = os.environ.get('GTag') self.ga_script = f""" """ self.ga_load = f""" function() {{ window.dataLayer = window.dataLayer || []; function gtag(){{dataLayer.push(arguments);}} gtag('js', new Date()); gtag('config', '{self.gtag}'); }} """ # # pre-download base model for better user experience try: self.model = MIMO() print("✅ MIMO model loaded successfully") except Exception as e: print(f"❌ Failed to load MIMO model: {e}") self.model = None self.debug_mode = debug_mode # turn off clip interrogator when debugging for faster building speed def title(self): gr.HTML( """
""" ) def get_template(self, num_cols=3): self.args_input['motion'] = gr.State('sports_basketball_gym') num_cols = 2 # Use thumbnails instead of videos for gallery display thumb_dir = "./assets/thumbnails" gallery_items = [] for motion in self.gr_motion: thumb_path = os.path.join(thumb_dir, f"{motion}.jpg") if os.path.exists(thumb_path): gallery_items.append((thumb_path, motion)) else: # Fallback to a placeholder or skip print(f"⚠️ Thumbnail not found: {thumb_path}") lora_gallery = gr.Gallery(label='Motion Templates', columns=num_cols, height=500, value=gallery_items, show_label=True) lora_gallery.select(self._update_selection, inputs=[], outputs=[self.args_input['motion']]) print(self.args_input['motion']) def _update_selection(self, selected_state: gr.SelectData): return self.gr_motion[selected_state.index] def run_process(self, *values): if self.model is None: print("❌ MIMO model not loaded. Please check dependencies and model weights.") return None try: gr_args = self.args_base.copy() print(self.args_input.keys()) for k, v in zip(list(self.args_input.keys()), values): gr_args[k] = v ref_image_pil = gr_args['img'] # pil image if ref_image_pil is None: print("⚠️ Please upload an image first.") return None template_name = gr_args['motion'] print('template_name:', template_name) save_dir = 'output' if not os.path.exists(save_dir): os.makedirs(save_dir) # generate uuid case = datetime.now().strftime("%Y%m%d%H%M%S") outpath = f"{save_dir}/{case}.mp4" res = self.model.run(ref_image_pil, template_name) if not res: print("❌ Video generation failed. Please check template and try again.") return None imageio.mimsave(outpath, res, fps=30, quality=8, macro_block_size=1) print('save to %s' % outpath) return outpath except Exception as e: print(f"❌ Error during processing: {e}") # Don't return error string - Gradio Video expects file path or None # Create a simple error video or return None return None def preset_library(self): with gr.Blocks() as demo: with gr.Accordion(label="🧭 Guidance:", open=True, elem_id="accordion"): with gr.Row(equal_height=True): gr.Markdown(""" - ⭐️ step1:Upload a character image or select one from the examples - ⭐️ step2:Choose a motion template from the gallery - ⭐️ step3:Click "Run" to generate the animation - Note: The input character image should be full-body, front-facing, no occlusion, no handheld objects """) with gr.Row(): img_input = gr.Image(label='Input image', type="pil", elem_id="fixed_size_img") self.args_input['img'] = img_input with gr.Column(): self.get_template(num_cols=3) submit_btn_load3d = gr.Button("Run", variant='primary') with gr.Column(scale=1): res_vid = gr.Video(format="mp4", label="Generated Result", autoplay=True, elem_id="fixed_size_img") submit_btn_load3d.click(self.run_process, inputs=list(self.args_input.values()), outputs=[res_vid], scroll_to_output=True, ) # Create examples list with only existing files example_images = [] possible_examples = [ './assets/test_image/sugar.jpg', './assets/test_image/ouwen1.png', './assets/test_image/actorhq_A1S1.png', './assets/test_image/actorhq_A7S1.png', './assets/test_image/cartoon1.png', './assets/test_image/cartoon2.png', './assets/test_image/sakura.png', './assets/test_image/kakashi.png', './assets/test_image/sasuke.png', './assets/test_image/avatar.jpg', ] for img_path in possible_examples: if os.path.exists(img_path): example_images.append([img_path]) if example_images: gr.Examples(examples=example_images, inputs=[img_input], examples_per_page=20, label="Examples", elem_id="examples", ) else: gr.Markdown("⚠️ No example images found. Please upload your own image.") def ui(self): with gr.Blocks(css=css_style) as demo: self.title() self.preset_library() demo.load(None, js=self.ga_load) return demo app = WebApp(debug_mode=False) demo = app.ui() if __name__ == "__main__": demo.queue(max_size=100) # For Hugging Face Spaces demo.launch(server_name="0.0.0.0", server_port=7860, share=False)