csm-tts-endpoint / handler.py.disabled
acambece25's picture
Rename handler.py to handler.py.disabled
0925957 verified
# handler.py — Sesame CSM @ 24 kHz, trims lead-in, optional tempo boost
import os, io, wave, base64, numpy as np
from transformers import AutoProcessor, AutoConfig, AutoModel # AutoConfig is critical
MODEL_ID = "sesame/csm-1b"
TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
# Output defaults
DEFAULT_SR = 24000
TRIM_DBFS = -42 # trim leading silence below this level
TRIM_MAX_MS = 350 # max trim at start (ms)
SPEED_MULTIPLIER = float(os.environ.get("CSM_SPEED_MULTIPLIER", "1.0")) # e.g. 1.12
# ---- Load with remote code enabled at BOTH config and model levels ----
processor = AutoProcessor.from_pretrained(MODEL_ID, token=TOKEN, trust_remote_code=True)
config = AutoConfig.from_pretrained(MODEL_ID, token=TOKEN, trust_remote_code=True) # <-- key fix
model = AutoModel.from_pretrained(MODEL_ID, config=config, token=TOKEN, trust_remote_code=True)
def _resample_linear(x: np.ndarray, src_sr: int, dst_sr: int):
if src_sr == dst_sr or x.size == 0:
return x.astype(np.float32, copy=False)
ratio = float(dst_sr) / float(src_sr)
out_len = max(1, int(round(x.size * ratio)))
t = np.linspace(0.0, x.size - 1, num=out_len, dtype=np.float32)
i0 = np.floor(t).astype(np.int32)
i1 = np.minimum(i0 + 1, x.size - 1)
frac = t - i0
y = (1.0 - frac) * x[i0] + frac * x[i1]
return y.astype(np.float32, copy=False)
def _trim_leading_silence(x: np.ndarray, sr: int, thresh_dbfs: float, max_ms: int):
x = np.asarray(x, dtype=np.float32)
thresh = 10.0 ** (thresh_dbfs / 20.0)
max_n = int(sr * max(0, max_ms) / 1000)
cut = 0
for i in range(min(x.size, max_n)):
if abs(x[i]) > thresh:
cut = i
break
if i == min(x.size, max_n) - 1:
cut = i
return x[cut:], int(round(cut * 1000 / sr))
def _tempo_boost(x: np.ndarray, sr: int, speed: float):
# cheap & safe: upsample to sr*speed then back to sr (time compression)
if not (speed and speed > 1.01):
return x
up_sr = int(round(sr * speed))
return _resample_linear(_resample_linear(x, sr, up_sr), up_sr, sr)
def _float_to_wav_bytes(x: np.ndarray, sr: int) -> bytes:
x = np.clip(np.asarray(x, dtype=np.float32), -1.0, 1.0)
i16 = (x * 32767.0).astype(np.int16)
buf = io.BytesIO()
with wave.open(buf, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(int(sr))
wf.writeframes(i16.tobytes())
return buf.getvalue()
class EndpointHandler:
def __init__(self, path: str = ""):
pass
def __call__(self, data: dict):
try:
text = (data.get("inputs") or data.get("text") or "").strip()
params = data.get("parameters") or {}
target_sr = int(params.get("sampleRate") or DEFAULT_SR)
if not text:
return {"status_code": 400, "headers": {"Content-Type": "text/plain"}, "body": "Missing text"}
# Sesame expects a speaker tag; default to [0]
if not text.startswith("["):
text = f"[0]{text}"
# Tokenize/encode; remote code handles generation
inputs = processor(text, add_special_tokens=True)
audio = model.generate(**inputs, output_audio=True)
# → float32 numpy
if hasattr(audio, "cpu"):
audio = audio.detach().cpu().float().numpy()
audio = np.asarray(audio, dtype=np.float32)
# Trim lead-in + optional tempo boost
audio, _ = _trim_leading_silence(audio, DEFAULT_SR, TRIM_DBFS, TRIM_MAX_MS)
if SPEED_MULTIPLIER and SPEED_MULTIPLIER > 1.01:
audio = _tempo_boost(audio, DEFAULT_SR, SPEED_MULTIPLIER)
# Normalize loudness
peak = float(np.max(np.abs(audio))) or 1.0
if peak > 0:
audio = (audio / peak) * 0.85
# Resample if caller asked for a different rate
out_sr = target_sr if target_sr != DEFAULT_SR else DEFAULT_SR
if out_sr != DEFAULT_SR:
audio = _resample_linear(audio, DEFAULT_SR, out_sr)
# Return base64 WAV (toolkit expects base64 when content-type is audio/*)
wav_b64 = base64.b64encode(_float_to_wav_bytes(audio, out_sr)).decode("ascii")
return {
"status_code": 200,
"headers": {"Content-Type": "audio/wav"},
"body": wav_b64,
"isBase64Encoded": True
}
except Exception as e:
return {"status_code": 500, "headers": {"Content-Type": "text/plain"}, "body": f"CSM error: {e}"}