From 5d17c096affeff2ad51f875c18b42ca2778116c7 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 8 Nov 2024 12:59:14 -0800 Subject: [PATCH] address https://github.com/lucidrains/audiolm-pytorch/issues/279 again --- audiolm_pytorch/trainer.py | 9 +++++---- audiolm_pytorch/version.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/audiolm_pytorch/trainer.py b/audiolm_pytorch/trainer.py index 9eec39e..ba8e66d 100644 --- a/audiolm_pytorch/trainer.py +++ b/audiolm_pytorch/trainer.py @@ -667,7 +667,7 @@ def train_step(self): self.accelerator.wait_for_everyone() - if self.is_main and not (steps % self.save_results_every): + if not (steps % self.save_results_every): models = [(self.unwrapped_soundstream, str(steps))] if self.use_ema: models.append((self.ema_soundstream.ema_model if self.use_ema else self.unwrapped_soundstream, f'{steps}.ema')) @@ -682,9 +682,10 @@ def train_step(self): with torch.inference_mode(): recons = model(wave, return_recons_only = True) - for ind, recon in enumerate(recons.unbind(dim = 0)): - filename = str(self.results_folder / f'sample_{label}.flac') - torchaudio.save(filename, recon.cpu().detach(), self.unwrapped_soundstream.target_sample_hz) + if self.is_main: + for ind, recon in enumerate(recons.unbind(dim = 0)): + filename = str(self.results_folder / f'sample_{label}.flac') + torchaudio.save(filename, recon.cpu().detach(), self.unwrapped_soundstream.target_sample_hz) self.print(f'{steps}: saving to {str(self.results_folder)}') diff --git a/audiolm_pytorch/version.py b/audiolm_pytorch/version.py index f1edb19..05633ca 100644 --- a/audiolm_pytorch/version.py +++ b/audiolm_pytorch/version.py @@ -1 +1 @@ -__version__ = '2.2.2' +__version__ = '2.2.3'