File size: 16,107 Bytes
11c4a5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b84154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8417fa3
7b84154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11c4a5a
 
 
 
 
 
 
 
 
 
7b84154
11c4a5a
 
 
 
 
 
8417fa3
7b84154
11c4a5a
 
7b84154
 
11c4a5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b84154
11c4a5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc64c8b
 
11c4a5a
 
 
fc64c8b
11c4a5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
935d736
11c4a5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
import time
from typing import Optional, Tuple
from app.interfaces import IStreamingSpeechEngine
import numpy as np
import torch
import gc
from omegaconf import OmegaConf


from nemo.collections.asr.models.aed_multitask_models import lens_to_mask
from nemo.collections.asr.parts.submodules.aed_decoding import (
    GreedyBatchedStreamingAEDComputer,
    return_decoder_input_ids,
)
from nemo.collections.asr.parts.submodules.multitask_decoding import (
    AEDStreamingDecodingConfig,
    MultiTaskDecodingConfig,
)
# from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer # Not used
from nemo.collections.asr.parts.utils.streaming_utils import (
    ContextSize,
    StreamingBatchedAudioBuffer,
)
import nemo.collections.asr as nemo_asr
from nemo.collections.asr.parts.utils.transcribe_utils import (
    get_inference_device,
    get_inference_dtype,
)
from app.logger_config import (
    logger as logging,
    DEBUG
) 


from dataclasses import dataclass
from typing import Optional, Literal

@dataclass
class CanaryConfig:
    chunk_secs: float = 1.0
    left_context_secs: float = 20.0
    right_context_secs: float = 0.5
    cuda: Optional[bool] = None
    allow_mps: bool = True
    compute_dtype: Optional[str] = None
    matmul_precision: str = "high"
    batch_size= 1
    decoding: dict = None
    streaming_policy: str = "alignatt"
    alignatt_thr: float = 8.0
    waitk_lagging: int = 2
    exclude_sink_frames: int = 8
    xatt_scores_layer: int = -2
    max_tokens_per_alignatt_step: int = 30
    max_generation_length: int = 512
    use_avgpool_for_alignatt: bool = False
    hallucinations_detector: bool = True
    
    prompt: dict = None
    pnc: str = "no"
    task: str = "asr"
    source_lang: str = "fr"
    target_lang: str = "fr"
    timestamps: bool = True
    
    def __post_init__(self):
        if self.decoding is None:
            self.decoding = {
                "streaming_policy": self.streaming_policy,
                "alignatt_thr": self.alignatt_thr,
                "waitk_lagging": self.waitk_lagging,
                "exclude_sink_frames": self.exclude_sink_frames,
                "xatt_scores_layer": self.xatt_scores_layer,
                "max_tokens_per_alignatt_step": self.max_tokens_per_alignatt_step,
                "max_generation_length": self.max_generation_length,
                "use_avgpool_for_alignatt": self.use_avgpool_for_alignatt,
                "hallucinations_detector": self.hallucinations_detector
            }
        
        if self.prompt is None:
            self.prompt = {
                "pnc": self.pnc,
                "task": self.task,
                "source_lang": self.source_lang,
                "target_lang": self.target_lang,
                "timestamps": self.timestamps
            }

    def toOmegaConf(self) -> OmegaConf:
        """Convert the config to OmegaConf format"""
        config_dict = {
            "chunk_secs": self.chunk_secs,
            "left_context_secs": self.left_context_secs,
            "right_context_secs": self.right_context_secs,
            "cuda": self.cuda,
            "allow_mps": self.allow_mps,
            "compute_dtype": self.compute_dtype,
            "matmul_precision": self.matmul_precision,
            "batch_size": self.batch_size,
            "decoding": self.decoding,
            "prompt": self.prompt
        }
        
        # Remove None values
        filtered_dict = {k: v for k, v in config_dict.items() if v is not None}
        
        return OmegaConf.create(filtered_dict)
    
    @classmethod
    def from_params(
        cls,
        task_type: str,
        source_lang: str,
        target_lang: str,
        chunk_secs: float = 1.0,
        left_context_secs: float = 20.0,
        right_context_secs: float = 0.5,
        streaming_policy: str = "alignatt",
        alignatt_thr: float = 8.0,
        waitk_lagging: int = 2,
        exclude_sink_frames: int = 8,
        xatt_scores_layer: int = -2,
        hallucinations_detector: bool = True
    ):
        """Create a CanaryConfig instance from parameters"""
        # Convert task type to model task
        task = "asr" if task_type == "Transcription" else "ast"
        target_lang = source_lang if task_type == "Transcription" else target_lang
        
        return cls(
            chunk_secs=chunk_secs,
            left_context_secs=left_context_secs,
            right_context_secs=right_context_secs,
            streaming_policy=streaming_policy,
            alignatt_thr=alignatt_thr,
            waitk_lagging=waitk_lagging,
            exclude_sink_frames=exclude_sink_frames,
            xatt_scores_layer=xatt_scores_layer,
            hallucinations_detector=hallucinations_detector,
            task=task,
            source_lang=source_lang,
            target_lang=target_lang
        )

def make_divisible_by(num: int, factor: int) -> int:
    """Make num divisible by factor"""
    return (num // factor) * factor


class CanarySpeechEngine(IStreamingSpeechEngine):
    """
    Encapsulates the state and logic for streaming audio transcription
    using an internally loaded Canary model.
    """
    def __init__(self,asr_model, cfg: CanaryConfig):
        """
        Initializes the speech engine and loads the ASR model.
        
        Args:
            cfg: An OmegaConf object containing 'model' and 'streaming' configs.
        """
        logging.debug(f"Initializing CanarySpeechEngine with config: {cfg}")
        self.cfg = cfg.toOmegaConf() # Store the full config
        
        # Setup device and dtype from config
        self.map_location = get_inference_device(cuda=None, allow_mps=self.cfg.allow_mps)
        self.compute_dtype = get_inference_dtype(None, device=self.map_location)
        logging.info(f"Inference will be on device: {self.map_location} with dtype: {self.compute_dtype}")

        # Load the model internally
        asr_model, _ = self._setup_model(asr_model,self.cfg, self.map_location)
        self.asr_model = asr_model
        
        self.full_transcription = [] # Stores finalized segments
        self._setup_streaming_params()
        
        # The initial full reset (buffer + decoder)
        self.reset()
        
        logging.info("CanarySpeechEngine initialized and ready.")
        logging.info(f"Model-adjusted chunk size: {self.context_samples.chunk} samples.")

    def _setup_model(self,asr_model, model_cfg: OmegaConf, map_location: str):
        """Loads the pretrained ASR model and configures it for inference."""
        logging.info(f"Loading model ...")
        start_time = time.time()
        try:
            asr_model = asr_model.to(map_location)
            asr_model.eval()
            
            # Change decoding strategy to greedy for streaming
            if hasattr(asr_model, 'change_decoding_strategy'):
                multitask_decoding = MultiTaskDecodingConfig()
                multitask_decoding.strategy = "greedy"
                asr_model.change_decoding_strategy(multitask_decoding)
                logging.info("Model decoding strategy set to 'greedy'.")

            if map_location == "cuda":
                torch.cuda.synchronize()
            
            end_time = time.time()
            logging.info("Model loaded successfully.")
            load_time = end_time - start_time
            logging.info("\n" + "="*30)
            logging.info(f"Total model load time: {load_time:.2f} seconds")
            logging.info("="*30)
            return asr_model, None

        except Exception as e:
            logging.error(f"Error loading model: {e}")
            logging.error("Ensure NeMo is installed (pip install nemo_toolkit['asr'])")
            return None, None
    
    def _setup_streaming_params(self):
        """Helper to calculate model-specific streaming parameters."""
        model_cfg = self.asr_model.cfg
        audio_sample_rate = model_cfg.preprocessor['sample_rate']
        self.feature_stride_sec = model_cfg.preprocessor['window_stride']
        features_per_sec = 1.0 / self.feature_stride_sec 
        self.encoder_subsampling_factor = self.asr_model.encoder.subsampling_factor
        
        self.features_frame2audio_samples = make_divisible_by(
            int(audio_sample_rate * self.feature_stride_sec ), factor=self.encoder_subsampling_factor
        )
        encoder_frame2audio_samples = self.features_frame2audio_samples * self.encoder_subsampling_factor

        # Use self.cfg.streaming instead of self.streaming_cfg
        streaming_cfg = self.cfg
        self.context_encoder_frames = ContextSize(
            left=int(streaming_cfg.left_context_secs * features_per_sec / self.encoder_subsampling_factor),
            chunk=int(streaming_cfg.chunk_secs * features_per_sec / self.encoder_subsampling_factor),
            right=int(streaming_cfg.right_context_secs * features_per_sec / self.encoder_subsampling_factor),
        )
        self.context_samples = ContextSize(
            left=self.context_encoder_frames.left * encoder_frame2audio_samples,
            chunk=self.context_encoder_frames.chunk * encoder_frame2audio_samples,
            right=self.context_encoder_frames.right * encoder_frame2audio_samples,
        )

    def _reset_decoder_state(self):
        """
        Resets ONLY the decoder state, preserving the audio buffer.
        This prevents slowdowns on long audio streams.
        """
        start_time = time.perf_counter()
        logging.debug("--- Resetting decoder state (audio buffer preserved) ---")
        
        # Reset tracking for this segment
        self.last_transcription = ""
        self.chunk_count = 0 
        batch_size = 1 # Hardcoded for this script
        
        # Use self.cfg.streaming instead of self.streaming_cfg
        streaming_cfg = self.cfg

        # 1. Recreate the initial prompt for the decoder
        self.decoder_input_ids = return_decoder_input_ids(streaming_cfg, self.asr_model)
        
        # 2. Recreate the "computer" object that manages decoding
        self.decoding_computer = GreedyBatchedStreamingAEDComputer(
            self.asr_model,
            frame_chunk_size=self.context_encoder_frames.chunk,
            decoding_cfg=streaming_cfg.decoding,
        )

        # 3. Recreate an EMPTY STATE object (model_state)
        self.model_state = GreedyBatchedStreamingAEDComputer.initialize_aed_model_state(
            asr_model=self.asr_model,
            decoder_input_ids=self.decoder_input_ids,
            batch_size=batch_size,
            context_encoder_frames=self.context_encoder_frames,
            chunk_secs=streaming_cfg.chunk_secs,
            right_context_secs=streaming_cfg.right_context_secs,
        )

        # Clear CUDA cache if possible
        if torch.cuda.is_available():
            gc.collect()
            torch.cuda.empty_cache()
            torch.cuda.synchronize()

        end_time = time.perf_counter()
        duration_ms = (end_time - start_time) * 1000  # Convert to milliseconds
        logging.debug(f"--- Decoder reset finished in {duration_ms:.2f} ms ---")

    def reset(self):
        """
        Resets the transcriber's state completely (audio buffer + decoder state).
        Called only on initialization.
        """
        start_time = time.perf_counter()
        logging.debug("--- FULL RESET (Audio Buffer + Decoder State) ---")
        
        # Operation 1: Reset the decoder (this now includes GC)
        self._reset_decoder_state()
        
        # Operation 2: Reset the audio buffer
        self.buffer = StreamingBatchedAudioBuffer(
            batch_size=1, # Hardcoded for this script
            context_samples=self.context_samples,
            dtype=torch.float32,
            device=self.map_location,
        )
        
        end_time = time.perf_counter()
        duration_ms = (end_time * 1000)
        logging.debug(f"--- RESET Complete: took {duration_ms:.2f} ms ---")

    def transcribe_chunk(self, chunk: np.ndarray, is_last_chunk: bool = False) -> Tuple[str, str]:
        """
        Processes a single audio chunk and returns the newly predicted text.

        Returns:
            Tuple[str, str]: 
                (current_transcription: The full transcription for the current segment,
                 new_text: The newly appended text since the last chunk)
        """
        start_time = time.perf_counter()
        self.chunk_count += 1
        
        # Preprocess audio
        signal = torch.from_numpy(chunk.astype(np.float32) / 32768.0)
        audio_batch = signal.unsqueeze(0).to(self.map_location)
        audio_batch_lengths = torch.tensor([signal.shape[0]], device=self.map_location)

        # 1. Add the chunk to the persistent buffer
        self.buffer.add_audio_batch_(
            audio_batch,
            audio_lengths=audio_batch_lengths,
            is_last_chunk=is_last_chunk,
            is_last_chunk_batch=torch.tensor([is_last_chunk], device=self.map_location)
        )

        self.model_state.is_last_chunk_batch = torch.tensor([is_last_chunk], device=self.map_location)

        # 2. Pass the buffer to the encoder
        _, encoded_len, enc_states, _ = self.asr_model(
            input_signal=self.buffer.samples, input_signal_length=self.buffer.context_size_batch.total()
        )
        encoder_context_batch = self.buffer.context_size_batch.subsample(factor=self.features_frame2audio_samples * self.encoder_subsampling_factor)
        encoded_len_no_rc = encoder_context_batch.left + encoder_context_batch.chunk
        encoded_length_corrected = torch.where(self.model_state.is_last_chunk_batch, encoded_len, encoded_len_no_rc)
        encoder_input_mask = lens_to_mask(encoded_length_corrected, enc_states.shape[1]).to(enc_states.dtype)

        # 3. Pass to the decoding computer
        self.model_state = self.decoding_computer(
            encoder_output=enc_states,
            encoder_output_len=encoded_length_corrected,
            encoder_input_mask=encoder_input_mask,
            prev_batched_state=self.model_state,
        )

        # 4. Calculate the new text
        current_tokens = self.model_state.pred_tokens_ids[0, self.decoder_input_ids.size(-1): self.model_state.current_context_lengths[0]]
        
        # OPTIMIZATION: Move tokens to CPU before converting to list
        current_transcription = self.asr_model.tokenizer.ids_to_text(current_tokens.cpu().tolist()).strip()
        
        # Calculate the NEW text by "subtracting" the old history
        new_text = ""
        if current_transcription.startswith(self.last_transcription):
             new_text = current_transcription[len(self.last_transcription):]
        else:
            # Model corrected itself, send the full new transcription
            new_text = current_transcription
        
        # Memorize the FULL current transcription as the new history
        if new_text:
            self.last_transcription = current_transcription
        
        end_time = time.perf_counter()
        duration_ms = (end_time - start_time) * 1000
        # logging.info(f"--- transcribe_chunk: took {duration_ms:.2f} ms ---")

        # Return both the full segment transcription and the new diff
        yield current_transcription, new_text
    
    def finalize_segment(self):
        """
        Finalizes the current transcription segment (e.g., on silence)
        and adds it to the full history.
        """
        if self.last_transcription:
            self.full_transcription.append(self.last_transcription)
        self.last_transcription = ""
        # We must reset the decoder state to start a new segment
        self._reset_decoder_state()

    def get_full_transcription(self) -> str:
        """
        Returns the full accumulated transcription from all finalized segments.
        Does NOT include the currently active (unfinalized) segment.
        """
        return " ".join(self.full_transcription)

    def get_current_segment_text(self) -> str:
        """Returns the text of the segment currently being transcribed."""
        return self.last_transcription