-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 7b3e052
Showing
29 changed files
with
1,373 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2023 konas122 | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# 声纹识别 | ||
|
||
## python第三方库 | ||
|
||
``` | ||
python=3.8 | ||
tensorboardX=2.6 | ||
tensorboard=2.11.2 | ||
scipy=1.4.1 | ||
numpy=1.23.5 | ||
librosa=0.9.2 | ||
torch=1.8.1 | ||
torchaudio=0.8.1 | ||
torchvision=0.9.1 | ||
``` | ||
|
||
| ||
## 训练 | ||
运行 `train.py` 进行训练。 | ||
|
||
该网络是在 `resnet18` 或 `vgg19` 的基础上再添加 LSTM 和线性层,从而实现声纹识别。 | ||
该项目同时也保留了单用 CNN 的方法( `net_cnn.py` )来实现声纹识别,其实效果也不差。 | ||
|
||
|
||
| ||
## 训练数据 | ||
这是我所用的数据集:https://pan.baidu.com/s/1_KrjPB27AHPrBa_1AeMQSQ?pwd=0mag 提取码:0mag | ||
|
||
当然,也可以用自己的数据集。只需在 `train.py` 的相同目录下创建 `data` 文件夹,并在 `data` 下创建子文件夹 `train`,然后将自己的训练数据放到 `train` 中。目前,这代码仅支持 `.wav` 格式的训练音频。 | ||
|
||
| ||
|
||
### Acknowledge | ||
|
||
We study many useful projects in our codeing process, which includes: | ||
|
||
[clovaai/voxceleb_trainer](https://github.com/clovaai/voxceleb_trainer). | ||
|
||
[lawlict/ECAPA-TDNN](https://github.com/lawlict/ECAPA-TDNN/blob/master/ecapa_tdnn.py). | ||
|
||
[TaoRuijie/ECAPA-TDNN](https://github.com/TaoRuijie/ECAPA-TDNN) | ||
|
||
Thanks for these authors to open source their code! | ||
|
||
未完待续... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
import torch | ||
import random | ||
import librosa | ||
import numpy as np | ||
import librosa.display | ||
from scipy.signal import medfilt | ||
import matplotlib.pyplot as plt | ||
# import torchaudio.transforms as T | ||
|
||
|
||
path = '.\\voices' | ||
name = 'a001.wav' | ||
audio_filename = ".\\data\\test\\G2231\\T0055G2231S0076.wav" | ||
|
||
|
||
def noise_augmentation(samples, min_db=40, max_db=80): | ||
samples = samples.copy() # frombuffer()导致数据不可更改因此使用拷贝 | ||
data_type = samples[0].dtype | ||
db = np.random.randint(low=min_db, high=max_db) | ||
db *= 1e-6 | ||
noise = db * np.random.normal(0, 1, len(samples)) # 高斯分布 | ||
# print(db) | ||
samples = samples + noise | ||
samples = samples.astype(data_type) | ||
return samples | ||
|
||
|
||
def add_noise(x, snr, method='vectorized', axis=0): | ||
# Signal power | ||
if method == 'vectorized': | ||
N = x.size | ||
Ps = np.sum(x ** 2 / N) | ||
elif method == 'max_en': | ||
N = x.shape[axis] | ||
Ps = np.max(np.sum(x ** 2 / N, axis=axis)) | ||
elif method == 'axial': | ||
N = x.shape[axis] | ||
Ps = np.sum(x ** 2 / N, axis=axis) | ||
else: | ||
raise ValueError('method \"' + str(method) + '\" not recognized.') | ||
|
||
Psdb = 10 * np.log10(Ps) # Signal power, in dB | ||
Pn = Psdb - snr # Noise level necessary | ||
n = np.sqrt(10 ** (Pn / 10)) * np.random.normal(0, 1, x.shape) # Noise vector (or matrix) | ||
return x + n | ||
|
||
|
||
def load_spectrogram(filename): | ||
wav, fs = librosa.load(filename, sr=16000) | ||
mag = librosa.feature.melspectrogram(y=wav, sr=16000, n_fft=512, n_mels=80, | ||
win_length=400, hop_length=160) | ||
mag = librosa.power_to_db(mag, ref=1.0, amin=1e-10, top_db=None) | ||
librosa.display.specshow(mag, sr=16000, x_axis='time', y_axis='mel') # 画mel谱图 | ||
plt.show() | ||
|
||
return mag | ||
|
||
|
||
def audio_to_wav(filename, sr=16000, noise=False): | ||
wav, fs = librosa.load(filename, sr=sr) | ||
|
||
# wav1 = load_spectrogram(wav) | ||
# t = T.MelSpectrogram(sample_rate=16000, n_fft=512, win_length=400, hop_length=160, | ||
# f_min=20, f_max=7600, window_fn=torch.hamming_window, n_mels=80) | ||
# wav2 = torch.from_numpy(wav) | ||
# wav2 = t(wav2) | ||
|
||
extended_wav = np.append(wav, wav) | ||
if len(extended_wav) < 41000: | ||
extended_wav = np.append(extended_wav, wav) | ||
if noise: | ||
extended_wav = add_noise(extended_wav, fs) | ||
return extended_wav, fs | ||
|
||
|
||
def loadWAV(filename, noise=False): | ||
y, sr = audio_to_wav(filename=filename, noise=noise) | ||
assert len(y) >= 41000, f'Error: file {filename}\n' | ||
num = random.randint(0, len(y) - 41000) | ||
y = y[num:num + 41000] | ||
y = torch.from_numpy(y).float() | ||
return y | ||
|
||
|
||
def load_pure_wav(filename, frame_threshold=10, noise=False): | ||
y, sr = audio_to_wav(filename=filename, noise=noise) | ||
mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=24, win_length=1024, hop_length=512, n_fft=1024) | ||
Mfcc1 = medfilt(mfcc[0, :], 9) # 对mfcc进行中值滤波 | ||
pic = Mfcc1 | ||
start = 0 | ||
end = 0 | ||
points = [] | ||
min_data = min(pic) * 0.9 | ||
for i in range((pic.shape[0])): | ||
if pic[i] < min_data and start == 0: | ||
start = i | ||
if pic[i] < min_data and start != 0: | ||
end = i | ||
elif pic[i] > min_data and start != 0: | ||
hh = [start, end] | ||
points.append(hh) | ||
start = 0 | ||
if pic[-1] < min_data and start != 0: # 解决 文件的最后为静音 | ||
hh = [start, end] | ||
points.append(hh) | ||
distances = [] | ||
for i in range(len(points)): | ||
two_ends = points[i] | ||
distance = two_ends[1] - two_ends[0] | ||
if distance > frame_threshold: | ||
distances.append(points[i]) | ||
|
||
# out, _ = soundfile.read(filename) | ||
# out = out.astype(np.float32) | ||
if len(distances) == 0: # 无静音段 | ||
return y | ||
else: | ||
silence_data = [] | ||
for i in range(len(distances)): | ||
if i == 0: | ||
start, end = distances[i] | ||
if start == 1: | ||
internal_clean = y[0:0] | ||
else: | ||
start = (start - 1) * 512 # 求取开始帧的开头 | ||
# end = (end - 1) * 512 + 1024 | ||
internal_clean = y[0:start - 1] | ||
else: | ||
_, end = distances[i - 1] | ||
start, _ = distances[i] | ||
start = (start - 1) * 512 | ||
end = (end - 1) * 512 + 1024 # 求取结束帧的结尾 | ||
internal_clean = y[end + 1:start] | ||
# hhh = np.array(internal_clean) | ||
silence_data.extend(internal_clean) | ||
ll = len(distances) # 结尾音频处理 | ||
_, end = distances[ll - 1] | ||
end = (end - 1) * 512 + 1024 | ||
end_part_clean = y[end:len(y)] | ||
silence_data.extend(end_part_clean) | ||
y = silence_data | ||
y = torch.from_numpy(np.array(y)).float() | ||
return y | ||
|
||
|
||
if __name__ == '__main__': | ||
a = load_pure_wav(audio_filename, noise=True) | ||
print(a.shape, a.dtype) | ||
_ = load_spectrogram(audio_filename) | ||
# a = np.array([[[-11, -10, -9, -8], | ||
# [-7, -6, -5, -4], | ||
# [-3, -2, -1, 0]], | ||
# [[1, 2, 3, 4], | ||
# [5, 6, 7, 8], | ||
# [9, 10, 11, 12]]]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import torch | ||
import time | ||
from torch import nn | ||
# from d2l import torch as d2l | ||
|
||
|
||
class Timer: | ||
def __init__(self): | ||
self.times = [] | ||
self.tik = None | ||
self.start() | ||
|
||
def start(self): | ||
self.tik = time.time() | ||
|
||
def stop(self): | ||
self.times.append(time.time() - self.tik) | ||
return self.times[-1] | ||
|
||
def avg(self): | ||
return sum(self.times) / len(self.times) | ||
|
||
def sum(self): | ||
return sum(self.times) | ||
|
||
|
||
class Accumulator: | ||
def __init__(self, n): | ||
self.data = [0.0] * n | ||
|
||
def add(self, *args): | ||
self.data = [a + float(b) for a, b in zip(self.data, args)] | ||
|
||
def reset(self): | ||
self.data = [0.0] * len(self.data) | ||
|
||
def __getitem__(self, idx): | ||
return self.data[idx] | ||
|
||
|
||
def try_gpu(i=0): | ||
if torch.cuda.device_count() >= i + 1: | ||
return torch.device(f'cuda:{i}') | ||
return torch.device('cpu') | ||
|
||
|
||
def accuracy(y_hat, y): | ||
if len(y_hat.shape) > 1 and y_hat.shape[1] > 1: | ||
y_hat = y_hat.argmax(axis=1) | ||
astype = lambda x, *args, **kwargs: x.type(*args, **kwargs) | ||
cmp = astype(y_hat, y.dtype) == y | ||
reduce_sum = lambda x, *args, **kwargs: x.sum(*args, **kwargs) | ||
return float(reduce_sum(astype(cmp, y.dtype))) | ||
|
||
|
||
def evaluate_accuracy_gpu(net, data_iter, device=None): | ||
if isinstance(net, nn.Module): | ||
net.eval() | ||
if not device: | ||
device = next(iter(net.parameters())).device | ||
metric = Accumulator(2) | ||
|
||
with torch.no_grad(): | ||
for X, y in data_iter: | ||
if isinstance(X, list): | ||
X = [x.to(device) for x in X] | ||
else: | ||
X = X.to(device) | ||
y = y.to(device) | ||
size = lambda x, *args, **kwargs: x.numel(*args, **kwargs) | ||
metric.add(accuracy(net(X), y), size(y)) | ||
return metric[0] / metric[1] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import torch | ||
import loader | ||
import train as t | ||
import eval as d2l | ||
# import torch_directml | ||
from loss import AAMSoftmax | ||
# from d2l import torch as d2l | ||
from tensorboardX import SummaryWriter | ||
from torch.utils.data import DataLoader | ||
from models.tdnn_pretrain import Pretrain_TDNN | ||
|
||
|
||
def load_model(path, output_num, device, not_grad=False): | ||
load_net = torch.load(path, map_location=device) | ||
model = Pretrain_TDNN(output_num, 1024, output_embedding=False, not_grad=not_grad) | ||
model.speaker_encoder = load_net.speaker_encoder | ||
del load_net | ||
return model | ||
|
||
|
||
if __name__ == "__main__": | ||
people_num, data_per_people = 420, 10 | ||
noise, mel, reverse = False, True, False | ||
margin, scale, easy_margin = 0.2, 20, False | ||
num_epochs, learn_rate, weight_decay = 40, 0.1, 1e-3 | ||
learn_rate_period, learn_rate_decay = 10, 0.95 | ||
mode, model_name = "train", "resnet18" | ||
hidden_size, num_layers = 64, 2 | ||
|
||
# Device = torch_directml.device() | ||
# prefetch_factor, batch_size, num_works, persistent = 2, 32, 8, False | ||
|
||
Device = d2l.try_gpu() | ||
if Device.type == 'cpu': | ||
prefetch_factor, batch_size, num_works, persistent = 2, 8, 8, False | ||
elif torch.cuda.is_available(): | ||
prefetch_factor, batch_size, num_works, persistent = 8, 256, 32, True | ||
else: | ||
prefetch_factor, batch_size, num_works, persistent = 2, 32, 8, False | ||
|
||
t.init_logs() | ||
train_dict, test_dict, people_num = loader.load_files(mode=mode, folder_num=people_num, | ||
file_num=data_per_people, k=1) | ||
train_dataset = loader.MyDataset(data_dict=train_dict, people_num=people_num, train=True, | ||
mel=mel, noise=noise) | ||
test_dataset = loader.MyDataset(data_dict=test_dict, people_num=people_num, train=False, | ||
mel=mel, noise=noise) | ||
print(len(train_dataset), len(test_dataset)) | ||
train_ = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, | ||
drop_last=True, num_workers=num_works, pin_memory=True, | ||
persistent_workers=persistent, prefetch_factor=prefetch_factor) | ||
test_ = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True, | ||
drop_last=True, num_workers=num_works, pin_memory=True, | ||
persistent_workers=persistent, prefetch_factor=prefetch_factor) | ||
|
||
# pth_path = 'test.pth' | ||
# model2 = load_model(pth_path, people_num, Device, not_grad=True) | ||
|
||
model2 = Pretrain_TDNN(people_num, 1024, output_embedding=False, not_grad=False) | ||
model2.load_parameters('param.model', Device) | ||
|
||
loss = AAMSoftmax(192, people_num, margin, scale, easy_margin) | ||
writer = SummaryWriter('./logs') | ||
t.train(train_, test_, model2, loss, Device, writer, num_epochs, learn_rate, weight_decay) | ||
model2.save_parameters('param2.model') |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.