Spaces:
Runtime error
Runtime error
| # inference_engine.py | |
| import os | |
| import torch | |
| import decord | |
| import imageio | |
| from PIL import Image | |
| from models import MTVCrafterPipeline, Encoder, VectorQuantizer, Decoder, SMPL_VQVAE | |
| from torchvision.transforms import ToPILImage, transforms, InterpolationMode, functional as F | |
| import numpy as np | |
| import pickle | |
| import copy | |
| from huggingface_hub import hf_hub_download | |
| from draw_pose import get_pose_images | |
| from utils import concat_images_grid, sample_video, get_sample_indexes, get_new_height_width | |
| def run_inference(device, motion_data_path, ref_image_path='', dst_width=512, dst_height=512, num_inference_steps=50, guidance_scale=3.0, seed=6666): | |
| num_frames = 49 | |
| to_pil = ToPILImage() | |
| normalize = transforms.Normalize([0.5], [0.5]) | |
| pretrained_model_path = "THUDM/CogVideoX-5b" | |
| transformer_path = "yanboding/MTVCrafter/MV-DiT/CogVideoX" | |
| tokenizer_path = "4DMoT/mp_rank_00_model_states.pt" | |
| with open(motion_data_path, 'rb') as f: | |
| data_list = pickle.load(f) | |
| if not isinstance(data_list, list): | |
| data_list = [data_list] | |
| pe_mean = np.load('data/mean.npy') | |
| pe_std = np.load('data/std.npy') | |
| pipe = MTVCrafterPipeline.from_pretrained( | |
| model_path=pretrained_model_path, | |
| transformer_model_path=transformer_path, | |
| torch_dtype=torch.bfloat16, | |
| scheduler_type='dpm', | |
| ).to(device) | |
| pipe.vae.enable_tiling() | |
| pipe.vae.enable_slicing() | |
| # load VQVAE | |
| vqvae_model_path = hf_hub_download( | |
| repo_id="yanboding/MTVCrafter", | |
| filename="4DMoT/mp_rank_00_model_states.pt" | |
| ) | |
| state_dict = torch.load(tokenizer_path, map_location="cpu") | |
| motion_encoder = Encoder(in_channels=3, mid_channels=[128, 512], out_channels=3072, downsample_time=[2, 2], downsample_joint=[1, 1]) | |
| motion_quant = VectorQuantizer(nb_code=8192, code_dim=3072, is_train=False) | |
| motion_decoder = Decoder(in_channels=3072, mid_channels=[512, 128], out_channels=3, upsample_rate=2.0, frame_upsample_rate=[2.0, 2.0], joint_upsample_rate=[1.0, 1.0]) | |
| vqvae = SMPL_VQVAE(motion_encoder, motion_decoder, motion_quant).to(device) | |
| vqvae.load_state_dict(state_dict['module'], strict=True) | |
| # 这里只跑第一个样本 | |
| data = data_list[0] | |
| new_height, new_width = get_new_height_width(data, dst_height, dst_width) | |
| x1 = (new_width - dst_width) // 2 | |
| y1 = (new_height - dst_height) // 2 | |
| sample_indexes = get_sample_indexes(data['video_length'], num_frames, stride=1) | |
| input_images = sample_video(decord.VideoReader(data['video_path']), sample_indexes) | |
| input_images = torch.from_numpy(input_images).permute(0, 3, 1, 2).contiguous() | |
| input_images = F.resize(input_images, (new_height, new_width), InterpolationMode.BILINEAR) | |
| input_images = F.crop(input_images, y1, x1, dst_height, dst_width) | |
| if ref_image_path != '': | |
| ref_image = Image.open(ref_image_path).convert("RGB") | |
| ref_image = torch.from_numpy(np.array(ref_image)).permute(2, 0, 1).contiguous() | |
| ref_images = torch.stack([ref_image.clone() for _ in range(num_frames)]) | |
| ref_images = F.resize(ref_images, (new_height, new_width), InterpolationMode.BILINEAR) | |
| ref_images = F.crop(ref_images, y1, x1, dst_height, dst_width) | |
| else: | |
| ref_images = copy.deepcopy(input_images) | |
| frame0 = input_images[0] | |
| ref_images[:, :, :, :] = frame0 | |
| try: | |
| smpl_poses = np.array([pose[0][0].cpu().numpy() for pose in data['pose']['joints3d_nonparam']]) | |
| poses = smpl_poses[sample_indexes] | |
| except: | |
| poses = data['pose'][sample_indexes] | |
| norm_poses = torch.tensor((poses - pe_mean) / pe_std) | |
| offset = [data['video_height'], data['video_width'], 0] | |
| pose_images_before = get_pose_images(copy.deepcopy(poses), offset) | |
| pose_images_before = [image.resize((new_width, new_height)).crop((x1, y1, x1+dst_width, y1+dst_height)) for image in pose_images_before] | |
| input_smpl_joints = norm_poses.unsqueeze(0).to(device) | |
| motion_tokens, vq_loss = vqvae(input_smpl_joints, return_vq=True) | |
| output_motion, _ = vqvae(input_smpl_joints) | |
| pose_images_after = get_pose_images(output_motion[0].cpu().detach() * pe_std + pe_mean, offset) | |
| pose_images_after = [image.resize((new_width, new_height)).crop((x1, y1, x1+dst_width, y1+dst_height)) for image in pose_images_after] | |
| # normalize images | |
| input_images = input_images / 255.0 | |
| ref_images = ref_images / 255.0 | |
| input_images = normalize(input_images) | |
| ref_images = normalize(ref_images) | |
| # infer | |
| output_images = pipe( | |
| height=dst_height, | |
| width=dst_width, | |
| num_frames=num_frames, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| seed=seed, | |
| ref_images=ref_images, | |
| motion_embeds=motion_tokens, | |
| joint_mean=pe_mean, | |
| joint_std=pe_std, | |
| ).frames[0] | |
| # save result | |
| vis_images = [] | |
| for k in range(len(output_images)): | |
| vis_image = [to_pil(((input_images[k] + 1) * 127.5).clamp(0, 255).to(torch.uint8)), pose_images_before[k], pose_images_after[k], output_images[k]] | |
| vis_image = concat_images_grid(vis_image, cols=len(vis_image), pad=2) | |
| vis_images.append(vis_image) | |
| output_path = "output.mp4" | |
| imageio.mimsave(output_path, vis_images, fps=15) | |
| return output_path | |