Peter Michael Gits
Fix d_model and Gradio language errors in moshi-v4 configuration
617dcd9
#!/usr/bin/env python3
"""
Kyutai STT GPU Service Moshi v4
Official moshi-server implementation with web interface and protocol bridge
"""
import os
import sys
import time
import asyncio
import logging
import subprocess
from pathlib import Path
from typing import Optional, Dict, Any
import gradio as gr
import websockets
import msgpack
import httpx
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
import uvicorn
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class MoshiServerProxy:
"""Proxy to communicate with official moshi-server"""
def __init__(self):
self.moshi_server_url = "ws://localhost:8080/api/asr-streaming"
self.server_process = None
async def check_moshi_server_health(self) -> bool:
"""Check if moshi-server is running"""
try:
async with httpx.AsyncClient() as client:
response = await client.get("http://localhost:8080/health", timeout=5.0)
return response.status_code == 200
except:
return False
async def wait_for_moshi_server(self, max_wait_seconds: int = 60):
"""Wait for moshi-server to be ready"""
logger.info("Waiting for official moshi-server to be ready...")
for i in range(max_wait_seconds):
if await self.check_moshi_server_health():
logger.info("βœ… Official moshi-server is ready!")
return True
if i % 10 == 0:
logger.info(f"Still waiting for moshi-server... ({i}s)")
await asyncio.sleep(1)
logger.error("❌ Timeout waiting for moshi-server to be ready")
return False
# Global proxy instance
proxy = MoshiServerProxy()
# FastAPI app
app = FastAPI(title="Kyutai STT GPU Service Moshi v4")
@app.get("/health")
async def health_check():
"""Health check endpoint"""
moshi_ready = await proxy.check_moshi_server_health()
return JSONResponse({
"status": "healthy" if moshi_ready else "starting",
"service": "kyutai-stt-moshi-v4",
"version": "4.0.0",
"official_moshi_server": moshi_ready,
"endpoints": {
"websocket_streaming": "/ws",
"moshi_streaming": "/api/asr-streaming",
"health": "/health"
}
})
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket endpoint that bridges to official moshi-server"""
await websocket.accept()
logger.info("Client connected to bridge WebSocket")
try:
# Connect to official moshi-server
async with websockets.connect(
proxy.moshi_server_url,
extra_headers={"kyutai-api-key": "bridge-key"}
) as moshi_ws:
logger.info("Connected to official moshi-server")
async def client_to_moshi():
"""Forward messages from client to moshi-server"""
try:
async for message in websocket.iter_json():
# Convert our JSON protocol to MessagePack for moshi-server
if message.get("type") == "audio":
# Convert base64 audio to PCM float array
import base64
audio_bytes = base64.b64decode(message["data"])
# Convert bytes to float32 array (assuming little-endian)
import struct
num_floats = len(audio_bytes) // 4
audio_floats = struct.unpack(f'<{num_floats}f', audio_bytes)
# Create moshi-server message
moshi_msg = {
"type": "Audio",
"pcm": list(audio_floats)
}
# Send as MessagePack
packed = msgpack.packb(moshi_msg, use_bin_type=True, use_single_float=True)
await moshi_ws.send(packed)
elif message.get("type") == "start":
logger.info("Starting streaming session")
# Official moshi-server may not need explicit start messages
elif message.get("type") == "stop":
logger.info("Stopping streaming session")
# Official moshi-server may handle this automatically
except WebSocketDisconnect:
logger.info("Client disconnected")
except Exception as e:
logger.error(f"Error in client_to_moshi: {e}")
async def moshi_to_client():
"""Forward messages from moshi-server to client"""
try:
async for message in moshi_ws:
# Unpack MessagePack from moshi-server
try:
data = msgpack.unpackb(message, raw=False)
if data.get("type") == "Word":
# Convert to our JSON format
response = {
"type": "transcription",
"result": {
"text": data.get("text", ""),
"confidence": data.get("confidence"),
"start_time": data.get("start_time"),
"end_time": data.get("end_time")
}
}
await websocket.send_json(response)
elif data.get("type") == "Step":
# Voice Activity Detection
response = {
"type": "vad",
"active": data.get("active", False)
}
await websocket.send_json(response)
except Exception as e:
logger.error(f"Error unpacking moshi message: {e}")
except Exception as e:
logger.error(f"Error in moshi_to_client: {e}")
# Run both directions concurrently
await asyncio.gather(client_to_moshi(), moshi_to_client())
except Exception as e:
logger.error(f"WebSocket bridge error: {e}")
await websocket.close()
def create_gradio_interface():
"""Create Gradio web interface"""
with gr.Blocks(title="Kyutai STT Moshi v4") as interface:
gr.Markdown("# 🎀 Kyutai STT Server Moshi v4")
gr.Markdown("**Official moshi-server implementation** with MessagePack protocol and proven streaming performance.")
with gr.Row():
gr.Markdown("**Server Status**: Running βœ…")
with gr.Row():
connect_btn = gr.Button("Connect", variant="primary")
disconnect_btn = gr.Button("Disconnect")
start_btn = gr.Button("Start Stream")
stop_btn = gr.Button("Stop Stream")
gr.Markdown("### Messages:")
messages = gr.Textbox(label="Server Messages", lines=10, interactive=False)
gr.Markdown("### WebSocket Endpoints:")
gr.Code(
"wss://pgits-stt-gpu-service-moshi-v4.hf.space/ws (bridge to moshi-server)\n"
"wss://pgits-stt-gpu-service-moshi-v4.hf.space/api/asr-streaming (direct moshi-server)"
)
gr.Markdown("### Official Protocol Example:")
gr.Code("""
// MessagePack format (official moshi-server)
chunk = {"type": "Audio", "pcm": [float_array]}
msg = msgpack.packb(chunk, use_bin_type=True, use_single_float=True)
// Response types:
{"type": "Word", "text": "transcribed text", ...}
{"type": "Step", "active": true/false} // VAD
""")
gr.Markdown("### Architecture:")
gr.Markdown("""
- **Official moshi-server**: `cargo install --features cuda moshi-server`
- **Proven performance**: 64 streams on L40S, 400 on H100
- **Processing time**: ~125ms for real-time transcription
- **Protocol**: MessagePack (not JSON) with raw PCM audio
- **Model**: kyutai/stt-1b-en_fr (~1B params, 0.5s delay)
""")
return interface
async def main():
"""Main application entry point"""
logger.info("πŸš€ Starting Kyutai STT GPU Service Moshi v4")
# Wait for moshi-server to be ready (it should be starting in parallel)
if not await proxy.wait_for_moshi_server():
logger.error("Failed to connect to moshi-server, but continuing with web interface")
# Create Gradio interface
interface = create_gradio_interface()
# Mount Gradio on FastAPI
gradio_app = gr.mount_gradio_app(app, interface, path="/")
# Start server
config = uvicorn.Config(
gradio_app,
host="0.0.0.0",
port=7860,
log_level="info"
)
server = uvicorn.Server(config)
logger.info("🌐 Starting web server on port 7860...")
await server.serve()
if __name__ == "__main__":
asyncio.run(main())