| | import torch |
| | from infer.lib.rmvpe import STFT |
| | from torch.nn.functional import conv1d, conv2d |
| | from typing import Union, Optional |
| | from .utils import linspace, temperature_sigmoid, amp_to_db |
| |
|
| |
|
| | class TorchGate(torch.nn.Module): |
| | """ |
| | A PyTorch module that applies a spectral gate to an input signal. |
| | |
| | Arguments: |
| | sr {int} -- Sample rate of the input signal. |
| | nonstationary {bool} -- Whether to use non-stationary or stationary masking (default: {False}). |
| | n_std_thresh_stationary {float} -- Number of standard deviations above mean to threshold noise for |
| | stationary masking (default: {1.5}). |
| | n_thresh_nonstationary {float} -- Number of multiplies above smoothed magnitude spectrogram. for |
| | non-stationary masking (default: {1.3}). |
| | temp_coeff_nonstationary {float} -- Temperature coefficient for non-stationary masking (default: {0.1}). |
| | n_movemean_nonstationary {int} -- Number of samples for moving average smoothing in non-stationary masking |
| | (default: {20}). |
| | prop_decrease {float} -- Proportion to decrease signal by where the mask is zero (default: {1.0}). |
| | n_fft {int} -- Size of FFT for STFT (default: {1024}). |
| | win_length {[int]} -- Window length for STFT. If None, defaults to `n_fft` (default: {None}). |
| | hop_length {[int]} -- Hop length for STFT. If None, defaults to `win_length` // 4 (default: {None}). |
| | freq_mask_smooth_hz {float} -- Frequency smoothing width for mask (in Hz). If None, no smoothing is applied |
| | (default: {500}). |
| | time_mask_smooth_ms {float} -- Time smoothing width for mask (in ms). If None, no smoothing is applied |
| | (default: {50}). |
| | """ |
| |
|
| | @torch.no_grad() |
| | def __init__( |
| | self, |
| | sr: int, |
| | nonstationary: bool = False, |
| | n_std_thresh_stationary: float = 1.5, |
| | n_thresh_nonstationary: float = 1.3, |
| | temp_coeff_nonstationary: float = 0.1, |
| | n_movemean_nonstationary: int = 20, |
| | prop_decrease: float = 1.0, |
| | n_fft: int = 1024, |
| | win_length: bool = None, |
| | hop_length: int = None, |
| | freq_mask_smooth_hz: float = 500, |
| | time_mask_smooth_ms: float = 50, |
| | ): |
| | super().__init__() |
| |
|
| | |
| | self.sr = sr |
| | self.nonstationary = nonstationary |
| | assert 0.0 <= prop_decrease <= 1.0 |
| | self.prop_decrease = prop_decrease |
| |
|
| | |
| | self.n_fft = n_fft |
| | self.win_length = self.n_fft if win_length is None else win_length |
| | self.hop_length = self.win_length // 4 if hop_length is None else hop_length |
| |
|
| | |
| | self.n_std_thresh_stationary = n_std_thresh_stationary |
| |
|
| | |
| | self.temp_coeff_nonstationary = temp_coeff_nonstationary |
| | self.n_movemean_nonstationary = n_movemean_nonstationary |
| | self.n_thresh_nonstationary = n_thresh_nonstationary |
| |
|
| | |
| | self.freq_mask_smooth_hz = freq_mask_smooth_hz |
| | self.time_mask_smooth_ms = time_mask_smooth_ms |
| | self.register_buffer("smoothing_filter", self._generate_mask_smoothing_filter()) |
| |
|
| | @torch.no_grad() |
| | def _generate_mask_smoothing_filter(self) -> Union[torch.Tensor, None]: |
| | """ |
| | A PyTorch module that applies a spectral gate to an input signal using the STFT. |
| | |
| | Returns: |
| | smoothing_filter (torch.Tensor): a 2D tensor representing the smoothing filter, |
| | with shape (n_grad_freq, n_grad_time), where n_grad_freq is the number of frequency |
| | bins to smooth and n_grad_time is the number of time frames to smooth. |
| | If both self.freq_mask_smooth_hz and self.time_mask_smooth_ms are None, returns None. |
| | """ |
| | if self.freq_mask_smooth_hz is None and self.time_mask_smooth_ms is None: |
| | return None |
| |
|
| | n_grad_freq = ( |
| | 1 |
| | if self.freq_mask_smooth_hz is None |
| | else int(self.freq_mask_smooth_hz / (self.sr / (self.n_fft / 2))) |
| | ) |
| | if n_grad_freq < 1: |
| | raise ValueError( |
| | f"freq_mask_smooth_hz needs to be at least {int((self.sr / (self._n_fft / 2)))} Hz" |
| | ) |
| |
|
| | n_grad_time = ( |
| | 1 |
| | if self.time_mask_smooth_ms is None |
| | else int(self.time_mask_smooth_ms / ((self.hop_length / self.sr) * 1000)) |
| | ) |
| | if n_grad_time < 1: |
| | raise ValueError( |
| | f"time_mask_smooth_ms needs to be at least {int((self.hop_length / self.sr) * 1000)} ms" |
| | ) |
| |
|
| | if n_grad_time == 1 and n_grad_freq == 1: |
| | return None |
| |
|
| | v_f = torch.cat( |
| | [ |
| | linspace(0, 1, n_grad_freq + 1, endpoint=False), |
| | linspace(1, 0, n_grad_freq + 2), |
| | ] |
| | )[1:-1] |
| | v_t = torch.cat( |
| | [ |
| | linspace(0, 1, n_grad_time + 1, endpoint=False), |
| | linspace(1, 0, n_grad_time + 2), |
| | ] |
| | )[1:-1] |
| | smoothing_filter = torch.outer(v_f, v_t).unsqueeze(0).unsqueeze(0) |
| |
|
| | return smoothing_filter / smoothing_filter.sum() |
| |
|
| | @torch.no_grad() |
| | def _stationary_mask( |
| | self, X_db: torch.Tensor, xn: Optional[torch.Tensor] = None |
| | ) -> torch.Tensor: |
| | """ |
| | Computes a stationary binary mask to filter out noise in a log-magnitude spectrogram. |
| | |
| | Arguments: |
| | X_db (torch.Tensor): 2D tensor of shape (frames, freq_bins) containing the log-magnitude spectrogram. |
| | xn (torch.Tensor): 1D tensor containing the audio signal corresponding to X_db. |
| | |
| | Returns: |
| | sig_mask (torch.Tensor): Binary mask of the same shape as X_db, where values greater than the threshold |
| | are set to 1, and the rest are set to 0. |
| | """ |
| | if xn is not None: |
| | if "privateuseone" in str(xn.device): |
| | if not hasattr(self, "stft"): |
| | self.stft = STFT( |
| | filter_length=self.n_fft, |
| | hop_length=self.hop_length, |
| | win_length=self.win_length, |
| | window="hann", |
| | ).to(xn.device) |
| | XN = self.stft.transform(xn) |
| | else: |
| | XN = torch.stft( |
| | xn, |
| | n_fft=self.n_fft, |
| | hop_length=self.hop_length, |
| | win_length=self.win_length, |
| | return_complex=True, |
| | pad_mode="constant", |
| | center=True, |
| | window=torch.hann_window(self.win_length).to(xn.device), |
| | ) |
| | XN_db = amp_to_db(XN).to(dtype=X_db.dtype) |
| | else: |
| | XN_db = X_db |
| |
|
| | |
| | std_freq_noise, mean_freq_noise = torch.std_mean(XN_db, dim=-1) |
| |
|
| | |
| | noise_thresh = mean_freq_noise + std_freq_noise * self.n_std_thresh_stationary |
| |
|
| | |
| | sig_mask = X_db > noise_thresh.unsqueeze(2) |
| | return sig_mask |
| |
|
| | @torch.no_grad() |
| | def _nonstationary_mask(self, X_abs: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Computes a non-stationary binary mask to filter out noise in a log-magnitude spectrogram. |
| | |
| | Arguments: |
| | X_abs (torch.Tensor): 2D tensor of shape (frames, freq_bins) containing the magnitude spectrogram. |
| | |
| | Returns: |
| | sig_mask (torch.Tensor): Binary mask of the same shape as X_abs, where values greater than the threshold |
| | are set to 1, and the rest are set to 0. |
| | """ |
| | X_smoothed = ( |
| | conv1d( |
| | X_abs.reshape(-1, 1, X_abs.shape[-1]), |
| | torch.ones( |
| | self.n_movemean_nonstationary, |
| | dtype=X_abs.dtype, |
| | device=X_abs.device, |
| | ).view(1, 1, -1), |
| | padding="same", |
| | ).view(X_abs.shape) |
| | / self.n_movemean_nonstationary |
| | ) |
| |
|
| | |
| | slowness_ratio = (X_abs - X_smoothed) / (X_smoothed + 1e-6) |
| | sig_mask = temperature_sigmoid( |
| | slowness_ratio, self.n_thresh_nonstationary, self.temp_coeff_nonstationary |
| | ) |
| |
|
| | return sig_mask |
| |
|
| | def forward( |
| | self, x: torch.Tensor, xn: Optional[torch.Tensor] = None |
| | ) -> torch.Tensor: |
| | """ |
| | Apply the proposed algorithm to the input signal. |
| | |
| | Arguments: |
| | x (torch.Tensor): The input audio signal, with shape (batch_size, signal_length). |
| | xn (Optional[torch.Tensor]): The noise signal used for stationary noise reduction. If `None`, the input |
| | signal is used as the noise signal. Default: `None`. |
| | |
| | Returns: |
| | torch.Tensor: The denoised audio signal, with the same shape as the input signal. |
| | """ |
| |
|
| | |
| | if "privateuseone" in str(x.device): |
| | if not hasattr(self, "stft"): |
| | self.stft = STFT( |
| | filter_length=self.n_fft, |
| | hop_length=self.hop_length, |
| | win_length=self.win_length, |
| | window="hann", |
| | ).to(x.device) |
| | X, phase = self.stft.transform(x, return_phase=True) |
| | else: |
| | X = torch.stft( |
| | x, |
| | n_fft=self.n_fft, |
| | hop_length=self.hop_length, |
| | win_length=self.win_length, |
| | return_complex=True, |
| | pad_mode="constant", |
| | center=True, |
| | window=torch.hann_window(self.win_length).to(x.device), |
| | ) |
| |
|
| | |
| | if self.nonstationary: |
| | sig_mask = self._nonstationary_mask(X.abs()) |
| | else: |
| | sig_mask = self._stationary_mask(amp_to_db(X), xn) |
| |
|
| | |
| | sig_mask = self.prop_decrease * (sig_mask.float() - 1.0) + 1.0 |
| |
|
| | |
| | if self.smoothing_filter is not None: |
| | sig_mask = conv2d( |
| | sig_mask.unsqueeze(1), |
| | self.smoothing_filter.to(sig_mask.dtype), |
| | padding="same", |
| | ) |
| |
|
| | |
| | Y = X * sig_mask.squeeze(1) |
| |
|
| | |
| | if "privateuseone" in str(Y.device): |
| | y = self.stft.inverse(Y, phase) |
| | else: |
| | y = torch.istft( |
| | Y, |
| | n_fft=self.n_fft, |
| | hop_length=self.hop_length, |
| | win_length=self.win_length, |
| | center=True, |
| | window=torch.hann_window(self.win_length).to(Y.device), |
| | ) |
| |
|
| | return y.to(dtype=x.dtype) |
| |
|