Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| Kyutai STT GPU Service Moshi v4 | |
| Official moshi-server implementation with web interface and protocol bridge | |
| """ | |
| import os | |
| import sys | |
| import time | |
| import asyncio | |
| import logging | |
| import subprocess | |
| from pathlib import Path | |
| from typing import Optional, Dict, Any | |
| import gradio as gr | |
| import websockets | |
| import msgpack | |
| import httpx | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from fastapi.responses import JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| import uvicorn | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class MoshiServerProxy: | |
| """Proxy to communicate with official moshi-server""" | |
| def __init__(self): | |
| self.moshi_server_url = "ws://localhost:8080/api/asr-streaming" | |
| self.server_process = None | |
| async def check_moshi_server_health(self) -> bool: | |
| """Check if moshi-server is running""" | |
| try: | |
| async with httpx.AsyncClient() as client: | |
| response = await client.get("http://localhost:8080/health", timeout=5.0) | |
| return response.status_code == 200 | |
| except: | |
| return False | |
| async def wait_for_moshi_server(self, max_wait_seconds: int = 60): | |
| """Wait for moshi-server to be ready""" | |
| logger.info("Waiting for official moshi-server to be ready...") | |
| for i in range(max_wait_seconds): | |
| if await self.check_moshi_server_health(): | |
| logger.info("β Official moshi-server is ready!") | |
| return True | |
| if i % 10 == 0: | |
| logger.info(f"Still waiting for moshi-server... ({i}s)") | |
| await asyncio.sleep(1) | |
| logger.error("β Timeout waiting for moshi-server to be ready") | |
| return False | |
| # Global proxy instance | |
| proxy = MoshiServerProxy() | |
| # FastAPI app | |
| app = FastAPI(title="Kyutai STT GPU Service Moshi v4") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| moshi_ready = await proxy.check_moshi_server_health() | |
| return JSONResponse({ | |
| "status": "healthy" if moshi_ready else "starting", | |
| "service": "kyutai-stt-moshi-v4", | |
| "version": "4.0.0", | |
| "official_moshi_server": moshi_ready, | |
| "endpoints": { | |
| "websocket_streaming": "/ws", | |
| "moshi_streaming": "/api/asr-streaming", | |
| "health": "/health" | |
| } | |
| }) | |
| async def websocket_endpoint(websocket: WebSocket): | |
| """WebSocket endpoint that bridges to official moshi-server""" | |
| await websocket.accept() | |
| logger.info("Client connected to bridge WebSocket") | |
| try: | |
| # Connect to official moshi-server | |
| async with websockets.connect( | |
| proxy.moshi_server_url, | |
| extra_headers={"kyutai-api-key": "bridge-key"} | |
| ) as moshi_ws: | |
| logger.info("Connected to official moshi-server") | |
| async def client_to_moshi(): | |
| """Forward messages from client to moshi-server""" | |
| try: | |
| async for message in websocket.iter_json(): | |
| # Convert our JSON protocol to MessagePack for moshi-server | |
| if message.get("type") == "audio": | |
| # Convert base64 audio to PCM float array | |
| import base64 | |
| audio_bytes = base64.b64decode(message["data"]) | |
| # Convert bytes to float32 array (assuming little-endian) | |
| import struct | |
| num_floats = len(audio_bytes) // 4 | |
| audio_floats = struct.unpack(f'<{num_floats}f', audio_bytes) | |
| # Create moshi-server message | |
| moshi_msg = { | |
| "type": "Audio", | |
| "pcm": list(audio_floats) | |
| } | |
| # Send as MessagePack | |
| packed = msgpack.packb(moshi_msg, use_bin_type=True, use_single_float=True) | |
| await moshi_ws.send(packed) | |
| elif message.get("type") == "start": | |
| logger.info("Starting streaming session") | |
| # Official moshi-server may not need explicit start messages | |
| elif message.get("type") == "stop": | |
| logger.info("Stopping streaming session") | |
| # Official moshi-server may handle this automatically | |
| except WebSocketDisconnect: | |
| logger.info("Client disconnected") | |
| except Exception as e: | |
| logger.error(f"Error in client_to_moshi: {e}") | |
| async def moshi_to_client(): | |
| """Forward messages from moshi-server to client""" | |
| try: | |
| async for message in moshi_ws: | |
| # Unpack MessagePack from moshi-server | |
| try: | |
| data = msgpack.unpackb(message, raw=False) | |
| if data.get("type") == "Word": | |
| # Convert to our JSON format | |
| response = { | |
| "type": "transcription", | |
| "result": { | |
| "text": data.get("text", ""), | |
| "confidence": data.get("confidence"), | |
| "start_time": data.get("start_time"), | |
| "end_time": data.get("end_time") | |
| } | |
| } | |
| await websocket.send_json(response) | |
| elif data.get("type") == "Step": | |
| # Voice Activity Detection | |
| response = { | |
| "type": "vad", | |
| "active": data.get("active", False) | |
| } | |
| await websocket.send_json(response) | |
| except Exception as e: | |
| logger.error(f"Error unpacking moshi message: {e}") | |
| except Exception as e: | |
| logger.error(f"Error in moshi_to_client: {e}") | |
| # Run both directions concurrently | |
| await asyncio.gather(client_to_moshi(), moshi_to_client()) | |
| except Exception as e: | |
| logger.error(f"WebSocket bridge error: {e}") | |
| await websocket.close() | |
| def create_gradio_interface(): | |
| """Create Gradio web interface""" | |
| with gr.Blocks(title="Kyutai STT Moshi v4") as interface: | |
| gr.Markdown("# π€ Kyutai STT Server Moshi v4") | |
| gr.Markdown("**Official moshi-server implementation** with MessagePack protocol and proven streaming performance.") | |
| with gr.Row(): | |
| gr.Markdown("**Server Status**: Running β ") | |
| with gr.Row(): | |
| connect_btn = gr.Button("Connect", variant="primary") | |
| disconnect_btn = gr.Button("Disconnect") | |
| start_btn = gr.Button("Start Stream") | |
| stop_btn = gr.Button("Stop Stream") | |
| gr.Markdown("### Messages:") | |
| messages = gr.Textbox(label="Server Messages", lines=10, interactive=False) | |
| gr.Markdown("### WebSocket Endpoints:") | |
| gr.Code( | |
| "wss://pgits-stt-gpu-service-moshi-v4.hf.space/ws (bridge to moshi-server)\n" | |
| "wss://pgits-stt-gpu-service-moshi-v4.hf.space/api/asr-streaming (direct moshi-server)" | |
| ) | |
| gr.Markdown("### Official Protocol Example:") | |
| gr.Code(""" | |
| // MessagePack format (official moshi-server) | |
| chunk = {"type": "Audio", "pcm": [float_array]} | |
| msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True) | |
| // Response types: | |
| {"type": "Word", "text": "transcribed text", ...} | |
| {"type": "Step", "active": true/false} // VAD | |
| """) | |
| gr.Markdown("### Architecture:") | |
| gr.Markdown(""" | |
| - **Official moshi-server**: `cargo install --features cuda moshi-server` | |
| - **Proven performance**: 64 streams on L40S, 400 on H100 | |
| - **Processing time**: ~125ms for real-time transcription | |
| - **Protocol**: MessagePack (not JSON) with raw PCM audio | |
| - **Model**: kyutai/stt-1b-en_fr (~1B params, 0.5s delay) | |
| """) | |
| return interface | |
| async def main(): | |
| """Main application entry point""" | |
| logger.info("π Starting Kyutai STT GPU Service Moshi v4") | |
| # Wait for moshi-server to be ready (it should be starting in parallel) | |
| if not await proxy.wait_for_moshi_server(): | |
| logger.error("Failed to connect to moshi-server, but continuing with web interface") | |
| # Create Gradio interface | |
| interface = create_gradio_interface() | |
| # Mount Gradio on FastAPI | |
| gradio_app = gr.mount_gradio_app(app, interface, path="/") | |
| # Start server | |
| config = uvicorn.Config( | |
| gradio_app, | |
| host="0.0.0.0", | |
| port=7860, | |
| log_level="info" | |
| ) | |
| server = uvicorn.Server(config) | |
| logger.info("π Starting web server on port 7860...") | |
| await server.serve() | |
| if __name__ == "__main__": | |
| asyncio.run(main()) |