-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
107 lines (89 loc) · 3.38 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import copy
import ntpath
import os
import sys
import warnings
import numpy as np
import torch
import tqdm
from data import create_dataloader
from metric import create_metric_models
from metric import get_cityscapes_mIoU
from metric import get_fid
from models import create_model
from options.test_options import TestOptions
from utils import html, util
def save_images(webpage, visuals, image_path, opt):
def convert_visuals_to_numpy(visuals):
for key, t in visuals.items():
tile = opt.batch_size > 8
if key == 'labels':
t = util.tensor2label(t, opt.input_nc + 2, tile=tile)
else:
t = util.tensor2im(t, tile=tile)
visuals[key] = t
return visuals
visuals = convert_visuals_to_numpy(visuals)
image_dir = webpage.get_image_dir()
short_path = ntpath.basename(image_path[0])
name = os.path.splitext(short_path)[0]
webpage.add_header(name)
ims = []
txts = []
links = []
for label, image_numpy in visuals.items():
image_name = os.path.join(label, '%s.png' % (name))
save_path = os.path.join(image_dir, image_name)
util.save_image(image_numpy, save_path, create_dir=True)
ims.append(image_name)
txts.append(label)
links.append(image_name)
webpage.add_images(ims, txts, links, width=opt.display_winsize)
def check(opt):
assert opt.serial_batches
assert opt.no_flip
assert opt.load_size == opt.crop_size
assert opt.preprocess == 'resize_and_crop'
assert opt.batch_size == 1
if not opt.no_fid:
assert opt.real_stat_path is not None
if opt.phase == 'train':
warnings.warn('You are using training set for inference.')
if __name__ == '__main__':
opt = TestOptions().parse()
print(' '.join(sys.argv))
dataloader = create_dataloader(opt)
model = create_model(opt)
model.setup(opt)
web_dir = opt.results_dir # define the website directory
webpage = html.HTML(web_dir, 'restore_G_path: %s' % (opt.restore_G_path))
fakes, names = [], []
for i, data in enumerate(tqdm.tqdm(dataloader)):
model.set_input(data) # unpack data from data loader
if i == 0 and opt.need_profile:
model.profile()
model.test() # run inference
visuals = model.get_current_visuals() # get image results
generated = visuals['fake_B'].cpu()
fakes.append(generated)
for path in model.get_image_paths():
short_path = ntpath.basename(path)
name = os.path.splitext(short_path)[0]
names.append(name)
if i < opt.num_test:
save_images(webpage, visuals, model.get_image_paths(), opt)
webpage.save() # save the HTML
device = copy.deepcopy(model.device)
del model
torch.cuda.empty_cache()
inception_model, drn_model = create_metric_models(opt, device)
if inception_model is not None:
npz = np.load(opt.real_stat_path)
fid = get_fid(fakes, inception_model, npz, device, opt.batch_size)
print('fid score: %.2f' % fid, flush=True)
if drn_model is not None:
mIoU = get_cityscapes_mIoU(fakes, names, drn_model, device,
data_dir=opt.cityscapes_path,
batch_size=opt.batch_size,
num_workers=opt.num_threads)
print('mIoU: %.2f' % mIoU)