Spaces:
Running
Running
update
Browse files- README.md +2 -2
- examples/dtln_mp3_to_wav/run.sh +0 -168
- examples/dtln_mp3_to_wav/step_1_prepare_data.py +0 -127
- examples/dtln_mp3_to_wav/step_2_train_model.py +0 -445
- examples/dtln_mp3_to_wav/yaml/config-1024.yaml +0 -29
- examples/dtln_mp3_to_wav/yaml/config-256.yaml +0 -29
- examples/dtln_mp3_to_wav/yaml/config-512.yaml +0 -29
- examples/frcrn_mp3_to_wav/run.sh +0 -156
- examples/frcrn_mp3_to_wav/step_1_prepare_data.py +0 -127
- examples/frcrn_mp3_to_wav/step_2_train_model.py +0 -442
- examples/frcrn_mp3_to_wav/yaml/config-10.yaml +0 -31
- examples/frcrn_mp3_to_wav/yaml/config-14.yaml +0 -31
- examples/frcrn_mp3_to_wav/yaml/config-20.yaml +0 -31
- examples/simple_linear_irm_aishell/run.sh +0 -172
- examples/simple_linear_irm_aishell/step_1_prepare_data.py +0 -196
- examples/simple_linear_irm_aishell/step_2_train_model.py +0 -348
- examples/simple_linear_irm_aishell/step_3_evaluation.py +0 -239
- examples/simple_linear_irm_aishell/yaml/config.yaml +0 -13
- examples/spectrum_dfnet_aishell/run.sh +0 -178
- examples/spectrum_dfnet_aishell/step_1_prepare_data.py +0 -197
- examples/spectrum_dfnet_aishell/step_2_train_model.py +0 -440
- examples/spectrum_dfnet_aishell/step_3_evaluation.py +0 -302
- examples/spectrum_dfnet_aishell/yaml/config.yaml +0 -53
- examples/spectrum_unet_irm_aishell/run.sh +0 -178
- examples/spectrum_unet_irm_aishell/step_1_prepare_data.py +0 -197
- examples/spectrum_unet_irm_aishell/step_2_train_model.py +0 -420
- examples/spectrum_unet_irm_aishell/step_3_evaluation.py +0 -270
- examples/spectrum_unet_irm_aishell/yaml/config.yaml +0 -38
- main.py +1 -1
- toolbox/torch/utils/data/dataset/mp3_to_wav_jsonl_dataset.py +0 -197
README.md
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
emoji: 🐢
|
| 4 |
colorFrom: purple
|
| 5 |
colorTo: blue
|
|
@@ -9,7 +9,7 @@ license: apache-2.0
|
|
| 9 |
---
|
| 10 |
|
| 11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
| 12 |
-
##
|
| 13 |
|
| 14 |
|
| 15 |
### datasets
|
|
|
|
| 1 |
---
|
| 2 |
+
title: CC Denoise
|
| 3 |
emoji: 🐢
|
| 4 |
colorFrom: purple
|
| 5 |
colorTo: blue
|
|
|
|
| 9 |
---
|
| 10 |
|
| 11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
| 12 |
+
## CC Denoise
|
| 13 |
|
| 14 |
|
| 15 |
### datasets
|
examples/dtln_mp3_to_wav/run.sh
DELETED
|
@@ -1,168 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env bash
|
| 2 |
-
|
| 3 |
-
: <<'END'
|
| 4 |
-
|
| 5 |
-
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir-256 --final_model_name dtln-256-nx-dns3 \
|
| 6 |
-
--config_file "yaml/config-256.yaml" \
|
| 7 |
-
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
| 8 |
-
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir-512 --final_model_name dtln-512-nx-dns3 \
|
| 12 |
-
--config_file "yaml/config-512.yaml" \
|
| 13 |
-
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
| 14 |
-
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name dtnl-1024-nx2 --final_model_name dtln-1024-nx2 \
|
| 18 |
-
--config_file "yaml/config-1024.yaml" \
|
| 19 |
-
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \
|
| 20 |
-
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2"
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
bash run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name dtln-256-nx2-dns3-mp3 --final_model_name dtln-256-nx2-dns3-mp3 \
|
| 24 |
-
--config_file "yaml/config-256.yaml" \
|
| 25 |
-
--audio_dir "/data/tianxing/HuggingDatasets/nx_noise/data" \
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
END
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
# params
|
| 32 |
-
system_version="windows";
|
| 33 |
-
verbose=true;
|
| 34 |
-
stage=0 # start from 0 if you need to start from data preparation
|
| 35 |
-
stop_stage=9
|
| 36 |
-
|
| 37 |
-
work_dir="$(pwd)"
|
| 38 |
-
file_folder_name=file_folder_name
|
| 39 |
-
final_model_name=final_model_name
|
| 40 |
-
config_file="yaml/config.yaml"
|
| 41 |
-
limit=10
|
| 42 |
-
|
| 43 |
-
audio_dir=/data/tianxing/HuggingDatasets/nx_noise/data
|
| 44 |
-
|
| 45 |
-
max_count=-1
|
| 46 |
-
|
| 47 |
-
nohup_name=nohup.out
|
| 48 |
-
|
| 49 |
-
# model params
|
| 50 |
-
batch_size=64
|
| 51 |
-
max_epochs=200
|
| 52 |
-
save_top_k=10
|
| 53 |
-
patience=5
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
# parse options
|
| 57 |
-
while true; do
|
| 58 |
-
[ -z "${1:-}" ] && break; # break if there are no arguments
|
| 59 |
-
case "$1" in
|
| 60 |
-
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
| 61 |
-
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
| 62 |
-
old_value="(eval echo \\$$name)";
|
| 63 |
-
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
| 64 |
-
was_bool=true;
|
| 65 |
-
else
|
| 66 |
-
was_bool=false;
|
| 67 |
-
fi
|
| 68 |
-
|
| 69 |
-
# Set the variable to the right value-- the escaped quotes make it work if
|
| 70 |
-
# the option had spaces, like --cmd "queue.pl -sync y"
|
| 71 |
-
eval "${name}=\"$2\"";
|
| 72 |
-
|
| 73 |
-
# Check that Boolean-valued arguments are really Boolean.
|
| 74 |
-
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
| 75 |
-
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
| 76 |
-
exit 1;
|
| 77 |
-
fi
|
| 78 |
-
shift 2;
|
| 79 |
-
;;
|
| 80 |
-
|
| 81 |
-
*) break;
|
| 82 |
-
esac
|
| 83 |
-
done
|
| 84 |
-
|
| 85 |
-
file_dir="${work_dir}/${file_folder_name}"
|
| 86 |
-
final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
|
| 87 |
-
evaluation_audio_dir="${file_dir}/evaluation_audio"
|
| 88 |
-
|
| 89 |
-
train_dataset="${file_dir}/train.jsonl"
|
| 90 |
-
valid_dataset="${file_dir}/valid.jsonl"
|
| 91 |
-
|
| 92 |
-
$verbose && echo "system_version: ${system_version}"
|
| 93 |
-
$verbose && echo "file_folder_name: ${file_folder_name}"
|
| 94 |
-
|
| 95 |
-
if [ $system_version == "windows" ]; then
|
| 96 |
-
alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
|
| 97 |
-
elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
|
| 98 |
-
#source /data/local/bin/nx_denoise/bin/activate
|
| 99 |
-
alias python3='/data/local/bin/nx_denoise/bin/python3'
|
| 100 |
-
fi
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
| 104 |
-
$verbose && echo "stage 1: prepare data"
|
| 105 |
-
cd "${work_dir}" || exit 1
|
| 106 |
-
python3 step_1_prepare_data.py \
|
| 107 |
-
--file_dir "${file_dir}" \
|
| 108 |
-
--audio_dir "${audio_dir}" \
|
| 109 |
-
--train_dataset "${train_dataset}" \
|
| 110 |
-
--valid_dataset "${valid_dataset}" \
|
| 111 |
-
--max_count "${max_count}" \
|
| 112 |
-
|
| 113 |
-
fi
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
| 117 |
-
$verbose && echo "stage 2: train model"
|
| 118 |
-
cd "${work_dir}" || exit 1
|
| 119 |
-
python3 step_2_train_model.py \
|
| 120 |
-
--train_dataset "${train_dataset}" \
|
| 121 |
-
--valid_dataset "${valid_dataset}" \
|
| 122 |
-
--serialization_dir "${file_dir}" \
|
| 123 |
-
--config_file "${config_file}" \
|
| 124 |
-
|
| 125 |
-
fi
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
| 129 |
-
$verbose && echo "stage 3: test model"
|
| 130 |
-
cd "${work_dir}" || exit 1
|
| 131 |
-
python3 step_3_evaluation.py \
|
| 132 |
-
--valid_dataset "${valid_dataset}" \
|
| 133 |
-
--model_dir "${file_dir}/best" \
|
| 134 |
-
--evaluation_audio_dir "${evaluation_audio_dir}" \
|
| 135 |
-
--limit "${limit}" \
|
| 136 |
-
|
| 137 |
-
fi
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
| 141 |
-
$verbose && echo "stage 4: collect files"
|
| 142 |
-
cd "${work_dir}" || exit 1
|
| 143 |
-
|
| 144 |
-
mkdir -p ${final_model_dir}
|
| 145 |
-
|
| 146 |
-
cp "${file_dir}/best"/* "${final_model_dir}"
|
| 147 |
-
cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
|
| 148 |
-
|
| 149 |
-
cd "${final_model_dir}/.." || exit 1;
|
| 150 |
-
|
| 151 |
-
if [ -e "${final_model_name}.zip" ]; then
|
| 152 |
-
rm -rf "${final_model_name}_backup.zip"
|
| 153 |
-
mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
|
| 154 |
-
fi
|
| 155 |
-
|
| 156 |
-
zip -r "${final_model_name}.zip" "${final_model_name}"
|
| 157 |
-
rm -rf "${final_model_name}"
|
| 158 |
-
|
| 159 |
-
fi
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
| 163 |
-
$verbose && echo "stage 5: clear file_dir"
|
| 164 |
-
cd "${work_dir}" || exit 1
|
| 165 |
-
|
| 166 |
-
rm -rf "${file_dir}";
|
| 167 |
-
|
| 168 |
-
fi
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/dtln_mp3_to_wav/step_1_prepare_data.py
DELETED
|
@@ -1,127 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/python3
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
import argparse
|
| 4 |
-
import json
|
| 5 |
-
import os
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
import random
|
| 8 |
-
import sys
|
| 9 |
-
|
| 10 |
-
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 11 |
-
sys.path.append(os.path.join(pwd, "../../"))
|
| 12 |
-
|
| 13 |
-
import librosa
|
| 14 |
-
import numpy as np
|
| 15 |
-
from tqdm import tqdm
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def get_args():
|
| 19 |
-
parser = argparse.ArgumentParser()
|
| 20 |
-
parser.add_argument("--file_dir", default="./", type=str)
|
| 21 |
-
|
| 22 |
-
parser.add_argument(
|
| 23 |
-
"--audio_dir",
|
| 24 |
-
default="E:/Users/tianx/HuggingDatasets/nx_noise/data/speech",
|
| 25 |
-
type=str
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
-
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
| 29 |
-
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
| 30 |
-
|
| 31 |
-
parser.add_argument("--duration", default=4.0, type=float)
|
| 32 |
-
|
| 33 |
-
parser.add_argument("--target_sample_rate", default=8000, type=int)
|
| 34 |
-
|
| 35 |
-
parser.add_argument("--max_count", default=-1, type=int)
|
| 36 |
-
|
| 37 |
-
args = parser.parse_args()
|
| 38 |
-
return args
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 1):
|
| 42 |
-
data_dir = Path(data_dir)
|
| 43 |
-
for epoch_idx in range(max_epoch):
|
| 44 |
-
for filename in data_dir.glob("**/*.wav"):
|
| 45 |
-
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
| 46 |
-
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
| 47 |
-
|
| 48 |
-
if raw_duration < duration:
|
| 49 |
-
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
| 50 |
-
continue
|
| 51 |
-
if signal.ndim != 1:
|
| 52 |
-
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
| 53 |
-
|
| 54 |
-
signal_length = len(signal)
|
| 55 |
-
win_size = int(duration * sample_rate)
|
| 56 |
-
for begin in range(0, signal_length - win_size, win_size):
|
| 57 |
-
if np.sum(signal[begin: begin+win_size]) == 0:
|
| 58 |
-
continue
|
| 59 |
-
row = {
|
| 60 |
-
"epoch_idx": epoch_idx,
|
| 61 |
-
"filename": filename.as_posix(),
|
| 62 |
-
"raw_duration": round(raw_duration, 4),
|
| 63 |
-
"offset": round(begin / sample_rate, 4),
|
| 64 |
-
"duration": round(duration, 4),
|
| 65 |
-
}
|
| 66 |
-
yield row
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def main():
|
| 70 |
-
args = get_args()
|
| 71 |
-
|
| 72 |
-
file_dir = Path(args.file_dir)
|
| 73 |
-
file_dir.mkdir(exist_ok=True)
|
| 74 |
-
|
| 75 |
-
audio_dir = Path(args.audio_dir)
|
| 76 |
-
|
| 77 |
-
audio_generator = target_second_signal_generator(
|
| 78 |
-
audio_dir.as_posix(),
|
| 79 |
-
duration=args.duration,
|
| 80 |
-
sample_rate=args.target_sample_rate,
|
| 81 |
-
max_epoch=1,
|
| 82 |
-
)
|
| 83 |
-
count = 0
|
| 84 |
-
process_bar = tqdm(desc="build dataset jsonl")
|
| 85 |
-
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
|
| 86 |
-
for audio in audio_generator:
|
| 87 |
-
if count >= args.max_count > 0:
|
| 88 |
-
break
|
| 89 |
-
|
| 90 |
-
filename = audio["filename"]
|
| 91 |
-
raw_duration = audio["raw_duration"]
|
| 92 |
-
offset = audio["offset"]
|
| 93 |
-
duration = audio["duration"]
|
| 94 |
-
|
| 95 |
-
random1 = random.random()
|
| 96 |
-
random2 = random.random()
|
| 97 |
-
|
| 98 |
-
row = {
|
| 99 |
-
"count": count,
|
| 100 |
-
|
| 101 |
-
"filename": filename,
|
| 102 |
-
"raw_duration": raw_duration,
|
| 103 |
-
"offset": offset,
|
| 104 |
-
"duration": duration,
|
| 105 |
-
|
| 106 |
-
"random1": random1,
|
| 107 |
-
}
|
| 108 |
-
row = json.dumps(row, ensure_ascii=False)
|
| 109 |
-
if random2 < (1 / 300):
|
| 110 |
-
fvalid.write(f"{row}\n")
|
| 111 |
-
else:
|
| 112 |
-
ftrain.write(f"{row}\n")
|
| 113 |
-
|
| 114 |
-
count += 1
|
| 115 |
-
duration_seconds = count * args.duration
|
| 116 |
-
duration_hours = duration_seconds / 3600
|
| 117 |
-
|
| 118 |
-
process_bar.update(n=1)
|
| 119 |
-
process_bar.set_postfix({
|
| 120 |
-
"duration_hours": round(duration_hours, 4),
|
| 121 |
-
})
|
| 122 |
-
|
| 123 |
-
return
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
if __name__ == "__main__":
|
| 127 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/dtln_mp3_to_wav/step_2_train_model.py
DELETED
|
@@ -1,445 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/python3
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
"""
|
| 4 |
-
https://github.com/breizhn/DTLN
|
| 5 |
-
|
| 6 |
-
"""
|
| 7 |
-
import argparse
|
| 8 |
-
import json
|
| 9 |
-
import logging
|
| 10 |
-
from logging.handlers import TimedRotatingFileHandler
|
| 11 |
-
import os
|
| 12 |
-
import platform
|
| 13 |
-
from pathlib import Path
|
| 14 |
-
import random
|
| 15 |
-
import sys
|
| 16 |
-
import shutil
|
| 17 |
-
from typing import List
|
| 18 |
-
|
| 19 |
-
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 20 |
-
sys.path.append(os.path.join(pwd, "../../"))
|
| 21 |
-
|
| 22 |
-
import numpy as np
|
| 23 |
-
import torch
|
| 24 |
-
import torch.nn as nn
|
| 25 |
-
from torch.nn import functional as F
|
| 26 |
-
from torch.utils.data.dataloader import DataLoader
|
| 27 |
-
from tqdm import tqdm
|
| 28 |
-
|
| 29 |
-
from toolbox.torch.utils.data.dataset.mp3_to_wav_jsonl_dataset import Mp3ToWavJsonlDataset
|
| 30 |
-
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
|
| 31 |
-
from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
|
| 32 |
-
from toolbox.torchaudio.metrics.pesq import run_pesq_score
|
| 33 |
-
from toolbox.torchaudio.models.dtln.configuration_dtln import DTLNConfig
|
| 34 |
-
from toolbox.torchaudio.models.dtln.modeling_dtln import DTLNModel, DTLNPretrainedModel
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def get_args():
|
| 38 |
-
parser = argparse.ArgumentParser()
|
| 39 |
-
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
| 40 |
-
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
| 41 |
-
|
| 42 |
-
parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
|
| 43 |
-
parser.add_argument("--patience", default=30, type=int)
|
| 44 |
-
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
| 45 |
-
|
| 46 |
-
parser.add_argument("--config_file", default="config.yaml", type=str)
|
| 47 |
-
|
| 48 |
-
args = parser.parse_args()
|
| 49 |
-
return args
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def logging_config(file_dir: str):
|
| 53 |
-
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
| 54 |
-
|
| 55 |
-
logging.basicConfig(format=fmt,
|
| 56 |
-
datefmt="%m/%d/%Y %H:%M:%S",
|
| 57 |
-
level=logging.INFO)
|
| 58 |
-
file_handler = TimedRotatingFileHandler(
|
| 59 |
-
filename=os.path.join(file_dir, "main.log"),
|
| 60 |
-
encoding="utf-8",
|
| 61 |
-
when="D",
|
| 62 |
-
interval=1,
|
| 63 |
-
backupCount=7
|
| 64 |
-
)
|
| 65 |
-
file_handler.setLevel(logging.INFO)
|
| 66 |
-
file_handler.setFormatter(logging.Formatter(fmt))
|
| 67 |
-
logger = logging.getLogger(__name__)
|
| 68 |
-
logger.addHandler(file_handler)
|
| 69 |
-
|
| 70 |
-
return logger
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
class CollateFunction(object):
|
| 74 |
-
def __init__(self):
|
| 75 |
-
pass
|
| 76 |
-
|
| 77 |
-
def __call__(self, batch: List[dict]):
|
| 78 |
-
mp3_waveform_list = list()
|
| 79 |
-
wav_waveform_list = list()
|
| 80 |
-
|
| 81 |
-
for sample in batch:
|
| 82 |
-
mp3_waveform: torch.Tensor = sample["mp3_waveform"]
|
| 83 |
-
wav_waveform: torch.Tensor = sample["wav_waveform"]
|
| 84 |
-
|
| 85 |
-
mp3_waveform_list.append(mp3_waveform)
|
| 86 |
-
wav_waveform_list.append(wav_waveform)
|
| 87 |
-
|
| 88 |
-
mp3_waveform_list = torch.stack(mp3_waveform_list)
|
| 89 |
-
wav_waveform_list = torch.stack(wav_waveform_list)
|
| 90 |
-
|
| 91 |
-
# assert
|
| 92 |
-
if torch.any(torch.isnan(mp3_waveform_list)) or torch.any(torch.isinf(mp3_waveform_list)):
|
| 93 |
-
raise AssertionError("nan or inf in mp3_waveform_list")
|
| 94 |
-
if torch.any(torch.isnan(wav_waveform_list)) or torch.any(torch.isinf(wav_waveform_list)):
|
| 95 |
-
raise AssertionError("nan or inf in wav_waveform_list")
|
| 96 |
-
|
| 97 |
-
return mp3_waveform_list, wav_waveform_list
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
collate_fn = CollateFunction()
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
def main():
|
| 104 |
-
args = get_args()
|
| 105 |
-
|
| 106 |
-
config = DTLNConfig.from_pretrained(
|
| 107 |
-
pretrained_model_name_or_path=args.config_file,
|
| 108 |
-
)
|
| 109 |
-
|
| 110 |
-
serialization_dir = Path(args.serialization_dir)
|
| 111 |
-
serialization_dir.mkdir(parents=True, exist_ok=True)
|
| 112 |
-
|
| 113 |
-
logger = logging_config(serialization_dir)
|
| 114 |
-
|
| 115 |
-
random.seed(config.seed)
|
| 116 |
-
np.random.seed(config.seed)
|
| 117 |
-
torch.manual_seed(config.seed)
|
| 118 |
-
logger.info(f"set seed: {config.seed}")
|
| 119 |
-
|
| 120 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 121 |
-
n_gpu = torch.cuda.device_count()
|
| 122 |
-
logger.info(f"GPU available count: {n_gpu}; device: {device}")
|
| 123 |
-
|
| 124 |
-
# datasets
|
| 125 |
-
train_dataset = Mp3ToWavJsonlDataset(
|
| 126 |
-
jsonl_file=args.train_dataset,
|
| 127 |
-
expected_sample_rate=config.sample_rate,
|
| 128 |
-
max_wave_value=32768.0,
|
| 129 |
-
# skip=225000,
|
| 130 |
-
)
|
| 131 |
-
valid_dataset = Mp3ToWavJsonlDataset(
|
| 132 |
-
jsonl_file=args.valid_dataset,
|
| 133 |
-
expected_sample_rate=config.sample_rate,
|
| 134 |
-
max_wave_value=32768.0,
|
| 135 |
-
)
|
| 136 |
-
train_data_loader = DataLoader(
|
| 137 |
-
dataset=train_dataset,
|
| 138 |
-
batch_size=config.batch_size,
|
| 139 |
-
# shuffle=True,
|
| 140 |
-
sampler=None,
|
| 141 |
-
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
| 142 |
-
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
| 143 |
-
collate_fn=collate_fn,
|
| 144 |
-
pin_memory=False,
|
| 145 |
-
prefetch_factor=None if platform.system() == "Windows" else 2,
|
| 146 |
-
)
|
| 147 |
-
valid_data_loader = DataLoader(
|
| 148 |
-
dataset=valid_dataset,
|
| 149 |
-
batch_size=config.batch_size,
|
| 150 |
-
# shuffle=True,
|
| 151 |
-
sampler=None,
|
| 152 |
-
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
| 153 |
-
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
| 154 |
-
collate_fn=collate_fn,
|
| 155 |
-
pin_memory=False,
|
| 156 |
-
prefetch_factor=None if platform.system() == "Windows" else 2,
|
| 157 |
-
)
|
| 158 |
-
|
| 159 |
-
# models
|
| 160 |
-
logger.info(f"prepare models. config_file: {args.config_file}")
|
| 161 |
-
model = DTLNPretrainedModel(config).to(device)
|
| 162 |
-
model.to(device)
|
| 163 |
-
model.train()
|
| 164 |
-
|
| 165 |
-
# optimizer
|
| 166 |
-
logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
|
| 167 |
-
optimizer = torch.optim.AdamW(model.parameters(), config.lr)
|
| 168 |
-
|
| 169 |
-
# resume training
|
| 170 |
-
last_step_idx = -1
|
| 171 |
-
last_epoch = -1
|
| 172 |
-
for step_idx_str in serialization_dir.glob("steps-*"):
|
| 173 |
-
step_idx_str = Path(step_idx_str)
|
| 174 |
-
step_idx = step_idx_str.stem.split("-")[1]
|
| 175 |
-
step_idx = int(step_idx)
|
| 176 |
-
if step_idx > last_step_idx:
|
| 177 |
-
last_step_idx = step_idx
|
| 178 |
-
# last_epoch = 1
|
| 179 |
-
|
| 180 |
-
if last_step_idx != -1:
|
| 181 |
-
logger.info(f"resume from steps-{last_step_idx}.")
|
| 182 |
-
model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
|
| 183 |
-
|
| 184 |
-
logger.info(f"load state dict for model.")
|
| 185 |
-
with open(model_pt.as_posix(), "rb") as f:
|
| 186 |
-
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
| 187 |
-
model.load_state_dict(state_dict, strict=True)
|
| 188 |
-
|
| 189 |
-
if config.lr_scheduler == "CosineAnnealingLR":
|
| 190 |
-
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 191 |
-
optimizer,
|
| 192 |
-
last_epoch=last_epoch,
|
| 193 |
-
# T_max=10 * config.eval_steps,
|
| 194 |
-
# eta_min=0.01 * config.lr,
|
| 195 |
-
**config.lr_scheduler_kwargs,
|
| 196 |
-
)
|
| 197 |
-
elif config.lr_scheduler == "MultiStepLR":
|
| 198 |
-
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
| 199 |
-
optimizer,
|
| 200 |
-
last_epoch=last_epoch,
|
| 201 |
-
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
|
| 202 |
-
)
|
| 203 |
-
else:
|
| 204 |
-
raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
|
| 205 |
-
|
| 206 |
-
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
|
| 207 |
-
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
| 208 |
-
fft_size_list=[256, 512, 1024],
|
| 209 |
-
win_size_list=[256, 512, 1024],
|
| 210 |
-
hop_size_list=[128, 256, 512],
|
| 211 |
-
factor_sc=1.5,
|
| 212 |
-
factor_mag=1.0,
|
| 213 |
-
reduction="mean"
|
| 214 |
-
).to(device)
|
| 215 |
-
audio_l1_loss_fn = nn.L1Loss(reduction="mean")
|
| 216 |
-
|
| 217 |
-
# training loop
|
| 218 |
-
|
| 219 |
-
# state
|
| 220 |
-
average_pesq_score = 1000000000
|
| 221 |
-
average_loss = 1000000000
|
| 222 |
-
average_mr_stft_loss = 1000000000
|
| 223 |
-
average_audio_l1_loss = 1000000000
|
| 224 |
-
average_neg_si_snr_loss = 1000000000
|
| 225 |
-
|
| 226 |
-
model_list = list()
|
| 227 |
-
best_epoch_idx = None
|
| 228 |
-
best_step_idx = None
|
| 229 |
-
best_metric = None
|
| 230 |
-
patience_count = 0
|
| 231 |
-
|
| 232 |
-
step_idx = 0 if last_step_idx == -1 else last_step_idx
|
| 233 |
-
|
| 234 |
-
logger.info("training")
|
| 235 |
-
early_stop_flag = False
|
| 236 |
-
for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
|
| 237 |
-
if early_stop_flag:
|
| 238 |
-
break
|
| 239 |
-
|
| 240 |
-
# train
|
| 241 |
-
model.train()
|
| 242 |
-
|
| 243 |
-
total_pesq_score = 0.
|
| 244 |
-
total_loss = 0.
|
| 245 |
-
total_mr_stft_loss = 0.
|
| 246 |
-
total_audio_l1_loss = 0.
|
| 247 |
-
total_neg_si_snr_loss = 0.
|
| 248 |
-
total_batches = 0.
|
| 249 |
-
|
| 250 |
-
progress_bar_train = tqdm(
|
| 251 |
-
initial=step_idx,
|
| 252 |
-
desc="Training; epoch-{}".format(epoch_idx),
|
| 253 |
-
)
|
| 254 |
-
for train_batch in train_data_loader:
|
| 255 |
-
mp3_audios, wav_audios = train_batch
|
| 256 |
-
noisy_audios: torch.Tensor = mp3_audios.to(device)
|
| 257 |
-
clean_audios: torch.Tensor = wav_audios.to(device)
|
| 258 |
-
|
| 259 |
-
denoise_audios = model.forward(noisy_audios)
|
| 260 |
-
denoise_audios = torch.squeeze(denoise_audios, dim=1)
|
| 261 |
-
|
| 262 |
-
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
| 263 |
-
audio_l1_loss = audio_l1_loss_fn.forward(denoise_audios, clean_audios)
|
| 264 |
-
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
| 265 |
-
|
| 266 |
-
loss = 1.0 * mr_stft_loss + 1.0 * audio_l1_loss + 1.0 * neg_si_snr_loss
|
| 267 |
-
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 268 |
-
logger.info(f"find nan or inf in loss.")
|
| 269 |
-
continue
|
| 270 |
-
|
| 271 |
-
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
| 272 |
-
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
| 273 |
-
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
| 274 |
-
|
| 275 |
-
optimizer.zero_grad()
|
| 276 |
-
loss.backward()
|
| 277 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
|
| 278 |
-
optimizer.step()
|
| 279 |
-
lr_scheduler.step()
|
| 280 |
-
|
| 281 |
-
total_pesq_score += pesq_score
|
| 282 |
-
total_loss += loss.item()
|
| 283 |
-
total_mr_stft_loss += mr_stft_loss.item()
|
| 284 |
-
total_audio_l1_loss += audio_l1_loss.item()
|
| 285 |
-
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
| 286 |
-
total_batches += 1
|
| 287 |
-
|
| 288 |
-
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
| 289 |
-
average_loss = round(total_loss / total_batches, 4)
|
| 290 |
-
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
| 291 |
-
average_audio_l1_loss = round(total_audio_l1_loss / total_batches, 4)
|
| 292 |
-
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
| 293 |
-
|
| 294 |
-
progress_bar_train.update(1)
|
| 295 |
-
progress_bar_train.set_postfix({
|
| 296 |
-
"lr": lr_scheduler.get_last_lr()[0],
|
| 297 |
-
"pesq_score": average_pesq_score,
|
| 298 |
-
"loss": average_loss,
|
| 299 |
-
"mr_stft_loss": average_mr_stft_loss,
|
| 300 |
-
"audio_l1_loss": average_audio_l1_loss,
|
| 301 |
-
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 302 |
-
})
|
| 303 |
-
|
| 304 |
-
# evaluation
|
| 305 |
-
step_idx += 1
|
| 306 |
-
if step_idx % config.eval_steps == 0:
|
| 307 |
-
model.eval()
|
| 308 |
-
with torch.no_grad():
|
| 309 |
-
torch.cuda.empty_cache()
|
| 310 |
-
|
| 311 |
-
total_pesq_score = 0.
|
| 312 |
-
total_loss = 0.
|
| 313 |
-
total_mr_stft_loss = 0.
|
| 314 |
-
total_audio_l1_loss = 0.
|
| 315 |
-
total_neg_si_snr_loss = 0.
|
| 316 |
-
total_batches = 0.
|
| 317 |
-
|
| 318 |
-
progress_bar_train.close()
|
| 319 |
-
progress_bar_eval = tqdm(
|
| 320 |
-
desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
|
| 321 |
-
)
|
| 322 |
-
for eval_batch in valid_data_loader:
|
| 323 |
-
mp3_audios, wav_audios = eval_batch
|
| 324 |
-
noisy_audios: torch.Tensor = mp3_audios.to(device)
|
| 325 |
-
clean_audios: torch.Tensor = wav_audios.to(device)
|
| 326 |
-
|
| 327 |
-
denoise_audios = model.forward(noisy_audios)
|
| 328 |
-
denoise_audios = torch.squeeze(denoise_audios, dim=1)
|
| 329 |
-
|
| 330 |
-
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
| 331 |
-
audio_l1_loss = audio_l1_loss_fn.forward(denoise_audios, clean_audios)
|
| 332 |
-
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
| 333 |
-
|
| 334 |
-
loss = 1.0 * mr_stft_loss + 1.0 * audio_l1_loss + 1.0 * neg_si_snr_loss
|
| 335 |
-
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 336 |
-
logger.info(f"find nan or inf in loss.")
|
| 337 |
-
continue
|
| 338 |
-
|
| 339 |
-
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
| 340 |
-
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
| 341 |
-
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
| 342 |
-
|
| 343 |
-
total_pesq_score += pesq_score
|
| 344 |
-
total_loss += loss.item()
|
| 345 |
-
total_mr_stft_loss += mr_stft_loss.item()
|
| 346 |
-
total_audio_l1_loss += audio_l1_loss.item()
|
| 347 |
-
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
| 348 |
-
total_batches += 1
|
| 349 |
-
|
| 350 |
-
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
| 351 |
-
average_loss = round(total_loss / total_batches, 4)
|
| 352 |
-
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
| 353 |
-
average_audio_l1_loss = round(total_audio_l1_loss / total_batches, 4)
|
| 354 |
-
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
| 355 |
-
|
| 356 |
-
progress_bar_eval.update(1)
|
| 357 |
-
progress_bar_eval.set_postfix({
|
| 358 |
-
"lr": lr_scheduler.get_last_lr()[0],
|
| 359 |
-
"pesq_score": average_pesq_score,
|
| 360 |
-
"loss": average_loss,
|
| 361 |
-
"mr_stft_loss": average_mr_stft_loss,
|
| 362 |
-
"audio_l1_loss": average_audio_l1_loss,
|
| 363 |
-
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 364 |
-
|
| 365 |
-
})
|
| 366 |
-
|
| 367 |
-
total_pesq_score = 0.
|
| 368 |
-
total_loss = 0.
|
| 369 |
-
total_mr_stft_loss = 0.
|
| 370 |
-
total_audio_l1_loss = 0.
|
| 371 |
-
total_neg_si_snr_loss = 0.
|
| 372 |
-
total_batches = 0.
|
| 373 |
-
|
| 374 |
-
progress_bar_eval.close()
|
| 375 |
-
progress_bar_train = tqdm(
|
| 376 |
-
initial=progress_bar_train.n,
|
| 377 |
-
postfix=progress_bar_train.postfix,
|
| 378 |
-
desc=progress_bar_train.desc,
|
| 379 |
-
)
|
| 380 |
-
|
| 381 |
-
# save path
|
| 382 |
-
save_dir = serialization_dir / "steps-{}".format(step_idx)
|
| 383 |
-
save_dir.mkdir(parents=True, exist_ok=False)
|
| 384 |
-
|
| 385 |
-
# save models
|
| 386 |
-
model.save_pretrained(save_dir.as_posix())
|
| 387 |
-
|
| 388 |
-
model_list.append(save_dir)
|
| 389 |
-
if len(model_list) >= args.num_serialized_models_to_keep:
|
| 390 |
-
model_to_delete: Path = model_list.pop(0)
|
| 391 |
-
shutil.rmtree(model_to_delete.as_posix())
|
| 392 |
-
|
| 393 |
-
# save metric
|
| 394 |
-
if best_metric is None:
|
| 395 |
-
best_epoch_idx = epoch_idx
|
| 396 |
-
best_step_idx = step_idx
|
| 397 |
-
best_metric = average_pesq_score
|
| 398 |
-
elif average_pesq_score >= best_metric:
|
| 399 |
-
# great is better.
|
| 400 |
-
best_epoch_idx = epoch_idx
|
| 401 |
-
best_step_idx = step_idx
|
| 402 |
-
best_metric = average_pesq_score
|
| 403 |
-
else:
|
| 404 |
-
pass
|
| 405 |
-
|
| 406 |
-
metrics = {
|
| 407 |
-
"epoch_idx": epoch_idx,
|
| 408 |
-
"best_epoch_idx": best_epoch_idx,
|
| 409 |
-
"best_step_idx": best_step_idx,
|
| 410 |
-
"pesq_score": average_pesq_score,
|
| 411 |
-
"loss": average_loss,
|
| 412 |
-
"mr_stft_loss": average_mr_stft_loss,
|
| 413 |
-
"audio_l1_loss": average_audio_l1_loss,
|
| 414 |
-
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 415 |
-
}
|
| 416 |
-
metrics_filename = save_dir / "metrics_epoch.json"
|
| 417 |
-
with open(metrics_filename, "w", encoding="utf-8") as f:
|
| 418 |
-
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
| 419 |
-
|
| 420 |
-
# save best
|
| 421 |
-
best_dir = serialization_dir / "best"
|
| 422 |
-
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
| 423 |
-
if best_dir.exists():
|
| 424 |
-
shutil.rmtree(best_dir)
|
| 425 |
-
shutil.copytree(save_dir, best_dir)
|
| 426 |
-
|
| 427 |
-
# early stop
|
| 428 |
-
early_stop_flag = False
|
| 429 |
-
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
| 430 |
-
patience_count = 0
|
| 431 |
-
else:
|
| 432 |
-
patience_count += 1
|
| 433 |
-
if patience_count >= args.patience:
|
| 434 |
-
early_stop_flag = True
|
| 435 |
-
|
| 436 |
-
# early stop
|
| 437 |
-
if early_stop_flag:
|
| 438 |
-
break
|
| 439 |
-
model.train()
|
| 440 |
-
|
| 441 |
-
return
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
if __name__ == "__main__":
|
| 445 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/dtln_mp3_to_wav/yaml/config-1024.yaml
DELETED
|
@@ -1,29 +0,0 @@
|
|
| 1 |
-
model_name: "DTLN"
|
| 2 |
-
|
| 3 |
-
# spec
|
| 4 |
-
sample_rate: 8000
|
| 5 |
-
fft_size: 512
|
| 6 |
-
hop_size: 128
|
| 7 |
-
win_type: hann
|
| 8 |
-
|
| 9 |
-
# data
|
| 10 |
-
min_snr_db: -5
|
| 11 |
-
max_snr_db: 25
|
| 12 |
-
|
| 13 |
-
# model
|
| 14 |
-
encoder_size: 1024
|
| 15 |
-
|
| 16 |
-
# train
|
| 17 |
-
lr: 0.001
|
| 18 |
-
lr_scheduler: "CosineAnnealingLR"
|
| 19 |
-
lr_scheduler_kwargs:
|
| 20 |
-
T_max: 250000
|
| 21 |
-
eta_min: 0.0001
|
| 22 |
-
|
| 23 |
-
max_epochs: 100
|
| 24 |
-
clip_grad_norm: 10.0
|
| 25 |
-
seed: 1234
|
| 26 |
-
|
| 27 |
-
num_workers: 4
|
| 28 |
-
batch_size: 64
|
| 29 |
-
eval_steps: 15000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/dtln_mp3_to_wav/yaml/config-256.yaml
DELETED
|
@@ -1,29 +0,0 @@
|
|
| 1 |
-
model_name: "DTLN"
|
| 2 |
-
|
| 3 |
-
# spec
|
| 4 |
-
sample_rate: 8000
|
| 5 |
-
fft_size: 256
|
| 6 |
-
hop_size: 128
|
| 7 |
-
win_type: hann
|
| 8 |
-
|
| 9 |
-
# data
|
| 10 |
-
min_snr_db: -5
|
| 11 |
-
max_snr_db: 25
|
| 12 |
-
|
| 13 |
-
# model
|
| 14 |
-
encoder_size: 256
|
| 15 |
-
|
| 16 |
-
# train
|
| 17 |
-
lr: 0.001
|
| 18 |
-
lr_scheduler: "CosineAnnealingLR"
|
| 19 |
-
lr_scheduler_kwargs:
|
| 20 |
-
T_max: 250000
|
| 21 |
-
eta_min: 0.0001
|
| 22 |
-
|
| 23 |
-
max_epochs: 100
|
| 24 |
-
clip_grad_norm: 10.0
|
| 25 |
-
seed: 1234
|
| 26 |
-
|
| 27 |
-
num_workers: 4
|
| 28 |
-
batch_size: 64
|
| 29 |
-
eval_steps: 15000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/dtln_mp3_to_wav/yaml/config-512.yaml
DELETED
|
@@ -1,29 +0,0 @@
|
|
| 1 |
-
model_name: "DTLN"
|
| 2 |
-
|
| 3 |
-
# spec
|
| 4 |
-
sample_rate: 8000
|
| 5 |
-
fft_size: 512
|
| 6 |
-
hop_size: 128
|
| 7 |
-
win_type: hann
|
| 8 |
-
|
| 9 |
-
# data
|
| 10 |
-
min_snr_db: -5
|
| 11 |
-
max_snr_db: 25
|
| 12 |
-
|
| 13 |
-
# model
|
| 14 |
-
encoder_size: 512
|
| 15 |
-
|
| 16 |
-
# train
|
| 17 |
-
lr: 0.001
|
| 18 |
-
lr_scheduler: "CosineAnnealingLR"
|
| 19 |
-
lr_scheduler_kwargs:
|
| 20 |
-
T_max: 250000
|
| 21 |
-
eta_min: 0.0001
|
| 22 |
-
|
| 23 |
-
max_epochs: 100
|
| 24 |
-
clip_grad_norm: 10.0
|
| 25 |
-
seed: 1234
|
| 26 |
-
|
| 27 |
-
num_workers: 4
|
| 28 |
-
batch_size: 64
|
| 29 |
-
eval_steps: 15000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/frcrn_mp3_to_wav/run.sh
DELETED
|
@@ -1,156 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env bash
|
| 2 |
-
|
| 3 |
-
: <<'END'
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn-20-512-nx-dns3 \
|
| 7 |
-
--config_file "yaml/config-10.yaml" \
|
| 8 |
-
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
| 9 |
-
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech"
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn-10-nx2 \
|
| 13 |
-
--config_file "yaml/config-10.yaml" \
|
| 14 |
-
--audio_dir "/data/tianxing/HuggingDatasets/nx_noise/data" \
|
| 15 |
-
|
| 16 |
-
END
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
# params
|
| 20 |
-
system_version="windows";
|
| 21 |
-
verbose=true;
|
| 22 |
-
stage=0 # start from 0 if you need to start from data preparation
|
| 23 |
-
stop_stage=9
|
| 24 |
-
|
| 25 |
-
work_dir="$(pwd)"
|
| 26 |
-
file_folder_name=file_folder_name
|
| 27 |
-
final_model_name=final_model_name
|
| 28 |
-
config_file="yaml/config.yaml"
|
| 29 |
-
limit=10
|
| 30 |
-
|
| 31 |
-
audio_dir=/data/tianxing/HuggingDatasets/nx_noise/data
|
| 32 |
-
|
| 33 |
-
max_count=10000000
|
| 34 |
-
|
| 35 |
-
nohup_name=nohup.out
|
| 36 |
-
|
| 37 |
-
# model params
|
| 38 |
-
batch_size=64
|
| 39 |
-
max_epochs=200
|
| 40 |
-
save_top_k=10
|
| 41 |
-
patience=5
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
# parse options
|
| 45 |
-
while true; do
|
| 46 |
-
[ -z "${1:-}" ] && break; # break if there are no arguments
|
| 47 |
-
case "$1" in
|
| 48 |
-
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
| 49 |
-
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
| 50 |
-
old_value="(eval echo \\$$name)";
|
| 51 |
-
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
| 52 |
-
was_bool=true;
|
| 53 |
-
else
|
| 54 |
-
was_bool=false;
|
| 55 |
-
fi
|
| 56 |
-
|
| 57 |
-
# Set the variable to the right value-- the escaped quotes make it work if
|
| 58 |
-
# the option had spaces, like --cmd "queue.pl -sync y"
|
| 59 |
-
eval "${name}=\"$2\"";
|
| 60 |
-
|
| 61 |
-
# Check that Boolean-valued arguments are really Boolean.
|
| 62 |
-
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
| 63 |
-
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
| 64 |
-
exit 1;
|
| 65 |
-
fi
|
| 66 |
-
shift 2;
|
| 67 |
-
;;
|
| 68 |
-
|
| 69 |
-
*) break;
|
| 70 |
-
esac
|
| 71 |
-
done
|
| 72 |
-
|
| 73 |
-
file_dir="${work_dir}/${file_folder_name}"
|
| 74 |
-
final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
|
| 75 |
-
evaluation_audio_dir="${file_dir}/evaluation_audio"
|
| 76 |
-
|
| 77 |
-
train_dataset="${file_dir}/train.jsonl"
|
| 78 |
-
valid_dataset="${file_dir}/valid.jsonl"
|
| 79 |
-
|
| 80 |
-
$verbose && echo "system_version: ${system_version}"
|
| 81 |
-
$verbose && echo "file_folder_name: ${file_folder_name}"
|
| 82 |
-
|
| 83 |
-
if [ $system_version == "windows" ]; then
|
| 84 |
-
alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
|
| 85 |
-
elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
|
| 86 |
-
#source /data/local/bin/nx_denoise/bin/activate
|
| 87 |
-
alias python3='/data/local/bin/nx_denoise/bin/python3'
|
| 88 |
-
fi
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
| 92 |
-
$verbose && echo "stage 1: prepare data"
|
| 93 |
-
cd "${work_dir}" || exit 1
|
| 94 |
-
python3 step_1_prepare_data.py \
|
| 95 |
-
--file_dir "${file_dir}" \
|
| 96 |
-
--audio_dir "${audio_dir}" \
|
| 97 |
-
--train_dataset "${train_dataset}" \
|
| 98 |
-
--valid_dataset "${valid_dataset}" \
|
| 99 |
-
--max_count "${max_count}" \
|
| 100 |
-
|
| 101 |
-
fi
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
| 105 |
-
$verbose && echo "stage 2: train model"
|
| 106 |
-
cd "${work_dir}" || exit 1
|
| 107 |
-
python3 step_2_train_model.py \
|
| 108 |
-
--train_dataset "${train_dataset}" \
|
| 109 |
-
--valid_dataset "${valid_dataset}" \
|
| 110 |
-
--serialization_dir "${file_dir}" \
|
| 111 |
-
--config_file "${config_file}" \
|
| 112 |
-
|
| 113 |
-
fi
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
| 117 |
-
$verbose && echo "stage 3: test model"
|
| 118 |
-
cd "${work_dir}" || exit 1
|
| 119 |
-
python3 step_3_evaluation.py \
|
| 120 |
-
--valid_dataset "${valid_dataset}" \
|
| 121 |
-
--model_dir "${file_dir}/best" \
|
| 122 |
-
--evaluation_audio_dir "${evaluation_audio_dir}" \
|
| 123 |
-
--limit "${limit}" \
|
| 124 |
-
|
| 125 |
-
fi
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
| 129 |
-
$verbose && echo "stage 4: collect files"
|
| 130 |
-
cd "${work_dir}" || exit 1
|
| 131 |
-
|
| 132 |
-
mkdir -p ${final_model_dir}
|
| 133 |
-
|
| 134 |
-
cp "${file_dir}/best"/* "${final_model_dir}"
|
| 135 |
-
cp -r "${file_dir}/evaluation_audio" "${final_model_dir}"
|
| 136 |
-
|
| 137 |
-
cd "${final_model_dir}/.." || exit 1;
|
| 138 |
-
|
| 139 |
-
if [ -e "${final_model_name}.zip" ]; then
|
| 140 |
-
rm -rf "${final_model_name}_backup.zip"
|
| 141 |
-
mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
|
| 142 |
-
fi
|
| 143 |
-
|
| 144 |
-
zip -r "${final_model_name}.zip" "${final_model_name}"
|
| 145 |
-
rm -rf "${final_model_name}"
|
| 146 |
-
|
| 147 |
-
fi
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
| 151 |
-
$verbose && echo "stage 5: clear file_dir"
|
| 152 |
-
cd "${work_dir}" || exit 1
|
| 153 |
-
|
| 154 |
-
rm -rf "${file_dir}";
|
| 155 |
-
|
| 156 |
-
fi
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/frcrn_mp3_to_wav/step_1_prepare_data.py
DELETED
|
@@ -1,127 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/python3
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
import argparse
|
| 4 |
-
import json
|
| 5 |
-
import os
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
import random
|
| 8 |
-
import sys
|
| 9 |
-
|
| 10 |
-
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 11 |
-
sys.path.append(os.path.join(pwd, "../../"))
|
| 12 |
-
|
| 13 |
-
import librosa
|
| 14 |
-
import numpy as np
|
| 15 |
-
from tqdm import tqdm
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def get_args():
|
| 19 |
-
parser = argparse.ArgumentParser()
|
| 20 |
-
parser.add_argument("--file_dir", default="./", type=str)
|
| 21 |
-
|
| 22 |
-
parser.add_argument(
|
| 23 |
-
"--audio_dir",
|
| 24 |
-
default="E:/Users/tianx/HuggingDatasets/nx_noise/data/speech",
|
| 25 |
-
type=str
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
-
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
| 29 |
-
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
| 30 |
-
|
| 31 |
-
parser.add_argument("--duration", default=4.0, type=float)
|
| 32 |
-
|
| 33 |
-
parser.add_argument("--target_sample_rate", default=8000, type=int)
|
| 34 |
-
|
| 35 |
-
parser.add_argument("--max_count", default=-1, type=int)
|
| 36 |
-
|
| 37 |
-
args = parser.parse_args()
|
| 38 |
-
return args
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 1):
|
| 42 |
-
data_dir = Path(data_dir)
|
| 43 |
-
for epoch_idx in range(max_epoch):
|
| 44 |
-
for filename in data_dir.glob("**/*.wav"):
|
| 45 |
-
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
| 46 |
-
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
| 47 |
-
|
| 48 |
-
if raw_duration < duration:
|
| 49 |
-
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
| 50 |
-
continue
|
| 51 |
-
if signal.ndim != 1:
|
| 52 |
-
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
| 53 |
-
|
| 54 |
-
signal_length = len(signal)
|
| 55 |
-
win_size = int(duration * sample_rate)
|
| 56 |
-
for begin in range(0, signal_length - win_size, win_size):
|
| 57 |
-
if np.sum(signal[begin: begin+win_size]) == 0:
|
| 58 |
-
continue
|
| 59 |
-
row = {
|
| 60 |
-
"epoch_idx": epoch_idx,
|
| 61 |
-
"filename": filename.as_posix(),
|
| 62 |
-
"raw_duration": round(raw_duration, 4),
|
| 63 |
-
"offset": round(begin / sample_rate, 4),
|
| 64 |
-
"duration": round(duration, 4),
|
| 65 |
-
}
|
| 66 |
-
yield row
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def main():
|
| 70 |
-
args = get_args()
|
| 71 |
-
|
| 72 |
-
file_dir = Path(args.file_dir)
|
| 73 |
-
file_dir.mkdir(exist_ok=True)
|
| 74 |
-
|
| 75 |
-
audio_dir = Path(args.audio_dir)
|
| 76 |
-
|
| 77 |
-
audio_generator = target_second_signal_generator(
|
| 78 |
-
audio_dir.as_posix(),
|
| 79 |
-
duration=args.duration,
|
| 80 |
-
sample_rate=args.target_sample_rate,
|
| 81 |
-
max_epoch=1,
|
| 82 |
-
)
|
| 83 |
-
count = 0
|
| 84 |
-
process_bar = tqdm(desc="build dataset jsonl")
|
| 85 |
-
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
|
| 86 |
-
for audio in audio_generator:
|
| 87 |
-
if count >= args.max_count > 0:
|
| 88 |
-
break
|
| 89 |
-
|
| 90 |
-
filename = audio["filename"]
|
| 91 |
-
raw_duration = audio["raw_duration"]
|
| 92 |
-
offset = audio["offset"]
|
| 93 |
-
duration = audio["duration"]
|
| 94 |
-
|
| 95 |
-
random1 = random.random()
|
| 96 |
-
random2 = random.random()
|
| 97 |
-
|
| 98 |
-
row = {
|
| 99 |
-
"count": count,
|
| 100 |
-
|
| 101 |
-
"filename": filename,
|
| 102 |
-
"raw_duration": raw_duration,
|
| 103 |
-
"offset": offset,
|
| 104 |
-
"duration": duration,
|
| 105 |
-
|
| 106 |
-
"random1": random1,
|
| 107 |
-
}
|
| 108 |
-
row = json.dumps(row, ensure_ascii=False)
|
| 109 |
-
if random2 < (1 / 10):
|
| 110 |
-
fvalid.write(f"{row}\n")
|
| 111 |
-
else:
|
| 112 |
-
ftrain.write(f"{row}\n")
|
| 113 |
-
|
| 114 |
-
count += 1
|
| 115 |
-
duration_seconds = count * args.duration
|
| 116 |
-
duration_hours = duration_seconds / 3600
|
| 117 |
-
|
| 118 |
-
process_bar.update(n=1)
|
| 119 |
-
process_bar.set_postfix({
|
| 120 |
-
"duration_hours": round(duration_hours, 4),
|
| 121 |
-
})
|
| 122 |
-
|
| 123 |
-
return
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
if __name__ == "__main__":
|
| 127 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/frcrn_mp3_to_wav/step_2_train_model.py
DELETED
|
@@ -1,442 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/python3
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
import argparse
|
| 4 |
-
import json
|
| 5 |
-
import logging
|
| 6 |
-
from logging.handlers import TimedRotatingFileHandler
|
| 7 |
-
import os
|
| 8 |
-
import platform
|
| 9 |
-
from pathlib import Path
|
| 10 |
-
import random
|
| 11 |
-
import sys
|
| 12 |
-
import shutil
|
| 13 |
-
from typing import List
|
| 14 |
-
|
| 15 |
-
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 16 |
-
sys.path.append(os.path.join(pwd, "../../"))
|
| 17 |
-
|
| 18 |
-
import numpy as np
|
| 19 |
-
import torch
|
| 20 |
-
import torch.nn as nn
|
| 21 |
-
from torch.nn import functional as F
|
| 22 |
-
from torch.utils.data.dataloader import DataLoader
|
| 23 |
-
from tqdm import tqdm
|
| 24 |
-
|
| 25 |
-
from toolbox.torch.utils.data.dataset.mp3_to_wav_jsonl_dataset import Mp3ToWavJsonlDataset
|
| 26 |
-
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
|
| 27 |
-
from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
|
| 28 |
-
from toolbox.torchaudio.metrics.pesq import run_pesq_score
|
| 29 |
-
from toolbox.torchaudio.models.frcrn.configuration_frcrn import FRCRNConfig
|
| 30 |
-
from toolbox.torchaudio.models.frcrn.modeling_frcrn import FRCRN, FRCRNPretrainedModel
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def get_args():
|
| 34 |
-
parser = argparse.ArgumentParser()
|
| 35 |
-
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
| 36 |
-
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
| 37 |
-
|
| 38 |
-
parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
|
| 39 |
-
parser.add_argument("--patience", default=30, type=int)
|
| 40 |
-
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
| 41 |
-
|
| 42 |
-
parser.add_argument("--config_file", default="config.yaml", type=str)
|
| 43 |
-
|
| 44 |
-
args = parser.parse_args()
|
| 45 |
-
return args
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
def logging_config(file_dir: str):
|
| 49 |
-
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
| 50 |
-
|
| 51 |
-
logging.basicConfig(format=fmt,
|
| 52 |
-
datefmt="%m/%d/%Y %H:%M:%S",
|
| 53 |
-
level=logging.INFO)
|
| 54 |
-
file_handler = TimedRotatingFileHandler(
|
| 55 |
-
filename=os.path.join(file_dir, "main.log"),
|
| 56 |
-
encoding="utf-8",
|
| 57 |
-
when="D",
|
| 58 |
-
interval=1,
|
| 59 |
-
backupCount=7
|
| 60 |
-
)
|
| 61 |
-
file_handler.setLevel(logging.INFO)
|
| 62 |
-
file_handler.setFormatter(logging.Formatter(fmt))
|
| 63 |
-
logger = logging.getLogger(__name__)
|
| 64 |
-
logger.addHandler(file_handler)
|
| 65 |
-
|
| 66 |
-
return logger
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
class CollateFunction(object):
|
| 70 |
-
def __init__(self):
|
| 71 |
-
pass
|
| 72 |
-
|
| 73 |
-
def __call__(self, batch: List[dict]):
|
| 74 |
-
mp3_waveform_list = list()
|
| 75 |
-
wav_waveform_list = list()
|
| 76 |
-
|
| 77 |
-
for sample in batch:
|
| 78 |
-
mp3_waveform: torch.Tensor = sample["mp3_waveform"]
|
| 79 |
-
wav_waveform: torch.Tensor = sample["wav_waveform"]
|
| 80 |
-
|
| 81 |
-
mp3_waveform_list.append(mp3_waveform)
|
| 82 |
-
wav_waveform_list.append(wav_waveform)
|
| 83 |
-
|
| 84 |
-
mp3_waveform_list = torch.stack(mp3_waveform_list)
|
| 85 |
-
wav_waveform_list = torch.stack(wav_waveform_list)
|
| 86 |
-
|
| 87 |
-
# assert
|
| 88 |
-
if torch.any(torch.isnan(mp3_waveform_list)) or torch.any(torch.isinf(mp3_waveform_list)):
|
| 89 |
-
raise AssertionError("nan or inf in mp3_waveform_list")
|
| 90 |
-
if torch.any(torch.isnan(wav_waveform_list)) or torch.any(torch.isinf(wav_waveform_list)):
|
| 91 |
-
raise AssertionError("nan or inf in wav_waveform_list")
|
| 92 |
-
|
| 93 |
-
return mp3_waveform_list, wav_waveform_list
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
collate_fn = CollateFunction()
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
def main():
|
| 100 |
-
args = get_args()
|
| 101 |
-
|
| 102 |
-
config = FRCRNConfig.from_pretrained(
|
| 103 |
-
pretrained_model_name_or_path=args.config_file,
|
| 104 |
-
)
|
| 105 |
-
|
| 106 |
-
serialization_dir = Path(args.serialization_dir)
|
| 107 |
-
serialization_dir.mkdir(parents=True, exist_ok=True)
|
| 108 |
-
|
| 109 |
-
logger = logging_config(serialization_dir)
|
| 110 |
-
|
| 111 |
-
random.seed(config.seed)
|
| 112 |
-
np.random.seed(config.seed)
|
| 113 |
-
torch.manual_seed(config.seed)
|
| 114 |
-
logger.info(f"set seed: {config.seed}")
|
| 115 |
-
|
| 116 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 117 |
-
n_gpu = torch.cuda.device_count()
|
| 118 |
-
logger.info(f"GPU available count: {n_gpu}; device: {device}")
|
| 119 |
-
|
| 120 |
-
# datasets
|
| 121 |
-
train_dataset = Mp3ToWavJsonlDataset(
|
| 122 |
-
jsonl_file=args.train_dataset,
|
| 123 |
-
expected_sample_rate=config.sample_rate,
|
| 124 |
-
max_wave_value=32768.0,
|
| 125 |
-
# skip=225000,
|
| 126 |
-
)
|
| 127 |
-
valid_dataset = Mp3ToWavJsonlDataset(
|
| 128 |
-
jsonl_file=args.valid_dataset,
|
| 129 |
-
expected_sample_rate=config.sample_rate,
|
| 130 |
-
max_wave_value=32768.0,
|
| 131 |
-
)
|
| 132 |
-
train_data_loader = DataLoader(
|
| 133 |
-
dataset=train_dataset,
|
| 134 |
-
batch_size=config.batch_size,
|
| 135 |
-
# shuffle=True,
|
| 136 |
-
sampler=None,
|
| 137 |
-
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
| 138 |
-
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
| 139 |
-
collate_fn=collate_fn,
|
| 140 |
-
pin_memory=False,
|
| 141 |
-
prefetch_factor=2,
|
| 142 |
-
)
|
| 143 |
-
valid_data_loader = DataLoader(
|
| 144 |
-
dataset=valid_dataset,
|
| 145 |
-
batch_size=config.batch_size,
|
| 146 |
-
# shuffle=True,
|
| 147 |
-
sampler=None,
|
| 148 |
-
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
| 149 |
-
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
| 150 |
-
collate_fn=collate_fn,
|
| 151 |
-
pin_memory=False,
|
| 152 |
-
prefetch_factor=2,
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
-
# models
|
| 156 |
-
logger.info(f"prepare models. config_file: {args.config_file}")
|
| 157 |
-
model = FRCRNPretrainedModel(config).to(device)
|
| 158 |
-
model.to(device)
|
| 159 |
-
model.train()
|
| 160 |
-
|
| 161 |
-
# optimizer
|
| 162 |
-
logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
|
| 163 |
-
optimizer = torch.optim.AdamW(model.get_params(weight_decay=config.weight_decay), config.lr)
|
| 164 |
-
|
| 165 |
-
# resume training
|
| 166 |
-
last_step_idx = -1
|
| 167 |
-
last_epoch = -1
|
| 168 |
-
for step_idx_str in serialization_dir.glob("steps-*"):
|
| 169 |
-
step_idx_str = Path(step_idx_str)
|
| 170 |
-
step_idx = step_idx_str.stem.split("-")[1]
|
| 171 |
-
step_idx = int(step_idx)
|
| 172 |
-
if step_idx > last_step_idx:
|
| 173 |
-
last_step_idx = step_idx
|
| 174 |
-
# last_epoch = 0
|
| 175 |
-
|
| 176 |
-
if last_step_idx != -1:
|
| 177 |
-
logger.info(f"resume from steps-{last_step_idx}.")
|
| 178 |
-
model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
|
| 179 |
-
# optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
|
| 180 |
-
|
| 181 |
-
logger.info(f"load state dict for model.")
|
| 182 |
-
with open(model_pt.as_posix(), "rb") as f:
|
| 183 |
-
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
| 184 |
-
model.load_state_dict(state_dict, strict=True)
|
| 185 |
-
|
| 186 |
-
# logger.info(f"load state dict for optimizer.")
|
| 187 |
-
# with open(optimizer_pth.as_posix(), "rb") as f:
|
| 188 |
-
# state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
| 189 |
-
# optimizer.load_state_dict(state_dict)
|
| 190 |
-
|
| 191 |
-
if config.lr_scheduler == "CosineAnnealingLR":
|
| 192 |
-
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 193 |
-
optimizer,
|
| 194 |
-
last_epoch=last_epoch,
|
| 195 |
-
# T_max=10 * config.eval_steps,
|
| 196 |
-
# eta_min=0.01 * config.lr,
|
| 197 |
-
**config.lr_scheduler_kwargs,
|
| 198 |
-
)
|
| 199 |
-
elif config.lr_scheduler == "MultiStepLR":
|
| 200 |
-
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
| 201 |
-
optimizer,
|
| 202 |
-
last_epoch=last_epoch,
|
| 203 |
-
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
|
| 204 |
-
)
|
| 205 |
-
else:
|
| 206 |
-
raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
|
| 207 |
-
|
| 208 |
-
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
|
| 209 |
-
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
| 210 |
-
fft_size_list=[256, 512, 1024],
|
| 211 |
-
win_size_list=[256, 512, 1024],
|
| 212 |
-
hop_size_list=[128, 256, 512],
|
| 213 |
-
factor_sc=1.5,
|
| 214 |
-
factor_mag=1.0,
|
| 215 |
-
reduction="mean"
|
| 216 |
-
).to(device)
|
| 217 |
-
|
| 218 |
-
# training loop
|
| 219 |
-
|
| 220 |
-
# state
|
| 221 |
-
average_pesq_score = 1000000000
|
| 222 |
-
average_loss = 1000000000
|
| 223 |
-
average_neg_si_snr_loss = 1000000000
|
| 224 |
-
average_mask_loss = 1000000000
|
| 225 |
-
|
| 226 |
-
model_list = list()
|
| 227 |
-
best_epoch_idx = None
|
| 228 |
-
best_step_idx = None
|
| 229 |
-
best_metric = None
|
| 230 |
-
patience_count = 0
|
| 231 |
-
|
| 232 |
-
step_idx = 0 if last_step_idx == -1 else last_step_idx
|
| 233 |
-
|
| 234 |
-
logger.info("training")
|
| 235 |
-
early_stop_flag = False
|
| 236 |
-
for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
|
| 237 |
-
if early_stop_flag:
|
| 238 |
-
break
|
| 239 |
-
|
| 240 |
-
# train
|
| 241 |
-
model.train()
|
| 242 |
-
|
| 243 |
-
total_pesq_score = 0.
|
| 244 |
-
total_loss = 0.
|
| 245 |
-
total_mr_stft_loss = 0.
|
| 246 |
-
total_neg_si_snr_loss = 0.
|
| 247 |
-
total_mask_loss = 0.
|
| 248 |
-
total_batches = 0.
|
| 249 |
-
|
| 250 |
-
progress_bar_train = tqdm(
|
| 251 |
-
initial=step_idx,
|
| 252 |
-
desc="Training; epoch-{}".format(epoch_idx),
|
| 253 |
-
)
|
| 254 |
-
for train_batch in train_data_loader:
|
| 255 |
-
mp3_audios, wav_audios = train_batch
|
| 256 |
-
noisy_audios: torch.Tensor = mp3_audios.to(device)
|
| 257 |
-
clean_audios: torch.Tensor = wav_audios.to(device)
|
| 258 |
-
|
| 259 |
-
est_spec, est_wav, est_mask = model.forward(noisy_audios)
|
| 260 |
-
denoise_audios = est_wav
|
| 261 |
-
|
| 262 |
-
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
| 263 |
-
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
| 264 |
-
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
| 265 |
-
|
| 266 |
-
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss
|
| 267 |
-
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 268 |
-
logger.info(f"find nan or inf in loss.")
|
| 269 |
-
continue
|
| 270 |
-
|
| 271 |
-
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
| 272 |
-
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
| 273 |
-
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
| 274 |
-
|
| 275 |
-
optimizer.zero_grad()
|
| 276 |
-
loss.backward()
|
| 277 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
|
| 278 |
-
optimizer.step()
|
| 279 |
-
lr_scheduler.step()
|
| 280 |
-
|
| 281 |
-
total_pesq_score += pesq_score
|
| 282 |
-
total_loss += loss.item()
|
| 283 |
-
total_mr_stft_loss += mr_stft_loss.item()
|
| 284 |
-
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
| 285 |
-
total_mask_loss += mask_loss.item()
|
| 286 |
-
total_batches += 1
|
| 287 |
-
|
| 288 |
-
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
| 289 |
-
average_loss = round(total_loss / total_batches, 4)
|
| 290 |
-
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
| 291 |
-
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
| 292 |
-
average_mask_loss = round(total_mask_loss / total_batches, 4)
|
| 293 |
-
|
| 294 |
-
progress_bar_train.update(1)
|
| 295 |
-
progress_bar_train.set_postfix({
|
| 296 |
-
"lr": lr_scheduler.get_last_lr()[0],
|
| 297 |
-
"pesq_score": average_pesq_score,
|
| 298 |
-
"loss": average_loss,
|
| 299 |
-
"mr_stft_loss": average_mr_stft_loss,
|
| 300 |
-
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 301 |
-
"mask_loss": average_mask_loss,
|
| 302 |
-
})
|
| 303 |
-
|
| 304 |
-
# evaluation
|
| 305 |
-
step_idx += 1
|
| 306 |
-
if step_idx % config.eval_steps == 0:
|
| 307 |
-
model.eval()
|
| 308 |
-
with torch.no_grad():
|
| 309 |
-
torch.cuda.empty_cache()
|
| 310 |
-
|
| 311 |
-
total_pesq_score = 0.
|
| 312 |
-
total_loss = 0.
|
| 313 |
-
total_mr_stft_loss = 0.
|
| 314 |
-
total_neg_si_snr_loss = 0.
|
| 315 |
-
total_mask_loss = 0.
|
| 316 |
-
total_batches = 0.
|
| 317 |
-
|
| 318 |
-
progress_bar_train.close()
|
| 319 |
-
progress_bar_eval = tqdm(
|
| 320 |
-
desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
|
| 321 |
-
)
|
| 322 |
-
for eval_batch in valid_data_loader:
|
| 323 |
-
mp3_audios, wav_audios = eval_batch
|
| 324 |
-
noisy_audios: torch.Tensor = mp3_audios.to(device)
|
| 325 |
-
clean_audios: torch.Tensor = wav_audios.to(device)
|
| 326 |
-
|
| 327 |
-
est_spec, est_wav, est_mask = model.forward(noisy_audios)
|
| 328 |
-
denoise_audios = est_wav
|
| 329 |
-
|
| 330 |
-
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
| 331 |
-
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
| 332 |
-
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
| 333 |
-
|
| 334 |
-
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss
|
| 335 |
-
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
| 336 |
-
logger.info(f"find nan or inf in loss.")
|
| 337 |
-
continue
|
| 338 |
-
|
| 339 |
-
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
| 340 |
-
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
| 341 |
-
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
| 342 |
-
|
| 343 |
-
total_pesq_score += pesq_score
|
| 344 |
-
total_loss += loss.item()
|
| 345 |
-
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
| 346 |
-
total_mask_loss += mask_loss.item()
|
| 347 |
-
total_batches += 1
|
| 348 |
-
|
| 349 |
-
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
| 350 |
-
average_loss = round(total_loss / total_batches, 4)
|
| 351 |
-
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
| 352 |
-
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
| 353 |
-
average_mask_loss = round(total_mask_loss / total_batches, 4)
|
| 354 |
-
|
| 355 |
-
progress_bar_eval.update(1)
|
| 356 |
-
progress_bar_eval.set_postfix({
|
| 357 |
-
"lr": lr_scheduler.get_last_lr()[0],
|
| 358 |
-
"pesq_score": average_pesq_score,
|
| 359 |
-
"loss": average_loss,
|
| 360 |
-
"mr_stft_loss": average_mr_stft_loss,
|
| 361 |
-
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 362 |
-
"mask_loss": average_mask_loss,
|
| 363 |
-
})
|
| 364 |
-
|
| 365 |
-
total_pesq_score = 0.
|
| 366 |
-
total_loss = 0.
|
| 367 |
-
total_mr_stft_loss = 0.
|
| 368 |
-
total_neg_si_snr_loss = 0.
|
| 369 |
-
total_mask_loss = 0.
|
| 370 |
-
total_batches = 0.
|
| 371 |
-
|
| 372 |
-
progress_bar_eval.close()
|
| 373 |
-
progress_bar_train = tqdm(
|
| 374 |
-
initial=progress_bar_train.n,
|
| 375 |
-
postfix=progress_bar_train.postfix,
|
| 376 |
-
desc=progress_bar_train.desc,
|
| 377 |
-
)
|
| 378 |
-
|
| 379 |
-
# save path
|
| 380 |
-
save_dir = serialization_dir / "steps-{}".format(step_idx)
|
| 381 |
-
save_dir.mkdir(parents=True, exist_ok=False)
|
| 382 |
-
|
| 383 |
-
# save models
|
| 384 |
-
model.save_pretrained(save_dir.as_posix())
|
| 385 |
-
|
| 386 |
-
model_list.append(save_dir)
|
| 387 |
-
if len(model_list) >= args.num_serialized_models_to_keep:
|
| 388 |
-
model_to_delete: Path = model_list.pop(0)
|
| 389 |
-
shutil.rmtree(model_to_delete.as_posix())
|
| 390 |
-
|
| 391 |
-
# save metric
|
| 392 |
-
if best_metric is None:
|
| 393 |
-
best_epoch_idx = epoch_idx
|
| 394 |
-
best_step_idx = step_idx
|
| 395 |
-
best_metric = average_pesq_score
|
| 396 |
-
elif average_pesq_score >= best_metric:
|
| 397 |
-
# great is better.
|
| 398 |
-
best_epoch_idx = epoch_idx
|
| 399 |
-
best_step_idx = step_idx
|
| 400 |
-
best_metric = average_pesq_score
|
| 401 |
-
else:
|
| 402 |
-
pass
|
| 403 |
-
|
| 404 |
-
metrics = {
|
| 405 |
-
"epoch_idx": epoch_idx,
|
| 406 |
-
"best_epoch_idx": best_epoch_idx,
|
| 407 |
-
"best_step_idx": best_step_idx,
|
| 408 |
-
"pesq_score": average_pesq_score,
|
| 409 |
-
"loss": average_loss,
|
| 410 |
-
"neg_si_snr_loss": average_neg_si_snr_loss,
|
| 411 |
-
"mask_loss": average_mask_loss,
|
| 412 |
-
}
|
| 413 |
-
metrics_filename = save_dir / "metrics_epoch.json"
|
| 414 |
-
with open(metrics_filename, "w", encoding="utf-8") as f:
|
| 415 |
-
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
| 416 |
-
|
| 417 |
-
# save best
|
| 418 |
-
best_dir = serialization_dir / "best"
|
| 419 |
-
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
| 420 |
-
if best_dir.exists():
|
| 421 |
-
shutil.rmtree(best_dir)
|
| 422 |
-
shutil.copytree(save_dir, best_dir)
|
| 423 |
-
|
| 424 |
-
# early stop
|
| 425 |
-
early_stop_flag = False
|
| 426 |
-
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
| 427 |
-
patience_count = 0
|
| 428 |
-
else:
|
| 429 |
-
patience_count += 1
|
| 430 |
-
if patience_count >= args.patience:
|
| 431 |
-
early_stop_flag = True
|
| 432 |
-
|
| 433 |
-
# early stop
|
| 434 |
-
if early_stop_flag:
|
| 435 |
-
break
|
| 436 |
-
model.train()
|
| 437 |
-
|
| 438 |
-
return
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
if __name__ == "__main__":
|
| 442 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/frcrn_mp3_to_wav/yaml/config-10.yaml
DELETED
|
@@ -1,31 +0,0 @@
|
|
| 1 |
-
model_name: "frcrn"
|
| 2 |
-
|
| 3 |
-
sample_rate: 8000
|
| 4 |
-
segment_size: 32000
|
| 5 |
-
nfft: 128
|
| 6 |
-
win_size: 128
|
| 7 |
-
hop_size: 64
|
| 8 |
-
win_type: hann
|
| 9 |
-
|
| 10 |
-
use_complex_networks: true
|
| 11 |
-
model_depth: 10
|
| 12 |
-
model_complexity: -1
|
| 13 |
-
|
| 14 |
-
min_snr_db: -10
|
| 15 |
-
max_snr_db: 20
|
| 16 |
-
|
| 17 |
-
num_workers: 8
|
| 18 |
-
batch_size: 32
|
| 19 |
-
eval_steps: 20000
|
| 20 |
-
|
| 21 |
-
lr: 0.001
|
| 22 |
-
lr_scheduler: "CosineAnnealingLR"
|
| 23 |
-
lr_scheduler_kwargs:
|
| 24 |
-
T_max: 250000
|
| 25 |
-
eta_min: 0.0001
|
| 26 |
-
|
| 27 |
-
max_epochs: 100
|
| 28 |
-
weight_decay: 1.0e-05
|
| 29 |
-
clip_grad_norm: 10.0
|
| 30 |
-
seed: 1234
|
| 31 |
-
num_gpus: -1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/frcrn_mp3_to_wav/yaml/config-14.yaml
DELETED
|
@@ -1,31 +0,0 @@
|
|
| 1 |
-
model_name: "frcrn"
|
| 2 |
-
|
| 3 |
-
sample_rate: 8000
|
| 4 |
-
segment_size: 32000
|
| 5 |
-
nfft: 640
|
| 6 |
-
win_size: 640
|
| 7 |
-
hop_size: 320
|
| 8 |
-
win_type: hann
|
| 9 |
-
|
| 10 |
-
use_complex_networks: true
|
| 11 |
-
model_depth: 14
|
| 12 |
-
model_complexity: -1
|
| 13 |
-
|
| 14 |
-
min_snr_db: -10
|
| 15 |
-
max_snr_db: 20
|
| 16 |
-
|
| 17 |
-
num_workers: 8
|
| 18 |
-
batch_size: 32
|
| 19 |
-
eval_steps: 10000
|
| 20 |
-
|
| 21 |
-
lr: 0.001
|
| 22 |
-
lr_scheduler: "CosineAnnealingLR"
|
| 23 |
-
lr_scheduler_kwargs:
|
| 24 |
-
T_max: 250000
|
| 25 |
-
eta_min: 0.0001
|
| 26 |
-
|
| 27 |
-
max_epochs: 100
|
| 28 |
-
weight_decay: 1.0e-05
|
| 29 |
-
clip_grad_norm: 10.0
|
| 30 |
-
seed: 1234
|
| 31 |
-
num_gpus: -1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/frcrn_mp3_to_wav/yaml/config-20.yaml
DELETED
|
@@ -1,31 +0,0 @@
|
|
| 1 |
-
model_name: "frcrn"
|
| 2 |
-
|
| 3 |
-
sample_rate: 8000
|
| 4 |
-
segment_size: 32000
|
| 5 |
-
nfft: 512
|
| 6 |
-
win_size: 512
|
| 7 |
-
hop_size: 256
|
| 8 |
-
win_type: hann
|
| 9 |
-
|
| 10 |
-
use_complex_networks: true
|
| 11 |
-
model_depth: 20
|
| 12 |
-
model_complexity: 45
|
| 13 |
-
|
| 14 |
-
min_snr_db: -10
|
| 15 |
-
max_snr_db: 20
|
| 16 |
-
|
| 17 |
-
num_workers: 8
|
| 18 |
-
batch_size: 32
|
| 19 |
-
eval_steps: 10000
|
| 20 |
-
|
| 21 |
-
lr: 0.001
|
| 22 |
-
lr_scheduler: "CosineAnnealingLR"
|
| 23 |
-
lr_scheduler_kwargs:
|
| 24 |
-
T_max: 250000
|
| 25 |
-
eta_min: 0.0001
|
| 26 |
-
|
| 27 |
-
max_epochs: 100
|
| 28 |
-
weight_decay: 1.0e-05
|
| 29 |
-
clip_grad_norm: 10.0
|
| 30 |
-
seed: 1234
|
| 31 |
-
num_gpus: -1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/simple_linear_irm_aishell/run.sh
DELETED
|
@@ -1,172 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env bash
|
| 2 |
-
|
| 3 |
-
: <<'END'
|
| 4 |
-
|
| 5 |
-
sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir
|
| 6 |
-
|
| 7 |
-
sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir
|
| 8 |
-
|
| 9 |
-
sh run.sh --stage 1 --stop_stage 3 --system_version centos --file_folder_name file_dir \
|
| 10 |
-
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
| 11 |
-
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
END
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
# params
|
| 18 |
-
system_version="windows";
|
| 19 |
-
verbose=true;
|
| 20 |
-
stage=0 # start from 0 if you need to start from data preparation
|
| 21 |
-
stop_stage=9
|
| 22 |
-
|
| 23 |
-
work_dir="$(pwd)"
|
| 24 |
-
file_folder_name=file_folder_name
|
| 25 |
-
final_model_name=final_model_name
|
| 26 |
-
config_file="yaml/config.yaml"
|
| 27 |
-
limit=10
|
| 28 |
-
|
| 29 |
-
noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
|
| 30 |
-
speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
|
| 31 |
-
|
| 32 |
-
nohup_name=nohup.out
|
| 33 |
-
|
| 34 |
-
# model params
|
| 35 |
-
batch_size=64
|
| 36 |
-
max_epochs=200
|
| 37 |
-
save_top_k=10
|
| 38 |
-
patience=5
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
# parse options
|
| 42 |
-
while true; do
|
| 43 |
-
[ -z "${1:-}" ] && break; # break if there are no arguments
|
| 44 |
-
case "$1" in
|
| 45 |
-
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
| 46 |
-
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
| 47 |
-
old_value="(eval echo \\$$name)";
|
| 48 |
-
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
| 49 |
-
was_bool=true;
|
| 50 |
-
else
|
| 51 |
-
was_bool=false;
|
| 52 |
-
fi
|
| 53 |
-
|
| 54 |
-
# Set the variable to the right value-- the escaped quotes make it work if
|
| 55 |
-
# the option had spaces, like --cmd "queue.pl -sync y"
|
| 56 |
-
eval "${name}=\"$2\"";
|
| 57 |
-
|
| 58 |
-
# Check that Boolean-valued arguments are really Boolean.
|
| 59 |
-
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
| 60 |
-
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
| 61 |
-
exit 1;
|
| 62 |
-
fi
|
| 63 |
-
shift 2;
|
| 64 |
-
;;
|
| 65 |
-
|
| 66 |
-
*) break;
|
| 67 |
-
esac
|
| 68 |
-
done
|
| 69 |
-
|
| 70 |
-
file_dir="${work_dir}/${file_folder_name}"
|
| 71 |
-
final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
|
| 72 |
-
evaluation_audio_dir="${file_dir}/evaluation_audio"
|
| 73 |
-
|
| 74 |
-
dataset="${file_dir}/dataset.xlsx"
|
| 75 |
-
train_dataset="${file_dir}/train.xlsx"
|
| 76 |
-
valid_dataset="${file_dir}/valid.xlsx"
|
| 77 |
-
|
| 78 |
-
$verbose && echo "system_version: ${system_version}"
|
| 79 |
-
$verbose && echo "file_folder_name: ${file_folder_name}"
|
| 80 |
-
|
| 81 |
-
if [ $system_version == "windows" ]; then
|
| 82 |
-
alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
|
| 83 |
-
elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
|
| 84 |
-
#source /data/local/bin/nx_denoise/bin/activate
|
| 85 |
-
alias python3='/data/local/bin/nx_denoise/bin/python3'
|
| 86 |
-
fi
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
| 90 |
-
$verbose && echo "stage 1: prepare data"
|
| 91 |
-
cd "${work_dir}" || exit 1
|
| 92 |
-
python3 step_1_prepare_data.py \
|
| 93 |
-
--file_dir "${file_dir}" \
|
| 94 |
-
--noise_dir "${noise_dir}" \
|
| 95 |
-
--speech_dir "${speech_dir}" \
|
| 96 |
-
--train_dataset "${train_dataset}" \
|
| 97 |
-
--valid_dataset "${valid_dataset}" \
|
| 98 |
-
|
| 99 |
-
fi
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
| 103 |
-
$verbose && echo "stage 2: train model"
|
| 104 |
-
cd "${work_dir}" || exit 1
|
| 105 |
-
python3 step_2_train_model.py \
|
| 106 |
-
--train_dataset "${train_dataset}" \
|
| 107 |
-
--valid_dataset "${valid_dataset}" \
|
| 108 |
-
--serialization_dir "${file_dir}" \
|
| 109 |
-
--config_file "${config_file}" \
|
| 110 |
-
|
| 111 |
-
fi
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
| 115 |
-
$verbose && echo "stage 3: test model"
|
| 116 |
-
cd "${work_dir}" || exit 1
|
| 117 |
-
python3 step_3_evaluation.py \
|
| 118 |
-
--valid_dataset "${valid_dataset}" \
|
| 119 |
-
--model_dir "${file_dir}/best" \
|
| 120 |
-
--evaluation_audio_dir "${evaluation_audio_dir}" \
|
| 121 |
-
--limit "${limit}" \
|
| 122 |
-
|
| 123 |
-
fi
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
| 127 |
-
$verbose && echo "stage 4: export model"
|
| 128 |
-
cd "${work_dir}" || exit 1
|
| 129 |
-
python3 step_5_export_models.py \
|
| 130 |
-
--vocabulary_dir "${vocabulary_dir}" \
|
| 131 |
-
--model_dir "${file_dir}/best" \
|
| 132 |
-
--serialization_dir "${file_dir}" \
|
| 133 |
-
|
| 134 |
-
fi
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
| 138 |
-
$verbose && echo "stage 5: collect files"
|
| 139 |
-
cd "${work_dir}" || exit 1
|
| 140 |
-
|
| 141 |
-
mkdir -p ${final_model_dir}
|
| 142 |
-
|
| 143 |
-
cp "${file_dir}/best"/* "${final_model_dir}"
|
| 144 |
-
cp -r "${file_dir}/vocabulary" "${final_model_dir}"
|
| 145 |
-
|
| 146 |
-
cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx"
|
| 147 |
-
|
| 148 |
-
cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip"
|
| 149 |
-
cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip"
|
| 150 |
-
cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip"
|
| 151 |
-
cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip"
|
| 152 |
-
|
| 153 |
-
cd "${final_model_dir}/.." || exit 1;
|
| 154 |
-
|
| 155 |
-
if [ -e "${final_model_name}.zip" ]; then
|
| 156 |
-
rm -rf "${final_model_name}_backup.zip"
|
| 157 |
-
mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
|
| 158 |
-
fi
|
| 159 |
-
|
| 160 |
-
zip -r "${final_model_name}.zip" "${final_model_name}"
|
| 161 |
-
rm -rf "${final_model_name}"
|
| 162 |
-
|
| 163 |
-
fi
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
| 167 |
-
$verbose && echo "stage 6: clear file_dir"
|
| 168 |
-
cd "${work_dir}" || exit 1
|
| 169 |
-
|
| 170 |
-
rm -rf "${file_dir}";
|
| 171 |
-
|
| 172 |
-
fi
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/simple_linear_irm_aishell/step_1_prepare_data.py
DELETED
|
@@ -1,196 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/python3
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
import argparse
|
| 4 |
-
import os
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
import random
|
| 7 |
-
import sys
|
| 8 |
-
import shutil
|
| 9 |
-
|
| 10 |
-
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 11 |
-
sys.path.append(os.path.join(pwd, "../../"))
|
| 12 |
-
|
| 13 |
-
import pandas as pd
|
| 14 |
-
from scipy.io import wavfile
|
| 15 |
-
from tqdm import tqdm
|
| 16 |
-
import librosa
|
| 17 |
-
|
| 18 |
-
from project_settings import project_path
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def get_args():
|
| 22 |
-
parser = argparse.ArgumentParser()
|
| 23 |
-
parser.add_argument("--file_dir", default="./", type=str)
|
| 24 |
-
|
| 25 |
-
parser.add_argument(
|
| 26 |
-
"--noise_dir",
|
| 27 |
-
default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
|
| 28 |
-
type=str
|
| 29 |
-
)
|
| 30 |
-
parser.add_argument(
|
| 31 |
-
"--speech_dir",
|
| 32 |
-
default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
|
| 33 |
-
type=str
|
| 34 |
-
)
|
| 35 |
-
|
| 36 |
-
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
| 37 |
-
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
| 38 |
-
|
| 39 |
-
parser.add_argument("--duration", default=2.0, type=float)
|
| 40 |
-
parser.add_argument("--min_nsr_db", default=-20, type=float)
|
| 41 |
-
parser.add_argument("--max_nsr_db", default=5, type=float)
|
| 42 |
-
|
| 43 |
-
parser.add_argument("--target_sample_rate", default=8000, type=int)
|
| 44 |
-
|
| 45 |
-
args = parser.parse_args()
|
| 46 |
-
return args
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def filename_generator(data_dir: str):
|
| 50 |
-
data_dir = Path(data_dir)
|
| 51 |
-
for filename in data_dir.glob("**/*.wav"):
|
| 52 |
-
yield filename.as_posix()
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000):
|
| 56 |
-
data_dir = Path(data_dir)
|
| 57 |
-
for filename in data_dir.glob("**/*.wav"):
|
| 58 |
-
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
| 59 |
-
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
| 60 |
-
|
| 61 |
-
if raw_duration < duration:
|
| 62 |
-
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
| 63 |
-
continue
|
| 64 |
-
if signal.ndim != 1:
|
| 65 |
-
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
| 66 |
-
|
| 67 |
-
signal_length = len(signal)
|
| 68 |
-
win_size = int(duration * sample_rate)
|
| 69 |
-
for begin in range(0, signal_length - win_size, win_size):
|
| 70 |
-
row = {
|
| 71 |
-
"filename": filename.as_posix(),
|
| 72 |
-
"raw_duration": round(raw_duration, 4),
|
| 73 |
-
"offset": round(begin / sample_rate, 4),
|
| 74 |
-
"duration": round(duration, 4),
|
| 75 |
-
}
|
| 76 |
-
yield row
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def get_dataset(args):
|
| 80 |
-
file_dir = Path(args.file_dir)
|
| 81 |
-
file_dir.mkdir(exist_ok=True)
|
| 82 |
-
|
| 83 |
-
noise_dir = Path(args.noise_dir)
|
| 84 |
-
speech_dir = Path(args.speech_dir)
|
| 85 |
-
|
| 86 |
-
noise_generator = target_second_signal_generator(
|
| 87 |
-
noise_dir.as_posix(),
|
| 88 |
-
duration=args.duration,
|
| 89 |
-
sample_rate=args.target_sample_rate
|
| 90 |
-
)
|
| 91 |
-
speech_generator = target_second_signal_generator(
|
| 92 |
-
speech_dir.as_posix(),
|
| 93 |
-
duration=args.duration,
|
| 94 |
-
sample_rate=args.target_sample_rate
|
| 95 |
-
)
|
| 96 |
-
|
| 97 |
-
dataset = list()
|
| 98 |
-
|
| 99 |
-
count = 0
|
| 100 |
-
process_bar = tqdm(desc="build dataset excel")
|
| 101 |
-
for noise, speech in zip(noise_generator, speech_generator):
|
| 102 |
-
|
| 103 |
-
noise_filename = noise["filename"]
|
| 104 |
-
noise_raw_duration = noise["raw_duration"]
|
| 105 |
-
noise_offset = noise["offset"]
|
| 106 |
-
noise_duration = noise["duration"]
|
| 107 |
-
|
| 108 |
-
speech_filename = speech["filename"]
|
| 109 |
-
speech_raw_duration = speech["raw_duration"]
|
| 110 |
-
speech_offset = speech["offset"]
|
| 111 |
-
speech_duration = speech["duration"]
|
| 112 |
-
|
| 113 |
-
random1 = random.random()
|
| 114 |
-
random2 = random.random()
|
| 115 |
-
|
| 116 |
-
row = {
|
| 117 |
-
"noise_filename": noise_filename,
|
| 118 |
-
"noise_raw_duration": noise_raw_duration,
|
| 119 |
-
"noise_offset": noise_offset,
|
| 120 |
-
"noise_duration": noise_duration,
|
| 121 |
-
|
| 122 |
-
"speech_filename": speech_filename,
|
| 123 |
-
"speech_raw_duration": speech_raw_duration,
|
| 124 |
-
"speech_offset": speech_offset,
|
| 125 |
-
"speech_duration": speech_duration,
|
| 126 |
-
|
| 127 |
-
"snr_db": random.uniform(args.min_nsr_db, args.max_nsr_db),
|
| 128 |
-
|
| 129 |
-
"random1": random1,
|
| 130 |
-
"random2": random2,
|
| 131 |
-
"flag": "TRAIN" if random2 < 0.8 else "TEST",
|
| 132 |
-
}
|
| 133 |
-
dataset.append(row)
|
| 134 |
-
count += 1
|
| 135 |
-
duration_seconds = count * args.duration
|
| 136 |
-
duration_hours = duration_seconds / 3600
|
| 137 |
-
|
| 138 |
-
process_bar.update(n=1)
|
| 139 |
-
process_bar.set_postfix({
|
| 140 |
-
# "duration_seconds": round(duration_seconds, 4),
|
| 141 |
-
"duration_hours": round(duration_hours, 4),
|
| 142 |
-
})
|
| 143 |
-
|
| 144 |
-
dataset = pd.DataFrame(dataset)
|
| 145 |
-
dataset = dataset.sort_values(by=["random1"], ascending=False)
|
| 146 |
-
dataset.to_excel(
|
| 147 |
-
file_dir / "dataset.xlsx",
|
| 148 |
-
index=False,
|
| 149 |
-
)
|
| 150 |
-
return
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
def split_dataset(args):
|
| 155 |
-
"""分割训练集, 测试集"""
|
| 156 |
-
file_dir = Path(args.file_dir)
|
| 157 |
-
file_dir.mkdir(exist_ok=True)
|
| 158 |
-
|
| 159 |
-
df = pd.read_excel(file_dir / "dataset.xlsx")
|
| 160 |
-
|
| 161 |
-
train = list()
|
| 162 |
-
test = list()
|
| 163 |
-
|
| 164 |
-
for i, row in df.iterrows():
|
| 165 |
-
flag = row["flag"]
|
| 166 |
-
if flag == "TRAIN":
|
| 167 |
-
train.append(row)
|
| 168 |
-
else:
|
| 169 |
-
test.append(row)
|
| 170 |
-
|
| 171 |
-
train = pd.DataFrame(train)
|
| 172 |
-
train.to_excel(
|
| 173 |
-
args.train_dataset,
|
| 174 |
-
index=False,
|
| 175 |
-
# encoding="utf_8_sig"
|
| 176 |
-
)
|
| 177 |
-
test = pd.DataFrame(test)
|
| 178 |
-
test.to_excel(
|
| 179 |
-
args.valid_dataset,
|
| 180 |
-
index=False,
|
| 181 |
-
# encoding="utf_8_sig"
|
| 182 |
-
)
|
| 183 |
-
|
| 184 |
-
return
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
def main():
|
| 188 |
-
args = get_args()
|
| 189 |
-
|
| 190 |
-
get_dataset(args)
|
| 191 |
-
split_dataset(args)
|
| 192 |
-
return
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
if __name__ == "__main__":
|
| 196 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/simple_linear_irm_aishell/step_2_train_model.py
DELETED
|
@@ -1,348 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/python3
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
"""
|
| 4 |
-
https://github.com/WenzheLiu-Speech/awesome-speech-enhancement
|
| 5 |
-
"""
|
| 6 |
-
import argparse
|
| 7 |
-
import json
|
| 8 |
-
import logging
|
| 9 |
-
from logging.handlers import TimedRotatingFileHandler
|
| 10 |
-
import os
|
| 11 |
-
import platform
|
| 12 |
-
from pathlib import Path
|
| 13 |
-
import random
|
| 14 |
-
import sys
|
| 15 |
-
import shutil
|
| 16 |
-
from typing import List
|
| 17 |
-
|
| 18 |
-
from torch import dtype
|
| 19 |
-
|
| 20 |
-
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 21 |
-
sys.path.append(os.path.join(pwd, "../../"))
|
| 22 |
-
|
| 23 |
-
import numpy as np
|
| 24 |
-
import torch
|
| 25 |
-
import torch.nn as nn
|
| 26 |
-
from torch.utils.data.dataloader import DataLoader
|
| 27 |
-
import torchaudio
|
| 28 |
-
from tqdm import tqdm
|
| 29 |
-
|
| 30 |
-
from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
|
| 31 |
-
from toolbox.torchaudio.models.simple_linear_irm.configuration_simple_linear_irm import SimpleLinearIRMConfig
|
| 32 |
-
from toolbox.torchaudio.models.simple_linear_irm.modeling_simple_linear_irm import SimpleLinearIRMPretrainedModel
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def get_args():
|
| 36 |
-
parser = argparse.ArgumentParser()
|
| 37 |
-
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
| 38 |
-
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
| 39 |
-
|
| 40 |
-
parser.add_argument("--max_epochs", default=100, type=int)
|
| 41 |
-
|
| 42 |
-
parser.add_argument("--batch_size", default=64, type=int)
|
| 43 |
-
parser.add_argument("--learning_rate", default=1e-3, type=float)
|
| 44 |
-
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
| 45 |
-
parser.add_argument("--patience", default=5, type=int)
|
| 46 |
-
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
| 47 |
-
parser.add_argument("--seed", default=0, type=int)
|
| 48 |
-
|
| 49 |
-
parser.add_argument("--config_file", default="config.yaml", type=str)
|
| 50 |
-
|
| 51 |
-
args = parser.parse_args()
|
| 52 |
-
return args
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def logging_config(file_dir: str):
|
| 56 |
-
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
| 57 |
-
|
| 58 |
-
logging.basicConfig(format=fmt,
|
| 59 |
-
datefmt="%m/%d/%Y %H:%M:%S",
|
| 60 |
-
level=logging.INFO)
|
| 61 |
-
file_handler = TimedRotatingFileHandler(
|
| 62 |
-
filename=os.path.join(file_dir, "main.log"),
|
| 63 |
-
encoding="utf-8",
|
| 64 |
-
when="D",
|
| 65 |
-
interval=1,
|
| 66 |
-
backupCount=7
|
| 67 |
-
)
|
| 68 |
-
file_handler.setLevel(logging.INFO)
|
| 69 |
-
file_handler.setFormatter(logging.Formatter(fmt))
|
| 70 |
-
logger = logging.getLogger(__name__)
|
| 71 |
-
logger.addHandler(file_handler)
|
| 72 |
-
|
| 73 |
-
return logger
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
class CollateFunction(object):
|
| 77 |
-
def __init__(self,
|
| 78 |
-
n_fft: int = 512,
|
| 79 |
-
win_length: int = 200,
|
| 80 |
-
hop_length: int = 80,
|
| 81 |
-
window_fn: str = "hamming",
|
| 82 |
-
irm_beta: float = 1.0,
|
| 83 |
-
epsilon: float = 1e-8,
|
| 84 |
-
):
|
| 85 |
-
self.n_fft = n_fft
|
| 86 |
-
self.win_length = win_length
|
| 87 |
-
self.hop_length = hop_length
|
| 88 |
-
self.window_fn = window_fn
|
| 89 |
-
self.irm_beta = irm_beta
|
| 90 |
-
self.epsilon = epsilon
|
| 91 |
-
|
| 92 |
-
self.transform = torchaudio.transforms.Spectrogram(
|
| 93 |
-
n_fft=self.n_fft,
|
| 94 |
-
win_length=self.win_length,
|
| 95 |
-
hop_length=self.hop_length,
|
| 96 |
-
power=2.0,
|
| 97 |
-
window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
|
| 98 |
-
)
|
| 99 |
-
|
| 100 |
-
def __call__(self, batch: List[dict]):
|
| 101 |
-
mix_spec_list = list()
|
| 102 |
-
speech_irm_list = list()
|
| 103 |
-
snr_db_list = list()
|
| 104 |
-
for sample in batch:
|
| 105 |
-
noise_wave: torch.Tensor = sample["noise_wave"]
|
| 106 |
-
speech_wave: torch.Tensor = sample["speech_wave"]
|
| 107 |
-
mix_wave: torch.Tensor = sample["mix_wave"]
|
| 108 |
-
snr_db: float = sample["snr_db"]
|
| 109 |
-
|
| 110 |
-
noise_spec = self.transform.forward(noise_wave)
|
| 111 |
-
speech_spec = self.transform.forward(speech_wave)
|
| 112 |
-
mix_spec = self.transform.forward(mix_wave)
|
| 113 |
-
|
| 114 |
-
# noise_irm = noise_spec / (noise_spec + speech_spec)
|
| 115 |
-
speech_irm = speech_spec / (noise_spec + speech_spec + self.epsilon)
|
| 116 |
-
speech_irm = torch.pow(speech_irm, self.irm_beta)
|
| 117 |
-
|
| 118 |
-
mix_spec_list.append(mix_spec)
|
| 119 |
-
speech_irm_list.append(speech_irm)
|
| 120 |
-
snr_db_list.append(torch.tensor(snr_db, dtype=torch.float32))
|
| 121 |
-
|
| 122 |
-
mix_spec_list = torch.stack(mix_spec_list)
|
| 123 |
-
speech_irm_list = torch.stack(speech_irm_list)
|
| 124 |
-
snr_db_list = torch.stack(snr_db_list) # shape: (batch_size,)
|
| 125 |
-
|
| 126 |
-
# assert
|
| 127 |
-
if torch.any(torch.isnan(mix_spec_list)):
|
| 128 |
-
raise AssertionError("nan in mix_spec Tensor")
|
| 129 |
-
if torch.any(torch.isnan(speech_irm_list)):
|
| 130 |
-
raise AssertionError("nan in speech_irm Tensor")
|
| 131 |
-
if torch.any(torch.isnan(snr_db_list)):
|
| 132 |
-
raise AssertionError("nan in snr_db Tensor")
|
| 133 |
-
|
| 134 |
-
return mix_spec_list, speech_irm_list, snr_db_list
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
collate_fn = CollateFunction()
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
def main():
|
| 141 |
-
args = get_args()
|
| 142 |
-
|
| 143 |
-
serialization_dir = Path(args.serialization_dir)
|
| 144 |
-
serialization_dir.mkdir(parents=True, exist_ok=True)
|
| 145 |
-
|
| 146 |
-
logger = logging_config(serialization_dir)
|
| 147 |
-
|
| 148 |
-
random.seed(args.seed)
|
| 149 |
-
np.random.seed(args.seed)
|
| 150 |
-
torch.manual_seed(args.seed)
|
| 151 |
-
logger.info("set seed: {}".format(args.seed))
|
| 152 |
-
|
| 153 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 154 |
-
n_gpu = torch.cuda.device_count()
|
| 155 |
-
logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
|
| 156 |
-
|
| 157 |
-
# datasets
|
| 158 |
-
logger.info("prepare datasets")
|
| 159 |
-
train_dataset = DenoiseExcelDataset(
|
| 160 |
-
excel_file=args.train_dataset,
|
| 161 |
-
expected_sample_rate=8000,
|
| 162 |
-
max_wave_value=32768.0,
|
| 163 |
-
)
|
| 164 |
-
valid_dataset = DenoiseExcelDataset(
|
| 165 |
-
excel_file=args.valid_dataset,
|
| 166 |
-
expected_sample_rate=8000,
|
| 167 |
-
max_wave_value=32768.0,
|
| 168 |
-
)
|
| 169 |
-
train_data_loader = DataLoader(
|
| 170 |
-
dataset=train_dataset,
|
| 171 |
-
batch_size=args.batch_size,
|
| 172 |
-
shuffle=True,
|
| 173 |
-
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
| 174 |
-
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
| 175 |
-
collate_fn=collate_fn,
|
| 176 |
-
pin_memory=False,
|
| 177 |
-
# prefetch_factor=64,
|
| 178 |
-
)
|
| 179 |
-
valid_data_loader = DataLoader(
|
| 180 |
-
dataset=valid_dataset,
|
| 181 |
-
batch_size=args.batch_size,
|
| 182 |
-
shuffle=True,
|
| 183 |
-
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
| 184 |
-
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
| 185 |
-
collate_fn=collate_fn,
|
| 186 |
-
pin_memory=False,
|
| 187 |
-
# prefetch_factor=64,
|
| 188 |
-
)
|
| 189 |
-
|
| 190 |
-
# models
|
| 191 |
-
logger.info(f"prepare models. config_file: {args.config_file}")
|
| 192 |
-
config = SimpleLinearIRMConfig.from_pretrained(
|
| 193 |
-
pretrained_model_name_or_path=args.config_file,
|
| 194 |
-
# num_labels=vocabulary.get_vocab_size(namespace="labels")
|
| 195 |
-
)
|
| 196 |
-
model = SimpleLinearIRMPretrainedModel(
|
| 197 |
-
config=config,
|
| 198 |
-
)
|
| 199 |
-
model.to(device)
|
| 200 |
-
model.train()
|
| 201 |
-
|
| 202 |
-
# optimizer
|
| 203 |
-
logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
|
| 204 |
-
param_optimizer = model.parameters()
|
| 205 |
-
optimizer = torch.optim.Adam(
|
| 206 |
-
param_optimizer,
|
| 207 |
-
lr=args.learning_rate,
|
| 208 |
-
)
|
| 209 |
-
# lr_scheduler = torch.optim.lr_scheduler.StepLR(
|
| 210 |
-
# optimizer,
|
| 211 |
-
# step_size=2000
|
| 212 |
-
# )
|
| 213 |
-
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
| 214 |
-
optimizer,
|
| 215 |
-
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
|
| 216 |
-
)
|
| 217 |
-
mse_loss = nn.MSELoss(
|
| 218 |
-
reduction="mean",
|
| 219 |
-
)
|
| 220 |
-
|
| 221 |
-
# training loop
|
| 222 |
-
logger.info("training")
|
| 223 |
-
|
| 224 |
-
training_loss = 10000000000
|
| 225 |
-
evaluation_loss = 10000000000
|
| 226 |
-
|
| 227 |
-
model_list = list()
|
| 228 |
-
best_idx_epoch = None
|
| 229 |
-
best_metric = None
|
| 230 |
-
patience_count = 0
|
| 231 |
-
|
| 232 |
-
for idx_epoch in range(args.max_epochs):
|
| 233 |
-
total_loss = 0.
|
| 234 |
-
total_examples = 0.
|
| 235 |
-
progress_bar = tqdm(
|
| 236 |
-
total=len(train_data_loader),
|
| 237 |
-
desc="Training; epoch: {}".format(idx_epoch),
|
| 238 |
-
)
|
| 239 |
-
|
| 240 |
-
for batch in train_data_loader:
|
| 241 |
-
mix_spec, speech_irm, snr_db = batch
|
| 242 |
-
mix_spec = mix_spec.to(device)
|
| 243 |
-
speech_irm_target = speech_irm.to(device)
|
| 244 |
-
snr_db_target = snr_db.to(device)
|
| 245 |
-
|
| 246 |
-
speech_irm_prediction = model.forward(mix_spec)
|
| 247 |
-
loss = mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
| 248 |
-
|
| 249 |
-
total_loss += loss.item()
|
| 250 |
-
total_examples += mix_spec.size(0)
|
| 251 |
-
|
| 252 |
-
optimizer.zero_grad()
|
| 253 |
-
loss.backward()
|
| 254 |
-
optimizer.step()
|
| 255 |
-
lr_scheduler.step()
|
| 256 |
-
|
| 257 |
-
training_loss = total_loss / total_examples
|
| 258 |
-
training_loss = round(training_loss, 4)
|
| 259 |
-
|
| 260 |
-
progress_bar.update(1)
|
| 261 |
-
progress_bar.set_postfix({
|
| 262 |
-
"training_loss": training_loss,
|
| 263 |
-
})
|
| 264 |
-
|
| 265 |
-
total_loss = 0.
|
| 266 |
-
total_examples = 0.
|
| 267 |
-
progress_bar = tqdm(
|
| 268 |
-
total=len(valid_data_loader),
|
| 269 |
-
desc="Evaluation; epoch: {}".format(idx_epoch),
|
| 270 |
-
)
|
| 271 |
-
for batch in valid_data_loader:
|
| 272 |
-
mix_spec, speech_irm, snr_db = batch
|
| 273 |
-
mix_spec = mix_spec.to(device)
|
| 274 |
-
speech_irm_target = speech_irm.to(device)
|
| 275 |
-
snr_db_target = snr_db.to(device)
|
| 276 |
-
|
| 277 |
-
with torch.no_grad():
|
| 278 |
-
speech_irm_prediction = model.forward(mix_spec)
|
| 279 |
-
loss = mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
| 280 |
-
|
| 281 |
-
total_loss += loss.item()
|
| 282 |
-
total_examples += mix_spec.size(0)
|
| 283 |
-
|
| 284 |
-
evaluation_loss = total_loss / total_examples
|
| 285 |
-
evaluation_loss = round(evaluation_loss, 4)
|
| 286 |
-
|
| 287 |
-
progress_bar.update(1)
|
| 288 |
-
progress_bar.set_postfix({
|
| 289 |
-
"evaluation_loss": evaluation_loss,
|
| 290 |
-
})
|
| 291 |
-
|
| 292 |
-
# save path
|
| 293 |
-
epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
|
| 294 |
-
epoch_dir.mkdir(parents=True, exist_ok=False)
|
| 295 |
-
|
| 296 |
-
# save models
|
| 297 |
-
model.save_pretrained(epoch_dir.as_posix())
|
| 298 |
-
|
| 299 |
-
model_list.append(epoch_dir)
|
| 300 |
-
if len(model_list) >= args.num_serialized_models_to_keep:
|
| 301 |
-
model_to_delete: Path = model_list.pop(0)
|
| 302 |
-
shutil.rmtree(model_to_delete.as_posix())
|
| 303 |
-
|
| 304 |
-
# save metric
|
| 305 |
-
if best_metric is None:
|
| 306 |
-
best_idx_epoch = idx_epoch
|
| 307 |
-
best_metric = evaluation_loss
|
| 308 |
-
elif evaluation_loss < best_metric:
|
| 309 |
-
best_idx_epoch = idx_epoch
|
| 310 |
-
best_metric = evaluation_loss
|
| 311 |
-
else:
|
| 312 |
-
pass
|
| 313 |
-
|
| 314 |
-
metrics = {
|
| 315 |
-
"idx_epoch": idx_epoch,
|
| 316 |
-
"best_idx_epoch": best_idx_epoch,
|
| 317 |
-
"training_loss": training_loss,
|
| 318 |
-
"evaluation_loss": evaluation_loss,
|
| 319 |
-
"learning_rate": optimizer.param_groups[0]["lr"],
|
| 320 |
-
}
|
| 321 |
-
metrics_filename = epoch_dir / "metrics_epoch.json"
|
| 322 |
-
with open(metrics_filename, "w", encoding="utf-8") as f:
|
| 323 |
-
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
| 324 |
-
|
| 325 |
-
# save best
|
| 326 |
-
best_dir = serialization_dir / "best"
|
| 327 |
-
if best_idx_epoch == idx_epoch:
|
| 328 |
-
if best_dir.exists():
|
| 329 |
-
shutil.rmtree(best_dir)
|
| 330 |
-
shutil.copytree(epoch_dir, best_dir)
|
| 331 |
-
|
| 332 |
-
# early stop
|
| 333 |
-
early_stop_flag = False
|
| 334 |
-
if best_idx_epoch == idx_epoch:
|
| 335 |
-
patience_count = 0
|
| 336 |
-
else:
|
| 337 |
-
patience_count += 1
|
| 338 |
-
if patience_count >= args.patience:
|
| 339 |
-
early_stop_flag = True
|
| 340 |
-
|
| 341 |
-
# early stop
|
| 342 |
-
if early_stop_flag:
|
| 343 |
-
break
|
| 344 |
-
return
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
if __name__ == '__main__':
|
| 348 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/simple_linear_irm_aishell/step_3_evaluation.py
DELETED
|
@@ -1,239 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/python3
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
import argparse
|
| 4 |
-
import logging
|
| 5 |
-
import os
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
import sys
|
| 8 |
-
import uuid
|
| 9 |
-
|
| 10 |
-
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 11 |
-
sys.path.append(os.path.join(pwd, "../../"))
|
| 12 |
-
|
| 13 |
-
import librosa
|
| 14 |
-
import numpy as np
|
| 15 |
-
import pandas as pd
|
| 16 |
-
from scipy.io import wavfile
|
| 17 |
-
import torch
|
| 18 |
-
import torch.nn as nn
|
| 19 |
-
import torchaudio
|
| 20 |
-
from tqdm import tqdm
|
| 21 |
-
|
| 22 |
-
from toolbox.torchaudio.models.simple_linear_irm.modeling_simple_linear_irm import SimpleLinearIRMPretrainedModel
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def get_args():
|
| 26 |
-
parser = argparse.ArgumentParser()
|
| 27 |
-
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
| 28 |
-
parser.add_argument("--model_dir", default="serialization_dir/best", type=str)
|
| 29 |
-
parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str)
|
| 30 |
-
|
| 31 |
-
parser.add_argument("--limit", default=10, type=int)
|
| 32 |
-
|
| 33 |
-
args = parser.parse_args()
|
| 34 |
-
return args
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def logging_config():
|
| 38 |
-
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
| 39 |
-
|
| 40 |
-
logging.basicConfig(format=fmt,
|
| 41 |
-
datefmt="%m/%d/%Y %H:%M:%S",
|
| 42 |
-
level=logging.INFO)
|
| 43 |
-
stream_handler = logging.StreamHandler()
|
| 44 |
-
stream_handler.setLevel(logging.INFO)
|
| 45 |
-
stream_handler.setFormatter(logging.Formatter(fmt))
|
| 46 |
-
|
| 47 |
-
logger = logging.getLogger(__name__)
|
| 48 |
-
|
| 49 |
-
return logger
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float):
|
| 53 |
-
l1 = len(speech)
|
| 54 |
-
l2 = len(noise)
|
| 55 |
-
l = min(l1, l2)
|
| 56 |
-
speech = speech[:l]
|
| 57 |
-
noise = noise[:l]
|
| 58 |
-
|
| 59 |
-
# np.float32, value between (-1, 1).
|
| 60 |
-
|
| 61 |
-
speech_power = np.mean(np.square(speech))
|
| 62 |
-
noise_power = speech_power / (10 ** (snr_db / 10))
|
| 63 |
-
|
| 64 |
-
noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2))
|
| 65 |
-
|
| 66 |
-
noisy_signal = speech + noise_adjusted
|
| 67 |
-
|
| 68 |
-
return noisy_signal
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
stft_power = torchaudio.transforms.Spectrogram(
|
| 72 |
-
n_fft=512,
|
| 73 |
-
win_length=200,
|
| 74 |
-
hop_length=80,
|
| 75 |
-
power=2.0,
|
| 76 |
-
window_fn=torch.hamming_window,
|
| 77 |
-
)
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
stft_complex = torchaudio.transforms.Spectrogram(
|
| 81 |
-
n_fft=512,
|
| 82 |
-
win_length=200,
|
| 83 |
-
hop_length=80,
|
| 84 |
-
power=None,
|
| 85 |
-
window_fn=torch.hamming_window,
|
| 86 |
-
)
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
istft = torchaudio.transforms.InverseSpectrogram(
|
| 90 |
-
n_fft=512,
|
| 91 |
-
win_length=200,
|
| 92 |
-
hop_length=80,
|
| 93 |
-
window_fn=torch.hamming_window,
|
| 94 |
-
)
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
def enhance(mix_spec_complex: torch.Tensor, speech_irm_prediction: torch.Tensor):
|
| 98 |
-
mix_spec_complex = mix_spec_complex.detach().cpu()
|
| 99 |
-
speech_irm_prediction = speech_irm_prediction.detach().cpu()
|
| 100 |
-
|
| 101 |
-
mask_speech = speech_irm_prediction
|
| 102 |
-
mask_noise = 1.0 - speech_irm_prediction
|
| 103 |
-
|
| 104 |
-
speech_spec = mix_spec_complex * mask_speech
|
| 105 |
-
noise_spec = mix_spec_complex * mask_noise
|
| 106 |
-
|
| 107 |
-
speech_wave = istft.forward(speech_spec)
|
| 108 |
-
noise_wave = istft.forward(noise_spec)
|
| 109 |
-
|
| 110 |
-
return speech_wave, noise_wave
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
def save_audios(noise_wave: torch.Tensor,
|
| 114 |
-
speech_wave: torch.Tensor,
|
| 115 |
-
mix_wave: torch.Tensor,
|
| 116 |
-
speech_wave_enhanced: torch.Tensor,
|
| 117 |
-
noise_wave_enhanced: torch.Tensor,
|
| 118 |
-
output_dir: str,
|
| 119 |
-
sample_rate: int = 8000,
|
| 120 |
-
):
|
| 121 |
-
basename = uuid.uuid4().__str__()
|
| 122 |
-
output_dir = Path(output_dir) / basename
|
| 123 |
-
output_dir.mkdir(parents=True, exist_ok=True)
|
| 124 |
-
|
| 125 |
-
filename = output_dir / "noise_wave.wav"
|
| 126 |
-
torchaudio.save(filename, noise_wave, sample_rate)
|
| 127 |
-
filename = output_dir / "speech_wave.wav"
|
| 128 |
-
torchaudio.save(filename, speech_wave, sample_rate)
|
| 129 |
-
filename = output_dir / "mix_wave.wav"
|
| 130 |
-
torchaudio.save(filename, mix_wave, sample_rate)
|
| 131 |
-
|
| 132 |
-
filename = output_dir / "speech_wave_enhanced.wav"
|
| 133 |
-
torchaudio.save(filename, speech_wave_enhanced, sample_rate)
|
| 134 |
-
filename = output_dir / "noise_wave_enhanced.wav"
|
| 135 |
-
torchaudio.save(filename, noise_wave_enhanced, sample_rate)
|
| 136 |
-
|
| 137 |
-
return output_dir.as_posix()
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
def main():
|
| 141 |
-
args = get_args()
|
| 142 |
-
|
| 143 |
-
logger = logging_config()
|
| 144 |
-
|
| 145 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 146 |
-
n_gpu = torch.cuda.device_count()
|
| 147 |
-
logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
|
| 148 |
-
|
| 149 |
-
logger.info("prepare model")
|
| 150 |
-
model = SimpleLinearIRMPretrainedModel.from_pretrained(
|
| 151 |
-
pretrained_model_name_or_path=args.model_dir,
|
| 152 |
-
)
|
| 153 |
-
model.to(device)
|
| 154 |
-
model.eval()
|
| 155 |
-
|
| 156 |
-
# optimizer
|
| 157 |
-
logger.info("prepare loss_fn")
|
| 158 |
-
mse_loss = nn.MSELoss(
|
| 159 |
-
reduction="mean",
|
| 160 |
-
)
|
| 161 |
-
|
| 162 |
-
logger.info("read excel")
|
| 163 |
-
df = pd.read_excel(args.valid_dataset)
|
| 164 |
-
|
| 165 |
-
total_loss = 0.
|
| 166 |
-
total_examples = 0.
|
| 167 |
-
progress_bar = tqdm(total=len(df), desc="Evaluation")
|
| 168 |
-
for idx, row in df.iterrows():
|
| 169 |
-
noise_filename = row["noise_filename"]
|
| 170 |
-
noise_offset = row["noise_offset"]
|
| 171 |
-
noise_duration = row["noise_duration"]
|
| 172 |
-
|
| 173 |
-
speech_filename = row["speech_filename"]
|
| 174 |
-
speech_offset = row["speech_offset"]
|
| 175 |
-
speech_duration = row["speech_duration"]
|
| 176 |
-
|
| 177 |
-
snr_db = row["snr_db"]
|
| 178 |
-
|
| 179 |
-
noise_wave, _ = librosa.load(
|
| 180 |
-
noise_filename,
|
| 181 |
-
sr=8000,
|
| 182 |
-
offset=noise_offset,
|
| 183 |
-
duration=noise_duration,
|
| 184 |
-
)
|
| 185 |
-
speech_wave, _ = librosa.load(
|
| 186 |
-
speech_filename,
|
| 187 |
-
sr=8000,
|
| 188 |
-
offset=speech_offset,
|
| 189 |
-
duration=speech_duration,
|
| 190 |
-
)
|
| 191 |
-
mix_wave: np.ndarray = mix_speech_and_noise(
|
| 192 |
-
speech=speech_wave,
|
| 193 |
-
noise=noise_wave,
|
| 194 |
-
snr_db=snr_db,
|
| 195 |
-
)
|
| 196 |
-
noise_wave = torch.tensor(noise_wave, dtype=torch.float32)
|
| 197 |
-
speech_wave = torch.tensor(speech_wave, dtype=torch.float32)
|
| 198 |
-
mix_wave: torch.Tensor = torch.tensor(mix_wave, dtype=torch.float32)
|
| 199 |
-
|
| 200 |
-
noise_wave = noise_wave.unsqueeze(dim=0)
|
| 201 |
-
speech_wave = speech_wave.unsqueeze(dim=0)
|
| 202 |
-
mix_wave = mix_wave.unsqueeze(dim=0)
|
| 203 |
-
|
| 204 |
-
noise_spec: torch.Tensor = stft_power.forward(noise_wave)
|
| 205 |
-
speech_spec: torch.Tensor = stft_power.forward(speech_wave)
|
| 206 |
-
mix_spec: torch.Tensor = stft_power.forward(mix_wave)
|
| 207 |
-
mix_spec_complex: torch.Tensor = stft_complex.forward(mix_wave)
|
| 208 |
-
|
| 209 |
-
speech_irm = speech_spec / (noise_spec + speech_spec)
|
| 210 |
-
speech_irm = torch.pow(speech_irm, 1.0)
|
| 211 |
-
|
| 212 |
-
mix_spec = mix_spec.to(device)
|
| 213 |
-
speech_irm_target = speech_irm.to(device)
|
| 214 |
-
with torch.no_grad():
|
| 215 |
-
speech_irm_prediction = model.forward(mix_spec)
|
| 216 |
-
loss = mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
| 217 |
-
|
| 218 |
-
speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_irm_prediction)
|
| 219 |
-
save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir)
|
| 220 |
-
|
| 221 |
-
total_loss += loss.item()
|
| 222 |
-
total_examples += mix_spec.size(0)
|
| 223 |
-
|
| 224 |
-
evaluation_loss = total_loss / total_examples
|
| 225 |
-
evaluation_loss = round(evaluation_loss, 4)
|
| 226 |
-
|
| 227 |
-
progress_bar.update(1)
|
| 228 |
-
progress_bar.set_postfix({
|
| 229 |
-
"evaluation_loss": evaluation_loss,
|
| 230 |
-
})
|
| 231 |
-
|
| 232 |
-
if idx > args.limit:
|
| 233 |
-
break
|
| 234 |
-
|
| 235 |
-
return
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
if __name__ == '__main__':
|
| 239 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/simple_linear_irm_aishell/yaml/config.yaml
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
model_name: "simple_linear_irm"
|
| 2 |
-
|
| 3 |
-
# spec
|
| 4 |
-
sample_rate: 8000
|
| 5 |
-
n_fft: 512
|
| 6 |
-
win_length: 200
|
| 7 |
-
hop_length: 80
|
| 8 |
-
|
| 9 |
-
# model
|
| 10 |
-
num_bins: 257
|
| 11 |
-
hidden_size: 2048
|
| 12 |
-
lookback: 3
|
| 13 |
-
lookahead: 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/spectrum_dfnet_aishell/run.sh
DELETED
|
@@ -1,178 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env bash
|
| 2 |
-
|
| 3 |
-
: <<'END'
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir \
|
| 7 |
-
--noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
|
| 8 |
-
--speech_dir "E:/programmer/asr_datasets/aishell/data_aishell/wav/train"
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir \
|
| 12 |
-
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
| 13 |
-
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
| 14 |
-
|
| 15 |
-
sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name file_dir \
|
| 16 |
-
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
| 17 |
-
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
END
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
# params
|
| 24 |
-
system_version="windows";
|
| 25 |
-
verbose=true;
|
| 26 |
-
stage=0 # start from 0 if you need to start from data preparation
|
| 27 |
-
stop_stage=9
|
| 28 |
-
|
| 29 |
-
work_dir="$(pwd)"
|
| 30 |
-
file_folder_name=file_folder_name
|
| 31 |
-
final_model_name=final_model_name
|
| 32 |
-
config_file="yaml/config.yaml"
|
| 33 |
-
limit=10
|
| 34 |
-
|
| 35 |
-
noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
|
| 36 |
-
speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
|
| 37 |
-
|
| 38 |
-
nohup_name=nohup.out
|
| 39 |
-
|
| 40 |
-
# model params
|
| 41 |
-
batch_size=64
|
| 42 |
-
max_epochs=200
|
| 43 |
-
save_top_k=10
|
| 44 |
-
patience=5
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
# parse options
|
| 48 |
-
while true; do
|
| 49 |
-
[ -z "${1:-}" ] && break; # break if there are no arguments
|
| 50 |
-
case "$1" in
|
| 51 |
-
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
| 52 |
-
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
| 53 |
-
old_value="(eval echo \\$$name)";
|
| 54 |
-
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
| 55 |
-
was_bool=true;
|
| 56 |
-
else
|
| 57 |
-
was_bool=false;
|
| 58 |
-
fi
|
| 59 |
-
|
| 60 |
-
# Set the variable to the right value-- the escaped quotes make it work if
|
| 61 |
-
# the option had spaces, like --cmd "queue.pl -sync y"
|
| 62 |
-
eval "${name}=\"$2\"";
|
| 63 |
-
|
| 64 |
-
# Check that Boolean-valued arguments are really Boolean.
|
| 65 |
-
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
| 66 |
-
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
| 67 |
-
exit 1;
|
| 68 |
-
fi
|
| 69 |
-
shift 2;
|
| 70 |
-
;;
|
| 71 |
-
|
| 72 |
-
*) break;
|
| 73 |
-
esac
|
| 74 |
-
done
|
| 75 |
-
|
| 76 |
-
file_dir="${work_dir}/${file_folder_name}"
|
| 77 |
-
final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
|
| 78 |
-
evaluation_audio_dir="${file_dir}/evaluation_audio"
|
| 79 |
-
|
| 80 |
-
dataset="${file_dir}/dataset.xlsx"
|
| 81 |
-
train_dataset="${file_dir}/train.xlsx"
|
| 82 |
-
valid_dataset="${file_dir}/valid.xlsx"
|
| 83 |
-
|
| 84 |
-
$verbose && echo "system_version: ${system_version}"
|
| 85 |
-
$verbose && echo "file_folder_name: ${file_folder_name}"
|
| 86 |
-
|
| 87 |
-
if [ $system_version == "windows" ]; then
|
| 88 |
-
alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
|
| 89 |
-
elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
|
| 90 |
-
#source /data/local/bin/nx_denoise/bin/activate
|
| 91 |
-
alias python3='/data/local/bin/nx_denoise/bin/python3'
|
| 92 |
-
fi
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
| 96 |
-
$verbose && echo "stage 1: prepare data"
|
| 97 |
-
cd "${work_dir}" || exit 1
|
| 98 |
-
python3 step_1_prepare_data.py \
|
| 99 |
-
--file_dir "${file_dir}" \
|
| 100 |
-
--noise_dir "${noise_dir}" \
|
| 101 |
-
--speech_dir "${speech_dir}" \
|
| 102 |
-
--train_dataset "${train_dataset}" \
|
| 103 |
-
--valid_dataset "${valid_dataset}" \
|
| 104 |
-
|
| 105 |
-
fi
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
| 109 |
-
$verbose && echo "stage 2: train model"
|
| 110 |
-
cd "${work_dir}" || exit 1
|
| 111 |
-
python3 step_2_train_model.py \
|
| 112 |
-
--train_dataset "${train_dataset}" \
|
| 113 |
-
--valid_dataset "${valid_dataset}" \
|
| 114 |
-
--serialization_dir "${file_dir}" \
|
| 115 |
-
--config_file "${config_file}" \
|
| 116 |
-
|
| 117 |
-
fi
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
| 121 |
-
$verbose && echo "stage 3: test model"
|
| 122 |
-
cd "${work_dir}" || exit 1
|
| 123 |
-
python3 step_3_evaluation.py \
|
| 124 |
-
--valid_dataset "${valid_dataset}" \
|
| 125 |
-
--model_dir "${file_dir}/best" \
|
| 126 |
-
--evaluation_audio_dir "${evaluation_audio_dir}" \
|
| 127 |
-
--limit "${limit}" \
|
| 128 |
-
|
| 129 |
-
fi
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
| 133 |
-
$verbose && echo "stage 4: export model"
|
| 134 |
-
cd "${work_dir}" || exit 1
|
| 135 |
-
python3 step_5_export_models.py \
|
| 136 |
-
--vocabulary_dir "${vocabulary_dir}" \
|
| 137 |
-
--model_dir "${file_dir}/best" \
|
| 138 |
-
--serialization_dir "${file_dir}" \
|
| 139 |
-
|
| 140 |
-
fi
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
| 144 |
-
$verbose && echo "stage 5: collect files"
|
| 145 |
-
cd "${work_dir}" || exit 1
|
| 146 |
-
|
| 147 |
-
mkdir -p ${final_model_dir}
|
| 148 |
-
|
| 149 |
-
cp "${file_dir}/best"/* "${final_model_dir}"
|
| 150 |
-
cp -r "${file_dir}/vocabulary" "${final_model_dir}"
|
| 151 |
-
|
| 152 |
-
cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx"
|
| 153 |
-
|
| 154 |
-
cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip"
|
| 155 |
-
cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip"
|
| 156 |
-
cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip"
|
| 157 |
-
cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip"
|
| 158 |
-
|
| 159 |
-
cd "${final_model_dir}/.." || exit 1;
|
| 160 |
-
|
| 161 |
-
if [ -e "${final_model_name}.zip" ]; then
|
| 162 |
-
rm -rf "${final_model_name}_backup.zip"
|
| 163 |
-
mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
|
| 164 |
-
fi
|
| 165 |
-
|
| 166 |
-
zip -r "${final_model_name}.zip" "${final_model_name}"
|
| 167 |
-
rm -rf "${final_model_name}"
|
| 168 |
-
|
| 169 |
-
fi
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
| 173 |
-
$verbose && echo "stage 6: clear file_dir"
|
| 174 |
-
cd "${work_dir}" || exit 1
|
| 175 |
-
|
| 176 |
-
rm -rf "${file_dir}";
|
| 177 |
-
|
| 178 |
-
fi
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/spectrum_dfnet_aishell/step_1_prepare_data.py
DELETED
|
@@ -1,197 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/python3
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
import argparse
|
| 4 |
-
import os
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
import random
|
| 7 |
-
import sys
|
| 8 |
-
import shutil
|
| 9 |
-
|
| 10 |
-
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 11 |
-
sys.path.append(os.path.join(pwd, "../../"))
|
| 12 |
-
|
| 13 |
-
import pandas as pd
|
| 14 |
-
from scipy.io import wavfile
|
| 15 |
-
from tqdm import tqdm
|
| 16 |
-
import librosa
|
| 17 |
-
|
| 18 |
-
from project_settings import project_path
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def get_args():
|
| 22 |
-
parser = argparse.ArgumentParser()
|
| 23 |
-
parser.add_argument("--file_dir", default="./", type=str)
|
| 24 |
-
|
| 25 |
-
parser.add_argument(
|
| 26 |
-
"--noise_dir",
|
| 27 |
-
default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
|
| 28 |
-
type=str
|
| 29 |
-
)
|
| 30 |
-
parser.add_argument(
|
| 31 |
-
"--speech_dir",
|
| 32 |
-
default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
|
| 33 |
-
type=str
|
| 34 |
-
)
|
| 35 |
-
|
| 36 |
-
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
| 37 |
-
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
| 38 |
-
|
| 39 |
-
parser.add_argument("--duration", default=2.0, type=float)
|
| 40 |
-
parser.add_argument("--min_snr_db", default=-10, type=float)
|
| 41 |
-
parser.add_argument("--max_snr_db", default=20, type=float)
|
| 42 |
-
|
| 43 |
-
parser.add_argument("--target_sample_rate", default=8000, type=int)
|
| 44 |
-
|
| 45 |
-
args = parser.parse_args()
|
| 46 |
-
return args
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def filename_generator(data_dir: str):
|
| 50 |
-
data_dir = Path(data_dir)
|
| 51 |
-
for filename in data_dir.glob("**/*.wav"):
|
| 52 |
-
yield filename.as_posix()
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000):
|
| 56 |
-
data_dir = Path(data_dir)
|
| 57 |
-
for filename in data_dir.glob("**/*.wav"):
|
| 58 |
-
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
| 59 |
-
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
| 60 |
-
|
| 61 |
-
if raw_duration < duration:
|
| 62 |
-
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
| 63 |
-
continue
|
| 64 |
-
if signal.ndim != 1:
|
| 65 |
-
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
| 66 |
-
|
| 67 |
-
signal_length = len(signal)
|
| 68 |
-
win_size = int(duration * sample_rate)
|
| 69 |
-
for begin in range(0, signal_length - win_size, win_size):
|
| 70 |
-
row = {
|
| 71 |
-
"filename": filename.as_posix(),
|
| 72 |
-
"raw_duration": round(raw_duration, 4),
|
| 73 |
-
"offset": round(begin / sample_rate, 4),
|
| 74 |
-
"duration": round(duration, 4),
|
| 75 |
-
}
|
| 76 |
-
yield row
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def get_dataset(args):
|
| 80 |
-
file_dir = Path(args.file_dir)
|
| 81 |
-
file_dir.mkdir(exist_ok=True)
|
| 82 |
-
|
| 83 |
-
noise_dir = Path(args.noise_dir)
|
| 84 |
-
speech_dir = Path(args.speech_dir)
|
| 85 |
-
|
| 86 |
-
noise_generator = target_second_signal_generator(
|
| 87 |
-
noise_dir.as_posix(),
|
| 88 |
-
duration=args.duration,
|
| 89 |
-
sample_rate=args.target_sample_rate
|
| 90 |
-
)
|
| 91 |
-
speech_generator = target_second_signal_generator(
|
| 92 |
-
speech_dir.as_posix(),
|
| 93 |
-
duration=args.duration,
|
| 94 |
-
sample_rate=args.target_sample_rate
|
| 95 |
-
)
|
| 96 |
-
|
| 97 |
-
dataset = list()
|
| 98 |
-
|
| 99 |
-
count = 0
|
| 100 |
-
process_bar = tqdm(desc="build dataset excel")
|
| 101 |
-
for noise, speech in zip(noise_generator, speech_generator):
|
| 102 |
-
|
| 103 |
-
noise_filename = noise["filename"]
|
| 104 |
-
noise_raw_duration = noise["raw_duration"]
|
| 105 |
-
noise_offset = noise["offset"]
|
| 106 |
-
noise_duration = noise["duration"]
|
| 107 |
-
|
| 108 |
-
speech_filename = speech["filename"]
|
| 109 |
-
speech_raw_duration = speech["raw_duration"]
|
| 110 |
-
speech_offset = speech["offset"]
|
| 111 |
-
speech_duration = speech["duration"]
|
| 112 |
-
|
| 113 |
-
random1 = random.random()
|
| 114 |
-
random2 = random.random()
|
| 115 |
-
|
| 116 |
-
row = {
|
| 117 |
-
"noise_filename": noise_filename,
|
| 118 |
-
"noise_raw_duration": noise_raw_duration,
|
| 119 |
-
"noise_offset": noise_offset,
|
| 120 |
-
"noise_duration": noise_duration,
|
| 121 |
-
|
| 122 |
-
"speech_filename": speech_filename,
|
| 123 |
-
"speech_raw_duration": speech_raw_duration,
|
| 124 |
-
"speech_offset": speech_offset,
|
| 125 |
-
"speech_duration": speech_duration,
|
| 126 |
-
|
| 127 |
-
"snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
|
| 128 |
-
|
| 129 |
-
"random1": random1,
|
| 130 |
-
"random2": random2,
|
| 131 |
-
"flag": "TRAIN" if random2 < 0.8 else "TEST",
|
| 132 |
-
}
|
| 133 |
-
dataset.append(row)
|
| 134 |
-
count += 1
|
| 135 |
-
duration_seconds = count * args.duration
|
| 136 |
-
duration_hours = duration_seconds / 3600
|
| 137 |
-
|
| 138 |
-
process_bar.update(n=1)
|
| 139 |
-
process_bar.set_postfix({
|
| 140 |
-
# "duration_seconds": round(duration_seconds, 4),
|
| 141 |
-
"duration_hours": round(duration_hours, 4),
|
| 142 |
-
|
| 143 |
-
})
|
| 144 |
-
|
| 145 |
-
dataset = pd.DataFrame(dataset)
|
| 146 |
-
dataset = dataset.sort_values(by=["random1"], ascending=False)
|
| 147 |
-
dataset.to_excel(
|
| 148 |
-
file_dir / "dataset.xlsx",
|
| 149 |
-
index=False,
|
| 150 |
-
)
|
| 151 |
-
return
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
def split_dataset(args):
|
| 156 |
-
"""分割训练集, 测试集"""
|
| 157 |
-
file_dir = Path(args.file_dir)
|
| 158 |
-
file_dir.mkdir(exist_ok=True)
|
| 159 |
-
|
| 160 |
-
df = pd.read_excel(file_dir / "dataset.xlsx")
|
| 161 |
-
|
| 162 |
-
train = list()
|
| 163 |
-
test = list()
|
| 164 |
-
|
| 165 |
-
for i, row in df.iterrows():
|
| 166 |
-
flag = row["flag"]
|
| 167 |
-
if flag == "TRAIN":
|
| 168 |
-
train.append(row)
|
| 169 |
-
else:
|
| 170 |
-
test.append(row)
|
| 171 |
-
|
| 172 |
-
train = pd.DataFrame(train)
|
| 173 |
-
train.to_excel(
|
| 174 |
-
args.train_dataset,
|
| 175 |
-
index=False,
|
| 176 |
-
# encoding="utf_8_sig"
|
| 177 |
-
)
|
| 178 |
-
test = pd.DataFrame(test)
|
| 179 |
-
test.to_excel(
|
| 180 |
-
args.valid_dataset,
|
| 181 |
-
index=False,
|
| 182 |
-
# encoding="utf_8_sig"
|
| 183 |
-
)
|
| 184 |
-
|
| 185 |
-
return
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
def main():
|
| 189 |
-
args = get_args()
|
| 190 |
-
|
| 191 |
-
get_dataset(args)
|
| 192 |
-
split_dataset(args)
|
| 193 |
-
return
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
if __name__ == "__main__":
|
| 197 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/spectrum_dfnet_aishell/step_2_train_model.py
DELETED
|
@@ -1,440 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/python3
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
"""
|
| 4 |
-
https://github.com/WenzheLiu-Speech/awesome-speech-enhancement
|
| 5 |
-
"""
|
| 6 |
-
import argparse
|
| 7 |
-
import json
|
| 8 |
-
import logging
|
| 9 |
-
from logging.handlers import TimedRotatingFileHandler
|
| 10 |
-
import os
|
| 11 |
-
import platform
|
| 12 |
-
from pathlib import Path
|
| 13 |
-
import random
|
| 14 |
-
import sys
|
| 15 |
-
import shutil
|
| 16 |
-
from typing import List
|
| 17 |
-
|
| 18 |
-
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 19 |
-
sys.path.append(os.path.join(pwd, "../../"))
|
| 20 |
-
|
| 21 |
-
import numpy as np
|
| 22 |
-
import torch
|
| 23 |
-
import torch.nn as nn
|
| 24 |
-
from torch.nn import functional as F
|
| 25 |
-
from torch.utils.data.dataloader import DataLoader
|
| 26 |
-
import torchaudio
|
| 27 |
-
from tqdm import tqdm
|
| 28 |
-
|
| 29 |
-
from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
|
| 30 |
-
from toolbox.torchaudio.models.spectrum_dfnet.configuration_spectrum_dfnet import SpectrumDfNetConfig
|
| 31 |
-
from toolbox.torchaudio.models.spectrum_dfnet.modeling_spectrum_dfnet import SpectrumDfNetPretrainedModel
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def get_args():
|
| 35 |
-
parser = argparse.ArgumentParser()
|
| 36 |
-
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
| 37 |
-
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
| 38 |
-
|
| 39 |
-
parser.add_argument("--max_epochs", default=100, type=int)
|
| 40 |
-
|
| 41 |
-
parser.add_argument("--batch_size", default=16, type=int)
|
| 42 |
-
parser.add_argument("--learning_rate", default=1e-4, type=float)
|
| 43 |
-
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
| 44 |
-
parser.add_argument("--patience", default=5, type=int)
|
| 45 |
-
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
| 46 |
-
parser.add_argument("--seed", default=0, type=int)
|
| 47 |
-
|
| 48 |
-
parser.add_argument("--config_file", default="config.yaml", type=str)
|
| 49 |
-
|
| 50 |
-
args = parser.parse_args()
|
| 51 |
-
return args
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def logging_config(file_dir: str):
|
| 55 |
-
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
| 56 |
-
|
| 57 |
-
logging.basicConfig(format=fmt,
|
| 58 |
-
datefmt="%m/%d/%Y %H:%M:%S",
|
| 59 |
-
level=logging.INFO)
|
| 60 |
-
file_handler = TimedRotatingFileHandler(
|
| 61 |
-
filename=os.path.join(file_dir, "main.log"),
|
| 62 |
-
encoding="utf-8",
|
| 63 |
-
when="D",
|
| 64 |
-
interval=1,
|
| 65 |
-
backupCount=7
|
| 66 |
-
)
|
| 67 |
-
file_handler.setLevel(logging.INFO)
|
| 68 |
-
file_handler.setFormatter(logging.Formatter(fmt))
|
| 69 |
-
logger = logging.getLogger(__name__)
|
| 70 |
-
logger.addHandler(file_handler)
|
| 71 |
-
|
| 72 |
-
return logger
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
class CollateFunction(object):
|
| 76 |
-
def __init__(self,
|
| 77 |
-
n_fft: int = 512,
|
| 78 |
-
win_length: int = 200,
|
| 79 |
-
hop_length: int = 80,
|
| 80 |
-
window_fn: str = "hamming",
|
| 81 |
-
irm_beta: float = 1.0,
|
| 82 |
-
epsilon: float = 1e-8,
|
| 83 |
-
):
|
| 84 |
-
self.n_fft = n_fft
|
| 85 |
-
self.win_length = win_length
|
| 86 |
-
self.hop_length = hop_length
|
| 87 |
-
self.window_fn = window_fn
|
| 88 |
-
self.irm_beta = irm_beta
|
| 89 |
-
self.epsilon = epsilon
|
| 90 |
-
|
| 91 |
-
self.complex_transform = torchaudio.transforms.Spectrogram(
|
| 92 |
-
n_fft=self.n_fft,
|
| 93 |
-
win_length=self.win_length,
|
| 94 |
-
hop_length=self.hop_length,
|
| 95 |
-
power=None,
|
| 96 |
-
window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
|
| 97 |
-
)
|
| 98 |
-
self.transform = torchaudio.transforms.Spectrogram(
|
| 99 |
-
n_fft=self.n_fft,
|
| 100 |
-
win_length=self.win_length,
|
| 101 |
-
hop_length=self.hop_length,
|
| 102 |
-
power=2.0,
|
| 103 |
-
window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
|
| 104 |
-
)
|
| 105 |
-
|
| 106 |
-
@staticmethod
|
| 107 |
-
def make_unfold_snr_db(x: torch.Tensor, n_time_steps: int = 3):
|
| 108 |
-
batch_size, channels, freq_dim, time_steps = x.shape
|
| 109 |
-
|
| 110 |
-
# kernel: [freq_dim, n_time_step]
|
| 111 |
-
kernel_size = (freq_dim, n_time_steps)
|
| 112 |
-
|
| 113 |
-
# pad
|
| 114 |
-
pad = n_time_steps // 2
|
| 115 |
-
x = torch.concat(tensors=[
|
| 116 |
-
x[:, :, :, :pad],
|
| 117 |
-
x,
|
| 118 |
-
x[:, :, :, -pad:],
|
| 119 |
-
], dim=-1)
|
| 120 |
-
|
| 121 |
-
x = F.unfold(
|
| 122 |
-
input=x,
|
| 123 |
-
kernel_size=kernel_size,
|
| 124 |
-
)
|
| 125 |
-
# x shape: [batch_size, fold, time_steps]
|
| 126 |
-
return x
|
| 127 |
-
|
| 128 |
-
def __call__(self, batch: List[dict]):
|
| 129 |
-
speech_complex_spec_list = list()
|
| 130 |
-
mix_complex_spec_list = list()
|
| 131 |
-
speech_irm_list = list()
|
| 132 |
-
snr_db_list = list()
|
| 133 |
-
for sample in batch:
|
| 134 |
-
noise_wave: torch.Tensor = sample["noise_wave"]
|
| 135 |
-
speech_wave: torch.Tensor = sample["speech_wave"]
|
| 136 |
-
mix_wave: torch.Tensor = sample["mix_wave"]
|
| 137 |
-
# snr_db: float = sample["snr_db"]
|
| 138 |
-
|
| 139 |
-
noise_spec = self.transform.forward(noise_wave)
|
| 140 |
-
speech_spec = self.transform.forward(speech_wave)
|
| 141 |
-
|
| 142 |
-
speech_complex_spec = self.complex_transform.forward(speech_wave)
|
| 143 |
-
mix_complex_spec = self.complex_transform.forward(mix_wave)
|
| 144 |
-
|
| 145 |
-
# noise_irm = noise_spec / (noise_spec + speech_spec)
|
| 146 |
-
speech_irm = speech_spec / (noise_spec + speech_spec + self.epsilon)
|
| 147 |
-
speech_irm = torch.pow(speech_irm, self.irm_beta)
|
| 148 |
-
|
| 149 |
-
# noise_spec, speech_spec, mix_spec, speech_irm
|
| 150 |
-
# shape: [freq_dim, time_steps]
|
| 151 |
-
|
| 152 |
-
snr_db: torch.Tensor = 10 * torch.log10(
|
| 153 |
-
speech_spec / (noise_spec + self.epsilon)
|
| 154 |
-
)
|
| 155 |
-
snr_db = torch.clamp(snr_db, min=self.epsilon)
|
| 156 |
-
|
| 157 |
-
snr_db_ = torch.unsqueeze(snr_db, dim=0)
|
| 158 |
-
snr_db_ = torch.unsqueeze(snr_db_, dim=0)
|
| 159 |
-
snr_db_ = self.make_unfold_snr_db(snr_db_, n_time_steps=3)
|
| 160 |
-
snr_db_ = torch.squeeze(snr_db_, dim=0)
|
| 161 |
-
# snr_db_ shape: [fold, time_steps]
|
| 162 |
-
|
| 163 |
-
snr_db = torch.mean(snr_db_, dim=0, keepdim=True)
|
| 164 |
-
# snr_db shape: [1, time_steps]
|
| 165 |
-
|
| 166 |
-
speech_complex_spec_list.append(speech_complex_spec)
|
| 167 |
-
mix_complex_spec_list.append(mix_complex_spec)
|
| 168 |
-
speech_irm_list.append(speech_irm)
|
| 169 |
-
snr_db_list.append(snr_db)
|
| 170 |
-
|
| 171 |
-
speech_complex_spec_list = torch.stack(speech_complex_spec_list)
|
| 172 |
-
mix_complex_spec_list = torch.stack(mix_complex_spec_list)
|
| 173 |
-
speech_irm_list = torch.stack(speech_irm_list)
|
| 174 |
-
snr_db_list = torch.stack(snr_db_list) # shape: (batch_size, time_steps, 1)
|
| 175 |
-
|
| 176 |
-
speech_complex_spec_list = speech_complex_spec_list[:, :-1, :]
|
| 177 |
-
mix_complex_spec_list = mix_complex_spec_list[:, :-1, :]
|
| 178 |
-
speech_irm_list = speech_irm_list[:, :-1, :]
|
| 179 |
-
|
| 180 |
-
# speech_complex_spec_list shape: [batch_size, freq_dim, time_steps]
|
| 181 |
-
# mix_complex_spec_list shape: [batch_size, freq_dim, time_steps]
|
| 182 |
-
# speech_irm_list shape: [batch_size, freq_dim, time_steps]
|
| 183 |
-
# snr_db shape: [batch_size, 1, time_steps]
|
| 184 |
-
|
| 185 |
-
# assert
|
| 186 |
-
if torch.any(torch.isnan(speech_complex_spec_list)) or torch.any(torch.isinf(speech_complex_spec_list)):
|
| 187 |
-
raise AssertionError("nan or inf in speech_complex_spec_list")
|
| 188 |
-
if torch.any(torch.isnan(mix_complex_spec_list)) or torch.any(torch.isinf(mix_complex_spec_list)):
|
| 189 |
-
raise AssertionError("nan or inf in mix_complex_spec_list")
|
| 190 |
-
if torch.any(torch.isnan(speech_irm_list)) or torch.any(torch.isinf(speech_irm_list)):
|
| 191 |
-
raise AssertionError("nan or inf in speech_irm_list")
|
| 192 |
-
if torch.any(torch.isnan(snr_db_list)) or torch.any(torch.isinf(snr_db_list)):
|
| 193 |
-
raise AssertionError("nan or inf in snr_db_list")
|
| 194 |
-
|
| 195 |
-
return speech_complex_spec_list, mix_complex_spec_list, speech_irm_list, snr_db_list
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
collate_fn = CollateFunction()
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
def main():
|
| 202 |
-
args = get_args()
|
| 203 |
-
|
| 204 |
-
serialization_dir = Path(args.serialization_dir)
|
| 205 |
-
serialization_dir.mkdir(parents=True, exist_ok=True)
|
| 206 |
-
|
| 207 |
-
logger = logging_config(serialization_dir)
|
| 208 |
-
|
| 209 |
-
random.seed(args.seed)
|
| 210 |
-
np.random.seed(args.seed)
|
| 211 |
-
torch.manual_seed(args.seed)
|
| 212 |
-
logger.info("set seed: {}".format(args.seed))
|
| 213 |
-
|
| 214 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 215 |
-
n_gpu = torch.cuda.device_count()
|
| 216 |
-
logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
|
| 217 |
-
|
| 218 |
-
# datasets
|
| 219 |
-
logger.info("prepare datasets")
|
| 220 |
-
train_dataset = DenoiseExcelDataset(
|
| 221 |
-
excel_file=args.train_dataset,
|
| 222 |
-
expected_sample_rate=8000,
|
| 223 |
-
max_wave_value=32768.0,
|
| 224 |
-
)
|
| 225 |
-
valid_dataset = DenoiseExcelDataset(
|
| 226 |
-
excel_file=args.valid_dataset,
|
| 227 |
-
expected_sample_rate=8000,
|
| 228 |
-
max_wave_value=32768.0,
|
| 229 |
-
)
|
| 230 |
-
train_data_loader = DataLoader(
|
| 231 |
-
dataset=train_dataset,
|
| 232 |
-
batch_size=args.batch_size,
|
| 233 |
-
shuffle=True,
|
| 234 |
-
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
| 235 |
-
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
| 236 |
-
collate_fn=collate_fn,
|
| 237 |
-
pin_memory=False,
|
| 238 |
-
# prefetch_factor=64,
|
| 239 |
-
)
|
| 240 |
-
valid_data_loader = DataLoader(
|
| 241 |
-
dataset=valid_dataset,
|
| 242 |
-
batch_size=args.batch_size,
|
| 243 |
-
shuffle=True,
|
| 244 |
-
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
| 245 |
-
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
| 246 |
-
collate_fn=collate_fn,
|
| 247 |
-
pin_memory=False,
|
| 248 |
-
# prefetch_factor=64,
|
| 249 |
-
)
|
| 250 |
-
|
| 251 |
-
# models
|
| 252 |
-
logger.info(f"prepare models. config_file: {args.config_file}")
|
| 253 |
-
config = SpectrumDfNetConfig.from_pretrained(
|
| 254 |
-
pretrained_model_name_or_path=args.config_file,
|
| 255 |
-
# num_labels=vocabulary.get_vocab_size(namespace="labels")
|
| 256 |
-
)
|
| 257 |
-
model = SpectrumDfNetPretrainedModel(
|
| 258 |
-
config=config,
|
| 259 |
-
)
|
| 260 |
-
model.to(device)
|
| 261 |
-
model.train()
|
| 262 |
-
|
| 263 |
-
# optimizer
|
| 264 |
-
logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
|
| 265 |
-
param_optimizer = model.parameters()
|
| 266 |
-
optimizer = torch.optim.Adam(
|
| 267 |
-
param_optimizer,
|
| 268 |
-
lr=args.learning_rate,
|
| 269 |
-
)
|
| 270 |
-
# lr_scheduler = torch.optim.lr_scheduler.StepLR(
|
| 271 |
-
# optimizer,
|
| 272 |
-
# step_size=2000
|
| 273 |
-
# )
|
| 274 |
-
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
| 275 |
-
optimizer,
|
| 276 |
-
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
|
| 277 |
-
)
|
| 278 |
-
|
| 279 |
-
speech_mse_loss = nn.MSELoss(
|
| 280 |
-
reduction="mean",
|
| 281 |
-
)
|
| 282 |
-
irm_mse_loss = nn.MSELoss(
|
| 283 |
-
reduction="mean",
|
| 284 |
-
)
|
| 285 |
-
snr_mse_loss = nn.MSELoss(
|
| 286 |
-
reduction="mean",
|
| 287 |
-
)
|
| 288 |
-
|
| 289 |
-
# training loop
|
| 290 |
-
logger.info("training")
|
| 291 |
-
|
| 292 |
-
training_loss = 10000000000
|
| 293 |
-
evaluation_loss = 10000000000
|
| 294 |
-
|
| 295 |
-
model_list = list()
|
| 296 |
-
best_idx_epoch = None
|
| 297 |
-
best_metric = None
|
| 298 |
-
patience_count = 0
|
| 299 |
-
|
| 300 |
-
for idx_epoch in range(args.max_epochs):
|
| 301 |
-
total_loss = 0.
|
| 302 |
-
total_examples = 0.
|
| 303 |
-
progress_bar = tqdm(
|
| 304 |
-
total=len(train_data_loader),
|
| 305 |
-
desc="Training; epoch: {}".format(idx_epoch),
|
| 306 |
-
)
|
| 307 |
-
|
| 308 |
-
for batch in train_data_loader:
|
| 309 |
-
speech_complex_spec, mix_complex_spec, speech_irm, snr_db = batch
|
| 310 |
-
speech_complex_spec = speech_complex_spec.to(device)
|
| 311 |
-
mix_complex_spec = mix_complex_spec.to(device)
|
| 312 |
-
speech_irm_target = speech_irm.to(device)
|
| 313 |
-
snr_db_target = snr_db.to(device)
|
| 314 |
-
|
| 315 |
-
speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
|
| 316 |
-
if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
|
| 317 |
-
raise AssertionError("nan or inf in speech_spec_prediction")
|
| 318 |
-
if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
|
| 319 |
-
raise AssertionError("nan or inf in speech_irm_prediction")
|
| 320 |
-
if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
|
| 321 |
-
raise AssertionError("nan or inf in lsnr_prediction")
|
| 322 |
-
|
| 323 |
-
speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
|
| 324 |
-
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
| 325 |
-
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
| 326 |
-
|
| 327 |
-
loss = speech_loss + irm_loss + snr_loss
|
| 328 |
-
|
| 329 |
-
total_loss += loss.item()
|
| 330 |
-
total_examples += mix_complex_spec.size(0)
|
| 331 |
-
|
| 332 |
-
optimizer.zero_grad()
|
| 333 |
-
loss.backward()
|
| 334 |
-
optimizer.step()
|
| 335 |
-
lr_scheduler.step()
|
| 336 |
-
|
| 337 |
-
training_loss = total_loss / total_examples
|
| 338 |
-
training_loss = round(training_loss, 4)
|
| 339 |
-
|
| 340 |
-
progress_bar.update(1)
|
| 341 |
-
progress_bar.set_postfix({
|
| 342 |
-
"training_loss": training_loss,
|
| 343 |
-
})
|
| 344 |
-
|
| 345 |
-
total_loss = 0.
|
| 346 |
-
total_examples = 0.
|
| 347 |
-
progress_bar = tqdm(
|
| 348 |
-
total=len(valid_data_loader),
|
| 349 |
-
desc="Evaluation; epoch: {}".format(idx_epoch),
|
| 350 |
-
)
|
| 351 |
-
for batch in valid_data_loader:
|
| 352 |
-
speech_complex_spec, mix_complex_spec, speech_irm, snr_db = batch
|
| 353 |
-
speech_complex_spec = speech_complex_spec.to(device)
|
| 354 |
-
mix_complex_spec = mix_complex_spec.to(device)
|
| 355 |
-
speech_irm_target = speech_irm.to(device)
|
| 356 |
-
snr_db_target = snr_db.to(device)
|
| 357 |
-
|
| 358 |
-
with torch.no_grad():
|
| 359 |
-
speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
|
| 360 |
-
if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
|
| 361 |
-
raise AssertionError("nan or inf in speech_spec_prediction")
|
| 362 |
-
if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
|
| 363 |
-
raise AssertionError("nan or inf in speech_irm_prediction")
|
| 364 |
-
if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
|
| 365 |
-
raise AssertionError("nan or inf in lsnr_prediction")
|
| 366 |
-
|
| 367 |
-
speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
|
| 368 |
-
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
| 369 |
-
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
| 370 |
-
|
| 371 |
-
loss = speech_loss + irm_loss + snr_loss
|
| 372 |
-
|
| 373 |
-
total_loss += loss.item()
|
| 374 |
-
total_examples += mix_complex_spec.size(0)
|
| 375 |
-
|
| 376 |
-
evaluation_loss = total_loss / total_examples
|
| 377 |
-
evaluation_loss = round(evaluation_loss, 4)
|
| 378 |
-
|
| 379 |
-
progress_bar.update(1)
|
| 380 |
-
progress_bar.set_postfix({
|
| 381 |
-
"evaluation_loss": evaluation_loss,
|
| 382 |
-
})
|
| 383 |
-
|
| 384 |
-
# save path
|
| 385 |
-
epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
|
| 386 |
-
epoch_dir.mkdir(parents=True, exist_ok=False)
|
| 387 |
-
|
| 388 |
-
# save models
|
| 389 |
-
model.save_pretrained(epoch_dir.as_posix())
|
| 390 |
-
|
| 391 |
-
model_list.append(epoch_dir)
|
| 392 |
-
if len(model_list) >= args.num_serialized_models_to_keep:
|
| 393 |
-
model_to_delete: Path = model_list.pop(0)
|
| 394 |
-
shutil.rmtree(model_to_delete.as_posix())
|
| 395 |
-
|
| 396 |
-
# save metric
|
| 397 |
-
if best_metric is None:
|
| 398 |
-
best_idx_epoch = idx_epoch
|
| 399 |
-
best_metric = evaluation_loss
|
| 400 |
-
elif evaluation_loss < best_metric:
|
| 401 |
-
best_idx_epoch = idx_epoch
|
| 402 |
-
best_metric = evaluation_loss
|
| 403 |
-
else:
|
| 404 |
-
pass
|
| 405 |
-
|
| 406 |
-
metrics = {
|
| 407 |
-
"idx_epoch": idx_epoch,
|
| 408 |
-
"best_idx_epoch": best_idx_epoch,
|
| 409 |
-
"training_loss": training_loss,
|
| 410 |
-
"evaluation_loss": evaluation_loss,
|
| 411 |
-
"learning_rate": optimizer.param_groups[0]["lr"],
|
| 412 |
-
}
|
| 413 |
-
metrics_filename = epoch_dir / "metrics_epoch.json"
|
| 414 |
-
with open(metrics_filename, "w", encoding="utf-8") as f:
|
| 415 |
-
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
| 416 |
-
|
| 417 |
-
# save best
|
| 418 |
-
best_dir = serialization_dir / "best"
|
| 419 |
-
if best_idx_epoch == idx_epoch:
|
| 420 |
-
if best_dir.exists():
|
| 421 |
-
shutil.rmtree(best_dir)
|
| 422 |
-
shutil.copytree(epoch_dir, best_dir)
|
| 423 |
-
|
| 424 |
-
# early stop
|
| 425 |
-
early_stop_flag = False
|
| 426 |
-
if best_idx_epoch == idx_epoch:
|
| 427 |
-
patience_count = 0
|
| 428 |
-
else:
|
| 429 |
-
patience_count += 1
|
| 430 |
-
if patience_count >= args.patience:
|
| 431 |
-
early_stop_flag = True
|
| 432 |
-
|
| 433 |
-
# early stop
|
| 434 |
-
if early_stop_flag:
|
| 435 |
-
break
|
| 436 |
-
return
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
if __name__ == '__main__':
|
| 440 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/spectrum_dfnet_aishell/step_3_evaluation.py
DELETED
|
@@ -1,302 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/python3
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
import argparse
|
| 4 |
-
import logging
|
| 5 |
-
import os
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
import sys
|
| 8 |
-
import uuid
|
| 9 |
-
|
| 10 |
-
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 11 |
-
sys.path.append(os.path.join(pwd, "../../"))
|
| 12 |
-
|
| 13 |
-
import librosa
|
| 14 |
-
import numpy as np
|
| 15 |
-
import pandas as pd
|
| 16 |
-
from scipy.io import wavfile
|
| 17 |
-
import torch
|
| 18 |
-
import torch.nn as nn
|
| 19 |
-
import torchaudio
|
| 20 |
-
from tqdm import tqdm
|
| 21 |
-
|
| 22 |
-
from toolbox.torchaudio.models.spectrum_dfnet.modeling_spectrum_dfnet import SpectrumDfNetPretrainedModel
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def get_args():
|
| 26 |
-
parser = argparse.ArgumentParser()
|
| 27 |
-
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
| 28 |
-
parser.add_argument("--model_dir", default="serialization_dir/best", type=str)
|
| 29 |
-
parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str)
|
| 30 |
-
|
| 31 |
-
parser.add_argument("--limit", default=10, type=int)
|
| 32 |
-
|
| 33 |
-
args = parser.parse_args()
|
| 34 |
-
return args
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def logging_config():
|
| 38 |
-
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
| 39 |
-
|
| 40 |
-
logging.basicConfig(format=fmt,
|
| 41 |
-
datefmt="%m/%d/%Y %H:%M:%S",
|
| 42 |
-
level=logging.INFO)
|
| 43 |
-
stream_handler = logging.StreamHandler()
|
| 44 |
-
stream_handler.setLevel(logging.INFO)
|
| 45 |
-
stream_handler.setFormatter(logging.Formatter(fmt))
|
| 46 |
-
|
| 47 |
-
logger = logging.getLogger(__name__)
|
| 48 |
-
|
| 49 |
-
return logger
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float):
|
| 53 |
-
l1 = len(speech)
|
| 54 |
-
l2 = len(noise)
|
| 55 |
-
l = min(l1, l2)
|
| 56 |
-
speech = speech[:l]
|
| 57 |
-
noise = noise[:l]
|
| 58 |
-
|
| 59 |
-
# np.float32, value between (-1, 1).
|
| 60 |
-
|
| 61 |
-
speech_power = np.mean(np.square(speech))
|
| 62 |
-
noise_power = speech_power / (10 ** (snr_db / 10))
|
| 63 |
-
|
| 64 |
-
noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2))
|
| 65 |
-
|
| 66 |
-
noisy_signal = speech + noise_adjusted
|
| 67 |
-
|
| 68 |
-
return noisy_signal
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
stft_power = torchaudio.transforms.Spectrogram(
|
| 72 |
-
n_fft=512,
|
| 73 |
-
win_length=200,
|
| 74 |
-
hop_length=80,
|
| 75 |
-
power=2.0,
|
| 76 |
-
window_fn=torch.hamming_window,
|
| 77 |
-
)
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
stft_complex = torchaudio.transforms.Spectrogram(
|
| 81 |
-
n_fft=512,
|
| 82 |
-
win_length=200,
|
| 83 |
-
hop_length=80,
|
| 84 |
-
power=None,
|
| 85 |
-
window_fn=torch.hamming_window,
|
| 86 |
-
)
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
istft = torchaudio.transforms.InverseSpectrogram(
|
| 90 |
-
n_fft=512,
|
| 91 |
-
win_length=200,
|
| 92 |
-
hop_length=80,
|
| 93 |
-
window_fn=torch.hamming_window,
|
| 94 |
-
)
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
def enhance(mix_spec_complex: torch.Tensor,
|
| 98 |
-
speech_spec_prediction: torch.Tensor,
|
| 99 |
-
speech_irm_prediction: torch.Tensor,
|
| 100 |
-
):
|
| 101 |
-
mix_spec_complex = mix_spec_complex.detach().cpu()
|
| 102 |
-
speech_spec_prediction = speech_spec_prediction.detach().cpu()
|
| 103 |
-
speech_irm_prediction = speech_irm_prediction.detach().cpu()
|
| 104 |
-
|
| 105 |
-
mask_speech = speech_irm_prediction
|
| 106 |
-
mask_noise = 1.0 - speech_irm_prediction
|
| 107 |
-
|
| 108 |
-
speech_spec = mix_spec_complex * mask_speech
|
| 109 |
-
noise_spec = mix_spec_complex * mask_noise
|
| 110 |
-
|
| 111 |
-
# print(f"speech_spec_prediction: {speech_spec_prediction.shape}")
|
| 112 |
-
# print(f"noise_spec: {noise_spec.shape}")
|
| 113 |
-
|
| 114 |
-
speech_wave = istft.forward(speech_spec_prediction)
|
| 115 |
-
# speech_wave = istft.forward(speech_spec)
|
| 116 |
-
noise_wave = istft.forward(noise_spec)
|
| 117 |
-
|
| 118 |
-
return speech_wave, noise_wave
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
def save_audios(noise_wave: torch.Tensor,
|
| 122 |
-
speech_wave: torch.Tensor,
|
| 123 |
-
mix_wave: torch.Tensor,
|
| 124 |
-
speech_wave_enhanced: torch.Tensor,
|
| 125 |
-
noise_wave_enhanced: torch.Tensor,
|
| 126 |
-
output_dir: str,
|
| 127 |
-
sample_rate: int = 8000,
|
| 128 |
-
):
|
| 129 |
-
basename = uuid.uuid4().__str__()
|
| 130 |
-
output_dir = Path(output_dir) / basename
|
| 131 |
-
output_dir.mkdir(parents=True, exist_ok=True)
|
| 132 |
-
|
| 133 |
-
filename = output_dir / "noise_wave.wav"
|
| 134 |
-
torchaudio.save(filename, noise_wave, sample_rate)
|
| 135 |
-
filename = output_dir / "speech_wave.wav"
|
| 136 |
-
torchaudio.save(filename, speech_wave, sample_rate)
|
| 137 |
-
filename = output_dir / "mix_wave.wav"
|
| 138 |
-
torchaudio.save(filename, mix_wave, sample_rate)
|
| 139 |
-
|
| 140 |
-
filename = output_dir / "speech_wave_enhanced.wav"
|
| 141 |
-
torchaudio.save(filename, speech_wave_enhanced, sample_rate)
|
| 142 |
-
filename = output_dir / "noise_wave_enhanced.wav"
|
| 143 |
-
torchaudio.save(filename, noise_wave_enhanced, sample_rate)
|
| 144 |
-
|
| 145 |
-
return output_dir.as_posix()
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
def main():
|
| 149 |
-
args = get_args()
|
| 150 |
-
|
| 151 |
-
logger = logging_config()
|
| 152 |
-
|
| 153 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 154 |
-
n_gpu = torch.cuda.device_count()
|
| 155 |
-
logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
|
| 156 |
-
|
| 157 |
-
logger.info("prepare model")
|
| 158 |
-
model = SpectrumDfNetPretrainedModel.from_pretrained(
|
| 159 |
-
pretrained_model_name_or_path=args.model_dir,
|
| 160 |
-
)
|
| 161 |
-
model.to(device)
|
| 162 |
-
model.eval()
|
| 163 |
-
|
| 164 |
-
# optimizer
|
| 165 |
-
logger.info("prepare loss_fn")
|
| 166 |
-
irm_mse_loss = nn.MSELoss(
|
| 167 |
-
reduction="mean",
|
| 168 |
-
)
|
| 169 |
-
snr_mse_loss = nn.MSELoss(
|
| 170 |
-
reduction="mean",
|
| 171 |
-
)
|
| 172 |
-
|
| 173 |
-
logger.info("read excel")
|
| 174 |
-
df = pd.read_excel(args.valid_dataset)
|
| 175 |
-
|
| 176 |
-
total_loss = 0.
|
| 177 |
-
total_examples = 0.
|
| 178 |
-
progress_bar = tqdm(total=len(df), desc="Evaluation")
|
| 179 |
-
for idx, row in df.iterrows():
|
| 180 |
-
noise_filename = row["noise_filename"]
|
| 181 |
-
noise_offset = row["noise_offset"]
|
| 182 |
-
noise_duration = row["noise_duration"]
|
| 183 |
-
|
| 184 |
-
speech_filename = row["speech_filename"]
|
| 185 |
-
speech_offset = row["speech_offset"]
|
| 186 |
-
speech_duration = row["speech_duration"]
|
| 187 |
-
|
| 188 |
-
snr_db = row["snr_db"]
|
| 189 |
-
|
| 190 |
-
noise_wave, _ = librosa.load(
|
| 191 |
-
noise_filename,
|
| 192 |
-
sr=8000,
|
| 193 |
-
offset=noise_offset,
|
| 194 |
-
duration=noise_duration,
|
| 195 |
-
)
|
| 196 |
-
speech_wave, _ = librosa.load(
|
| 197 |
-
speech_filename,
|
| 198 |
-
sr=8000,
|
| 199 |
-
offset=speech_offset,
|
| 200 |
-
duration=speech_duration,
|
| 201 |
-
)
|
| 202 |
-
mix_wave: np.ndarray = mix_speech_and_noise(
|
| 203 |
-
speech=speech_wave,
|
| 204 |
-
noise=noise_wave,
|
| 205 |
-
snr_db=snr_db,
|
| 206 |
-
)
|
| 207 |
-
noise_wave = torch.tensor(noise_wave, dtype=torch.float32)
|
| 208 |
-
speech_wave = torch.tensor(speech_wave, dtype=torch.float32)
|
| 209 |
-
mix_wave: torch.Tensor = torch.tensor(mix_wave, dtype=torch.float32)
|
| 210 |
-
|
| 211 |
-
noise_wave = noise_wave.unsqueeze(dim=0)
|
| 212 |
-
speech_wave = speech_wave.unsqueeze(dim=0)
|
| 213 |
-
mix_wave = mix_wave.unsqueeze(dim=0)
|
| 214 |
-
|
| 215 |
-
noise_spec: torch.Tensor = stft_power.forward(noise_wave)
|
| 216 |
-
speech_spec: torch.Tensor = stft_power.forward(speech_wave)
|
| 217 |
-
mix_spec: torch.Tensor = stft_power.forward(mix_wave)
|
| 218 |
-
|
| 219 |
-
speech_spec_complex: torch.Tensor = stft_complex.forward(speech_wave)
|
| 220 |
-
mix_spec_complex: torch.Tensor = stft_complex.forward(mix_wave)
|
| 221 |
-
# mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2]
|
| 222 |
-
|
| 223 |
-
noise_spec = noise_spec[:, :-1, :]
|
| 224 |
-
speech_spec = speech_spec[:, :-1, :]
|
| 225 |
-
mix_spec = mix_spec[:, :-1, :]
|
| 226 |
-
speech_spec_complex = speech_spec_complex[:, :-1, :]
|
| 227 |
-
mix_spec_complex = mix_spec_complex[:, :-1, :]
|
| 228 |
-
|
| 229 |
-
speech_irm = speech_spec / (noise_spec + speech_spec)
|
| 230 |
-
speech_irm = torch.pow(speech_irm, 1.0)
|
| 231 |
-
|
| 232 |
-
snr_db: torch.Tensor = 10 * torch.log10(
|
| 233 |
-
speech_spec / (noise_spec + 1e-8)
|
| 234 |
-
)
|
| 235 |
-
snr_db = torch.clamp(snr_db, min=1e-8)
|
| 236 |
-
snr_db = torch.mean(snr_db, dim=1, keepdim=True)
|
| 237 |
-
# snr_db shape: [batch_size, 1, time_steps]
|
| 238 |
-
|
| 239 |
-
speech_spec_complex = speech_spec_complex.to(device)
|
| 240 |
-
mix_spec_complex = mix_spec_complex.to(device)
|
| 241 |
-
mix_spec = mix_spec.to(device)
|
| 242 |
-
speech_irm_target = speech_irm.to(device)
|
| 243 |
-
snr_db_target = snr_db.to(device)
|
| 244 |
-
|
| 245 |
-
with torch.no_grad():
|
| 246 |
-
speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_spec_complex)
|
| 247 |
-
speech_spec_prediction = torch.view_as_complex(speech_spec_prediction)
|
| 248 |
-
|
| 249 |
-
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
| 250 |
-
# snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
| 251 |
-
# loss = irm_loss + 0.1 * snr_loss
|
| 252 |
-
loss = irm_loss
|
| 253 |
-
|
| 254 |
-
# mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2]
|
| 255 |
-
# speech_irm_prediction shape: [batch_size, freq_dim (256), time_steps]
|
| 256 |
-
batch_size, _, time_steps = speech_irm_prediction.shape
|
| 257 |
-
|
| 258 |
-
mix_spec_complex = torch.concat(
|
| 259 |
-
[
|
| 260 |
-
mix_spec_complex,
|
| 261 |
-
torch.zeros(size=(batch_size, 1, time_steps), dtype=mix_spec_complex.dtype).to(device)
|
| 262 |
-
],
|
| 263 |
-
dim=1,
|
| 264 |
-
)
|
| 265 |
-
speech_spec_prediction = torch.concat(
|
| 266 |
-
[
|
| 267 |
-
speech_spec_prediction,
|
| 268 |
-
torch.zeros(size=(batch_size, 1, time_steps), dtype=speech_spec_prediction.dtype).to(device)
|
| 269 |
-
],
|
| 270 |
-
dim=1,
|
| 271 |
-
)
|
| 272 |
-
speech_irm_prediction = torch.concat(
|
| 273 |
-
[
|
| 274 |
-
speech_irm_prediction,
|
| 275 |
-
0.5 * torch.ones(size=(batch_size, 1, time_steps), dtype=speech_irm_prediction.dtype).to(device)
|
| 276 |
-
],
|
| 277 |
-
dim=1,
|
| 278 |
-
)
|
| 279 |
-
|
| 280 |
-
# speech_irm_prediction shape: [batch_size, freq_dim (257), time_steps]
|
| 281 |
-
speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_spec_prediction, speech_irm_prediction)
|
| 282 |
-
save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir)
|
| 283 |
-
|
| 284 |
-
total_loss += loss.item()
|
| 285 |
-
total_examples += mix_spec.size(0)
|
| 286 |
-
|
| 287 |
-
evaluation_loss = total_loss / total_examples
|
| 288 |
-
evaluation_loss = round(evaluation_loss, 4)
|
| 289 |
-
|
| 290 |
-
progress_bar.update(1)
|
| 291 |
-
progress_bar.set_postfix({
|
| 292 |
-
"evaluation_loss": evaluation_loss,
|
| 293 |
-
})
|
| 294 |
-
|
| 295 |
-
if idx > args.limit:
|
| 296 |
-
break
|
| 297 |
-
|
| 298 |
-
return
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
if __name__ == '__main__':
|
| 302 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/spectrum_dfnet_aishell/yaml/config.yaml
DELETED
|
@@ -1,53 +0,0 @@
|
|
| 1 |
-
model_name: "spectrum_unet_irm"
|
| 2 |
-
|
| 3 |
-
# spec
|
| 4 |
-
sample_rate: 8000
|
| 5 |
-
n_fft: 512
|
| 6 |
-
win_length: 200
|
| 7 |
-
hop_length: 80
|
| 8 |
-
|
| 9 |
-
spec_bins: 256
|
| 10 |
-
|
| 11 |
-
# model
|
| 12 |
-
conv_channels: 64
|
| 13 |
-
conv_kernel_size_input:
|
| 14 |
-
- 3
|
| 15 |
-
- 3
|
| 16 |
-
conv_kernel_size_inner:
|
| 17 |
-
- 1
|
| 18 |
-
- 3
|
| 19 |
-
conv_lookahead: 0
|
| 20 |
-
|
| 21 |
-
convt_kernel_size_inner:
|
| 22 |
-
- 1
|
| 23 |
-
- 3
|
| 24 |
-
|
| 25 |
-
embedding_hidden_size: 256
|
| 26 |
-
encoder_combine_op: "concat"
|
| 27 |
-
|
| 28 |
-
encoder_emb_skip_op: "none"
|
| 29 |
-
encoder_emb_linear_groups: 16
|
| 30 |
-
encoder_emb_hidden_size: 256
|
| 31 |
-
|
| 32 |
-
encoder_linear_groups: 32
|
| 33 |
-
|
| 34 |
-
lsnr_max: 30
|
| 35 |
-
lsnr_min: -15
|
| 36 |
-
norm_tau: 1.
|
| 37 |
-
|
| 38 |
-
decoder_emb_num_layers: 3
|
| 39 |
-
decoder_emb_skip_op: "none"
|
| 40 |
-
decoder_emb_linear_groups: 16
|
| 41 |
-
decoder_emb_hidden_size: 256
|
| 42 |
-
|
| 43 |
-
df_decoder_hidden_size: 256
|
| 44 |
-
df_num_layers: 2
|
| 45 |
-
df_order: 5
|
| 46 |
-
df_bins: 96
|
| 47 |
-
df_gru_skip: "grouped_linear"
|
| 48 |
-
df_decoder_linear_groups: 16
|
| 49 |
-
df_pathway_kernel_size_t: 5
|
| 50 |
-
df_lookahead: 2
|
| 51 |
-
|
| 52 |
-
# runtime
|
| 53 |
-
use_post_filter: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/spectrum_unet_irm_aishell/run.sh
DELETED
|
@@ -1,178 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env bash
|
| 2 |
-
|
| 3 |
-
: <<'END'
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir \
|
| 7 |
-
--noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
|
| 8 |
-
--speech_dir "E:/programmer/asr_datasets/aishell/data_aishell/wav/train"
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
sh run.sh --stage 1 --stop_stage 3 --system_version centos --file_folder_name file_dir \
|
| 12 |
-
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
| 13 |
-
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
| 14 |
-
|
| 15 |
-
sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name file_dir \
|
| 16 |
-
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
| 17 |
-
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
END
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
# params
|
| 24 |
-
system_version="windows";
|
| 25 |
-
verbose=true;
|
| 26 |
-
stage=0 # start from 0 if you need to start from data preparation
|
| 27 |
-
stop_stage=9
|
| 28 |
-
|
| 29 |
-
work_dir="$(pwd)"
|
| 30 |
-
file_folder_name=file_folder_name
|
| 31 |
-
final_model_name=final_model_name
|
| 32 |
-
config_file="yaml/config.yaml"
|
| 33 |
-
limit=10
|
| 34 |
-
|
| 35 |
-
noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
|
| 36 |
-
speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
|
| 37 |
-
|
| 38 |
-
nohup_name=nohup.out
|
| 39 |
-
|
| 40 |
-
# model params
|
| 41 |
-
batch_size=64
|
| 42 |
-
max_epochs=200
|
| 43 |
-
save_top_k=10
|
| 44 |
-
patience=5
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
# parse options
|
| 48 |
-
while true; do
|
| 49 |
-
[ -z "${1:-}" ] && break; # break if there are no arguments
|
| 50 |
-
case "$1" in
|
| 51 |
-
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
| 52 |
-
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
| 53 |
-
old_value="(eval echo \\$$name)";
|
| 54 |
-
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
| 55 |
-
was_bool=true;
|
| 56 |
-
else
|
| 57 |
-
was_bool=false;
|
| 58 |
-
fi
|
| 59 |
-
|
| 60 |
-
# Set the variable to the right value-- the escaped quotes make it work if
|
| 61 |
-
# the option had spaces, like --cmd "queue.pl -sync y"
|
| 62 |
-
eval "${name}=\"$2\"";
|
| 63 |
-
|
| 64 |
-
# Check that Boolean-valued arguments are really Boolean.
|
| 65 |
-
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
| 66 |
-
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
| 67 |
-
exit 1;
|
| 68 |
-
fi
|
| 69 |
-
shift 2;
|
| 70 |
-
;;
|
| 71 |
-
|
| 72 |
-
*) break;
|
| 73 |
-
esac
|
| 74 |
-
done
|
| 75 |
-
|
| 76 |
-
file_dir="${work_dir}/${file_folder_name}"
|
| 77 |
-
final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
|
| 78 |
-
evaluation_audio_dir="${file_dir}/evaluation_audio"
|
| 79 |
-
|
| 80 |
-
dataset="${file_dir}/dataset.xlsx"
|
| 81 |
-
train_dataset="${file_dir}/train.xlsx"
|
| 82 |
-
valid_dataset="${file_dir}/valid.xlsx"
|
| 83 |
-
|
| 84 |
-
$verbose && echo "system_version: ${system_version}"
|
| 85 |
-
$verbose && echo "file_folder_name: ${file_folder_name}"
|
| 86 |
-
|
| 87 |
-
if [ $system_version == "windows" ]; then
|
| 88 |
-
alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
|
| 89 |
-
elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
|
| 90 |
-
#source /data/local/bin/nx_denoise/bin/activate
|
| 91 |
-
alias python3='/data/local/bin/nx_denoise/bin/python3'
|
| 92 |
-
fi
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
| 96 |
-
$verbose && echo "stage 1: prepare data"
|
| 97 |
-
cd "${work_dir}" || exit 1
|
| 98 |
-
python3 step_1_prepare_data.py \
|
| 99 |
-
--file_dir "${file_dir}" \
|
| 100 |
-
--noise_dir "${noise_dir}" \
|
| 101 |
-
--speech_dir "${speech_dir}" \
|
| 102 |
-
--train_dataset "${train_dataset}" \
|
| 103 |
-
--valid_dataset "${valid_dataset}" \
|
| 104 |
-
|
| 105 |
-
fi
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
| 109 |
-
$verbose && echo "stage 2: train model"
|
| 110 |
-
cd "${work_dir}" || exit 1
|
| 111 |
-
python3 step_2_train_model.py \
|
| 112 |
-
--train_dataset "${train_dataset}" \
|
| 113 |
-
--valid_dataset "${valid_dataset}" \
|
| 114 |
-
--serialization_dir "${file_dir}" \
|
| 115 |
-
--config_file "${config_file}" \
|
| 116 |
-
|
| 117 |
-
fi
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
| 121 |
-
$verbose && echo "stage 3: test model"
|
| 122 |
-
cd "${work_dir}" || exit 1
|
| 123 |
-
python3 step_3_evaluation.py \
|
| 124 |
-
--valid_dataset "${valid_dataset}" \
|
| 125 |
-
--model_dir "${file_dir}/best" \
|
| 126 |
-
--evaluation_audio_dir "${evaluation_audio_dir}" \
|
| 127 |
-
--limit "${limit}" \
|
| 128 |
-
|
| 129 |
-
fi
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
| 133 |
-
$verbose && echo "stage 4: export model"
|
| 134 |
-
cd "${work_dir}" || exit 1
|
| 135 |
-
python3 step_5_export_models.py \
|
| 136 |
-
--vocabulary_dir "${vocabulary_dir}" \
|
| 137 |
-
--model_dir "${file_dir}/best" \
|
| 138 |
-
--serialization_dir "${file_dir}" \
|
| 139 |
-
|
| 140 |
-
fi
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
| 144 |
-
$verbose && echo "stage 5: collect files"
|
| 145 |
-
cd "${work_dir}" || exit 1
|
| 146 |
-
|
| 147 |
-
mkdir -p ${final_model_dir}
|
| 148 |
-
|
| 149 |
-
cp "${file_dir}/best"/* "${final_model_dir}"
|
| 150 |
-
cp -r "${file_dir}/vocabulary" "${final_model_dir}"
|
| 151 |
-
|
| 152 |
-
cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx"
|
| 153 |
-
|
| 154 |
-
cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip"
|
| 155 |
-
cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip"
|
| 156 |
-
cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip"
|
| 157 |
-
cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip"
|
| 158 |
-
|
| 159 |
-
cd "${final_model_dir}/.." || exit 1;
|
| 160 |
-
|
| 161 |
-
if [ -e "${final_model_name}.zip" ]; then
|
| 162 |
-
rm -rf "${final_model_name}_backup.zip"
|
| 163 |
-
mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
|
| 164 |
-
fi
|
| 165 |
-
|
| 166 |
-
zip -r "${final_model_name}.zip" "${final_model_name}"
|
| 167 |
-
rm -rf "${final_model_name}"
|
| 168 |
-
|
| 169 |
-
fi
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
| 173 |
-
$verbose && echo "stage 6: clear file_dir"
|
| 174 |
-
cd "${work_dir}" || exit 1
|
| 175 |
-
|
| 176 |
-
rm -rf "${file_dir}";
|
| 177 |
-
|
| 178 |
-
fi
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/spectrum_unet_irm_aishell/step_1_prepare_data.py
DELETED
|
@@ -1,197 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/python3
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
import argparse
|
| 4 |
-
import os
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
import random
|
| 7 |
-
import sys
|
| 8 |
-
import shutil
|
| 9 |
-
|
| 10 |
-
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 11 |
-
sys.path.append(os.path.join(pwd, "../../"))
|
| 12 |
-
|
| 13 |
-
import pandas as pd
|
| 14 |
-
from scipy.io import wavfile
|
| 15 |
-
from tqdm import tqdm
|
| 16 |
-
import librosa
|
| 17 |
-
|
| 18 |
-
from project_settings import project_path
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def get_args():
|
| 22 |
-
parser = argparse.ArgumentParser()
|
| 23 |
-
parser.add_argument("--file_dir", default="./", type=str)
|
| 24 |
-
|
| 25 |
-
parser.add_argument(
|
| 26 |
-
"--noise_dir",
|
| 27 |
-
default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
|
| 28 |
-
type=str
|
| 29 |
-
)
|
| 30 |
-
parser.add_argument(
|
| 31 |
-
"--speech_dir",
|
| 32 |
-
default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
|
| 33 |
-
type=str
|
| 34 |
-
)
|
| 35 |
-
|
| 36 |
-
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
| 37 |
-
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
| 38 |
-
|
| 39 |
-
parser.add_argument("--duration", default=2.0, type=float)
|
| 40 |
-
parser.add_argument("--min_snr_db", default=-10, type=float)
|
| 41 |
-
parser.add_argument("--max_snr_db", default=20, type=float)
|
| 42 |
-
|
| 43 |
-
parser.add_argument("--target_sample_rate", default=8000, type=int)
|
| 44 |
-
|
| 45 |
-
args = parser.parse_args()
|
| 46 |
-
return args
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def filename_generator(data_dir: str):
|
| 50 |
-
data_dir = Path(data_dir)
|
| 51 |
-
for filename in data_dir.glob("**/*.wav"):
|
| 52 |
-
yield filename.as_posix()
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000):
|
| 56 |
-
data_dir = Path(data_dir)
|
| 57 |
-
for filename in data_dir.glob("**/*.wav"):
|
| 58 |
-
signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
|
| 59 |
-
raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
|
| 60 |
-
|
| 61 |
-
if raw_duration < duration:
|
| 62 |
-
# print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
|
| 63 |
-
continue
|
| 64 |
-
if signal.ndim != 1:
|
| 65 |
-
raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
|
| 66 |
-
|
| 67 |
-
signal_length = len(signal)
|
| 68 |
-
win_size = int(duration * sample_rate)
|
| 69 |
-
for begin in range(0, signal_length - win_size, win_size):
|
| 70 |
-
row = {
|
| 71 |
-
"filename": filename.as_posix(),
|
| 72 |
-
"raw_duration": round(raw_duration, 4),
|
| 73 |
-
"offset": round(begin / sample_rate, 4),
|
| 74 |
-
"duration": round(duration, 4),
|
| 75 |
-
}
|
| 76 |
-
yield row
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def get_dataset(args):
|
| 80 |
-
file_dir = Path(args.file_dir)
|
| 81 |
-
file_dir.mkdir(exist_ok=True)
|
| 82 |
-
|
| 83 |
-
noise_dir = Path(args.noise_dir)
|
| 84 |
-
speech_dir = Path(args.speech_dir)
|
| 85 |
-
|
| 86 |
-
noise_generator = target_second_signal_generator(
|
| 87 |
-
noise_dir.as_posix(),
|
| 88 |
-
duration=args.duration,
|
| 89 |
-
sample_rate=args.target_sample_rate
|
| 90 |
-
)
|
| 91 |
-
speech_generator = target_second_signal_generator(
|
| 92 |
-
speech_dir.as_posix(),
|
| 93 |
-
duration=args.duration,
|
| 94 |
-
sample_rate=args.target_sample_rate
|
| 95 |
-
)
|
| 96 |
-
|
| 97 |
-
dataset = list()
|
| 98 |
-
|
| 99 |
-
count = 0
|
| 100 |
-
process_bar = tqdm(desc="build dataset excel")
|
| 101 |
-
for noise, speech in zip(noise_generator, speech_generator):
|
| 102 |
-
|
| 103 |
-
noise_filename = noise["filename"]
|
| 104 |
-
noise_raw_duration = noise["raw_duration"]
|
| 105 |
-
noise_offset = noise["offset"]
|
| 106 |
-
noise_duration = noise["duration"]
|
| 107 |
-
|
| 108 |
-
speech_filename = speech["filename"]
|
| 109 |
-
speech_raw_duration = speech["raw_duration"]
|
| 110 |
-
speech_offset = speech["offset"]
|
| 111 |
-
speech_duration = speech["duration"]
|
| 112 |
-
|
| 113 |
-
random1 = random.random()
|
| 114 |
-
random2 = random.random()
|
| 115 |
-
|
| 116 |
-
row = {
|
| 117 |
-
"noise_filename": noise_filename,
|
| 118 |
-
"noise_raw_duration": noise_raw_duration,
|
| 119 |
-
"noise_offset": noise_offset,
|
| 120 |
-
"noise_duration": noise_duration,
|
| 121 |
-
|
| 122 |
-
"speech_filename": speech_filename,
|
| 123 |
-
"speech_raw_duration": speech_raw_duration,
|
| 124 |
-
"speech_offset": speech_offset,
|
| 125 |
-
"speech_duration": speech_duration,
|
| 126 |
-
|
| 127 |
-
"snr_db": random.uniform(args.min_snr_db, args.max_snr_db),
|
| 128 |
-
|
| 129 |
-
"random1": random1,
|
| 130 |
-
"random2": random2,
|
| 131 |
-
"flag": "TRAIN" if random2 < 0.8 else "TEST",
|
| 132 |
-
}
|
| 133 |
-
dataset.append(row)
|
| 134 |
-
count += 1
|
| 135 |
-
duration_seconds = count * args.duration
|
| 136 |
-
duration_hours = duration_seconds / 3600
|
| 137 |
-
|
| 138 |
-
process_bar.update(n=1)
|
| 139 |
-
process_bar.set_postfix({
|
| 140 |
-
# "duration_seconds": round(duration_seconds, 4),
|
| 141 |
-
"duration_hours": round(duration_hours, 4),
|
| 142 |
-
|
| 143 |
-
})
|
| 144 |
-
|
| 145 |
-
dataset = pd.DataFrame(dataset)
|
| 146 |
-
dataset = dataset.sort_values(by=["random1"], ascending=False)
|
| 147 |
-
dataset.to_excel(
|
| 148 |
-
file_dir / "dataset.xlsx",
|
| 149 |
-
index=False,
|
| 150 |
-
)
|
| 151 |
-
return
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
def split_dataset(args):
|
| 156 |
-
"""分割训练集, 测试集"""
|
| 157 |
-
file_dir = Path(args.file_dir)
|
| 158 |
-
file_dir.mkdir(exist_ok=True)
|
| 159 |
-
|
| 160 |
-
df = pd.read_excel(file_dir / "dataset.xlsx")
|
| 161 |
-
|
| 162 |
-
train = list()
|
| 163 |
-
test = list()
|
| 164 |
-
|
| 165 |
-
for i, row in df.iterrows():
|
| 166 |
-
flag = row["flag"]
|
| 167 |
-
if flag == "TRAIN":
|
| 168 |
-
train.append(row)
|
| 169 |
-
else:
|
| 170 |
-
test.append(row)
|
| 171 |
-
|
| 172 |
-
train = pd.DataFrame(train)
|
| 173 |
-
train.to_excel(
|
| 174 |
-
args.train_dataset,
|
| 175 |
-
index=False,
|
| 176 |
-
# encoding="utf_8_sig"
|
| 177 |
-
)
|
| 178 |
-
test = pd.DataFrame(test)
|
| 179 |
-
test.to_excel(
|
| 180 |
-
args.valid_dataset,
|
| 181 |
-
index=False,
|
| 182 |
-
# encoding="utf_8_sig"
|
| 183 |
-
)
|
| 184 |
-
|
| 185 |
-
return
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
def main():
|
| 189 |
-
args = get_args()
|
| 190 |
-
|
| 191 |
-
get_dataset(args)
|
| 192 |
-
split_dataset(args)
|
| 193 |
-
return
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
if __name__ == "__main__":
|
| 197 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/spectrum_unet_irm_aishell/step_2_train_model.py
DELETED
|
@@ -1,420 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/python3
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
"""
|
| 4 |
-
https://github.com/WenzheLiu-Speech/awesome-speech-enhancement
|
| 5 |
-
"""
|
| 6 |
-
import argparse
|
| 7 |
-
import json
|
| 8 |
-
import logging
|
| 9 |
-
from logging.handlers import TimedRotatingFileHandler
|
| 10 |
-
import os
|
| 11 |
-
import platform
|
| 12 |
-
from pathlib import Path
|
| 13 |
-
import random
|
| 14 |
-
import sys
|
| 15 |
-
import shutil
|
| 16 |
-
from typing import List
|
| 17 |
-
|
| 18 |
-
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 19 |
-
sys.path.append(os.path.join(pwd, "../../"))
|
| 20 |
-
|
| 21 |
-
import numpy as np
|
| 22 |
-
import torch
|
| 23 |
-
import torch.nn as nn
|
| 24 |
-
from torch.nn import functional as F
|
| 25 |
-
from torch.utils.data.dataloader import DataLoader
|
| 26 |
-
import torchaudio
|
| 27 |
-
from tqdm import tqdm
|
| 28 |
-
|
| 29 |
-
from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
|
| 30 |
-
from toolbox.torchaudio.models.spectrum_unet_irm.configuration_specturm_unet_irm import SpectrumUnetIRMConfig
|
| 31 |
-
from toolbox.torchaudio.models.spectrum_unet_irm.modeling_spectrum_unet_irm import SpectrumUnetIRMPretrainedModel
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def get_args():
|
| 35 |
-
parser = argparse.ArgumentParser()
|
| 36 |
-
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
| 37 |
-
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
| 38 |
-
|
| 39 |
-
parser.add_argument("--max_epochs", default=100, type=int)
|
| 40 |
-
|
| 41 |
-
parser.add_argument("--batch_size", default=64, type=int)
|
| 42 |
-
parser.add_argument("--learning_rate", default=1e-4, type=float)
|
| 43 |
-
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
| 44 |
-
parser.add_argument("--patience", default=5, type=int)
|
| 45 |
-
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
| 46 |
-
parser.add_argument("--seed", default=0, type=int)
|
| 47 |
-
|
| 48 |
-
parser.add_argument("--config_file", default="config.yaml", type=str)
|
| 49 |
-
|
| 50 |
-
args = parser.parse_args()
|
| 51 |
-
return args
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def logging_config(file_dir: str):
|
| 55 |
-
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
| 56 |
-
|
| 57 |
-
logging.basicConfig(format=fmt,
|
| 58 |
-
datefmt="%m/%d/%Y %H:%M:%S",
|
| 59 |
-
level=logging.INFO)
|
| 60 |
-
file_handler = TimedRotatingFileHandler(
|
| 61 |
-
filename=os.path.join(file_dir, "main.log"),
|
| 62 |
-
encoding="utf-8",
|
| 63 |
-
when="D",
|
| 64 |
-
interval=1,
|
| 65 |
-
backupCount=7
|
| 66 |
-
)
|
| 67 |
-
file_handler.setLevel(logging.INFO)
|
| 68 |
-
file_handler.setFormatter(logging.Formatter(fmt))
|
| 69 |
-
logger = logging.getLogger(__name__)
|
| 70 |
-
logger.addHandler(file_handler)
|
| 71 |
-
|
| 72 |
-
return logger
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
class CollateFunction(object):
|
| 76 |
-
def __init__(self,
|
| 77 |
-
n_fft: int = 512,
|
| 78 |
-
win_length: int = 200,
|
| 79 |
-
hop_length: int = 80,
|
| 80 |
-
window_fn: str = "hamming",
|
| 81 |
-
irm_beta: float = 1.0,
|
| 82 |
-
epsilon: float = 1e-8,
|
| 83 |
-
):
|
| 84 |
-
self.n_fft = n_fft
|
| 85 |
-
self.win_length = win_length
|
| 86 |
-
self.hop_length = hop_length
|
| 87 |
-
self.window_fn = window_fn
|
| 88 |
-
self.irm_beta = irm_beta
|
| 89 |
-
self.epsilon = epsilon
|
| 90 |
-
|
| 91 |
-
self.transform = torchaudio.transforms.Spectrogram(
|
| 92 |
-
n_fft=self.n_fft,
|
| 93 |
-
win_length=self.win_length,
|
| 94 |
-
hop_length=self.hop_length,
|
| 95 |
-
power=2.0,
|
| 96 |
-
window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
|
| 97 |
-
)
|
| 98 |
-
|
| 99 |
-
@staticmethod
|
| 100 |
-
def make_unfold_snr_db(x: torch.Tensor, n_time_steps: int = 3):
|
| 101 |
-
batch_size, channels, freq_dim, time_steps = x.shape
|
| 102 |
-
|
| 103 |
-
# kernel: [freq_dim, n_time_step]
|
| 104 |
-
kernel_size = (freq_dim, n_time_steps)
|
| 105 |
-
|
| 106 |
-
# pad
|
| 107 |
-
pad = n_time_steps // 2
|
| 108 |
-
x = torch.concat(tensors=[
|
| 109 |
-
x[:, :, :, :pad],
|
| 110 |
-
x,
|
| 111 |
-
x[:, :, :, -pad:],
|
| 112 |
-
], dim=-1)
|
| 113 |
-
|
| 114 |
-
x = F.unfold(
|
| 115 |
-
input=x,
|
| 116 |
-
kernel_size=kernel_size,
|
| 117 |
-
)
|
| 118 |
-
# x shape: [batch_size, fold, time_steps]
|
| 119 |
-
return x
|
| 120 |
-
|
| 121 |
-
def __call__(self, batch: List[dict]):
|
| 122 |
-
mix_spec_list = list()
|
| 123 |
-
speech_irm_list = list()
|
| 124 |
-
snr_db_list = list()
|
| 125 |
-
for sample in batch:
|
| 126 |
-
noise_wave: torch.Tensor = sample["noise_wave"]
|
| 127 |
-
speech_wave: torch.Tensor = sample["speech_wave"]
|
| 128 |
-
mix_wave: torch.Tensor = sample["mix_wave"]
|
| 129 |
-
# snr_db: float = sample["snr_db"]
|
| 130 |
-
|
| 131 |
-
noise_spec = self.transform.forward(noise_wave)
|
| 132 |
-
speech_spec = self.transform.forward(speech_wave)
|
| 133 |
-
mix_spec = self.transform.forward(mix_wave)
|
| 134 |
-
|
| 135 |
-
# noise_irm = noise_spec / (noise_spec + speech_spec)
|
| 136 |
-
speech_irm = speech_spec / (noise_spec + speech_spec + self.epsilon)
|
| 137 |
-
speech_irm = torch.pow(speech_irm, self.irm_beta)
|
| 138 |
-
|
| 139 |
-
# noise_spec, speech_spec, mix_spec, speech_irm
|
| 140 |
-
# shape: [freq_dim, time_steps]
|
| 141 |
-
|
| 142 |
-
snr_db: torch.Tensor = 10 * torch.log10(
|
| 143 |
-
speech_spec / (noise_spec + self.epsilon)
|
| 144 |
-
)
|
| 145 |
-
snr_db = torch.clamp(snr_db, min=self.epsilon)
|
| 146 |
-
|
| 147 |
-
snr_db_ = torch.unsqueeze(snr_db, dim=0)
|
| 148 |
-
snr_db_ = torch.unsqueeze(snr_db_, dim=0)
|
| 149 |
-
snr_db_ = self.make_unfold_snr_db(snr_db_, n_time_steps=3)
|
| 150 |
-
snr_db_ = torch.squeeze(snr_db_, dim=0)
|
| 151 |
-
# snr_db_ shape: [fold, time_steps]
|
| 152 |
-
|
| 153 |
-
snr_db = torch.mean(snr_db_, dim=0, keepdim=True)
|
| 154 |
-
# snr_db shape: [1, time_steps]
|
| 155 |
-
|
| 156 |
-
mix_spec_list.append(mix_spec)
|
| 157 |
-
speech_irm_list.append(speech_irm)
|
| 158 |
-
snr_db_list.append(snr_db)
|
| 159 |
-
|
| 160 |
-
mix_spec_list = torch.stack(mix_spec_list)
|
| 161 |
-
speech_irm_list = torch.stack(speech_irm_list)
|
| 162 |
-
snr_db_list = torch.stack(snr_db_list) # shape: (batch_size, time_steps, 1)
|
| 163 |
-
|
| 164 |
-
mix_spec_list = mix_spec_list[:, :-1, :]
|
| 165 |
-
speech_irm_list = speech_irm_list[:, :-1, :]
|
| 166 |
-
|
| 167 |
-
# mix_spec_list shape: [batch_size, freq_dim, time_steps]
|
| 168 |
-
# speech_irm_list shape: [batch_size, freq_dim, time_steps]
|
| 169 |
-
# snr_db shape: [batch_size, 1, time_steps]
|
| 170 |
-
|
| 171 |
-
# assert
|
| 172 |
-
if torch.any(torch.isnan(mix_spec_list)) or torch.any(torch.isinf(mix_spec_list)):
|
| 173 |
-
raise AssertionError("nan or inf in mix_spec_list")
|
| 174 |
-
if torch.any(torch.isnan(speech_irm_list)) or torch.any(torch.isinf(speech_irm_list)):
|
| 175 |
-
raise AssertionError("nan or inf in speech_irm_list")
|
| 176 |
-
if torch.any(torch.isnan(snr_db_list)) or torch.any(torch.isinf(snr_db_list)):
|
| 177 |
-
raise AssertionError("nan or inf in snr_db_list")
|
| 178 |
-
|
| 179 |
-
return mix_spec_list, speech_irm_list, snr_db_list
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
collate_fn = CollateFunction()
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
def main():
|
| 186 |
-
args = get_args()
|
| 187 |
-
|
| 188 |
-
serialization_dir = Path(args.serialization_dir)
|
| 189 |
-
serialization_dir.mkdir(parents=True, exist_ok=True)
|
| 190 |
-
|
| 191 |
-
logger = logging_config(serialization_dir)
|
| 192 |
-
|
| 193 |
-
random.seed(args.seed)
|
| 194 |
-
np.random.seed(args.seed)
|
| 195 |
-
torch.manual_seed(args.seed)
|
| 196 |
-
logger.info("set seed: {}".format(args.seed))
|
| 197 |
-
|
| 198 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 199 |
-
n_gpu = torch.cuda.device_count()
|
| 200 |
-
logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
|
| 201 |
-
|
| 202 |
-
# datasets
|
| 203 |
-
logger.info("prepare datasets")
|
| 204 |
-
train_dataset = DenoiseExcelDataset(
|
| 205 |
-
excel_file=args.train_dataset,
|
| 206 |
-
expected_sample_rate=8000,
|
| 207 |
-
max_wave_value=32768.0,
|
| 208 |
-
)
|
| 209 |
-
valid_dataset = DenoiseExcelDataset(
|
| 210 |
-
excel_file=args.valid_dataset,
|
| 211 |
-
expected_sample_rate=8000,
|
| 212 |
-
max_wave_value=32768.0,
|
| 213 |
-
)
|
| 214 |
-
train_data_loader = DataLoader(
|
| 215 |
-
dataset=train_dataset,
|
| 216 |
-
batch_size=args.batch_size,
|
| 217 |
-
shuffle=True,
|
| 218 |
-
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
| 219 |
-
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
| 220 |
-
collate_fn=collate_fn,
|
| 221 |
-
pin_memory=False,
|
| 222 |
-
# prefetch_factor=64,
|
| 223 |
-
)
|
| 224 |
-
valid_data_loader = DataLoader(
|
| 225 |
-
dataset=valid_dataset,
|
| 226 |
-
batch_size=args.batch_size,
|
| 227 |
-
shuffle=True,
|
| 228 |
-
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
| 229 |
-
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
| 230 |
-
collate_fn=collate_fn,
|
| 231 |
-
pin_memory=False,
|
| 232 |
-
# prefetch_factor=64,
|
| 233 |
-
)
|
| 234 |
-
|
| 235 |
-
# models
|
| 236 |
-
logger.info(f"prepare models. config_file: {args.config_file}")
|
| 237 |
-
config = SpectrumUnetIRMConfig.from_pretrained(
|
| 238 |
-
pretrained_model_name_or_path=args.config_file,
|
| 239 |
-
# num_labels=vocabulary.get_vocab_size(namespace="labels")
|
| 240 |
-
)
|
| 241 |
-
model = SpectrumUnetIRMPretrainedModel(
|
| 242 |
-
config=config,
|
| 243 |
-
)
|
| 244 |
-
model.to(device)
|
| 245 |
-
model.train()
|
| 246 |
-
|
| 247 |
-
# optimizer
|
| 248 |
-
logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
|
| 249 |
-
param_optimizer = model.parameters()
|
| 250 |
-
optimizer = torch.optim.Adam(
|
| 251 |
-
param_optimizer,
|
| 252 |
-
lr=args.learning_rate,
|
| 253 |
-
)
|
| 254 |
-
# lr_scheduler = torch.optim.lr_scheduler.StepLR(
|
| 255 |
-
# optimizer,
|
| 256 |
-
# step_size=2000
|
| 257 |
-
# )
|
| 258 |
-
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
| 259 |
-
optimizer,
|
| 260 |
-
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
|
| 261 |
-
)
|
| 262 |
-
irm_mse_loss = nn.MSELoss(
|
| 263 |
-
reduction="mean",
|
| 264 |
-
)
|
| 265 |
-
snr_mse_loss = nn.MSELoss(
|
| 266 |
-
reduction="mean",
|
| 267 |
-
)
|
| 268 |
-
|
| 269 |
-
# training loop
|
| 270 |
-
logger.info("training")
|
| 271 |
-
|
| 272 |
-
training_loss = 10000000000
|
| 273 |
-
evaluation_loss = 10000000000
|
| 274 |
-
|
| 275 |
-
model_list = list()
|
| 276 |
-
best_idx_epoch = None
|
| 277 |
-
best_metric = None
|
| 278 |
-
patience_count = 0
|
| 279 |
-
|
| 280 |
-
for idx_epoch in range(args.max_epochs):
|
| 281 |
-
total_loss = 0.
|
| 282 |
-
total_examples = 0.
|
| 283 |
-
progress_bar = tqdm(
|
| 284 |
-
total=len(train_data_loader),
|
| 285 |
-
desc="Training; epoch: {}".format(idx_epoch),
|
| 286 |
-
)
|
| 287 |
-
|
| 288 |
-
for batch in train_data_loader:
|
| 289 |
-
mix_spec, speech_irm, snr_db = batch
|
| 290 |
-
mix_spec = mix_spec.to(device)
|
| 291 |
-
speech_irm_target = speech_irm.to(device)
|
| 292 |
-
snr_db_target = snr_db.to(device)
|
| 293 |
-
|
| 294 |
-
speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
|
| 295 |
-
if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
|
| 296 |
-
raise AssertionError("nan or inf in speech_irm_prediction")
|
| 297 |
-
if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
|
| 298 |
-
raise AssertionError("nan or inf in lsnr_prediction")
|
| 299 |
-
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
| 300 |
-
lsnr_prediction = (lsnr_prediction - config.lsnr_min) / (config.lsnr_max - config.lsnr_min)
|
| 301 |
-
if torch.max(lsnr_prediction) > 1 or torch.min(lsnr_prediction) < 0:
|
| 302 |
-
raise AssertionError(f"expected lsnr_prediction between 0 and 1.")
|
| 303 |
-
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
| 304 |
-
if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
|
| 305 |
-
raise AssertionError("nan or inf in snr_loss")
|
| 306 |
-
# loss = irm_loss + 0.1 * snr_loss
|
| 307 |
-
loss = 10.0 * irm_loss + 0.05 * snr_loss
|
| 308 |
-
# loss = irm_loss
|
| 309 |
-
|
| 310 |
-
total_loss += loss.item()
|
| 311 |
-
total_examples += mix_spec.size(0)
|
| 312 |
-
|
| 313 |
-
optimizer.zero_grad()
|
| 314 |
-
loss.backward()
|
| 315 |
-
optimizer.step()
|
| 316 |
-
lr_scheduler.step()
|
| 317 |
-
|
| 318 |
-
training_loss = total_loss / total_examples
|
| 319 |
-
training_loss = round(training_loss, 4)
|
| 320 |
-
|
| 321 |
-
progress_bar.update(1)
|
| 322 |
-
progress_bar.set_postfix({
|
| 323 |
-
"training_loss": training_loss,
|
| 324 |
-
})
|
| 325 |
-
|
| 326 |
-
total_loss = 0.
|
| 327 |
-
total_examples = 0.
|
| 328 |
-
progress_bar = tqdm(
|
| 329 |
-
total=len(valid_data_loader),
|
| 330 |
-
desc="Evaluation; epoch: {}".format(idx_epoch),
|
| 331 |
-
)
|
| 332 |
-
for batch in valid_data_loader:
|
| 333 |
-
mix_spec, speech_irm, snr_db = batch
|
| 334 |
-
mix_spec = mix_spec.to(device)
|
| 335 |
-
speech_irm_target = speech_irm.to(device)
|
| 336 |
-
snr_db_target = snr_db.to(device)
|
| 337 |
-
|
| 338 |
-
with torch.no_grad():
|
| 339 |
-
speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
|
| 340 |
-
if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
|
| 341 |
-
raise AssertionError("nan or inf in speech_irm_prediction")
|
| 342 |
-
if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
|
| 343 |
-
raise AssertionError("nan or inf in lsnr_prediction")
|
| 344 |
-
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
| 345 |
-
lsnr_prediction = (lsnr_prediction - config.lsnr_min) / (config.lsnr_max - config.lsnr_min)
|
| 346 |
-
if torch.max(lsnr_prediction) > 1 or torch.min(lsnr_prediction) < 0:
|
| 347 |
-
raise AssertionError(f"expected lsnr_prediction between 0 and 1.")
|
| 348 |
-
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
| 349 |
-
# loss = irm_loss + 0.1 * snr_loss
|
| 350 |
-
loss = 10.0 * irm_loss + 0.05 * snr_loss
|
| 351 |
-
# loss = irm_loss
|
| 352 |
-
|
| 353 |
-
total_loss += loss.item()
|
| 354 |
-
total_examples += mix_spec.size(0)
|
| 355 |
-
|
| 356 |
-
evaluation_loss = total_loss / total_examples
|
| 357 |
-
evaluation_loss = round(evaluation_loss, 4)
|
| 358 |
-
|
| 359 |
-
progress_bar.update(1)
|
| 360 |
-
progress_bar.set_postfix({
|
| 361 |
-
"evaluation_loss": evaluation_loss,
|
| 362 |
-
})
|
| 363 |
-
|
| 364 |
-
# save path
|
| 365 |
-
epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
|
| 366 |
-
epoch_dir.mkdir(parents=True, exist_ok=False)
|
| 367 |
-
|
| 368 |
-
# save models
|
| 369 |
-
model.save_pretrained(epoch_dir.as_posix())
|
| 370 |
-
|
| 371 |
-
model_list.append(epoch_dir)
|
| 372 |
-
if len(model_list) >= args.num_serialized_models_to_keep:
|
| 373 |
-
model_to_delete: Path = model_list.pop(0)
|
| 374 |
-
shutil.rmtree(model_to_delete.as_posix())
|
| 375 |
-
|
| 376 |
-
# save metric
|
| 377 |
-
if best_metric is None:
|
| 378 |
-
best_idx_epoch = idx_epoch
|
| 379 |
-
best_metric = evaluation_loss
|
| 380 |
-
elif evaluation_loss < best_metric:
|
| 381 |
-
best_idx_epoch = idx_epoch
|
| 382 |
-
best_metric = evaluation_loss
|
| 383 |
-
else:
|
| 384 |
-
pass
|
| 385 |
-
|
| 386 |
-
metrics = {
|
| 387 |
-
"idx_epoch": idx_epoch,
|
| 388 |
-
"best_idx_epoch": best_idx_epoch,
|
| 389 |
-
"training_loss": training_loss,
|
| 390 |
-
"evaluation_loss": evaluation_loss,
|
| 391 |
-
"learning_rate": optimizer.param_groups[0]["lr"],
|
| 392 |
-
}
|
| 393 |
-
metrics_filename = epoch_dir / "metrics_epoch.json"
|
| 394 |
-
with open(metrics_filename, "w", encoding="utf-8") as f:
|
| 395 |
-
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
| 396 |
-
|
| 397 |
-
# save best
|
| 398 |
-
best_dir = serialization_dir / "best"
|
| 399 |
-
if best_idx_epoch == idx_epoch:
|
| 400 |
-
if best_dir.exists():
|
| 401 |
-
shutil.rmtree(best_dir)
|
| 402 |
-
shutil.copytree(epoch_dir, best_dir)
|
| 403 |
-
|
| 404 |
-
# early stop
|
| 405 |
-
early_stop_flag = False
|
| 406 |
-
if best_idx_epoch == idx_epoch:
|
| 407 |
-
patience_count = 0
|
| 408 |
-
else:
|
| 409 |
-
patience_count += 1
|
| 410 |
-
if patience_count >= args.patience:
|
| 411 |
-
early_stop_flag = True
|
| 412 |
-
|
| 413 |
-
# early stop
|
| 414 |
-
if early_stop_flag:
|
| 415 |
-
break
|
| 416 |
-
return
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
if __name__ == '__main__':
|
| 420 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/spectrum_unet_irm_aishell/step_3_evaluation.py
DELETED
|
@@ -1,270 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/python3
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
import argparse
|
| 4 |
-
import logging
|
| 5 |
-
import os
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
import sys
|
| 8 |
-
import uuid
|
| 9 |
-
|
| 10 |
-
pwd = os.path.abspath(os.path.dirname(__file__))
|
| 11 |
-
sys.path.append(os.path.join(pwd, "../../"))
|
| 12 |
-
|
| 13 |
-
import librosa
|
| 14 |
-
import numpy as np
|
| 15 |
-
import pandas as pd
|
| 16 |
-
from scipy.io import wavfile
|
| 17 |
-
import torch
|
| 18 |
-
import torch.nn as nn
|
| 19 |
-
import torchaudio
|
| 20 |
-
from tqdm import tqdm
|
| 21 |
-
|
| 22 |
-
from toolbox.torchaudio.models.spectrum_unet_irm.modeling_spectrum_unet_irm import SpectrumUnetIRMPretrainedModel
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def get_args():
|
| 26 |
-
parser = argparse.ArgumentParser()
|
| 27 |
-
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
| 28 |
-
parser.add_argument("--model_dir", default="serialization_dir/best", type=str)
|
| 29 |
-
parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str)
|
| 30 |
-
|
| 31 |
-
parser.add_argument("--limit", default=10, type=int)
|
| 32 |
-
|
| 33 |
-
args = parser.parse_args()
|
| 34 |
-
return args
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def logging_config():
|
| 38 |
-
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
| 39 |
-
|
| 40 |
-
logging.basicConfig(format=fmt,
|
| 41 |
-
datefmt="%m/%d/%Y %H:%M:%S",
|
| 42 |
-
level=logging.INFO)
|
| 43 |
-
stream_handler = logging.StreamHandler()
|
| 44 |
-
stream_handler.setLevel(logging.INFO)
|
| 45 |
-
stream_handler.setFormatter(logging.Formatter(fmt))
|
| 46 |
-
|
| 47 |
-
logger = logging.getLogger(__name__)
|
| 48 |
-
|
| 49 |
-
return logger
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float):
|
| 53 |
-
l1 = len(speech)
|
| 54 |
-
l2 = len(noise)
|
| 55 |
-
l = min(l1, l2)
|
| 56 |
-
speech = speech[:l]
|
| 57 |
-
noise = noise[:l]
|
| 58 |
-
|
| 59 |
-
# np.float32, value between (-1, 1).
|
| 60 |
-
|
| 61 |
-
speech_power = np.mean(np.square(speech))
|
| 62 |
-
noise_power = speech_power / (10 ** (snr_db / 10))
|
| 63 |
-
|
| 64 |
-
noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2))
|
| 65 |
-
|
| 66 |
-
noisy_signal = speech + noise_adjusted
|
| 67 |
-
|
| 68 |
-
return noisy_signal
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
stft_power = torchaudio.transforms.Spectrogram(
|
| 72 |
-
n_fft=512,
|
| 73 |
-
win_length=200,
|
| 74 |
-
hop_length=80,
|
| 75 |
-
power=2.0,
|
| 76 |
-
window_fn=torch.hamming_window,
|
| 77 |
-
)
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
stft_complex = torchaudio.transforms.Spectrogram(
|
| 81 |
-
n_fft=512,
|
| 82 |
-
win_length=200,
|
| 83 |
-
hop_length=80,
|
| 84 |
-
power=None,
|
| 85 |
-
window_fn=torch.hamming_window,
|
| 86 |
-
)
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
istft = torchaudio.transforms.InverseSpectrogram(
|
| 90 |
-
n_fft=512,
|
| 91 |
-
win_length=200,
|
| 92 |
-
hop_length=80,
|
| 93 |
-
window_fn=torch.hamming_window,
|
| 94 |
-
)
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
def enhance(mix_spec_complex: torch.Tensor, speech_irm_prediction: torch.Tensor):
|
| 98 |
-
mix_spec_complex = mix_spec_complex.detach().cpu()
|
| 99 |
-
speech_irm_prediction = speech_irm_prediction.detach().cpu()
|
| 100 |
-
|
| 101 |
-
mask_speech = speech_irm_prediction
|
| 102 |
-
mask_noise = 1.0 - speech_irm_prediction
|
| 103 |
-
|
| 104 |
-
speech_spec = mix_spec_complex * mask_speech
|
| 105 |
-
noise_spec = mix_spec_complex * mask_noise
|
| 106 |
-
|
| 107 |
-
speech_wave = istft.forward(speech_spec)
|
| 108 |
-
noise_wave = istft.forward(noise_spec)
|
| 109 |
-
|
| 110 |
-
return speech_wave, noise_wave
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
def save_audios(noise_wave: torch.Tensor,
|
| 114 |
-
speech_wave: torch.Tensor,
|
| 115 |
-
mix_wave: torch.Tensor,
|
| 116 |
-
speech_wave_enhanced: torch.Tensor,
|
| 117 |
-
noise_wave_enhanced: torch.Tensor,
|
| 118 |
-
output_dir: str,
|
| 119 |
-
sample_rate: int = 8000,
|
| 120 |
-
):
|
| 121 |
-
basename = uuid.uuid4().__str__()
|
| 122 |
-
output_dir = Path(output_dir) / basename
|
| 123 |
-
output_dir.mkdir(parents=True, exist_ok=True)
|
| 124 |
-
|
| 125 |
-
filename = output_dir / "noise_wave.wav"
|
| 126 |
-
torchaudio.save(filename, noise_wave, sample_rate)
|
| 127 |
-
filename = output_dir / "speech_wave.wav"
|
| 128 |
-
torchaudio.save(filename, speech_wave, sample_rate)
|
| 129 |
-
filename = output_dir / "mix_wave.wav"
|
| 130 |
-
torchaudio.save(filename, mix_wave, sample_rate)
|
| 131 |
-
|
| 132 |
-
filename = output_dir / "speech_wave_enhanced.wav"
|
| 133 |
-
torchaudio.save(filename, speech_wave_enhanced, sample_rate)
|
| 134 |
-
filename = output_dir / "noise_wave_enhanced.wav"
|
| 135 |
-
torchaudio.save(filename, noise_wave_enhanced, sample_rate)
|
| 136 |
-
|
| 137 |
-
return output_dir.as_posix()
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
def main():
|
| 141 |
-
args = get_args()
|
| 142 |
-
|
| 143 |
-
logger = logging_config()
|
| 144 |
-
|
| 145 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 146 |
-
n_gpu = torch.cuda.device_count()
|
| 147 |
-
logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
|
| 148 |
-
|
| 149 |
-
logger.info("prepare model")
|
| 150 |
-
model = SpectrumUnetIRMPretrainedModel.from_pretrained(
|
| 151 |
-
pretrained_model_name_or_path=args.model_dir,
|
| 152 |
-
)
|
| 153 |
-
model.to(device)
|
| 154 |
-
model.eval()
|
| 155 |
-
|
| 156 |
-
# optimizer
|
| 157 |
-
logger.info("prepare loss_fn")
|
| 158 |
-
irm_mse_loss = nn.MSELoss(
|
| 159 |
-
reduction="mean",
|
| 160 |
-
)
|
| 161 |
-
snr_mse_loss = nn.MSELoss(
|
| 162 |
-
reduction="mean",
|
| 163 |
-
)
|
| 164 |
-
|
| 165 |
-
logger.info("read excel")
|
| 166 |
-
df = pd.read_excel(args.valid_dataset)
|
| 167 |
-
|
| 168 |
-
total_loss = 0.
|
| 169 |
-
total_examples = 0.
|
| 170 |
-
progress_bar = tqdm(total=len(df), desc="Evaluation")
|
| 171 |
-
for idx, row in df.iterrows():
|
| 172 |
-
noise_filename = row["noise_filename"]
|
| 173 |
-
noise_offset = row["noise_offset"]
|
| 174 |
-
noise_duration = row["noise_duration"]
|
| 175 |
-
|
| 176 |
-
speech_filename = row["speech_filename"]
|
| 177 |
-
speech_offset = row["speech_offset"]
|
| 178 |
-
speech_duration = row["speech_duration"]
|
| 179 |
-
|
| 180 |
-
snr_db = row["snr_db"]
|
| 181 |
-
|
| 182 |
-
noise_wave, _ = librosa.load(
|
| 183 |
-
noise_filename,
|
| 184 |
-
sr=8000,
|
| 185 |
-
offset=noise_offset,
|
| 186 |
-
duration=noise_duration,
|
| 187 |
-
)
|
| 188 |
-
speech_wave, _ = librosa.load(
|
| 189 |
-
speech_filename,
|
| 190 |
-
sr=8000,
|
| 191 |
-
offset=speech_offset,
|
| 192 |
-
duration=speech_duration,
|
| 193 |
-
)
|
| 194 |
-
mix_wave: np.ndarray = mix_speech_and_noise(
|
| 195 |
-
speech=speech_wave,
|
| 196 |
-
noise=noise_wave,
|
| 197 |
-
snr_db=snr_db,
|
| 198 |
-
)
|
| 199 |
-
noise_wave = torch.tensor(noise_wave, dtype=torch.float32)
|
| 200 |
-
speech_wave = torch.tensor(speech_wave, dtype=torch.float32)
|
| 201 |
-
mix_wave: torch.Tensor = torch.tensor(mix_wave, dtype=torch.float32)
|
| 202 |
-
|
| 203 |
-
noise_wave = noise_wave.unsqueeze(dim=0)
|
| 204 |
-
speech_wave = speech_wave.unsqueeze(dim=0)
|
| 205 |
-
mix_wave = mix_wave.unsqueeze(dim=0)
|
| 206 |
-
|
| 207 |
-
noise_spec: torch.Tensor = stft_power.forward(noise_wave)
|
| 208 |
-
speech_spec: torch.Tensor = stft_power.forward(speech_wave)
|
| 209 |
-
mix_spec: torch.Tensor = stft_power.forward(mix_wave)
|
| 210 |
-
|
| 211 |
-
noise_spec = noise_spec[:, :-1, :]
|
| 212 |
-
speech_spec = speech_spec[:, :-1, :]
|
| 213 |
-
mix_spec = mix_spec[:, :-1, :]
|
| 214 |
-
|
| 215 |
-
mix_spec_complex: torch.Tensor = stft_complex.forward(mix_wave)
|
| 216 |
-
# mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2]
|
| 217 |
-
|
| 218 |
-
speech_irm = speech_spec / (noise_spec + speech_spec)
|
| 219 |
-
speech_irm = torch.pow(speech_irm, 1.0)
|
| 220 |
-
|
| 221 |
-
snr_db: torch.Tensor = 10 * torch.log10(
|
| 222 |
-
speech_spec / (noise_spec + 1e-8)
|
| 223 |
-
)
|
| 224 |
-
snr_db = torch.mean(snr_db, dim=1, keepdim=True)
|
| 225 |
-
# snr_db shape: [batch_size, 1, time_steps]
|
| 226 |
-
|
| 227 |
-
mix_spec = mix_spec.to(device)
|
| 228 |
-
speech_irm_target = speech_irm.to(device)
|
| 229 |
-
snr_db_target = snr_db.to(device)
|
| 230 |
-
|
| 231 |
-
with torch.no_grad():
|
| 232 |
-
speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
|
| 233 |
-
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
| 234 |
-
# snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
| 235 |
-
# loss = irm_loss + 0.1 * snr_loss
|
| 236 |
-
loss = irm_loss
|
| 237 |
-
|
| 238 |
-
# mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2]
|
| 239 |
-
# speech_irm_prediction shape: [batch_size, freq_dim (256), time_steps]
|
| 240 |
-
batch_size, _, time_steps = speech_irm_prediction.shape
|
| 241 |
-
speech_irm_prediction = torch.concat(
|
| 242 |
-
[
|
| 243 |
-
speech_irm_prediction,
|
| 244 |
-
0.5*torch.ones(size=(batch_size, 1, time_steps), dtype=speech_irm_prediction.dtype).to(device)
|
| 245 |
-
],
|
| 246 |
-
dim=1,
|
| 247 |
-
)
|
| 248 |
-
# speech_irm_prediction shape: [batch_size, freq_dim (257), time_steps]
|
| 249 |
-
speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_irm_prediction)
|
| 250 |
-
save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir)
|
| 251 |
-
|
| 252 |
-
total_loss += loss.item()
|
| 253 |
-
total_examples += mix_spec.size(0)
|
| 254 |
-
|
| 255 |
-
evaluation_loss = total_loss / total_examples
|
| 256 |
-
evaluation_loss = round(evaluation_loss, 4)
|
| 257 |
-
|
| 258 |
-
progress_bar.update(1)
|
| 259 |
-
progress_bar.set_postfix({
|
| 260 |
-
"evaluation_loss": evaluation_loss,
|
| 261 |
-
})
|
| 262 |
-
|
| 263 |
-
if idx > args.limit:
|
| 264 |
-
break
|
| 265 |
-
|
| 266 |
-
return
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
if __name__ == '__main__':
|
| 270 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/spectrum_unet_irm_aishell/yaml/config.yaml
DELETED
|
@@ -1,38 +0,0 @@
|
|
| 1 |
-
model_name: "spectrum_unet_irm"
|
| 2 |
-
|
| 3 |
-
# spec
|
| 4 |
-
sample_rate: 8000
|
| 5 |
-
n_fft: 512
|
| 6 |
-
win_length: 200
|
| 7 |
-
hop_length: 80
|
| 8 |
-
|
| 9 |
-
spec_bins: 256
|
| 10 |
-
|
| 11 |
-
# model
|
| 12 |
-
conv_channels: 64
|
| 13 |
-
conv_kernel_size_input:
|
| 14 |
-
- 3
|
| 15 |
-
- 3
|
| 16 |
-
conv_kernel_size_inner:
|
| 17 |
-
- 1
|
| 18 |
-
- 3
|
| 19 |
-
conv_lookahead: 0
|
| 20 |
-
|
| 21 |
-
convt_kernel_size_inner:
|
| 22 |
-
- 1
|
| 23 |
-
- 3
|
| 24 |
-
|
| 25 |
-
encoder_emb_skip_op: "none"
|
| 26 |
-
encoder_emb_linear_groups: 16
|
| 27 |
-
encoder_emb_hidden_size: 256
|
| 28 |
-
|
| 29 |
-
lsnr_max: 30
|
| 30 |
-
lsnr_min: -15
|
| 31 |
-
|
| 32 |
-
decoder_emb_num_layers: 3
|
| 33 |
-
decoder_emb_skip_op: "none"
|
| 34 |
-
decoder_emb_linear_groups: 16
|
| 35 |
-
decoder_emb_hidden_size: 256
|
| 36 |
-
|
| 37 |
-
# runtime
|
| 38 |
-
use_post_filter: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
#!/usr/bin/python3
|
| 2 |
# -*- coding: utf-8 -*-
|
| 3 |
"""
|
| 4 |
-
docker build -t denoise:
|
| 5 |
docker stop denoise_7865 && docker rm denoise_7865
|
| 6 |
docker run -itd \
|
| 7 |
--name denoise_7865 \
|
|
|
|
| 1 |
#!/usr/bin/python3
|
| 2 |
# -*- coding: utf-8 -*-
|
| 3 |
"""
|
| 4 |
+
docker build -t denoise:v20250626_1616 .
|
| 5 |
docker stop denoise_7865 && docker rm denoise_7865
|
| 6 |
docker run -itd \
|
| 7 |
--name denoise_7865 \
|
toolbox/torch/utils/data/dataset/mp3_to_wav_jsonl_dataset.py
DELETED
|
@@ -1,197 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/python3
|
| 2 |
-
# -*- coding: utf-8 -*-
|
| 3 |
-
import json
|
| 4 |
-
import os
|
| 5 |
-
import random
|
| 6 |
-
from typing import List
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
import tempfile
|
| 9 |
-
import uuid
|
| 10 |
-
|
| 11 |
-
from pydub import AudioSegment
|
| 12 |
-
from scipy.io import wavfile
|
| 13 |
-
import librosa
|
| 14 |
-
import numpy as np
|
| 15 |
-
import torch
|
| 16 |
-
from torch.utils.data import Dataset, IterableDataset
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
class Mp3ToWavJsonlDataset(IterableDataset):
|
| 20 |
-
def __init__(self,
|
| 21 |
-
jsonl_file: str,
|
| 22 |
-
expected_sample_rate: int,
|
| 23 |
-
resample: bool = False,
|
| 24 |
-
max_wave_value: float = 1.0,
|
| 25 |
-
buffer_size: int = 1000,
|
| 26 |
-
eps: float = 1e-8,
|
| 27 |
-
skip: int = 0,
|
| 28 |
-
):
|
| 29 |
-
self.jsonl_file = jsonl_file
|
| 30 |
-
self.expected_sample_rate = expected_sample_rate
|
| 31 |
-
self.resample = resample
|
| 32 |
-
self.max_wave_value = max_wave_value
|
| 33 |
-
self.eps = eps
|
| 34 |
-
self.skip = skip
|
| 35 |
-
|
| 36 |
-
self.buffer_size = buffer_size
|
| 37 |
-
self.buffer_samples: List[dict] = list()
|
| 38 |
-
|
| 39 |
-
def __iter__(self):
|
| 40 |
-
self.buffer_samples = list()
|
| 41 |
-
|
| 42 |
-
iterable_source = self.iterable_source()
|
| 43 |
-
|
| 44 |
-
try:
|
| 45 |
-
for _ in range(self.skip):
|
| 46 |
-
next(iterable_source)
|
| 47 |
-
except StopIteration:
|
| 48 |
-
pass
|
| 49 |
-
|
| 50 |
-
# 初始填充缓冲区
|
| 51 |
-
try:
|
| 52 |
-
for _ in range(self.buffer_size):
|
| 53 |
-
self.buffer_samples.append(next(iterable_source))
|
| 54 |
-
except StopIteration:
|
| 55 |
-
pass
|
| 56 |
-
|
| 57 |
-
# 动态替换逻辑
|
| 58 |
-
while True:
|
| 59 |
-
try:
|
| 60 |
-
item = next(iterable_source)
|
| 61 |
-
# 随机替换缓冲区元素
|
| 62 |
-
replace_idx = random.randint(0, len(self.buffer_samples) - 1)
|
| 63 |
-
sample = self.buffer_samples[replace_idx]
|
| 64 |
-
self.buffer_samples[replace_idx] = item
|
| 65 |
-
yield self.convert_sample(sample)
|
| 66 |
-
except StopIteration:
|
| 67 |
-
break
|
| 68 |
-
|
| 69 |
-
# 清空剩余元素
|
| 70 |
-
random.shuffle(self.buffer_samples)
|
| 71 |
-
for sample in self.buffer_samples:
|
| 72 |
-
yield self.convert_sample(sample)
|
| 73 |
-
|
| 74 |
-
def iterable_source(self):
|
| 75 |
-
last_sample = None
|
| 76 |
-
with open(self.jsonl_file, "r", encoding="utf-8") as f:
|
| 77 |
-
for row in f:
|
| 78 |
-
row = json.loads(row)
|
| 79 |
-
filename = row["filename"]
|
| 80 |
-
raw_duration = row["raw_duration"]
|
| 81 |
-
offset = row["offset"]
|
| 82 |
-
duration = row["duration"]
|
| 83 |
-
|
| 84 |
-
sample = {
|
| 85 |
-
"filename": filename,
|
| 86 |
-
"raw_duration": raw_duration,
|
| 87 |
-
"offset": offset,
|
| 88 |
-
"duration": duration,
|
| 89 |
-
}
|
| 90 |
-
if last_sample is None:
|
| 91 |
-
last_sample = sample
|
| 92 |
-
continue
|
| 93 |
-
yield sample
|
| 94 |
-
yield last_sample
|
| 95 |
-
|
| 96 |
-
def convert_sample(self, sample: dict):
|
| 97 |
-
filename = sample["filename"]
|
| 98 |
-
offset = sample["offset"]
|
| 99 |
-
duration = sample["duration"]
|
| 100 |
-
|
| 101 |
-
wav_waveform = self.filename_to_waveform(filename, offset, duration)
|
| 102 |
-
mp3_waveform = self.filename_to_mp3_waveform(filename, offset, duration)
|
| 103 |
-
|
| 104 |
-
if wav_waveform.shape != mp3_waveform.shape:
|
| 105 |
-
raise AssertionError(f"wav_waveform: {wav_waveform.shape}, mp3_waveform: {mp3_waveform.shape}")
|
| 106 |
-
|
| 107 |
-
result = {
|
| 108 |
-
"mp3_waveform": mp3_waveform,
|
| 109 |
-
"wav_waveform": wav_waveform,
|
| 110 |
-
}
|
| 111 |
-
return result
|
| 112 |
-
|
| 113 |
-
@staticmethod
|
| 114 |
-
def filename_to_waveform(filename: str, offset: float, duration: float, expected_sample_rate: int = 8000):
|
| 115 |
-
try:
|
| 116 |
-
waveform, sample_rate = librosa.load(
|
| 117 |
-
filename,
|
| 118 |
-
sr=expected_sample_rate,
|
| 119 |
-
offset=offset,
|
| 120 |
-
duration=duration,
|
| 121 |
-
)
|
| 122 |
-
except ValueError as e:
|
| 123 |
-
print(f"load failed. error type: {type(e)}, error text: {str(e)}, filename: {filename}")
|
| 124 |
-
raise e
|
| 125 |
-
waveform = torch.tensor(waveform, dtype=torch.float32)
|
| 126 |
-
return waveform
|
| 127 |
-
|
| 128 |
-
@staticmethod
|
| 129 |
-
def get_temporary_file(suffix: str = ".wav"):
|
| 130 |
-
temp_audio_dir = Path(tempfile.gettempdir()) / "mp3_to_wav_jsonl_dataset"
|
| 131 |
-
temp_audio_dir.mkdir(parents=True, exist_ok=True)
|
| 132 |
-
filename = temp_audio_dir / f"{uuid.uuid4()}{suffix}"
|
| 133 |
-
filename = filename.as_posix()
|
| 134 |
-
return filename
|
| 135 |
-
|
| 136 |
-
@staticmethod
|
| 137 |
-
def filename_to_mp3_waveform(filename: str, offset: float, duration: float, expected_sample_rate: int = 8000):
|
| 138 |
-
try:
|
| 139 |
-
waveform, sample_rate = librosa.load(
|
| 140 |
-
filename,
|
| 141 |
-
sr=expected_sample_rate,
|
| 142 |
-
offset=offset,
|
| 143 |
-
duration=duration,
|
| 144 |
-
)
|
| 145 |
-
waveform = np.array(waveform * (1 << 15), dtype=np.int16)
|
| 146 |
-
except ValueError as e:
|
| 147 |
-
print(f"load failed. error type: {type(e)}, error text: {str(e)}, filename: {filename}")
|
| 148 |
-
raise e
|
| 149 |
-
|
| 150 |
-
wav_temporary_file = Mp3ToWavJsonlDataset.get_temporary_file(suffix=".wav")
|
| 151 |
-
wavfile.write(
|
| 152 |
-
wav_temporary_file,
|
| 153 |
-
rate=sample_rate,
|
| 154 |
-
data=waveform,
|
| 155 |
-
)
|
| 156 |
-
|
| 157 |
-
mp3_temporary_file = Mp3ToWavJsonlDataset.get_temporary_file(suffix=".mp3")
|
| 158 |
-
|
| 159 |
-
audio = AudioSegment.from_wav(wav_temporary_file)
|
| 160 |
-
audio.export(mp3_temporary_file,
|
| 161 |
-
format="mp3",
|
| 162 |
-
bitrate="64k", # 8kHz建议使用64kbps
|
| 163 |
-
# parameters=["-ar", "8000"]
|
| 164 |
-
parameters=["-ar", f"{expected_sample_rate}"]
|
| 165 |
-
)
|
| 166 |
-
|
| 167 |
-
try:
|
| 168 |
-
waveform, sample_rate = librosa.load(mp3_temporary_file, sr=expected_sample_rate)
|
| 169 |
-
except ValueError as e:
|
| 170 |
-
print(f"load failed. error type: {type(e)}, error text: {str(e)}, filename: {filename}")
|
| 171 |
-
raise e
|
| 172 |
-
|
| 173 |
-
os.remove(wav_temporary_file)
|
| 174 |
-
os.remove(mp3_temporary_file)
|
| 175 |
-
|
| 176 |
-
waveform = torch.tensor(waveform, dtype=torch.float32)
|
| 177 |
-
return waveform
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
def main():
|
| 181 |
-
filename = r"E:\Users\tianx\HuggingDatasets\nx_noise\data\speech\nx-speech\en-PH\2025-06-13\active_media_r_2e6e6303-4a2e-4bc9-b814-98ceddc59e9d_23.wav"
|
| 182 |
-
|
| 183 |
-
waveform = Mp3ToWavJsonlDataset.filename_to_mp3_waveform(filename, offset=0, duration=15)
|
| 184 |
-
print(waveform.shape)
|
| 185 |
-
|
| 186 |
-
signal = np.array(waveform.numpy() * (1 << 15), dtype=np.int16)
|
| 187 |
-
|
| 188 |
-
wavfile.write(
|
| 189 |
-
"temp.wav",
|
| 190 |
-
8000,
|
| 191 |
-
signal,
|
| 192 |
-
)
|
| 193 |
-
return
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
if __name__ == "__main__":
|
| 197 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|