File size: 5,871 Bytes
870cc4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import torch
import torchaudio
from transformers import AutoModel # For the new model
from pydub import AudioSegment # Requires ffmpeg installed on system
import aiofiles # For asynchronous file operations
import uuid # For generating unique filenames

from fastapi import FastAPI, HTTPException, File, UploadFile
from starlette.concurrency import run_in_threadpool # For running blocking code in background thread

# -----------------------------------------------------------
# 1. FastAPI App Instance
# -----------------------------------------------------------
app = FastAPI()

# -----------------------------------------------------------
# 2. Global Variables (for model and directories)
#    These will be initialized during startup
# -----------------------------------------------------------
ASR_MODEL = None
DEVICE = None
UPLOAD_DIR = "./uploads"
CONVERTED_AUDIO_DIR = "./converted_audio_temp"
TRANSCRIPTION_OUTPUT_DIR = "./transcriptions"
TARGET_SAMPLE_RATE = 16000 # Required sample rate for the new model

# -----------------------------------------------------------
# 3. Startup Event: Load Model and Create Directories
#    This runs once when the FastAPI application starts
# -----------------------------------------------------------
@app.on_event("startup")
async def startup_event():
    # Ensure directories exist
    os.makedirs(UPLOAD_DIR, exist_ok=True)
    os.makedirs(CONVERTED_AUDIO_DIR, exist_ok=True)
    os.makedirs(TRANSCRIPTION_OUTPUT_DIR, exist_ok=True)

    # Load the ASR model globally
    global ASR_MODEL, DEVICE
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ASR_MODEL = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True)
    ASR_MODEL.to(DEVICE)
    ASR_MODEL.eval() # Set model to evaluation mode


# -----------------------------------------------------------
# 4. Helper Function: Audio Conversion
#    This function performs the actual audio conversion (blocking operation)
# -----------------------------------------------------------
def _convert_audio_sync(input_path: str, output_path: str, target_sample_rate: int = TARGET_SAMPLE_RATE, channels: int = 1):
    audio = AudioSegment.from_file(input_path)
    audio = audio.set_frame_rate(target_sample_rate).set_channels(channels)
    audio.export(output_path, format="wav")


# -----------------------------------------------------------
# 5. Main API Endpoint: Handle File Upload and Transcription
# -----------------------------------------------------------
@app.post('/transcribefile/')
async def transcribe_file(file: UploadFile = File(...)):
    # 5.1. Generate unique filenames for uploaded and converted files
    unique_id = str(uuid.uuid4())
    uploaded_file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_{file.filename}")
    converted_audio_path = os.path.join(CONVERTED_AUDIO_DIR, f"{unique_id}.wav")
    #transcription_output_path_ctc = os.path.join(TRANSCRIPTION_OUTPUT_DIR, f"{unique_id}_ctc.txt")
    transcription_output_path_rnnt = os.path.join(TRANSCRIPTION_OUTPUT_DIR, f"{unique_id}_rnnt.txt")

    try:
        # 5.2. Asynchronously save the uploaded file
        async with aiofiles.open(uploaded_file_path, "wb") as f:
            while content := await file.read(1024 * 1024):
                await f.write(content)

        # 5.3. Handle potential file upload errors (e.g., empty file)
        if not os.path.exists(uploaded_file_path) or os.path.getsize(uploaded_file_path) == 0:
            raise HTTPException(status_code=400, detail="Uploaded file is empty or could not be saved.")

        # 5.4. Convert audio (run blocking operation in a thread pool)
        #      This is where pydub uses ffmpeg
        await run_in_threadpool(
            _convert_audio_sync, uploaded_file_path, converted_audio_path
        )

        # 5.5. Load and preprocess the converted audio for the new model
        wav, sr = torchaudio.load(converted_audio_path)
        wav = torch.mean(wav, dim=0, keepdim=True) # Convert to mono if stereo

        if sr != TARGET_SAMPLE_RATE:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=TARGET_SAMPLE_RATE)
            wav = resampler(wav)
        
        wav = wav.to(DEVICE) # Move tensor to the correct device

        # 5.6. Perform transcription using both CTC and RNNT decoding
        with torch.no_grad(): # Disable gradient calculation for inference
            #transcription_ctc = ASR_MODEL(wav, "ml", "ctc")
            transcription_rnnt = ASR_MODEL(wav, "ml", "rnnt")

        # 5.7. Save transcriptions (optional)
        #async with aiofiles.open(transcription_output_path_ctc, "w", encoding="utf-8") as f:
        #    await f.write(transcription_ctc)
        
        async with aiofiles.open(transcription_output_path_rnnt, "w", encoding="utf-8") as f:
            await f.write(transcription_rnnt)

        # 5.8. Return the transcriptions
        return {
        #    "ctc_transcription": transcription_ctc,
            "rnnt_transcription": transcription_rnnt
        }

    except Exception as e:
        # 5.9. Centralized error handling
        print(f"Error during transcription process: {e}")
        # Specific error for file not found or corrupted during conversion
        if "File not found" in str(e) or "Error parsing" in str(e):
            raise HTTPException(status_code=422, detail=f"Could not process audio file: {e}")
        # General server error
        raise HTTPException(status_code=500, detail=f"An internal server error occurred: {e}")

    finally:
        # 5.10. Clean up temporary files
        await file.close() # Close the UploadFile's underlying file handle
        if os.path.exists(uploaded_file_path):
            os.remove(uploaded_file_path)
        if os.path.exists(converted_audio_path):
            os.remove(converted_audio_path)