| | import os |
| | import sys |
| | import torch |
| |
|
| | from torch.nn.functional import conv1d, conv2d |
| |
|
| | sys.path.append(os.getcwd()) |
| |
|
| | @torch.no_grad() |
| | def temperature_sigmoid(x, x0, temp_coeff): |
| | return ((x - x0) / temp_coeff).sigmoid() |
| |
|
| | @torch.no_grad() |
| | def linspace(start, stop, num = 50, endpoint = True, **kwargs): |
| | return ( |
| | torch.linspace( |
| | start, |
| | stop, |
| | num, |
| | **kwargs |
| | ) |
| | ) if endpoint else ( |
| | torch.linspace( |
| | start, |
| | stop, |
| | num + 1, |
| | **kwargs |
| | )[:-1] |
| | ) |
| |
|
| | @torch.no_grad() |
| | def amp_to_db(x, eps=torch.finfo(torch.float32).eps, top_db=40): |
| | x_db = 20 * (x + eps).log10() |
| |
|
| | return x_db.max( |
| | (x_db.max(-1).values - top_db).unsqueeze(-1) |
| | ) |
| |
|
| | class TorchGate(torch.nn.Module): |
| | @torch.no_grad() |
| | def __init__( |
| | self, |
| | sr, |
| | nonstationary = False, |
| | n_std_thresh_stationary = 1.5, |
| | n_thresh_nonstationary = 1.3, |
| | temp_coeff_nonstationary = 0.1, |
| | n_movemean_nonstationary = 20, |
| | prop_decrease = 1.0, |
| | n_fft = 1024, |
| | win_length = None, |
| | hop_length = None, |
| | freq_mask_smooth_hz = 500, |
| | time_mask_smooth_ms = 50 |
| | ): |
| | super().__init__() |
| | self.sr = sr |
| | self.nonstationary = nonstationary |
| | assert 0.0 <= prop_decrease <= 1.0 |
| | self.prop_decrease = prop_decrease |
| | self.n_fft = n_fft |
| | self.win_length = self.n_fft if win_length is None else win_length |
| | self.hop_length = self.win_length // 4 if hop_length is None else hop_length |
| | self.n_std_thresh_stationary = n_std_thresh_stationary |
| | self.temp_coeff_nonstationary = temp_coeff_nonstationary |
| | self.n_movemean_nonstationary = n_movemean_nonstationary |
| | self.n_thresh_nonstationary = n_thresh_nonstationary |
| | self.freq_mask_smooth_hz = freq_mask_smooth_hz |
| | self.time_mask_smooth_ms = time_mask_smooth_ms |
| | self.register_buffer("smoothing_filter", self._generate_mask_smoothing_filter()) |
| |
|
| | @torch.no_grad() |
| | def _generate_mask_smoothing_filter(self): |
| | if self.freq_mask_smooth_hz is None and self.time_mask_smooth_ms is None: return None |
| | n_grad_freq = (1 if self.freq_mask_smooth_hz is None else int(self.freq_mask_smooth_hz / (self.sr / (self.n_fft / 2)))) |
| | if n_grad_freq < 1: raise ValueError |
| |
|
| | n_grad_time = (1 if self.time_mask_smooth_ms is None else int(self.time_mask_smooth_ms / ((self.hop_length / self.sr) * 1000))) |
| | if n_grad_time < 1: raise ValueError |
| | if n_grad_time == 1 and n_grad_freq == 1: return None |
| |
|
| | smoothing_filter = torch.outer( |
| | torch.cat([ |
| | linspace(0, 1, n_grad_freq + 1, endpoint=False), |
| | linspace(1, 0, n_grad_freq + 2) |
| | ])[1:-1], |
| | torch.cat([ |
| | linspace(0, 1, n_grad_time + 1, endpoint=False), |
| | linspace(1, 0, n_grad_time + 2) |
| | ])[1:-1] |
| | ).unsqueeze(0).unsqueeze(0) |
| |
|
| | return smoothing_filter / smoothing_filter.sum() |
| |
|
| | @torch.no_grad() |
| | def _stationary_mask(self, X_db): |
| | std_freq_noise, mean_freq_noise = torch.std_mean(X_db, dim=-1) |
| | return X_db > (mean_freq_noise + std_freq_noise * self.n_std_thresh_stationary).unsqueeze(2) |
| |
|
| | @torch.no_grad() |
| | def _nonstationary_mask(self, X_abs): |
| | X_smoothed = ( |
| | conv1d( |
| | X_abs.reshape(-1, 1, X_abs.shape[-1]), |
| | torch.ones( |
| | self.n_movemean_nonstationary, |
| | dtype=X_abs.dtype, |
| | device=X_abs.device |
| | ).view(1, 1, -1), |
| | padding="same" |
| | ).view(X_abs.shape) / self.n_movemean_nonstationary |
| | ) |
| |
|
| | return temperature_sigmoid( |
| | ((X_abs - X_smoothed) / X_smoothed), |
| | self.n_thresh_nonstationary, |
| | self.temp_coeff_nonstationary |
| | ) |
| |
|
| | def forward(self, x): |
| | assert x.ndim == 2 |
| | if x.shape[-1] < self.win_length * 2: raise Exception |
| |
|
| | if str(x.device).startswith(("ocl", "privateuseone")): |
| | if not hasattr(self, "stft"): |
| | from main.library.backends.utils import STFT |
| |
|
| | self.stft = STFT( |
| | filter_length=self.n_fft, |
| | hop_length=self.hop_length, |
| | win_length=self.win_length, |
| | pad_mode="constant" |
| | ).to(x.device) |
| |
|
| | X, phase = self.stft.transform( |
| | x, |
| | eps=1e-9, |
| | return_phase=True |
| | ) |
| | else: |
| | X = torch.stft( |
| | x, |
| | n_fft=self.n_fft, |
| | hop_length=self.hop_length, |
| | win_length=self.win_length, |
| | return_complex=True, |
| | pad_mode="constant", |
| | center=True, |
| | window=torch.hann_window(self.win_length).to(x.device) |
| | ) |
| | |
| | sig_mask = self._nonstationary_mask(X.abs()) if self.nonstationary else self._stationary_mask(amp_to_db(X.abs())) |
| | sig_mask = self.prop_decrease * (sig_mask.float() * 1.0 - 1.0) + 1.0 |
| |
|
| | if self.smoothing_filter is not None: |
| | sig_mask = conv2d( |
| | sig_mask.unsqueeze(1), |
| | self.smoothing_filter.to(sig_mask.dtype), |
| | padding="same" |
| | ) |
| |
|
| | Y = X * sig_mask.squeeze(1) |
| |
|
| | return ( |
| | self.stft.inverse( |
| | Y, |
| | phase |
| | ) |
| | ) if hasattr(self, "stft") else ( |
| | torch.istft( |
| | Y, |
| | n_fft=self.n_fft, |
| | hop_length=self.hop_length, |
| | win_length=self.win_length, |
| | center=True, |
| | window=torch.hann_window(self.win_length).to(Y.device) |
| | ).to(dtype=x.dtype) |
| | ) |