|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" Example Usage |
|
|
torchrun --nproc_per_node=1 \ |
|
|
benchmark.py --output-dir $log_dir \ |
|
|
--batch-size $batch_size \ |
|
|
--enable-warmup \ |
|
|
--split-name $split_name \ |
|
|
--model-path $CKPT_DIR/$model/model_1200000.pt \ |
|
|
--vocab-file $CKPT_DIR/$model/vocab.txt \ |
|
|
--vocoder-trt-engine-path $vocoder_trt_engine_path \ |
|
|
--backend-type $backend_type \ |
|
|
--tllm-model-dir $TRTLLM_ENGINE_DIR || exit 1 |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import importlib |
|
|
import json |
|
|
import os |
|
|
import sys |
|
|
import time |
|
|
|
|
|
import datasets |
|
|
import tensorrt as trt |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import torch.nn.functional as F |
|
|
import torchaudio |
|
|
from datasets import load_dataset |
|
|
from huggingface_hub import hf_hub_download |
|
|
from tensorrt_llm._utils import trt_dtype_to_torch |
|
|
from tensorrt_llm.logger import logger |
|
|
from tensorrt_llm.runtime.session import Session, TensorInfo |
|
|
from torch.utils.data import DataLoader, DistributedSampler |
|
|
from tqdm import tqdm |
|
|
from vocos import Vocos |
|
|
|
|
|
|
|
|
sys.path.append(f"{os.path.dirname(os.path.abspath(__file__))}/../../../../src/") |
|
|
|
|
|
from f5_tts.eval.utils_eval import padded_mel_batch |
|
|
from f5_tts.model.modules import get_vocos_mel_spectrogram |
|
|
from f5_tts.model.utils import convert_char_to_pinyin, get_tokenizer, list_str_to_idx |
|
|
|
|
|
|
|
|
F5TTS = importlib.import_module("model_repo_f5_tts.f5_tts.1.f5_tts_trtllm").F5TTS |
|
|
|
|
|
torch.manual_seed(0) |
|
|
|
|
|
|
|
|
def get_args(): |
|
|
parser = argparse.ArgumentParser(description="extract speech code") |
|
|
parser.add_argument( |
|
|
"--split-name", |
|
|
type=str, |
|
|
default="wenetspeech4tts", |
|
|
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"], |
|
|
help="huggingface dataset split name", |
|
|
) |
|
|
parser.add_argument("--output-dir", required=True, type=str, help="dir to save result") |
|
|
parser.add_argument( |
|
|
"--vocab-file", |
|
|
required=True, |
|
|
type=str, |
|
|
help="vocab file", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model-path", |
|
|
required=True, |
|
|
type=str, |
|
|
help="model path, to load text embedding", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--tllm-model-dir", |
|
|
required=True, |
|
|
type=str, |
|
|
help="tllm model dir", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--batch-size", |
|
|
required=True, |
|
|
type=int, |
|
|
help="batch size (per-device) for inference", |
|
|
) |
|
|
parser.add_argument("--num-workers", type=int, default=0, help="workers for dataloader") |
|
|
parser.add_argument("--prefetch", type=int, default=None, help="prefetch for dataloader") |
|
|
parser.add_argument( |
|
|
"--vocoder", |
|
|
default="vocos", |
|
|
type=str, |
|
|
help="vocoder name", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--vocoder-trt-engine-path", |
|
|
default=None, |
|
|
type=str, |
|
|
help="vocoder trt engine path", |
|
|
) |
|
|
parser.add_argument("--enable-warmup", action="store_true") |
|
|
parser.add_argument("--remove-input-padding", action="store_true") |
|
|
parser.add_argument("--use-perf", action="store_true", help="use nvtx to record performance") |
|
|
parser.add_argument("--backend-type", type=str, default="triton", choices=["trt", "pytorch"], help="backend type") |
|
|
args = parser.parse_args() |
|
|
return args |
|
|
|
|
|
|
|
|
def data_collator(batch, vocab_char_map, device="cuda", use_perf=False): |
|
|
if use_perf: |
|
|
torch.cuda.nvtx.range_push("data_collator") |
|
|
target_sample_rate = 24000 |
|
|
target_rms = 0.1 |
|
|
( |
|
|
ids, |
|
|
ref_rms_list, |
|
|
ref_mel_list, |
|
|
ref_mel_len_list, |
|
|
estimated_reference_target_mel_len, |
|
|
reference_target_texts_list, |
|
|
) = ( |
|
|
[], |
|
|
[], |
|
|
[], |
|
|
[], |
|
|
[], |
|
|
[], |
|
|
) |
|
|
for i, item in enumerate(batch): |
|
|
item_id, prompt_text, target_text = ( |
|
|
item["id"], |
|
|
item["prompt_text"], |
|
|
item["target_text"], |
|
|
) |
|
|
ids.append(item_id) |
|
|
reference_target_texts_list.append(prompt_text + target_text) |
|
|
|
|
|
ref_audio_org, ref_sr = ( |
|
|
item["prompt_audio"]["array"], |
|
|
item["prompt_audio"]["sampling_rate"], |
|
|
) |
|
|
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float() |
|
|
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org))) |
|
|
ref_rms_list.append(ref_rms) |
|
|
if ref_rms < target_rms: |
|
|
ref_audio_org = ref_audio_org * target_rms / ref_rms |
|
|
|
|
|
if ref_sr != target_sample_rate: |
|
|
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate) |
|
|
ref_audio = resampler(ref_audio_org) |
|
|
else: |
|
|
ref_audio = ref_audio_org |
|
|
|
|
|
if use_perf: |
|
|
torch.cuda.nvtx.range_push(f"mel_spectrogram {i}") |
|
|
ref_audio = ref_audio.to("cuda") |
|
|
ref_mel = get_vocos_mel_spectrogram(ref_audio).squeeze(0) |
|
|
if use_perf: |
|
|
torch.cuda.nvtx.range_pop() |
|
|
ref_mel_len = ref_mel.shape[-1] |
|
|
assert ref_mel.shape[0] == 100 |
|
|
|
|
|
ref_mel_list.append(ref_mel) |
|
|
ref_mel_len_list.append(ref_mel_len) |
|
|
|
|
|
estimated_reference_target_mel_len.append( |
|
|
int(ref_mel_len * (1 + len(target_text.encode("utf-8")) / len(prompt_text.encode("utf-8")))) |
|
|
) |
|
|
|
|
|
ref_mel_batch = padded_mel_batch(ref_mel_list) |
|
|
ref_mel_len_batch = torch.LongTensor(ref_mel_len_list) |
|
|
|
|
|
pinyin_list = convert_char_to_pinyin(reference_target_texts_list, polyphone=True) |
|
|
text_pad_sequence = list_str_to_idx(pinyin_list, vocab_char_map) |
|
|
|
|
|
if use_perf: |
|
|
torch.cuda.nvtx.range_pop() |
|
|
return { |
|
|
"ids": ids, |
|
|
"ref_rms_list": ref_rms_list, |
|
|
"ref_mel_batch": ref_mel_batch, |
|
|
"ref_mel_len_batch": ref_mel_len_batch, |
|
|
"text_pad_sequence": text_pad_sequence, |
|
|
"estimated_reference_target_mel_len": estimated_reference_target_mel_len, |
|
|
} |
|
|
|
|
|
|
|
|
def init_distributed(): |
|
|
world_size = int(os.environ.get("WORLD_SIZE", 1)) |
|
|
local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
|
|
rank = int(os.environ.get("RANK", 0)) |
|
|
print( |
|
|
"Inference on multiple gpus, this gpu {}".format(local_rank) |
|
|
+ ", rank {}, world_size {}".format(rank, world_size) |
|
|
) |
|
|
torch.cuda.set_device(local_rank) |
|
|
|
|
|
dist.init_process_group( |
|
|
"nccl", |
|
|
) |
|
|
return world_size, local_rank, rank |
|
|
|
|
|
|
|
|
def load_vocoder( |
|
|
vocoder_name="vocos", is_local=False, local_path="", device="cuda", hf_cache_dir=None, vocoder_trt_engine_path=None |
|
|
): |
|
|
if vocoder_name == "vocos": |
|
|
if vocoder_trt_engine_path is not None: |
|
|
vocoder = VocosTensorRT(engine_path=vocoder_trt_engine_path) |
|
|
else: |
|
|
|
|
|
if is_local: |
|
|
print(f"Load vocos from local path {local_path}") |
|
|
config_path = f"{local_path}/config.yaml" |
|
|
model_path = f"{local_path}/pytorch_model.bin" |
|
|
else: |
|
|
print("Download Vocos from huggingface charactr/vocos-mel-24khz") |
|
|
repo_id = "charactr/vocos-mel-24khz" |
|
|
config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml") |
|
|
model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin") |
|
|
vocoder = Vocos.from_hparams(config_path) |
|
|
state_dict = torch.load(model_path, map_location="cpu", weights_only=True) |
|
|
from vocos.feature_extractors import EncodecFeatures |
|
|
|
|
|
if isinstance(vocoder.feature_extractor, EncodecFeatures): |
|
|
encodec_parameters = { |
|
|
"feature_extractor.encodec." + key: value |
|
|
for key, value in vocoder.feature_extractor.encodec.state_dict().items() |
|
|
} |
|
|
state_dict.update(encodec_parameters) |
|
|
vocoder.load_state_dict(state_dict) |
|
|
vocoder = vocoder.eval().to(device) |
|
|
elif vocoder_name == "bigvgan": |
|
|
raise NotImplementedError("BigVGAN is not implemented yet") |
|
|
return vocoder |
|
|
|
|
|
|
|
|
class VocosTensorRT: |
|
|
def __init__(self, engine_path="./vocos_vocoder.plan", stream=None): |
|
|
TRT_LOGGER = trt.Logger(trt.Logger.WARNING) |
|
|
trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="") |
|
|
logger.info(f"Loading vocoder engine from {engine_path}") |
|
|
self.engine_path = engine_path |
|
|
with open(engine_path, "rb") as f: |
|
|
engine_buffer = f.read() |
|
|
self.session = Session.from_serialized_engine(engine_buffer) |
|
|
self.stream = stream if stream is not None else torch.cuda.current_stream().cuda_stream |
|
|
|
|
|
def decode(self, mels): |
|
|
mels = mels.contiguous() |
|
|
inputs = {"mel": mels} |
|
|
output_info = self.session.infer_shapes([TensorInfo("mel", trt.DataType.FLOAT, mels.shape)]) |
|
|
outputs = { |
|
|
t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device="cuda") for t in output_info |
|
|
} |
|
|
ok = self.session.run(inputs, outputs, self.stream) |
|
|
|
|
|
assert ok, "Runtime execution failed for vae session" |
|
|
|
|
|
samples = outputs["waveform"] |
|
|
return samples |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = get_args() |
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
assert torch.cuda.is_available() |
|
|
world_size, local_rank, rank = init_distributed() |
|
|
device = torch.device(f"cuda:{local_rank}") |
|
|
|
|
|
vocab_char_map, vocab_size = get_tokenizer(args.vocab_file, "custom") |
|
|
|
|
|
tllm_model_dir = args.tllm_model_dir |
|
|
with open(os.path.join(tllm_model_dir, "config.json")) as f: |
|
|
tllm_model_config = json.load(f) |
|
|
if args.backend_type == "trt": |
|
|
model = F5TTS( |
|
|
tllm_model_config, |
|
|
debug_mode=False, |
|
|
tllm_model_dir=tllm_model_dir, |
|
|
model_path=args.model_path, |
|
|
vocab_size=vocab_size, |
|
|
) |
|
|
elif args.backend_type == "pytorch": |
|
|
from f5_tts.infer.utils_infer import load_model |
|
|
from f5_tts.model import DiT |
|
|
|
|
|
pretrained_config = tllm_model_config["pretrained_config"] |
|
|
pt_model_config = dict( |
|
|
dim=pretrained_config["hidden_size"], |
|
|
depth=pretrained_config["num_hidden_layers"], |
|
|
heads=pretrained_config["num_attention_heads"], |
|
|
ff_mult=pretrained_config["ff_mult"], |
|
|
text_dim=pretrained_config["text_dim"], |
|
|
text_mask_padding=pretrained_config["text_mask_padding"], |
|
|
conv_layers=pretrained_config["conv_layers"], |
|
|
pe_attn_head=pretrained_config["pe_attn_head"], |
|
|
|
|
|
|
|
|
) |
|
|
model = load_model(DiT, pt_model_config, args.model_path) |
|
|
|
|
|
vocoder = load_vocoder( |
|
|
vocoder_name=args.vocoder, device=device, vocoder_trt_engine_path=args.vocoder_trt_engine_path |
|
|
) |
|
|
|
|
|
dataset = load_dataset( |
|
|
"yuekai/seed_tts", |
|
|
split=args.split_name, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
|
|
|
def add_estimated_duration(example): |
|
|
prompt_audio_len = example["prompt_audio"]["array"].shape[0] |
|
|
scale_factor = 1 + len(example["target_text"]) / len(example["prompt_text"]) |
|
|
estimated_duration = prompt_audio_len * scale_factor |
|
|
example["estimated_duration"] = estimated_duration / example["prompt_audio"]["sampling_rate"] |
|
|
return example |
|
|
|
|
|
dataset = dataset.map(add_estimated_duration) |
|
|
dataset = dataset.sort("estimated_duration", reverse=True) |
|
|
if args.use_perf: |
|
|
|
|
|
dataset_list_short = [dataset.select([24]) for i in range(8)] |
|
|
|
|
|
|
|
|
dataset = datasets.concatenate_datasets(dataset_list_short) |
|
|
if world_size > 1: |
|
|
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) |
|
|
else: |
|
|
|
|
|
sampler = None |
|
|
|
|
|
dataloader = DataLoader( |
|
|
dataset, |
|
|
batch_size=args.batch_size, |
|
|
sampler=sampler, |
|
|
shuffle=False, |
|
|
num_workers=args.num_workers, |
|
|
prefetch_factor=args.prefetch, |
|
|
collate_fn=lambda x: data_collator(x, vocab_char_map, use_perf=args.use_perf), |
|
|
) |
|
|
|
|
|
total_steps = len(dataset) |
|
|
|
|
|
if args.enable_warmup: |
|
|
for batch in dataloader: |
|
|
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device) |
|
|
text_pad_seq = batch["text_pad_sequence"].to(device) |
|
|
total_mel_lens = batch["estimated_reference_target_mel_len"] |
|
|
cond_pad_seq = F.pad(ref_mels, (0, 0, 0, max(total_mel_lens) - ref_mels.shape[1], 0, 0)) |
|
|
if args.backend_type == "trt": |
|
|
_ = model.sample( |
|
|
text_pad_seq, |
|
|
cond_pad_seq, |
|
|
ref_mel_lens, |
|
|
total_mel_lens, |
|
|
remove_input_padding=args.remove_input_padding, |
|
|
) |
|
|
elif args.backend_type == "pytorch": |
|
|
total_mel_lens = torch.tensor(total_mel_lens, device=device) |
|
|
with torch.inference_mode(): |
|
|
generated, _ = model.sample( |
|
|
cond=ref_mels, |
|
|
text=text_pad_seq, |
|
|
duration=total_mel_lens, |
|
|
steps=32, |
|
|
cfg_strength=2.0, |
|
|
sway_sampling_coef=-1, |
|
|
) |
|
|
|
|
|
if rank == 0: |
|
|
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs") |
|
|
|
|
|
decoding_time = 0 |
|
|
vocoder_time = 0 |
|
|
total_duration = 0 |
|
|
if args.use_perf: |
|
|
torch.cuda.cudart().cudaProfilerStart() |
|
|
total_decoding_time = time.time() |
|
|
for batch in dataloader: |
|
|
if args.use_perf: |
|
|
torch.cuda.nvtx.range_push("data sample") |
|
|
ref_mels, ref_mel_lens = batch["ref_mel_batch"].to(device), batch["ref_mel_len_batch"].to(device) |
|
|
text_pad_seq = batch["text_pad_sequence"].to(device) |
|
|
total_mel_lens = batch["estimated_reference_target_mel_len"] |
|
|
cond_pad_seq = F.pad(ref_mels, (0, 0, 0, max(total_mel_lens) - ref_mels.shape[1], 0, 0)) |
|
|
if args.use_perf: |
|
|
torch.cuda.nvtx.range_pop() |
|
|
if args.backend_type == "trt": |
|
|
generated, cost_time = model.sample( |
|
|
text_pad_seq, |
|
|
cond_pad_seq, |
|
|
ref_mel_lens, |
|
|
total_mel_lens, |
|
|
remove_input_padding=args.remove_input_padding, |
|
|
use_perf=args.use_perf, |
|
|
) |
|
|
elif args.backend_type == "pytorch": |
|
|
total_mel_lens = torch.tensor(total_mel_lens, device=device) |
|
|
with torch.inference_mode(): |
|
|
start_time = time.time() |
|
|
generated, _ = model.sample( |
|
|
cond=ref_mels, |
|
|
text=text_pad_seq, |
|
|
duration=total_mel_lens, |
|
|
lens=ref_mel_lens, |
|
|
steps=32, |
|
|
cfg_strength=2.0, |
|
|
sway_sampling_coef=-1, |
|
|
) |
|
|
cost_time = time.time() - start_time |
|
|
decoding_time += cost_time |
|
|
vocoder_start_time = time.time() |
|
|
target_rms = 0.1 |
|
|
target_sample_rate = 24000 |
|
|
for i, gen in enumerate(generated): |
|
|
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0) |
|
|
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32) |
|
|
if args.vocoder == "vocos": |
|
|
if args.use_perf: |
|
|
torch.cuda.nvtx.range_push("vocoder decode") |
|
|
generated_wave = vocoder.decode(gen_mel_spec).cpu() |
|
|
if args.use_perf: |
|
|
torch.cuda.nvtx.range_pop() |
|
|
else: |
|
|
generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu() |
|
|
|
|
|
if batch["ref_rms_list"][i] < target_rms: |
|
|
generated_wave = generated_wave * batch["ref_rms_list"][i] / target_rms |
|
|
|
|
|
utt = batch["ids"][i] |
|
|
torchaudio.save( |
|
|
f"{args.output_dir}/{utt}.wav", |
|
|
generated_wave, |
|
|
target_sample_rate, |
|
|
) |
|
|
total_duration += generated_wave.shape[1] / target_sample_rate |
|
|
vocoder_time += time.time() - vocoder_start_time |
|
|
if rank == 0: |
|
|
progress_bar.update(world_size * len(batch["ids"])) |
|
|
total_decoding_time = time.time() - total_decoding_time |
|
|
if rank == 0: |
|
|
progress_bar.close() |
|
|
rtf = total_decoding_time / total_duration |
|
|
s = f"RTF: {rtf:.4f}\n" |
|
|
s += f"total_duration: {total_duration:.3f} seconds\n" |
|
|
s += f"({total_duration / 3600:.2f} hours)\n" |
|
|
s += f"DiT time: {decoding_time:.3f} seconds ({decoding_time / 3600:.2f} hours)\n" |
|
|
s += f"Vocoder time: {vocoder_time:.3f} seconds ({vocoder_time / 3600:.2f} hours)\n" |
|
|
s += f"total decoding time: {total_decoding_time:.3f} seconds ({total_decoding_time / 3600:.2f} hours)\n" |
|
|
s += f"batch size: {args.batch_size}\n" |
|
|
print(s) |
|
|
|
|
|
with open(f"{args.output_dir}/rtf.txt", "w") as f: |
|
|
f.write(s) |
|
|
|
|
|
dist.barrier() |
|
|
dist.destroy_process_group() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|