-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_mnist.py
67 lines (56 loc) · 2.02 KB
/
run_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import mnist
import numpy as np
import pickle
import cnn
training_images = mnist.train_images()
training_labels = mnist.train_labels()
## uncomment below to train mnist images as RGB data
# import cv2
# training_images_rgb = []
# for i, image in enumerate(training_images):
# training_images_rgb.append(cv2.cvtColor(image, cv2.COLOR_GRAY2RGB))
# training_images = np.array(training_images_rgb)
classes = [x for x in range(10)]
# initialize
net = None
answer = input("Would you like to load a model? (enter 'y' to load): ")
should_load = answer == 'y'
if should_load:
filename = input("Enter a filename (without the extension): ")
pickle_in = open(f'{filename}.pickle','rb')
net = pickle.load(pickle_in)
else:
layers = [
cnn.layers.Conv(num_kernels=16),
cnn.layers.MaxPool(),
cnn.layers.SoftMax(num_classes=10),
]
net = cnn.CNN(layers)
# train
answer = input("Would you like to train? (enter 'y' to train): ")
should_train = answer == 'y'
if should_train:
net.train(training_images, training_labels, classes, num_epochs=5, rate=0.005)
# predict
answer = input("Would you like to test the model? (enter 'y' to test): ")
should_test = answer == 'y'
if should_test:
print('\n\n>>> Testing model...\n')
test_images = mnist.test_images()[:1000]
test_labels = mnist.test_labels()[:1000]
num_correct = 0
for image, label in zip(test_images, test_labels):
prediction_index = net.predict(image)
prediction = classes[prediction_index]
correct_add = 1 if prediction == label else 0
num_correct += correct_add
num_tests = len(test_images)
percent_accurate = round(((num_correct / num_tests) * 100), 3)
print(f'Prediction accuracy ({num_tests} attempts): {percent_accurate}%\n')
# save model
answer = input("Would you like to save the model? (enter 'y' to save): ")
should_save = answer == 'y'
if should_save:
filename = input("Enter a filename (without the extension): ")
with open(f'{filename}.pickle','wb') as f:
pickle.dump(net, f)