Spaces:
Runtime error
Runtime error
File size: 6,444 Bytes
870cc4e d3a18f1 870cc4e d3a18f1 870cc4e d3a18f1 870cc4e d3a18f1 870cc4e d3a18f1 870cc4e d3a18f1 870cc4e d3a18f1 870cc4e d3a18f1 870cc4e d3a18f1 870cc4e d3a18f1 870cc4e d3a18f1 870cc4e d3a18f1 870cc4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import os
import torch
import torchaudio
from transformers import AutoModel
from pydub import AudioSegment
import aiofiles
import uuid
from fastapi import FastAPI, HTTPException, File, UploadFile
from starlette.concurrency import run_in_threadpool
from starlette.staticfiles import StaticFiles # <-- NEW IMPORT
from starlette.responses import HTMLResponse, RedirectResponse # <-- NEW IMPORT
# -----------------------------------------------------------
# 1. FastAPI App Instance
# -----------------------------------------------------------
app = FastAPI()
# -----------------------------------------------------------
# 2. Global Variables (for model and directories)
# These will be initialized during startup
# -----------------------------------------------------------
ASR_MODEL = None
DEVICE = None
UPLOAD_DIR = "./uploads"
CONVERTED_AUDIO_DIR = "./converted_audio_temp"
TRANSCRIPTION_OUTPUT_DIR = "./transcriptions"
TARGET_SAMPLE_RATE = 16000 # Required sample rate for the new model
# -----------------------------------------------------------
# 3. Startup Event: Load Model and Create Directories
# This runs once when the FastAPI application starts
# -----------------------------------------------------------
@app.on_event("startup")
async def startup_event():
# Ensure directories exist
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(CONVERTED_AUDIO_DIR, exist_ok=True)
os.makedirs(TRANSCRIPTION_OUTPUT_DIR, exist_ok=True)
# Load the ASR model globally
global ASR_MODEL, DEVICE
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ASR_MODEL = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True)
ASR_MODEL.to(DEVICE)
ASR_MODEL.eval()
# -----------------------------------------------------------
# 4. Mount Static Files and Define Root Endpoint (NEW)
# -----------------------------------------------------------
# Mount the 'static' directory to serve HTML, CSS, JS files
# This makes files like 'static/index.html' accessible at /static/index.html
app.mount("/static", StaticFiles(directory="static"), name="static")
# Define a root endpoint that serves your main HTML page
@app.get("/", response_class=HTMLResponse)
async def read_root():
try:
# FastAPI will serve this index.html when users visit the root URL of your Space
with open("static/index.html", "r", encoding="utf-8") as f:
return HTMLResponse(content=f.read())
except FileNotFoundError:
# This fallback should ideally not be hit if your Dockerfile copies files correctly
return HTMLResponse("<h1>Error: index.html not found!</h1><p>Please ensure 'static/index.html' exists in your project.</p>", status_code=404)
# -----------------------------------------------------------
# 5. Helper Function: Audio Conversion (Existing Code)
# This function performs the actual audio conversion (blocking operation)
# -----------------------------------------------------------
def _convert_audio_sync(input_path: str, output_path: str, target_sample_rate: int = TARGET_SAMPLE_RATE, channels: int = 1):
audio = AudioSegment.from_file(input_path)
audio = audio.set_frame_rate(target_sample_rate).set_channels(channels)
audio.export(output_path, format="wav")
# -----------------------------------------------------------
# 6. Main API Endpoint: Handle File Upload and Transcription (Existing Code)
# -----------------------------------------------------------
@app.post('/transcribefile/')
async def transcribe_file(file: UploadFile = File(...)):
# 5.1. Generate unique filenames for uploaded and converted files
unique_id = str(uuid.uuid4())
uploaded_file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_{file.filename}")
converted_audio_path = os.path.join(CONVERTED_AUDIO_DIR, f"{unique_id}.wav")
transcription_output_path_rnnt = os.path.join(TRANSCRIPTION_OUTPUT_DIR, f"{unique_id}_rnnt.txt")
try:
# 5.2. Asynchronously save the uploaded file
async with aiofiles.open(uploaded_file_path, "wb") as f:
while content := await file.read(1024 * 1024):
await f.write(content)
# 5.3. Handle potential file upload errors (e.g., empty file)
if not os.path.exists(uploaded_file_path) or os.path.getsize(uploaded_file_path) == 0:
raise HTTPException(status_code=400, detail="Uploaded file is empty or could not be saved.")
# 5.4. Convert audio (run blocking operation in a thread pool)
# This is where pydub uses ffmpeg
await run_in_threadpool(
_convert_audio_sync, uploaded_file_path, converted_audio_path
)
# 5.5. Load and preprocess the converted audio for the new model
wav, sr = torchaudio.load(converted_audio_path)
wav = torch.mean(wav, dim=0, keepdim=True) # Convert to mono if stereo
if sr != TARGET_SAMPLE_RATE:
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=TARGET_SAMPLE_RATE)
wav = resampler(wav)
wav = wav.to(DEVICE) # Move tensor to the correct device
# 5.6. Perform transcription using RNNT decoding
with torch.no_grad(): # Disable gradient calculation for inference
transcription_rnnt = ASR_MODEL(wav, "ml", "rnnt")
# 5.7. Save transcription (optional)
async with aiofiles.open(transcription_output_path_rnnt, "w", encoding="utf-8") as f:
await f.write(transcription_rnnt)
# 5.8. Return the transcription
return {
"rnnt_transcription": transcription_rnnt
}
except Exception as e:
# 5.9. Centralized error handling
print(f"Error during transcription process: {e}")
# Specific error for file not found or corrupted during conversion
if "File not found" in str(e) or "Error parsing" in str(e):
raise HTTPException(status_code=422, detail=f"Could not process audio file: {e}")
# General server error
raise HTTPException(status_code=500, detail=f"An internal server error occurred: {e}")
finally:
# 5.10. Clean up temporary files
await file.close() # Close the UploadFile's underlying file handle
if os.path.exists(uploaded_file_path):
os.remove(uploaded_file_path)
if os.path.exists(converted_audio_path):
os.remove(converted_audio_path) |