-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtest_pix2pix.py
61 lines (54 loc) · 2.3 KB
/
test_pix2pix.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
import argparse
import pandas as pd
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms as T
from pathgan.data import MPRDataset
from pathgan.models import Generator
if __name__ == '__main__':
parser = argparse.ArgumentParser(prog = 'top', description='Testing Pix2Pix GAN (our GAN)')
parser.add_argument('--checkpoint_path', default=None, help='Load directory to continue training (default: "None")')
parser.add_argument('--batch_size', type=int, default=1, help='"Batch size" with which GAN will be trained (default: 1)')
parser.add_argument('--save_dir', default='results/pix2pix', help='Save directory (default: "results/pix2pix")')
parser.add_argument('--device', type=str, default='cuda:0', help='Device (default: "cuda:0")')
args = parser.parse_args()
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
transform = T.Compose([
T.ToTensor(),
T.Normalize(
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5),
),
])
df = pd.read_csv('dataset/test.csv')
dataset = MPRDataset(
map_dir = 'dataset/maps',
point_dir = 'dataset/tasks',
roi_dir = 'dataset/tasks',
csv_file = df,
transform = transform,
)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
generator = Generator()
print('=========== Loading weights for Generator ===========')
generator.load_state_dict(torch.load(args.checkpoint_path, map_location="cpu"))
generator = generator.to(device)
generator = generator.eval()
print('============== Testing Started ==============')
os.makedirs(args.save_dir, exist_ok=True)
for i, (maps, points, rois) in enumerate(tqdm(dataloader)):
maps = maps.to(device)
points = points.to(device)
with torch.no_grad():
pred_rois = generator(maps, points).detach().cpu()[0]
pred_rois = pred_rois.permute(1,2,0).numpy()
pred_rois = (pred_rois > 0).astype(np.uint8) * 255
roi_img = Image.fromarray(pred_rois)
roi_path = os.path.join(args.save_dir, f"roi_{i}.png")
roi_img.save(roi_path)
print('============== Testing Finished! ==============')