Skip to content

Commit

Permalink
Examples dcgan (pytorch#464)
Browse files Browse the repository at this point in the history
* mnist added dcgan

* mnist added
  • Loading branch information
surgan12 authored and soumith committed Dec 9, 2018
1 parent 6d08877 commit 64f829c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
Binary file added dcgan/.swp
Binary file not shown.
18 changes: 16 additions & 2 deletions dcgan/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


parser = argparse.ArgumentParser()
parser.add_argument('--dataset', required=True, help='cifar10 | lsun | imagenet | folder | lfw | fake')
parser.add_argument('--dataset', required=True, help='cifar10 | lsun | mnist |imagenet | folder | lfw | fake')
parser.add_argument('--dataroot', required=True, help='path to dataset')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
Expand Down Expand Up @@ -60,6 +60,7 @@
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
nc=3
elif opt.dataset == 'lsun':
dataset = dset.LSUN(root=opt.dataroot, classes=['bedroom_train'],
transform=transforms.Compose([
Expand All @@ -68,16 +69,30 @@
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
nc=3
elif opt.dataset == 'cifar10':
dataset = dset.CIFAR10(root=opt.dataroot, download=True,
transform=transforms.Compose([
transforms.Resize(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
nc=3

elif opt.dataset == 'mnist':
dataset = dset.MNIST(root=opt.dataroot, download=True,
transform=transforms.Compose([
transforms.Resize(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
]))
nc=1

elif opt.dataset == 'fake':
dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize),
transform=transforms.ToTensor())
nc=3

assert dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
shuffle=True, num_workers=int(opt.workers))
Expand All @@ -87,7 +102,6 @@
nz = int(opt.nz)
ngf = int(opt.ngf)
ndf = int(opt.ndf)
nc = 3


# custom weights initialization called on netG and netD
Expand Down

0 comments on commit 64f829c

Please sign in to comment.