diff --git a/utils.py b/utils.py index 2086affc1..5c9eb83f1 100644 --- a/utils.py +++ b/utils.py @@ -59,8 +59,8 @@ def transform_mnist(data, label): mnist_train = gluon.data.vision.FashionMNIST(root=root, train=True, transform=None) mnist_test = gluon.data.vision.FashionMNIST(root=root, train=False, transform=None) - # Transform later to avoid memory explosion. - train_data = DataLoader(mnist_train, batch_size, shuffle=True, transform=transform_mnist) + # Transform later to avoid memory explosion. + train_data = DataLoader(mnist_train, batch_size, shuffle=True, transform=transform_mnist) test_data = DataLoader(mnist_test, batch_size, shuffle=False, transform=transform_mnist) return (train_data, test_data)