Skip to content

Commit

Permalink
update dataset mode
Browse files Browse the repository at this point in the history
  • Loading branch information
junyanz committed Jun 13, 2017
1 parent 3b72a65 commit e6858e3
Show file tree
Hide file tree
Showing 21 changed files with 259 additions and 245 deletions.
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ In CVPR 2017.
## Getting Started
### Installation
- Install PyTorch and dependencies from http://pytorch.org/
- Install Torch vision from the source.
```bash
git clone https://github.com/pytorch/vision
cd vision
python setup.py install
```
- Install python libraries [visdom](https://github.com/facebookresearch/visdom) and [dominate](https://github.com/Knio/dominate).
```bash
pip install visdom
Expand Down Expand Up @@ -81,13 +87,13 @@ bash ./datasets/download_pix2pix_dataset.sh facades
- Train a model:
```bash
#!./scripts/train_pix2pix.sh
python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --lambda_A 100 --align_data --use_dropout --no_lsgan
python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --lambda_A 100 --dataset_mode aligned --use_dropout --no_lsgan
```
- To view training results and loss plots, run `python -m visdom.server` and click the URL http://localhost:8097. To see more intermediate results, check out `./checkpoints/facades_pix2pix/web/index.html`
- Test the model (`bash ./scripts/test_pix2pix.sh`):
```bash
#!./scripts/test_pix2pix.sh
python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --align_data
python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --dataset_mode aligned
```
The test results will be saved to a html file here: `./results/facades_pix2pix/latest_val/index.html`.

Expand Down
87 changes: 0 additions & 87 deletions data/aligned_data_loader.py

This file was deleted.

56 changes: 56 additions & 0 deletions data/aligned_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os.path
import random
import torchvision.transforms as transforms
import torch
from data.base_dataset import BaseDataset
from data.image_folder import make_dataset
from PIL import Image


class AlignedDataset(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot
self.dir_AB = os.path.join(opt.dataroot, opt.phase)

self.AB_paths = sorted(make_dataset(self.dir_AB))

assert(opt.resize_or_crop == 'resize_and_crop')

transform_list = [transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]

self.transform = transforms.Compose(transform_list)

def __getitem__(self, index):
AB_path = self.AB_paths[index]
AB = Image.open(AB_path).convert('RGB')
AB = AB.resize((self.opt.loadSize * 2, self.opt.loadSize), Image.BICUBIC)
AB = self.transform(AB)

w_total = AB.size(2)
w = int(w_total / 2)
h = AB.size(1)
w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1))
h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1))

A = AB[:, h_offset:h_offset + self.opt.fineSize,
w_offset:w_offset + self.opt.fineSize]
B = AB[:, h_offset:h_offset + self.opt.fineSize,
w + w_offset:w + w_offset + self.opt.fineSize]

if (not self.opt.no_flip) and random.random() < 0.5:
idx = [i for i in range(A.size(2) - 1, -1, -1)]
idx = torch.LongTensor(idx)
A = A.index_select(2, idx)
B = B.index_select(2, idx)

return {'A': A, 'B': B,
'A_paths': AB_path, 'B_paths': AB_path}

def __len__(self):
return len(self.AB_paths)

def name(self):
return 'AlignedDataset'
14 changes: 0 additions & 14 deletions data/base_data_loader.py

This file was deleted.

12 changes: 12 additions & 0 deletions data/base_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch.utils.data as data

class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()

def name(self):
return 'BaseDataset'

def initialize(self, opt):
pass

41 changes: 41 additions & 0 deletions data/custom_dataset_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch.utils.data
from data.base_data_loader import BaseDataLoader


def CreateDataset(opt):
dataset = None
if opt.dataset_mode == 'aligned':
from data.aligned_dataset import AlignedDataset
dataset = AlignedDataset()
elif opt.dataset_mode == 'unaligned':
from data.unaligned_dataset import UnalignedDataset
dataset = UnalignedDataset()
elif opt.dataset_mode == 'single':
from data.single_dataset import SingleDataset
dataset = SingleDataset()
else:
raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)

print("dataset [%s] was created" % (dataset.name()))
dataset.initialize(opt)
return dataset


class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'

def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = CreateDataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads))

def load_data(self):
return self.dataloader

def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)
9 changes: 2 additions & 7 deletions data/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@

def CreateDataLoader(opt):
data_loader = None
if opt.align_data > 0:
from data.aligned_data_loader import AlignedDataLoader
data_loader = AlignedDataLoader()
else:
from data.unaligned_data_loader import UnalignedDataLoader
data_loader = UnalignedDataLoader()
from data.custom_dataset_data_loader import CustomDatasetDataLoader
data_loader = CustomDatasetDataLoader()
print(data_loader.name())
data_loader.initialize(opt)
return data_loader
7 changes: 4 additions & 3 deletions data/image_folder.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
################################################################################
###############################################################################
# Code from
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
# Modified the original code so that it also loads images from the current
# directory as well as the subdirectories
################################################################################
###############################################################################

import torch.utils.data as data

Expand Down Expand Up @@ -45,7 +45,8 @@ def __init__(self, root, transform=None, return_paths=False,
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
"Supported image extensions are: " +
",".join(IMG_EXTENSIONS)))

self.root = root
self.imgs = imgs
Expand Down
47 changes: 47 additions & 0 deletions data/single_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os.path
import torchvision.transforms as transforms
from data.base_dataset import BaseDataset
from data.image_folder import make_dataset
from PIL import Image


class SingleDataset(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot
self.dir_A = os.path.join(opt.dataroot)

self.A_paths = make_dataset(self.dir_A)

self.A_paths = sorted(self.A_paths)

transform_list = []
if opt.resize_or_crop == 'resize_and_crop':
transform_list.append(transforms.Scale(opt.loadSize))

if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.RandomHorizontalFlip())

if opt.resize_or_crop != 'no_resize':
transform_list.append(transforms.RandomCrop(opt.fineSize))

transform_list += [transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]

self.transform = transforms.Compose(transform_list)

def __getitem__(self, index):
A_path = self.A_paths[index]

A_img = Image.open(A_path).convert('RGB')

A_img = self.transform(A_img)

return {'A': A_img, 'A_paths': A_path}

def __len__(self):
return len(self.A_paths)

def name(self):
return 'SingleImageDataset'
Loading

0 comments on commit e6858e3

Please sign in to comment.