forked from jclh/image-classifier-PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
77 lines (61 loc) · 2.43 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# PROGRAMMER: JC Lopez
# DATE CREATED: 08/09/2018
# REVISED DATE: 05/21/2019
# PURPOSE: Predict flower name from image along with the inferred
# probability of that category.
#
# BASIC USAGE:
# python predict.py <path to image> <checkpoint>
# --top_k <number of most likely classes>
# --category_names <mapping of categories to real names>
# --gpu
# Example basic usage:
# python predict.py in_image checkpoint --top_k 3 --gpu
# --category_names cat_to_name.json
# Imports python modules
from time import time, sleep
from utility_fs_predict import *
from model_functions import *
from PIL import Image
def main():
# Collect start time
start_time = time()
# Define get_input_args() function to create
# and retrieve command line arguments
in_args = get_input_args()
print_input_args(in_args)
# If in_args.input == random_test, pick random test image
if in_args.input == 'random_test':
image_path = random_test_img(test_dir='flowers/test/')
else:
image_path = in_args.input
# Load model from checkpoint
model = load_checkpoint(in_args.checkpoint)
# Open image as PIL object
pil_image = Image.open(image_path)
# Process PIL image to normalized Numpy array
np_image = process_image(pil_image)
# Resize array to match dataloader output size
# and return torch tensor
img_tensor = image_to_tensor(np_image)
# Resize array to match dataloader output size
# and return torch tensor
top_probs, classes = predict(model, img_tensor,
in_args.top_k, in_args.gpu)
# Import dictionary of keys = class number (as in data folders)
# and values = flower names (in words)
class_name_dict = class_to_name(filename=in_args.category_names)
flower_names = [class_name_dict[key] for key in classes]
print('\n Filepath to image: ', image_path, '\n',
'\n Classes: ', classes,
'\n Flower names: ', flower_names,
'\n Probabilities: ', top_probs)
# Define end_time to measure total program runtime
end_time = time()
tot_time = end_time - start_time
print('\n** Total Elapsed Runtime:', tot_time, '\n')
# Return the flower name and class probability
return classes, top_probs
# Call to main function to run the program
if __name__ == "__main__":
main()