Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| import mediapipe as mp | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionControlNetInpaintPipeline | |
| from transformers import AutoTokenizer | |
| import base64 | |
| import requests | |
| import json | |
| from rembg import remove | |
| from scipy import ndimage | |
| from moviepy.editor import ImageSequenceClip | |
| from tqdm import tqdm | |
| import os | |
| import shutil | |
| import time | |
| from huggingface_hub import snapshot_download | |
| import subprocess | |
| import sys | |
| def download_liveportrait(): | |
| """ | |
| Clone the LivePortrait repository and prepare its dependencies. | |
| """ | |
| liveportrait_path = "./LivePortrait" | |
| try: | |
| if not os.path.exists(liveportrait_path): | |
| print("Cloning LivePortrait repository...") | |
| os.system(f"git clone https://github.com/KwaiVGI/LivePortrait.git {liveportrait_path}") | |
| os.chdir(liveportrait_path) | |
| print("Installing LivePortrait dependencies...") | |
| os.system("pip install -r requirements.txt") | |
| dependency_path = "src/utils/dependencies/XPose/models/UniPose/ops" | |
| os.chdir(dependency_path) | |
| print("Building MultiScaleDeformableAttention...") | |
| os.system("python setup.py build") | |
| os.system("python setup.py install") | |
| module_path = os.path.abspath(dependency_path) | |
| if module_path not in sys.path: | |
| sys.path.append(module_path) | |
| os.chdir("../../../../../../../") | |
| print("LivePortrait setup completed") | |
| except Exception as e: | |
| print("Failed to initialize LivePortrait:", e) | |
| raise | |
| download_liveportrait() | |
| def download_huggingface_resources(): | |
| """ | |
| Download additional necessary resources from Hugging Face using the CLI. | |
| """ | |
| try: | |
| local_dir = "./pretrained_weights" | |
| os.makedirs(local_dir, exist_ok=True) | |
| # Use the Hugging Face CLI for downloading | |
| cmd = [ | |
| "huggingface-cli", "download", | |
| "KwaiVGI/LivePortrait", | |
| "--local-dir", local_dir, | |
| "--exclude", "*.git*", "README.md", "docs" | |
| ] | |
| print("Executing command:", " ".join(cmd)) | |
| subprocess.run(cmd, check=True) | |
| print("Resources successfully downloaded to:", local_dir) | |
| except subprocess.CalledProcessError as e: | |
| print("Error during Hugging Face CLI download:", e) | |
| raise | |
| except Exception as e: | |
| print("General error in downloading resources:", e) | |
| raise | |
| download_huggingface_resources() | |
| def get_project_root(): | |
| """Get the root directory of the current project.""" | |
| return os.path.abspath(os.path.dirname(__file__)) | |
| # Ensure working directory is project root | |
| os.chdir(get_project_root()) | |
| # Initialize the necessary models and components | |
| mp_pose = mp.solutions.pose | |
| mp_drawing = mp.solutions.drawing_utils | |
| # Load ControlNet model | |
| controlnet = ControlNetModel.from_pretrained('lllyasviel/sd-controlnet-openpose', torch_dtype=torch.float16) | |
| # Load Stable Diffusion model with ControlNet | |
| pipe_controlnet = StableDiffusionControlNetPipeline.from_pretrained( | |
| 'runwayml/stable-diffusion-v1-5', | |
| controlnet=controlnet, | |
| torch_dtype=torch.float16 | |
| ) | |
| # Load Inpaint Controlnet | |
| pipe_inpaint_controlnet = StableDiffusionControlNetInpaintPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-inpainting", | |
| controlnet=controlnet, | |
| torch_dtype=torch.float16 | |
| ) | |
| # Move to GPU if available | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| pipe_controlnet.to(device) | |
| pipe_controlnet.enable_attention_slicing() | |
| pipe_inpaint_controlnet.to(device) | |
| pipe_inpaint_controlnet.enable_attention_slicing() | |
| def resize_to_multiple_of_64(width, height): | |
| return (width // 64) * 64, (height // 64) * 64 | |
| def expand_mask(mask, kernel_size): | |
| mask_array = np.array(mask) | |
| structuring_element = np.ones((kernel_size, kernel_size), dtype=np.uint8) | |
| expanded_mask_array = ndimage.binary_dilation( | |
| mask_array, structure=structuring_element | |
| ).astype(np.uint8) * 255 | |
| return Image.fromarray(expanded_mask_array) | |
| def crop_face_to_square(image_rgb, padding_ratio=0.2, height_multiplier=1.2): | |
| """ | |
| Detect the face and crop a rectangular region that includes more of the body below the face. | |
| Instead of centering around the face, we start near the face region and extend downward. | |
| """ | |
| face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') | |
| gray_image = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2GRAY) | |
| faces = face_cascade.detectMultiScale(gray_image, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30)) | |
| if len(faces) == 0: | |
| print("No face detected.") | |
| return None | |
| x, y, w, h = faces[0] | |
| face_x_center = x + w // 2 | |
| face_y_top = y | |
| face_side_length = max(w, h) | |
| padded_side_length = int(face_side_length * (1 + padding_ratio)) | |
| cropped_width = padded_side_length | |
| cropped_height = int(padded_side_length * height_multiplier) | |
| top_left_x = max(face_x_center - cropped_width // 2, 0) | |
| top_margin = int(padded_side_length * 0.1) | |
| top_left_y = max(face_y_top - top_margin, 0) | |
| bottom_right_x = min(top_left_x + cropped_width, image_rgb.shape[1]) | |
| bottom_right_y = min(top_left_y + cropped_height, image_rgb.shape[0]) | |
| cropped_image = image_rgb[top_left_y:bottom_right_y, top_left_x:bottom_right_x] | |
| return cropped_image | |
| def spirit_animal_baseline(image_path, num_images = 4): | |
| image = cv2.imread(image_path) | |
| image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| image_rgb = crop_face_to_square(image_rgb) | |
| original_height, original_width, _ = image_rgb.shape | |
| aspect_ratio = original_width / original_height | |
| if aspect_ratio > 1: | |
| gen_width = 768 | |
| gen_height = int(gen_width / aspect_ratio) | |
| else: | |
| gen_height = 768 | |
| gen_width = int(gen_height * aspect_ratio) | |
| gen_width, gen_height = resize_to_multiple_of_64(gen_width, gen_height) | |
| with mp_pose.Pose(static_image_mode=True) as pose: | |
| results = pose.process(image_rgb) | |
| if results.pose_landmarks: | |
| annotated_image = image_rgb.copy() | |
| mp_drawing.draw_landmarks( | |
| annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS | |
| ) | |
| else: | |
| print("No pose detected.") | |
| return "No pose detected.", [] | |
| pose_image = np.zeros_like(image_rgb) | |
| for connection in mp_pose.POSE_CONNECTIONS: | |
| start_idx, end_idx = connection | |
| start, end = results.pose_landmarks.landmark[start_idx], results.pose_landmarks.landmark[end_idx] | |
| if start.visibility > 0.5 and end.visibility > 0.5: | |
| x1, y1 = int(start.x * pose_image.shape[1]), int(start.y * pose_image.shape[0]) | |
| x2, y2 = int(end.x * pose_image.shape[1]), int(end.y * pose_image.shape[0]) | |
| cv2.line(pose_image, (x1, y1), (x2, y2), (255, 255, 255), 2) | |
| pose_pil = Image.fromarray(cv2.resize(pose_image, (gen_width, gen_height), interpolation=cv2.INTER_LANCZOS4)) | |
| base64_image = base64.b64encode(cv2.imencode('.jpg', image_rgb)[1]).decode() | |
| api_key = os.getenv("GPT_KEY") | |
| headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} | |
| payload = { | |
| "model": "gpt-4o-mini", | |
| "messages": [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": "Based on the provided image, think of one spirit animal that is right for the person, and answer in the following format: An ultra-realistic, highly detailed photograph of a single {animal} with facial features characterized by {description}, standing upright in a human-like pose, looking directly at the camera, against a solid, neutral background. Generate one sentence without any other responses or numbering."}, | |
| {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} | |
| ] | |
| } | |
| ], | |
| "max_tokens": 100 | |
| } | |
| response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload) | |
| prompt = response.json()['choices'][0]['message']['content'] if 'choices' in response.json() else "A majestic animal" | |
| num_images = num_images | |
| generated_images = [] | |
| with torch.no_grad(): | |
| with torch.autocast(device_type=device.type): | |
| for _ in range(num_images): | |
| images = pipe_controlnet( | |
| prompt=prompt, | |
| negative_prompt=( | |
| "multiple heads, two heads, double head, triple head, extra limbs, extra arms, extra legs, " | |
| "duplicate faces, multiple faces, mutated anatomy, deformed, disfigured, malformed, " | |
| "extra ears, fused ears, blurred, low resolution, cartoonish, watermark, text, logo, " | |
| "poorly drawn, distorted, floating limbs, out-of-frame" | |
| ), | |
| num_inference_steps=20, | |
| image=pose_pil, | |
| guidance_scale=5, | |
| width=gen_width, | |
| height=gen_height, | |
| ).images | |
| generated_images.append(images[0]) | |
| return prompt, generated_images | |
| def spirit_animal_with_background(image_path, num_images = 4): | |
| image = cv2.imread(image_path) | |
| image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| # image_rgb = crop_face_to_square(image_rgb) | |
| original_height, original_width, _ = image_rgb.shape | |
| aspect_ratio = original_width / original_height | |
| if aspect_ratio > 1: | |
| gen_width = 768 | |
| gen_height = int(gen_width / aspect_ratio) | |
| else: | |
| gen_height = 768 | |
| gen_width = int(gen_height * aspect_ratio) | |
| gen_width, gen_height = resize_to_multiple_of_64(gen_width, gen_height) | |
| with mp_pose.Pose(static_image_mode=True) as pose: | |
| results = pose.process(image_rgb) | |
| if results.pose_landmarks: | |
| annotated_image = image_rgb.copy() | |
| mp_drawing.draw_landmarks( | |
| annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS | |
| ) | |
| else: | |
| print("No pose detected.") | |
| return "No pose detected.", [] | |
| pose_image = np.zeros_like(image_rgb) | |
| for connection in mp_pose.POSE_CONNECTIONS: | |
| start_idx, end_idx = connection | |
| start, end = results.pose_landmarks.landmark[start_idx], results.pose_landmarks.landmark[end_idx] | |
| if start.visibility > 0.5 and end.visibility > 0.5: | |
| x1, y1 = int(start.x * pose_image.shape[1]), int(start.y * pose_image.shape[0]) | |
| x2, y2 = int(end.x * pose_image.shape[1]), int(end.y * pose_image.shape[0]) | |
| cv2.line(pose_image, (x1, y1), (x2, y2), (255, 255, 255), 2) | |
| pose_pil = Image.fromarray(cv2.resize(pose_image, (gen_width, gen_height), interpolation=cv2.INTER_LANCZOS4)) | |
| base64_image = base64.b64encode(cv2.imencode('.jpg', image_rgb)[1]).decode() | |
| api_key = os.getenv("GPT_KEY") | |
| headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} | |
| payload = { | |
| "model": "gpt-4o-mini", | |
| "messages": [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": "Based on the provided image, think of one spirit animal that is right for the person, and answer in the following format: An ultra-realistic, highly detailed photograph of a single {animal} with facial features characterized by {description}, standing upright in a human-like pose, looking directly at the camera, against a solid, neutral background. Generate one sentence without any other responses or numbering."}, | |
| {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} | |
| ] | |
| } | |
| ], | |
| "max_tokens": 100 | |
| } | |
| response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload) | |
| prompt = response.json()['choices'][0]['message']['content'] if 'choices' in response.json() else "A majestic animal" | |
| mask_image = remove(Image.fromarray(image_rgb)) | |
| initial_mask = mask_image.split()[-1].convert('L') | |
| kernel_size = min(gen_width, gen_height) // 15 | |
| expanded_mask = expand_mask(initial_mask, kernel_size) | |
| num_images = num_images | |
| generated_images = [] | |
| with torch.no_grad(): | |
| with torch.autocast(device_type=device.type): | |
| for _ in range(num_images): | |
| images = pipe_inpaint_controlnet( | |
| prompt=prompt, | |
| negative_prompt=( | |
| "multiple heads, two heads, double head, triple head, extra limbs, extra arms, extra legs, " | |
| "duplicate faces, multiple faces, mutated anatomy, deformed, disfigured, malformed, " | |
| "extra ears, fused ears, blurred, low resolution, cartoonish, watermark, text, logo, " | |
| "poorly drawn, distorted, floating limbs, out-of-frame" | |
| ), | |
| num_inference_steps=20, | |
| image=Image.fromarray(image_rgb), | |
| mask_image=expanded_mask, | |
| control_image=pose_pil, | |
| width=gen_width, | |
| height=gen_height, | |
| guidance_scale=5, | |
| ).images | |
| generated_images.append(images[0]) | |
| return prompt, generated_images | |
| def generate_multiple_animals(image_path, keep_background=True, num_images = 4, height_multiplier = 1.5): | |
| image = cv2.imread(image_path) | |
| image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| image_rgb = crop_face_to_square(image_rgb, height_multiplier = height_multiplier) | |
| original_image = Image.fromarray(image_rgb) | |
| original_width, original_height = original_image.size | |
| aspect_ratio = original_width / original_height | |
| if aspect_ratio > 1: | |
| gen_width = 768 | |
| gen_height = int(gen_width / aspect_ratio) | |
| else: | |
| gen_height = 768 | |
| gen_width = int(gen_height * aspect_ratio) | |
| gen_width, gen_height = resize_to_multiple_of_64(gen_width, gen_height) | |
| base64_image = base64.b64encode(cv2.imencode('.jpg', image_rgb)[1]).decode() | |
| api_key = os.getenv("GPT_KEY") | |
| headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} | |
| payload = { | |
| "model": "gpt-4o-mini", | |
| "messages": [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": "Based on the provided image, think of " + str(num_images) + " different spirit animals that are right for the person, and answer in the following format for each: An ultra-realistic, highly detailed photograph of a {animal} with facial features characterized by {description}, standing upright in a human-like pose, looking directly at the camera, against a solid, neutral background. Generate these sentences without any other responses or numbering. For the animal choose between owl, bear, fox, koala, lion, dog" | |
| }, | |
| { | |
| "type": "image_url", | |
| "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"} | |
| } | |
| ] | |
| } | |
| ], | |
| "max_tokens": 500 | |
| } | |
| response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload) | |
| response_json = response.json() | |
| if 'choices' in response_json and len(response_json['choices']) > 0: | |
| content = response_json['choices'][0]['message']['content'] | |
| prompts = [prompt.strip() for prompt in content.strip().split('.') if prompt.strip()] | |
| negative_prompt=( | |
| "multiple heads, two heads, double head, triple head, extra limbs, extra arms, extra legs, " | |
| "duplicate faces, multiple faces, mutated anatomy, deformed, disfigured, malformed, " | |
| "extra ears, fused ears, blurred, low resolution, cartoonish, watermark, text, logo, " | |
| "poorly drawn, distorted, floating limbs, out-of-frame") | |
| formatted_prompts = "\n".join(f"{i+1}. {prompt}" for i, prompt in enumerate(prompts)) | |
| with mp_pose.Pose(static_image_mode=True) as pose: | |
| results = pose.process(image_rgb) | |
| if results.pose_landmarks: | |
| annotated_image = image_rgb.copy() | |
| mp_drawing.draw_landmarks( | |
| annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS | |
| ) | |
| else: | |
| print("No pose detected.") | |
| return "No pose detected.", [] | |
| pose_image = np.zeros_like(image_rgb) | |
| for connection in mp_pose.POSE_CONNECTIONS: | |
| start_idx, end_idx = connection | |
| start, end = results.pose_landmarks.landmark[start_idx], results.pose_landmarks.landmark[end_idx] | |
| if start.visibility > 0.5 and end.visibility > 0.5: | |
| x1, y1 = int(start.x * pose_image.shape[1]), int(start.y * pose_image.shape[0]) | |
| x2, y2 = int(end.x * pose_image.shape[1]), int(end.y * pose_image.shape[0]) | |
| cv2.line(pose_image, (x1, y1), (x2, y2), (255, 255, 255), 2) | |
| pose_pil = Image.fromarray(cv2.resize(pose_image, (gen_width, gen_height), interpolation=cv2.INTER_LANCZOS4)) | |
| if keep_background: | |
| mask_image = remove(original_image) | |
| initial_mask = mask_image.split()[-1].convert('L') | |
| expanded_mask = expand_mask(initial_mask, kernel_size=min(gen_width, gen_height) // 15) | |
| else: | |
| expanded_mask = None | |
| generated_images = [] | |
| if keep_background: | |
| with torch.no_grad(): | |
| with torch.amp.autocast("cuda"): | |
| for prompt in prompts: | |
| images = pipe_inpaint_controlnet( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=20, | |
| image=Image.fromarray(image_rgb), | |
| mask_image=expanded_mask, | |
| control_image=pose_pil, | |
| width=gen_width, | |
| height=gen_height, | |
| guidance_scale=5, | |
| ).images | |
| generated_images.append(images[0]) | |
| else: | |
| with torch.no_grad(): | |
| with torch.amp.autocast("cuda"): | |
| for prompt in prompts: | |
| images = pipe_controlnet( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=20, | |
| image=pose_pil, | |
| guidance_scale=5, | |
| width=gen_width, | |
| height=gen_height, | |
| ).images | |
| generated_images.append(images[0]) | |
| return formatted_prompts, generated_images | |
| def wait_for_file(file_path, timeout=500): | |
| """ | |
| Wait for a file to be created, with a specified timeout. | |
| Args: | |
| file_path (str): The path of the file to wait for. | |
| timeout (int): Maximum time to wait in seconds. | |
| Returns: | |
| bool: True if the file is created, False if timeout occurs. | |
| """ | |
| start_time = time.time() | |
| while not os.path.exists(file_path): | |
| if time.time() - start_time > timeout: | |
| return False | |
| time.sleep(0.5) # Check every 0.5 seconds | |
| return True | |
| def generate_spirit_animal_video(driving_video_path): | |
| os.chdir(".") | |
| try: | |
| # Step 1: Extract the first frame | |
| cap = cv2.VideoCapture(driving_video_path) | |
| if not cap.isOpened(): | |
| print("Error: Unable to open video.") | |
| return None | |
| ret, frame = cap.read() | |
| cap.release() | |
| if not ret: | |
| print("Error: Unable to read the first frame.") | |
| return None | |
| # Save the first frame | |
| first_frame_path = "./first_frame.jpg" | |
| cv2.imwrite(first_frame_path, frame) | |
| print(f"First frame saved to: {first_frame_path}") | |
| # Generate spirit animal image | |
| _, input_image = generate_multiple_animals(first_frame_path, True, 1, height_multiplier = 1) | |
| if input_image is None or not input_image: | |
| print("Error: Spirit animal generation failed.") | |
| return None | |
| spirit_animal_path = "./animal.jpeg" | |
| cv2.imwrite(spirit_animal_path, cv2.cvtColor(np.array(input_image[0]), cv2.COLOR_RGB2BGR)) | |
| print(f"Spirit animal image saved to: {spirit_animal_path}") | |
| # Step 3: Run inference | |
| output_path = "./animations/animal--uploaded_video_compressed.mp4" | |
| script_path = os.path.abspath("./LivePortrait/inference_animals.py") | |
| if not os.path.exists(script_path): | |
| print(f"Error: Inference script not found at {script_path}.") | |
| return None | |
| command = f"python {script_path} -s {spirit_animal_path} -d {driving_video_path} --driving_multiplier 1.75 --no_flag_stitching" | |
| print(f"Running command: {command}") | |
| result = os.system(command) | |
| if result != 0: | |
| print(f"Error: Command failed with exit code {result}.") | |
| return None | |
| # Verify output file exists | |
| if not os.path.exists(output_path): | |
| print(f"Error: Expected output video not found at {output_path}.") | |
| return None | |
| print(f"Output video generated at: {output_path}") | |
| return output_path | |
| except Exception as e: | |
| print(f"Error occurred: {e}") | |
| return None | |
| def generate_spirit_animal(image, animal_type, background): | |
| if animal_type == "Single Animal": | |
| if background == "Preserve Background": | |
| prompt, generated_images = spirit_animal_with_background(image) | |
| else: | |
| prompt, generated_images = spirit_animal_baseline(image) | |
| elif animal_type == "Multiple Animals": | |
| if background == "Preserve Background": | |
| prompt, generated_images = generate_multiple_animals(image, keep_background=True) | |
| else: | |
| prompt, generated_images = generate_multiple_animals(image, keep_background=False) | |
| return prompt, generated_images | |
| def compress_video(input_path, output_path, target_size_mb): | |
| target_size_bytes = target_size_mb * 1024 * 1024 | |
| temp_output = "./temp_compressed.mp4" | |
| cap = cv2.VideoCapture(input_path) | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| writer = cv2.VideoWriter(temp_output, fourcc, fps, (width, height)) | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| writer.write(frame) | |
| cap.release() | |
| writer.release() | |
| current_size = os.path.getsize(temp_output) | |
| if current_size > target_size_bytes: | |
| bitrate = int(target_size_bytes * 8 / (current_size / target_size_bytes)) | |
| os.system(f"ffmpeg -i {temp_output} -b:v {bitrate} -y {output_path}") | |
| os.remove(temp_output) | |
| else: | |
| shutil.move(temp_output, output_path) | |
| def process_video(video_file): | |
| compressed_path = "./uploaded_video_compressed.mp4" | |
| compress_video(video_file, compressed_path, target_size_mb=1) | |
| print(f"Compressed and moved video to: {compressed_path}") | |
| output_video_path = "./animations/animal--uploaded_video_compressed.mp4" | |
| generate_spirit_animal_video(compressed_path) | |
| # Wait until the output video is generated | |
| timeout = 1000 # Timeout in seconds | |
| if not wait_for_file(output_video_path, timeout=timeout): | |
| print("Timeout occurred while waiting for video generation.") | |
| return gr.update(value=None, visible=False) # Hide output if failed | |
| # Return the generated video path | |
| print(f"Output video is ready: {output_video_path}") | |
| return gr.update(value=output_video_path, visible=True) # Show video | |
| css = """ | |
| #title-container { | |
| font-family: 'Arial', sans-serif; | |
| color: #4a4a4a; | |
| text-align: center; | |
| margin-bottom: 20px; | |
| } | |
| #title-container h1 { | |
| font-size: 2.5em; | |
| font-weight: bold; | |
| color: #ff9900; | |
| } | |
| #title-container h2 { | |
| font-size: 1.2em; | |
| color: #6c757d; | |
| } | |
| #intro-text { | |
| font-size: 1em; | |
| color: #6c757d; | |
| margin: 50px; | |
| text-align: center; | |
| font-style: italic; | |
| } | |
| #prompt-output { | |
| font-family: 'Courier New', monospace; | |
| color: #5a5a5a; | |
| font-size: 1.1em; | |
| padding: 10px; | |
| background-color: #f9f9f9; | |
| border: 1px solid #ddd; | |
| border-radius: 5px; | |
| margin-top: 10px; | |
| } | |
| .examples-container { | |
| display: flex; | |
| flex-wrap: wrap; | |
| gap: 10px; | |
| justify-content: center; | |
| align-items: flex-start; | |
| } | |
| """ | |
| # Title and description | |
| title_html = """ | |
| <div id="title-container"> | |
| <h1>Spirit Animal Generator</h1> | |
| <h2>Create your unique spirit animal with AI-assisted image generation.</h2> | |
| </div> | |
| """ | |
| description_text = """ | |
| ### Project Overview | |
| Welcome to the Spirit Animal Generator! This tool leverages Stable Diffusion models to create unique visualizations of spirit animals from videos and images. | |
| #### Key Features: | |
| 1. **Prompting**: [GPT Model](https://arxiv.org/abs/2305.10435) generates descriptive prompts for each media input. | |
| 2. **Image Creation**: [ControlNet Model](https://arxiv.org/abs/2302.05543) generates animal images with pose control. | |
| 3. **Video Transformation**: [LivePortrait Model](https://arxiv.org/abs/2407.03168) generate animal animation with same facial expressions. | |
| --- | |
| ### How It Works: | |
| 1. **Upload Your Media**: | |
| - Images: Use clear, high-resolution photos for better results. | |
| - Videos: Ensure the file is in MP4 format. | |
| 2. **Customize Options**: | |
| - For images, select the type of animal and background settings. | |
| 3. **View Your Results**: | |
| - Images will produce customized visual art along with a generated prompt. | |
| - Videos will be transformed into animal animations. | |
| Discover your spirit animal and let your imagination run wild! | |
| --- | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| gr.HTML(title_html) | |
| gr.Markdown(description_text) | |
| with gr.Tabs(): | |
| with gr.Tab("Generate Spirit Animal Image"): | |
| gr.Markdown("Upload an image to generate a spirit animal.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="filepath", label="Upload an image") | |
| animal_type = gr.Radio(choices=["Single Animal", "Multiple Animals"], label="Animal Type", value="Single Animal") | |
| background_option = gr.Radio(choices=["Preserve Background", "Don't Preserve Background"], label="Background Option", value="Preserve Background") | |
| generate_image_button = gr.Button("Generate Image") | |
| gr.Examples( | |
| examples=["example1.jpg", "example2.jpg", "example3.jpg"], | |
| inputs=image_input, | |
| label="Example Images" | |
| ) | |
| with gr.Column(scale=1): | |
| generated_prompt = gr.Textbox(label="Generated Prompt") | |
| generated_gallery = gr.Gallery(label="Generated Images") | |
| generate_image_button.click( | |
| fn=generate_spirit_animal, | |
| inputs=[image_input, animal_type, background_option], | |
| outputs=[generated_prompt, generated_gallery], | |
| ) | |
| with gr.Tab("Generate Spirit Animal Video"): | |
| gr.Markdown("Upload a driving video to generate a spirit animal video.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| video_input = gr.Video(label="Upload a driving video (MP4 format)") | |
| generate_video_button = gr.Button("Generate Video") | |
| gr.Examples( | |
| examples=["video1.mp4", "video3.mp4", "video4.mp4"], | |
| inputs=video_input, | |
| label="Example Videos" | |
| ) | |
| with gr.Column(scale=1): | |
| video_output = gr.Video(label="Generated Spirit Animal Video") | |
| generate_video_button.click( | |
| fn=process_video, | |
| inputs=video_input, | |
| outputs=video_output, | |
| ) | |
| demo.launch() |