|
|
|
|
|
import os, io, wave, base64, numpy as np |
|
|
|
|
|
from transformers import AutoProcessor, AutoConfig, AutoModel |
|
|
|
|
|
MODEL_ID = "sesame/csm-1b" |
|
|
TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") |
|
|
|
|
|
|
|
|
DEFAULT_SR = 24000 |
|
|
TRIM_DBFS = -42 |
|
|
TRIM_MAX_MS = 350 |
|
|
SPEED_MULTIPLIER = float(os.environ.get("CSM_SPEED_MULTIPLIER", "1.0")) |
|
|
|
|
|
|
|
|
processor = AutoProcessor.from_pretrained(MODEL_ID, token=TOKEN, trust_remote_code=True) |
|
|
config = AutoConfig.from_pretrained(MODEL_ID, token=TOKEN, trust_remote_code=True) |
|
|
model = AutoModel.from_pretrained(MODEL_ID, config=config, token=TOKEN, trust_remote_code=True) |
|
|
|
|
|
def _resample_linear(x: np.ndarray, src_sr: int, dst_sr: int): |
|
|
if src_sr == dst_sr or x.size == 0: |
|
|
return x.astype(np.float32, copy=False) |
|
|
ratio = float(dst_sr) / float(src_sr) |
|
|
out_len = max(1, int(round(x.size * ratio))) |
|
|
t = np.linspace(0.0, x.size - 1, num=out_len, dtype=np.float32) |
|
|
i0 = np.floor(t).astype(np.int32) |
|
|
i1 = np.minimum(i0 + 1, x.size - 1) |
|
|
frac = t - i0 |
|
|
y = (1.0 - frac) * x[i0] + frac * x[i1] |
|
|
return y.astype(np.float32, copy=False) |
|
|
|
|
|
def _trim_leading_silence(x: np.ndarray, sr: int, thresh_dbfs: float, max_ms: int): |
|
|
x = np.asarray(x, dtype=np.float32) |
|
|
thresh = 10.0 ** (thresh_dbfs / 20.0) |
|
|
max_n = int(sr * max(0, max_ms) / 1000) |
|
|
cut = 0 |
|
|
for i in range(min(x.size, max_n)): |
|
|
if abs(x[i]) > thresh: |
|
|
cut = i |
|
|
break |
|
|
if i == min(x.size, max_n) - 1: |
|
|
cut = i |
|
|
return x[cut:], int(round(cut * 1000 / sr)) |
|
|
|
|
|
def _tempo_boost(x: np.ndarray, sr: int, speed: float): |
|
|
|
|
|
if not (speed and speed > 1.01): |
|
|
return x |
|
|
up_sr = int(round(sr * speed)) |
|
|
return _resample_linear(_resample_linear(x, sr, up_sr), up_sr, sr) |
|
|
|
|
|
def _float_to_wav_bytes(x: np.ndarray, sr: int) -> bytes: |
|
|
x = np.clip(np.asarray(x, dtype=np.float32), -1.0, 1.0) |
|
|
i16 = (x * 32767.0).astype(np.int16) |
|
|
buf = io.BytesIO() |
|
|
with wave.open(buf, "wb") as wf: |
|
|
wf.setnchannels(1) |
|
|
wf.setsampwidth(2) |
|
|
wf.setframerate(int(sr)) |
|
|
wf.writeframes(i16.tobytes()) |
|
|
return buf.getvalue() |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = ""): |
|
|
pass |
|
|
|
|
|
def __call__(self, data: dict): |
|
|
try: |
|
|
text = (data.get("inputs") or data.get("text") or "").strip() |
|
|
params = data.get("parameters") or {} |
|
|
target_sr = int(params.get("sampleRate") or DEFAULT_SR) |
|
|
if not text: |
|
|
return {"status_code": 400, "headers": {"Content-Type": "text/plain"}, "body": "Missing text"} |
|
|
|
|
|
|
|
|
if not text.startswith("["): |
|
|
text = f"[0]{text}" |
|
|
|
|
|
|
|
|
inputs = processor(text, add_special_tokens=True) |
|
|
audio = model.generate(**inputs, output_audio=True) |
|
|
|
|
|
|
|
|
if hasattr(audio, "cpu"): |
|
|
audio = audio.detach().cpu().float().numpy() |
|
|
audio = np.asarray(audio, dtype=np.float32) |
|
|
|
|
|
|
|
|
audio, _ = _trim_leading_silence(audio, DEFAULT_SR, TRIM_DBFS, TRIM_MAX_MS) |
|
|
if SPEED_MULTIPLIER and SPEED_MULTIPLIER > 1.01: |
|
|
audio = _tempo_boost(audio, DEFAULT_SR, SPEED_MULTIPLIER) |
|
|
|
|
|
|
|
|
peak = float(np.max(np.abs(audio))) or 1.0 |
|
|
if peak > 0: |
|
|
audio = (audio / peak) * 0.85 |
|
|
|
|
|
|
|
|
out_sr = target_sr if target_sr != DEFAULT_SR else DEFAULT_SR |
|
|
if out_sr != DEFAULT_SR: |
|
|
audio = _resample_linear(audio, DEFAULT_SR, out_sr) |
|
|
|
|
|
|
|
|
wav_b64 = base64.b64encode(_float_to_wav_bytes(audio, out_sr)).decode("ascii") |
|
|
return { |
|
|
"status_code": 200, |
|
|
"headers": {"Content-Type": "audio/wav"}, |
|
|
"body": wav_b64, |
|
|
"isBase64Encoded": True |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
return {"status_code": 500, "headers": {"Content-Type": "text/plain"}, "body": f"CSM error: {e}"} |
|
|
|