diff --git a/.github/workflows/build_and_test.yaml b/.github/workflows/build_and_test.yaml new file mode 100644 index 0000000..66609be --- /dev/null +++ b/.github/workflows/build_and_test.yaml @@ -0,0 +1,38 @@ +# Copyright (C) 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: MIT + +name: Build and Test + +on: + pull_request: + push: + branches: + - '**' + tags-ignore: + - '**' + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - name: Checkout repo + uses: actions/checkout@v3 + + - name: Set up Python 3.9 + uses: actions/setup-python@v4 + with: + python-version: 3.9 + cache: 'pip' + cache-dependency-path: 'requirements.txt' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install pytest + + - name: Run unit tests + run: | + python -m pytest test diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 50c9ab2..a324d98 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,7 +35,7 @@ repos: - --line-length=119 - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort args: ["--profile", "black", "--filter-files", "--line-length", "119", "--skip-gitignore"] diff --git a/.reuse/dep5 b/.reuse/dep5 index 28b2b0a..bd49e06 100644 --- a/.reuse/dep5 +++ b/.reuse/dep5 @@ -1,5 +1,7 @@ Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ -Files: checkpoints/default_mrx_pre_trained_weights.pth +Files: + checkpoints/default_mrx_pre_trained_weights.pth + checkpoints/paper_mrx_pre_trained_weights.pth Copyright: 2023 Mitsubishi Electric Research Laboratories (MERL) License: MIT diff --git a/README.md b/README.md index 311e150..083ec10 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,8 @@ The code has been tested using `python 3.9`. Necessary dependencies can be insta pip install -r requirements.txt ``` +If you prefer to use the [torchaudio soundfile backend](https://pytorch.org/audio/stable/backend.html) (required on windows) please refer to the [SoundFile documentation](https://pysoundfile.readthedocs.io/en/latest/) for installation instructions. + Please modify pytorch installation depending on your particular CUDA version if necessary. ## Using a pre-trained model @@ -70,6 +72,11 @@ my_model.load_state_dict(state_dict) enhanced_dict = separate.separate_soundtrack(audio_tensor, separation_model=my_model, ...) ``` +We include two pre-trained models in the `checkpoints` directory: +1. `default_mrx_pre_trained_weights.pth`: This is the model trained using the default arguments from [`lightning_train.py`](./lightning_train.py), except the training loss is SNR (`--loss snr`). This ensures that the level of the output signals matches the mixture. +2. `paper_mrx_pre_trained_weights.pth`: This is the model trained using the default arguments from [`lightning_train.py`](./lightning_train.py) including scale-invariant SNR loss function, which reproduces the results from our ICASSP paper. +However, due to the scale-invariant training the level of the output signals will not match the mixture. + ## Training a model on the Divide and Remaster Dataset If you haven't already, you will first need to download the [Divide and Remaster (DnR) Dataset.](https://zenodo.org/record/6949108#.Y861fOLMKrN) diff --git a/checkpoints/default_mrx_pre_trained_weights.pth b/checkpoints/default_mrx_pre_trained_weights.pth index 378c08b..4db61d4 100644 --- a/checkpoints/default_mrx_pre_trained_weights.pth +++ b/checkpoints/default_mrx_pre_trained_weights.pth @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:e206fe97739a1fd5ba66d2a4a1e3f68c16f0e3eb8cb7b25ec93d53e7b429f58f -size 122316769 +oid sha256:9cecc071c9e58afda34d1a98609a1540f61320e718bafecc4c9c1c5e3ed20a0c +size 122317463 diff --git a/checkpoints/paper_mrx_pre_trained_weights.pth b/checkpoints/paper_mrx_pre_trained_weights.pth new file mode 100644 index 0000000..378c08b --- /dev/null +++ b/checkpoints/paper_mrx_pre_trained_weights.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e206fe97739a1fd5ba66d2a4a1e3f68c16f0e3eb8cb7b25ec93d53e7b429f58f +size 122316769 diff --git a/consistency.py b/consistency.py new file mode 100644 index 0000000..1c3efe2 --- /dev/null +++ b/consistency.py @@ -0,0 +1,63 @@ +# Copyright (C) 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: MIT + +from typing import List, Optional + +import torch + +from dnr_dataset import SOURCE_NAMES + + +def mixture_consistency( + mixture: torch.Tensor, + estimated_sources: torch.Tensor, + source_dim: int, + source_weights: Optional[List[float]] = None, +) -> torch.Tensor: + """ + Postprocessing for adding residual between mixture and estimated sources back to estimated sources. + + :param mixture (torch.tensor): audio mixture signal + :param estimated_sources (torch.tensor): estimated separated source signals with added source dimension + :param source_dim (int): dimension of sources in estimated source tensor + :param source_weights (list): the weights for each source. Length must match the source_dim of estimated sources + :return: + """ + if source_weights is None: + n_sources = estimated_sources.shape[source_dim] + source_weights = [1 / n_sources] * n_sources + source_weights = torch.tensor(source_weights).to(estimated_sources) + source_weights = source_weights / source_weights.sum() + n_trailing_dims = len(estimated_sources.shape[source_dim + 1 :]) + source_weights = source_weights.reshape(source_weights.shape + (1,) * n_trailing_dims) + res = mixture - estimated_sources.sum(source_dim) + new_source_signals = estimated_sources + source_weights * res.unsqueeze(source_dim) + return new_source_signals + + +def dnr_consistency(mixture: torch.Tensor, estimated_sources: torch.Tensor, mode: str = "pass") -> torch.Tensor: + """ + Postprocessing for adding residual between mixture and estimated sources back to estimated sources. + + :param mixture (torch.tensor): 3D Tensor with shape [batch, channels, samples] + or 2D Tensor of shape [channels, samples] + :param estimated_sources (torch.tensor): 4D Tensor with shape [batch, num_sources, channels, samples] + or 3D Tensor of shape [num_sources, channels, samples] + :param mode (str): choices=["all", "pass", "music_sfx"], + Whether to add the residual to estimates, 'pass' doesn't add residual, 'all' splits residual + among all sources, 'music_sfx' splits residual among music and sfx sources . (default: pass)" + :return: Tensor of same shape as estimated sources + """ + input_ndim = mixture.ndim + if input_ndim > 2: # we have a batch dimension + source_dim = 1 + else: + source_dim = 0 + if mode == "all": + return mixture_consistency(mixture, estimated_sources, source_dim) + elif mode == "music_sfx": + source_weights = [0 if src == "speech" else 0.5 for src in SOURCE_NAMES] + return mixture_consistency(mixture, estimated_sources, source_dim, source_weights) + else: + return estimated_sources diff --git a/dnr_dataset.py b/dnr_dataset.py index 7ee012a..9c4da3d 100644 --- a/dnr_dataset.py +++ b/dnr_dataset.py @@ -39,6 +39,7 @@ def __init__( self.chunk_size = -1 self.random_start = random_start self.track_list = self._get_tracklist() + self._check_subset_lengths(subset) def _get_tracklist(self) -> List[str]: path = Path(self.path) @@ -51,6 +52,17 @@ def _get_tracklist(self) -> List[str]: names.append(name) return sorted(names) + def _check_subset_lengths(self, subset: str): + """ + Assert if the number of files is incorrect, to ensure we are using DnR v2 not an old version + """ + if subset == "tr": + assert len(self.track_list) == 3406, "Expected 3406 mix in training set" + elif subset == "cv": + assert len(self.track_list) == 487, "Expected 487 mix in validation set" + elif subset == "tt": + assert len(self.track_list) == 973, "Expected 973 mix in testing set" + def _get_audio_path(self, track_name: str, source_name: str) -> Path: return Path(self.path) / track_name / f"{source_name}{EXT}" diff --git a/eval.py b/eval.py index 36c3a98..7d2713e 100644 --- a/eval.py +++ b/eval.py @@ -15,16 +15,16 @@ from separate import DEFAULT_PRE_TRAINED_MODEL_PATH -def _read_checkpoint(checkpoint_path): +def _read_checkpoint(checkpoint_path, **kwargs): ckpt = torch.load(checkpoint_path, map_location="cpu") if "state_dict" in ckpt.keys(): # lightning module checkpoint - model = CocktailForkModule.load_from_checkpoint(checkpoint_path) + model = CocktailForkModule.load_from_checkpoint(checkpoint_path, **kwargs) else: # only network weights model = MRX() model.load_state_dict(ckpt) - model = CocktailForkModule(model=model) + model = CocktailForkModule(model=model, **kwargs) return model @@ -42,6 +42,14 @@ def _lightning_eval(): help="Path to trained model weights. Can be a pytorch_lightning checkpoint or pytorch state_dict", ) parser.add_argument("--gpu-device", default=-1, type=int, help="The gpu device for model inference. (default: -1)") + parser.add_argument( + "--mixture-residual", + default="pass", + type=str, + choices=["all", "pass", "music_sfx"], + help="Whether to add the residual to estimates, 'pass' doesn't add residual, 'all' splits residual among " + "all sources, 'music_sfx' splits residual among only music and sfx sources . (default: pass)", + ) args = parser.parse_args() test_dataset = DivideAndRemaster(args.root_dir, "tt") @@ -64,7 +72,8 @@ def _lightning_eval(): accelerator=accelerator, enable_progress_bar=True, # this will print the results to the command line ) - model = _read_checkpoint(args.checkpoint) + + model = _read_checkpoint(args.checkpoint, mixture_residual=args.mixture_residual) trainer.test(model, test_loader) diff --git a/lightning_train.py b/lightning_train.py index 9522552..f79e0e2 100644 --- a/lightning_train.py +++ b/lightning_train.py @@ -11,23 +11,27 @@ from pytorch_lightning.callbacks import ModelCheckpoint from torch.utils.data import DataLoader +from consistency import dnr_consistency from dnr_dataset import SOURCE_NAMES, DivideAndRemaster from mrx import MRX -from si_snr import si_snr +from snr import snr_loss class CocktailForkModule(LightningModule): - def __init__(self, model=None): + def __init__(self, model=None, si_loss=True, mixture_residual="pass"): super().__init__() if model is None: self.model = MRX() else: self.model = model + self.si_loss = si_loss + self.mixture_residual = mixture_residual def _step(self, batch, batch_idx, split): x, y, filenames = batch y_hat = self.model(x) - loss = si_snr(y_hat, y).mean() + y_hat = dnr_consistency(x, y_hat, mode=self.mixture_residual) + loss = snr_loss(y_hat, y, scale_invariant=self.si_loss).mean() self.log(f"{split}_loss", loss, on_step=True, on_epoch=True) return loss @@ -40,20 +44,23 @@ def validation_step(self, batch, batch_idx): def test_step(self, batch, batch_idx): x, y, filenames = batch y_hat = self.model(x) - est_sdr = -si_snr(y_hat, y) - est_sdr = est_sdr.mean(-1).mean(0) # average of batch and channel + y_hat = dnr_consistency(x, y_hat, mode=self.mixture_residual) + est_sisdr = -snr_loss(y_hat, y, scale_invariant=True).mean(-1).mean(0) # average of batch and channel + est_snr = -snr_loss(y_hat, y, scale_invariant=False).mean(-1).mean(0) # expand mixture to shape of isolated sources for noisy SDR repeat_shape = len(y.shape) * [1] repeat_shape[1] = y.shape[1] x = x.unsqueeze(1).repeat(repeat_shape) - noisy_sdr = -si_snr(x, y) - noisy_sdr = noisy_sdr.mean(-1).mean(0) # average of batch and channel + noisy_sisdr = -snr_loss(x, y, scale_invariant=True).mean(-1).mean(0) + noisy_snr = -snr_loss(x, y, scale_invariant=False).mean(-1).mean(0) result_dict = {} for i, src in enumerate(SOURCE_NAMES): - result_dict[f"noisy_{src}"] = noisy_sdr[i].item() - result_dict[f"est_{src}"] = est_sdr[i].item() + result_dict[f"noisy_sisdr_{src}"] = noisy_sisdr[i].item() + result_dict[f"est_sisdr_{src}"] = est_sisdr[i].item() + result_dict[f"noisy_snr_{src}"] = noisy_snr[i].item() + result_dict[f"est_snr_{src}"] = est_snr[i].item() self.log_dict(result_dict, on_epoch=True) - return est_sdr.mean() + return est_sisdr.mean() def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) @@ -102,7 +109,7 @@ def cli_main(): parser.add_argument( "--root-dir", type=Path, - help="The path to the directory where the directory ``Libri2Mix`` or ``Libri3Mix`` is stored.", + help="The path to the DnR directory containing ``tr`` ``cv`` and ``tt`` directories.", ) parser.add_argument( "--exp-dir", default=Path("./exp"), type=Path, help="The directory to save checkpoints and logs." @@ -132,10 +139,25 @@ def cli_main(): type=int, help="The number of workers for dataloader. (default: 4)", ) + parser.add_argument( + "--loss", + default="si_snr", + type=str, + choices=["si_snr", "snr"], + help="The loss function for network training, either snr or si_snr. (default: si_snr)", + ) + parser.add_argument( + "--mixture-residual", + default="pass", + type=str, + choices=["all", "pass", "music_sfx"], + help="Whether to add the residual to estimates, 'pass' doesn't add residual, 'all' splits residual among " + "all sources, 'music_sfx' splits residual among only music and sfx sources . (default: pass)", + ) args = parser.parse_args() - - model = CocktailForkModule() + si_loss = True if args.loss == "si_snr" else False + model = CocktailForkModule(si_loss=si_loss, mixture_residual=args.mixture_residual) train_loader, valid_loader, test_loader = _get_dataloaders( args.root_dir, args.train_batch_size, diff --git a/requirements.txt b/requirements.txt index d6b67ac..54a3bb9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: MIT -torch==1.10.0 -torchaudio==0.10.0 -pytorch_lightning==1.8.6 +pyloudnorm==0.1.1 +pytorch_lightning==1.9.0 +torch==1.13.1 +torchaudio==0.13.1 diff --git a/separate.py b/separate.py index 3bf2a6d..0d55ea7 100644 --- a/separate.py +++ b/separate.py @@ -4,10 +4,13 @@ from argparse import ArgumentParser from pathlib import Path +from typing import Optional, Union +import pyloudnorm import torch import torchaudio +from consistency import dnr_consistency from dnr_dataset import EXT, SAMPLE_RATE, SOURCE_NAMES from mrx import MRX @@ -21,7 +24,7 @@ def load_default_pre_trained(): return model -def _mrx_output_to_dict(output): +def _mrx_output_to_dict(output: torch.tensor) -> dict: """ Convert MRX() to dictionary with one key per output source. @@ -35,15 +38,43 @@ def _mrx_output_to_dict(output): return output_dict -def separate_soundtrack(audio_tensor, separation_model=None, device=None): +def _compute_gain(audio_tensor: torch.tensor, target_lufs: float) -> float: + """ + Compute the gain required to achieve a target integrated loudness. + + :param audio_tensor (torch.tensor): 2D Tensor of shape [channels, samples]. + :param target_lufs (float): Target level in loudness units full scale. + :return gain (float): Gain that when multiplied by audio_tensor will achieve target_lufs + """ + meter = pyloudnorm.Meter(SAMPLE_RATE) + loudness = meter.integrated_loudness(audio_tensor.cpu().numpy().T) + gain_lufs = target_lufs - loudness + gain = 10 ** (gain_lufs / 20.0) + return gain + + +def separate_soundtrack( + audio_tensor: torch.tensor, + separation_model: Optional[MRX] = None, + device: Optional[int] = None, + consistency_mode: Optional[str] = "pass", + input_lufs: Optional[float] = -27.0, +): """ Separates a torch.Tensor into three stems. If a separation_model is provided, it will be used, otherwise the included pre-trained weights will be used. - :param audio_tensor (torch.tensor): 2D Tensor of shape [channels, samples] + :param audio_tensor (torch.tensor): 2D Tensor of shape [channels, samples]. Assumed samplerate of 44.1 kHz. :param separation_model (MRX, optional): a preloaded MRX model, or none to use included pre-trained model. :param device (int, optional): The gpu device for model inference. (default: -1) [cpu] + :param consistency_mode (str, optional): choices=["all", "pass", "music_sfx"], + Whether to add the residual to estimates, 'pass' doesn't add residual, + 'all' splits residual among all sources, 'music_sfx' splits residual among + only music and sfx sources . (default: pass)" + :param input_lufs (float, optional): Add gain to input and normalize output, so audio input level matches average + of Divide and Remaster dataset in loudness units full scale. + Pass None to skip. (default: -27) :return: (dictionary): {'music': music_samples, 'speech': speech_samples, 'sfx': sfx_samples} where each of the x_samples are 2D Tensor of shape [channels, samples] """ @@ -53,11 +84,24 @@ def separate_soundtrack(audio_tensor, separation_model=None, device=None): separation_model = separation_model.to(device) audio_tensor = audio_tensor.to(device) with torch.no_grad(): + if input_lufs is not None: + gain = _compute_gain(audio_tensor, input_lufs) + audio_tensor *= gain output_tensor = separation_model(audio_tensor) + output_tensor = dnr_consistency(audio_tensor, output_tensor, mode=consistency_mode) + if input_lufs is not None: + output_tensor /= gain return _mrx_output_to_dict(output_tensor) -def separate_soundtrack_file(audio_filepath, output_directory, separation_model=None, device=None): +def separate_soundtrack_file( + audio_filepath: Union[str, Path], + output_directory: Union[str, Path], + separation_model: Optional[MRX] = None, + device: Optional[int] = None, + consistency_mode: Optional[str] = "pass", + input_lufs: Optional[float] = -27.0, +) -> None: """ Takes the path to a wav file, separates it, and saves the results in speech.wav, music.wav, and sfx.wav. Wraps seperate_soundtrack(). Audio will be resampled if it's not at the correct samplerate. @@ -67,11 +111,19 @@ def separate_soundtrack_file(audio_filepath, output_directory, separation_model= :param separation_model (MRX, optional): a preloaded MRX model, or none to use included pre-trained model. :param device (int, optional): The gpu device for model inference. (default: -1) [cpu] + :param consistency_mode (str, optional): choices=["all", "pass", "music_sfx"], + Whether to add the residual to estimates, 'pass' doesn't add residual, + 'all' splits residual among all sources, 'music_sfx' splits residual among + only music and sfx sources . (default: pass)" + :param input_lufs (float, optional): Add gain to input and normalize output, so audio input level matches average + of Divide and Remaster dataset in loudness units full scale. (default: -27) """ audio_tensor, fs = torchaudio.load(audio_filepath) if fs != SAMPLE_RATE: audio_tensor = torchaudio.functional.resample(audio_tensor, fs, SAMPLE_RATE) - output_dict = separate_soundtrack(audio_tensor, separation_model, device) + output_dict = separate_soundtrack( + audio_tensor, separation_model, device, consistency_mode=consistency_mode, input_lufs=input_lufs + ) for k, v in output_dict.items(): output_path = Path(output_directory) / f"{k}{EXT}" torchaudio.save(output_path, v.cpu(), SAMPLE_RATE) @@ -91,6 +143,14 @@ def cli_main(): help="Path to directory for saving output files.", ) parser.add_argument("--gpu-device", default=-1, type=int, help="The gpu device for model inference. (default: -1)") + parser.add_argument( + "--mixture-residual", + default="pass", + type=str, + choices=["all", "pass", "music_sfx"], + help="Whether to add the residual to estimates, 'pass' doesn't add residual, 'all' splits residual among " + "all sources, 'music_sfx' splits residual among only music and sfx sources . (default: pass)", + ) args = parser.parse_args() if args.gpu_device != -1: device = torch.device("cuda:" + str(args.gpu_device)) @@ -98,7 +158,7 @@ def cli_main(): device = torch.device("cpu") output_dir = args.out_dir output_dir.mkdir(parents=True, exist_ok=True) - separate_soundtrack_file(args.audio_path, output_dir, device=device) + separate_soundtrack_file(args.audio_path, output_dir, device=device, consistency_mode=args.mixture_residual) if __name__ == "__main__": diff --git a/si_snr.py b/snr.py similarity index 55% rename from si_snr.py rename to snr.py index c87c101..ba20d74 100644 --- a/si_snr.py +++ b/snr.py @@ -9,19 +9,25 @@ EPSILON = 1e-8 -def si_snr(estimates: torch.Tensor, targets: torch.Tensor, dim: Optional[int] = -1) -> torch.Tensor: +def snr_loss( + estimates: torch.Tensor, targets: torch.Tensor, dim: Optional[int] = -1, scale_invariant: Optional[bool] = True +) -> torch.Tensor: """ - Computes the negative scale-invariant signal (source) to noise (distortion) ratio. + Computes the negative [scale-invariant] signal (source) to noise (distortion) ratio. :param estimates (torch.Tensor): estimated source signals, tensor of shape [..., n_samples, ....] :param targets (torch.Tensor): ground truth signals, tensor of shape [...., n_samples, ....] :param dim (int): time (sample) dimension - :return (torch.Tensor): estimated SI-SNR with one value for each non-sample dimension + :param scale_invariant (bool): use SI-SNR when true, regular SNR when false + :return (torch.Tensor): estimated [SI-]SNR with one value for each non-sample dimension """ - estimates = _mean_center(estimates, dim=dim) - targets = _mean_center(targets, dim=dim) - sig_power = _l2_square(targets, dim=dim, keepdim=True) # [n_batch, 1, n_srcs] - dot_ = torch.sum(estimates * targets, dim=dim, keepdim=True) - scale = dot_ / (sig_power + 1e-12) + if scale_invariant: + estimates = _mean_center(estimates, dim=dim) + targets = _mean_center(targets, dim=dim) + sig_power = _l2_square(targets, dim=dim, keepdim=True) # [n_batch, 1, n_srcs] + dot_ = torch.sum(estimates * targets, dim=dim, keepdim=True) + scale = dot_ / (sig_power + 1e-12) + else: # Regular SNR with no processing + scale = 1 s_target = scale * targets e_noise = estimates - s_target si_snr_array = _l2_square(s_target, dim=dim) / (_l2_square(e_noise, dim=dim) + EPSILON) diff --git a/test/consistency_test.py b/test/consistency_test.py new file mode 100644 index 0000000..3e69dbe --- /dev/null +++ b/test/consistency_test.py @@ -0,0 +1,52 @@ +# Copyright (C) 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: MIT + +import torch + +from consistency import mixture_consistency + + +def test_mixture_consistency_no_weights(): + n_sources = 4 + n_channels = 2 + n_samples = 20000 + n_batch = 5 + source_dim = 1 + true_sources = torch.rand((n_batch, n_sources, n_channels, n_samples)) + mixture = true_sources.sum(dim=source_dim) + res = torch.rand(mixture.shape) + mixture += res + new_est = mixture_consistency(mixture, true_sources, source_dim) + # if no source weights are given add the residual equally to all sources + expected_output = true_sources + (1 / n_sources) * res.unsqueeze(source_dim) + torch.testing.assert_close(new_est, expected_output) + + +def test_mixture_consistency_source_weights(): + n_sources = 4 + n_channels = 2 + n_samples = 20000 + n_batch = 5 + source_dim = 1 + source_weights = [0.1, 0, 0.9, 0] + true_sources = torch.rand((n_batch, n_sources, n_channels, n_samples)) + mixture = true_sources.sum(dim=source_dim) + res = torch.rand(mixture.shape) + mixture += res + new_est = mixture_consistency(mixture, true_sources, source_dim, source_weights) + for i_src in range(n_sources): + expected_output = true_sources[:, i_src, :, :] + source_weights[i_src] * res + torch.testing.assert_close(new_est[:, i_src, :, :], expected_output) + + +def test_mixture_consistency_no_residual_unchanged(): + n_sources = 4 + n_channels = 2 + n_samples = 20000 + n_batch = 5 + source_dim = 1 + true_sources = torch.rand((n_batch, n_sources, n_channels, n_samples)) + mixture = true_sources.sum(dim=source_dim) + new_est = mixture_consistency(mixture, true_sources, source_dim) + torch.testing.assert_close(new_est, true_sources) diff --git a/test/sisnr_test.py b/test/snr_test.py similarity index 69% rename from test/sisnr_test.py rename to test/snr_test.py index 18711b0..67512dd 100644 --- a/test/sisnr_test.py +++ b/test/snr_test.py @@ -4,7 +4,7 @@ import torch -from si_snr import si_snr +from snr import snr_loss def test_si_snr_scale_invariant(): @@ -13,6 +13,6 @@ def test_si_snr_scale_invariant(): estimation_errors = 0.1 * torch.rand(targets.shape) estimates = targets + estimation_errors estimates_scale = 0.5 * estimates - loss_no_scale = si_snr(estimates, targets) - loss_scale = si_snr(estimates_scale, targets) - torch.testing.assert_allclose(loss_scale, loss_no_scale) + loss_no_scale = snr_loss(estimates, targets) + loss_scale = snr_loss(estimates_scale, targets) + torch.testing.assert_close(loss_scale, loss_no_scale)