File size: 6,444 Bytes
870cc4e
 
 
d3a18f1
 
 
 
870cc4e
 
d3a18f1
 
 
870cc4e
 
 
 
 
 
 
 
d3a18f1
870cc4e
 
 
 
 
 
 
 
 
 
d3a18f1
870cc4e
 
 
 
 
 
 
 
 
 
 
 
 
d3a18f1
870cc4e
d3a18f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
870cc4e
 
d3a18f1
 
870cc4e
 
 
 
 
 
 
 
d3a18f1
870cc4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3a18f1
870cc4e
 
 
 
 
 
 
 
 
 
 
 
 
 
d3a18f1
870cc4e
 
 
d3a18f1
870cc4e
 
 
d3a18f1
870cc4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import torch
import torchaudio
from transformers import AutoModel
from pydub import AudioSegment
import aiofiles
import uuid

from fastapi import FastAPI, HTTPException, File, UploadFile
from starlette.concurrency import run_in_threadpool
from starlette.staticfiles import StaticFiles # <-- NEW IMPORT
from starlette.responses import HTMLResponse, RedirectResponse # <-- NEW IMPORT

# -----------------------------------------------------------
# 1. FastAPI App Instance
# -----------------------------------------------------------
app = FastAPI()

# -----------------------------------------------------------
# 2. Global Variables (for model and directories)
# These will be initialized during startup
# -----------------------------------------------------------
ASR_MODEL = None
DEVICE = None
UPLOAD_DIR = "./uploads"
CONVERTED_AUDIO_DIR = "./converted_audio_temp"
TRANSCRIPTION_OUTPUT_DIR = "./transcriptions"
TARGET_SAMPLE_RATE = 16000 # Required sample rate for the new model

# -----------------------------------------------------------
# 3. Startup Event: Load Model and Create Directories
# This runs once when the FastAPI application starts
# -----------------------------------------------------------
@app.on_event("startup")
async def startup_event():
    # Ensure directories exist
    os.makedirs(UPLOAD_DIR, exist_ok=True)
    os.makedirs(CONVERTED_AUDIO_DIR, exist_ok=True)
    os.makedirs(TRANSCRIPTION_OUTPUT_DIR, exist_ok=True)

    # Load the ASR model globally
    global ASR_MODEL, DEVICE
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ASR_MODEL = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True)
    ASR_MODEL.to(DEVICE)
    ASR_MODEL.eval()

# -----------------------------------------------------------
# 4. Mount Static Files and Define Root Endpoint (NEW)
# -----------------------------------------------------------
# Mount the 'static' directory to serve HTML, CSS, JS files
# This makes files like 'static/index.html' accessible at /static/index.html
app.mount("/static", StaticFiles(directory="static"), name="static")

# Define a root endpoint that serves your main HTML page
@app.get("/", response_class=HTMLResponse)
async def read_root():
    try:
        # FastAPI will serve this index.html when users visit the root URL of your Space
        with open("static/index.html", "r", encoding="utf-8") as f:
            return HTMLResponse(content=f.read())
    except FileNotFoundError:
        # This fallback should ideally not be hit if your Dockerfile copies files correctly
        return HTMLResponse("<h1>Error: index.html not found!</h1><p>Please ensure 'static/index.html' exists in your project.</p>", status_code=404)

# -----------------------------------------------------------
# 5. Helper Function: Audio Conversion (Existing Code)
# This function performs the actual audio conversion (blocking operation)
# -----------------------------------------------------------
def _convert_audio_sync(input_path: str, output_path: str, target_sample_rate: int = TARGET_SAMPLE_RATE, channels: int = 1):
    audio = AudioSegment.from_file(input_path)
    audio = audio.set_frame_rate(target_sample_rate).set_channels(channels)
    audio.export(output_path, format="wav")


# -----------------------------------------------------------
# 6. Main API Endpoint: Handle File Upload and Transcription (Existing Code)
# -----------------------------------------------------------
@app.post('/transcribefile/')
async def transcribe_file(file: UploadFile = File(...)):
    # 5.1. Generate unique filenames for uploaded and converted files
    unique_id = str(uuid.uuid4())
    uploaded_file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_{file.filename}")
    converted_audio_path = os.path.join(CONVERTED_AUDIO_DIR, f"{unique_id}.wav")
    transcription_output_path_rnnt = os.path.join(TRANSCRIPTION_OUTPUT_DIR, f"{unique_id}_rnnt.txt")

    try:
        # 5.2. Asynchronously save the uploaded file
        async with aiofiles.open(uploaded_file_path, "wb") as f:
            while content := await file.read(1024 * 1024):
                await f.write(content)

        # 5.3. Handle potential file upload errors (e.g., empty file)
        if not os.path.exists(uploaded_file_path) or os.path.getsize(uploaded_file_path) == 0:
            raise HTTPException(status_code=400, detail="Uploaded file is empty or could not be saved.")

        # 5.4. Convert audio (run blocking operation in a thread pool)
        # This is where pydub uses ffmpeg
        await run_in_threadpool(
            _convert_audio_sync, uploaded_file_path, converted_audio_path
        )

        # 5.5. Load and preprocess the converted audio for the new model
        wav, sr = torchaudio.load(converted_audio_path)
        wav = torch.mean(wav, dim=0, keepdim=True) # Convert to mono if stereo

        if sr != TARGET_SAMPLE_RATE:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=TARGET_SAMPLE_RATE)
            wav = resampler(wav)
        
        wav = wav.to(DEVICE) # Move tensor to the correct device

        # 5.6. Perform transcription using RNNT decoding
        with torch.no_grad(): # Disable gradient calculation for inference
            transcription_rnnt = ASR_MODEL(wav, "ml", "rnnt")

        # 5.7. Save transcription (optional)
        async with aiofiles.open(transcription_output_path_rnnt, "w", encoding="utf-8") as f:
            await f.write(transcription_rnnt)

        # 5.8. Return the transcription
        return {
            "rnnt_transcription": transcription_rnnt
        }

    except Exception as e:
        # 5.9. Centralized error handling
        print(f"Error during transcription process: {e}")
        # Specific error for file not found or corrupted during conversion
        if "File not found" in str(e) or "Error parsing" in str(e):
            raise HTTPException(status_code=422, detail=f"Could not process audio file: {e}")
        # General server error
        raise HTTPException(status_code=500, detail=f"An internal server error occurred: {e}")

    finally:
        # 5.10. Clean up temporary files
        await file.close() # Close the UploadFile's underlying file handle
        if os.path.exists(uploaded_file_path):
            os.remove(uploaded_file_path)
        if os.path.exists(converted_audio_path):
            os.remove(converted_audio_path)