Spaces:
Running
on
Zero
Running
on
Zero
File size: 16,107 Bytes
11c4a5a 7b84154 8417fa3 7b84154 11c4a5a 7b84154 11c4a5a 8417fa3 7b84154 11c4a5a 7b84154 11c4a5a 7b84154 11c4a5a fc64c8b 11c4a5a fc64c8b 11c4a5a 935d736 11c4a5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 |
import time
from typing import Optional, Tuple
from app.interfaces import IStreamingSpeechEngine
import numpy as np
import torch
import gc
from omegaconf import OmegaConf
from nemo.collections.asr.models.aed_multitask_models import lens_to_mask
from nemo.collections.asr.parts.submodules.aed_decoding import (
GreedyBatchedStreamingAEDComputer,
return_decoder_input_ids,
)
from nemo.collections.asr.parts.submodules.multitask_decoding import (
AEDStreamingDecodingConfig,
MultiTaskDecodingConfig,
)
# from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer # Not used
from nemo.collections.asr.parts.utils.streaming_utils import (
ContextSize,
StreamingBatchedAudioBuffer,
)
import nemo.collections.asr as nemo_asr
from nemo.collections.asr.parts.utils.transcribe_utils import (
get_inference_device,
get_inference_dtype,
)
from app.logger_config import (
logger as logging,
DEBUG
)
from dataclasses import dataclass
from typing import Optional, Literal
@dataclass
class CanaryConfig:
chunk_secs: float = 1.0
left_context_secs: float = 20.0
right_context_secs: float = 0.5
cuda: Optional[bool] = None
allow_mps: bool = True
compute_dtype: Optional[str] = None
matmul_precision: str = "high"
batch_size= 1
decoding: dict = None
streaming_policy: str = "alignatt"
alignatt_thr: float = 8.0
waitk_lagging: int = 2
exclude_sink_frames: int = 8
xatt_scores_layer: int = -2
max_tokens_per_alignatt_step: int = 30
max_generation_length: int = 512
use_avgpool_for_alignatt: bool = False
hallucinations_detector: bool = True
prompt: dict = None
pnc: str = "no"
task: str = "asr"
source_lang: str = "fr"
target_lang: str = "fr"
timestamps: bool = True
def __post_init__(self):
if self.decoding is None:
self.decoding = {
"streaming_policy": self.streaming_policy,
"alignatt_thr": self.alignatt_thr,
"waitk_lagging": self.waitk_lagging,
"exclude_sink_frames": self.exclude_sink_frames,
"xatt_scores_layer": self.xatt_scores_layer,
"max_tokens_per_alignatt_step": self.max_tokens_per_alignatt_step,
"max_generation_length": self.max_generation_length,
"use_avgpool_for_alignatt": self.use_avgpool_for_alignatt,
"hallucinations_detector": self.hallucinations_detector
}
if self.prompt is None:
self.prompt = {
"pnc": self.pnc,
"task": self.task,
"source_lang": self.source_lang,
"target_lang": self.target_lang,
"timestamps": self.timestamps
}
def toOmegaConf(self) -> OmegaConf:
"""Convert the config to OmegaConf format"""
config_dict = {
"chunk_secs": self.chunk_secs,
"left_context_secs": self.left_context_secs,
"right_context_secs": self.right_context_secs,
"cuda": self.cuda,
"allow_mps": self.allow_mps,
"compute_dtype": self.compute_dtype,
"matmul_precision": self.matmul_precision,
"batch_size": self.batch_size,
"decoding": self.decoding,
"prompt": self.prompt
}
# Remove None values
filtered_dict = {k: v for k, v in config_dict.items() if v is not None}
return OmegaConf.create(filtered_dict)
@classmethod
def from_params(
cls,
task_type: str,
source_lang: str,
target_lang: str,
chunk_secs: float = 1.0,
left_context_secs: float = 20.0,
right_context_secs: float = 0.5,
streaming_policy: str = "alignatt",
alignatt_thr: float = 8.0,
waitk_lagging: int = 2,
exclude_sink_frames: int = 8,
xatt_scores_layer: int = -2,
hallucinations_detector: bool = True
):
"""Create a CanaryConfig instance from parameters"""
# Convert task type to model task
task = "asr" if task_type == "Transcription" else "ast"
target_lang = source_lang if task_type == "Transcription" else target_lang
return cls(
chunk_secs=chunk_secs,
left_context_secs=left_context_secs,
right_context_secs=right_context_secs,
streaming_policy=streaming_policy,
alignatt_thr=alignatt_thr,
waitk_lagging=waitk_lagging,
exclude_sink_frames=exclude_sink_frames,
xatt_scores_layer=xatt_scores_layer,
hallucinations_detector=hallucinations_detector,
task=task,
source_lang=source_lang,
target_lang=target_lang
)
def make_divisible_by(num: int, factor: int) -> int:
"""Make num divisible by factor"""
return (num // factor) * factor
class CanarySpeechEngine(IStreamingSpeechEngine):
"""
Encapsulates the state and logic for streaming audio transcription
using an internally loaded Canary model.
"""
def __init__(self,asr_model, cfg: CanaryConfig):
"""
Initializes the speech engine and loads the ASR model.
Args:
cfg: An OmegaConf object containing 'model' and 'streaming' configs.
"""
logging.debug(f"Initializing CanarySpeechEngine with config: {cfg}")
self.cfg = cfg.toOmegaConf() # Store the full config
# Setup device and dtype from config
self.map_location = get_inference_device(cuda=None, allow_mps=self.cfg.allow_mps)
self.compute_dtype = get_inference_dtype(None, device=self.map_location)
logging.info(f"Inference will be on device: {self.map_location} with dtype: {self.compute_dtype}")
# Load the model internally
asr_model, _ = self._setup_model(asr_model,self.cfg, self.map_location)
self.asr_model = asr_model
self.full_transcription = [] # Stores finalized segments
self._setup_streaming_params()
# The initial full reset (buffer + decoder)
self.reset()
logging.info("CanarySpeechEngine initialized and ready.")
logging.info(f"Model-adjusted chunk size: {self.context_samples.chunk} samples.")
def _setup_model(self,asr_model, model_cfg: OmegaConf, map_location: str):
"""Loads the pretrained ASR model and configures it for inference."""
logging.info(f"Loading model ...")
start_time = time.time()
try:
asr_model = asr_model.to(map_location)
asr_model.eval()
# Change decoding strategy to greedy for streaming
if hasattr(asr_model, 'change_decoding_strategy'):
multitask_decoding = MultiTaskDecodingConfig()
multitask_decoding.strategy = "greedy"
asr_model.change_decoding_strategy(multitask_decoding)
logging.info("Model decoding strategy set to 'greedy'.")
if map_location == "cuda":
torch.cuda.synchronize()
end_time = time.time()
logging.info("Model loaded successfully.")
load_time = end_time - start_time
logging.info("\n" + "="*30)
logging.info(f"Total model load time: {load_time:.2f} seconds")
logging.info("="*30)
return asr_model, None
except Exception as e:
logging.error(f"Error loading model: {e}")
logging.error("Ensure NeMo is installed (pip install nemo_toolkit['asr'])")
return None, None
def _setup_streaming_params(self):
"""Helper to calculate model-specific streaming parameters."""
model_cfg = self.asr_model.cfg
audio_sample_rate = model_cfg.preprocessor['sample_rate']
self.feature_stride_sec = model_cfg.preprocessor['window_stride']
features_per_sec = 1.0 / self.feature_stride_sec
self.encoder_subsampling_factor = self.asr_model.encoder.subsampling_factor
self.features_frame2audio_samples = make_divisible_by(
int(audio_sample_rate * self.feature_stride_sec ), factor=self.encoder_subsampling_factor
)
encoder_frame2audio_samples = self.features_frame2audio_samples * self.encoder_subsampling_factor
# Use self.cfg.streaming instead of self.streaming_cfg
streaming_cfg = self.cfg
self.context_encoder_frames = ContextSize(
left=int(streaming_cfg.left_context_secs * features_per_sec / self.encoder_subsampling_factor),
chunk=int(streaming_cfg.chunk_secs * features_per_sec / self.encoder_subsampling_factor),
right=int(streaming_cfg.right_context_secs * features_per_sec / self.encoder_subsampling_factor),
)
self.context_samples = ContextSize(
left=self.context_encoder_frames.left * encoder_frame2audio_samples,
chunk=self.context_encoder_frames.chunk * encoder_frame2audio_samples,
right=self.context_encoder_frames.right * encoder_frame2audio_samples,
)
def _reset_decoder_state(self):
"""
Resets ONLY the decoder state, preserving the audio buffer.
This prevents slowdowns on long audio streams.
"""
start_time = time.perf_counter()
logging.debug("--- Resetting decoder state (audio buffer preserved) ---")
# Reset tracking for this segment
self.last_transcription = ""
self.chunk_count = 0
batch_size = 1 # Hardcoded for this script
# Use self.cfg.streaming instead of self.streaming_cfg
streaming_cfg = self.cfg
# 1. Recreate the initial prompt for the decoder
self.decoder_input_ids = return_decoder_input_ids(streaming_cfg, self.asr_model)
# 2. Recreate the "computer" object that manages decoding
self.decoding_computer = GreedyBatchedStreamingAEDComputer(
self.asr_model,
frame_chunk_size=self.context_encoder_frames.chunk,
decoding_cfg=streaming_cfg.decoding,
)
# 3. Recreate an EMPTY STATE object (model_state)
self.model_state = GreedyBatchedStreamingAEDComputer.initialize_aed_model_state(
asr_model=self.asr_model,
decoder_input_ids=self.decoder_input_ids,
batch_size=batch_size,
context_encoder_frames=self.context_encoder_frames,
chunk_secs=streaming_cfg.chunk_secs,
right_context_secs=streaming_cfg.right_context_secs,
)
# Clear CUDA cache if possible
if torch.cuda.is_available():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
end_time = time.perf_counter()
duration_ms = (end_time - start_time) * 1000 # Convert to milliseconds
logging.debug(f"--- Decoder reset finished in {duration_ms:.2f} ms ---")
def reset(self):
"""
Resets the transcriber's state completely (audio buffer + decoder state).
Called only on initialization.
"""
start_time = time.perf_counter()
logging.debug("--- FULL RESET (Audio Buffer + Decoder State) ---")
# Operation 1: Reset the decoder (this now includes GC)
self._reset_decoder_state()
# Operation 2: Reset the audio buffer
self.buffer = StreamingBatchedAudioBuffer(
batch_size=1, # Hardcoded for this script
context_samples=self.context_samples,
dtype=torch.float32,
device=self.map_location,
)
end_time = time.perf_counter()
duration_ms = (end_time * 1000)
logging.debug(f"--- RESET Complete: took {duration_ms:.2f} ms ---")
def transcribe_chunk(self, chunk: np.ndarray, is_last_chunk: bool = False) -> Tuple[str, str]:
"""
Processes a single audio chunk and returns the newly predicted text.
Returns:
Tuple[str, str]:
(current_transcription: The full transcription for the current segment,
new_text: The newly appended text since the last chunk)
"""
start_time = time.perf_counter()
self.chunk_count += 1
# Preprocess audio
signal = torch.from_numpy(chunk.astype(np.float32) / 32768.0)
audio_batch = signal.unsqueeze(0).to(self.map_location)
audio_batch_lengths = torch.tensor([signal.shape[0]], device=self.map_location)
# 1. Add the chunk to the persistent buffer
self.buffer.add_audio_batch_(
audio_batch,
audio_lengths=audio_batch_lengths,
is_last_chunk=is_last_chunk,
is_last_chunk_batch=torch.tensor([is_last_chunk], device=self.map_location)
)
self.model_state.is_last_chunk_batch = torch.tensor([is_last_chunk], device=self.map_location)
# 2. Pass the buffer to the encoder
_, encoded_len, enc_states, _ = self.asr_model(
input_signal=self.buffer.samples, input_signal_length=self.buffer.context_size_batch.total()
)
encoder_context_batch = self.buffer.context_size_batch.subsample(factor=self.features_frame2audio_samples * self.encoder_subsampling_factor)
encoded_len_no_rc = encoder_context_batch.left + encoder_context_batch.chunk
encoded_length_corrected = torch.where(self.model_state.is_last_chunk_batch, encoded_len, encoded_len_no_rc)
encoder_input_mask = lens_to_mask(encoded_length_corrected, enc_states.shape[1]).to(enc_states.dtype)
# 3. Pass to the decoding computer
self.model_state = self.decoding_computer(
encoder_output=enc_states,
encoder_output_len=encoded_length_corrected,
encoder_input_mask=encoder_input_mask,
prev_batched_state=self.model_state,
)
# 4. Calculate the new text
current_tokens = self.model_state.pred_tokens_ids[0, self.decoder_input_ids.size(-1): self.model_state.current_context_lengths[0]]
# OPTIMIZATION: Move tokens to CPU before converting to list
current_transcription = self.asr_model.tokenizer.ids_to_text(current_tokens.cpu().tolist()).strip()
# Calculate the NEW text by "subtracting" the old history
new_text = ""
if current_transcription.startswith(self.last_transcription):
new_text = current_transcription[len(self.last_transcription):]
else:
# Model corrected itself, send the full new transcription
new_text = current_transcription
# Memorize the FULL current transcription as the new history
if new_text:
self.last_transcription = current_transcription
end_time = time.perf_counter()
duration_ms = (end_time - start_time) * 1000
# logging.info(f"--- transcribe_chunk: took {duration_ms:.2f} ms ---")
# Return both the full segment transcription and the new diff
yield current_transcription, new_text
def finalize_segment(self):
"""
Finalizes the current transcription segment (e.g., on silence)
and adds it to the full history.
"""
if self.last_transcription:
self.full_transcription.append(self.last_transcription)
self.last_transcription = ""
# We must reset the decoder state to start a new segment
self._reset_decoder_state()
def get_full_transcription(self) -> str:
"""
Returns the full accumulated transcription from all finalized segments.
Does NOT include the currently active (unfinalized) segment.
"""
return " ".join(self.full_transcription)
def get_current_segment_text(self) -> str:
"""Returns the text of the segment currently being transcribed."""
return self.last_transcription
|