forked from junyanz/pytorch-CycleGAN-and-pix2pix
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
110 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,40 +1,66 @@ | ||
"""General-purpose test script for image-to-image translation. | ||
Once you have trained your model with train.py, you can use this script to test the model. | ||
It will load a saved model from --checkpoints_dir and save the results to --results_dir. | ||
It first creates model and dataset given the option. It will hard-code some parameters. | ||
It then runs inference for --num_test images and save results to an HTML file. | ||
Example (You need to train models first or download pre-trained models from our website): | ||
Test a CycleGAN model (both sides): | ||
python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan | ||
Test a CycleGAN model (one side only): | ||
python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout | ||
The option '--model test' is used for generating CycleGAN results only for one side. | ||
This option will automatically set '--dataset_mode single', which only loads the images from one set. | ||
On the contrary, using '--model cycle_gan' requires loading and generating results in both directions, | ||
which is sometimes unnecessary. The results will be saved at ./results/. | ||
Use '--results_dir <directory_path_to_save_result>' to specify the results directory. | ||
Test a pix2pix model: | ||
python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA | ||
See options/base_options.py and options/test_options.py for more test options. | ||
See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md | ||
See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md | ||
""" | ||
import os | ||
from options.test_options import TestOptions | ||
from data import create_dataloader | ||
from data import create_dataset | ||
from models import create_model | ||
from util.visualizer import save_images | ||
from util import html | ||
|
||
|
||
if __name__ == '__main__': | ||
opt = TestOptions().parse() | ||
opt = TestOptions().parse() # get test options | ||
# hard-code some parameters for test | ||
opt.num_threads = 1 # test code only supports num_threads = 1 | ||
opt.batch_size = 1 # test code only supports batch_size = 1 | ||
opt.serial_batches = True # no shuffle | ||
opt.no_flip = True # no flip | ||
opt.display_id = -1 # no visdom display | ||
data_loader = create_dataloader(opt) | ||
dataset = data_loader.load_data() | ||
model = create_model(opt) | ||
model.setup(opt) | ||
opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. | ||
opt.no_flip = True # no flip; comment this line if results on flipped images are needed. | ||
opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file. | ||
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options | ||
model = create_model(opt) # create a model given opt.model and other options | ||
model.setup(opt) # regular setup: load and print networks; create schedulers | ||
# create a website | ||
web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.epoch)) | ||
web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.epoch)) # define the website directory | ||
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch)) | ||
# test with eval mode. This only affects layers like batchnorm and dropout. | ||
# pix2pix: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode. | ||
# CycleGAN: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout. | ||
# For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode. | ||
# For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout. | ||
if opt.eval: | ||
model.eval() | ||
for i, data in enumerate(dataset): | ||
if i >= opt.num_test: | ||
if i >= opt.num_test: # only apply our model to opt.num_test images. | ||
break | ||
model.set_input(data) | ||
model.test() | ||
visuals = model.get_current_visuals() | ||
img_path = model.get_image_paths() | ||
if i % 5 == 0: | ||
model.set_input(data) # unpack data from data loader | ||
model.test() # run inference | ||
visuals = model.get_current_visuals() # get image results | ||
img_path = model.get_image_paths() # get image paths | ||
if i % 5 == 0: # save images to an HTML file | ||
print('processing (%04d)-th image... %s' % (i, img_path)) | ||
save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize) | ||
# save the website | ||
webpage.save() | ||
webpage.save() # save the HTML |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,59 +1,77 @@ | ||
"""General-purpose training script for image-to-image translation. | ||
This script works for various models (with option '--model': e.g., pix2pix, cyclegan, colorization) and | ||
different datasets (with option '--dataset_mode': e.g., aligned, unaligned, `single, colorization). | ||
You need to specify the dataset ('--dataroot'), experiment name ('--name'), and model ('--model'). | ||
It first creates model, dataset, and visualizer given the option. | ||
It then does standard network training. During the training, it also visualize/save the images, print/save the loss plot, and save models. | ||
The script supports continue/resume training. Use '--continue_train' to resume your previous training. | ||
Example: | ||
Train a CycleGAN model: | ||
python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan | ||
Train a pix2pix model: | ||
python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA | ||
See options/base_options.py and options/train_options.py for more training options. | ||
See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md | ||
See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md | ||
""" | ||
import time | ||
from options.train_options import TrainOptions | ||
from data import create_dataloader | ||
from data import create_dataset | ||
from models import create_model | ||
from util.visualizer import Visualizer | ||
|
||
if __name__ == '__main__': | ||
opt = TrainOptions().parse() | ||
data_loader = create_dataloader(opt) | ||
dataset = data_loader.load_data() | ||
dataset_size = len(data_loader) | ||
print('#training images = %d' % dataset_size) | ||
|
||
model = create_model(opt) | ||
model.setup(opt) | ||
visualizer = Visualizer(opt) | ||
total_steps = 0 | ||
|
||
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): | ||
epoch_start_time = time.time() | ||
iter_data_time = time.time() | ||
epoch_iter = 0 | ||
|
||
for i, data in enumerate(dataset): | ||
iter_start_time = time.time() | ||
if total_steps % opt.print_freq == 0: | ||
opt = TrainOptions().parse() # get training options | ||
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options | ||
dataset_size = len(dataset) # get the number of images in the dataset. | ||
print('The number of training images = %d' % dataset_size) | ||
|
||
model = create_model(opt) # create a model given opt.model and other options | ||
model.setup(opt) # regular setup: load and print networks; create schedulers | ||
visualizer = Visualizer(opt) # create a visualizer that display/save images and plots | ||
total_iters = 0 # the total number of training iterations | ||
|
||
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq> | ||
epoch_start_time = time.time() # timer for entire epoch | ||
iter_data_time = time.time() # timer for data loading per iteration | ||
epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch | ||
|
||
for i, data in enumerate(dataset): # inner loop within one epoch | ||
iter_start_time = time.time() # timer for computation per iteration | ||
if total_iters % opt.print_freq == 0: | ||
t_data = iter_start_time - iter_data_time | ||
visualizer.reset() | ||
total_steps += opt.batch_size | ||
total_iters += opt.batch_size | ||
epoch_iter += opt.batch_size | ||
model.set_input(data) | ||
model.optimize_parameters() | ||
model.set_input(data) # unpack data from dataset and apply preprocessing | ||
model.optimize_parameters() # calculate loss functions, get gradients, update network weights | ||
|
||
if total_steps % opt.display_freq == 0: | ||
save_result = total_steps % opt.update_html_freq == 0 | ||
if total_iters % opt.display_freq == 0: # display images on visdom and save images to a HTML file | ||
save_result = total_iters % opt.update_html_freq == 0 | ||
model.compute_visuals() | ||
visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) | ||
|
||
if total_steps % opt.print_freq == 0: | ||
if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk | ||
losses = model.get_current_losses() | ||
t = (time.time() - iter_start_time) / opt.batch_size | ||
visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data) | ||
t_comp = (time.time() - iter_start_time) / opt.batch_size | ||
visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data) | ||
if opt.display_id > 0: | ||
visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, opt, losses) | ||
|
||
if total_steps % opt.save_latest_freq == 0: | ||
print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps)) | ||
save_suffix = 'iter_%d' % total_steps if opt.save_by_iter else 'latest' | ||
if total_iters % opt.save_latest_freq == 0: # cache our latest model every <save_latest_freq> iterations | ||
print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) | ||
save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest' | ||
model.save_networks(save_suffix) | ||
|
||
iter_data_time = time.time() | ||
if epoch % opt.save_epoch_freq == 0: | ||
print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) | ||
if epoch % opt.save_epoch_freq == 0: # cache our model every <save_epoch_freq> epochs | ||
print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) | ||
model.save_networks('latest') | ||
model.save_networks(epoch) | ||
|
||
print('End of epoch %d / %d \t Time Taken: %d sec' % | ||
(epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) | ||
model.update_learning_rate() | ||
print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) | ||
model.update_learning_rate() # update learning rates at the end of every epoch. |