-
Notifications
You must be signed in to change notification settings - Fork 1
/
dataset.py
66 lines (58 loc) · 2.47 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import glob
import numpy as np
from tqdm import tqdm
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from skimage import io, transform
from PIL import Image
trans = transforms.Compose([transforms.RandomCrop(256),
transforms.ToTensor()])
class PreprocessDataset(Dataset):
def __init__(self, content_dir, style_dir, transforms=trans):
content_dir_resized = content_dir + '_resized'
style_dir_resized = style_dir + '_resized'
if not (os.path.exists(content_dir_resized) and
os.path.exists(style_dir_resized)):
os.mkdir(content_dir_resized)
os.mkdir(style_dir_resized)
self._resize(content_dir, content_dir_resized)
self._resize(style_dir, style_dir_resized)
content_images = glob.glob((content_dir_resized + '/*'))
np.random.shuffle(content_images)
style_images = glob.glob(style_dir_resized + '/*')
np.random.shuffle(style_images)
self.images_pairs = list(zip(content_images, style_images))
self.transforms = transforms
@staticmethod
def _resize(source_dir, target_dir):
print(f'Start resizing {source_dir} ')
for i in tqdm(os.listdir(source_dir)):
filename = os.path.basename(i)
try:
image = io.imread(os.path.join(source_dir, i))
if len(image.shape) == 3 and image.shape[-1] == 3:
H, W, _ = image.shape
if H < W:
ratio = W / H
H = 512
W = int(ratio * H)
else:
ratio = H / W
W = 512
H = int(ratio * W)
image = transform.resize(image, (H, W), mode='reflect', anti_aliasing=True)
io.imsave(os.path.join(target_dir, filename), image)
except:
continue
def __len__(self):
return len(self.images_pairs)
def __getitem__(self, index):
content_image, style_image = self.images_pairs[index]
content_image = Image.open(content_image)
style_image = Image.open(style_image)
if self.transforms:
content_image = self.transforms(content_image)
style_image = self.transforms(style_image)
return content_image, style_image