File size: 4,667 Bytes
ff41861
99dfb5a
60c0495
ff41861
ae649aa
 
b7b57b9
ab5158a
ff41861
 
 
 
a464634
99dfb5a
60c0495
99dfb5a
ff41861
60c0495
b7b57b9
99dfb5a
60c0495
 
99dfb5a
 
 
 
 
 
 
 
b7b57b9
 
 
99dfb5a
 
b7b57b9
99dfb5a
a464634
 
 
 
 
b7b57b9
 
 
ff41861
60c0495
 
b7b57b9
99dfb5a
b7b57b9
 
 
 
ae649aa
 
a464634
 
 
ae649aa
 
 
 
a464634
 
ffa3ef8
ae649aa
 
9da5d99
a464634
60c0495
99dfb5a
a464634
99dfb5a
ff41861
99dfb5a
ae649aa
776eedc
ff41861
99dfb5a
b7b57b9
a464634
ff41861
 
b7b57b9
99dfb5a
b7b57b9
ff41861
60c0495
99dfb5a
60c0495
b7b57b9
ff41861
99dfb5a
a464634
 
 
60c0495
ff41861
 
 
776eedc
60c0495
ff41861
776eedc
b7b57b9
776eedc
99dfb5a
 
776eedc
60c0495
ae649aa
a464634
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# handler.py — Sesame CSM @ 24 kHz, trims lead-in, optional tempo boost
import os, io, wave, base64, numpy as np

from transformers import AutoProcessor, AutoConfig, AutoModel  # AutoConfig is critical

MODEL_ID = "sesame/csm-1b"
TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")

# Output defaults
DEFAULT_SR = 24000
TRIM_DBFS = -42        # trim leading silence below this level
TRIM_MAX_MS = 350      # max trim at start (ms)
SPEED_MULTIPLIER = float(os.environ.get("CSM_SPEED_MULTIPLIER", "1.0"))  # e.g. 1.12

# ---- Load with remote code enabled at BOTH config and model levels ----
processor = AutoProcessor.from_pretrained(MODEL_ID, token=TOKEN, trust_remote_code=True)
config = AutoConfig.from_pretrained(MODEL_ID, token=TOKEN, trust_remote_code=True)  # <-- key fix
model = AutoModel.from_pretrained(MODEL_ID, config=config, token=TOKEN, trust_remote_code=True)

def _resample_linear(x: np.ndarray, src_sr: int, dst_sr: int):
    if src_sr == dst_sr or x.size == 0:
        return x.astype(np.float32, copy=False)
    ratio = float(dst_sr) / float(src_sr)
    out_len = max(1, int(round(x.size * ratio)))
    t = np.linspace(0.0, x.size - 1, num=out_len, dtype=np.float32)
    i0 = np.floor(t).astype(np.int32)
    i1 = np.minimum(i0 + 1, x.size - 1)
    frac = t - i0
    y = (1.0 - frac) * x[i0] + frac * x[i1]
    return y.astype(np.float32, copy=False)

def _trim_leading_silence(x: np.ndarray, sr: int, thresh_dbfs: float, max_ms: int):
    x = np.asarray(x, dtype=np.float32)
    thresh = 10.0 ** (thresh_dbfs / 20.0)
    max_n = int(sr * max(0, max_ms) / 1000)
    cut = 0
    for i in range(min(x.size, max_n)):
        if abs(x[i]) > thresh:
            cut = i
            break
        if i == min(x.size, max_n) - 1:
            cut = i
    return x[cut:], int(round(cut * 1000 / sr))

def _tempo_boost(x: np.ndarray, sr: int, speed: float):
    # cheap & safe: upsample to sr*speed then back to sr (time compression)
    if not (speed and speed > 1.01):
        return x
    up_sr = int(round(sr * speed))
    return _resample_linear(_resample_linear(x, sr, up_sr), up_sr, sr)

def _float_to_wav_bytes(x: np.ndarray, sr: int) -> bytes:
    x = np.clip(np.asarray(x, dtype=np.float32), -1.0, 1.0)
    i16 = (x * 32767.0).astype(np.int16)
    buf = io.BytesIO()
    with wave.open(buf, "wb") as wf:
        wf.setnchannels(1)
        wf.setsampwidth(2)
        wf.setframerate(int(sr))
        wf.writeframes(i16.tobytes())
    return buf.getvalue()

class EndpointHandler:
    def __init__(self, path: str = ""):
        pass

    def __call__(self, data: dict):
        try:
            text = (data.get("inputs") or data.get("text") or "").strip()
            params = data.get("parameters") or {}
            target_sr = int(params.get("sampleRate") or DEFAULT_SR)
            if not text:
                return {"status_code": 400, "headers": {"Content-Type": "text/plain"}, "body": "Missing text"}

            # Sesame expects a speaker tag; default to [0]
            if not text.startswith("["):
                text = f"[0]{text}"

            # Tokenize/encode; remote code handles generation
            inputs = processor(text, add_special_tokens=True)
            audio = model.generate(**inputs, output_audio=True)

            # → float32 numpy
            if hasattr(audio, "cpu"):
                audio = audio.detach().cpu().float().numpy()
            audio = np.asarray(audio, dtype=np.float32)

            # Trim lead-in + optional tempo boost
            audio, _ = _trim_leading_silence(audio, DEFAULT_SR, TRIM_DBFS, TRIM_MAX_MS)
            if SPEED_MULTIPLIER and SPEED_MULTIPLIER > 1.01:
                audio = _tempo_boost(audio, DEFAULT_SR, SPEED_MULTIPLIER)

            # Normalize loudness
            peak = float(np.max(np.abs(audio))) or 1.0
            if peak > 0:
                audio = (audio / peak) * 0.85

            # Resample if caller asked for a different rate
            out_sr = target_sr if target_sr != DEFAULT_SR else DEFAULT_SR
            if out_sr != DEFAULT_SR:
                audio = _resample_linear(audio, DEFAULT_SR, out_sr)

            # Return base64 WAV (toolkit expects base64 when content-type is audio/*)
            wav_b64 = base64.b64encode(_float_to_wav_bytes(audio, out_sr)).decode("ascii")
            return {
                "status_code": 200,
                "headers": {"Content-Type": "audio/wav"},
                "body": wav_b64,
                "isBase64Encoded": True
            }

        except Exception as e:
            return {"status_code": 500, "headers": {"Content-Type": "text/plain"}, "body": f"CSM error: {e}"}