fastapi-asr-app / main.py
ogflash's picture
Fix: added webui for asr api
d3a18f1
import os
import torch
import torchaudio
from transformers import AutoModel
from pydub import AudioSegment
import aiofiles
import uuid
from fastapi import FastAPI, HTTPException, File, UploadFile
from starlette.concurrency import run_in_threadpool
from starlette.staticfiles import StaticFiles # <-- NEW IMPORT
from starlette.responses import HTMLResponse, RedirectResponse # <-- NEW IMPORT
# -----------------------------------------------------------
# 1. FastAPI App Instance
# -----------------------------------------------------------
app = FastAPI()
# -----------------------------------------------------------
# 2. Global Variables (for model and directories)
# These will be initialized during startup
# -----------------------------------------------------------
ASR_MODEL = None
DEVICE = None
UPLOAD_DIR = "./uploads"
CONVERTED_AUDIO_DIR = "./converted_audio_temp"
TRANSCRIPTION_OUTPUT_DIR = "./transcriptions"
TARGET_SAMPLE_RATE = 16000 # Required sample rate for the new model
# -----------------------------------------------------------
# 3. Startup Event: Load Model and Create Directories
# This runs once when the FastAPI application starts
# -----------------------------------------------------------
@app.on_event("startup")
async def startup_event():
# Ensure directories exist
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(CONVERTED_AUDIO_DIR, exist_ok=True)
os.makedirs(TRANSCRIPTION_OUTPUT_DIR, exist_ok=True)
# Load the ASR model globally
global ASR_MODEL, DEVICE
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ASR_MODEL = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True)
ASR_MODEL.to(DEVICE)
ASR_MODEL.eval()
# -----------------------------------------------------------
# 4. Mount Static Files and Define Root Endpoint (NEW)
# -----------------------------------------------------------
# Mount the 'static' directory to serve HTML, CSS, JS files
# This makes files like 'static/index.html' accessible at /static/index.html
app.mount("/static", StaticFiles(directory="static"), name="static")
# Define a root endpoint that serves your main HTML page
@app.get("/", response_class=HTMLResponse)
async def read_root():
try:
# FastAPI will serve this index.html when users visit the root URL of your Space
with open("static/index.html", "r", encoding="utf-8") as f:
return HTMLResponse(content=f.read())
except FileNotFoundError:
# This fallback should ideally not be hit if your Dockerfile copies files correctly
return HTMLResponse("<h1>Error: index.html not found!</h1><p>Please ensure 'static/index.html' exists in your project.</p>", status_code=404)
# -----------------------------------------------------------
# 5. Helper Function: Audio Conversion (Existing Code)
# This function performs the actual audio conversion (blocking operation)
# -----------------------------------------------------------
def _convert_audio_sync(input_path: str, output_path: str, target_sample_rate: int = TARGET_SAMPLE_RATE, channels: int = 1):
audio = AudioSegment.from_file(input_path)
audio = audio.set_frame_rate(target_sample_rate).set_channels(channels)
audio.export(output_path, format="wav")
# -----------------------------------------------------------
# 6. Main API Endpoint: Handle File Upload and Transcription (Existing Code)
# -----------------------------------------------------------
@app.post('/transcribefile/')
async def transcribe_file(file: UploadFile = File(...)):
# 5.1. Generate unique filenames for uploaded and converted files
unique_id = str(uuid.uuid4())
uploaded_file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_{file.filename}")
converted_audio_path = os.path.join(CONVERTED_AUDIO_DIR, f"{unique_id}.wav")
transcription_output_path_rnnt = os.path.join(TRANSCRIPTION_OUTPUT_DIR, f"{unique_id}_rnnt.txt")
try:
# 5.2. Asynchronously save the uploaded file
async with aiofiles.open(uploaded_file_path, "wb") as f:
while content := await file.read(1024 * 1024):
await f.write(content)
# 5.3. Handle potential file upload errors (e.g., empty file)
if not os.path.exists(uploaded_file_path) or os.path.getsize(uploaded_file_path) == 0:
raise HTTPException(status_code=400, detail="Uploaded file is empty or could not be saved.")
# 5.4. Convert audio (run blocking operation in a thread pool)
# This is where pydub uses ffmpeg
await run_in_threadpool(
_convert_audio_sync, uploaded_file_path, converted_audio_path
)
# 5.5. Load and preprocess the converted audio for the new model
wav, sr = torchaudio.load(converted_audio_path)
wav = torch.mean(wav, dim=0, keepdim=True) # Convert to mono if stereo
if sr != TARGET_SAMPLE_RATE:
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=TARGET_SAMPLE_RATE)
wav = resampler(wav)
wav = wav.to(DEVICE) # Move tensor to the correct device
# 5.6. Perform transcription using RNNT decoding
with torch.no_grad(): # Disable gradient calculation for inference
transcription_rnnt = ASR_MODEL(wav, "ml", "rnnt")
# 5.7. Save transcription (optional)
async with aiofiles.open(transcription_output_path_rnnt, "w", encoding="utf-8") as f:
await f.write(transcription_rnnt)
# 5.8. Return the transcription
return {
"rnnt_transcription": transcription_rnnt
}
except Exception as e:
# 5.9. Centralized error handling
print(f"Error during transcription process: {e}")
# Specific error for file not found or corrupted during conversion
if "File not found" in str(e) or "Error parsing" in str(e):
raise HTTPException(status_code=422, detail=f"Could not process audio file: {e}")
# General server error
raise HTTPException(status_code=500, detail=f"An internal server error occurred: {e}")
finally:
# 5.10. Clean up temporary files
await file.close() # Close the UploadFile's underlying file handle
if os.path.exists(uploaded_file_path):
os.remove(uploaded_file_path)
if os.path.exists(converted_audio_path):
os.remove(converted_audio_path)