ogflash commited on
Commit
d3a18f1
·
1 Parent(s): 69e578b

Fix: added webui for asr api

Browse files
Files changed (2) hide show
  1. main.py +34 -21
  2. static/index.html +210 -0
main.py CHANGED
@@ -1,13 +1,15 @@
1
  import os
2
  import torch
3
  import torchaudio
4
- from transformers import AutoModel # For the new model
5
- from pydub import AudioSegment # Requires ffmpeg installed on system
6
- import aiofiles # For asynchronous file operations
7
- import uuid # For generating unique filenames
8
 
9
  from fastapi import FastAPI, HTTPException, File, UploadFile
10
- from starlette.concurrency import run_in_threadpool # For running blocking code in background thread
 
 
11
 
12
  # -----------------------------------------------------------
13
  # 1. FastAPI App Instance
@@ -16,7 +18,7 @@ app = FastAPI()
16
 
17
  # -----------------------------------------------------------
18
  # 2. Global Variables (for model and directories)
19
- # These will be initialized during startup
20
  # -----------------------------------------------------------
21
  ASR_MODEL = None
22
  DEVICE = None
@@ -27,7 +29,7 @@ TARGET_SAMPLE_RATE = 16000 # Required sample rate for the new model
27
 
28
  # -----------------------------------------------------------
29
  # 3. Startup Event: Load Model and Create Directories
30
- # This runs once when the FastAPI application starts
31
  # -----------------------------------------------------------
32
  @app.on_event("startup")
33
  async def startup_event():
@@ -41,12 +43,29 @@ async def startup_event():
41
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
  ASR_MODEL = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True)
43
  ASR_MODEL.to(DEVICE)
44
- ASR_MODEL.eval() # Set model to evaluation mode
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  # -----------------------------------------------------------
48
- # 4. Helper Function: Audio Conversion
49
- # This function performs the actual audio conversion (blocking operation)
50
  # -----------------------------------------------------------
51
  def _convert_audio_sync(input_path: str, output_path: str, target_sample_rate: int = TARGET_SAMPLE_RATE, channels: int = 1):
52
  audio = AudioSegment.from_file(input_path)
@@ -55,7 +74,7 @@ def _convert_audio_sync(input_path: str, output_path: str, target_sample_rate: i
55
 
56
 
57
  # -----------------------------------------------------------
58
- # 5. Main API Endpoint: Handle File Upload and Transcription
59
  # -----------------------------------------------------------
60
  @app.post('/transcribefile/')
61
  async def transcribe_file(file: UploadFile = File(...)):
@@ -63,7 +82,6 @@ async def transcribe_file(file: UploadFile = File(...)):
63
  unique_id = str(uuid.uuid4())
64
  uploaded_file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_{file.filename}")
65
  converted_audio_path = os.path.join(CONVERTED_AUDIO_DIR, f"{unique_id}.wav")
66
- #transcription_output_path_ctc = os.path.join(TRANSCRIPTION_OUTPUT_DIR, f"{unique_id}_ctc.txt")
67
  transcription_output_path_rnnt = os.path.join(TRANSCRIPTION_OUTPUT_DIR, f"{unique_id}_rnnt.txt")
68
 
69
  try:
@@ -77,7 +95,7 @@ async def transcribe_file(file: UploadFile = File(...)):
77
  raise HTTPException(status_code=400, detail="Uploaded file is empty or could not be saved.")
78
 
79
  # 5.4. Convert audio (run blocking operation in a thread pool)
80
- # This is where pydub uses ffmpeg
81
  await run_in_threadpool(
82
  _convert_audio_sync, uploaded_file_path, converted_audio_path
83
  )
@@ -92,21 +110,16 @@ async def transcribe_file(file: UploadFile = File(...)):
92
 
93
  wav = wav.to(DEVICE) # Move tensor to the correct device
94
 
95
- # 5.6. Perform transcription using both CTC and RNNT decoding
96
  with torch.no_grad(): # Disable gradient calculation for inference
97
- #transcription_ctc = ASR_MODEL(wav, "ml", "ctc")
98
  transcription_rnnt = ASR_MODEL(wav, "ml", "rnnt")
99
 
100
- # 5.7. Save transcriptions (optional)
101
- #async with aiofiles.open(transcription_output_path_ctc, "w", encoding="utf-8") as f:
102
- # await f.write(transcription_ctc)
103
-
104
  async with aiofiles.open(transcription_output_path_rnnt, "w", encoding="utf-8") as f:
105
  await f.write(transcription_rnnt)
106
 
107
- # 5.8. Return the transcriptions
108
  return {
109
- # "ctc_transcription": transcription_ctc,
110
  "rnnt_transcription": transcription_rnnt
111
  }
112
 
 
1
  import os
2
  import torch
3
  import torchaudio
4
+ from transformers import AutoModel
5
+ from pydub import AudioSegment
6
+ import aiofiles
7
+ import uuid
8
 
9
  from fastapi import FastAPI, HTTPException, File, UploadFile
10
+ from starlette.concurrency import run_in_threadpool
11
+ from starlette.staticfiles import StaticFiles # <-- NEW IMPORT
12
+ from starlette.responses import HTMLResponse, RedirectResponse # <-- NEW IMPORT
13
 
14
  # -----------------------------------------------------------
15
  # 1. FastAPI App Instance
 
18
 
19
  # -----------------------------------------------------------
20
  # 2. Global Variables (for model and directories)
21
+ # These will be initialized during startup
22
  # -----------------------------------------------------------
23
  ASR_MODEL = None
24
  DEVICE = None
 
29
 
30
  # -----------------------------------------------------------
31
  # 3. Startup Event: Load Model and Create Directories
32
+ # This runs once when the FastAPI application starts
33
  # -----------------------------------------------------------
34
  @app.on_event("startup")
35
  async def startup_event():
 
43
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
  ASR_MODEL = AutoModel.from_pretrained("ai4bharat/indic-conformer-600m-multilingual", trust_remote_code=True)
45
  ASR_MODEL.to(DEVICE)
46
+ ASR_MODEL.eval()
47
 
48
+ # -----------------------------------------------------------
49
+ # 4. Mount Static Files and Define Root Endpoint (NEW)
50
+ # -----------------------------------------------------------
51
+ # Mount the 'static' directory to serve HTML, CSS, JS files
52
+ # This makes files like 'static/index.html' accessible at /static/index.html
53
+ app.mount("/static", StaticFiles(directory="static"), name="static")
54
+
55
+ # Define a root endpoint that serves your main HTML page
56
+ @app.get("/", response_class=HTMLResponse)
57
+ async def read_root():
58
+ try:
59
+ # FastAPI will serve this index.html when users visit the root URL of your Space
60
+ with open("static/index.html", "r", encoding="utf-8") as f:
61
+ return HTMLResponse(content=f.read())
62
+ except FileNotFoundError:
63
+ # This fallback should ideally not be hit if your Dockerfile copies files correctly
64
+ return HTMLResponse("<h1>Error: index.html not found!</h1><p>Please ensure 'static/index.html' exists in your project.</p>", status_code=404)
65
 
66
  # -----------------------------------------------------------
67
+ # 5. Helper Function: Audio Conversion (Existing Code)
68
+ # This function performs the actual audio conversion (blocking operation)
69
  # -----------------------------------------------------------
70
  def _convert_audio_sync(input_path: str, output_path: str, target_sample_rate: int = TARGET_SAMPLE_RATE, channels: int = 1):
71
  audio = AudioSegment.from_file(input_path)
 
74
 
75
 
76
  # -----------------------------------------------------------
77
+ # 6. Main API Endpoint: Handle File Upload and Transcription (Existing Code)
78
  # -----------------------------------------------------------
79
  @app.post('/transcribefile/')
80
  async def transcribe_file(file: UploadFile = File(...)):
 
82
  unique_id = str(uuid.uuid4())
83
  uploaded_file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_{file.filename}")
84
  converted_audio_path = os.path.join(CONVERTED_AUDIO_DIR, f"{unique_id}.wav")
 
85
  transcription_output_path_rnnt = os.path.join(TRANSCRIPTION_OUTPUT_DIR, f"{unique_id}_rnnt.txt")
86
 
87
  try:
 
95
  raise HTTPException(status_code=400, detail="Uploaded file is empty or could not be saved.")
96
 
97
  # 5.4. Convert audio (run blocking operation in a thread pool)
98
+ # This is where pydub uses ffmpeg
99
  await run_in_threadpool(
100
  _convert_audio_sync, uploaded_file_path, converted_audio_path
101
  )
 
110
 
111
  wav = wav.to(DEVICE) # Move tensor to the correct device
112
 
113
+ # 5.6. Perform transcription using RNNT decoding
114
  with torch.no_grad(): # Disable gradient calculation for inference
 
115
  transcription_rnnt = ASR_MODEL(wav, "ml", "rnnt")
116
 
117
+ # 5.7. Save transcription (optional)
 
 
 
118
  async with aiofiles.open(transcription_output_path_rnnt, "w", encoding="utf-8") as f:
119
  await f.write(transcription_rnnt)
120
 
121
+ # 5.8. Return the transcription
122
  return {
 
123
  "rnnt_transcription": transcription_rnnt
124
  }
125
 
static/index.html ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>ASR Transcription App</title>
7
+ <style>
8
+ body {
9
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
10
+ margin: 0;
11
+ padding: 20px;
12
+ background-color: #011227;
13
+ color: #333;
14
+ display: flex;
15
+ justify-content: center;
16
+ align-items: center;
17
+ min-height: 100vh;
18
+ box-sizing: border-box;
19
+ }
20
+ .container {
21
+ max-width: 650px;
22
+ width: 100%;
23
+ margin: auto;
24
+ background: linear-gradient(135deg, #ffffff, #f0f8ff);
25
+ padding: 40px;
26
+ border-radius: 12px;
27
+ box-shadow: 0 5px 20px rgba(0,0,0,0.1);
28
+ border: 1px solid #d0e0f0;
29
+ }
30
+ h1 {
31
+ text-align: center;
32
+ color: #0056b3;
33
+ margin-bottom: 30px;
34
+ font-size: 2em;
35
+ }
36
+ .form-group {
37
+ margin-bottom: 25px;
38
+ }
39
+ label {
40
+ display: block;
41
+ margin-bottom: 8px;
42
+ font-weight: bold;
43
+ color: #555;
44
+ }
45
+ input[type="file"] {
46
+ display: block;
47
+ width: 100%;
48
+ padding: 12px;
49
+ border: 1px solid #a7d0e0;
50
+ border-radius: 6px;
51
+ box-sizing: border-box;
52
+ background-color: #fcfdff;
53
+ cursor: pointer;
54
+ }
55
+ input[type="file"]::-webkit-file-upload-button {
56
+ background-color: #007bff;
57
+ color: white;
58
+ padding: 8px 15px;
59
+ border: none;
60
+ border-radius: 4px;
61
+ cursor: pointer;
62
+ margin-right: 15px;
63
+ transition: background-color 0.2s ease;
64
+ }
65
+ input[type="file"]::-webkit-file-upload-button:hover {
66
+ background-color: #0056b3;
67
+ }
68
+ button {
69
+ background-color: #28a745;
70
+ color: white;
71
+ padding: 15px 25px;
72
+ border: none;
73
+ border-radius: 6px;
74
+ cursor: pointer;
75
+ font-size: 1.1em;
76
+ width: 100%;
77
+ transition: background-color 0.2s ease, transform 0.1s ease;
78
+ }
79
+ button:hover {
80
+ background-color: #218838;
81
+ transform: translateY(-2px);
82
+ }
83
+ button:disabled {
84
+ background-color: #cccccc;
85
+ cursor: not-allowed;
86
+ }
87
+ #loading {
88
+ text-align: center;
89
+ margin-top: 30px;
90
+ font-weight: bold;
91
+ color: #007bff;
92
+ font-size: 1.1em;
93
+ display: none; /* Hidden by default */
94
+ }
95
+ #response-card {
96
+ margin-top: 30px;
97
+ padding: 20px;
98
+ background-color: #f8fafd;
99
+ border: 1px solid #d0e0f0;
100
+ border-radius: 8px;
101
+ min-height: 80px;
102
+ box-shadow: inset 0 1px 3px rgba(0,0,0,0.05);
103
+ }
104
+ #response-card strong {
105
+ color: #0056b3;
106
+ display: block;
107
+ margin-bottom: 10px;
108
+ font-size: 1.1em;
109
+ }
110
+ #transcriptionOutput {
111
+ white-space: pre-wrap; /* Preserve whitespace and line breaks */
112
+ word-wrap: break-word; /* Break long words */
113
+ font-size: 1.05em;
114
+ color: #333;
115
+ }
116
+ .error {
117
+ color: #dc3545;
118
+ font-weight: bold;
119
+ }
120
+ </style>
121
+ </head>
122
+ <body>
123
+ <div class="container">
124
+ <h1>Audio Transcription</h1>
125
+ <form id="uploadForm">
126
+ <div class="form-group">
127
+ <label for="audioFile">Select an audio or video file:</label>
128
+ <input type="file" id="audioFile" name="file" accept="audio/*,video/*">
129
+ </div>
130
+ <button type="submit" id="submitButton">Transcribe Audio</button>
131
+ </form>
132
+
133
+ <div id="loading">Processing... Please wait, this might take a moment.</div>
134
+
135
+ <div id="response-card">
136
+ <strong>Transcription Output:</strong>
137
+ <span id="transcriptionOutput"></span>
138
+ </div>
139
+ </div>
140
+
141
+ <script>
142
+ const uploadForm = document.getElementById('uploadForm');
143
+ const audioFile = document.getElementById('audioFile');
144
+ const loadingDiv = document.getElementById('loading');
145
+ const transcriptionOutput = document.getElementById('transcriptionOutput');
146
+ const submitButton = document.getElementById('submitButton');
147
+
148
+ uploadForm.addEventListener('submit', async (event) => {
149
+ event.preventDefault(); // Prevent default form submission
150
+
151
+ transcriptionOutput.textContent = ''; // Clear previous output
152
+ transcriptionOutput.classList.remove('error'); // Remove error styling
153
+ loadingDiv.style.display = 'block'; // Show loading text
154
+ submitButton.disabled = true; // Disable button during processing
155
+
156
+ const file = audioFile.files[0];
157
+ if (!file) {
158
+ transcriptionOutput.textContent = 'Please select an audio or video file.';
159
+ transcriptionOutput.classList.add('error');
160
+ loadingDiv.style.display = 'none';
161
+ submitButton.disabled = false;
162
+ return;
163
+ }
164
+
165
+ const formData = new FormData();
166
+ formData.append('file', file); // 'file' must match the parameter name in your FastAPI endpoint
167
+
168
+ try {
169
+ // Use a relative path to the API endpoint
170
+ const response = await fetch('/transcribefile/', {
171
+ method: 'POST',
172
+ body: formData,
173
+ // fetch will automatically set the 'Content-Type' header correctly for FormData
174
+ });
175
+
176
+ if (response.ok) { // Check if HTTP status is 2xx (e.g., 200 OK)
177
+ const data = await response.json();
178
+ transcriptionOutput.textContent = data.rnnt_transcription || 'No transcription found.';
179
+ } else {
180
+ // Handle API errors (e.g., 400 Bad Request, 500 Internal Server Error)
181
+ let errorMessage = `Error: ${response.status} - ${response.statusText}`;
182
+ try {
183
+ const errorData = await response.json(); // FastAPI often returns JSON for errors
184
+ if (errorData.detail) {
185
+ errorMessage = `Error: ${response.status} - ${errorData.detail}`;
186
+ } else {
187
+ errorMessage = `Error: ${response.status} - ${JSON.stringify(errorData)}`;
188
+ }
189
+ } catch (e) {
190
+ // If response is not JSON, use raw text
191
+ const rawText = await response.text();
192
+ errorMessage = `Error: ${response.status} - ${rawText.substring(0, 200)}...`; // Limit length
193
+ }
194
+ transcriptionOutput.textContent = errorMessage;
195
+ transcriptionOutput.classList.add('error');
196
+ console.error('API Error:', errorMessage);
197
+ }
198
+ } catch (error) {
199
+ // Handle network errors (e.g., server unreachable)
200
+ transcriptionOutput.textContent = `Network error: ${error.message}. Please check your connection or try again.`;
201
+ transcriptionOutput.classList.add('error');
202
+ console.error('Fetch error:', error);
203
+ } finally {
204
+ loadingDiv.style.display = 'none'; // Hide loading text
205
+ submitButton.disabled = false; // Re-enable button
206
+ }
207
+ });
208
+ </script>
209
+ </body>
210
+ </html>