Skip to content

Commit

Permalink
Merge pull request #99 from THasthika/master
Browse files Browse the repository at this point in the history
Fixed torch.rfft compatibility issue
  • Loading branch information
KinWaiCheuk authored Jun 29, 2021
2 parents bf7a639 + d63f0f5 commit 61e03b2
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 6 deletions.
30 changes: 30 additions & 0 deletions .github/workflows/nnAudio.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: nnAudio
on:
push:
branches:
- main
- master
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.7, 3.8]
torch-version: [1.6.0, 1.7.0, 1.8.0]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies pytorch version == ${{ matrix.torch-version }}
run: |
sudo apt-get install libsndfile1-dev
python -m pip install --upgrade pip
pip install pytest
pip install librosa
pip install torch==${{matrix.torch-version}}
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest
run: |
cd Installation/ && pytest
29 changes: 24 additions & 5 deletions Installation/nnAudio/Spectrogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

# 0.2.0
import sys

import torch
import torch.nn as nn
Expand All @@ -21,6 +22,24 @@
sz_float = 4 # size of a float
epsilon = 10e-8 # fudge factor for normalization

# Acquires and parses the PyTorch version
__TORCH_GTE_1_7 = False
split_version = torch.__version__.split('.')
major_version = int(split_version[0])
minor_version = int(split_version[1])
if major_version > 1 or (major_version == 1 and minor_version >= 7):
__TORCH_GTE_1_7 = True
import torch.fft
if "torch.fft" not in sys.modules:
raise RuntimeError("torch.fft module available but not imported")

def rfft_fn(x, n=None, onesided=False):
if __TORCH_GTE_1_7:
y = torch.fft.fft(x)
return torch.view_as_real(y)
else:
return torch.rfft(x, n, onesided=onesided)

### --------------------------- Spectrogram Classes ---------------------------###
class STFT(torch.nn.Module):
"""This function is to calculate the short-time Fourier transform (STFT) of the input signal.
Expand Down Expand Up @@ -563,7 +582,7 @@ def _dct(self, x, norm=None):
N = x_shape[-1]

v = torch.cat([x[:, :, ::2], x[:, :, 1::2].flip([2])], dim=2)
Vc = torch.rfft(v, 1, onesided=False)
Vc = rfft_fn(v, 1, onesided=False)

# TODO: Can make the W_r and W_i trainable here
k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
Expand Down Expand Up @@ -2251,10 +2270,10 @@ def _CFP(self, spec):
if self.NumofLayer >= 2:
for gc in range(1, self.NumofLayer):
if np.remainder(gc, 2) == 1:
ceps = torch.rfft(spec, 1, onesided=False)[:,:,:,0]/np.sqrt(self.N)
ceps = rfft_fn(spec, 1, onesided=False)[:,:,:,0]/np.sqrt(self.N)
ceps = self.nonlinear_func(ceps, self.g[gc], self.tc_idx)
else:
spec = torch.rfft(ceps, 1, onesided=False)[:,:,:,0]/np.sqrt(self.N)
spec = rfft_fn(ceps, 1, onesided=False)[:,:,:,0]/np.sqrt(self.N)
spec = self.nonlinear_func(spec, self.g[gc], self.fc_idx)

return spec, ceps
Expand Down Expand Up @@ -2440,10 +2459,10 @@ def _CFP(self, spec):
if self.NumofLayer >= 2:
for gc in range(1, self.NumofLayer):
if np.remainder(gc, 2) == 1:
ceps = torch.rfft(spec, 1, onesided=False)[:,:,:,0]/np.sqrt(self.N)
ceps = rfft_fn(spec, 1, onesided=False)[:,:,:,0]/np.sqrt(self.N)
ceps = self.nonlinear_func(ceps, self.g[gc], self.tc_idx)
else:
spec = torch.rfft(ceps, 1, onesided=False)[:,:,:,0]/np.sqrt(self.N)
spec = rfft_fn(ceps, 1, onesided=False)[:,:,:,0]/np.sqrt(self.N)
spec = self.nonlinear_func(spec, self.g[gc], self.fc_idx)

return spec, ceps
Expand Down
1 change: 0 additions & 1 deletion Installation/tests/test_spectrogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
# librosa example audio for testing
example_y, example_sr = librosa.load(librosa.util.example_audio_file())


@pytest.mark.parametrize("n_fft, hop_length, window", stft_parameters)
@pytest.mark.parametrize("device", [*device_args])
def test_inverse2(n_fft, hop_length, window, device):
Expand Down

0 comments on commit 61e03b2

Please sign in to comment.