From 01f00088976fac4de8ac22dc3d538e97eab039ee Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 11 Dec 2023 01:04:55 +0100 Subject: [PATCH] add warmup for soundstream as well as all discriminators --- audiolm_pytorch/trainer.py | 89 +++++++++++++++++++++++++++++++++++--- audiolm_pytorch/version.py | 2 +- setup.py | 1 + 3 files changed, 86 insertions(+), 6 deletions(-) diff --git a/audiolm_pytorch/trainer.py b/audiolm_pytorch/trainer.py index d296dd5..148ca60 100644 --- a/audiolm_pytorch/trainer.py +++ b/audiolm_pytorch/trainer.py @@ -9,7 +9,7 @@ from collections import Counter from contextlib import contextmanager, nullcontext -from beartype.typing import Union, List, Optional, Tuple +from beartype.typing import Union, List, Optional, Tuple, Type from typing_extensions import Annotated from beartype import beartype @@ -19,8 +19,12 @@ import torch import torchaudio from torch import nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR, _LRScheduler from torch.utils.data import Dataset, DataLoader, random_split +import pytorch_warmup as warmup + from einops import rearrange from audiolm_pytorch.optimizer import get_optimizer @@ -55,6 +59,8 @@ DEFAULT_SAMPLE_RATE = 16000 +ConstantLRScheduler = partial(LambdaLR, lr_lambda = lambda step: 1.) + # make sure only one trainer is instantiated ONE_TRAINER_INSTANTIATED = False @@ -152,6 +158,53 @@ def checkpoint_num_steps(checkpoint_path): return int(results[-1]) +# optimizer with scheduler + warmup + +class OptimizerWithWarmupSchedule(nn.Module): + @beartype + def __init__( + self, + accelerator: Accelerator, + optimizer: Optimizer, + scheduler: Optional[Type[_LRScheduler]] = None, + scheduler_kwargs: dict = dict(), + warmup_steps: int = 0 + ): + super().__init__() + self.warmup = warmup.LinearWarmup(optimizer, warmup_period = warmup_steps) + + if exists(scheduler): + self.scheduler = scheduler(optimizer, **scheduler_kwargs) + else: + self.scheduler = ConstantLRScheduler(optimizer) + + self.optimizer = optimizer + + self.optimizer, self.scheduler = accelerator.prepare(self.optimizer, self.scheduler) + self.accelerator = accelerator + + def state_dict(self): + return dict( + optimizer = self.optimizer.state_dict(), + scheduler = self.scheduler.state_dict(), + warmup = self.warmup.state_dict() + ) + + def load_state_dict(self, pkg): + self.optimizer.load_state_dict(pkg['optimizer']) + self.scheduler.load_state_dict(pkg['scheduler']) + self.warmup.load_state_dict(pkg['warmup']) + + def zero_grad(self): + self.optimizer.zero_grad() + + def step(self): + self.optimizer.step() + + if not self.accelerator.optimizer_step_was_skipped: + with self.warmup.dampening(): + self.scheduler.step() + # main trainer class class SoundStreamTrainer(nn.Module): @@ -172,6 +225,12 @@ def __init__( lr: float = 2e-4, grad_accum_every: int = 4, wd: float = 0., + warmup_steps: int = 1000, + scheduler: Optional[Type[_LRScheduler]] = None, + scheduler_kwargs: dict = dict(), + discr_warmup_steps: Optional[int] = None, + discr_scheduler: Optional[Type[_LRScheduler]] = None, + discr_scheduler_kwargs: dict = dict(), max_grad_norm: float = 0.5, discr_max_grad_norm: float = None, save_results_every: int = 100, @@ -240,13 +299,33 @@ def __init__( # optimizers - self.optim = get_optimizer(soundstream.non_discr_parameters(), lr = lr, wd = wd) + self.optim = OptimizerWithWarmupSchedule( + self.accelerator, + get_optimizer(soundstream.non_discr_parameters(), lr = lr, wd = wd), + scheduler = scheduler, + scheduler_kwargs = scheduler_kwargs, + warmup_steps = warmup_steps + ) + + discr_warmup_steps = default(discr_warmup_steps, warmup_steps) for discr_optimizer_key, discr in self.multiscale_discriminator_iter(): - one_multiscale_discr_optimizer = get_optimizer(discr.parameters(), lr = lr, wd = wd) + one_multiscale_discr_optimizer = OptimizerWithWarmupSchedule( + self.accelerator, + get_optimizer(discr.parameters(), lr = lr, wd = wd), + scheduler = discr_scheduler, + scheduler_kwargs = discr_scheduler_kwargs, + warmup_steps = discr_warmup_steps + ) setattr(self, discr_optimizer_key, one_multiscale_discr_optimizer) - self.discr_optim = get_optimizer(soundstream.stft_discriminator.parameters(), lr = lr, wd = wd) + self.discr_optim = OptimizerWithWarmupSchedule( + self.accelerator, + get_optimizer(soundstream.stft_discriminator.parameters(), lr = lr, wd = wd), + scheduler = discr_scheduler, + scheduler_kwargs = discr_scheduler_kwargs, + warmup_steps = discr_warmup_steps + ) # max grad norm @@ -596,6 +675,7 @@ def train_step(self): for model, label in models: model.eval() + model = model.to(device) with torch.inference_mode(): recons = model(wave, return_recons_only = True) @@ -1064,7 +1144,6 @@ def load(self, path): # + 1 to start from the next step and avoid overwriting the last checkpoint self.steps = torch.tensor([checkpoint_num_steps(path) + 1], device=self.device) - def print(self, msg): self.accelerator.print(msg) diff --git a/audiolm_pytorch/version.py b/audiolm_pytorch/version.py index 655be52..e5102d3 100644 --- a/audiolm_pytorch/version.py +++ b/audiolm_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.8.7' +__version__ = '1.9.0' diff --git a/setup.py b/setup.py index e70f408..94a80df 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ 'gateloop-transformer>=0.0.24', 'joblib', 'local-attention>=1.9.0', + 'pytorch-warmup', 'scikit-learn', 'sentencepiece', 'torch>=1.12',