Skip to content

Commit

Permalink
Support resampling and CutSet.save_audios when torchaudio is missing (
Browse files Browse the repository at this point in the history
#1255)

* Fix in CutSet.save_audios()

* Support resampling when torchaudio is missing using scipy

* Add scipy installation to CI for missing torchaudio tests
  • Loading branch information
pzelasko authored Jan 4, 2024
1 parent 110e067 commit b3373c0
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/missing_torchaudio.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
- name: Install Python dependencies
run: |
python -m pip install --upgrade pip
pip install wheel numpy
pip install wheel numpy scipy
# Force the installation of a CPU-only PyTorch
${{ matrix.torch-install-cmd }}
# the torchaudio env var does nothing when torchaudio is installed, but doesn't require it's presence when it's not
Expand Down
3 changes: 1 addition & 2 deletions lhotse/audio/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,8 +1178,7 @@ def save_flac_file(
kwargs.pop("bits_per_sample", None) # ignore this arg when not using torchaudio
if torch.is_tensor(src):
src = src.numpy()
src = src.squeeze(0)
sf.write(file=dest, data=src, samplerate=sample_rate, format="FLAC")
sf.write(file=dest, data=src.T, samplerate=sample_rate, format="FLAC")


def torchaudio_save_flac_safe(
Expand Down
2 changes: 1 addition & 1 deletion lhotse/audio/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def move_to_memory(
save_flac_file(
stream, torch.from_numpy(audio), self.sampling_rate, format=format
)
channels = (ifnone(channels, self.channel_ids),)
channels = ifnone(channels, self.channel_ids)
if isinstance(channels, int):
channels = [channels]
return Recording(
Expand Down
35 changes: 26 additions & 9 deletions lhotse/augmentation/torchaudio.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from dataclasses import dataclass
from decimal import ROUND_HALF_UP
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
Expand All @@ -11,6 +11,7 @@
Seconds,
compute_num_samples,
during_docs_build,
is_module_available,
is_torchaudio_available,
perturb_num_samples,
)
Expand Down Expand Up @@ -181,19 +182,35 @@ class Resample(AudioTransform):
def __post_init__(self):
self.source_sampling_rate = int(self.source_sampling_rate)
self.target_sampling_rate = int(self.target_sampling_rate)
self.resampler = get_or_create_resampler(
self.source_sampling_rate, self.target_sampling_rate
)
if not is_torchaudio_available():
assert is_module_available(
"scipy"
), "In order to use resampling, either torchaudio or scipy needs to be installed."
else:
self.resampler = get_or_create_resampler(
self.source_sampling_rate, self.target_sampling_rate
)

def __call__(self, samples: np.ndarray, *args, **kwargs) -> np.ndarray:
check_for_torchaudio()
if self.source_sampling_rate == self.target_sampling_rate:
return samples

if isinstance(samples, np.ndarray):
samples = torch.from_numpy(samples)
augmented = self.resampler(samples)
return augmented.numpy()
if is_torchaudio_available():
if isinstance(samples, np.ndarray):
samples = torch.from_numpy(samples)
augmented = self.resampler(samples)
return augmented.numpy()
else:
import scipy

gcd = np.gcd(self.source_sampling_rate, self.target_sampling_rate)
augmented = scipy.signal.resample_poly(
samples,
up=self.target_sampling_rate // gcd,
down=self.source_sampling_rate // gcd,
axis=-1,
)
return augmented

def reverse_timestamps(
self, offset: Seconds, duration: Optional[Seconds], sampling_rate: int
Expand Down
2 changes: 1 addition & 1 deletion lhotse/cut/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ def save_audio(
else:
import soundfile as sf

sf.write(str(storage_path), samples, samplerate=self.sampling_rate)
sf.write(str(storage_path), samples.T, samplerate=self.sampling_rate)
recording = Recording(
id=storage_path.stem,
sampling_rate=self.sampling_rate,
Expand Down
22 changes: 22 additions & 0 deletions test/test_missing_torchaudio.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ def test_lhotse_load_audio():
assert isinstance(audio, np.ndarray)


@notorchaudio
@pytest.mark.parametrize("sr", [8000, 16000, 22500, 24000, 44100])
def test_lhotse_resample(sr):
import lhotse

cuts = lhotse.CutSet.from_file("test/fixtures/libri/cuts.json")
cut = cuts[0]
cut = cut.resample(sr)
audio = cut.load_audio()
assert isinstance(audio, np.ndarray)
assert audio.shape == (1, cut.num_samples)


@notorchaudio
def test_lhotse_audio_in_memory():
import lhotse
Expand All @@ -50,6 +63,15 @@ def test_lhotse_audio_in_memory():
assert isinstance(audio, np.ndarray)


@notorchaudio
@pytest.mark.parametrize("fmt", ["wav", "flac"])
def test_lhotse_save_audios(tmp_path, fmt):
import lhotse

cuts = lhotse.CutSet.from_file("test/fixtures/libri/cuts.json")
cuts.save_audios(tmp_path, format=fmt)


@notorchaudio
def test_create_dummy_recording():
from lhotse.testing.dummies import dummy_recording
Expand Down

0 comments on commit b3373c0

Please sign in to comment.