Skip to content

Commit

Permalink
Merge pull request #50 from cc-ai/2domain_masker
Browse files Browse the repository at this point in the history
[WIP] From 4 to 2 domains, and adding masker functionality
  • Loading branch information
melisandeteng authored May 28, 2020
2 parents dac7514 + 5711f71 commit 7b88449
Show file tree
Hide file tree
Showing 17 changed files with 317 additions and 265 deletions.
47 changes: 10 additions & 37 deletions omnigan/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def __init__(
self.activation = nn.SELU(inplace=True)
elif activation == "tanh":
self.activation = nn.Tanh()
elif activation == "sigmoid":
self.activation = nn.Sigmoid()
elif activation == "none":
self.activation = None
else:
Expand All @@ -74,14 +76,10 @@ def __init__(
# initialize convolution
if norm == "spectral":
self.conv = SpectralNorm(
nn.Conv2d(
input_dim, output_dim, kernel_size, stride, bias=self.use_bias
)
nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
)
else:
self.conv = nn.Conv2d(
input_dim, output_dim, kernel_size, stride, bias=self.use_bias
)
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)

def forward(self, x):
x = self.conv(self.pad(x))
Expand Down Expand Up @@ -123,15 +121,9 @@ def __init__(self, dim, norm="in", activation="relu", pad_type="zero"):
self.activation = activation
model = []
model += [
Conv2dBlock(
dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type
)
]
model += [
Conv2dBlock(
dim, dim, 3, 1, 1, norm=norm, activation="none", pad_type=pad_type
)
Conv2dBlock(dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)
]
model += [Conv2dBlock(dim, dim, 3, 1, 1, norm=norm, activation="none", pad_type=pad_type)]
self.model = nn.Sequential(*model)

def forward(self, x):
Expand All @@ -157,6 +149,7 @@ def __init__(
res_norm="instance",
activ="relu",
pad_type="zero",
output_activ="tanh",
):
super().__init__()

Expand All @@ -166,28 +159,14 @@ def __init__(
self.model += [
nn.Upsample(scale_factor=2),
Conv2dBlock(
dim,
dim // 2,
5,
1,
2,
norm="layer",
activation=activ,
pad_type=pad_type,
dim, dim // 2, 5, 1, 2, norm="layer", activation=activ, pad_type=pad_type,
),
]
dim //= 2
# use reflection padding in the last conv layer
self.model += [
Conv2dBlock(
dim,
output_dim,
7,
1,
3,
norm="none",
activation="tanh",
pad_type=pad_type,
dim, output_dim, 7, 1, 3, norm="none", activation=output_activ, pad_type=pad_type,
)
]
self.model = nn.Sequential(*self.model)
Expand All @@ -207,13 +186,7 @@ def __str__(self):
# 0ff661e on 13 Apr 2019
class SPADEResnetBlock(nn.Module):
def __init__(
self,
fin,
fout,
cond_nc,
spade_use_spectral_norm,
spade_param_free_norm,
spade_kernel_size,
self, fin, fout, cond_nc, spade_use_spectral_norm, spade_param_free_norm, spade_kernel_size,
):
super().__init__()
# Attributes
Expand Down
7 changes: 2 additions & 5 deletions omnigan/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, latent_space, loss):
nn.AvgPool2d((int(self.feature_size / 4), int(self.feature_size / 4))),
Squeeze(-1),
Squeeze(-1),
nn.Linear(int(self.channels / 4), 4),
nn.Linear(int(self.channels / 4), 2),
]
)

Expand Down Expand Up @@ -84,9 +84,7 @@ def __init__(
self.stride = stride
self.downsample = downsample
if stride != 1 or inplanes != planes:
self.downsample = nn.Sequential(
conv1x1(inplanes, planes, stride), norm_layer(planes)
)
self.downsample = nn.Sequential(conv1x1(inplanes, planes, stride), norm_layer(planes))

def forward(self, x):
identity = x
Expand Down Expand Up @@ -163,4 +161,3 @@ def conv1x1(in_planes, out_planes, stride=1):
Default: 1 (default: {1})
"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

18 changes: 12 additions & 6 deletions omnigan/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def pil_image_loader(path, task):
arr = arr.astype(np.float32)
arr[arr != 0] = 1 / arr[arr != 0]

if task == 'm':
arr[arr != 0] = 255
#Make sure mask is single-channel
if len(arr.shape) >= 3:
arr = arr[:,:,0]
# if task == "s":
# arr = decode_segmap(arr)

Expand All @@ -96,9 +101,7 @@ def __init__(self, mode, domain, opts, transform=None):
self.mode = mode
self.tasks = set(opts.tasks)
self.tasks.add("x")

file_list_path = Path(opts.data.files[mode][domain])

if "/" not in str(file_list_path):
file_list_path = Path(opts.data.files.base) / Path(
opts.data.files[mode][domain]
Expand All @@ -116,6 +119,8 @@ def __init__(self, mode, domain, opts, transform=None):
self.file_list_path = str(file_list_path)
self.transform = transform



def filter_samples(self):
"""
Filter out data which is not required for the model's tasks
Expand Down Expand Up @@ -144,6 +149,7 @@ def __getitem__(self, i):
"""
paths = self.samples_paths[i]


if self.transform:
return {
"data": self.transform(
Expand All @@ -156,7 +162,7 @@ def __getitem__(self, i):

return {
"data": {
task: pil_image_loader(path, task) for task, path in paths.items()
task: pil_image_loader(path, task) for task, path in paths.items()
},
"paths": paths,
"domain": self.domain,
Expand Down Expand Up @@ -184,10 +190,10 @@ def check_samples(self):
assert Path(v).exists(), f"{k} {v} does not exist"


def get_loader(domain, mode, opts):
def get_loader(mode, domain, opts):
return DataLoader(
OmniListDataset(
domain, mode, opts, transform=transforms.Compose(get_transforms(opts))
mode, domain, opts, transform=transforms.Compose(get_transforms(opts))
),
batch_size=opts.data.loaders.get("batch_size", 4),
# shuffle=opts.data.loaders.get("shuffle", True),
Expand All @@ -200,7 +206,7 @@ def get_all_loaders(opts):
loaders = {}
for mode in ["train", "val"]:
loaders[mode] = {}
for domain in ["rf", "rn", "sf", "sn"]:
for domain in ["r", "s"]:
if mode in opts.data.files:
if domain in opts.data.files[mode]:
loaders[mode][domain] = get_loader(mode, domain, opts)
Expand Down
59 changes: 18 additions & 41 deletions omnigan/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,7 @@ def __init__(self, opts, latent_shape=None, verbose=None):
# call set_translation_decoder(latent_shape, device)
else:
self.decoders["t"] = nn.ModuleDict(
{
"f": BaseTranslationDecoder(opts),
"n": BaseTranslationDecoder(opts),
}
{"f": BaseTranslationDecoder(opts), "n": BaseTranslationDecoder(opts),}
)

if "a" in opts.tasks and not opts.gen.a.ignore:
Expand All @@ -84,8 +81,8 @@ def __init__(self, opts, latent_shape=None, verbose=None):
if "s" in opts.tasks and not opts.gen.s.ignore:
self.decoders["s"] = SegmentationDecoder(opts)

if "w" in opts.tasks and not opts.gen.w.ignore:
self.decoders["w"] = WaterDecoder(opts)
if "m" in opts.tasks and not opts.gen.m.ignore:
self.decoders["m"] = MaskDecoder(opts)

self.decoders = nn.ModuleDict(self.decoders)

Expand Down Expand Up @@ -137,11 +134,7 @@ def translate_batch(self, batch, translator="f", z=None):
return y

def decode_tasks(self, z):
return {
task: self.decoders[task](z)
for task in self.opts.tasks
if task not in {"t", "a"}
}
return {task: self.decoders[task](z) for task in self.opts.tasks if task not in {"t", "a"}}

def encode(self, x):
return self.encoder.forward(x)
Expand Down Expand Up @@ -196,29 +189,16 @@ def __init__(self, opts):
pad_type = opts.gen.encoder.pad_type

self.model = [
Conv2dBlock(
input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type
)
Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)
]
# downsampling blocks
for i in range(n_downsample):
self.model += [
Conv2dBlock(
dim,
2 * dim,
4,
2,
1,
norm=norm,
activation=activ,
pad_type=pad_type,
)
Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type,)
]
dim *= 2
# residual blocks
self.model += [
ResBlocks(n_res, dim, norm=res_norm, activation=activ, pad_type=pad_type)
]
self.model += [ResBlocks(n_res, dim, norm=res_norm, activation=activ, pad_type=pad_type)]
self.model = nn.Sequential(*self.model)
self.output_dim = dim

Expand All @@ -242,16 +222,17 @@ def __init__(self, opts):
)


class WaterDecoder(BaseDecoder):
class MaskDecoder(BaseDecoder):
def __init__(self, opts):
super().__init__(
opts.gen.w.n_upsample,
opts.gen.w.n_res,
opts.gen.w.res_dim,
opts.gen.w.output_dim,
res_norm=opts.gen.w.res_norm,
activ=opts.gen.w.activ,
pad_type=opts.gen.w.pad_type,
opts.gen.m.n_upsample,
opts.gen.m.n_res,
opts.gen.m.res_dim,
opts.gen.m.output_dim,
res_norm=opts.gen.m.res_norm,
activ=opts.gen.m.activ,
pad_type=opts.gen.m.pad_type,
output_activ="sigmoid",
)


Expand Down Expand Up @@ -327,9 +308,7 @@ def __getitem__(self, key):
return self._model

def forward(self, *args, **kwargs):
raise NotImplementedError(
"Cannot forward the SpadeTranslationDict, chose a domain"
)
raise NotImplementedError("Cannot forward the SpadeTranslationDict, chose a domain")

def __str__(self):
return str(self._model).strip()
Expand Down Expand Up @@ -373,9 +352,7 @@ def update_bit(self, key):

def concat_bit_to_seg(self, seg):
bit = get_4D_bit(seg.shape, self.bit)
return torch.cat(
[bit.to(torch.float32).to(seg.device), seg.to(torch.float32)], dim=1
)
return torch.cat([bit.to(torch.float32).to(seg.device), seg.to(torch.float32)], dim=1)

def forward(self, z, seg):
if self.use_bit_conditioning:
Expand Down
9 changes: 9 additions & 0 deletions omnigan/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ def __call__(self, logits, target):
return self.loss(logits, target.to(logits.device))


class BinaryCrossEntropy(nn.Module):
def __init__(self):
super().__init__()
self.loss = torch.nn.BCELoss()

def __call__(self, logits, target):
return self.loss(logits, target.to(logits.device))


class PixelCrossEntropy(CrossEntropy):
"""
Computes the cross entropy per pixel
Expand Down
Loading

0 comments on commit 7b88449

Please sign in to comment.