Skip to content

Commit

Permalink
add models documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
junyanz committed Jan 2, 2019
1 parent 92d4afb commit dc25d96
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 32 deletions.
2 changes: 1 addition & 1 deletion data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def create_dataset(opt):
return instance


def CreateDataLoader(opt):
def create_dataloader(opt):
"""Create dataloader given the option.
This function warps the function create_dataset.
Expand Down
2 changes: 1 addition & 1 deletion data/single_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def modify_commandline_options(parser, is_train):

def __init__(self, opt):
BaseDataset.__init__(self, opt)
self.A_paths = sorted(make_dataset(self.root))
self.A_paths = sorted(make_dataset(opt.dataroot))
input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
self.transform = get_transform(opt, input_nc == 1)

Expand Down
46 changes: 24 additions & 22 deletions docs/overview.md
Original file line number Diff line number Diff line change
@@ -1,31 +1,33 @@
## Overview of Code Structure
We give a brief overview of each directory and each file. Please see the documentation in each file for more details. If you have questions, you may find useful information in [training/test tips](tips.md) and [frequently asked questions](qa.md).
We briefly overview the functionality and implementation of each package and each module. Please see the documentation in each file for more details. If you have questions, you may find useful information in [training/test tips](tips.md) and [frequently asked questions](qa.md).

[train.py](../train.py) is a general-purpose training script. It works for various models (with option `--model`: e.g., `pix2pix`, `cyclegan`, `colorization`) and different datasets (with option `--dataset_mode`: e.g., `aligned`, `unaligned`, `single`, `colorization`). See the main [README](.../README.md) and Training/test [tips](tips.md) for more details.
[train.py](../train.py) is a general-purpose training script. It works for various models (with option `--model`: e.g., `pix2pix`, `cyclegan`, `colorization`) and different datasets (with option `--dataset_mode`: e.g., `aligned`, `unaligned`, `single`, `colorization`). See the main [README](.../README.md) and [training/test tips](tips.md) for more details.

[test.py](../test.py) is a general-purpose test script. Once you have trained your model with `train.py`, you can use this script to test the model. It will load a saved model from `--checkpoints_dir` and save the results to `--results_dir`. See the main [README](.../README.md) and Training/test [tips](tips.md) for more details.
[test.py](../test.py) is a general-purpose test script. Once you have trained your model with `train.py`, you can use this script to test the model. It will load a saved model from `--checkpoints_dir` and save the results to `--results_dir`. See the main [README](.../README.md) and [training/test tips](tips.md) for more details.


[data](../data) directory contains all the modules related to data loading and data preprocessing.
* [\_\_init\_\_.py](../data/__init__.py) implements the interface between this package and training/test script. In the `train.py` and `test.py`, we call `from data import CreateDataLoader` and `data_loader = CreateDataLoader(opt)` to create a dataloader given the option `opt`.
* [base_dataset.py](../data/base_dataset.py) implements an abstract base class for datasets. It also includes common transformation functions `get_transform` and `get_simple_transform` which can be used in subclasses. To add a custom dataset class called `dummy`, you need to add a file called `dummy_dataset.py` and define a subclass `DummyDataset` inherited from `BaseDataset`. You need to implement four functions: `name`, `__len__`, `__getitem__`, and optionally `modify_commandline_options`. You can then use this dataset class by specifying flag `--dataset_mode dummy`.
[data](../data) directory contains all the modules related to data loading and preprocessing. To add a custom dataset class called `dummy`, you need to add a file called `dummy_dataset.py` and define a subclass `DummyDataset` inherited from `BaseDataset`. You need to implement four functions: `__init__` (initialize the class, you need to first call `BaseDataset.__init__(self, opt)`), `__len__` (return the size of dataset), `__getitem__` (get a data point), and optionally `modify_commandline_options` (add dataset-specific options and set default options). Now you can use the dataset class by specifying flag `--dataset_mode dummy`.

* [\_\_init\_\_.py](../data/__init__.py) implements the interface between this package and training/test script. `train.py` and `test.py` call `from data import create_dataloader` and `data_loader = create_dataloader(opt)` to create a data loader given the option `opt`.
* [base_dataset.py](../data/base_dataset.py) implements an abstract base class ([ABC](https://docs.python.org/3/library/abc.html)) for datasets. It also includes common transformation functions (e.g., `get_transform`, `__scale_width`), which can be later used in subclasses. Below we explain each file in details.
* [image_folder.py](../data/image_folder.py) implements an image folder class. We modify the official PyTorch image folder [code](https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) so that this class can load images from both the current directory and its subdirectories.
* [template_dataset.py](../data/template_dataset.py) provides a dataset class template with detailed documentation. Check out this file if you plan to implement your own dataset class.
* [aligned_dataset.py](../data/aligned_dataset.py) includes a dataset class that can load aligned image pairs. It assumes a single image directory `/path/to/data/train`, which contains image pairs in the form of {A,B}. See [here](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md#prepare-your-own-datasets-for-pix2pix) on how to prepare aligned datasets. During test time, you need to prepare a directory `/path/to/data/test` for test data.
* [unaligned_dataset.py](../data/unaligned_dataset.py) includes a dataset class that can load unaligned/unpaired datasets. It assumes that two directories to host training images from domain A `/path/to/data/trainA` and from domain B `/path/to/data/trainB` separately. Then you can train the model with the dataset flag `--dataroot /path/to/data`. Similarly, you need to prepare two directories `/path/to/data/testA` and `/path/to/data/testB` during test time.
* [single_dataset.py](../data/single_dataset.py) includes a dataset class that can load a set of single images. It is used in `test.py` when only model in one direction is being tested. The option `--model test` is used for generating CycleGAN results only for one side. This option will automatically set `--dataset_mode single`.
* [colorization_dataset.py](../data/colorization_dataset.py) implements a dataset class that can load a set of nature images in RGB, and convert RGB format into (L, ab) pairs. It is required by pix2pix-based colorization model (`--model colorization`).


[models](../models) directory contains core modules related to objective functions, optimizations, and network architectures.
* [\_\_init\_\_.py](../models/__init__.py)
* [base_model.py](../models/base_model.py)
* [template_model.py](../models/template_model.py)
* [pix2pix_model.py](../models/pix2pix_model.py)
* [colorization_model.py](../models/colorization_model.py)
* [cycle_gan_model.py](../models/cycle_gan_model.py)
* [networks.py](../models/networks.py) module implements network architectures (both generators and discriminators), as well as normalization layers, initialization, optimization scheduler (learning rate policy), and GAN loss function.
* [test_model.py](../models/test_model.py)
* [aligned_dataset.py](../data/aligned_dataset.py) includes a dataset class that can load image pairs. It assumes a single image directory `/path/to/data/train`, which contains image pairs in the form of {A,B}. See [here](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md#prepare-your-own-datasets-for-pix2pix) on how to prepare aligned datasets. During test time, you need to prepare a directory `/path/to/data/test` as test data.
* [unaligned_dataset.py](../data/unaligned_dataset.py) includes a dataset class that can load unaligned/unpaired datasets. It assumes that two directories to host training images from domain A `/path/to/data/trainA` and from domain B `/path/to/data/trainB` respectively. Then you can train the model with the dataset flag `--dataroot /path/to/data`. Similarly, you need to prepare two directories `/path/to/data/testA` and `/path/to/data/testB` during test time.
* [single_dataset.py](../data/single_dataset.py) includes a dataset class that can load a set of single images specified by the path `--dataroot /path/to/data`. It can be used for generating CycleGAN results only for one side with the model option `-model test`.
* [colorization_dataset.py](../data/colorization_dataset.py) implements a dataset class that can load a set of nature images in RGB, and convert RGB format into (L, ab) pairs in [Lab](https://en.wikipedia.org/wiki/CIELAB_color_space) color space. It is required by pix2pix-based colorization model (`--model colorization`).


[models](../models) directory contains modules related to objective functions, optimizations, and network architectures. To add a custom model class called `dummy`, you need to add a file called `dummy_model.py` and define a subclass `DummyModel` inherited from `BaseModel`. You need to implement four functions: `__init__` (initialize the class; you need to first call `BaseModel.__init__(self, opt)`), `set_input` (unpack data from data loader and apply preprocessing), `forward` (generate intermediate results), `optimize_parameters` (calculate loss, gradients, and update network weights), and optionally `modify_commandline_options` (add model-specific options and set default options). Now you can use the model class by specifying flag `--model dummy`. Below we explain each file in details.

* [\_\_init\_\_.py](../models/__init__.py) implements the interface between this package and training/test script. `train.py` and `test.py` call `from models import create_model` and `model = create_model(opt)` to create a model given the option `opt`. You also need to call `mode.setup(opt)` to initialize the model.
* [base_model.py](../models/base_model.py) implements an abstract base class ([ABC](https://docs.python.org/3/library/abc.html)) for models. It also includes helper functions (e.g., `setup`, `test`, `update_learning_rate`), which can be later used in subclasses.
* [template_model.py](../models/template_model.py) provides a model class template with detailed documentation. Check out this file if you plan to implement your own model class.
* [pix2pix_model.py](../models/pix2pix_model.py) implements the pix2pix [model](https://phillipi.github.io/pix2pix/), for learning a mapping from input images to output images given paired data. The model training requires `--dataset_mode aligned` dataset. By default, it uses a `--netG unet256` [U-Net](https://arxiv.org/pdf/1505.04597.pdf) generator, a `--netD basic` discriminator (PatchGAN), and a `--gan_mode vanilla` GAN loss (standard cross-entropy objective).
* [colorization_model.py](../models/colorization_model.py) implements a subclass of `Pix2PixModel` for image colorization. The model training requires `-dataset_model colorization` dataset. It trains a pix2pix model from L channel to ab channel in Lab color space. By default, `--input_nc 1` and `--output_nc 2`.
* [cycle_gan_model.py](../models/cycle_gan_model.py) implements the CycleGAN [model](https://junyanz.github.io/CycleGAN/), for learning image-to-image translation without paired data. The model training requires `--dataset_mode unaligned` dataset. By default, it uses a `--netG resnet_9blocks` ResNet generator, a `--netD basic` discrimiator (PatchGAN introduced by pix2pix), and a least-square GANs [objective](https://arxiv.org/abs/1611.04076) (`--gan_mode lsgan`).
* [networks.py](../models/networks.py) module implements network architectures (both generators and discriminators), as well as normalization layers, initialization methods, optimization scheduler (i.e., learning rate policy), and GAN loss function.
* [test_model.py](../models/test_model.py) implements a model that can be used to generate CycleGAN results for only one direction. This option will automatically set `--dataset_mode single`, which only loads the images from one set. See the test [instruction](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix#apply-a-pre-trained-model-cyclegan) for more details.

[options](../options) directory includes our option modules: training options, test options and basic options (used in both training and test).
* [\_\_init\_\_.py](../options/__init__.py) an empty file to make the `options` directory a package.
Expand All @@ -34,7 +36,7 @@ We give a brief overview of each directory and each file. Please see the documen
* [test_options.py](../options/test_options.py) includes options that are only used in test time.


[util](../util) directory includes a misc collection of useful utility functions.
[util](../util) directory includes a miscellaneous collection of useful helper functions.
* [\_\_init\_\_.py](../util/__init__.py): an empty file to make the `util` directory a package.
* [get_data.py](../util/get_data.py)
* [html.py](../util/html.py)
Expand Down
4 changes: 2 additions & 2 deletions models/template_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss:
min_<netG> ||netG(data_A) - data_B||_1
You need to implement the following functions:
<modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
<modify_commandline_options>: Add model-specific options and rewrite default values for existing options.
<__init__>: Initialize this model class.
<set_input>: Unpack input data and perform data pre-processing.
<forward>: Run forward pass. This will be called by both <optimize_parameters> and <test>.
Expand All @@ -23,7 +23,7 @@
class TemplateModel(BaseModel):
@staticmethod
def modify_commandline_options(parser, is_train=True):
"""Add new dataset-specific options and rewrite default values for existing options.
"""Add new model-specific options and rewrite default values for existing options.
Parameters:
parser -- the option parser
Expand Down
4 changes: 2 additions & 2 deletions options/base_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def initialize(self, parser):
parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
parser.add_argument('--netD', type=str, default='basic', help='selects model to use for netD [basic | n_layers | pixel]')
parser.add_argument('--netG', type=str, default='resnet_9blocks', help='selects model to use for netG [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')
parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')
parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
Expand Down
4 changes: 2 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from options.test_options import TestOptions
from data import CreateDataLoader
from data import create_dataloader
from models import create_model
from util.visualizer import save_images
from util import html
Expand All @@ -14,7 +14,7 @@
opt.serial_batches = True # no shuffle
opt.no_flip = True # no flip
opt.display_id = -1 # no visdom display
data_loader = CreateDataLoader(opt)
data_loader = create_dataloader(opt)
dataset = data_loader.load_data()
model = create_model(opt)
model.setup(opt)
Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import time
from options.train_options import TrainOptions
from data import CreateDataLoader
from data import create_dataloader
from models import create_model
from util.visualizer import Visualizer

if __name__ == '__main__':
opt = TrainOptions().parse()
data_loader = CreateDataLoader(opt)
data_loader = create_dataloader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
print('#training images = %d' % dataset_size)
Expand Down

0 comments on commit dc25d96

Please sign in to comment.