Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,701 Bytes
11c4a5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import os
import subprocess
import torch
import numpy as np
import onnxruntime
import warnings
from app.interfaces import IVoiceActivityEngine
from app.logger_config import (
logger as logging,
DEBUG
)
class VoiceActivityDetection():
def __init__(self, force_onnx_cpu=True):
logging.info("Initializing VoiceActivityDetection...")
path = self.download()
opts = onnxruntime.SessionOptions()
opts.log_severity_level = 3 # Suppress ONNX runtime's own logs
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 1
try:
if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
logging.info("ONNX VAD session created with CPUExecutionProvider.")
else:
self.session = onnxruntime.InferenceSession(path, providers=['CUDAExecutionProvider'], sess_options=opts)
logging.info("ONNX VAD session created with CUDAExecutionProvider.")
except Exception as e:
logging.critical(f"Failed to create ONNX InferenceSession: {e}", exc_info=True)
raise
self.reset_states()
if '16k' in path:
logging.warning('This VAD model supports only 16000 sampling rate!')
self.sample_rates = [16000]
else:
logging.info("VAD model supports 8000Hz and 16000Hz.")
self.sample_rates = [8000, 16000]
def _validate_input(self, x, sr: int):
if x.dim() == 1:
x = x.unsqueeze(0)
if x.dim() > 2:
logging.error(f"Too many dimensions for input audio chunk: {x.dim()}")
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
if sr != 16000 and (sr % 16000 == 0):
step = sr // 16000
x = x[:,::step]
sr = 16000
logging.debug(f"Downsampled input audio to 16000Hz from {sr}Hz.")
if sr not in self.sample_rates:
logging.error(f"Unsupported sampling rate: {sr}. Supported: {self.sample_rates}")
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
return x, sr
def reset_states(self, batch_size=1):
logging.debug(f"Resetting VAD states for batch_size: {batch_size}")
self._state = torch.zeros((2, batch_size, 128)).float()
self._context = torch.zeros(0)
self._last_sr = 0
self._last_batch_size = 0
def __call__(self, x, sr: int):
x, sr = self._validate_input(x, sr)
num_samples = 512 if sr == 16000 else 256
if x.shape[-1] != num_samples:
logging.error(f"Invalid audio chunk size: {x.shape[-1]}. Expected {num_samples} for {sr}Hz.")
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
batch_size = x.shape[0]
context_size = 64 if sr == 16000 else 32
if not self._last_batch_size:
logging.debug("First call, resetting states.")
self.reset_states(batch_size)
if (self._last_sr) and (self._last_sr != sr):
logging.warning(f"Sample rate changed ({self._last_sr}Hz -> {sr}Hz). Resetting states.")
self.reset_states(batch_size)
if (self._last_batch_size) and (self._last_batch_size != batch_size):
logging.warning(f"Batch size changed ({self._last_batch_size} -> {batch_size}). Resetting states.")
self.reset_states(batch_size)
if not len(self._context):
self._context = torch.zeros(batch_size, context_size)
x = torch.cat([self._context, x], dim=1)
if sr in [8000, 16000]:
ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
ort_outs = self.session.run(None, ort_inputs)
out, state = ort_outs
self._state = torch.from_numpy(state)
else:
# This should be caught by _validate_input, but as a safeguard:
logging.critical(f"Unexpected sample rate in VAD __call__: {sr}")
raise ValueError()
self._context = x[..., -context_size:]
self._last_sr = sr
self._last_batch_size = batch_size
out = torch.from_numpy(out)
return out
def audio_forward(self, x, sr: int):
outs = []
x, sr = self._validate_input(x, sr)
self.reset_states()
num_samples = 512 if sr == 16000 else 256
if x.shape[1] % num_samples:
pad_num = num_samples - (x.shape[1] % num_samples)
logging.debug(f"Padding audio input with {pad_num} samples.")
x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)
for i in range(0, x.shape[1], num_samples):
wavs_batch = x[:, i:i+num_samples]
out_chunk = self.__call__(wavs_batch, sr)
outs.append(out_chunk)
stacked = torch.cat(outs, dim=1)
return stacked.cpu()
@staticmethod
def download(model_url="https://github.com/snakers4/silero-vad/raw/v5.0/files/silero_vad.onnx"):
target_dir = os.path.expanduser("~/.cache/silero_vad/")
os.makedirs(target_dir, exist_ok=True)
model_filename = os.path.join(target_dir, "silero_vad.onnx")
if not os.path.exists(model_filename):
logging.info(f"Downloading VAD model to {model_filename}...")
try:
subprocess.run(["wget", "-O", model_filename, model_url], check=True)
logging.info("VAD model downloaded successfully.")
except subprocess.CalledProcessError as e:
logging.critical(f"Failed to download the model using wget: {e}")
raise
else:
logging.info(f"VAD model already exists at {model_filename}.")
return model_filename
class Silero_Vad_Engine(IVoiceActivityEngine):
def __init__(self, threshold :float =0.5, frame_rate: int =16000):
"""
Initializes the Silero_Vad_Engine with a voice activity detection model and a threshold.
Args:
threshold (float, optional): The probability threshold for detecting voice activity. Defaults to 0.5.
"""
logging.info(f"Initializing Silero_Vad_Engine with threshold: {threshold} and frame_rate: {frame_rate}Hz.")
self.model = VoiceActivityDetection()
self.threshold = threshold
self.frame_rate = frame_rate
def __call__(self, audio_frame):
"""
Determines if the given audio frame contains speech by comparing the detected speech probability against
the threshold.
Args:
audio_frame (np.ndarray): The audio frame to be analyzed for voice activity. It is expected to be a
NumPy array of audio samples.
Returns:
bool: True if the speech probability exceeds the threshold, indicating the presence of voice activity;
False otherwise.
"""
# Convert frame to tensor
audio_tensor = torch.from_numpy(audio_frame.copy())
# Get speech probabilities
speech_probs = self.model.audio_forward(audio_tensor, self.frame_rate)[0]
# Check against threshold
is_speech = torch.any(speech_probs > self.threshold).item()
logging.debug(f"VAD check result: {is_speech} (Max prob: {torch.max(speech_probs).item():.4f})")
return is_speech
|