Spaces:
Running
on
Zero
Running
on
Zero
| import time | |
| from typing import Optional, Tuple | |
| from app.interfaces import IStreamingSpeechEngine | |
| import numpy as np | |
| import torch | |
| import gc | |
| from omegaconf import OmegaConf | |
| from nemo.collections.asr.models.aed_multitask_models import lens_to_mask | |
| from nemo.collections.asr.parts.submodules.aed_decoding import ( | |
| GreedyBatchedStreamingAEDComputer, | |
| return_decoder_input_ids, | |
| ) | |
| from nemo.collections.asr.parts.submodules.multitask_decoding import ( | |
| AEDStreamingDecodingConfig, | |
| MultiTaskDecodingConfig, | |
| ) | |
| # from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer # Not used | |
| from nemo.collections.asr.parts.utils.streaming_utils import ( | |
| ContextSize, | |
| StreamingBatchedAudioBuffer, | |
| ) | |
| import nemo.collections.asr as nemo_asr | |
| from nemo.collections.asr.parts.utils.transcribe_utils import ( | |
| get_inference_device, | |
| get_inference_dtype, | |
| ) | |
| from app.logger_config import ( | |
| logger as logging, | |
| DEBUG | |
| ) | |
| from dataclasses import dataclass | |
| from typing import Optional, Literal | |
| class CanaryConfig: | |
| chunk_secs: float = 1.0 | |
| left_context_secs: float = 20.0 | |
| right_context_secs: float = 0.5 | |
| cuda: Optional[bool] = None | |
| allow_mps: bool = True | |
| compute_dtype: Optional[str] = None | |
| matmul_precision: str = "high" | |
| batch_size= 1 | |
| decoding: dict = None | |
| streaming_policy: str = "alignatt" | |
| alignatt_thr: float = 8.0 | |
| waitk_lagging: int = 2 | |
| exclude_sink_frames: int = 8 | |
| xatt_scores_layer: int = -2 | |
| max_tokens_per_alignatt_step: int = 30 | |
| max_generation_length: int = 512 | |
| use_avgpool_for_alignatt: bool = False | |
| hallucinations_detector: bool = True | |
| prompt: dict = None | |
| pnc: str = "no" | |
| task: str = "asr" | |
| source_lang: str = "fr" | |
| target_lang: str = "fr" | |
| timestamps: bool = True | |
| def __post_init__(self): | |
| if self.decoding is None: | |
| self.decoding = { | |
| "streaming_policy": self.streaming_policy, | |
| "alignatt_thr": self.alignatt_thr, | |
| "waitk_lagging": self.waitk_lagging, | |
| "exclude_sink_frames": self.exclude_sink_frames, | |
| "xatt_scores_layer": self.xatt_scores_layer, | |
| "max_tokens_per_alignatt_step": self.max_tokens_per_alignatt_step, | |
| "max_generation_length": self.max_generation_length, | |
| "use_avgpool_for_alignatt": self.use_avgpool_for_alignatt, | |
| "hallucinations_detector": self.hallucinations_detector | |
| } | |
| if self.prompt is None: | |
| self.prompt = { | |
| "pnc": self.pnc, | |
| "task": self.task, | |
| "source_lang": self.source_lang, | |
| "target_lang": self.target_lang, | |
| "timestamps": self.timestamps | |
| } | |
| def toOmegaConf(self) -> OmegaConf: | |
| """Convert the config to OmegaConf format""" | |
| config_dict = { | |
| "chunk_secs": self.chunk_secs, | |
| "left_context_secs": self.left_context_secs, | |
| "right_context_secs": self.right_context_secs, | |
| "cuda": self.cuda, | |
| "allow_mps": self.allow_mps, | |
| "compute_dtype": self.compute_dtype, | |
| "matmul_precision": self.matmul_precision, | |
| "batch_size": self.batch_size, | |
| "decoding": self.decoding, | |
| "prompt": self.prompt | |
| } | |
| # Remove None values | |
| filtered_dict = {k: v for k, v in config_dict.items() if v is not None} | |
| return OmegaConf.create(filtered_dict) | |
| def from_params( | |
| cls, | |
| task_type: str, | |
| source_lang: str, | |
| target_lang: str, | |
| chunk_secs: float = 1.0, | |
| left_context_secs: float = 20.0, | |
| right_context_secs: float = 0.5, | |
| streaming_policy: str = "alignatt", | |
| alignatt_thr: float = 8.0, | |
| waitk_lagging: int = 2, | |
| exclude_sink_frames: int = 8, | |
| xatt_scores_layer: int = -2, | |
| hallucinations_detector: bool = True | |
| ): | |
| """Create a CanaryConfig instance from parameters""" | |
| # Convert task type to model task | |
| task = "asr" if task_type == "Transcription" else "ast" | |
| target_lang = source_lang if task_type == "Transcription" else target_lang | |
| return cls( | |
| chunk_secs=chunk_secs, | |
| left_context_secs=left_context_secs, | |
| right_context_secs=right_context_secs, | |
| streaming_policy=streaming_policy, | |
| alignatt_thr=alignatt_thr, | |
| waitk_lagging=waitk_lagging, | |
| exclude_sink_frames=exclude_sink_frames, | |
| xatt_scores_layer=xatt_scores_layer, | |
| hallucinations_detector=hallucinations_detector, | |
| task=task, | |
| source_lang=source_lang, | |
| target_lang=target_lang | |
| ) | |
| def make_divisible_by(num: int, factor: int) -> int: | |
| """Make num divisible by factor""" | |
| return (num // factor) * factor | |
| class CanarySpeechEngine(IStreamingSpeechEngine): | |
| """ | |
| Encapsulates the state and logic for streaming audio transcription | |
| using an internally loaded Canary model. | |
| """ | |
| def __init__(self,asr_model, cfg: CanaryConfig): | |
| """ | |
| Initializes the speech engine and loads the ASR model. | |
| Args: | |
| cfg: An OmegaConf object containing 'model' and 'streaming' configs. | |
| """ | |
| logging.debug(f"Initializing CanarySpeechEngine with config: {cfg}") | |
| self.cfg = cfg.toOmegaConf() # Store the full config | |
| # Setup device and dtype from config | |
| self.map_location = get_inference_device(cuda=None, allow_mps=self.cfg.allow_mps) | |
| self.compute_dtype = get_inference_dtype(None, device=self.map_location) | |
| logging.info(f"Inference will be on device: {self.map_location} with dtype: {self.compute_dtype}") | |
| # Load the model internally | |
| asr_model, _ = self._setup_model(asr_model,self.cfg, self.map_location) | |
| self.asr_model = asr_model | |
| self.full_transcription = [] # Stores finalized segments | |
| self._setup_streaming_params() | |
| # The initial full reset (buffer + decoder) | |
| self.reset() | |
| logging.info("CanarySpeechEngine initialized and ready.") | |
| logging.info(f"Model-adjusted chunk size: {self.context_samples.chunk} samples.") | |
| def _setup_model(self,asr_model, model_cfg: OmegaConf, map_location: str): | |
| """Loads the pretrained ASR model and configures it for inference.""" | |
| logging.info(f"Loading model ...") | |
| start_time = time.time() | |
| try: | |
| asr_model = asr_model.to(map_location) | |
| asr_model.eval() | |
| # Change decoding strategy to greedy for streaming | |
| if hasattr(asr_model, 'change_decoding_strategy'): | |
| multitask_decoding = MultiTaskDecodingConfig() | |
| multitask_decoding.strategy = "greedy" | |
| asr_model.change_decoding_strategy(multitask_decoding) | |
| logging.info("Model decoding strategy set to 'greedy'.") | |
| if map_location == "cuda": | |
| torch.cuda.synchronize() | |
| end_time = time.time() | |
| logging.info("Model loaded successfully.") | |
| load_time = end_time - start_time | |
| logging.info("\n" + "="*30) | |
| logging.info(f"Total model load time: {load_time:.2f} seconds") | |
| logging.info("="*30) | |
| return asr_model, None | |
| except Exception as e: | |
| logging.error(f"Error loading model: {e}") | |
| logging.error("Ensure NeMo is installed (pip install nemo_toolkit['asr'])") | |
| return None, None | |
| def _setup_streaming_params(self): | |
| """Helper to calculate model-specific streaming parameters.""" | |
| model_cfg = self.asr_model.cfg | |
| audio_sample_rate = model_cfg.preprocessor['sample_rate'] | |
| self.feature_stride_sec = model_cfg.preprocessor['window_stride'] | |
| features_per_sec = 1.0 / self.feature_stride_sec | |
| self.encoder_subsampling_factor = self.asr_model.encoder.subsampling_factor | |
| self.features_frame2audio_samples = make_divisible_by( | |
| int(audio_sample_rate * self.feature_stride_sec ), factor=self.encoder_subsampling_factor | |
| ) | |
| encoder_frame2audio_samples = self.features_frame2audio_samples * self.encoder_subsampling_factor | |
| # Use self.cfg.streaming instead of self.streaming_cfg | |
| streaming_cfg = self.cfg | |
| self.context_encoder_frames = ContextSize( | |
| left=int(streaming_cfg.left_context_secs * features_per_sec / self.encoder_subsampling_factor), | |
| chunk=int(streaming_cfg.chunk_secs * features_per_sec / self.encoder_subsampling_factor), | |
| right=int(streaming_cfg.right_context_secs * features_per_sec / self.encoder_subsampling_factor), | |
| ) | |
| self.context_samples = ContextSize( | |
| left=self.context_encoder_frames.left * encoder_frame2audio_samples, | |
| chunk=self.context_encoder_frames.chunk * encoder_frame2audio_samples, | |
| right=self.context_encoder_frames.right * encoder_frame2audio_samples, | |
| ) | |
| def _reset_decoder_state(self): | |
| """ | |
| Resets ONLY the decoder state, preserving the audio buffer. | |
| This prevents slowdowns on long audio streams. | |
| """ | |
| start_time = time.perf_counter() | |
| logging.debug("--- Resetting decoder state (audio buffer preserved) ---") | |
| # Reset tracking for this segment | |
| self.last_transcription = "" | |
| self.chunk_count = 0 | |
| batch_size = 1 # Hardcoded for this script | |
| # Use self.cfg.streaming instead of self.streaming_cfg | |
| streaming_cfg = self.cfg | |
| # 1. Recreate the initial prompt for the decoder | |
| self.decoder_input_ids = return_decoder_input_ids(streaming_cfg, self.asr_model) | |
| # 2. Recreate the "computer" object that manages decoding | |
| self.decoding_computer = GreedyBatchedStreamingAEDComputer( | |
| self.asr_model, | |
| frame_chunk_size=self.context_encoder_frames.chunk, | |
| decoding_cfg=streaming_cfg.decoding, | |
| ) | |
| # 3. Recreate an EMPTY STATE object (model_state) | |
| self.model_state = GreedyBatchedStreamingAEDComputer.initialize_aed_model_state( | |
| asr_model=self.asr_model, | |
| decoder_input_ids=self.decoder_input_ids, | |
| batch_size=batch_size, | |
| context_encoder_frames=self.context_encoder_frames, | |
| chunk_secs=streaming_cfg.chunk_secs, | |
| right_context_secs=streaming_cfg.right_context_secs, | |
| ) | |
| # Clear CUDA cache if possible | |
| if torch.cuda.is_available(): | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| end_time = time.perf_counter() | |
| duration_ms = (end_time - start_time) * 1000 # Convert to milliseconds | |
| logging.debug(f"--- Decoder reset finished in {duration_ms:.2f} ms ---") | |
| def reset(self): | |
| """ | |
| Resets the transcriber's state completely (audio buffer + decoder state). | |
| Called only on initialization. | |
| """ | |
| start_time = time.perf_counter() | |
| logging.debug("--- FULL RESET (Audio Buffer + Decoder State) ---") | |
| # Operation 1: Reset the decoder (this now includes GC) | |
| self._reset_decoder_state() | |
| # Operation 2: Reset the audio buffer | |
| self.buffer = StreamingBatchedAudioBuffer( | |
| batch_size=1, # Hardcoded for this script | |
| context_samples=self.context_samples, | |
| dtype=torch.float32, | |
| device=self.map_location, | |
| ) | |
| end_time = time.perf_counter() | |
| duration_ms = (end_time * 1000) | |
| logging.debug(f"--- RESET Complete: took {duration_ms:.2f} ms ---") | |
| def transcribe_chunk(self, chunk: np.ndarray, is_last_chunk: bool = False) -> Tuple[str, str]: | |
| """ | |
| Processes a single audio chunk and returns the newly predicted text. | |
| Returns: | |
| Tuple[str, str]: | |
| (current_transcription: The full transcription for the current segment, | |
| new_text: The newly appended text since the last chunk) | |
| """ | |
| start_time = time.perf_counter() | |
| self.chunk_count += 1 | |
| # Preprocess audio | |
| signal = torch.from_numpy(chunk.astype(np.float32) / 32768.0) | |
| audio_batch = signal.unsqueeze(0).to(self.map_location) | |
| audio_batch_lengths = torch.tensor([signal.shape[0]], device=self.map_location) | |
| # 1. Add the chunk to the persistent buffer | |
| self.buffer.add_audio_batch_( | |
| audio_batch, | |
| audio_lengths=audio_batch_lengths, | |
| is_last_chunk=is_last_chunk, | |
| is_last_chunk_batch=torch.tensor([is_last_chunk], device=self.map_location) | |
| ) | |
| self.model_state.is_last_chunk_batch = torch.tensor([is_last_chunk], device=self.map_location) | |
| # 2. Pass the buffer to the encoder | |
| _, encoded_len, enc_states, _ = self.asr_model( | |
| input_signal=self.buffer.samples, input_signal_length=self.buffer.context_size_batch.total() | |
| ) | |
| encoder_context_batch = self.buffer.context_size_batch.subsample(factor=self.features_frame2audio_samples * self.encoder_subsampling_factor) | |
| encoded_len_no_rc = encoder_context_batch.left + encoder_context_batch.chunk | |
| encoded_length_corrected = torch.where(self.model_state.is_last_chunk_batch, encoded_len, encoded_len_no_rc) | |
| encoder_input_mask = lens_to_mask(encoded_length_corrected, enc_states.shape[1]).to(enc_states.dtype) | |
| # 3. Pass to the decoding computer | |
| self.model_state = self.decoding_computer( | |
| encoder_output=enc_states, | |
| encoder_output_len=encoded_length_corrected, | |
| encoder_input_mask=encoder_input_mask, | |
| prev_batched_state=self.model_state, | |
| ) | |
| # 4. Calculate the new text | |
| current_tokens = self.model_state.pred_tokens_ids[0, self.decoder_input_ids.size(-1): self.model_state.current_context_lengths[0]] | |
| # OPTIMIZATION: Move tokens to CPU before converting to list | |
| current_transcription = self.asr_model.tokenizer.ids_to_text(current_tokens.cpu().tolist()).strip() | |
| # Calculate the NEW text by "subtracting" the old history | |
| new_text = "" | |
| if current_transcription.startswith(self.last_transcription): | |
| new_text = current_transcription[len(self.last_transcription):] | |
| else: | |
| # Model corrected itself, send the full new transcription | |
| new_text = current_transcription | |
| # Memorize the FULL current transcription as the new history | |
| if new_text: | |
| self.last_transcription = current_transcription | |
| end_time = time.perf_counter() | |
| duration_ms = (end_time - start_time) * 1000 | |
| # logging.info(f"--- transcribe_chunk: took {duration_ms:.2f} ms ---") | |
| # Return both the full segment transcription and the new diff | |
| yield current_transcription, new_text | |
| def finalize_segment(self): | |
| """ | |
| Finalizes the current transcription segment (e.g., on silence) | |
| and adds it to the full history. | |
| """ | |
| if self.last_transcription: | |
| self.full_transcription.append(self.last_transcription) | |
| self.last_transcription = "" | |
| # We must reset the decoder state to start a new segment | |
| self._reset_decoder_state() | |
| def get_full_transcription(self) -> str: | |
| """ | |
| Returns the full accumulated transcription from all finalized segments. | |
| Does NOT include the currently active (unfinalized) segment. | |
| """ | |
| return " ".join(self.full_transcription) | |
| def get_current_segment_text(self) -> str: | |
| """Returns the text of the segment currently being transcribed.""" | |
| return self.last_transcription | |