-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
117 lines (94 loc) · 3.52 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import time
import json
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch import nn, optim
from torchvision import datasets, transforms, models
from collections import OrderedDict
from PIL import Image
from get_input_args import get_input_args, check_device
import gc
from train import load_data
def main():
input_args = get_input_args()
train_dataloader, valid_dataloader, test_dataloader, image_datasets = load_data(input_args.data_dir)
model = load_checkpoint(input_args.input, image_datasets)
device = check_device(input_args.gpu)
with open('cat_to_name.json', 'r') as f:
cat_to_name = json.load(f)
img = process_image(input_args.image_path)
imshow(img)
probs , classes = predict(img, model, device, cat_to_name, topk=5)
print(probs)
print(classes)
show_top_5(input_args.image_path, model,device, cat_to_name)
#load checkpoint
def load_checkpoint(filepath, image_datasets):
checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage)
model = models.densenet161(pretrained=True)
model.classifier = checkpoint['model.classifier']
model.load_state_dict(checkpoint['state_dict'])
model.class_to_idx = image_datasets[0].class_to_idx
return model
#process_image
def process_image(image):
pil_image = Image.open(image, 'r')
pil_image.thumbnail((256, 256))
pil_image = pil_image.crop((16, 16, 240, 240))
np_image = np.array(pil_image)
transform_image = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
np_image = transform_image(np_image).float()
return np_image
#imgshow
def imshow(image, ax=None, title=None):
"""Imshow for Tensor."""
if ax is None:
fig, ax = plt.subplots()
# PyTorch tensors assume the color channel is the first dimension
# but matplotlib assumes is the third dimension
image = image.numpy().transpose((1, 2, 0))
# Undo preprocessing
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
image = std * image + mean
# Image needs to be clipped between 0 and 1 or it looks like noise when displayed
image = np.clip(image, 0, 1)
ax.imshow(image)
return ax
#predict
def predict(img, model,device,cat_to_name, topk=5):
model.to(device)
model.eval()
with torch.no_grad():
image = img
image = image.type(torch.FloatTensor).to(device)
image = image.unsqueeze(0)
output = model.forward(image)
ps = F.softmax(output, dim=1)
top_ps, top_classes = ps.topk(topk, dim=1)
top_p = top_ps[0]
idx_to_class = {val: cat_to_name[k] for k, val in model.class_to_idx.items()}
top_class = [idx_to_class[i] for i in top_classes[0].cpu().numpy()]
return top_p, top_class
#show top 5
def show_top_5(path, model,device, cat_to_name):
plt.figure(figsize=(3, 6))
pl = plt.subplot(2, 1, 1)
image = process_image(path)
title = path.split('/')
name = cat_to_name[title[2]]
print(name)
imshow(image, pl, name)
score, flowers_list = predict(image, model, 'cpu', cat_to_name)
fig, pl = plt.subplots(figsize=(4, 3))
sticks = np.arange(len(flowers_list))
pl.barh(sticks, score, height=0.3, linewidth=2.0, align='center')
pl.set_yticks(ticks=sticks)
pl.set_yticklabels(flowers_list)
plt.show()
if __name__ == "__main__":
main()