-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataloader.py
33 lines (27 loc) · 1.34 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torchvision
import torchvision.transforms as transforms
def get_data_loaders(args):
"""get dataloaders of training set and test set of cifar10 from torchvision apis"""
normalize = transforms.Normalize((.4914, .4822, .4465), (.2470, .2435, .2616))
# randomly crop or flip for data augmentation
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
transform_test = transforms.Compose([
transforms.ToTensor(),
normalize,
])
kwargs = {'num_workers': 0, 'pin_memory': True}
trainset = torchvision.datasets.CIFAR10(root=args.data_dir, train=True, download=True,
transform=transform_train)
testset = torchvision.datasets.CIFAR10(root=args.data_dir, train=False, download=True,
transform=transform_test)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
shuffle=True, **kwargs)
testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size * 8,
shuffle=False, **kwargs)
return trainloader, testloader