Skip to content

Commit

Permalink
minimize gradient syncs for soundstream distributed training
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 15, 2023
1 parent d6988ab commit 4805163
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 26 deletions.
4 changes: 1 addition & 3 deletions audiolm_pytorch/soundstream.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import functools
from itertools import cycle
from pathlib import Path

from functools import partial, wraps
from itertools import zip_longest
from itertools import cycle, zip_longest
from typing import Optional

import torch
Expand Down
53 changes: 31 additions & 22 deletions audiolm_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from random import choice
from pathlib import Path
from shutil import rmtree
from functools import partial
from collections import Counter
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext

from beartype.typing import Union, List, Optional, Tuple
from typing_extensions import Annotated
Expand Down Expand Up @@ -164,8 +165,8 @@ def __init__(
data_max_length: int = None,
data_max_length_seconds: Union[int, float] = None,
folder: str = None,
train_dataloader: DataLoader = None,
val_dataloader: DataLoader = None,
train_dataloader: Optional[DataLoader] = None,
val_dataloader: Optional[DataLoader] = None,
lr: float = 2e-4,
grad_accum_every: int = 4,
wd: float = 0.,
Expand Down Expand Up @@ -251,14 +252,14 @@ def __init__(
self.max_grad_norm = max_grad_norm
self.discr_max_grad_norm = discr_max_grad_norm

if folder is None:
assert train_dataloader is not None
assert val_dataloader is not None
if not exists(folder):
assert exists(train_dataloader)
assert exists(val_dataloader)
self.dl = train_dataloader
self.valid_dl = val_dataloader
else:
assert train_dataloader is None
assert val_dataloader is None
assert not exists(train_dataloader)
assert not exists(val_dataloader)

# create dataset

Expand Down Expand Up @@ -478,13 +479,17 @@ def train_step(self):

# update vae (generator)

for _ in range(self.grad_accum_every):
for i in range(self.grad_accum_every):
is_last = i == (self.grad_accum_every - 1)
context = partial(self.accelerator.no_sync, self.soundstream) if not is_last else nullcontext

wave, = next(self.dl_iter)
wave = wave.to(device)

loss, (recon_loss, multi_spectral_recon_loss, adversarial_loss, feature_loss, all_commitment_loss) = self.soundstream(wave, return_loss_breakdown = True)
with context():
loss, (recon_loss, multi_spectral_recon_loss, adversarial_loss, feature_loss, all_commitment_loss) = self.soundstream(wave, return_loss_breakdown = True)

self.accelerator.backward(loss / self.grad_accum_every)
self.accelerator.backward(loss / self.grad_accum_every)

accum_log(logs, dict(
loss = loss.item() / self.grad_accum_every,
Expand Down Expand Up @@ -512,20 +517,24 @@ def train_step(self):
for name, multiscale_discr_optim in self.multiscale_discriminator_optim_iter():
multiscale_discr_optim.zero_grad()

for _ in range(self.grad_accum_every):
for i in range(self.grad_accum_every):
is_last = i == (self.grad_accum_every - 1)
context = partial(self.accelerator.no_sync, self.soundstream) if not is_last else nullcontext

wave, = next(self.dl_iter)
wave = wave.to(device)

discr_losses = self.soundstream(
wave,
apply_grad_penalty = apply_grad_penalty,
return_discr_loss = True,
return_discr_losses_separately = True
)

for name, discr_loss in discr_losses:
self.accelerator.backward(discr_loss / self.grad_accum_every, retain_graph = True)
accum_log(logs, {name: discr_loss.item() / self.grad_accum_every})
with context():
discr_losses = self.soundstream(
wave,
apply_grad_penalty = apply_grad_penalty,
return_discr_loss = True,
return_discr_losses_separately = True
)

for name, discr_loss in discr_losses:
self.accelerator.backward(discr_loss / self.grad_accum_every, retain_graph = True)
accum_log(logs, {name: discr_loss.item() / self.grad_accum_every})

if exists(self.discr_max_grad_norm):
self.accelerator.clip_grad_norm_(self.soundstream.stft_discriminator.parameters(), self.discr_max_grad_norm)
Expand Down
2 changes: 1 addition & 1 deletion audiolm_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.7.7'
__version__ = '1.7.9'

0 comments on commit 4805163

Please sign in to comment.