Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import torchaudio | |
| from transformers import AutoModel | |
| from pydub import AudioSegment | |
| import aiofiles | |
| import uuid | |
| from fastapi import FastAPI, HTTPException, File, UploadFile | |
| from starlette.concurrency import run_in_threadpool | |
| from starlette.staticfiles import StaticFiles # <-- NEW IMPORT | |
| from starlette.responses import HTMLResponse, RedirectResponse # <-- NEW IMPORT | |
| # ----------------------------------------------------------- | |
| # 1. FastAPI App Instance | |
| # ----------------------------------------------------------- | |
| app = FastAPI() | |
| # ----------------------------------------------------------- | |
| # 2. Global Variables (for model and directories) | |
| # These will be initialized during startup | |
| # ----------------------------------------------------------- | |
| ASR_MODEL = None | |
| DEVICE = None | |
| UPLOAD_DIR = "./uploads" | |
| CONVERTED_AUDIO_DIR = "./converted_audio_temp" | |
| TRANSCRIPTION_OUTPUT_DIR = "./transcriptions" | |
| TARGET_SAMPLE_RATE = 16000 # Required sample rate for the new model | |
| # ----------------------------------------------------------- | |
| # 3. Startup Event: Load Model and Create Directories | |
| # This runs once when the FastAPI application starts | |
| # ----------------------------------------------------------- | |
| async def startup_event(): | |
| # Ensure directories exist | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| os.makedirs(CONVERTED_AUDIO_DIR, exist_ok=True) | |
| os.makedirs(TRANSCRIPTION_OUTPUT_DIR, exist_ok=True) | |
| # Load the ASR model globally | |
| global ASR_MODEL, DEVICE | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| ASR_MODEL = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True) | |
| ASR_MODEL.to(DEVICE) | |
| ASR_MODEL.eval() | |
| # ----------------------------------------------------------- | |
| # 4. Mount Static Files and Define Root Endpoint (NEW) | |
| # ----------------------------------------------------------- | |
| # Mount the 'static' directory to serve HTML, CSS, JS files | |
| # This makes files like 'static/index.html' accessible at /static/index.html | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # Define a root endpoint that serves your main HTML page | |
| async def read_root(): | |
| try: | |
| # FastAPI will serve this index.html when users visit the root URL of your Space | |
| with open("static/index.html", "r", encoding="utf-8") as f: | |
| return HTMLResponse(content=f.read()) | |
| except FileNotFoundError: | |
| # This fallback should ideally not be hit if your Dockerfile copies files correctly | |
| return HTMLResponse("<h1>Error: index.html not found!</h1><p>Please ensure 'static/index.html' exists in your project.</p>", status_code=404) | |
| # ----------------------------------------------------------- | |
| # 5. Helper Function: Audio Conversion (Existing Code) | |
| # This function performs the actual audio conversion (blocking operation) | |
| # ----------------------------------------------------------- | |
| def _convert_audio_sync(input_path: str, output_path: str, target_sample_rate: int = TARGET_SAMPLE_RATE, channels: int = 1): | |
| audio = AudioSegment.from_file(input_path) | |
| audio = audio.set_frame_rate(target_sample_rate).set_channels(channels) | |
| audio.export(output_path, format="wav") | |
| # ----------------------------------------------------------- | |
| # 6. Main API Endpoint: Handle File Upload and Transcription (Existing Code) | |
| # ----------------------------------------------------------- | |
| async def transcribe_file(file: UploadFile = File(...)): | |
| # 5.1. Generate unique filenames for uploaded and converted files | |
| unique_id = str(uuid.uuid4()) | |
| uploaded_file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_{file.filename}") | |
| converted_audio_path = os.path.join(CONVERTED_AUDIO_DIR, f"{unique_id}.wav") | |
| transcription_output_path_rnnt = os.path.join(TRANSCRIPTION_OUTPUT_DIR, f"{unique_id}_rnnt.txt") | |
| try: | |
| # 5.2. Asynchronously save the uploaded file | |
| async with aiofiles.open(uploaded_file_path, "wb") as f: | |
| while content := await file.read(1024 * 1024): | |
| await f.write(content) | |
| # 5.3. Handle potential file upload errors (e.g., empty file) | |
| if not os.path.exists(uploaded_file_path) or os.path.getsize(uploaded_file_path) == 0: | |
| raise HTTPException(status_code=400, detail="Uploaded file is empty or could not be saved.") | |
| # 5.4. Convert audio (run blocking operation in a thread pool) | |
| # This is where pydub uses ffmpeg | |
| await run_in_threadpool( | |
| _convert_audio_sync, uploaded_file_path, converted_audio_path | |
| ) | |
| # 5.5. Load and preprocess the converted audio for the new model | |
| wav, sr = torchaudio.load(converted_audio_path) | |
| wav = torch.mean(wav, dim=0, keepdim=True) # Convert to mono if stereo | |
| if sr != TARGET_SAMPLE_RATE: | |
| resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=TARGET_SAMPLE_RATE) | |
| wav = resampler(wav) | |
| wav = wav.to(DEVICE) # Move tensor to the correct device | |
| # 5.6. Perform transcription using RNNT decoding | |
| with torch.no_grad(): # Disable gradient calculation for inference | |
| transcription_rnnt = ASR_MODEL(wav, "ml", "rnnt") | |
| # 5.7. Save transcription (optional) | |
| async with aiofiles.open(transcription_output_path_rnnt, "w", encoding="utf-8") as f: | |
| await f.write(transcription_rnnt) | |
| # 5.8. Return the transcription | |
| return { | |
| "rnnt_transcription": transcription_rnnt | |
| } | |
| except Exception as e: | |
| # 5.9. Centralized error handling | |
| print(f"Error during transcription process: {e}") | |
| # Specific error for file not found or corrupted during conversion | |
| if "File not found" in str(e) or "Error parsing" in str(e): | |
| raise HTTPException(status_code=422, detail=f"Could not process audio file: {e}") | |
| # General server error | |
| raise HTTPException(status_code=500, detail=f"An internal server error occurred: {e}") | |
| finally: | |
| # 5.10. Clean up temporary files | |
| await file.close() # Close the UploadFile's underlying file handle | |
| if os.path.exists(uploaded_file_path): | |
| os.remove(uploaded_file_path) | |
| if os.path.exists(converted_audio_path): | |
| os.remove(converted_audio_path) |