import os, asyncio, json, base64, time, tempfile, io from typing import Optional, Dict, Any from contextlib import asynccontextmanager import torch, numpy as np, uvicorn from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Query, File, UploadFile, Form, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from loguru import logger import librosa from pydantic import BaseModel # --- Moshi Streaming Imports --- from moshi.models import loaders, MimiModel, LMModel, LMGen import sentencepiece # --- OpenAI Whisper API Compatible Response Models --- class TranscriptionWord(BaseModel): word: str start: float end: float class TranscriptionSegment(BaseModel): id: int seek: float start: float end: float text: str tokens: list[int] = [] temperature: float = 0.0 avg_logprob: float = 0.0 compression_ratio: float = 0.0 no_speech_prob: float = 0.0 words: Optional[list[TranscriptionWord]] = None class TranscriptionResponse(BaseModel): text: str task: str = "transcribe" language: str = "en" duration: float segments: Optional[list[TranscriptionSegment]] = None # --- Core Streaming Engine --- class StreamingKyutaiEngine: def __init__(self, device: str): self.device = device logger.info("🚀 Loading Moshi streaming model components...") checkpoint_info = loaders.CheckpointInfo.from_hf_repo("kyutai/stt-1b-en_fr") self.mimi: MimiModel = checkpoint_info.get_mimi(device=device) self.text_tokenizer: sentencepiece.SentencePieceProcessor = checkpoint_info.get_text_tokenizer() self.lm_model: LMModel = checkpoint_info.get_moshi(device=device) self.frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate) self.sample_rate = self.mimi.sample_rate self._model_loaded = True # --- Lock to protect the stateful model --- self.lock = asyncio.Lock() logger.info(f"🎉 Moshi streaming engine loaded on {self.device}") logger.info(f"📊 Sample rate: {self.sample_rate}Hz, Frame size: {self.frame_size}") async def transcribe_audio_file(self, audio_data: np.ndarray, sample_rate: int = None) -> tuple[str, float]: """Transcribe audio file and return (text, duration)""" async with self.lock: try: # Resample if necessary if sample_rate and sample_rate != self.sample_rate: audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=self.sample_rate) duration = len(audio_data) / self.sample_rate # Create a new generator and set up the streaming context lm_gen = LMGen(self.lm_model, temp=0, temp_text=0, use_sampling=False) transcription_text = "" with self.mimi.streaming(batch_size=1), lm_gen.streaming(batch_size=1): first_frame = True # Process audio in chunks for i in range(0, len(audio_data), self.frame_size): chunk = audio_data[i:i + self.frame_size] if len(chunk) == self.frame_size: writable_chunk = chunk.copy() in_pcms = torch.from_numpy(writable_chunk).to(self.device).unsqueeze(0).unsqueeze(0) codes = self.mimi.encode(in_pcms) if first_frame: lm_gen.step(codes) first_frame = False tokens = lm_gen.step(codes) if tokens is None: continue text_id = tokens[0, 0].cpu().item() if text_id not in [0, 3]: text_fragment = self.text_tokenizer.id_to_piece(text_id) clean_fragment = text_fragment.replace("▁", " ") transcription_text += clean_fragment return transcription_text.strip(), duration except Exception as e: logger.error(f"Error transcribing audio: {e}") return "", 0.0 # Global engine instance stt_engine: Optional[StreamingKyutaiEngine] = None @asynccontextmanager async def lifespan(app: FastAPI): """Modern FastAPI lifespan management""" # Startup global stt_engine device = "cuda" if torch.cuda.is_available() else "cpu" stt_engine = StreamingKyutaiEngine(device=device) logger.info("✅ Kyutai OpenAI Whisper API Compatible service is ready.") yield # Shutdown (if needed) logger.info("🔄 Shutting down Kyutai service...") # --- FastAPI App Setup with modern lifespan --- app = FastAPI( title="Kyutai OpenAI Whisper API Compatible STT", version="3.0.0", lifespan=lifespan ) app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) @app.get("/health") async def health_check(): is_ready = stt_engine and stt_engine._model_loaded if is_ready: return {"status": "healthy", "model_loaded": True, "api_format": "openai_whisper_compatible"} else: return {"status": "unhealthy", "model_loaded": False}, 503 # --- OpenAI Whisper API Compatible Endpoints --- @app.post("/v1/audio/transcriptions", response_model=TranscriptionResponse) async def create_transcription( file: UploadFile = File(...), model: str = Form("whisper-1"), language: Optional[str] = Form(None), prompt: Optional[str] = Form(None), response_format: str = Form("json"), temperature: float = Form(0.0), timestamp_granularities: Optional[str] = Form(None) ): """ OpenAI Whisper API Compatible transcription endpoint Compatible with: - OpenAI official clients - Groq API clients - Any Whisper API client Just change the base_url to point here! """ if not stt_engine: raise HTTPException(status_code=503, detail="STT engine not ready") try: # Read the uploaded file audio_content = await file.read() # Load audio using librosa (supports many formats) audio_data, original_sr = librosa.load(io.BytesIO(audio_content), sr=None, mono=True) logger.info(f"Processing audio file: {file.filename}, duration: {len(audio_data)/original_sr:.2f}s") # Transcribe using Kyutai engine transcription_text, duration = await stt_engine.transcribe_audio_file(audio_data, original_sr) # Create OpenAI-compatible response if response_format == "text": from fastapi.responses import PlainTextResponse return PlainTextResponse(content=transcription_text, media_type="text/plain") elif response_format == "srt": # Simple SRT format srt_content = f"1\n00:00:00,000 --> {int(duration//60):02d}:{int(duration%60):02d},{int((duration%1)*1000):03d}\n{transcription_text}\n" from fastapi.responses import PlainTextResponse return PlainTextResponse(content=srt_content, media_type="text/plain") elif response_format == "vtt": # Simple VTT format vtt_content = f"WEBVTT\n\n00:00:00.000 --> {int(duration//60):02d}:{int(duration%60):02d}.{int((duration%1)*1000):03d}\n{transcription_text}\n" from fastapi.responses import PlainTextResponse return PlainTextResponse(content=vtt_content, media_type="text/plain") else: # Default JSON response (OpenAI format) segments = [] if transcription_text: segments = [ TranscriptionSegment( id=0, seek=0.0, start=0.0, end=duration, text=transcription_text, tokens=[], temperature=temperature, avg_logprob=0.0, compression_ratio=1.0, no_speech_prob=0.0 ) ] return TranscriptionResponse( text=transcription_text, task="transcribe", language=language or "en", duration=duration, segments=segments if timestamp_granularities else None ) except Exception as e: logger.error(f"Transcription error: {e}") raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}") @app.post("/v1/audio/translations", response_model=TranscriptionResponse) async def create_translation( file: UploadFile = File(...), model: str = Form("whisper-1"), prompt: Optional[str] = Form(None), response_format: str = Form("json"), temperature: float = Form(0.0) ): """ OpenAI Whisper API Compatible translation endpoint Note: Kyutai model outputs English, so this behaves the same as transcription """ # For now, treat translation the same as transcription since Kyutai outputs English return await create_transcription( file=file, model=model, language="en", # Force English for translation prompt=prompt, response_format=response_format, temperature=temperature ) # --- STREAMING WEBSOCKET ENDPOINTS (FULLY IMPLEMENTED) --- @app.websocket("/v1/audio/stream") async def streaming_websocket(websocket: WebSocket): """ Real-time audio streaming endpoint Protocol: 1. Client connects 2. Server sends {"type": "ready", "sample_rate": 24000} 3. Client sends binary audio chunks (PCM float32, 24kHz, mono) 4. Server sends {"type": "transcription", "text": "...", "accumulated": "...", "is_final": false} 5. Client sends {"type": "finalize"} to get final transcription 6. Server sends {"type": "transcription", "text": "...", "is_final": true} Commands: - {"type": "finalize"} - Get final transcription - {"type": "reset"} - Clear transcription buffer - {"type": "stop"} - Close connection """ await websocket.accept() if not stt_engine: await websocket.close(code=1011, reason="STT engine not ready") return logger.info("🔌 New streaming connection established") async with stt_engine.lock: try: # Initialize streaming context lm_gen = LMGen(stt_engine.lm_model, temp=0, temp_text=0, use_sampling=False) transcription_buffer = "" audio_buffer = np.array([], dtype=np.float32) first_frame = True frames_processed = 0 with stt_engine.mimi.streaming(batch_size=1), lm_gen.streaming(batch_size=1): # Send ready signal await websocket.send_json({ "type": "ready", "sample_rate": stt_engine.sample_rate, "frame_size": stt_engine.frame_size }) logger.info(f"✅ Sent ready signal (sample_rate: {stt_engine.sample_rate}Hz)") while True: try: message = await asyncio.wait_for(websocket.receive(), timeout=30.0) except asyncio.TimeoutError: logger.warning("⏱️ WebSocket timeout - no data received for 30s") break if message["type"] == "websocket.disconnect": logger.info("👋 Client disconnected") break # Handle binary audio data if message["type"] == "websocket.receive" and "bytes" in message: audio_bytes = message["bytes"] # Convert to float32 numpy array audio_chunk = np.frombuffer(audio_bytes, dtype=np.float32) # Add to buffer audio_buffer = np.concatenate([audio_buffer, audio_chunk]) # Process complete frames while len(audio_buffer) >= stt_engine.frame_size: # Extract one frame frame = audio_buffer[:stt_engine.frame_size] audio_buffer = audio_buffer[stt_engine.frame_size:] # Convert to torch tensor in_pcms = torch.from_numpy(frame.copy()).to(stt_engine.device).unsqueeze(0).unsqueeze(0) # Encode audio codes = stt_engine.mimi.encode(in_pcms) # Generate tokens if first_frame: lm_gen.step(codes) first_frame = False frames_processed += 1 continue tokens = lm_gen.step(codes) frames_processed += 1 if tokens is not None: text_id = tokens[0, 0].cpu().item() # Filter special tokens if text_id not in [0, 3]: text_fragment = stt_engine.text_tokenizer.id_to_piece(text_id) clean_fragment = text_fragment.replace("▁", " ") transcription_buffer += clean_fragment # Send progressive transcription await websocket.send_json({ "type": "transcription", "text": clean_fragment, "accumulated": transcription_buffer.strip(), "is_final": False, "frames_processed": frames_processed }) logger.debug(f"📝 Sent fragment: '{clean_fragment}'") # Handle text commands elif message["type"] == "websocket.receive" and "text" in message: try: data = json.loads(message["text"]) if data.get("type") == "finalize": # Send final transcription final_text = transcription_buffer.strip() await websocket.send_json({ "type": "transcription", "text": final_text, "is_final": True, "frames_processed": frames_processed }) logger.info(f"✅ Finalized transcription ({len(final_text)} chars, {frames_processed} frames)") elif data.get("type") == "reset": # Reset transcription buffer transcription_buffer = "" audio_buffer = np.array([], dtype=np.float32) frames_processed = 0 await websocket.send_json({"type": "reset_confirmed"}) logger.info("🔄 Transcription reset") elif data.get("type") == "stop": logger.info("🛑 Client requested stop") break except json.JSONDecodeError: logger.error("❌ Invalid JSON received from client") await websocket.send_json({"type": "error", "message": "Invalid JSON"}) except WebSocketDisconnect: logger.info("🔌 WebSocket disconnected") except Exception as e: logger.error(f"❌ Streaming error: {e}", exc_info=True) try: await websocket.send_json({"type": "error", "message": str(e)}) except: pass finally: try: await websocket.close() except: pass logger.info("🔒 Streaming connection closed") @app.websocket("/v1/realtime") async def openai_realtime_websocket( websocket: WebSocket, model: str = Query(default="kyutai/stt-1b-en_fr") ): """ OpenAI Realtime API Compatible WebSocket endpoint Protocol follows OpenAI's realtime API structure with session management Events: - session.created: Sent on connection - input_audio_buffer.append: Client sends audio (base64 PCM16) - conversation.item.input_audio_transcription.delta: Server sends partial text - input_audio_buffer.commit: Client requests final transcription - conversation.item.input_audio_transcription.completed: Server sends final text - input_audio_buffer.clear: Clear buffers """ await websocket.accept() if not stt_engine: await websocket.close(code=1011, reason="STT engine not ready") return session_id = f"sess_{int(time.time())}_{id(websocket)}" logger.info(f"🔌 New realtime session: {session_id}") # Send session created event await websocket.send_text(json.dumps({ "type": "session.created", "session": { "id": session_id, "model": model, "modalities": ["text", "audio"], "instructions": "Real-time speech-to-text transcription using Kyutai Moshi model", "voice": "kyutai", "input_audio_format": "pcm16", "output_audio_format": "pcm16", "input_audio_transcription": { "model": "kyutai-stt-1b" }, "turn_detection": None, "tools": [], "tool_choice": "auto", "temperature": 0.0, "max_output_tokens": None } })) async with stt_engine.lock: try: lm_gen = LMGen(stt_engine.lm_model, temp=0, temp_text=0, use_sampling=False) transcription_buffer = "" audio_buffer = np.array([], dtype=np.float32) first_frame = True item_id = f"item_{int(time.time())}" with stt_engine.mimi.streaming(batch_size=1), lm_gen.streaming(batch_size=1): while True: try: message = await asyncio.wait_for(websocket.receive(), timeout=30.0) except asyncio.TimeoutError: logger.warning(f"⏱️ Session {session_id} timeout") break if message["type"] == "websocket.disconnect": break # Handle text events (OpenAI format) if message["type"] == "websocket.receive" and "text" in message: try: event = json.loads(message["text"]) if event.get("type") == "input_audio_buffer.append": # Decode base64 audio (PCM16) audio_b64 = event.get("audio", "") audio_bytes = base64.b64decode(audio_b64) # Convert PCM16 to float32 (-1.0 to 1.0) audio_chunk = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0 audio_buffer = np.concatenate([audio_buffer, audio_chunk]) # Process frames while len(audio_buffer) >= stt_engine.frame_size: frame = audio_buffer[:stt_engine.frame_size] audio_buffer = audio_buffer[stt_engine.frame_size:] in_pcms = torch.from_numpy(frame.copy()).to(stt_engine.device).unsqueeze(0).unsqueeze(0) codes = stt_engine.mimi.encode(in_pcms) if first_frame: lm_gen.step(codes) first_frame = False continue tokens = lm_gen.step(codes) if tokens is not None: text_id = tokens[0, 0].cpu().item() if text_id not in [0, 3]: text_fragment = stt_engine.text_tokenizer.id_to_piece(text_id) clean_fragment = text_fragment.replace("▁", " ") transcription_buffer += clean_fragment # Send delta (partial transcription) await websocket.send_text(json.dumps({ "type": "conversation.item.input_audio_transcription.delta", "item_id": item_id, "content_index": 0, "delta": clean_fragment })) logger.debug(f"📝 Sent delta: '{clean_fragment}'") elif event.get("type") == "input_audio_buffer.commit": # Send final transcription final_text = transcription_buffer.strip() await websocket.send_text(json.dumps({ "type": "conversation.item.input_audio_transcription.completed", "item_id": item_id, "content_index": 0, "transcript": final_text })) logger.info(f"✅ Committed transcription: '{final_text}'") transcription_buffer = "" item_id = f"item_{int(time.time())}" # New item for next transcription elif event.get("type") == "input_audio_buffer.clear": # Clear buffers audio_buffer = np.array([], dtype=np.float32) transcription_buffer = "" await websocket.send_text(json.dumps({ "type": "input_audio_buffer.cleared" })) logger.info("🔄 Buffers cleared") elif event.get("type") == "session.update": # Acknowledge session update await websocket.send_text(json.dumps({ "type": "session.updated", "session": event.get("session", {}) })) except json.JSONDecodeError: logger.error("❌ Invalid JSON in realtime event") except Exception as e: logger.error(f"❌ Error processing event: {e}", exc_info=True) await websocket.send_text(json.dumps({ "type": "error", "error": { "type": "processing_error", "message": str(e) } })) except WebSocketDisconnect: logger.info(f"🔌 Realtime session {session_id} disconnected") except Exception as e: logger.error(f"❌ Realtime session error: {e}", exc_info=True) finally: try: await websocket.close() except: pass logger.info(f"🔒 Realtime session {session_id} closed") # --- Models endpoint (OpenAI compatible) --- @app.get("/v1/models") async def list_models(): """OpenAI compatible models endpoint""" return { "object": "list", "data": [ { "id": "whisper-1", "object": "model", "created": 1677532384, "owned_by": "kyutai", "permission": [], "root": "whisper-1", "parent": None }, { "id": "kyutai/stt-1b-en_fr", "object": "model", "created": 1677532384, "owned_by": "kyutai", "permission": [], "root": "kyutai/stt-1b-en_fr", "parent": None } ] } # --- Main Execution --- if __name__ == "__main__": port = int(os.getenv("PORT", 7860)) host = os.getenv("HOST", "0.0.0.0") logger.info(f"🚀 Starting Kyutai OpenAI Whisper API Compatible service on {host}:{port}") logger.info(f"📋 Endpoints:") logger.info(f" - POST http://{host}:{port}/v1/audio/transcriptions (File upload)") logger.info(f" - WS ws://{host}:{port}/v1/audio/stream (Real-time streaming)") logger.info(f" - WS ws://{host}:{port}/v1/realtime (OpenAI Realtime API)") logger.info(f" - GET http://{host}:{port}/health (Health check)") uvicorn.run(app, host=host, port=port, log_level="info")