-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathpredict.py
75 lines (55 loc) · 2.42 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os, sys
import argparse
import numpy as np
import torch
import torchvision.transforms as t
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets.folder import default_loader
from tqdm import tqdm
from alexnet import KitModel as AlexNet
from vgg19 import KitModel as VGG19
class ImageListDataset (Dataset):
def __init__(self, list_filename, root=None, transform=None):
super(ImageListDataset).__init__()
with open(list_filename, 'r') as list_file:
self.list = list(map(str.rstrip, list_file))
self.root = root
self.transform = transform
def __getitem__(self, index):
path = self.list[index]
if self.root:
path = os.path.join(self.root, path)
x = default_loader(path)
if self.transform:
x = self.transform(x)
return x
def __len__(self):
return len(self.list)
def main(args):
transform = t.Compose([
t.Resize((224, 224)),
t.ToTensor(),
t.Lambda(lambda x: x[[2,1,0], ...] * 255), # RGB -> BGR and [0,1] -> [0,255]
t.Normalize(mean=[116.8007, 121.2751, 130.4602], std=[1,1,1]), # mean subtraction
])
data = ImageListDataset(args.image_list, root=args.root, transform=transform)
dataloader = DataLoader(data, batch_size=args.batch_size, num_workers=8, pin_memory=True)
model = AlexNet if 'hybrid' in args.model else VGG19
model = model('converted-models/{}.pth'.format(args.model)).to('cuda')
model.eval()
with torch.no_grad():
for x in tqdm(dataloader):
p = model(x.to('cuda')).cpu().numpy() # order is (NEG, NEU, POS)
np.savetxt(sys.stdout.buffer, p, delimiter=',')
if __name__ == '__main__':
models = ('hybrid_finetuned_fc6+',
'hybrid_finetuned_all',
'vgg19_finetuned_fc6+',
'vgg19_finetuned_all')
parser = argparse.ArgumentParser(description='Predict Visual Sentiment')
parser.add_argument('image_list', type=str, help='Image list (txt, one path per line)')
parser.add_argument('-r', '--root', default=None, help='Root path to prepend to image list')
parser.add_argument('-m', '--model', type=str, choices=models, default='vgg19_finetuned_all', help='Pretrained model')
parser.add_argument('-b', '--batch-size', type=int, default=48, help='Batch size')
args = parser.parse_args()
main(args)