-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmnist.py
26 lines (23 loc) · 798 Bytes
/
mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import os
import gzip
import pickle
def get_data():
path = os.environ["MNIST_PKL_GZ"]
if not os.path.exists(path):
try:
import urllib
urllib.urlretrieve('http://google.com')
except AttributeError:
import urllib.request as urllib
url = 'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz'
urllib.urlretrieve(url, path)
f = gzip.open(path, 'rb')
try:
split = pickle.load(f, encoding="latin1")
except TypeError:
split = pickle.load(f)
f.close()
which_sets = "train valid test".split()
return dict((which_set, dict(features=x.astype("float32"),
targets=y.astype("int32")))
for which_set, (x, y) in zip(which_sets, split))