-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathimage_inpainting.py
67 lines (53 loc) · 4.12 KB
/
image_inpainting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from inpainter.model import Inpainter
import torchvision.transforms as transforms
from pathlib import Path
import argparse
from libs.data_retriever import InpaintDataset
from libs.utils import str2bool
parser = argparse.ArgumentParser()
parser.add_argument('-b', '--batch-size', default=8, type=int, help='Size of the batch')
parser.add_argument('-e', '--epochs', default=100, type=int, help='Number of epochs')
parser.add_argument('-lr', '--learning-rate', default=5e-5, type=float, help='Initial learning rate')
parser.add_argument('-w', '--write-per-epoch', default=10, type=int, help='Times to write per epoch')
parser.add_argument('-train', '--training-dir', default='Datasets/Paris/paris_train_original', type=str, help='Path for training samples')
parser.add_argument('-test', '--testing-dir', default='Datasets/Paris/paris_eval_gt', type=str, help='Path for testing samples')
parser.add_argument('-test-samples', '--test-samples', default=2, type=int, help='Number of generated samples for testing')
parser.add_argument('-s', '--image-size', default=400, type=int, help='Size of the images (squared)')
parser.add_argument('-r', '--restore-check', default=True, type=str2bool, help='Restore last checkpoint in folder --checkpoint')
parser.add_argument('-c', '--checkpoint-dir', default='inpainter/weights', help='Checkpoint directory')
parser.add_argument('-m', '--mode', default='test', help='Mode: train or test')
parser.add_argument('-ie', '--initial-epoch', default=0, type=int, help='Initial epoch')
parser.add_argument('-config', '--config_file', default='config.yml', type=str, help='Path for config file')
parser.add_argument('-f', '--freeze-bn', type=str2bool, nargs='?', const=True, default=False, help='Freeze BN while training')
parser.add_argument('-po', '--pretrained-outside', type=str2bool, nargs='?', const=True, default=False, help='Take the pretrained model from github')
def main(FLAGS):
print('Parameters\n')
print(f'Image = {FLAGS.image_size}x{FLAGS.image_size}\n')
normalization_mean = [0.485, 0.456, 0.406]
normalization_std = [0.229, 0.224, 0.225]
if FLAGS.mode == 'train':
train_data_path = Path(__file__).parents[0].joinpath(FLAGS.training_dir)
test_data_path = Path(__file__).parents[0].joinpath(FLAGS.testing_dir)
transform_train = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.RandomCrop((500, 500)),
transforms.Resize(FLAGS.image_size), transforms.ToTensor(),
transforms.Normalize(normalization_mean, normalization_std)])
transform_test = transforms.Compose([transforms.Resize(FLAGS.image_size), transforms.ToTensor(),
transforms.Normalize(normalization_mean, normalization_std)])
trainset = InpaintDataset(train_data_path, transform_image=transform_train)
testset = InpaintDataset(test_data_path, transform_image=transform_test)
inpainter = Inpainter(FLAGS.mode, trainset, testset, checkpoint_dir=FLAGS.checkpoint_dir, restore_parameters=FLAGS.restore_check,
epochs=FLAGS.epochs, lr=FLAGS.learning_rate, batch_size=FLAGS.batch_size, initial_epoch=FLAGS.initial_epoch,
writing_per_epoch=FLAGS.write_per_epoch, freeze_bn=FLAGS.freeze_bn, config_path=FLAGS.config_file,
outside_pretrain=FLAGS.pretrained-outside)
inpainter.fit()
else:
test_data_path = Path(__file__).parents[0].joinpath(FLAGS.testing_dir)
transform_test = transforms.Compose([transforms.Resize(FLAGS.image_size), transforms.ToTensor(),
transforms.Normalize(normalization_mean, normalization_std)])
testset = InpaintDataset(test_data_path, transform_image=transform_test)
inpainter = Inpainter(FLAGS.mode, [], testset, checkpoint_dir=FLAGS.checkpoint_dir, restore_parameters=True,
batch_size=FLAGS.batch_size, config_path=FLAGS.config_file)
inpainter.test_model()
if __name__ == '__main__':
FLAGS, unparsed = parser.parse_known_args()
main(FLAGS)