forked from junyanz/pytorch-CycleGAN-and-pix2pix
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
junyanz
committed
Jun 13, 2017
1 parent
3b72a65
commit e6858e3
Showing
21 changed files
with
259 additions
and
245 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import os.path | ||
import random | ||
import torchvision.transforms as transforms | ||
import torch | ||
from data.base_dataset import BaseDataset | ||
from data.image_folder import make_dataset | ||
from PIL import Image | ||
|
||
|
||
class AlignedDataset(BaseDataset): | ||
def initialize(self, opt): | ||
self.opt = opt | ||
self.root = opt.dataroot | ||
self.dir_AB = os.path.join(opt.dataroot, opt.phase) | ||
|
||
self.AB_paths = sorted(make_dataset(self.dir_AB)) | ||
|
||
assert(opt.resize_or_crop == 'resize_and_crop') | ||
|
||
transform_list = [transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), | ||
(0.5, 0.5, 0.5))] | ||
|
||
self.transform = transforms.Compose(transform_list) | ||
|
||
def __getitem__(self, index): | ||
AB_path = self.AB_paths[index] | ||
AB = Image.open(AB_path).convert('RGB') | ||
AB = AB.resize((self.opt.loadSize * 2, self.opt.loadSize), Image.BICUBIC) | ||
AB = self.transform(AB) | ||
|
||
w_total = AB.size(2) | ||
w = int(w_total / 2) | ||
h = AB.size(1) | ||
w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1)) | ||
h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1)) | ||
|
||
A = AB[:, h_offset:h_offset + self.opt.fineSize, | ||
w_offset:w_offset + self.opt.fineSize] | ||
B = AB[:, h_offset:h_offset + self.opt.fineSize, | ||
w + w_offset:w + w_offset + self.opt.fineSize] | ||
|
||
if (not self.opt.no_flip) and random.random() < 0.5: | ||
idx = [i for i in range(A.size(2) - 1, -1, -1)] | ||
idx = torch.LongTensor(idx) | ||
A = A.index_select(2, idx) | ||
B = B.index_select(2, idx) | ||
|
||
return {'A': A, 'B': B, | ||
'A_paths': AB_path, 'B_paths': AB_path} | ||
|
||
def __len__(self): | ||
return len(self.AB_paths) | ||
|
||
def name(self): | ||
return 'AlignedDataset' |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import torch.utils.data as data | ||
|
||
class BaseDataset(data.Dataset): | ||
def __init__(self): | ||
super(BaseDataset, self).__init__() | ||
|
||
def name(self): | ||
return 'BaseDataset' | ||
|
||
def initialize(self, opt): | ||
pass | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import torch.utils.data | ||
from data.base_data_loader import BaseDataLoader | ||
|
||
|
||
def CreateDataset(opt): | ||
dataset = None | ||
if opt.dataset_mode == 'aligned': | ||
from data.aligned_dataset import AlignedDataset | ||
dataset = AlignedDataset() | ||
elif opt.dataset_mode == 'unaligned': | ||
from data.unaligned_dataset import UnalignedDataset | ||
dataset = UnalignedDataset() | ||
elif opt.dataset_mode == 'single': | ||
from data.single_dataset import SingleDataset | ||
dataset = SingleDataset() | ||
else: | ||
raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode) | ||
|
||
print("dataset [%s] was created" % (dataset.name())) | ||
dataset.initialize(opt) | ||
return dataset | ||
|
||
|
||
class CustomDatasetDataLoader(BaseDataLoader): | ||
def name(self): | ||
return 'CustomDatasetDataLoader' | ||
|
||
def initialize(self, opt): | ||
BaseDataLoader.initialize(self, opt) | ||
self.dataset = CreateDataset(opt) | ||
self.dataloader = torch.utils.data.DataLoader( | ||
self.dataset, | ||
batch_size=opt.batchSize, | ||
shuffle=not opt.serial_batches, | ||
num_workers=int(opt.nThreads)) | ||
|
||
def load_data(self): | ||
return self.dataloader | ||
|
||
def __len__(self): | ||
return min(len(self.dataset), self.opt.max_dataset_size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,7 @@ | ||
|
||
def CreateDataLoader(opt): | ||
data_loader = None | ||
if opt.align_data > 0: | ||
from data.aligned_data_loader import AlignedDataLoader | ||
data_loader = AlignedDataLoader() | ||
else: | ||
from data.unaligned_data_loader import UnalignedDataLoader | ||
data_loader = UnalignedDataLoader() | ||
from data.custom_dataset_data_loader import CustomDatasetDataLoader | ||
data_loader = CustomDatasetDataLoader() | ||
print(data_loader.name()) | ||
data_loader.initialize(opt) | ||
return data_loader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import os.path | ||
import torchvision.transforms as transforms | ||
from data.base_dataset import BaseDataset | ||
from data.image_folder import make_dataset | ||
from PIL import Image | ||
|
||
|
||
class SingleDataset(BaseDataset): | ||
def initialize(self, opt): | ||
self.opt = opt | ||
self.root = opt.dataroot | ||
self.dir_A = os.path.join(opt.dataroot) | ||
|
||
self.A_paths = make_dataset(self.dir_A) | ||
|
||
self.A_paths = sorted(self.A_paths) | ||
|
||
transform_list = [] | ||
if opt.resize_or_crop == 'resize_and_crop': | ||
transform_list.append(transforms.Scale(opt.loadSize)) | ||
|
||
if opt.isTrain and not opt.no_flip: | ||
transform_list.append(transforms.RandomHorizontalFlip()) | ||
|
||
if opt.resize_or_crop != 'no_resize': | ||
transform_list.append(transforms.RandomCrop(opt.fineSize)) | ||
|
||
transform_list += [transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), | ||
(0.5, 0.5, 0.5))] | ||
|
||
self.transform = transforms.Compose(transform_list) | ||
|
||
def __getitem__(self, index): | ||
A_path = self.A_paths[index] | ||
|
||
A_img = Image.open(A_path).convert('RGB') | ||
|
||
A_img = self.transform(A_img) | ||
|
||
return {'A': A_img, 'A_paths': A_path} | ||
|
||
def __len__(self): | ||
return len(self.A_paths) | ||
|
||
def name(self): | ||
return 'SingleImageDataset' |
Oops, something went wrong.