-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathpreprocess.py
119 lines (102 loc) · 5.04 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
from torchaudio.functional import magphase
from functools import partial
import torch
class OnlinePreprocessor(torch.nn.Module):
def __init__(self, sample_rate=16000, win_len=512, hop_len=256, n_freq=257, feat_list=None, **kwargs):
super(OnlinePreprocessor, self).__init__()
n_fft = (n_freq - 1) * 2
self._win_args = {'n_fft': n_fft,
'hop_length': hop_len, 'win_length': win_len}
self.register_buffer('_window', torch.hann_window(win_len))
self._stft_args = {'center': True, 'pad_mode': 'reflect',
'normalized': False, 'onesided': True}
self._istft_args = {'center': True,
'normalized': False, 'onesided': True}
# stft_args: same default values as torchaudio.transforms.Spectrogram & librosa.core.spectrum._spectrogram
self._stft = partial(torch.stft, **self._win_args, **self._stft_args)
self._istft = partial(
torch.istft, **self._win_args, **self._istft_args)
self._magphase = partial(magphase, power=2)
self.feat_list = feat_list
self.register_buffer('_pseudo_wav', torch.randn(
2, 2, sample_rate)) # batch_size=2, channel_size=2
def _check_list(self, feat_list):
if feat_list is None:
feat_list = self.feat_list
assert type(feat_list) is list
return feat_list
def _transpose_list(self, feats):
return [feat.transpose(-1, -2).contiguous() if type(feat) is torch.Tensor else feat for feat in feats]
@classmethod
def get_feat_config(cls, feat_type, channel=0, log=False):
assert feat_type in ['complx', 'linear', 'phase']
assert type(channel) is int
assert type(log) is bool
return {
'feat_type': feat_type,
'channel': channel,
'log': log,
}
def forward(self, wavs=None, feat_list=None):
# wavs: (*, channel_size, max_len)
feat_list = self._check_list(feat_list)
if wavs is None:
max_channel_id = max(
[int(args['channel']) if 'channel' in args else 0 for args in feat_list])
wavs = self._pseudo_wav.expand(-1, max_channel_id + 1, -1)
assert wavs.dim() >= 3
shape = wavs.size()
complx = self._stft(wavs.reshape(-1, shape[-1]), window=self._window)
complx = complx.reshape(shape[:-1] + complx.shape[-3:])
# complx: (*, channel_size, feat_dim, max_len, 2)
linear, phase = self._magphase(complx)
linear = linear.sqrt()
complx = complx.transpose(-1, -2).reshape(*
linear.shape[:2], -1, linear.size(-1))
# complx, linear, phase: (*, channel_size, feat_dim, max_len)
def select_feat(variables, feat_type, channel=0, log=False):
raw_feat = variables[feat_type].select(dim=-3, index=channel)
# apply log scale
if bool(log):
raw_feat = (raw_feat + 1e-10).log()
feats = raw_feat.contiguous()
return feats
# return: (*, feat_dim, max_len)
local_variables = locals()
return self._transpose_list([select_feat(local_variables, **args) for args in feat_list])
# return: [(*, max_len, feat_dim), ...]
def istft(self, linears=None, phases=None, linear_power=2, complxs=None, length=None):
assert complxs is not None or (
linears is not None and phases is not None)
# complxs: (*, n_freq, max_feat_len, 2) or (*, max_feat_len, n_freq * 2)
# linears, phases: (*, max_feat_len, n_freq)
if complxs is None:
linears = linears.pow(2)
linears, phases = self._transpose_list([linears, phases])
complxs = linears.pow(1/linear_power).unsqueeze(-1) * \
torch.stack([phases.cos(), phases.sin()], dim=-1)
if complxs.size(-1) != 2:
# treat complxs as: (*, max_feat_len, n_freq * 2)
shape = complxs.size()
complxs = complxs.view(
*shape[:-1], -1, 2).transpose(-2, -3).contiguous()
# complxs: (*, n_freq, max_feat_len, 2)
return self._istft(complxs, window=self._window, length=length)
# return: (*, max_wav_len)
def test_istft(self, wavs=None, epsilon=1e-6):
# wavs: (*, channel_size, max_wav_len)
if wavs is None:
wavs = self._pseudo_wav
channel1, channel2 = 0, 1
feat_list = [
{'feat_type': 'complx', 'channel': channel1},
{'feat_type': 'linear', 'channel': channel2},
{'feat_type': 'phase', 'channel': channel2}
]
complxs, linears, phases = self.forward(wavs, feat_list)
assert torch.allclose(wavs.select(
dim=-2, index=channel1), self.istft(complxs=complxs, length=16000), atol=epsilon)
assert torch.allclose(wavs.select(dim=-2, index=channel2),
self.istft(linears=linears, phases=phases, length=16000), atol=epsilon)
print('[Test passed] stft -> istft')