Skip to content

Commit

Permalink
new paper claims gradient penalty on fake images going into discrimin…
Browse files Browse the repository at this point in the history
…ator stabilizes adversarial networks
  • Loading branch information
lucidrains committed Jan 11, 2025
1 parent 5d17c09 commit 6740714
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -590,3 +590,12 @@ $ accelerate launch train.py
url = {https://api.semanticscholar.org/CorpusID:273532030}
}
```

```bibtex
@inproceedings{Huang2025TheGI,
title = {The GAN is dead; long live the GAN! A Modern GAN Baseline},
author = {Yiwen Huang and Aaron Gokaslan and Volodymyr Kuleshov and James Tompkin},
year = {2025},
url = {https://api.semanticscholar.org/CorpusID:275405495}
}
```
11 changes: 7 additions & 4 deletions audiolm_pytorch/soundstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,22 +877,25 @@ def forward(

if self.single_channel:
real, fake = orig_x.clone(), recon_x.detach()
stft_real_logits, stft_fake_logits = map(self.stft_discriminator, (real.requires_grad_(), fake))
stft_real_logits, stft_fake_logits = map(self.stft_discriminator, (real.requires_grad_(), fake.requires_grad_()))
stft_discr_loss = hinge_discr_loss(stft_fake_logits, stft_real_logits)

if apply_grad_penalty:
stft_grad_penalty = gradient_penalty(real, stft_discr_loss)
stft_grad_penalty = gradient_penalty(real, stft_discr_loss) + gradient_penalty(fake, stft_discr_loss)

scaled_real, scaled_fake = real, fake
for discr, downsample in zip(self.discriminators, self.downsamples):
scaled_real, scaled_fake = map(downsample, (scaled_real, scaled_fake))

real_logits, fake_logits = map(discr, (scaled_real.requires_grad_(), scaled_fake))
real_logits, fake_logits = map(discr, (scaled_real.requires_grad_(), scaled_fake.requires_grad_()))
one_discr_loss = hinge_discr_loss(fake_logits, real_logits)

discr_losses.append(one_discr_loss)
if apply_grad_penalty:
discr_grad_penalties.append(gradient_penalty(scaled_real, one_discr_loss))
discr_grad_penalties.extend([
gradient_penalty(scaled_real, one_discr_loss),
gradient_penalty(scaled_fake, one_discr_loss)
])

if not return_discr_losses_separately:
all_discr_losses = torch.stack(discr_losses).mean()
Expand Down
2 changes: 1 addition & 1 deletion audiolm_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.2.3'
__version__ = '2.3.0'

0 comments on commit 6740714

Please sign in to comment.