Skip to content

Commit

Permalink
Merge pull request #212 from astonzhang/nin
Browse files Browse the repository at this point in the history
Add transform=None in utils.DataLoader
  • Loading branch information
astonzhang authored Feb 28, 2018
2 parents f111a6c + f120361 commit 8e7dff3
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class DataLoader(object):
time. But the limits are 1) all examples in dataset have the same shape, 2)
data transfomer needs to process multiple examples at each time
"""
def __init__(self, dataset, batch_size, shuffle, transform):
def __init__(self, dataset, batch_size, shuffle, transform=None):
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
Expand Down Expand Up @@ -47,7 +47,7 @@ def __len__(self):
def load_data_fashion_mnist(batch_size, resize=None, root="~/.mxnet/datasets/fashion-mnist"):
"""download the fashion mnist dataest and then load into memory"""
def transform_mnist(data, label):
# transform a batch of examples
# Transform a batch of examples.
if resize:
n = data.shape[0]
new_data = nd.zeros((n, resize, resize, data.shape[3]))
Expand All @@ -56,11 +56,12 @@ def transform_mnist(data, label):
data = new_data
# change data from batch x height x width x channel to batch x channel x height x width
return nd.transpose(data.astype('float32'), (0,3,1,2))/255, label.astype('float32')

mnist_train = gluon.data.vision.FashionMNIST(root=root, train=True, transform=None)
mnist_test = gluon.data.vision.FashionMNIST(root=root, train=False, transform=None)
train_data = DataLoader(mnist_train, batch_size, shuffle=True, transform = transform_mnist)
test_data = DataLoader(mnist_test, batch_size, shuffle=False, transform = transform_mnist)
# Transform later to avoid memory explosion.
train_data = DataLoader(mnist_train, batch_size, shuffle=True, transform=transform_mnist)
test_data = DataLoader(mnist_test, batch_size, shuffle=False, transform=transform_mnist)
return (train_data, test_data)

def try_gpu():
Expand Down

0 comments on commit 8e7dff3

Please sign in to comment.