diff --git a/examples/dcgan.py b/examples/dcgan.py index 81d7134e..28180dc7 100644 --- a/examples/dcgan.py +++ b/examples/dcgan.py @@ -30,9 +30,9 @@ import torchvision.datasets as dset import torchvision.transforms as transforms import torchvision.utils as vutils -from opacus import PrivacyEngine from tqdm import tqdm +from opacus import PrivacyEngine parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--data-root", required=True, help="path to dataset") @@ -298,21 +298,19 @@ def forward(self, input): fake = netG(noise) label_fake = torch.full((batch_size,), FAKE_LABEL, device=device) output = netD(fake.detach()) + D_G_z1 = output.mean().item() errD_fake = criterion(output, label_fake) - errD_fake.backward() - optimizerD.step() - optimizerD.zero_grad() # train with real label_true = torch.full((batch_size,), REAL_LABEL, device=device) output = netD(real_data) - errD_real = criterion(output, label_true) - errD_real.backward() - optimizerD.step() D_x = output.mean().item() + errD_real = criterion(output, label_true) - D_G_z1 = output.mean().item() - errD = errD_real + errD_fake + # Note that we clip the gradient for not only real but also fake data. + errD = errD_fake + errD_real + errD.backward() + optimizerD.step() ############################ # (2) Update G network: maximize log(D(G(z))) @@ -324,7 +322,7 @@ def forward(self, input): output_g = netD(fake) errG = criterion(output_g, label_g) errG.backward() - D_G_z2 = output.mean().item() + D_G_z2 = output_g.mean().item() optimizerG.step() data_bar.set_description( f"epoch: {epoch}, Loss_D: {errD.item()} "