Skip to content

Latest commit

 

History

History
39 lines (30 loc) · 1.94 KB

README.md

File metadata and controls

39 lines (30 loc) · 1.94 KB

Prediction GAN

Introduction

This is a Pytorch implementation of the Prediction method presented in the following paper:

Abhay Yadav, Sohil Shah, Zheng Xu, David Jacobs and Tom Goldstein, Stabilizing Adversarial Nets With Prediction Methods, ICLR 2018.

If you use this code in your research, please cite our paper.

@article{yadav2018stabilizing,
  title={Stabilizing Adversarial Nets With Prediction Methods},
  author={Yadav, Abhay and Shah, Sohil and Xu, Zheng and Jacobs, David and Goldstein, Tom},
  journal={International Conference on Learning Representations},
  year={2018}
}

The source code and dataset are published under the MIT license. See LICENSE for details. In general, you can use the code for any purpose with proper attribution. If you do something interesting with the code, we'll be happy to know. Feel free to contact us.

Requirement

Usage

Please place adamPre file in the local folder of your project. You can start using AdamPre optimizer by simply importing it and initializing it like any other optimizer in Pytorch.

from adamPre import AdamPre
optG = AdamPre(netG.parameters(), lr=0.001, betas=(0.9, 0.999))

Once initialized, please appropriately call optG.stepLookAhead() and optG.restoreStepLookAhead() to update network weights using prediction method and to restore back to non-prediction weights respectively.

Please see sample code for MoG to understand the minor details. This code also reproduces the result in Figure 8 of the paper. One can start training from the console as follow:

$python main.py --pdhgGLookAhead --cuda --outf results/ --manualSeed 6162 --plotLoss --plotRealData

Gaussian

Other Implementations