-
Notifications
You must be signed in to change notification settings - Fork 55
/
Copy patheval.py
69 lines (51 loc) · 2.1 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from torch.autograd import Variable
import torch.nn as nn
import skimage
import skimage.io as io
from torchvision import transforms
import numpy as np
import scipy.io as scio
from modelNetM import EncoderNet, DecoderNet, ClassNet, EPELoss
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
model_en = EncoderNet([1,1,1,1,2])
model_de = DecoderNet([1,1,1,1,2])
model_class = ClassNet()
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model_en = nn.DataParallel(model_en)
model_de = nn.DataParallel(model_de)
model_class = nn.DataParallel(model_class)
if torch.cuda.is_available():
model_en = model_en.cuda()
model_de = model_de.cuda()
model_class = model_class.cuda()
model_en.load_state_dict(torch.load('model_en.pkl'))
model_de.load_state_dict(torch.load('model_de.pkl'))
model_class.load_state_dict(torch.load('model_class.pkl'))
model_en.eval()
model_de.eval()
model_class.eval()
testImgPath = '/home/xliea/Dataset256/Dataset256/test/distorted'
saveFlowPath = '/home/xliea/test/flow_256/flow_cla'
correct = 0
for index, types in enumerate(['barrel','pincushion','rotation','shear','projective','wave']):
for k in range(50000,55000):
imgPath = '%s%s%s%s%s%s' % (testImgPath, '/',types,'_', str(k).zfill(6), '.jpg')
disimgs = io.imread(imgPath)
disimgs = transform(disimgs)
use_GPU = torch.cuda.is_available()
if use_GPU:
disimgs = disimgs.cuda()
disimgs = disimgs.view(1,3,256,256)
disimgs = Variable(disimgs)
middle = model_en(disimgs)
flow_output = model_de(middle)
clas = model_class(middle)
_, predicted = torch.max(clas.data, 1)
if predicted.cpu().numpy()[0] == index:
correct += 1
u = flow_output.data.cpu().numpy()[0][0]
v = flow_output.data.cpu().numpy()[0][1]
saveMatPath = '%s%s%s%s%s%s' % (saveFlowPath, '/',types,'_', str(k).zfill(6), '.mat')
scio.savemat(saveMatPath, {'u': u,'v': v})