-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
- Loading branch information
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,28 @@ | ||
# padac-mmasia24 | ||
Source code for "Pitch-aware generative pretraining improves multi-pitch estimation with scarce data" (MMASIA 2024) | ||
|
||
<p align="center"> | ||
<img src="./plots/framework.png"width=70%></p> | ||
|
||
### 1. Preparation | ||
- This repo is based on the [DAC repo](https://github.com/descriptinc/descript-audio-codec), accompanying the paper ["High-Fidelity Audio Compression with Improved RVQGAN"](https://arxiv.org/pdf/2306.06546). | ||
- Please create a virtual environment and install the packages specified in requirements.txt | ||
- Add [audiotools-mir](https://github.com/marypilataki/audiotools-mir) as a submodule and checkout the *mpe_labels* branch. | ||
- Add [Basic Pitch model](https://github.com/spotify/basic-pitch) as a submodule and checkout the *main* branch. | ||
|
||
### 2. Stage 1: pretraining | ||
- Replace dataset paths in conf/padac/pitch_cond_padac.yml with the paths of the dataset you would like to perform pretraining on. | ||
- Run the following command to start training. Replace *./runs* with the path to the folder where you would like the model checkpoints to be saved. | ||
``` | ||
python -m scripts.train_padac --args.load conf/padac/conf_padac.yml --save_path ./runs | ||
``` | ||
|
||
### 3. Stage 2: shallow transcriber training | ||
- After freezing PA-DAC, extract and save latent space embeddings using ```latent_space = self.encoder(audio_data)```. | ||
- Prepare a config file similar to conf/transcriber.json specifying the paths to the extracted features. | ||
- To start training, run the following command: | ||
``` | ||
python -m scripts.train_transcriber --config_file ./conf/transcriber.json | ||
``` | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
batch_size: 64 | ||
val_batch_size: 64 | ||
num_workers: 0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
$include: | ||
- conf/padac/pitch_cond_padac.yml | ||
- conf/1gpu.yml |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# Model setup | ||
DAC.sample_rate: 44100 | ||
DAC.encoder_dim: 64 | ||
DAC.encoder_rates: [2, 4, 8, 8] | ||
DAC.decoder_dim: 1536 | ||
DAC.decoder_rates: [8, 8, 4, 2] | ||
DAC.add_conditioner: True | ||
|
||
# Quantization | ||
DAC.n_codebooks: 9 | ||
DAC.codebook_size: 1024 | ||
DAC.codebook_dim: 8 | ||
DAC.quantizer_dropout: 0.5 | ||
|
||
# Discriminator | ||
Discriminator.sample_rate: 44100 | ||
Discriminator.rates: [] | ||
Discriminator.periods: [2, 3, 5, 7, 11] | ||
Discriminator.fft_sizes: [2048, 1024, 512] | ||
Discriminator.bands: | ||
- [0.0, 0.1] | ||
- [0.1, 0.25] | ||
- [0.25, 0.5] | ||
- [0.5, 0.75] | ||
- [0.75, 1.0] | ||
|
||
# Optimization | ||
AdamW.betas: [0.8, 0.99] | ||
AdamW.lr: 0.0001 | ||
ExponentialLR.gamma: 0.999996 | ||
|
||
amp: false | ||
val_batch_size: 100 | ||
resume: false | ||
tag: latest | ||
device: cuda | ||
recon_mode: true | ||
num_iters: 400000 | ||
save_iters: [5000, 10000, 20000, 30000, 50000, 100000, 150000, 200000] | ||
valid_freq: 1000 | ||
sample_freq: 10000 | ||
num_workers: 32 | ||
val_idx: [0, 1, 2, 3, 4, 5, 6, 7] | ||
seed: 0 | ||
lambdas: | ||
mel/loss: 15.0 | ||
adv/feat_loss: 2.0 | ||
adv/gen_loss: 1.0 | ||
vq/commitment_loss: 0.25 | ||
vq/codebook_loss: 1.0 | ||
pitch/loss: 150.0 | ||
|
||
VolumeNorm.db: [const, -16] | ||
|
||
# Transforms | ||
build_transform.preprocess: | ||
- Identity | ||
build_transform.augment_prob: 0.0 | ||
build_transform.augment: | ||
- Identity | ||
build_transform.postprocess: | ||
- VolumeNorm | ||
- RescaleAudio | ||
|
||
# Loss setup | ||
MultiScaleSTFTLoss.window_lengths: [2048, 512] | ||
MelSpectrogramLoss.n_mels: [5, 10, 20, 40, 80, 160, 320] | ||
MelSpectrogramLoss.window_lengths: [32, 64, 128, 256, 512, 1024, 2048] | ||
MelSpectrogramLoss.mel_fmin: [0, 0, 0, 0, 0, 0, 0] | ||
MelSpectrogramLoss.mel_fmax: [null, null, null, null, null, null, null] | ||
MelSpectrogramLoss.pow: 1.0 | ||
MelSpectrogramLoss.clamp_eps: 1.0e-5 | ||
MelSpectrogramLoss.mag_weight: 0.0 | ||
|
||
# Data | ||
batch_size: 72 | ||
train/AudioDataset.duration: 1.0 | ||
train/AudioDataset.n_examples: 10000000 | ||
|
||
val/AudioDataset.duration: 1.0 | ||
val/build_transform.augment_prob: 1.0 | ||
val/AudioDataset.n_examples: 250 | ||
|
||
AudioLoader.shuffle: true | ||
AudioLoader.num_channels: 1 | ||
AudioLoader.ext: [".wav", ".flac", ".mp4", ".au", ".mp3", ".aiff"] | ||
AudioLoader.basic_pitch: true | ||
AudioDataset.without_replacement: true | ||
|
||
train/build_dataset.folders: | ||
gtzan_pop_rock_blues: | ||
- /homes/mpm30/audio_full/Gtzan_rock_pop_country/train | ||
gtzan_classical: | ||
- /homes/mpm30/audio_full/Gtzan_classical/train | ||
gtzan_jazz: | ||
- /homes/mpm30/audio_full/Gtzan_jazz/train | ||
mazurkas: | ||
- /homes/mpm30/audio_full/Mazurkas/train | ||
violin: | ||
- /homes/mpm30/audio_full/bach-violin-dataset/train | ||
guitar: | ||
- /homes/mpm30/audio_full/guitar/train | ||
|
||
|
||
val/build_dataset.folders: | ||
gtzan_pop_rock_blues: | ||
- /homes/mpm30/audio_full/Gtzan_rock_pop_country/valid | ||
gtzan_classical: | ||
- /homes/mpm30/audio_full/Gtzan_classical/valid | ||
gtzan_jazz: | ||
- /homes/mpm30/audio_full/Gtzan_jazz/valid | ||
mazurkas: | ||
- /homes/mpm30/audio_full/Mazurkas/valid | ||
violin: | ||
- /homes/mpm30/audio_full/bach-violin-dataset/valid | ||
guitar: | ||
- /homes/mpm30/audio_full/guitar/valid |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
{ | ||
"experiment_name": "padac-shallow-transcriber", | ||
"data": { | ||
"dataset_path": "/path/to/features/", | ||
"labels_path": "/path/to/labels/" | ||
}, | ||
"model": { | ||
"name": "LinearModel", | ||
"params": { | ||
"in_features": 1024, | ||
"out_features": 128, | ||
"hidden_units": 512 | ||
}, | ||
"save_dir": "./runs/transcriber" | ||
}, | ||
"train": { | ||
"n_epochs": 1, | ||
"lr": 1e-5, | ||
"batch_size": 4, | ||
"weight_decay": 1e-4 | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
__version__ = "1.0.0" | ||
|
||
# preserved here for legacy reasons | ||
__model_version__ = "latest" | ||
|
||
import audiotools | ||
|
||
audiotools.ml.BaseModel.INTERN += ["dac.**"] | ||
audiotools.ml.BaseModel.EXTERN += ["einops"] | ||
|
||
|
||
from . import nn | ||
from . import model | ||
from . import utils | ||
from .model import DAC | ||
from .model import DACFile |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import sys | ||
|
||
import argbind | ||
|
||
from dac.utils import download | ||
from dac.utils.decode import decode | ||
from dac.utils.encode import encode | ||
|
||
STAGES = ["encode", "decode", "download"] | ||
|
||
|
||
def run(stage: str): | ||
"""Run stages. | ||
Parameters | ||
---------- | ||
stage : str | ||
Stage to run | ||
""" | ||
if stage not in STAGES: | ||
raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}") | ||
stage_fn = globals()[stage] | ||
|
||
if stage == "download": | ||
stage_fn() | ||
return | ||
|
||
stage_fn() | ||
|
||
|
||
if __name__ == "__main__": | ||
group = sys.argv.pop(1) | ||
args = argbind.parse_args(group=group) | ||
|
||
with argbind.scope(args): | ||
run(group) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .base import CodecMixin | ||
from .base import DACFile | ||
from .dac import DAC | ||
from .discriminator import Discriminator |