test with global model
Browse files- sparktts/modules/speaker/perceiver_encoder.py +9 -18
- webui.py +11 -6
sparktts/modules/speaker/perceiver_encoder.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
| 15 |
|
| 16 |
# Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
|
| 17 |
|
|
|
|
| 18 |
from functools import wraps
|
| 19 |
|
| 20 |
import torch
|
|
@@ -45,21 +46,6 @@ def once(fn):
|
|
| 45 |
|
| 46 |
print_once = once(print)
|
| 47 |
|
| 48 |
-
# Define config class at module level
|
| 49 |
-
class EfficientAttentionConfig:
|
| 50 |
-
def __init__(self, enable_flash, enable_math, enable_mem_efficient):
|
| 51 |
-
self.enable_flash = enable_flash
|
| 52 |
-
self.enable_math = enable_math
|
| 53 |
-
self.enable_mem_efficient = enable_mem_efficient
|
| 54 |
-
|
| 55 |
-
def _asdict(self):
|
| 56 |
-
return {
|
| 57 |
-
'enable_flash': self.enable_flash,
|
| 58 |
-
'enable_math': self.enable_math,
|
| 59 |
-
'enable_mem_efficient': self.enable_mem_efficient
|
| 60 |
-
}
|
| 61 |
-
|
| 62 |
-
|
| 63 |
# main class
|
| 64 |
|
| 65 |
|
|
@@ -77,7 +63,12 @@ class Attend(nn.Module):
|
|
| 77 |
use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
|
| 78 |
), "in order to use flash attention, you must be using pytorch 2.0 or above"
|
| 79 |
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
self.cuda_config = None
|
| 82 |
|
| 83 |
if not torch.cuda.is_available() or not use_flash:
|
|
@@ -89,12 +80,12 @@ class Attend(nn.Module):
|
|
| 89 |
print_once(
|
| 90 |
"A100 GPU detected, using flash attention if input tensor is on cuda"
|
| 91 |
)
|
| 92 |
-
self.cuda_config =
|
| 93 |
else:
|
| 94 |
print_once(
|
| 95 |
"Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
|
| 96 |
)
|
| 97 |
-
self.cuda_config =
|
| 98 |
|
| 99 |
def get_mask(self, n, device):
|
| 100 |
if exists(self.mask) and self.mask.shape[-1] >= n:
|
|
|
|
| 15 |
|
| 16 |
# Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
|
| 17 |
|
| 18 |
+
from collections import namedtuple
|
| 19 |
from functools import wraps
|
| 20 |
|
| 21 |
import torch
|
|
|
|
| 46 |
|
| 47 |
print_once = once(print)
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
# main class
|
| 50 |
|
| 51 |
|
|
|
|
| 63 |
use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
|
| 64 |
), "in order to use flash attention, you must be using pytorch 2.0 or above"
|
| 65 |
|
| 66 |
+
# determine efficient attention configs for cuda and cpu
|
| 67 |
+
self.config = namedtuple(
|
| 68 |
+
"EfficientAttentionConfig",
|
| 69 |
+
["enable_flash", "enable_math", "enable_mem_efficient"],
|
| 70 |
+
)
|
| 71 |
+
self.cpu_config = self.config(True, True, True)
|
| 72 |
self.cuda_config = None
|
| 73 |
|
| 74 |
if not torch.cuda.is_available() or not use_flash:
|
|
|
|
| 80 |
print_once(
|
| 81 |
"A100 GPU detected, using flash attention if input tensor is on cuda"
|
| 82 |
)
|
| 83 |
+
self.cuda_config = self.config(True, False, False)
|
| 84 |
else:
|
| 85 |
print_once(
|
| 86 |
"Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
|
| 87 |
)
|
| 88 |
+
self.cuda_config = self.config(False, True, True)
|
| 89 |
|
| 90 |
def get_mask(self, n, device):
|
| 91 |
if exists(self.mask) and self.mask.shape[-1] >= n:
|
webui.py
CHANGED
|
@@ -25,6 +25,8 @@ from sparktts.utils.token_parser import LEVELS_MAP_UI
|
|
| 25 |
from huggingface_hub import snapshot_download
|
| 26 |
import spaces
|
| 27 |
|
|
|
|
|
|
|
| 28 |
def initialize_model(model_dir=None, device="cpu"):
|
| 29 |
"""Load the model once at the beginning."""
|
| 30 |
|
|
@@ -38,8 +40,7 @@ def initialize_model(model_dir=None, device="cpu"):
|
|
| 38 |
return model
|
| 39 |
|
| 40 |
@spaces.GPU
|
| 41 |
-
def generate(
|
| 42 |
-
text,
|
| 43 |
prompt_speech,
|
| 44 |
prompt_text,
|
| 45 |
gender,
|
|
@@ -47,6 +48,10 @@ def generate(model,
|
|
| 47 |
speed,
|
| 48 |
):
|
| 49 |
"""Generate audio from text."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
# if gpu available, move model to gpu
|
| 51 |
if torch.cuda.is_available():
|
| 52 |
model = model.to("cuda")
|
|
@@ -66,7 +71,6 @@ def generate(model,
|
|
| 66 |
|
| 67 |
def run_tts(
|
| 68 |
text,
|
| 69 |
-
model,
|
| 70 |
prompt_text=None,
|
| 71 |
prompt_speech=None,
|
| 72 |
gender=None,
|
|
@@ -90,7 +94,7 @@ def run_tts(
|
|
| 90 |
logging.info("Starting inference...")
|
| 91 |
|
| 92 |
# Perform inference and save the output audio
|
| 93 |
-
wav = generate(
|
| 94 |
prompt_speech,
|
| 95 |
prompt_text,
|
| 96 |
gender,
|
|
@@ -109,6 +113,9 @@ def build_ui(model_dir, device=0):
|
|
| 109 |
|
| 110 |
# Initialize model
|
| 111 |
model = initialize_model(model_dir, device=device)
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
# Define callback function for voice cloning
|
| 114 |
def voice_clone(text, prompt_text, prompt_wav_upload, prompt_wav_record):
|
|
@@ -123,7 +130,6 @@ def build_ui(model_dir, device=0):
|
|
| 123 |
|
| 124 |
audio_output_path = run_tts(
|
| 125 |
text,
|
| 126 |
-
model,
|
| 127 |
prompt_text=prompt_text_clean,
|
| 128 |
prompt_speech=prompt_speech
|
| 129 |
)
|
|
@@ -141,7 +147,6 @@ def build_ui(model_dir, device=0):
|
|
| 141 |
speed_val = LEVELS_MAP_UI[int(speed)]
|
| 142 |
audio_output_path = run_tts(
|
| 143 |
text,
|
| 144 |
-
model,
|
| 145 |
gender=gender,
|
| 146 |
pitch=pitch_val,
|
| 147 |
speed=speed_val
|
|
|
|
| 25 |
from huggingface_hub import snapshot_download
|
| 26 |
import spaces
|
| 27 |
|
| 28 |
+
MODEL = None
|
| 29 |
+
|
| 30 |
def initialize_model(model_dir=None, device="cpu"):
|
| 31 |
"""Load the model once at the beginning."""
|
| 32 |
|
|
|
|
| 40 |
return model
|
| 41 |
|
| 42 |
@spaces.GPU
|
| 43 |
+
def generate(text,
|
|
|
|
| 44 |
prompt_speech,
|
| 45 |
prompt_text,
|
| 46 |
gender,
|
|
|
|
| 48 |
speed,
|
| 49 |
):
|
| 50 |
"""Generate audio from text."""
|
| 51 |
+
|
| 52 |
+
global MODEL
|
| 53 |
+
model = MODEL
|
| 54 |
+
|
| 55 |
# if gpu available, move model to gpu
|
| 56 |
if torch.cuda.is_available():
|
| 57 |
model = model.to("cuda")
|
|
|
|
| 71 |
|
| 72 |
def run_tts(
|
| 73 |
text,
|
|
|
|
| 74 |
prompt_text=None,
|
| 75 |
prompt_speech=None,
|
| 76 |
gender=None,
|
|
|
|
| 94 |
logging.info("Starting inference...")
|
| 95 |
|
| 96 |
# Perform inference and save the output audio
|
| 97 |
+
wav = generate(text,
|
| 98 |
prompt_speech,
|
| 99 |
prompt_text,
|
| 100 |
gender,
|
|
|
|
| 113 |
|
| 114 |
# Initialize model
|
| 115 |
model = initialize_model(model_dir, device=device)
|
| 116 |
+
|
| 117 |
+
global MODEL
|
| 118 |
+
MODEL = model
|
| 119 |
|
| 120 |
# Define callback function for voice cloning
|
| 121 |
def voice_clone(text, prompt_text, prompt_wav_upload, prompt_wav_record):
|
|
|
|
| 130 |
|
| 131 |
audio_output_path = run_tts(
|
| 132 |
text,
|
|
|
|
| 133 |
prompt_text=prompt_text_clean,
|
| 134 |
prompt_speech=prompt_speech
|
| 135 |
)
|
|
|
|
| 147 |
speed_val = LEVELS_MAP_UI[int(speed)]
|
| 148 |
audio_output_path = run_tts(
|
| 149 |
text,
|
|
|
|
| 150 |
gender=gender,
|
| 151 |
pitch=pitch_val,
|
| 152 |
speed=speed_val
|