From 2154a74a9b097d0a6f5a935b15a9d10b4a05f6c8 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 11 Jan 2025 16:20:56 -0800 Subject: [PATCH] needs to be zero centered GP --- audiolm_pytorch/soundstream.py | 4 ++-- audiolm_pytorch/version.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/audiolm_pytorch/soundstream.py b/audiolm_pytorch/soundstream.py index 7db7285..2b9ef99 100644 --- a/audiolm_pytorch/soundstream.py +++ b/audiolm_pytorch/soundstream.py @@ -67,7 +67,7 @@ def hinge_gen_loss(fake): def leaky_relu(p = 0.1): return nn.LeakyReLU(p) -def gradient_penalty(wave, output, weight = 10): +def gradient_penalty(wave, output, weight = 10, center = 0.): batch_size, device = wave.shape[0], wave.device gradients = torch_grad( @@ -80,7 +80,7 @@ def gradient_penalty(wave, output, weight = 10): )[0] gradients = rearrange(gradients, 'b ... -> b (...)') - return weight * ((vector_norm(gradients, dim = 1) - 1) ** 2).mean() + return weight * ((vector_norm(gradients, dim = 1) - center) ** 2).mean() # better sequential diff --git a/audiolm_pytorch/version.py b/audiolm_pytorch/version.py index 8219039..1c4ddd3 100644 --- a/audiolm_pytorch/version.py +++ b/audiolm_pytorch/version.py @@ -1 +1 @@ -__version__ = '2.3.0' +__version__ = '2.3.1'