-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathRead_MNIST.py
93 lines (72 loc) · 2.67 KB
/
Read_MNIST.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import struct
from array import array
class MNIST(object):
def __init__(self, path='.'):
self.path = path
self.test_img_fname = 't10k-images-idx3-ubyte'
self.test_lbl_fname = 't10k-labels-idx1-ubyte'
self.train_img_fname = 'train-images-idx3-ubyte'
self.train_lbl_fname = 'train-labels-idx1-ubyte'
self.test_images = []
self.test_labels = []
self.train_images = []
self.train_labels = []
def load_testing(self):
ims, labels = self.load(os.path.join(self.path, self.test_img_fname),
os.path.join(self.path, self.test_lbl_fname))
self.test_images = ims
self.test_labels = labels
return ims, labels
def load_training(self):
ims, labels = self.load(os.path.join(self.path, self.train_img_fname),
os.path.join(self.path, self.train_lbl_fname))
self.train_images = ims
self.train_labels = labels
return ims, labels
@classmethod
def load(cls, path_img, path_lbl):
with open(path_lbl, 'rb') as file:
magic, size = struct.unpack(">II", file.read(8))
if magic != 2049:
raise ValueError('Magic number mismatch, expected 2049,'
'got %d' % magic)
labels = array("B", file.read())
with open(path_img, 'rb') as file:
magic, size, rows, cols = struct.unpack(">IIII", file.read(16))
if magic != 2051:
raise ValueError('Magic number mismatch, expected 2051,'
'got %d' % magic)
image_data = array("B", file.read())
images = []
for i in range(size):
images.append([0]*rows*cols)
for i in range(size):
images[i][:] = image_data[i*rows*cols : (i+1)*rows*cols]
return images, labels
def test(self):
test_img, test_label = self.load_testing()
train_img, train_label = self.load_training()
assert len(test_img) == len(test_label)
assert len(test_img) == 10000
assert len(train_img) == len(train_label)
assert len(train_img) == 60000
print ('Showing num:%d' % train_label[0])
print (self.display(train_img[0]))
print
return True
@classmethod
def display(cls, img, width=28):
render = ''
for i in range(len(img)):
if i % width == 0: render += '\n'
if img[i] > 200:
render += '1'
else:
render += '0'
return render
if __name__ == "__main__":
print ('Testing')
mn = MNIST('.')
if mn.test():
print ('Passed')