-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdata.py
62 lines (50 loc) · 2.5 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
# todo: set or infer the output size
class Loader(object):
def __init__(self, dataset_ident, file_path, download, shuffle, batch_size, data_transform, target_transform, use_cuda):
kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {}
# set the dataset
# NOTE: will need a refractor one we load more different datasets, that require custom classes
loader_map = {
'mnist': datasets.MNIST,
'MNIST': datasets.MNIST,
'FashionMNIST': datasets.FashionMNIST,
'fashion': datasets.FashionMNIST
}
num_class = {
'mnist': 10,
'MNIST': 10,
'fashion': 10,
'FashionMNIST': 10
}
# Get the datasets
train_dataset, test_dataset = self.get_dataset(loader_map[dataset_ident], file_path, download,
data_transform, target_transform)
# Set the loaders
self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, **kwargs)
self.test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle, **kwargs)
# infer and set size, idea from:
# https://github.com/jramapuram/helpers/
tmp_batch, _ = self.train_loader.__iter__().__next__()
self.img_shape = list(tmp_batch.size())[1:]
self.num_class = num_class[dataset_ident]
self.batch_size = batch_size
@staticmethod
def get_dataset(dataset, file_path, download, data_transform, target_transform):
# Check for transform to be None, a single item, or a list
# None -> default to transform_list = [transforms.ToTensor()]
# single item -> list
if not data_transform:
data_transform = [transforms.ToTensor()]
elif not isinstance(data_transform, list):
data_transform = list(data_transform)
# Training and Validation datasets
train_dataset = dataset(file_path, train=True, download=download,
transform=transforms.Compose(data_transform),
target_transform=target_transform)
test_dataset = dataset(file_path, train=False, download=download,
transform=transforms.Compose(data_transform),
target_transform=target_transform)
return train_dataset, test_dataset