canary_aed_streaming / app /canary_speech_engine.py
Archime's picture
add custom style
fc64c8b
import time
from typing import Optional, Tuple
from app.interfaces import IStreamingSpeechEngine
import numpy as np
import torch
import gc
from omegaconf import OmegaConf
from nemo.collections.asr.models.aed_multitask_models import lens_to_mask
from nemo.collections.asr.parts.submodules.aed_decoding import (
GreedyBatchedStreamingAEDComputer,
return_decoder_input_ids,
)
from nemo.collections.asr.parts.submodules.multitask_decoding import (
AEDStreamingDecodingConfig,
MultiTaskDecodingConfig,
)
# from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer # Not used
from nemo.collections.asr.parts.utils.streaming_utils import (
ContextSize,
StreamingBatchedAudioBuffer,
)
import nemo.collections.asr as nemo_asr
from nemo.collections.asr.parts.utils.transcribe_utils import (
get_inference_device,
get_inference_dtype,
)
from app.logger_config import (
logger as logging,
DEBUG
)
from dataclasses import dataclass
from typing import Optional, Literal
@dataclass
class CanaryConfig:
chunk_secs: float = 1.0
left_context_secs: float = 20.0
right_context_secs: float = 0.5
cuda: Optional[bool] = None
allow_mps: bool = True
compute_dtype: Optional[str] = None
matmul_precision: str = "high"
batch_size= 1
decoding: dict = None
streaming_policy: str = "alignatt"
alignatt_thr: float = 8.0
waitk_lagging: int = 2
exclude_sink_frames: int = 8
xatt_scores_layer: int = -2
max_tokens_per_alignatt_step: int = 30
max_generation_length: int = 512
use_avgpool_for_alignatt: bool = False
hallucinations_detector: bool = True
prompt: dict = None
pnc: str = "no"
task: str = "asr"
source_lang: str = "fr"
target_lang: str = "fr"
timestamps: bool = True
def __post_init__(self):
if self.decoding is None:
self.decoding = {
"streaming_policy": self.streaming_policy,
"alignatt_thr": self.alignatt_thr,
"waitk_lagging": self.waitk_lagging,
"exclude_sink_frames": self.exclude_sink_frames,
"xatt_scores_layer": self.xatt_scores_layer,
"max_tokens_per_alignatt_step": self.max_tokens_per_alignatt_step,
"max_generation_length": self.max_generation_length,
"use_avgpool_for_alignatt": self.use_avgpool_for_alignatt,
"hallucinations_detector": self.hallucinations_detector
}
if self.prompt is None:
self.prompt = {
"pnc": self.pnc,
"task": self.task,
"source_lang": self.source_lang,
"target_lang": self.target_lang,
"timestamps": self.timestamps
}
def toOmegaConf(self) -> OmegaConf:
"""Convert the config to OmegaConf format"""
config_dict = {
"chunk_secs": self.chunk_secs,
"left_context_secs": self.left_context_secs,
"right_context_secs": self.right_context_secs,
"cuda": self.cuda,
"allow_mps": self.allow_mps,
"compute_dtype": self.compute_dtype,
"matmul_precision": self.matmul_precision,
"batch_size": self.batch_size,
"decoding": self.decoding,
"prompt": self.prompt
}
# Remove None values
filtered_dict = {k: v for k, v in config_dict.items() if v is not None}
return OmegaConf.create(filtered_dict)
@classmethod
def from_params(
cls,
task_type: str,
source_lang: str,
target_lang: str,
chunk_secs: float = 1.0,
left_context_secs: float = 20.0,
right_context_secs: float = 0.5,
streaming_policy: str = "alignatt",
alignatt_thr: float = 8.0,
waitk_lagging: int = 2,
exclude_sink_frames: int = 8,
xatt_scores_layer: int = -2,
hallucinations_detector: bool = True
):
"""Create a CanaryConfig instance from parameters"""
# Convert task type to model task
task = "asr" if task_type == "Transcription" else "ast"
target_lang = source_lang if task_type == "Transcription" else target_lang
return cls(
chunk_secs=chunk_secs,
left_context_secs=left_context_secs,
right_context_secs=right_context_secs,
streaming_policy=streaming_policy,
alignatt_thr=alignatt_thr,
waitk_lagging=waitk_lagging,
exclude_sink_frames=exclude_sink_frames,
xatt_scores_layer=xatt_scores_layer,
hallucinations_detector=hallucinations_detector,
task=task,
source_lang=source_lang,
target_lang=target_lang
)
def make_divisible_by(num: int, factor: int) -> int:
"""Make num divisible by factor"""
return (num // factor) * factor
class CanarySpeechEngine(IStreamingSpeechEngine):
"""
Encapsulates the state and logic for streaming audio transcription
using an internally loaded Canary model.
"""
def __init__(self,asr_model, cfg: CanaryConfig):
"""
Initializes the speech engine and loads the ASR model.
Args:
cfg: An OmegaConf object containing 'model' and 'streaming' configs.
"""
logging.debug(f"Initializing CanarySpeechEngine with config: {cfg}")
self.cfg = cfg.toOmegaConf() # Store the full config
# Setup device and dtype from config
self.map_location = get_inference_device(cuda=None, allow_mps=self.cfg.allow_mps)
self.compute_dtype = get_inference_dtype(None, device=self.map_location)
logging.info(f"Inference will be on device: {self.map_location} with dtype: {self.compute_dtype}")
# Load the model internally
asr_model, _ = self._setup_model(asr_model,self.cfg, self.map_location)
self.asr_model = asr_model
self.full_transcription = [] # Stores finalized segments
self._setup_streaming_params()
# The initial full reset (buffer + decoder)
self.reset()
logging.info("CanarySpeechEngine initialized and ready.")
logging.info(f"Model-adjusted chunk size: {self.context_samples.chunk} samples.")
def _setup_model(self,asr_model, model_cfg: OmegaConf, map_location: str):
"""Loads the pretrained ASR model and configures it for inference."""
logging.info(f"Loading model ...")
start_time = time.time()
try:
asr_model = asr_model.to(map_location)
asr_model.eval()
# Change decoding strategy to greedy for streaming
if hasattr(asr_model, 'change_decoding_strategy'):
multitask_decoding = MultiTaskDecodingConfig()
multitask_decoding.strategy = "greedy"
asr_model.change_decoding_strategy(multitask_decoding)
logging.info("Model decoding strategy set to 'greedy'.")
if map_location == "cuda":
torch.cuda.synchronize()
end_time = time.time()
logging.info("Model loaded successfully.")
load_time = end_time - start_time
logging.info("\n" + "="*30)
logging.info(f"Total model load time: {load_time:.2f} seconds")
logging.info("="*30)
return asr_model, None
except Exception as e:
logging.error(f"Error loading model: {e}")
logging.error("Ensure NeMo is installed (pip install nemo_toolkit['asr'])")
return None, None
def _setup_streaming_params(self):
"""Helper to calculate model-specific streaming parameters."""
model_cfg = self.asr_model.cfg
audio_sample_rate = model_cfg.preprocessor['sample_rate']
self.feature_stride_sec = model_cfg.preprocessor['window_stride']
features_per_sec = 1.0 / self.feature_stride_sec
self.encoder_subsampling_factor = self.asr_model.encoder.subsampling_factor
self.features_frame2audio_samples = make_divisible_by(
int(audio_sample_rate * self.feature_stride_sec ), factor=self.encoder_subsampling_factor
)
encoder_frame2audio_samples = self.features_frame2audio_samples * self.encoder_subsampling_factor
# Use self.cfg.streaming instead of self.streaming_cfg
streaming_cfg = self.cfg
self.context_encoder_frames = ContextSize(
left=int(streaming_cfg.left_context_secs * features_per_sec / self.encoder_subsampling_factor),
chunk=int(streaming_cfg.chunk_secs * features_per_sec / self.encoder_subsampling_factor),
right=int(streaming_cfg.right_context_secs * features_per_sec / self.encoder_subsampling_factor),
)
self.context_samples = ContextSize(
left=self.context_encoder_frames.left * encoder_frame2audio_samples,
chunk=self.context_encoder_frames.chunk * encoder_frame2audio_samples,
right=self.context_encoder_frames.right * encoder_frame2audio_samples,
)
def _reset_decoder_state(self):
"""
Resets ONLY the decoder state, preserving the audio buffer.
This prevents slowdowns on long audio streams.
"""
start_time = time.perf_counter()
logging.debug("--- Resetting decoder state (audio buffer preserved) ---")
# Reset tracking for this segment
self.last_transcription = ""
self.chunk_count = 0
batch_size = 1 # Hardcoded for this script
# Use self.cfg.streaming instead of self.streaming_cfg
streaming_cfg = self.cfg
# 1. Recreate the initial prompt for the decoder
self.decoder_input_ids = return_decoder_input_ids(streaming_cfg, self.asr_model)
# 2. Recreate the "computer" object that manages decoding
self.decoding_computer = GreedyBatchedStreamingAEDComputer(
self.asr_model,
frame_chunk_size=self.context_encoder_frames.chunk,
decoding_cfg=streaming_cfg.decoding,
)
# 3. Recreate an EMPTY STATE object (model_state)
self.model_state = GreedyBatchedStreamingAEDComputer.initialize_aed_model_state(
asr_model=self.asr_model,
decoder_input_ids=self.decoder_input_ids,
batch_size=batch_size,
context_encoder_frames=self.context_encoder_frames,
chunk_secs=streaming_cfg.chunk_secs,
right_context_secs=streaming_cfg.right_context_secs,
)
# Clear CUDA cache if possible
if torch.cuda.is_available():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
end_time = time.perf_counter()
duration_ms = (end_time - start_time) * 1000 # Convert to milliseconds
logging.debug(f"--- Decoder reset finished in {duration_ms:.2f} ms ---")
def reset(self):
"""
Resets the transcriber's state completely (audio buffer + decoder state).
Called only on initialization.
"""
start_time = time.perf_counter()
logging.debug("--- FULL RESET (Audio Buffer + Decoder State) ---")
# Operation 1: Reset the decoder (this now includes GC)
self._reset_decoder_state()
# Operation 2: Reset the audio buffer
self.buffer = StreamingBatchedAudioBuffer(
batch_size=1, # Hardcoded for this script
context_samples=self.context_samples,
dtype=torch.float32,
device=self.map_location,
)
end_time = time.perf_counter()
duration_ms = (end_time * 1000)
logging.debug(f"--- RESET Complete: took {duration_ms:.2f} ms ---")
def transcribe_chunk(self, chunk: np.ndarray, is_last_chunk: bool = False) -> Tuple[str, str]:
"""
Processes a single audio chunk and returns the newly predicted text.
Returns:
Tuple[str, str]:
(current_transcription: The full transcription for the current segment,
new_text: The newly appended text since the last chunk)
"""
start_time = time.perf_counter()
self.chunk_count += 1
# Preprocess audio
signal = torch.from_numpy(chunk.astype(np.float32) / 32768.0)
audio_batch = signal.unsqueeze(0).to(self.map_location)
audio_batch_lengths = torch.tensor([signal.shape[0]], device=self.map_location)
# 1. Add the chunk to the persistent buffer
self.buffer.add_audio_batch_(
audio_batch,
audio_lengths=audio_batch_lengths,
is_last_chunk=is_last_chunk,
is_last_chunk_batch=torch.tensor([is_last_chunk], device=self.map_location)
)
self.model_state.is_last_chunk_batch = torch.tensor([is_last_chunk], device=self.map_location)
# 2. Pass the buffer to the encoder
_, encoded_len, enc_states, _ = self.asr_model(
input_signal=self.buffer.samples, input_signal_length=self.buffer.context_size_batch.total()
)
encoder_context_batch = self.buffer.context_size_batch.subsample(factor=self.features_frame2audio_samples * self.encoder_subsampling_factor)
encoded_len_no_rc = encoder_context_batch.left + encoder_context_batch.chunk
encoded_length_corrected = torch.where(self.model_state.is_last_chunk_batch, encoded_len, encoded_len_no_rc)
encoder_input_mask = lens_to_mask(encoded_length_corrected, enc_states.shape[1]).to(enc_states.dtype)
# 3. Pass to the decoding computer
self.model_state = self.decoding_computer(
encoder_output=enc_states,
encoder_output_len=encoded_length_corrected,
encoder_input_mask=encoder_input_mask,
prev_batched_state=self.model_state,
)
# 4. Calculate the new text
current_tokens = self.model_state.pred_tokens_ids[0, self.decoder_input_ids.size(-1): self.model_state.current_context_lengths[0]]
# OPTIMIZATION: Move tokens to CPU before converting to list
current_transcription = self.asr_model.tokenizer.ids_to_text(current_tokens.cpu().tolist()).strip()
# Calculate the NEW text by "subtracting" the old history
new_text = ""
if current_transcription.startswith(self.last_transcription):
new_text = current_transcription[len(self.last_transcription):]
else:
# Model corrected itself, send the full new transcription
new_text = current_transcription
# Memorize the FULL current transcription as the new history
if new_text:
self.last_transcription = current_transcription
end_time = time.perf_counter()
duration_ms = (end_time - start_time) * 1000
# logging.info(f"--- transcribe_chunk: took {duration_ms:.2f} ms ---")
# Return both the full segment transcription and the new diff
yield current_transcription, new_text
def finalize_segment(self):
"""
Finalizes the current transcription segment (e.g., on silence)
and adds it to the full history.
"""
if self.last_transcription:
self.full_transcription.append(self.last_transcription)
self.last_transcription = ""
# We must reset the decoder state to start a new segment
self._reset_decoder_state()
def get_full_transcription(self) -> str:
"""
Returns the full accumulated transcription from all finalized segments.
Does NOT include the currently active (unfinalized) segment.
"""
return " ".join(self.full_transcription)
def get_current_segment_text(self) -> str:
"""Returns the text of the segment currently being transcribed."""
return self.last_transcription