-
Notifications
You must be signed in to change notification settings - Fork 1
/
02b_tf_mnist.py
104 lines (82 loc) · 2.33 KB
/
02b_tf_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import numpy as np
import tensorflow as tf
img_placeholder = tf.placeholder(
tf.float32,
shape = (None, 28, 28),
name = "img"
)
rescaled_image = (img_placeholder - 128) / 128
flattened_img = tf.reshape(
rescaled_image,
(-1, 28 * 28)
)
weight_matrix1 = tf.Variable(
tf.random_normal(
shape = (28 * 28, 128),
stddev = np.sqrt(1/(28 *28))
),
name = 'weight_matrix1',
)
bias_vector1 = tf.Variable(
tf.zeros(shape = (128,)),
name = 'bias_vector1',
)
z1 = tf.matmul(flattened_img, weight_matrix1) + bias_vector1
h1 = tf.nn.sigmoid(z1)
weight_matrix2 = tf.Variable(
tf.random_normal(
shape = (128, 10),
stddev = np.sqrt(1/128)
),
name = 'weight_matrix2',
)
bias_vector2 = tf.Variable(
tf.zeros(shape = (10,)),
name = 'bias_vector2',
)
z2 = tf.matmul(h1, weight_matrix2) + bias_vector2
h2 = tf.nn.softmax(z2)
correct_y = tf.placeholder(tf.int32, shape = (None,))
one_hot_correct_y = tf.one_hot(correct_y, 10)
eps = 1e-6
error = tf.reduce_mean(
tf.reduce_sum(one_hot_correct_y * -tf.log(h2 + eps), axis = 1)
)
predictions = tf.argmax(
h2,
axis = 1,
output_type = tf.int32
)
accuracy = tf.reduce_mean(
tf.cast(tf.equal(correct_y, predictions), tf.float32)
)
optimizer = tf.train.GradientDescentOptimizer(
learning_rate = 1e-2
)
train_step = optimizer.minimize(error)
session = tf.Session()
session.run(tf.global_variables_initializer())
BATCH_SIZE = 128
from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
for epoch_idx in range(1, 101):
for batch_idx, batch_start in enumerate(range(0, x_train.shape[0], BATCH_SIZE)):
x_batch = x_train[batch_start:(batch_start + BATCH_SIZE)]
y_batch = y_train[batch_start:(batch_start + BATCH_SIZE)]
_, e, a = session.run(
[train_step, error, accuracy],
feed_dict = {
img_placeholder: x_batch,
correct_y: y_batch
}
)
if batch_idx % 100 == 0:
print(f'epoch: {epoch_idx:04d} | batch {batch_idx:04d} | err: {e:0.1f} | acc: {a:0.2f}')
e, a = session.run(
[error, accuracy],
feed_dict = {
img_placeholder: x_test,
correct_y: y_test
}
)
print(f'>> epoch: {epoch_idx:04d} | err: {e:0.1f} | acc: {a:0.2f}')