acambece25 commited on
Commit
99dfb5a
·
verified ·
1 Parent(s): b7b57b9

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +52 -75
handler.py CHANGED
@@ -1,114 +1,91 @@
1
- # handler.py — Sesame CSM endpoint: 24kHz output + leading-silence trim
2
- import os, io, wave, base64, math
3
- import numpy as np
4
- import torch
5
- from transformers import AutoProcessor, CsmForConditionalGeneration
6
 
7
  MODEL_ID = "sesame/csm-1b"
8
  TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
9
 
10
- TARGET_SR = 24000 # <- hard target sample rate for telephony bridge
11
- TRIM_DBFS = -42 # leading silence threshold
12
- TRIM_MAX_MS = 350 # max leading trim
13
- SPEED_MULTIPLIER = float(os.environ.get("CSM_SPEED_MULTIPLIER", "1.0")) # optional 1.10–1.25
 
 
 
 
14
 
15
- device = "cuda" if torch.cuda.is_available() else "cpu"
16
- processor = AutoProcessor.from_pretrained(MODEL_ID, token=TOKEN)
17
- model = CsmForConditionalGeneration.from_pretrained(
18
- MODEL_ID, token=TOKEN,
19
- torch_dtype=torch.float16 if device == "cuda" else None
20
- ).to(device)
 
 
 
 
21
 
22
  def _trim_leading_silence(x: np.ndarray, sr: int, thresh_dbfs: float, max_ms: int):
23
  x = np.asarray(x, dtype=np.float32)
24
- thresh = (10.0 ** (thresh_dbfs / 20.0))
25
- max_samples = int(sr * max(0, max_ms) / 1000)
26
  cut = 0
27
- for i in range(min(len(x), max_samples)):
28
- if abs(x[i]) > thresh:
29
- cut = i
30
- break
31
- if i == min(len(x), max_samples) - 1:
32
- cut = i
33
  return x[cut:], int(round(cut * 1000 / sr))
34
 
35
  def _tempo_boost(x: np.ndarray, sr: int, speed: float):
36
- if not (speed and speed > 1.01):
37
- return x
38
- # crude tempo increase via resample to higher SR then back to original SR (raises pitch a bit)
39
  up_sr = int(round(sr * speed))
40
- x_up = _resample_linear(x, sr, up_sr)
41
- return _resample_linear(x_up, up_sr, sr)
42
-
43
- def _resample_linear(x: np.ndarray, src_sr: int, dst_sr: int):
44
- if src_sr == dst_sr or len(x) == 0:
45
- return x.astype(np.float32, copy=False)
46
- # linear interpolation in float32
47
- ratio = float(dst_sr) / float(src_sr)
48
- out_len = max(1, int(round(len(x) * ratio)))
49
- t = np.linspace(0.0, len(x) - 1, num=out_len, dtype=np.float32)
50
- i0 = np.floor(t).astype(np.int32)
51
- i1 = np.minimum(i0 + 1, len(x) - 1)
52
- frac = t - i0
53
- y = (1.0 - frac) * x[i0] + frac * x[i1]
54
- return y.astype(np.float32, copy=False)
55
 
56
  def _float_to_wav_bytes(x: np.ndarray, sr: int) -> bytes:
57
- # clamp -> int16
58
  x = np.clip(np.asarray(x, dtype=np.float32), -1.0, 1.0)
59
  i16 = (x * 32767.0).astype(np.int16)
60
  buf = io.BytesIO()
61
  with wave.open(buf, "wb") as wf:
62
- wf.setnchannels(1)
63
- wf.setsampwidth(2)
64
- wf.setframerate(int(sr))
65
  wf.writeframes(i16.tobytes())
66
  return buf.getvalue()
67
 
68
  class EndpointHandler:
69
- def __init__(self, path: str = ""):
70
- pass
71
 
72
  def __call__(self, data: dict):
73
  try:
74
  text = (data.get("inputs") or data.get("text") or "").strip()
75
- # ensure Sesame speaker prefix exists
76
- if text and not text.startswith("["):
 
 
 
77
  text = f"[0]{text}"
78
 
79
- # 1) generate audio (model-native rate)
80
- inputs = processor(text, add_special_tokens=True).to(model.device)
 
81
  audio = model.generate(**inputs, output_audio=True)
82
- if isinstance(audio, torch.Tensor):
83
  audio = audio.detach().cpu().float().numpy()
 
84
 
85
- # 2) trim leading silence (model often leaves a big gap)
86
- audio, _ = _trim_leading_silence(audio, sr=TARGET_SR, thresh_dbfs=TRIM_DBFS, max_ms=TRIM_MAX_MS)
87
-
88
- # 3) tempo boost if requested (optional)
89
- audio = _tempo_boost(audio, TARGET_SR, SPEED_MULTIPLIER)
90
-
91
- # 4) upsample/downsample to TARGET_SR
92
- audio_24k = _resample_linear(audio, src_sr=TARGET_SR, dst_sr=TARGET_SR)
93
 
94
- # 5) to WAV (24k mono 16-bit)
95
- wav_bytes = _float_to_wav_bytes(audio_24k, TARGET_SR)
96
- b64 = base64.b64encode(wav_bytes).decode("ascii")
97
 
 
98
  return {
99
  "status_code": 200,
100
- "statusCode": 200,
101
  "headers": {"Content-Type": "audio/wav"},
102
- "body": b64,
103
- "isBase64Encoded": True,
104
- "is_base64_encoded": True,
105
  }
106
  except Exception as e:
107
- return {
108
- "status_code": 500,
109
- "statusCode": 500,
110
- "headers": {"Content-Type": "text/plain"},
111
- "body": f"CSM error: {e}",
112
- "isBase64Encoded": False,
113
- "is_base64_encoded": False,
114
- }
 
1
+ # handler.py — Sesame CSM @ 24kHz + trim + optional tempo
2
+ import os, io, wave, base64, numpy as np
3
+
4
+ from transformers import AutoProcessor, AutoModel
 
5
 
6
  MODEL_ID = "sesame/csm-1b"
7
  TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
8
 
9
+ TARGET_SR = 24000 # force 24 kHz out
10
+ TRIM_DBFS = -42 # leading silence cutoff (≈ quiet room)
11
+ TRIM_MAX_MS = 350 # cap leading trim
12
+ SPEED_MULTIPLIER = float(os.environ.get("CSM_SPEED_MULTIPLIER", "1.0")) # e.g. 1.12..1.22
13
+
14
+ # ---- load via remote code (avoids missing Csm* import) ----
15
+ processor = AutoProcessor.from_pretrained(MODEL_ID, token=TOKEN, trust_remote_code=True)
16
+ model = AutoModel.from_pretrained(MODEL_ID, token=TOKEN, trust_remote_code=True)
17
 
18
+ def _resample_linear(x: np.ndarray, src_sr: int, dst_sr: int):
19
+ if src_sr == dst_sr or x.size == 0: return x.astype(np.float32, copy=False)
20
+ ratio = float(dst_sr) / float(src_sr)
21
+ out_len = max(1, int(round(x.size * ratio)))
22
+ t = np.linspace(0.0, x.size - 1, num=out_len, dtype=np.float32)
23
+ i0 = np.floor(t).astype(np.int32)
24
+ i1 = np.minimum(i0 + 1, x.size - 1)
25
+ frac = t - i0
26
+ y = (1.0 - frac) * x[i0] + frac * x[i1]
27
+ return y.astype(np.float32, copy=False)
28
 
29
  def _trim_leading_silence(x: np.ndarray, sr: int, thresh_dbfs: float, max_ms: int):
30
  x = np.asarray(x, dtype=np.float32)
31
+ thresh = 10.0 ** (thresh_dbfs / 20.0)
32
+ max_n = int(sr * max(0, max_ms) / 1000)
33
  cut = 0
34
+ for i in range(min(x.size, max_n)):
35
+ if abs(x[i]) > thresh: cut = i; break
36
+ if i == min(x.size, max_n) - 1: cut = i
 
 
 
37
  return x[cut:], int(round(cut * 1000 / sr))
38
 
39
  def _tempo_boost(x: np.ndarray, sr: int, speed: float):
40
+ if not (speed and speed > 1.01): return x
 
 
41
  up_sr = int(round(sr * speed))
42
+ return _resample_linear(_resample_linear(x, sr, up_sr), up_sr, sr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  def _float_to_wav_bytes(x: np.ndarray, sr: int) -> bytes:
 
45
  x = np.clip(np.asarray(x, dtype=np.float32), -1.0, 1.0)
46
  i16 = (x * 32767.0).astype(np.int16)
47
  buf = io.BytesIO()
48
  with wave.open(buf, "wb") as wf:
49
+ wf.setnchannels(1); wf.setsampwidth(2); wf.setframerate(int(sr))
 
 
50
  wf.writeframes(i16.tobytes())
51
  return buf.getvalue()
52
 
53
  class EndpointHandler:
54
+ def __init__(self, path: str = ""): pass
 
55
 
56
  def __call__(self, data: dict):
57
  try:
58
  text = (data.get("inputs") or data.get("text") or "").strip()
59
+ if not text:
60
+ return {"status_code":400,"headers":{"Content-Type":"text/plain"},"body":"Missing text"}
61
+
62
+ # CSM speaker prefix if absent
63
+ if not text.startswith("["):
64
  text = f"[0]{text}"
65
 
66
+ # generate (model defines its own rate internally)
67
+ inputs = processor(text, add_special_tokens=True)
68
+ # sesame remote code supports output_audio=True
69
  audio = model.generate(**inputs, output_audio=True)
70
+ if hasattr(audio, "cpu"): # torch tensor
71
  audio = audio.detach().cpu().float().numpy()
72
+ audio = np.asarray(audio, dtype=np.float32)
73
 
74
+ # trim + (optional) tempo boost
75
+ audio, _ = _trim_leading_silence(audio, TARGET_SR, TRIM_DBFS, TRIM_MAX_MS)
76
+ if SPEED_MULTIPLIER and SPEED_MULTIPLIER > 1.01:
77
+ audio = _tempo_boost(audio, TARGET_SR, SPEED_MULTIPLIER)
 
 
 
 
78
 
79
+ # normalize gentle
80
+ peak = float(np.max(np.abs(audio))) or 1.0
81
+ if peak > 0: audio = (audio / peak) * 0.85
82
 
83
+ wav_b64 = base64.b64encode(_float_to_wav_bytes(audio, TARGET_SR)).decode("ascii")
84
  return {
85
  "status_code": 200,
 
86
  "headers": {"Content-Type": "audio/wav"},
87
+ "body": wav_b64,
88
+ "isBase64Encoded": True
 
89
  }
90
  except Exception as e:
91
+ return {"status_code":500,"headers":{"Content-Type":"text/plain"},"body":f"CSM error: {e}"}