Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] From 4 to 2 domains, and adding masker functionality #50

Merged
merged 6 commits into from
May 28, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions omnigan/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, latent_space, loss):
nn.AvgPool2d((int(self.feature_size / 4), int(self.feature_size / 4))),
Squeeze(-1),
Squeeze(-1),
nn.Linear(int(self.channels / 4), 4),
nn.Linear(int(self.channels / 4), 2),
]
)

Expand Down Expand Up @@ -84,9 +84,7 @@ def __init__(
self.stride = stride
self.downsample = downsample
if stride != 1 or inplanes != planes:
self.downsample = nn.Sequential(
conv1x1(inplanes, planes, stride), norm_layer(planes)
)
self.downsample = nn.Sequential(conv1x1(inplanes, planes, stride), norm_layer(planes))

def forward(self, x):
identity = x
Expand Down Expand Up @@ -163,4 +161,3 @@ def conv1x1(in_planes, out_planes, stride=1):
Default: 1 (default: {1})
"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

18 changes: 12 additions & 6 deletions omnigan/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def pil_image_loader(path, task):
arr = arr.astype(np.float32)
arr[arr != 0] = 1 / arr[arr != 0]

if task == 'm':
arr[arr != 0] = 1 / arr[arr != 0]
#Make sure mask is RGB for the sake of transforms
51N84D marked this conversation as resolved.
Show resolved Hide resolved
if len(arr.shape) >= 3:
arr = arr[:,:,0]
51N84D marked this conversation as resolved.
Show resolved Hide resolved
# if task == "s":
# arr = decode_segmap(arr)

Expand All @@ -96,9 +101,7 @@ def __init__(self, mode, domain, opts, transform=None):
self.mode = mode
self.tasks = set(opts.tasks)
self.tasks.add("x")

file_list_path = Path(opts.data.files[mode][domain])

if "/" not in str(file_list_path):
file_list_path = Path(opts.data.files.base) / Path(
opts.data.files[mode][domain]
Expand All @@ -116,6 +119,8 @@ def __init__(self, mode, domain, opts, transform=None):
self.file_list_path = str(file_list_path)
self.transform = transform



def filter_samples(self):
"""
Filter out data which is not required for the model's tasks
Expand Down Expand Up @@ -144,6 +149,7 @@ def __getitem__(self, i):
"""
paths = self.samples_paths[i]


if self.transform:
return {
"data": self.transform(
Expand All @@ -156,7 +162,7 @@ def __getitem__(self, i):

return {
"data": {
task: pil_image_loader(path, task) for task, path in paths.items()
task: pil_image_loader(path, task) for task, path in paths.items()
},
"paths": paths,
"domain": self.domain,
Expand All @@ -183,10 +189,10 @@ def check_samples(self):
assert Path(v).exists(), f"{k} {v} does not exist"


def get_loader(domain, mode, opts):
def get_loader(mode, domain, opts):
vict0rsch marked this conversation as resolved.
Show resolved Hide resolved
return DataLoader(
OmniListDataset(
domain, mode, opts, transform=transforms.Compose(get_transforms(opts))
mode, domain, opts, transform=transforms.Compose(get_transforms(opts))
),
batch_size=opts.data.loaders.get("batch_size", 4),
# shuffle=opts.data.loaders.get("shuffle", True),
Expand All @@ -199,7 +205,7 @@ def get_all_loaders(opts):
loaders = {}
for mode in ["train", "val"]:
loaders[mode] = {}
for domain in ["rf", "rn", "sf", "sn"]:
for domain in ["r", "s"]:
if mode in opts.data.files:
if domain in opts.data.files[mode]:
loaders[mode][domain] = get_loader(mode, domain, opts)
Expand Down
54 changes: 23 additions & 31 deletions omnigan/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,7 @@ def __init__(self, opts, latent_shape=None, verbose=None):
# call set_translation_decoder(latent_shape, device)
else:
self.decoders["t"] = nn.ModuleDict(
{
"f": BaseTranslationDecoder(opts),
"n": BaseTranslationDecoder(opts),
}
{"f": BaseTranslationDecoder(opts), "n": BaseTranslationDecoder(opts),}
)

if "a" in opts.tasks and not opts.gen.a.ignore:
Expand All @@ -87,6 +84,9 @@ def __init__(self, opts, latent_shape=None, verbose=None):
if "w" in opts.tasks and not opts.gen.w.ignore:
self.decoders["w"] = WaterDecoder(opts)

if "m" in opts.tasks and not opts.gen.m.ignore:
vict0rsch marked this conversation as resolved.
Show resolved Hide resolved
self.decoders["m"] = MaskDecoder(opts)

self.decoders = nn.ModuleDict(self.decoders)

def set_translation_decoder(self, latent_shape, device):
Expand Down Expand Up @@ -137,11 +137,7 @@ def translate_batch(self, batch, translator="f", z=None):
return y

def decode_tasks(self, z):
return {
task: self.decoders[task](z)
for task in self.opts.tasks
if task not in {"t", "a"}
}
return {task: self.decoders[task](z) for task in self.opts.tasks if task not in {"t", "a"}}

def encode(self, x):
return self.encoder.forward(x)
Expand Down Expand Up @@ -196,29 +192,16 @@ def __init__(self, opts):
pad_type = opts.gen.encoder.pad_type

self.model = [
Conv2dBlock(
input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type
)
Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)
]
# downsampling blocks
for i in range(n_downsample):
self.model += [
Conv2dBlock(
dim,
2 * dim,
4,
2,
1,
norm=norm,
activation=activ,
pad_type=pad_type,
)
Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type,)
]
dim *= 2
# residual blocks
self.model += [
ResBlocks(n_res, dim, norm=res_norm, activation=activ, pad_type=pad_type)
]
self.model += [ResBlocks(n_res, dim, norm=res_norm, activation=activ, pad_type=pad_type)]
self.model = nn.Sequential(*self.model)
self.output_dim = dim

Expand Down Expand Up @@ -255,6 +238,19 @@ def __init__(self, opts):
)


class MaskDecoder(BaseDecoder):
def __init__(self, opts):
super().__init__(
opts.gen.w.n_upsample,
opts.gen.w.n_res,
51N84D marked this conversation as resolved.
Show resolved Hide resolved
opts.gen.w.res_dim,
opts.gen.w.output_dim,
res_norm=opts.gen.w.res_norm,
activ=opts.gen.w.activ,
pad_type=opts.gen.w.pad_type,
)


class DepthDecoder(BaseDecoder):
def __init__(self, opts):
super().__init__(
Expand Down Expand Up @@ -327,9 +323,7 @@ def __getitem__(self, key):
return self._model

def forward(self, *args, **kwargs):
raise NotImplementedError(
"Cannot forward the SpadeTranslationDict, chose a domain"
)
raise NotImplementedError("Cannot forward the SpadeTranslationDict, chose a domain")

def __str__(self):
return str(self._model).strip()
Expand Down Expand Up @@ -373,9 +367,7 @@ def update_bit(self, key):

def concat_bit_to_seg(self, seg):
bit = get_4D_bit(seg.shape, self.bit)
return torch.cat(
[bit.to(torch.float32).to(seg.device), seg.to(torch.float32)], dim=1
)
return torch.cat([bit.to(torch.float32).to(seg.device), seg.to(torch.float32)], dim=1)

def forward(self, z, seg):
if self.use_bit_conditioning:
Expand Down
9 changes: 9 additions & 0 deletions omnigan/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ def __call__(self, logits, target):
return self.loss(logits, target.to(logits.device))


class BinaryCrossEntropy(nn.Module):
def __init__(self):
super().__init__()
self.loss = torch.nn.BCELoss()

def __call__(self, logits, target):
return self.loss(logits, target.to(logits.device))


class PixelCrossEntropy(CrossEntropy):
"""
Computes the cross entropy per pixel
Expand Down
Loading