Poooroooseee commited on
Commit
e3a6017
·
verified ·
1 Parent(s): d111819

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +274 -0
app.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import gradio as gr
5
+ from typing import Optional, Tuple
6
+ from funasr import AutoModel
7
+ from pathlib import Path
8
+
9
+ os.environ["TORCHDYNAMO_DISABLE"] = "1"
10
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
11
+ if os.environ.get("HF_REPO_ID", "").strip() == "":
12
+ os.environ["HF_REPO_ID"] = "openbmb/VoxCPM-0.5B"
13
+
14
+ import voxcpm
15
+
16
+
17
+ class VoxCPMDemo:
18
+ def __init__(self) -> None:
19
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ print(f"🚀 Running on device: {self.device}")
21
+
22
+ # ASR model for prompt text recognition
23
+ self.asr_model_id = "iic/SenseVoiceSmall"
24
+ self.asr_model: Optional[AutoModel] = AutoModel(
25
+ model=self.asr_model_id,
26
+ disable_update=True,
27
+ log_level='DEBUG',
28
+ device="cuda:0" if self.device == "cuda" else "cpu",
29
+ )
30
+
31
+ # TTS model (lazy init)
32
+ self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
33
+ self.default_local_model_dir = "./models/VoxCPM-0.5B"
34
+
35
+ # ---------- Model helpers ----------
36
+ def _resolve_model_dir(self) -> str:
37
+ """
38
+ Resolve model directory:
39
+ 1) Use local checkpoint directory if exists
40
+ 2) If HF_REPO_ID env is set, download into models/{repo}
41
+ 3) Fallback to 'models'
42
+ """
43
+ if os.path.isdir(self.default_local_model_dir):
44
+ return self.default_local_model_dir
45
+
46
+ repo_id = os.environ.get("HF_REPO_ID", "").strip()
47
+ if len(repo_id) > 0:
48
+ target_dir = os.path.join("models", repo_id.replace("/", "__"))
49
+ if not os.path.isdir(target_dir):
50
+ try:
51
+ from huggingface_hub import snapshot_download # type: ignore
52
+ os.makedirs(target_dir, exist_ok=True)
53
+ print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...")
54
+ snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
55
+ except Exception as e:
56
+ print(f"Warning: HF download failed: {e}. Falling back to 'models'.")
57
+ return "models"
58
+ return target_dir
59
+ return "models"
60
+
61
+ def get_or_load_voxcpm(self) -> voxcpm.VoxCPM:
62
+ if self.voxcpm_model is not None:
63
+ return self.voxcpm_model
64
+ print("Model not loaded, initializing...")
65
+ model_dir = self._resolve_model_dir()
66
+ print(f"Using model dir: {model_dir}")
67
+ self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir)
68
+ print("Model loaded successfully.")
69
+ return self.voxcpm_model
70
+
71
+ # ---------- Functional endpoints ----------
72
+ def prompt_wav_recognition(self, prompt_wav: Optional[str]) -> str:
73
+ if prompt_wav is None:
74
+ return ""
75
+ res = self.asr_model.generate(input=prompt_wav, language="auto", use_itn=True)
76
+ text = res[0]["text"].split('|>')[-1]
77
+ return text
78
+
79
+ def generate_tts_audio(
80
+ self,
81
+ text_input: str,
82
+ prompt_wav_path_input: Optional[str] = None,
83
+ prompt_text_input: Optional[str] = None,
84
+ cfg_value_input: float = 2.0,
85
+ inference_timesteps_input: int = 10,
86
+ do_normalize: bool = True,
87
+ denoise: bool = True,
88
+ ) -> Tuple[int, np.ndarray]:
89
+ """
90
+ Generate speech from text using VoxCPM; optional reference audio for voice style guidance.
91
+ Returns (sample_rate, waveform_numpy)
92
+ """
93
+ current_model = self.get_or_load_voxcpm()
94
+
95
+ text = (text_input or "").strip()
96
+ if len(text) == 0:
97
+ raise ValueError("Please input text to synthesize.")
98
+
99
+ prompt_wav_path = prompt_wav_path_input if prompt_wav_path_input else None
100
+ prompt_text = prompt_text_input if prompt_text_input else None
101
+
102
+ print(f"Generating audio for text: '{text[:60]}...'")
103
+ wav = current_model.generate(
104
+ text=text,
105
+ prompt_text=prompt_text,
106
+ prompt_wav_path=prompt_wav_path,
107
+ cfg_value=float(cfg_value_input),
108
+ inference_timesteps=int(inference_timesteps_input),
109
+ normalize=do_normalize,
110
+ denoise=denoise,
111
+ )
112
+ return (16000, wav)
113
+
114
+
115
+ # ---------- UI Builders ----------
116
+
117
+ def create_demo_interface(demo: VoxCPMDemo):
118
+ """Build the Gradio UI for VoxCPM demo."""
119
+ # static assets (logo path)
120
+ gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"])
121
+
122
+ with gr.Blocks(
123
+ theme=gr.themes.Soft(
124
+ primary_hue="blue",
125
+ secondary_hue="gray",
126
+ neutral_hue="slate",
127
+ font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"]
128
+ ),
129
+ css="""
130
+ .logo-container {
131
+ text-align: center;
132
+ margin: 0.5rem 0 1rem 0;
133
+ }
134
+ .logo-container img {
135
+ height: 80px;
136
+ width: auto;
137
+ max-width: 200px;
138
+ display: inline-block;
139
+ }
140
+ /* Bold accordion labels */
141
+ #acc_quick details > summary,
142
+ #acc_tips details > summary {
143
+ font-weight: 600 !important;
144
+ font-size: 1.1em !important;
145
+ }
146
+ /* Bold labels for specific checkboxes */
147
+ #chk_denoise label,
148
+ #chk_denoise span,
149
+ #chk_normalize label,
150
+ #chk_normalize span {
151
+ font-weight: 600;
152
+ }
153
+ """
154
+ ) as interface:
155
+ # Header logo
156
+ gr.HTML('<div class="logo-container"><img src="/gradio_api/file=assets/voxcpm_logo.png" alt="VoxCPM Logo"></div>')
157
+
158
+ # Quick Start
159
+ with gr.Accordion("📋 Quick Start Guide |快速入门", open=False, elem_id="acc_quick"):
160
+ gr.Markdown("""
161
+ ### How to Use |使用说明
162
+ 1. **(Optional) Provide a Voice Prompt** - Upload or record an audio clip to provide the desired voice characteristics for synthesis.
163
+ **(可选)提供参考声音** - 上传或录制一段音频,为声音合成提供音色、语调和情感等个性化特征
164
+ 2. **(Optional) Enter prompt text** - If you provided a voice prompt, enter the corresponding transcript here (auto-recognition available).
165
+ **(可选项)输入参考文本** - 如果提供了参考语音,请输入其对应的文本内容(支持自动识别)。
166
+ 3. **Enter target text** - Type the text you want the model to speak.
167
+ **输入目标文本** - 输入您希望模型朗读的文字内容。
168
+ 4. **Generate Speech** - Click the "Generate" button to create your audio.
169
+ **生成语音** - 点击"生成"按钮,即可为您创造出音频。
170
+ """)
171
+
172
+ # Pro Tips
173
+ with gr.Accordion("💡 Pro Tips |使用建议", open=False, elem_id="acc_tips"):
174
+ gr.Markdown("""
175
+ ### Prompt Speech Enhancement|参考语音降噪
176
+ - **Enable** to remove background noise for a clean, studio-like voice, with an external ZipEnhancer component.
177
+ **启用**:通过 ZipEnhancer 组件消除背景噪音,获得更好的音质。
178
+ - **Disable** to preserve the original audio's background atmosphere.
179
+ **禁用**:保留原始音频的背景环境声,如果想复刻相应声学环境。
180
+
181
+ ### Text Normalization|文本正则化
182
+ - **Enable** to process general text with an external WeTextProcessing component.
183
+ **启用**:使用 WeTextProcessing 组件,可处理常见文本。
184
+ - **Disable** to use VoxCPM's native text understanding ability. For example, it supports phonemes input ({HH AH0 L OW1}), try it!
185
+ **禁用**:将使用 VoxCPM 内置的文本理解能力。如,支持音素输入(如 {da4}{jia1}好)和公式符号合成,尝试一下!
186
+
187
+ ### CFG Value|CFG 值
188
+ - **Lower CFG** if the voice prompt sounds strained or expressive.
189
+ **调低**:如果提示语音听起来不自然或过于夸张。
190
+ - **Higher CFG** for better adherence to the prompt speech style or input text.
191
+ **调高**:为更好地贴合提示音频的风格或输入文本。
192
+
193
+ ### Inference Timesteps|推理时间步
194
+ - **Lower** for faster synthesis speed.
195
+ **调低**:合成速度更快。
196
+ - **Higher** for better synthesis quality.
197
+ **调高**:合成质量更佳。
198
+ """)
199
+
200
+ # Main controls
201
+ with gr.Row():
202
+ with gr.Column():
203
+ prompt_wav = gr.Audio(
204
+ sources=["upload", 'microphone'],
205
+ type="filepath",
206
+ label="Prompt Speech (Optional, or let VoxCPM improvise)",
207
+ value="./examples/example.wav",
208
+ )
209
+ DoDenoisePromptAudio = gr.Checkbox(
210
+ value=False,
211
+ label="Prompt Speech Enhancement",
212
+ elem_id="chk_denoise",
213
+ info="We use ZipEnhancer model to denoise the prompt audio."
214
+ )
215
+ with gr.Row():
216
+ prompt_text = gr.Textbox(
217
+ value="Just by listening a few minutes a day, you'll be able to eliminate negative thoughts by conditioning your mind to be more positive.",
218
+ label="Prompt Text",
219
+ placeholder="Please enter the prompt text. Automatic recognition is supported, and you can correct the results yourself..."
220
+ )
221
+ run_btn = gr.Button("Generate Speech", variant="primary")
222
+
223
+ with gr.Column():
224
+ cfg_value = gr.Slider(
225
+ minimum=1.0,
226
+ maximum=3.0,
227
+ value=2.0,
228
+ step=0.1,
229
+ label="CFG Value (Guidance Scale)",
230
+ info="Higher values increase adherence to prompt, lower values allow more creativity"
231
+ )
232
+ inference_timesteps = gr.Slider(
233
+ minimum=4,
234
+ maximum=30,
235
+ value=10,
236
+ step=1,
237
+ label="Inference Timesteps",
238
+ info="Number of inference timesteps for generation (higher values may improve quality but slower)"
239
+ )
240
+ with gr.Row():
241
+ text = gr.Textbox(
242
+ value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
243
+ label="Target Text",
244
+ )
245
+ with gr.Row():
246
+ DoNormalizeText = gr.Checkbox(
247
+ value=False,
248
+ label="Text Normalization",
249
+ elem_id="chk_normalize",
250
+ info="We use wetext library to normalize the input text."
251
+ )
252
+ audio_output = gr.Audio(label="Output Audio")
253
+
254
+ # Wiring
255
+ run_btn.click(
256
+ fn=demo.generate_tts_audio,
257
+ inputs=[text, prompt_wav, prompt_text, cfg_value, inference_timesteps, DoNormalizeText, DoDenoisePromptAudio],
258
+ outputs=[audio_output],
259
+ show_progress=True,
260
+ api_name="generate",
261
+ )
262
+ prompt_wav.change(fn=demo.prompt_wav_recognition, inputs=[prompt_wav], outputs=[prompt_text])
263
+
264
+ return interface
265
+
266
+
267
+ # ---------- Launch ----------
268
+ if __name__ == "__main__":
269
+ demo = VoxCPMDemo()
270
+ interface = create_demo_interface(demo)
271
+ interface.queue(max_size=10).launch(
272
+ server_name="0.0.0.0",
273
+ server_port=int(os.environ.get("PORT", 7860))
274
+ )