-
Notifications
You must be signed in to change notification settings - Fork 0
/
datasource.py
45 lines (36 loc) · 1.06 KB
/
datasource.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import os
import random
import numpy as np
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, RandomSampler, Sampler
# reproducible setup for testing
seed = 42
random.seed(seed)
np.random.seed(seed)
def _dataloader_init_fn():
np.random.seed(seed)
def get_mnist_dataset(is_train_dataset: bool = True) -> MNIST:
"""
Prepare MNIST dataset
"""
return MNIST(
os.getcwd(),
download=True,
transform=transforms.ToTensor(),
train=is_train_dataset,
)
def get_loader(is_train_set: bool = True) -> DataLoader:
"""
Prepare MNIST train dataset loader
"""
_dataset = get_mnist_dataset(is_train_dataset=is_train_set)
_dataset_sampler = RandomSampler(_dataset)
return _get_loader(dataset=_dataset, dataset_sampler=_dataset_sampler)
def _get_loader(dataset: MNIST, dataset_sampler: Sampler) -> DataLoader:
return DataLoader(
dataset,
batch_size=10,
sampler=dataset_sampler,
worker_init_fn=_dataloader_init_fn,
)