fastapi-asr-app / main.py
ogflash's picture
Deploying FastAPI ASR app with custom Dockerfile
870cc4e
raw
history blame
5.87 kB
import os
import torch
import torchaudio
from transformers import AutoModel # For the new model
from pydub import AudioSegment # Requires ffmpeg installed on system
import aiofiles # For asynchronous file operations
import uuid # For generating unique filenames
from fastapi import FastAPI, HTTPException, File, UploadFile
from starlette.concurrency import run_in_threadpool # For running blocking code in background thread
# -----------------------------------------------------------
# 1. FastAPI App Instance
# -----------------------------------------------------------
app = FastAPI()
# -----------------------------------------------------------
# 2. Global Variables (for model and directories)
# These will be initialized during startup
# -----------------------------------------------------------
ASR_MODEL = None
DEVICE = None
UPLOAD_DIR = "./uploads"
CONVERTED_AUDIO_DIR = "./converted_audio_temp"
TRANSCRIPTION_OUTPUT_DIR = "./transcriptions"
TARGET_SAMPLE_RATE = 16000 # Required sample rate for the new model
# -----------------------------------------------------------
# 3. Startup Event: Load Model and Create Directories
# This runs once when the FastAPI application starts
# -----------------------------------------------------------
@app.on_event("startup")
async def startup_event():
# Ensure directories exist
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(CONVERTED_AUDIO_DIR, exist_ok=True)
os.makedirs(TRANSCRIPTION_OUTPUT_DIR, exist_ok=True)
# Load the ASR model globally
global ASR_MODEL, DEVICE
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ASR_MODEL = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True)
ASR_MODEL.to(DEVICE)
ASR_MODEL.eval() # Set model to evaluation mode
# -----------------------------------------------------------
# 4. Helper Function: Audio Conversion
# This function performs the actual audio conversion (blocking operation)
# -----------------------------------------------------------
def _convert_audio_sync(input_path: str, output_path: str, target_sample_rate: int = TARGET_SAMPLE_RATE, channels: int = 1):
audio = AudioSegment.from_file(input_path)
audio = audio.set_frame_rate(target_sample_rate).set_channels(channels)
audio.export(output_path, format="wav")
# -----------------------------------------------------------
# 5. Main API Endpoint: Handle File Upload and Transcription
# -----------------------------------------------------------
@app.post('/transcribefile/')
async def transcribe_file(file: UploadFile = File(...)):
# 5.1. Generate unique filenames for uploaded and converted files
unique_id = str(uuid.uuid4())
uploaded_file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_{file.filename}")
converted_audio_path = os.path.join(CONVERTED_AUDIO_DIR, f"{unique_id}.wav")
#transcription_output_path_ctc = os.path.join(TRANSCRIPTION_OUTPUT_DIR, f"{unique_id}_ctc.txt")
transcription_output_path_rnnt = os.path.join(TRANSCRIPTION_OUTPUT_DIR, f"{unique_id}_rnnt.txt")
try:
# 5.2. Asynchronously save the uploaded file
async with aiofiles.open(uploaded_file_path, "wb") as f:
while content := await file.read(1024 * 1024):
await f.write(content)
# 5.3. Handle potential file upload errors (e.g., empty file)
if not os.path.exists(uploaded_file_path) or os.path.getsize(uploaded_file_path) == 0:
raise HTTPException(status_code=400, detail="Uploaded file is empty or could not be saved.")
# 5.4. Convert audio (run blocking operation in a thread pool)
# This is where pydub uses ffmpeg
await run_in_threadpool(
_convert_audio_sync, uploaded_file_path, converted_audio_path
)
# 5.5. Load and preprocess the converted audio for the new model
wav, sr = torchaudio.load(converted_audio_path)
wav = torch.mean(wav, dim=0, keepdim=True) # Convert to mono if stereo
if sr != TARGET_SAMPLE_RATE:
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=TARGET_SAMPLE_RATE)
wav = resampler(wav)
wav = wav.to(DEVICE) # Move tensor to the correct device
# 5.6. Perform transcription using both CTC and RNNT decoding
with torch.no_grad(): # Disable gradient calculation for inference
#transcription_ctc = ASR_MODEL(wav, "ml", "ctc")
transcription_rnnt = ASR_MODEL(wav, "ml", "rnnt")
# 5.7. Save transcriptions (optional)
#async with aiofiles.open(transcription_output_path_ctc, "w", encoding="utf-8") as f:
# await f.write(transcription_ctc)
async with aiofiles.open(transcription_output_path_rnnt, "w", encoding="utf-8") as f:
await f.write(transcription_rnnt)
# 5.8. Return the transcriptions
return {
# "ctc_transcription": transcription_ctc,
"rnnt_transcription": transcription_rnnt
}
except Exception as e:
# 5.9. Centralized error handling
print(f"Error during transcription process: {e}")
# Specific error for file not found or corrupted during conversion
if "File not found" in str(e) or "Error parsing" in str(e):
raise HTTPException(status_code=422, detail=f"Could not process audio file: {e}")
# General server error
raise HTTPException(status_code=500, detail=f"An internal server error occurred: {e}")
finally:
# 5.10. Clean up temporary files
await file.close() # Close the UploadFile's underlying file handle
if os.path.exists(uploaded_file_path):
os.remove(uploaded_file_path)
if os.path.exists(converted_audio_path):
os.remove(converted_audio_path)