Skip to content

Commit

Permalink
Create release 1.1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
gwichern authored and kieranparsons committed Feb 16, 2023
1 parent 84c5605 commit 6c5256e
Show file tree
Hide file tree
Showing 15 changed files with 317 additions and 42 deletions.
38 changes: 38 additions & 0 deletions .github/workflows/build_and_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (C) 2023 Mitsubishi Electric Research Laboratories (MERL)
#
# SPDX-License-Identifier: MIT

name: Build and Test

on:
pull_request:
push:
branches:
- '**'
tags-ignore:
- '**'

jobs:
build:
runs-on: ubuntu-latest

steps:
- name: Checkout repo
uses: actions/checkout@v3

- name: Set up Python 3.9
uses: actions/setup-python@v4
with:
python-version: 3.9
cache: 'pip'
cache-dependency-path: 'requirements.txt'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install pytest
- name: Run unit tests
run: |
python -m pytest test
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ repos:
- --line-length=119

- repo: https://github.com/pycqa/isort
rev: 5.10.1
rev: 5.12.0
hooks:
- id: isort
args: ["--profile", "black", "--filter-files", "--line-length", "119", "--skip-gitignore"]
Expand Down
4 changes: 3 additions & 1 deletion .reuse/dep5
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/

Files: checkpoints/default_mrx_pre_trained_weights.pth
Files:
checkpoints/default_mrx_pre_trained_weights.pth
checkpoints/paper_mrx_pre_trained_weights.pth
Copyright: 2023 Mitsubishi Electric Research Laboratories (MERL)
License: MIT
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ The code has been tested using `python 3.9`. Necessary dependencies can be insta
pip install -r requirements.txt
```

If you prefer to use the [torchaudio soundfile backend](https://pytorch.org/audio/stable/backend.html) (required on windows) please refer to the [SoundFile documentation](https://pysoundfile.readthedocs.io/en/latest/) for installation instructions.

Please modify pytorch installation depending on your particular CUDA version if necessary.

## Using a pre-trained model
Expand Down Expand Up @@ -70,6 +72,11 @@ my_model.load_state_dict(state_dict)
enhanced_dict = separate.separate_soundtrack(audio_tensor, separation_model=my_model, ...)
```

We include two pre-trained models in the `checkpoints` directory:
1. `default_mrx_pre_trained_weights.pth`: This is the model trained using the default arguments from [`lightning_train.py`](./lightning_train.py), except the training loss is SNR (`--loss snr`). This ensures that the level of the output signals matches the mixture.
2. `paper_mrx_pre_trained_weights.pth`: This is the model trained using the default arguments from [`lightning_train.py`](./lightning_train.py) including scale-invariant SNR loss function, which reproduces the results from our ICASSP paper.
However, due to the scale-invariant training the level of the output signals will not match the mixture.

## Training a model on the Divide and Remaster Dataset

If you haven't already, you will first need to download the [Divide and Remaster (DnR) Dataset.](https://zenodo.org/record/6949108#.Y861fOLMKrN)
Expand Down
4 changes: 2 additions & 2 deletions checkpoints/default_mrx_pre_trained_weights.pth
Git LFS file not shown
3 changes: 3 additions & 0 deletions checkpoints/paper_mrx_pre_trained_weights.pth
Git LFS file not shown
63 changes: 63 additions & 0 deletions consistency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (C) 2023 Mitsubishi Electric Research Laboratories (MERL)
#
# SPDX-License-Identifier: MIT

from typing import List, Optional

import torch

from dnr_dataset import SOURCE_NAMES


def mixture_consistency(
mixture: torch.Tensor,
estimated_sources: torch.Tensor,
source_dim: int,
source_weights: Optional[List[float]] = None,
) -> torch.Tensor:
"""
Postprocessing for adding residual between mixture and estimated sources back to estimated sources.
:param mixture (torch.tensor): audio mixture signal
:param estimated_sources (torch.tensor): estimated separated source signals with added source dimension
:param source_dim (int): dimension of sources in estimated source tensor
:param source_weights (list): the weights for each source. Length must match the source_dim of estimated sources
:return:
"""
if source_weights is None:
n_sources = estimated_sources.shape[source_dim]
source_weights = [1 / n_sources] * n_sources
source_weights = torch.tensor(source_weights).to(estimated_sources)
source_weights = source_weights / source_weights.sum()
n_trailing_dims = len(estimated_sources.shape[source_dim + 1 :])
source_weights = source_weights.reshape(source_weights.shape + (1,) * n_trailing_dims)
res = mixture - estimated_sources.sum(source_dim)
new_source_signals = estimated_sources + source_weights * res.unsqueeze(source_dim)
return new_source_signals


def dnr_consistency(mixture: torch.Tensor, estimated_sources: torch.Tensor, mode: str = "pass") -> torch.Tensor:
"""
Postprocessing for adding residual between mixture and estimated sources back to estimated sources.
:param mixture (torch.tensor): 3D Tensor with shape [batch, channels, samples]
or 2D Tensor of shape [channels, samples]
:param estimated_sources (torch.tensor): 4D Tensor with shape [batch, num_sources, channels, samples]
or 3D Tensor of shape [num_sources, channels, samples]
:param mode (str): choices=["all", "pass", "music_sfx"],
Whether to add the residual to estimates, 'pass' doesn't add residual, 'all' splits residual
among all sources, 'music_sfx' splits residual among music and sfx sources . (default: pass)"
:return: Tensor of same shape as estimated sources
"""
input_ndim = mixture.ndim
if input_ndim > 2: # we have a batch dimension
source_dim = 1
else:
source_dim = 0
if mode == "all":
return mixture_consistency(mixture, estimated_sources, source_dim)
elif mode == "music_sfx":
source_weights = [0 if src == "speech" else 0.5 for src in SOURCE_NAMES]
return mixture_consistency(mixture, estimated_sources, source_dim, source_weights)
else:
return estimated_sources
12 changes: 12 additions & 0 deletions dnr_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
self.chunk_size = -1
self.random_start = random_start
self.track_list = self._get_tracklist()
self._check_subset_lengths(subset)

def _get_tracklist(self) -> List[str]:
path = Path(self.path)
Expand All @@ -51,6 +52,17 @@ def _get_tracklist(self) -> List[str]:
names.append(name)
return sorted(names)

def _check_subset_lengths(self, subset: str):
"""
Assert if the number of files is incorrect, to ensure we are using DnR v2 not an old version
"""
if subset == "tr":
assert len(self.track_list) == 3406, "Expected 3406 mix in training set"
elif subset == "cv":
assert len(self.track_list) == 487, "Expected 487 mix in validation set"
elif subset == "tt":
assert len(self.track_list) == 973, "Expected 973 mix in testing set"

def _get_audio_path(self, track_name: str, source_name: str) -> Path:
return Path(self.path) / track_name / f"{source_name}{EXT}"

Expand Down
17 changes: 13 additions & 4 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@
from separate import DEFAULT_PRE_TRAINED_MODEL_PATH


def _read_checkpoint(checkpoint_path):
def _read_checkpoint(checkpoint_path, **kwargs):
ckpt = torch.load(checkpoint_path, map_location="cpu")
if "state_dict" in ckpt.keys():
# lightning module checkpoint
model = CocktailForkModule.load_from_checkpoint(checkpoint_path)
model = CocktailForkModule.load_from_checkpoint(checkpoint_path, **kwargs)
else:
# only network weights
model = MRX()
model.load_state_dict(ckpt)
model = CocktailForkModule(model=model)
model = CocktailForkModule(model=model, **kwargs)
return model


Expand All @@ -42,6 +42,14 @@ def _lightning_eval():
help="Path to trained model weights. Can be a pytorch_lightning checkpoint or pytorch state_dict",
)
parser.add_argument("--gpu-device", default=-1, type=int, help="The gpu device for model inference. (default: -1)")
parser.add_argument(
"--mixture-residual",
default="pass",
type=str,
choices=["all", "pass", "music_sfx"],
help="Whether to add the residual to estimates, 'pass' doesn't add residual, 'all' splits residual among "
"all sources, 'music_sfx' splits residual among only music and sfx sources . (default: pass)",
)
args = parser.parse_args()

test_dataset = DivideAndRemaster(args.root_dir, "tt")
Expand All @@ -64,7 +72,8 @@ def _lightning_eval():
accelerator=accelerator,
enable_progress_bar=True, # this will print the results to the command line
)
model = _read_checkpoint(args.checkpoint)

model = _read_checkpoint(args.checkpoint, mixture_residual=args.mixture_residual)
trainer.test(model, test_loader)


Expand Down
48 changes: 35 additions & 13 deletions lightning_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,27 @@
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader

from consistency import dnr_consistency
from dnr_dataset import SOURCE_NAMES, DivideAndRemaster
from mrx import MRX
from si_snr import si_snr
from snr import snr_loss


class CocktailForkModule(LightningModule):
def __init__(self, model=None):
def __init__(self, model=None, si_loss=True, mixture_residual="pass"):
super().__init__()
if model is None:
self.model = MRX()
else:
self.model = model
self.si_loss = si_loss
self.mixture_residual = mixture_residual

def _step(self, batch, batch_idx, split):
x, y, filenames = batch
y_hat = self.model(x)
loss = si_snr(y_hat, y).mean()
y_hat = dnr_consistency(x, y_hat, mode=self.mixture_residual)
loss = snr_loss(y_hat, y, scale_invariant=self.si_loss).mean()
self.log(f"{split}_loss", loss, on_step=True, on_epoch=True)
return loss

Expand All @@ -40,20 +44,23 @@ def validation_step(self, batch, batch_idx):
def test_step(self, batch, batch_idx):
x, y, filenames = batch
y_hat = self.model(x)
est_sdr = -si_snr(y_hat, y)
est_sdr = est_sdr.mean(-1).mean(0) # average of batch and channel
y_hat = dnr_consistency(x, y_hat, mode=self.mixture_residual)
est_sisdr = -snr_loss(y_hat, y, scale_invariant=True).mean(-1).mean(0) # average of batch and channel
est_snr = -snr_loss(y_hat, y, scale_invariant=False).mean(-1).mean(0)
# expand mixture to shape of isolated sources for noisy SDR
repeat_shape = len(y.shape) * [1]
repeat_shape[1] = y.shape[1]
x = x.unsqueeze(1).repeat(repeat_shape)
noisy_sdr = -si_snr(x, y)
noisy_sdr = noisy_sdr.mean(-1).mean(0) # average of batch and channel
noisy_sisdr = -snr_loss(x, y, scale_invariant=True).mean(-1).mean(0)
noisy_snr = -snr_loss(x, y, scale_invariant=False).mean(-1).mean(0)
result_dict = {}
for i, src in enumerate(SOURCE_NAMES):
result_dict[f"noisy_{src}"] = noisy_sdr[i].item()
result_dict[f"est_{src}"] = est_sdr[i].item()
result_dict[f"noisy_sisdr_{src}"] = noisy_sisdr[i].item()
result_dict[f"est_sisdr_{src}"] = est_sisdr[i].item()
result_dict[f"noisy_snr_{src}"] = noisy_snr[i].item()
result_dict[f"est_snr_{src}"] = est_snr[i].item()
self.log_dict(result_dict, on_epoch=True)
return est_sdr.mean()
return est_sisdr.mean()

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
Expand Down Expand Up @@ -102,7 +109,7 @@ def cli_main():
parser.add_argument(
"--root-dir",
type=Path,
help="The path to the directory where the directory ``Libri2Mix`` or ``Libri3Mix`` is stored.",
help="The path to the DnR directory containing ``tr`` ``cv`` and ``tt`` directories.",
)
parser.add_argument(
"--exp-dir", default=Path("./exp"), type=Path, help="The directory to save checkpoints and logs."
Expand Down Expand Up @@ -132,10 +139,25 @@ def cli_main():
type=int,
help="The number of workers for dataloader. (default: 4)",
)
parser.add_argument(
"--loss",
default="si_snr",
type=str,
choices=["si_snr", "snr"],
help="The loss function for network training, either snr or si_snr. (default: si_snr)",
)
parser.add_argument(
"--mixture-residual",
default="pass",
type=str,
choices=["all", "pass", "music_sfx"],
help="Whether to add the residual to estimates, 'pass' doesn't add residual, 'all' splits residual among "
"all sources, 'music_sfx' splits residual among only music and sfx sources . (default: pass)",
)

args = parser.parse_args()

model = CocktailForkModule()
si_loss = True if args.loss == "si_snr" else False
model = CocktailForkModule(si_loss=si_loss, mixture_residual=args.mixture_residual)
train_loader, valid_loader, test_loader = _get_dataloaders(
args.root_dir,
args.train_batch_size,
Expand Down
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: MIT

torch==1.10.0
torchaudio==0.10.0
pytorch_lightning==1.8.6
pyloudnorm==0.1.1
pytorch_lightning==1.9.0
torch==1.13.1
torchaudio==0.13.1
Loading

0 comments on commit 6c5256e

Please sign in to comment.