Spanicin commited on
Commit
6dfeac9
·
verified ·
1 Parent(s): 527c539

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -159
app.py CHANGED
@@ -1,160 +1,160 @@
1
- import argparse
2
- import tempfile
3
- import os
4
-
5
- from flask import Flask, request, jsonify
6
- from omegaconf import OmegaConf
7
- import torch
8
- from diffusers import AutoencoderKL, DDIMScheduler
9
- from latentsync.models.unet import UNet3DConditionModel
10
- from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
11
- from diffusers.utils.import_utils import is_xformers_available
12
- from accelerate.utils import set_seed
13
- from latentsync.whisper.audio2feature import Audio2Feature
14
- from openai import OpenAI
15
- from elevenlabs import set_api_key, generate, play, clone, Voice, VoiceSettings
16
-
17
- # Initialize the Flask app
18
- app = Flask(__name__)
19
- TEMP_DIR = None
20
-
21
- def run_inference(video_path, audio_path, video_out_path,
22
- inference_ckpt_path, unet_config_path="configs/unet/second_stage.yaml",
23
- inference_steps=20, guidance_scale=1.0, seed=1247):
24
- # Load configuration
25
- config = OmegaConf.load(unet_config_path)
26
-
27
- # Determine proper dtype based on GPU capabilities
28
- is_fp16_supported = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] > 7
29
- dtype = torch.float16 if is_fp16_supported else torch.float32
30
-
31
- # Setup scheduler
32
- scheduler = DDIMScheduler.from_pretrained("configs")
33
-
34
- # Choose whisper model based on config settings
35
- if config.model.cross_attention_dim == 768:
36
- whisper_model_path = "checkpoints/whisper/small.pt"
37
- elif config.model.cross_attention_dim == 384:
38
- whisper_model_path = "checkpoints/whisper/tiny.pt"
39
- else:
40
- raise NotImplementedError("cross_attention_dim must be 768 or 384")
41
-
42
- # Initialize the audio encoder
43
- audio_encoder = Audio2Feature(model_path=whisper_model_path,
44
- device="cuda", num_frames=config.data.num_frames)
45
-
46
- # Load VAE
47
- vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype)
48
- vae.config.scaling_factor = 0.18215
49
- vae.config.shift_factor = 0
50
-
51
- # Load UNet model from the checkpoint
52
- unet, _ = UNet3DConditionModel.from_pretrained(
53
- OmegaConf.to_container(config.model),
54
- inference_ckpt_path, # load checkpoint
55
- device="cpu",
56
- )
57
- unet = unet.to(dtype=dtype)
58
-
59
- # Optionally enable memory-efficient attention if available
60
- if is_xformers_available():
61
- unet.enable_xformers_memory_efficient_attention()
62
-
63
- # Initialize the pipeline and move to GPU
64
- pipeline = LipsyncPipeline(
65
- vae=vae,
66
- audio_encoder=audio_encoder,
67
- unet=unet,
68
- scheduler=scheduler,
69
- ).to("cuda")
70
-
71
- # Set seed
72
- if seed != -1:
73
- set_seed(seed)
74
- else:
75
- torch.seed()
76
-
77
- # Run the pipeline
78
- pipeline(
79
- video_path=video_path,
80
- audio_path=audio_path,
81
- video_out_path=video_out_path,
82
- video_mask_path=video_out_path.replace(".mp4", "_mask.mp4"),
83
- num_frames=config.data.num_frames,
84
- num_inference_steps=inference_steps,
85
- guidance_scale=guidance_scale,
86
- weight_dtype=dtype,
87
- width=config.data.resolution,
88
- height=config.data.resolution,
89
- )
90
-
91
- def create_temp_dir():
92
- return tempfile.TemporaryDirectory()
93
-
94
- def generate_audio(voice_cloning, text_prompt):
95
- if voice_cloning == 'yes':
96
- set_api_key('92e149985ea2732b4359c74346c3daee')
97
- voice = Voice(voice_id="VJpttplXHolgV2leGe5V",name="Marc",settings=VoiceSettings(
98
- stability=0.71, similarity_boost=0.9, style=0.0, use_speaker_boost=True),)
99
-
100
- audio = generate(text = text_prompt, voice = voice, model = "eleven_multilingual_v2",stream=True, latency=4)
101
- with tempfile.NamedTemporaryFile(suffix=".mp3", prefix="cloned_audio_",dir=TEMP_DIR.name, delete=False) as temp_file:
102
- for chunk in audio:
103
- temp_file.write(chunk)
104
- driven_audio_path = temp_file.name
105
- print('driven_audio_path',driven_audio_path)
106
-
107
- return driven_audio_path
108
-
109
-
110
- @app.route('/run', methods=['POST'])
111
- def generate_video():
112
- global TEMP_DIR
113
- TEMP_DIR = create_temp_dir()
114
-
115
- if 'video' not in request.files:
116
- return jsonify({'error': 'Video file is required.'}), 400
117
-
118
- video_file = request.files['video']
119
- text_prompt = request.form['text_prompt']
120
- print('Input text prompt: ',text_prompt)
121
- text_prompt = text_prompt.strip()
122
- if not text_prompt:
123
- return jsonify({'error': 'Input text prompt cannot be blank'}), 400
124
-
125
- voice_cloning = 'yes'
126
- temp_audio_path = generate_audio(voice_cloning, text_prompt)
127
- with tempfile.NamedTemporaryFile(suffix=".mp4", prefix="input_",dir=TEMP_DIR.name, delete=False) as temp_file:
128
- temp_video_path = temp_file.name
129
- video_file.save(temp_video_path)
130
- print('temp_video_path',temp_video_path)
131
-
132
- output_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
133
-
134
- # You can pass additional parameters via form data if needed (e.g., checkpoint path)
135
- inference_ckpt_path = request.form.get('inference_ckpt_path', 'checkpoints/latentsync_unet.pt')
136
- unet_config_path = request.form.get('unet_config_path', 'configs/unet/second_stage.yaml')
137
-
138
- try:
139
- run_inference(
140
- video_path=temp_video_path.name,
141
- audio_path=temp_audio_path.name,
142
- video_out_path=output_video,
143
- inference_ckpt_path=inference_ckpt_path,
144
- unet_config_path=unet_config_path,
145
- inference_steps=int(request.form.get('inference_steps', 20)),
146
- guidance_scale=float(request.form.get('guidance_scale', 1.0)),
147
- seed=int(request.form.get('seed', 1247))
148
- )
149
- # Return the output video path or further process the file for download
150
- return jsonify({'output_video': output_video}), 200
151
- except Exception as e:
152
- return jsonify({'error': str(e)}), 500
153
-
154
- @app.route("/health", methods=["GET"])
155
- def health_status():
156
- response = {"online": "true"}
157
- return jsonify(response)
158
-
159
- if __name__ == '__main__':
160
  app.run(debug=True)
 
1
+ import argparse
2
+ import tempfile
3
+ import os
4
+
5
+ from flask import Flask, request, jsonify
6
+ from omegaconf import OmegaConf
7
+ import torch
8
+ from diffusers import AutoencoderKL, DDIMScheduler
9
+ from latentsync.models.unet import UNet3DConditionModel
10
+ from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
11
+ from diffusers.utils.import_utils import is_xformers_available
12
+ from accelerate.utils import set_seed
13
+ from latentsync.whisper.audio2feature import Audio2Feature
14
+ from openai import OpenAI
15
+ from elevenlabs import set_api_key, generate, play, clone, Voice, VoiceSettings
16
+
17
+ # Initialize the Flask app
18
+ app = Flask(__name__)
19
+ TEMP_DIR = None
20
+
21
+ def run_inference(video_path, audio_path, video_out_path,
22
+ inference_ckpt_path, unet_config_path="configs/unet/second_stage.yaml",
23
+ inference_steps=20, guidance_scale=1.0, seed=1247):
24
+ # Load configuration
25
+ config = OmegaConf.load(unet_config_path)
26
+
27
+ # Determine proper dtype based on GPU capabilities
28
+ is_fp16_supported = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] > 7
29
+ dtype = torch.float16 if is_fp16_supported else torch.float32
30
+
31
+ # Setup scheduler
32
+ scheduler = DDIMScheduler.from_pretrained("configs")
33
+
34
+ # Choose whisper model based on config settings
35
+ if config.model.cross_attention_dim == 768:
36
+ whisper_model_path = "checkpoints/whisper/small.pt"
37
+ elif config.model.cross_attention_dim == 384:
38
+ whisper_model_path = "checkpoints/whisper/tiny.pt"
39
+ else:
40
+ raise NotImplementedError("cross_attention_dim must be 768 or 384")
41
+
42
+ # Initialize the audio encoder
43
+ audio_encoder = Audio2Feature(model_path=whisper_model_path,
44
+ device="cuda", num_frames=config.data.num_frames)
45
+
46
+ # Load VAE
47
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype)
48
+ vae.config.scaling_factor = 0.18215
49
+ vae.config.shift_factor = 0
50
+
51
+ # Load UNet model from the checkpoint
52
+ unet, _ = UNet3DConditionModel.from_pretrained(
53
+ OmegaConf.to_container(config.model),
54
+ inference_ckpt_path, # load checkpoint
55
+ device="cpu",
56
+ )
57
+ unet = unet.to(dtype=dtype)
58
+
59
+ # Optionally enable memory-efficient attention if available
60
+ if is_xformers_available():
61
+ unet.enable_xformers_memory_efficient_attention()
62
+
63
+ # Initialize the pipeline and move to GPU
64
+ pipeline = LipsyncPipeline(
65
+ vae=vae,
66
+ audio_encoder=audio_encoder,
67
+ unet=unet,
68
+ scheduler=scheduler,
69
+ ).to("cuda")
70
+
71
+ # Set seed
72
+ if seed != -1:
73
+ set_seed(seed)
74
+ else:
75
+ torch.seed()
76
+
77
+ # Run the pipeline
78
+ pipeline(
79
+ video_path=video_path,
80
+ audio_path=audio_path,
81
+ video_out_path=video_out_path,
82
+ video_mask_path=video_out_path.replace(".mp4", "_mask.mp4"),
83
+ num_frames=config.data.num_frames,
84
+ num_inference_steps=inference_steps,
85
+ guidance_scale=guidance_scale,
86
+ weight_dtype=dtype,
87
+ width=config.data.resolution,
88
+ height=config.data.resolution,
89
+ )
90
+
91
+ def create_temp_dir():
92
+ return tempfile.TemporaryDirectory()
93
+
94
+ def generate_audio(voice_cloning, text_prompt):
95
+ if voice_cloning == 'yes':
96
+ set_api_key('92e149985ea2732b4359c74346c3daee')
97
+ voice = Voice(voice_id="VJpttplXHolgV2leGe5V",name="Marc",settings=VoiceSettings(
98
+ stability=0.71, similarity_boost=0.9, style=0.0, use_speaker_boost=True),)
99
+
100
+ audio = generate(text = text_prompt, voice = voice, model = "eleven_multilingual_v2",stream=True, latency=4)
101
+ with tempfile.NamedTemporaryFile(suffix=".mp3", prefix="cloned_audio_",dir=TEMP_DIR.name, delete=False) as temp_file:
102
+ for chunk in audio:
103
+ temp_file.write(chunk)
104
+ driven_audio_path = temp_file.name
105
+ print('driven_audio_path',driven_audio_path)
106
+
107
+ return driven_audio_path
108
+
109
+
110
+ @app.route('/run', methods=['POST'])
111
+ def generate_video():
112
+ global TEMP_DIR
113
+ TEMP_DIR = create_temp_dir()
114
+
115
+ if 'video' not in request.files:
116
+ return jsonify({'error': 'Video file is required.'}), 400
117
+
118
+ video_file = request.files['video']
119
+ text_prompt = request.form['text_prompt']
120
+ print('Input text prompt: ',text_prompt)
121
+ text_prompt = text_prompt.strip()
122
+ if not text_prompt:
123
+ return jsonify({'error': 'Input text prompt cannot be blank'}), 400
124
+
125
+ voice_cloning = 'yes'
126
+ temp_audio_path = generate_audio(voice_cloning, text_prompt)
127
+ with tempfile.NamedTemporaryFile(suffix=".mp4", prefix="input_",dir=TEMP_DIR.name, delete=False) as temp_file:
128
+ temp_video_path = temp_file.name
129
+ video_file.save(temp_video_path)
130
+ print('temp_video_path',temp_video_path)
131
+
132
+ output_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
133
+
134
+ # You can pass additional parameters via form data if needed (e.g., checkpoint path)
135
+ inference_ckpt_path = request.form.get('inference_ckpt_path', 'checkpoints/latentsync_unet.pt')
136
+ unet_config_path = request.form.get('unet_config_path', 'configs/unet/second_stage.yaml')
137
+
138
+ try:
139
+ run_inference(
140
+ video_path=temp_video_path,
141
+ audio_path=temp_audio_path,
142
+ video_out_path=output_video,
143
+ inference_ckpt_path=inference_ckpt_path,
144
+ unet_config_path=unet_config_path,
145
+ inference_steps=int(request.form.get('inference_steps', 20)),
146
+ guidance_scale=float(request.form.get('guidance_scale', 1.0)),
147
+ seed=int(request.form.get('seed', 1247))
148
+ )
149
+ # Return the output video path or further process the file for download
150
+ return jsonify({'output_video': output_video}), 200
151
+ except Exception as e:
152
+ return jsonify({'error': str(e)}), 500
153
+
154
+ @app.route("/health", methods=["GET"])
155
+ def health_status():
156
+ response = {"online": "true"}
157
+ return jsonify(response)
158
+
159
+ if __name__ == '__main__':
160
  app.run(debug=True)