Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
marypilataki committed Nov 5, 2024
1 parent 6fd958c commit 897c75a
Show file tree
Hide file tree
Showing 34 changed files with 3,451 additions and 0 deletions.
8 changes: 8 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/inspectionProfiles/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/padac-mmasia24.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,28 @@
# padac-mmasia24
Source code for "Pitch-aware generative pretraining improves multi-pitch estimation with scarce data" (MMASIA 2024)

<p align="center">
<img src="./plots/framework.png"width=70%></p>

### 1. Preparation
- This repo is based on the [DAC repo](https://github.com/descriptinc/descript-audio-codec), accompanying the paper ["High-Fidelity Audio Compression with Improved RVQGAN"](https://arxiv.org/pdf/2306.06546).
- Please create a virtual environment and install the packages specified in requirements.txt
- Add [audiotools-mir](https://github.com/marypilataki/audiotools-mir) as a submodule and checkout the *mpe_labels* branch.
- Add [Basic Pitch model](https://github.com/spotify/basic-pitch) as a submodule and checkout the *main* branch.

### 2. Stage 1: pretraining
- Replace dataset paths in conf/padac/pitch_cond_padac.yml with the paths of the dataset you would like to perform pretraining on.
- Run the following command to start training. Replace *./runs* with the path to the folder where you would like the model checkpoints to be saved.
```
python -m scripts.train_padac --args.load conf/padac/conf_padac.yml --save_path ./runs
```

### 3. Stage 2: shallow transcriber training
- After freezing PA-DAC, extract and save latent space embeddings using ```latent_space = self.encoder(audio_data)```.
- Prepare a config file similar to conf/transcriber.json specifying the paths to the extracted features.
- To start training, run the following command:
```
python -m scripts.train_transcriber --config_file ./conf/transcriber.json
```


3 changes: 3 additions & 0 deletions conf/1gpu.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
batch_size: 64
val_batch_size: 64
num_workers: 0
3 changes: 3 additions & 0 deletions conf/padac/conf_padac.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
$include:
- conf/padac/pitch_cond_padac.yml
- conf/1gpu.yml
117 changes: 117 additions & 0 deletions conf/padac/pitch_cond_padac.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Model setup
DAC.sample_rate: 44100
DAC.encoder_dim: 64
DAC.encoder_rates: [2, 4, 8, 8]
DAC.decoder_dim: 1536
DAC.decoder_rates: [8, 8, 4, 2]
DAC.add_conditioner: True

# Quantization
DAC.n_codebooks: 9
DAC.codebook_size: 1024
DAC.codebook_dim: 8
DAC.quantizer_dropout: 0.5

# Discriminator
Discriminator.sample_rate: 44100
Discriminator.rates: []
Discriminator.periods: [2, 3, 5, 7, 11]
Discriminator.fft_sizes: [2048, 1024, 512]
Discriminator.bands:
- [0.0, 0.1]
- [0.1, 0.25]
- [0.25, 0.5]
- [0.5, 0.75]
- [0.75, 1.0]

# Optimization
AdamW.betas: [0.8, 0.99]
AdamW.lr: 0.0001
ExponentialLR.gamma: 0.999996

amp: false
val_batch_size: 100
resume: false
tag: latest
device: cuda
recon_mode: true
num_iters: 400000
save_iters: [5000, 10000, 20000, 30000, 50000, 100000, 150000, 200000]
valid_freq: 1000
sample_freq: 10000
num_workers: 32
val_idx: [0, 1, 2, 3, 4, 5, 6, 7]
seed: 0
lambdas:
mel/loss: 15.0
adv/feat_loss: 2.0
adv/gen_loss: 1.0
vq/commitment_loss: 0.25
vq/codebook_loss: 1.0
pitch/loss: 150.0

VolumeNorm.db: [const, -16]

# Transforms
build_transform.preprocess:
- Identity
build_transform.augment_prob: 0.0
build_transform.augment:
- Identity
build_transform.postprocess:
- VolumeNorm
- RescaleAudio

# Loss setup
MultiScaleSTFTLoss.window_lengths: [2048, 512]
MelSpectrogramLoss.n_mels: [5, 10, 20, 40, 80, 160, 320]
MelSpectrogramLoss.window_lengths: [32, 64, 128, 256, 512, 1024, 2048]
MelSpectrogramLoss.mel_fmin: [0, 0, 0, 0, 0, 0, 0]
MelSpectrogramLoss.mel_fmax: [null, null, null, null, null, null, null]
MelSpectrogramLoss.pow: 1.0
MelSpectrogramLoss.clamp_eps: 1.0e-5
MelSpectrogramLoss.mag_weight: 0.0

# Data
batch_size: 72
train/AudioDataset.duration: 1.0
train/AudioDataset.n_examples: 10000000

val/AudioDataset.duration: 1.0
val/build_transform.augment_prob: 1.0
val/AudioDataset.n_examples: 250

AudioLoader.shuffle: true
AudioLoader.num_channels: 1
AudioLoader.ext: [".wav", ".flac", ".mp4", ".au", ".mp3", ".aiff"]
AudioLoader.basic_pitch: true
AudioDataset.without_replacement: true

train/build_dataset.folders:
gtzan_pop_rock_blues:
- /homes/mpm30/audio_full/Gtzan_rock_pop_country/train
gtzan_classical:
- /homes/mpm30/audio_full/Gtzan_classical/train
gtzan_jazz:
- /homes/mpm30/audio_full/Gtzan_jazz/train
mazurkas:
- /homes/mpm30/audio_full/Mazurkas/train
violin:
- /homes/mpm30/audio_full/bach-violin-dataset/train
guitar:
- /homes/mpm30/audio_full/guitar/train


val/build_dataset.folders:
gtzan_pop_rock_blues:
- /homes/mpm30/audio_full/Gtzan_rock_pop_country/valid
gtzan_classical:
- /homes/mpm30/audio_full/Gtzan_classical/valid
gtzan_jazz:
- /homes/mpm30/audio_full/Gtzan_jazz/valid
mazurkas:
- /homes/mpm30/audio_full/Mazurkas/valid
violin:
- /homes/mpm30/audio_full/bach-violin-dataset/valid
guitar:
- /homes/mpm30/audio_full/guitar/valid
22 changes: 22 additions & 0 deletions conf/transcriber.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"experiment_name": "padac-shallow-transcriber",
"data": {
"dataset_path": "/path/to/features/",
"labels_path": "/path/to/labels/"
},
"model": {
"name": "LinearModel",
"params": {
"in_features": 1024,
"out_features": 128,
"hidden_units": 512
},
"save_dir": "./runs/transcriber"
},
"train": {
"n_epochs": 1,
"lr": 1e-5,
"batch_size": 4,
"weight_decay": 1e-4
}
}
16 changes: 16 additions & 0 deletions dac/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
__version__ = "1.0.0"

# preserved here for legacy reasons
__model_version__ = "latest"

import audiotools

audiotools.ml.BaseModel.INTERN += ["dac.**"]
audiotools.ml.BaseModel.EXTERN += ["einops"]


from . import nn
from . import model
from . import utils
from .model import DAC
from .model import DACFile
36 changes: 36 additions & 0 deletions dac/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import sys

import argbind

from dac.utils import download
from dac.utils.decode import decode
from dac.utils.encode import encode

STAGES = ["encode", "decode", "download"]


def run(stage: str):
"""Run stages.
Parameters
----------
stage : str
Stage to run
"""
if stage not in STAGES:
raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}")
stage_fn = globals()[stage]

if stage == "download":
stage_fn()
return

stage_fn()


if __name__ == "__main__":
group = sys.argv.pop(1)
args = argbind.parse_args(group=group)

with argbind.scope(args):
run(group)
4 changes: 4 additions & 0 deletions dac/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .base import CodecMixin
from .base import DACFile
from .dac import DAC
from .discriminator import Discriminator
Loading

0 comments on commit 897c75a

Please sign in to comment.