diff --git a/omnigan/blocks.py b/omnigan/blocks.py index da78b2df..592ffa2b 100644 --- a/omnigan/blocks.py +++ b/omnigan/blocks.py @@ -66,6 +66,8 @@ def __init__( self.activation = nn.SELU(inplace=True) elif activation == "tanh": self.activation = nn.Tanh() + elif activation == "sigmoid": + self.activation = nn.Sigmoid() elif activation == "none": self.activation = None else: @@ -74,14 +76,10 @@ def __init__( # initialize convolution if norm == "spectral": self.conv = SpectralNorm( - nn.Conv2d( - input_dim, output_dim, kernel_size, stride, bias=self.use_bias - ) + nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) ) else: - self.conv = nn.Conv2d( - input_dim, output_dim, kernel_size, stride, bias=self.use_bias - ) + self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) def forward(self, x): x = self.conv(self.pad(x)) @@ -123,15 +121,9 @@ def __init__(self, dim, norm="in", activation="relu", pad_type="zero"): self.activation = activation model = [] model += [ - Conv2dBlock( - dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type - ) - ] - model += [ - Conv2dBlock( - dim, dim, 3, 1, 1, norm=norm, activation="none", pad_type=pad_type - ) + Conv2dBlock(dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type) ] + model += [Conv2dBlock(dim, dim, 3, 1, 1, norm=norm, activation="none", pad_type=pad_type)] self.model = nn.Sequential(*model) def forward(self, x): @@ -157,6 +149,7 @@ def __init__( res_norm="instance", activ="relu", pad_type="zero", + output_activ="tanh", ): super().__init__() @@ -166,28 +159,14 @@ def __init__( self.model += [ nn.Upsample(scale_factor=2), Conv2dBlock( - dim, - dim // 2, - 5, - 1, - 2, - norm="layer", - activation=activ, - pad_type=pad_type, + dim, dim // 2, 5, 1, 2, norm="layer", activation=activ, pad_type=pad_type, ), ] dim //= 2 # use reflection padding in the last conv layer self.model += [ Conv2dBlock( - dim, - output_dim, - 7, - 1, - 3, - norm="none", - activation="tanh", - pad_type=pad_type, + dim, output_dim, 7, 1, 3, norm="none", activation=output_activ, pad_type=pad_type, ) ] self.model = nn.Sequential(*self.model) @@ -207,13 +186,7 @@ def __str__(self): # 0ff661e on 13 Apr 2019 class SPADEResnetBlock(nn.Module): def __init__( - self, - fin, - fout, - cond_nc, - spade_use_spectral_norm, - spade_param_free_norm, - spade_kernel_size, + self, fin, fout, cond_nc, spade_use_spectral_norm, spade_param_free_norm, spade_kernel_size, ): super().__init__() # Attributes diff --git a/omnigan/classifier.py b/omnigan/classifier.py index 65a8b0cc..09638439 100644 --- a/omnigan/classifier.py +++ b/omnigan/classifier.py @@ -31,7 +31,7 @@ def __init__(self, latent_space, loss): nn.AvgPool2d((int(self.feature_size / 4), int(self.feature_size / 4))), Squeeze(-1), Squeeze(-1), - nn.Linear(int(self.channels / 4), 4), + nn.Linear(int(self.channels / 4), 2), ] ) @@ -84,9 +84,7 @@ def __init__( self.stride = stride self.downsample = downsample if stride != 1 or inplanes != planes: - self.downsample = nn.Sequential( - conv1x1(inplanes, planes, stride), norm_layer(planes) - ) + self.downsample = nn.Sequential(conv1x1(inplanes, planes, stride), norm_layer(planes)) def forward(self, x): identity = x @@ -163,4 +161,3 @@ def conv1x1(in_planes, out_planes, stride=1): Default: 1 (default: {1}) """ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) - diff --git a/omnigan/data.py b/omnigan/data.py index d91ebf7c..849a8fa0 100644 --- a/omnigan/data.py +++ b/omnigan/data.py @@ -81,6 +81,11 @@ def pil_image_loader(path, task): arr = arr.astype(np.float32) arr[arr != 0] = 1 / arr[arr != 0] + if task == 'm': + arr[arr != 0] = 255 + #Make sure mask is single-channel + if len(arr.shape) >= 3: + arr = arr[:,:,0] # if task == "s": # arr = decode_segmap(arr) @@ -96,9 +101,7 @@ def __init__(self, mode, domain, opts, transform=None): self.mode = mode self.tasks = set(opts.tasks) self.tasks.add("x") - file_list_path = Path(opts.data.files[mode][domain]) - if "/" not in str(file_list_path): file_list_path = Path(opts.data.files.base) / Path( opts.data.files[mode][domain] @@ -116,6 +119,8 @@ def __init__(self, mode, domain, opts, transform=None): self.file_list_path = str(file_list_path) self.transform = transform + + def filter_samples(self): """ Filter out data which is not required for the model's tasks @@ -144,6 +149,7 @@ def __getitem__(self, i): """ paths = self.samples_paths[i] + if self.transform: return { "data": self.transform( @@ -156,7 +162,7 @@ def __getitem__(self, i): return { "data": { - task: pil_image_loader(path, task) for task, path in paths.items() + task: pil_image_loader(path, task) for task, path in paths.items() }, "paths": paths, "domain": self.domain, @@ -183,10 +189,10 @@ def check_samples(self): assert Path(v).exists(), f"{k} {v} does not exist" -def get_loader(domain, mode, opts): +def get_loader(mode, domain, opts): return DataLoader( OmniListDataset( - domain, mode, opts, transform=transforms.Compose(get_transforms(opts)) + mode, domain, opts, transform=transforms.Compose(get_transforms(opts)) ), batch_size=opts.data.loaders.get("batch_size", 4), # shuffle=opts.data.loaders.get("shuffle", True), @@ -199,7 +205,7 @@ def get_all_loaders(opts): loaders = {} for mode in ["train", "val"]: loaders[mode] = {} - for domain in ["rf", "rn", "sf", "sn"]: + for domain in ["r", "s"]: if mode in opts.data.files: if domain in opts.data.files[mode]: loaders[mode][domain] = get_loader(mode, domain, opts) diff --git a/omnigan/generator.py b/omnigan/generator.py index 56dc4f22..a8ef03a9 100644 --- a/omnigan/generator.py +++ b/omnigan/generator.py @@ -64,10 +64,7 @@ def __init__(self, opts, latent_shape=None, verbose=None): # call set_translation_decoder(latent_shape, device) else: self.decoders["t"] = nn.ModuleDict( - { - "f": BaseTranslationDecoder(opts), - "n": BaseTranslationDecoder(opts), - } + {"f": BaseTranslationDecoder(opts), "n": BaseTranslationDecoder(opts),} ) if "a" in opts.tasks and not opts.gen.a.ignore: @@ -84,8 +81,8 @@ def __init__(self, opts, latent_shape=None, verbose=None): if "s" in opts.tasks and not opts.gen.s.ignore: self.decoders["s"] = SegmentationDecoder(opts) - if "w" in opts.tasks and not opts.gen.w.ignore: - self.decoders["w"] = WaterDecoder(opts) + if "m" in opts.tasks and not opts.gen.m.ignore: + self.decoders["m"] = MaskDecoder(opts) self.decoders = nn.ModuleDict(self.decoders) @@ -137,11 +134,7 @@ def translate_batch(self, batch, translator="f", z=None): return y def decode_tasks(self, z): - return { - task: self.decoders[task](z) - for task in self.opts.tasks - if task not in {"t", "a"} - } + return {task: self.decoders[task](z) for task in self.opts.tasks if task not in {"t", "a"}} def encode(self, x): return self.encoder.forward(x) @@ -196,29 +189,16 @@ def __init__(self, opts): pad_type = opts.gen.encoder.pad_type self.model = [ - Conv2dBlock( - input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type - ) + Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type) ] # downsampling blocks for i in range(n_downsample): self.model += [ - Conv2dBlock( - dim, - 2 * dim, - 4, - 2, - 1, - norm=norm, - activation=activ, - pad_type=pad_type, - ) + Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type,) ] dim *= 2 # residual blocks - self.model += [ - ResBlocks(n_res, dim, norm=res_norm, activation=activ, pad_type=pad_type) - ] + self.model += [ResBlocks(n_res, dim, norm=res_norm, activation=activ, pad_type=pad_type)] self.model = nn.Sequential(*self.model) self.output_dim = dim @@ -242,16 +222,17 @@ def __init__(self, opts): ) -class WaterDecoder(BaseDecoder): +class MaskDecoder(BaseDecoder): def __init__(self, opts): super().__init__( - opts.gen.w.n_upsample, - opts.gen.w.n_res, - opts.gen.w.res_dim, - opts.gen.w.output_dim, - res_norm=opts.gen.w.res_norm, - activ=opts.gen.w.activ, - pad_type=opts.gen.w.pad_type, + opts.gen.m.n_upsample, + opts.gen.m.n_res, + opts.gen.m.res_dim, + opts.gen.m.output_dim, + res_norm=opts.gen.m.res_norm, + activ=opts.gen.m.activ, + pad_type=opts.gen.m.pad_type, + output_activ="sigmoid", ) @@ -327,9 +308,7 @@ def __getitem__(self, key): return self._model def forward(self, *args, **kwargs): - raise NotImplementedError( - "Cannot forward the SpadeTranslationDict, chose a domain" - ) + raise NotImplementedError("Cannot forward the SpadeTranslationDict, chose a domain") def __str__(self): return str(self._model).strip() @@ -373,9 +352,7 @@ def update_bit(self, key): def concat_bit_to_seg(self, seg): bit = get_4D_bit(seg.shape, self.bit) - return torch.cat( - [bit.to(torch.float32).to(seg.device), seg.to(torch.float32)], dim=1 - ) + return torch.cat([bit.to(torch.float32).to(seg.device), seg.to(torch.float32)], dim=1) def forward(self, z, seg): if self.use_bit_conditioning: diff --git a/omnigan/losses.py b/omnigan/losses.py index 3185974e..3ebadc47 100644 --- a/omnigan/losses.py +++ b/omnigan/losses.py @@ -80,6 +80,15 @@ def __call__(self, logits, target): return self.loss(logits, target.to(logits.device)) +class BinaryCrossEntropy(nn.Module): + def __init__(self): + super().__init__() + self.loss = torch.nn.BCELoss() + + def __call__(self, logits, target): + return self.loss(logits, target.to(logits.device)) + + class PixelCrossEntropy(CrossEntropy): """ Computes the cross entropy per pixel diff --git a/omnigan/trainer.py b/omnigan/trainer.py index 0b3587e0..740eab84 100644 --- a/omnigan/trainer.py +++ b/omnigan/trainer.py @@ -14,6 +14,7 @@ from omnigan.discriminator import get_dis from omnigan.generator import get_gen from omnigan.losses import ( + BinaryCrossEntropy, CrossEntropy, PixelCrossEntropy, L1Loss, @@ -32,12 +33,16 @@ slice_batch, shuffle_batch_tuple, get_conditioning_tensor, + get_num_params, ) +import torch.nn as nn +import torchvision.utils as vutils class Trainer: """Main trainer class """ + def __init__(self, opts, comet_exp=None, verbose=0): """Trainer class to gather various model training procedures such as training evaluating saving and logging @@ -73,7 +78,7 @@ def __init__(self, opts, comet_exp=None, verbose=0): if isinstance(comet_exp, Experiment): self.exp = comet_exp - def log_losses(self, model_to_update="G"): + def log_losses(self, model_to_update="G", mode="train"): """Logs metrics on comet.ml Args: @@ -85,22 +90,23 @@ def log_losses(self, model_to_update="G"): if self.exp is None: return - assert model_to_update in { - "G", - "D", - "C", - }, "unknown model to log losses {}".format(model_to_update) + assert model_to_update in {"G", "D", "C",}, "unknown model to log losses {}".format( + model_to_update + ) losses = self.logger.losses.copy() + if self.opts.train.log_level == 1: # Only log aggregated losses: delete other keys in losses for k in self.logger.losses: if k not in {"representation", "generator", "translation"}: del losses[k] # convert losses into a single-level dictionnary + losses = flatten_opts(losses) + self.exp.log_metrics( - losses, prefix=model_to_update, step=self.logger.global_step + losses, prefix=f"{model_to_update}_{mode}", step=self.logger.global_step ) def batch_to_device(self, b): @@ -188,6 +194,9 @@ def set_losses(self): if "w" in self.opts.tasks: self.losses["G"]["tasks"]["w"] = lambda x, y: (x + y).mean() + if "m" in self.opts.tasks: + self.losses["G"]["tasks"]["m"] = nn.BCELoss() + # undistinguishable features loss # TODO setup a get_losses func to assign the right loss according to the yaml if self.opts.classifier.loss == "l1": @@ -228,13 +237,14 @@ def setup(self): self.output_size = self.latent_shape[0] * 2 ** self.opts.gen.t.spade_n_up self.G.set_translation_decoder(self.latent_shape, self.device) self.D = get_dis(self.opts, verbose=self.verbose).to(self.device) - self.C = get_classifier(self.opts, self.latent_shape, verbose=self.verbose).to( - self.device - ) + self.C = get_classifier(self.opts, self.latent_shape, verbose=self.verbose).to(self.device) self.P = {"s": get_mega_model()} # P => pseudo labeling models self.g_opt, self.g_scheduler = get_optimizer(self.G, self.opts.gen.opt) - self.d_opt, self.d_scheduler = get_optimizer(self.D, self.opts.dis.opt) + if get_num_params(self.D) > 0: + self.d_opt, self.d_scheduler = get_optimizer(self.D, self.opts.dis.opt) + else: + self.d_opt, self.d_scheduler = None, None self.c_opt, self.c_scheduler = get_optimizer(self.C, self.opts.classifier.opt) self.set_losses() @@ -242,11 +252,23 @@ def setup(self): if self.verbose > 0: for mode, mode_dict in self.loaders.items(): for domain, domain_loader in mode_dict.items(): - print( - "Loader {} {} : {}".format( - mode, domain, len(domain_loader.dataset) - ) - ) + print("Loader {} {} : {}".format(mode, domain, len(domain_loader.dataset))) + + # Create display images: + print("Creating display images...", end="", flush=True) + + if type(self.opts.comet.display_size) == int: + display_indices = range(self.opts.comet.display_size) + else: + display_indices = self.opts.comet.display_size + + self.display_images = {} + for mode, mode_dict in self.loaders.items(): + self.display_images[mode] = {} + for domain, domain_loader in mode_dict.items(): + self.display_images[mode][domain] = [ + Dict(self.loaders[mode][domain].dataset[i]) for i in display_indices + ] self.is_setup = True @@ -254,9 +276,7 @@ def g_opt_step(self): """Run an optimizing step ; if using ExtraAdam, there needs to be an extrapolation step every other step """ - if "extra" in self.opts.gen.opt.optimizer.lower() and ( - self.logger.global_step % 2 == 0 - ): + if "extra" in self.opts.gen.opt.optimizer.lower() and (self.logger.global_step % 2 == 0): self.g_opt.extrapolation() else: self.g_opt.step() @@ -265,9 +285,7 @@ def d_opt_step(self): """Run an optimizing step ; if using ExtraAdam, there needs to be an extrapolation step every other step """ - if "extra" in self.opts.dis.opt.optimizer.lower() and ( - self.logger.global_step % 2 == 0 - ): + if "extra" in self.opts.dis.opt.optimizer.lower() and (self.logger.global_step % 2 == 0): self.d_opt.extrapolation() else: self.d_opt.step() @@ -324,25 +342,88 @@ def run_epoch(self): # (batch_domain_0, ..., batch_domain_i) # and send it to self.device print( - "\rEpoch {} batch {} step {}".format( - self.logger.epoch, i, self.logger.global_step - ) + "\rEpoch {} batch {} step {}".format(self.logger.epoch, i, self.logger.global_step) ) + multi_batch_tuple = shuffle_batch_tuple(multi_batch_tuple) multi_domain_batch = { - batch["domain"][0]: self.batch_to_device(batch) - for batch in multi_batch_tuple + batch["domain"][0]: self.batch_to_device(batch) for batch in multi_batch_tuple } + self.update_g(multi_domain_batch) - self.update_d(multi_domain_batch) + if self.d_opt is not None: + self.update_d(multi_domain_batch) self.update_c(multi_domain_batch) self.logger.global_step += 1 if self.should_freeze_representation(): freeze(self.G.encoder) # ? Freeze decoders != t for memory management purposes ; faster ? self.representation_is_frozen = True + + self.log_comet_images("train", "r") + self.log_comet_images("train", "s") self.update_learning_rates() + def log_comet_images(self, mode, domain): + + save_images = {} + + for im_set in self.display_images[mode][domain]: + x = im_set["data"]["x"].unsqueeze(0).to(self.device) + # print(im_set["data"].items()) + # print("x: ", x.shape) + self.z = self.G.encode(x) + + for update_task, update_target in im_set["data"].items(): + target = im_set["data"][update_task].unsqueeze(0).to(self.device) + task_saves = [] + if update_task != "x": + if update_task not in save_images: + save_images[update_task] = [] + prediction = self.G.decoders[update_task](self.z) + + if update_task in {"m"}: + prediction = prediction.repeat(1, 3, 1, 1) + task_saves.append(x * (1.0 - prediction)) + task_saves.append(x * (1.0 - target.repeat(1, 3, 1, 1))) + task_saves.append(prediction) + #! This assumes the output is some kind of image + save_images[update_task].append(x) + for im in task_saves: + save_images[update_task].append(im) + + for task in save_images.keys(): + # Write images: + self.write_images( + image_outputs=save_images[task], + mode=mode, + domain=domain, + task=task, + im_per_row=4, + comet_exp=self.exp, + ) + + return 0 + + def write_images(self, image_outputs, mode, domain, task, im_per_row=3, comet_exp=None): + """Save output image + Arguments: + image_outputs {Tensor list} -- list of output images + im_per_row {int} -- number of images to be displayed (per row) + file_name {str} -- name of the file where to save the images + """ + curr_iter = self.logger.global_step + image_outputs = torch.stack(image_outputs).squeeze() + image_grid = vutils.make_grid( + image_outputs, nrow=im_per_row, normalize=True, scale_each=True + ) + image_grid = image_grid.permute(1, 2, 0).cpu().detach().numpy() + + if comet_exp is not None: + comet_exp.log_image( + image_grid, name=f"{mode}_{domain}_{task}_{str(curr_iter)}", step=curr_iter + ) + def train(self): """For each epoch: * train @@ -353,7 +434,7 @@ def train(self): for self.logger.epoch in range(self.opts.train.epochs): self.run_epoch() - self.eval() + self.eval(verbose=1) self.save() def should_freeze_representation(self): @@ -401,11 +482,12 @@ def update_g(self, multi_domain_batch, verbose=0): self.g_opt.zero_grad() r_loss = t_loss = None - if self.should_compute_r_loss(): - r_loss = self.get_representation_loss(multi_domain_batch) + # For now, always compute "representation loss" + # if self.should_compute_r_loss(): + r_loss = self.get_representation_loss(multi_domain_batch) - if self.should_compute_t_loss(): - t_loss = self.get_translation_loss(multi_domain_batch) + # if self.should_compute_t_loss(): + # t_loss = self.get_translation_loss(multi_domain_batch) assert any(l is not None for l in [r_loss, t_loss]), "Both losses are None" @@ -415,17 +497,19 @@ def update_g(self, multi_domain_batch, verbose=0): if verbose > 0: print("adding r_loss {} to g_loss".format(r_loss)) self.logger.losses.representation = r_loss.item() + if t_loss is not None: g_loss += t_loss if verbose > 0: print("adding t_loss {} to g_loss".format(t_loss)) self.logger.losses.translation = t_loss.item() + if verbose > 0: print("g_loss is {}".format(g_loss)) self.logger.losses.generator = g_loss.item() g_loss.backward() self.g_opt_step() - self.log_losses(model_to_update="G") + self.log_losses(model_to_update="G", mode="train") def get_representation_loss(self, multi_domain_batch): """Only update the representation part of the model, meaning everything @@ -451,6 +535,7 @@ def get_representation_loss(self, multi_domain_batch): # ? loop) or update the networks for each domain sequentially # ? (.backward() and .step() n times)? for batch_domain, batch in multi_domain_batch.items(): + x = batch["data"]["x"] self.z = self.G.encode(x) # --------------------------------- @@ -458,12 +543,14 @@ def get_representation_loss(self, multi_domain_batch): # --------------------------------- # Forward pass through classifier, output : (batch_size, 4) output_classifier = self.C(self.z) + # Cross entropy loss (with sigmoid) with fake labels to fool C update_loss = self.losses["G"]["classifier"]( - output_classifier, - fake_domains_to_class_tensor(batch["domain"], one_hot), + output_classifier, fake_domains_to_class_tensor(batch["domain"], one_hot), ) + step_loss += lambdas.G.classifier * update_loss + # ------------------------------------------------- # ----- task-specific regression losses (2) ----- # ------------------------------------------------- @@ -475,15 +562,13 @@ def get_representation_loss(self, multi_domain_batch): # ? output features classifier prediction = self.G.decoders[update_task](self.z) task_tensors[update_task] = prediction - update_loss = self.losses["G"]["tasks"][update_task]( - prediction, update_target - ) + update_loss = self.losses["G"]["tasks"][update_task](prediction, update_target) + step_loss += lambdas.G[update_task] * update_loss - self.logger.losses.task_loss[update_task][ - batch_domain - ] = update_loss.item() + self.logger.losses.task_loss[update_task][batch_domain] = update_loss.item() - self.debug("get_representation_loss", locals(), 0) + #! Translation and Adaptation components. Ignore for now... + """ # ------------------------------------------------------ # ----- auto-encoding update for translation (3) ----- # ------------------------------------------------------ @@ -508,6 +593,7 @@ def get_representation_loss(self, multi_domain_batch): step_loss += lambdas.G.a.auto * update_loss self.logger.losses.a.auto[batch_domain] = update_loss.item() self.debug("get_representation_loss", locals(), 2) + """ # --------------------------------------------- # ----- Adaptation translation task (4) ----- @@ -518,7 +604,7 @@ def get_representation_loss(self, multi_domain_batch): # ? * how to use noisy labels Alex Lamb ICT (we don't have ground truth in the # ? real world so is it better to use noisy, noisy + ICT or no label in this # ? case?) - + """ # only do this if adaptation is specified in opts if "a" in self.opts.tasks: adaptation_tasks = [] @@ -567,6 +653,8 @@ def get_representation_loss(self, multi_domain_batch): self.logger.losses.a.cycle[ "{} > {}".format(source_domain, target_domain) ] = update_loss.item() + """ + return step_loss def get_translation_loss(self, multi_domain_batch): @@ -641,13 +729,9 @@ def get_translation_loss(self, multi_domain_batch): fake_s = self.G.decoders["s"](fake_z).detach() real_s_labels = torch.argmax(self.G.decoders["s"](real_z).detach(), 1) mask = ( - torch.randint(0, 2, real_s_labels.shape) - .to(torch.float32) - .to(self.device) + torch.randint(0, 2, real_s_labels.shape).to(torch.float32).to(self.device) ) # TODO : load mask - update_loss = ( - self.losses["G"]["t"]["sm"](fake_s, real_s_labels) * mask - ).mean() + update_loss = (self.losses["G"]["t"]["sm"](fake_s, real_s_labels) * mask).mean() step_loss += lambdas.G.t.sm * update_loss self.logger.losses.t.sm[ "{} > {}".format(source_domain, target_domain) @@ -777,38 +861,48 @@ def get_classifier_loss(self, multi_domain_batch): return lambdas.C * loss def eval(self, num_threads=5, verbose=0): - counter = {} + print("*******************EVALUATING***********************") + for i, multi_batch_tuple in enumerate(self.val_loaders): # create a dictionnay (domain => batch) from tuple # (batch_domain_0, ..., batch_domain_i) # and send it to self.device multi_domain_batch = { - batch["domain"][0]: self.batch_to_device(batch) - for batch in multi_batch_tuple + batch["domain"][0]: self.batch_to_device(batch) for batch in multi_batch_tuple } + # ---------------------------------------------- # ----- Infer separately for each domain ----- # ---------------------------------------------- for domain, domain_batch in multi_domain_batch.items(): + + x = domain_batch["data"]["x"] + self.z = self.G.encode(x) # Don't infer if domains has enough images - remaining = self.opts.val.max_log_images - counter.get(domain, 0) - if remaining <= 0: - continue + if verbose > 0: - print("\rInferring batch {} domain {}".format(i, domain), end="") - - translator = "f" if "n" in domain else "n" - domain_batch = slice_batch(domain_batch, remaining) - translated = self.G.translate(domain_batch, translator) - domain_batch["data"]["y"] = translated - multi_domain_batch[domain] = domain_batch - counter[domain] = counter.get(domain, 0) + translated.shape[0] - - write_path = Path(self.opts.output_path) / "eval_images" - step = self.logger.global_step - save_batch(multi_domain_batch, write_path, step, num_threads) - if verbose > 0: - print() + print(f"Inferring batch {i} domain {domain}") + + # translator = "f" if "n" in domain else "n" + # domain_batch = slice_batch(domain_batch, remaining) + # translated = self.G.translate_batch(domain_batch, translator) + # Get task losses: + task_tensors = {} + for update_task, update_target in domain_batch["data"].items(): + # task t (=translation) will be done in get_translation_loss + # task a (=adaptation) and x (=auto-encoding) will be done hereafter + if update_task not in {"t", "a", "x"}: + # ? output features classifier + prediction = self.G.decoders[update_task](self.z) + task_tensors[update_task] = prediction + update_loss = self.losses["G"]["tasks"][update_task]( + prediction, update_target + ) + self.logger.losses.task_loss[update_task][domain] = update_loss.item() + self.log_losses(model_to_update="G", mode="val") + self.log_comet_images("val", "r") + self.log_comet_images("val", "s") + print("******************DONE EVALUATING*********************") def save(self): pass diff --git a/omnigan/transforms.py b/omnigan/transforms.py index 876c11a6..1005fad2 100644 --- a/omnigan/transforms.py +++ b/omnigan/transforms.py @@ -38,10 +38,7 @@ def __call__(self, data): h, w = data["x"].size[-2:] top = np.random.randint(0, h - self.h) left = np.random.randint(0, w - self.w) - - return { - task: TF.crop(im, top, left, self.h, self.w) for task, im in data.items() - } + return {task: TF.crop(im, top, left, self.h, self.w) for task, im in data.items()} class RandomHorizontalFlip: @@ -65,12 +62,11 @@ def __call__(self, data): for task, im in data.items(): if task in {"x", "a"}: new_data[task] = self.ImagetoTensor(im) - elif task in {"h", "d", "w"}: + elif task in {"h", "d", "w", "m"}: new_data[task] = self.MaptoTensor(im) elif task == "s": - new_data[task] = torch.squeeze(torch.from_numpy(np.array(im))).to( - torch.int64 - ) + new_data[task] = torch.squeeze(torch.from_numpy(np.array(im))).to(torch.int64) + return new_data @@ -79,17 +75,18 @@ def __init__(self): self.normImage = trsfs.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # self.normSeg = trsfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) self.normDepth = trsfs.Normalize([1 / 255], [1 / 3]) + self.normMask = lambda x: x self.normalize = { "x": self.normImage, # "s": self.normSeg, "d": self.normDepth, + "m": self.normMask, } def __call__(self, data): return { - task: self.normalize.get(task, lambda x: x)(tensor) - for task, tensor in data.items() + task: self.normalize.get(task, lambda x: x)(tensor) for task, tensor in data.items() } diff --git a/omnigan/tutils.py b/omnigan/tutils.py index 678ea6b6..d8b885ff 100644 --- a/omnigan/tutils.py +++ b/omnigan/tutils.py @@ -1,6 +1,7 @@ """Tensor-utils """ from pathlib import Path + # from copy import copy from threading import Thread @@ -92,17 +93,15 @@ def domains_to_class_tensor(domains, one_hot=False): domain labels in a 2D tensor """ - mapping = {"rf": 0, "rn": 1, "sf": 2, "sn": 3} + mapping = {"r": 0, "s": 1} if not all(domain in mapping for domain in domains): - raise ValueError( - "Unknown domains {} should be in {}".format(domains, list(mapping.keys())) - ) + raise ValueError("Unknown domains {} should be in {}".format(domains, list(mapping.keys()))) target = torch.tensor([mapping[domain] for domain in domains]) if one_hot: - one_hot_target = torch.FloatTensor(len(target), 4) # 4 domains + one_hot_target = torch.FloatTensor(len(target), 2) # 2 domains one_hot_target.zero_() one_hot_target.scatter_(1, target.unsqueeze(1), 1) # https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507 @@ -114,12 +113,12 @@ def fake_domains_to_class_tensor(domains, one_hot=False): """Converts a list of strings to a 1D Tensor representing the fake domains (real or sim only) - fake_domains_to_class_tensor(["sf", "rn"], False) - >>> torch.Tensor([0, 3]) + fake_domains_to_class_tensor(["s", "r"], False) + >>> torch.Tensor([0, 2]) Args: - domain (list(str)): each element of the list should be in {rf, rn, sf, sn} + domain (list(str)): each element of the list should be in {r, s} one_hot (bool, optional): whether or not to 1-h encode class labels. Defaults to False. Raises: @@ -131,17 +130,15 @@ def fake_domains_to_class_tensor(domains, one_hot=False): for each domain). """ if one_hot: - target = torch.FloatTensor(len(domains), 4) + target = torch.FloatTensor(len(domains), 2) target.fill_(0.25) else: - mapping = {"rf": 2, "rn": 3, "sf": 0, "sn": 1} + mapping = {"r": 1, "s": 0} if not all(domain in mapping for domain in domains): raise ValueError( - "Unknown domains {} should be in {}".format( - domains, list(mapping.keys()) - ) + "Unknown domains {} should be in {}".format(domains, list(mapping.keys())) ) target = torch.tensor([mapping[domain] for domain in domains]) @@ -178,11 +175,7 @@ def get_4D_bit(shape, probs): torch.Tensor: batch x # of domains x h x w """ probs = probs if isinstance(probs, torch.Tensor) else torch.tensor(probs) - bit = ( - torch.ones(shape[0], probs.shape[-1], *shape[-2:]) - .to(torch.float32) - .to(probs.device) - ) + bit = torch.ones(shape[0], probs.shape[-1], *shape[-2:]).to(torch.float32).to(probs.device) bit *= probs[None, :, None, None].to(torch.float32) return bit @@ -304,7 +297,7 @@ def save_tanh_tensor(image, path): """ path = Path(path) if isinstance(image, torch.Tensor): - image = image.detach().numpy() + image = image.detach().cpu().numpy() if image.shape[-1] != 3 and image.shape[0] == 3: image = np.transpose(image, (1, 2, 0)) if image.min() < 0 and image.min() > -1: @@ -329,9 +322,7 @@ def save_batch(multi_domain_batch, root="./", step=0, num_threads=5): imtensor = torch.cat([x, y], dim=-1) for i, im in enumerate(imtensor): imid = Path(paths[i]).stem[:10] - images_to_save["paths"] += [ - root / "im_{}_{}_{}.png".format(step, domain, imid) - ] + images_to_save["paths"] += [root / "im_{}_{}_{}.png".format(step, domain, imid)] images_to_save["images"].append(im) if num_threads > 0: threaded_write(images_to_save["images"], images_to_save["paths"], num_threads) @@ -347,17 +338,17 @@ def threaded_write(images, paths, num_threads=5): t_im.append(im) t_p.append(p) if len(t_im) == num_threads: - ts = [ - Thread(target=save_tanh_tensor, args=(_i, _p)) - for _i, _p in zip(t_im, t_p) - ] + ts = [Thread(target=save_tanh_tensor, args=(_i, _p)) for _i, _p in zip(t_im, t_p)] list(map(lambda t: t.start(), ts)) list(map(lambda t: t.join(), ts)) t_im = [] t_p = [] if t_im: - ts = [ - Thread(target=save_tanh_tensor, args=(_i, _p)) for _i, _p in zip(t_im, t_p) - ] + ts = [Thread(target=save_tanh_tensor, args=(_i, _p)) for _i, _p in zip(t_im, t_p)] list(map(lambda t: t.start(), ts)) list(map(lambda t: t.join(), ts)) + + +def get_num_params(model): + total_params = sum(p.numel() for p in model.parameters()) + return total_params diff --git a/shared/trainer/defaults.yaml b/shared/trainer/defaults.yaml index 9f467154..686fc37a 100644 --- a/shared/trainer/defaults.yaml +++ b/shared/trainer/defaults.yaml @@ -2,7 +2,7 @@ output_path: /network/tmp1/schmidtv/yb_runs/test_v1 # ------------------- # ----- Tasks ----- # ------------------- -tasks: [a, d, h, s, t, w] +tasks: [a, d, h, s, t, m] # ---------------- # ----- Data ----- @@ -11,15 +11,12 @@ data: files: # if one is not none it will override the dirs location base: /path/to/data train: - rf: train_rf.json - rn: train_rn.json - sf: train_sf.json - sn: train_sn.json + r: train_r.json + s: train_s.json val: - rf: val_rf.json - rn: val_rn.json - sf: val_sf.json - sn: val_sn.json + r: val_r.json + s: val_s.json + loaders: batch_size: 2 shuffle: false @@ -93,6 +90,9 @@ gen: w: # specific params for the water-segmentation decoder <<: *default-gen output_dim: 1 + m: # specific params for the mask-generation decoder + <<: *default-gen + output_dim: 1 # ------------------------- # ----- Discriminator ----- # ------------------------- @@ -121,6 +121,8 @@ dis: <<: *default-dis w: <<: *default-dis + m: + <<: *default-dis # ------------------------------- # ----- Domain Classifier ----- # ------------------------------- @@ -152,6 +154,7 @@ train: h: 1 s: 1 w: 1 + m: 1 t: auto: 1 # auto-encoding, reconstruction cycle: 1 # cycle consistency @@ -175,3 +178,10 @@ val: store_images: false # write to disk on top of comet logging infer_rec: true infer_idt: true # order: real, translated, rec, idt + + +# ----------------------------- +# ----- Comet Params ---------- +# ----------------------------- +comet: + display_size: 5 diff --git a/tests/test_classifier.py b/tests/test_classifier.py index 4587e495..365e8f08 100644 --- a/tests/test_classifier.py +++ b/tests/test_classifier.py @@ -25,35 +25,35 @@ # ----- Test Setup ----- # ------------------------ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - target_domains = ["rf", "rn", "sf", "sn", "rf"] + target_domains = ["r", "s"] labels = domains_to_class_tensor(target_domains, one_hot=False).to(device) one_hot_labels = domains_to_class_tensor(target_domains, one_hot=True).to(device) - cross_entropy = CrossEntropy() loss_l1 = L1Loss() # ------------------------------ # ----- Test C.forward() ----- # ------------------------------ - z = torch.ones(5, 128, 32, 32).to(device) + z = torch.ones(len(target_domains), 128, 32, 32).to(device) latent_space = (128, 32, 32) C = get_classifier(opts, latent_space, 0).to(device) y = C(z) tprint( - "output of classifier's shape for latent space {} :".format(list(z.shape[1:])), - y.shape, + "output of classifier's shape for latent space {} :".format(list(z.shape[1:])), y.shape, ) # -------------------------------- # ----- Test cross_entropy ----- # -------------------------------- + tprint("CE loss:", cross_entropy(y, labels)) # -------------------------- # ----- Test l1_loss ----- # -------------------------- tprint("l1 loss:", loss_l1(y, one_hot_labels)) + print() - z = torch.ones(5, 256, 64, 64).to(device) + z = torch.ones(len(target_domains), 256, 64, 64).to(device) # ------------------------------------------ # ----- Test different latent shapes ----- # ------------------------------------------ @@ -61,20 +61,18 @@ C = get_classifier(opts, latent_space, 0).to(device) y = C(z) tprint( - "output of classifier's shape for latent space {} :".format(list(z.shape[1:])), - y.shape, + "output of classifier's shape for latent space {} :".format(list(z.shape[1:])), y.shape, ) tprint("CE loss:", cross_entropy(y, labels)) tprint("l1 loss:", loss_l1(y, one_hot_labels)) print() - z = torch.ones(5, 64, 16, 16).to(device) + z = torch.ones(len(target_domains), 64, 16, 16).to(device) latent_space = (64, 16, 16) C = get_classifier(opts, latent_space, 0).to(device) y = C(z) tprint( - "output of classifier's shape for latent space {} :".format(list(z.shape[1:])), - y.shape, + "output of classifier's shape for latent space {} :".format(list(z.shape[1:])), y.shape, ) tprint("CE loss:", cross_entropy(y, labels)) tprint("l1 loss:", loss_l1(y, one_hot_labels)) diff --git a/tests/test_data.py b/tests/test_data.py index 7381b499..8fd01bdb 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -5,7 +5,7 @@ from addict import Dict sys.path.append(str(Path(__file__).parent.parent.resolve())) -from omnigan.data import OmniListDataset, get_all_loaders +from omnigan.data import OmniListDataset, get_all_loaders, get_loader from omnigan.utils import load_test_opts from omnigan.tutils import transforms_string @@ -26,7 +26,7 @@ opts.data.loaders.shuffle = True loaders = get_all_loaders(opts) - ds = OmniListDataset("train", "rn", opts) + ds = OmniListDataset("train", "r", opts) # -------------------------------- # ----- Test task matching ----- @@ -36,16 +36,14 @@ for sample_path in ds.samples_paths: ds_vars = set(sample_path.keys()) assert ds_vars.issubset(tasks) - # ------------------------------------ # ----- Test transforms_string ----- # ------------------------------------ - print(transforms_string(loaders["train"]["rn"].dataset.transform)) + print(transforms_string(loaders["train"]["r"].dataset.transform)) sample = ds[0] - batch = Dict(next(iter(loaders["train"]["rn"]))) - + batch = Dict(next(iter(loaders["train"]["r"]))) print("Batch: ", "items, ", " ".join(batch.keys()), "keys") # ------------------------------- @@ -95,3 +93,6 @@ ) ) multi_domain_batch = {batch["domain"][0]: batch for batch in multi_batch} + + if i > 5: + break diff --git a/tests/test_gen.py b/tests/test_gen.py index d14bcd8f..eec0a090 100644 --- a/tests/test_gen.py +++ b/tests/test_gen.py @@ -13,7 +13,7 @@ parser = argparse.ArgumentParser() -parser.add_argument("-c", "--config", default="config/trainer/local_tests.yaml") +parser.add_argument("-c", "--config", default="config/trainer/local_tests.yaml") args = parser.parse_args() root = Path(__file__).parent.parent opts = load_test_opts(args.config) @@ -92,6 +92,8 @@ else: print(dec, G.decoders[dec](z).shape) + #! Holding off on translation... + """ # -------------------------------------------------------------------- # ----- Test translation depending on use_bit_conditioning and ----- # ----- use_spade ----- @@ -118,3 +120,5 @@ G = get_gen(opts).to(device) G.set_translation_decoder(latent_space_dims, device) print(G.forward(image, translator="f").shape) + """ + diff --git a/tests/test_losses.py b/tests/test_losses.py index af7c5f66..f8d5cbc8 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -13,7 +13,7 @@ parser = argparse.ArgumentParser() -parser.add_argument("-c", "--config", default="config/trainer/local_tests.yaml") +parser.add_argument("-c", "--config", default="config/trainer/local_tests.yaml") args = parser.parse_args() root = Path(__file__).parent.parent opts = load_test_opts(args.config) @@ -28,7 +28,7 @@ opts.data.loaders.shuffle = True device = torch.device("cuda" if torch.cuda.is_available() else "cpu") loaders = get_all_loaders(opts) - batch = next(iter(loaders["train"]["rn"])) + batch = next(iter(loaders["train"]["r"])) image = torch.randn(opts.data.loaders.batch_size, 3, 32, 32).to(device) G = get_gen(opts).to(device) z = G.encode(image) @@ -39,7 +39,7 @@ print_header("test_crossentroy_2d") prediction = G.decoders["s"](z) pce = PixelCrossEntropy() - print(pce(prediction, batch["data"]["s"].to(device))) + print(pce(prediction.squeeze(), batch["data"]["s"].long().squeeze().to(device))) # ! error how to infer from cropped data: input: 224 output: 256?? # TODO more test for the losses diff --git a/tests/test_mega_depth.py b/tests/test_mega_depth.py index 2f1020b5..e75e682a 100644 --- a/tests/test_mega_depth.py +++ b/tests/test_mega_depth.py @@ -14,7 +14,7 @@ from run import print_header parser = argparse.ArgumentParser() -parser.add_argument("-c", "--config", default="config/trainer/local_tests.yaml") +parser.add_argument("-c", "--config", default="config/trainer/local_tests.yaml") args = parser.parse_args() root = Path(__file__).parent.parent opts = load_test_opts(args.config) @@ -30,7 +30,7 @@ not_committed_path.mkdir() mega = get_mega_model().to(device) loaders = get_all_loaders(opts) - loader = loaders["train"]["rn"] + loader = loaders["train"]["r"] batch = next(iter(loader)) # ------------------------- # ----- Test Config ----- @@ -50,13 +50,15 @@ print("Done. Saving...") for i, im in enumerate(im_d): im_n = decode_mega_depth(im, numpy=True) - stem = Path(batch["paths"]["s"][i]).stem + stem = Path(batch["paths"]["x"][i]).stem if write_images: io.imsave( str(not_committed_path / (stem + "_depth.png")), im_n, ) print("Done.") + #! No translation, so holding off... + """ # --------------------------------------- # ----- Test MD after translation ----- # --------------------------------------- @@ -73,9 +75,12 @@ for i, im_d in enumerate(y_d): print(i, "/", len(y_d)) im_n = decode_mega_depth(im_d, numpy=True) - stem = Path(batch["paths"]["s"][i]).stem + + stem = Path(batch["paths"]["x"][i]).stem io.imsave( str(not_committed_path / (stem + "_translated_depth.png")), im_n, ) else: im_d = decode_mega_depth(y_d) + """ + diff --git a/tests/test_trainer.py b/tests/test_trainer.py index a47f1ddf..b480659c 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -9,7 +9,7 @@ from run import print_header parser = argparse.ArgumentParser() -parser.add_argument("-c", "--config", default="config/trainer/local_tests.yaml") +parser.add_argument("-c", "--config", default="config/trainer/local_tests.yaml") args = parser.parse_args() root = Path(__file__).parent.parent opts = load_test_opts(args.config) @@ -23,8 +23,7 @@ trainer.setup() multi_batch_tuple = next(iter(trainer.train_loaders)) multi_domain_batch = { - batch["domain"][0]: trainer.batch_to_device(batch) - for batch in multi_batch_tuple + batch["domain"][0]: trainer.batch_to_device(batch) for batch in multi_batch_tuple } # ------------------------- # ----- Test Config ----- @@ -132,9 +131,7 @@ trainer.opts.train.representational_training = False trainer.opts.train.representation_steps = 100 trainer.logger.global_step = 200 - print( - False, 100, 200, "Not Using repr_tr and step < repr_step and step % 2 == 0" - ) + print(False, 100, 200, "Not Using repr_tr and step < repr_step and step % 2 == 0") trainer.update_g(multi_domain_batch, 1) print() @@ -142,9 +139,7 @@ trainer.opts.train.representational_training = False trainer.opts.train.representation_steps = 100 trainer.logger.global_step = 201 - print( - False, 100, 201, "Not Using repr_tr and step > repr_step and step % 2 == 1" - ) + print(False, 100, 201, "Not Using repr_tr and step > repr_step and step % 2 == 1") trainer.update_g(multi_domain_batch, 1) print() @@ -181,15 +176,16 @@ print("Setting up") trainer.setup() - encoder_weights = [ - [p.detach().cpu().numpy()[:5] for p in trainer.G.encoder.parameters()] - ] + encoder_weights = [[p.detach().cpu().numpy()[:5] for p in trainer.G.encoder.parameters()]] print("First update: extrapolation") print(" - Update g") trainer.update_g(multi_domain_batch) - print(" - Update d") - trainer.update_d(multi_domain_batch) + + #! Ignoring discrim update since we aren't + #! yet doing translation + # print(" - Update d") + # trainer.update_d(multi_domain_batch) print(" - Update c") trainer.update_c(multi_domain_batch) @@ -198,24 +194,22 @@ print("Second update: gradient step") print(" - Update g") trainer.update_g(multi_domain_batch) - print(" - Update d") - trainer.update_d(multi_domain_batch) + # print(" - Update d") + # trainer.update_d(multi_domain_batch) print(" - Update c") trainer.update_c(multi_domain_batch) print("Freezing encoder") freeze(trainer.G.encoder) trainer.representation_is_frozen = True - encoder_weights += [ - [p.cpu().numpy()[:5] for p in trainer.G.encoder.parameters()] - ] + encoder_weights += [[p.cpu().numpy()[:5] for p in trainer.G.encoder.parameters()]] trainer.logger.global_step += 1 print("Third update: extrapolation") print(" - Update g") trainer.update_g(multi_domain_batch) - print(" - Update d") - trainer.update_d(multi_domain_batch) + # print(" - Update d") + # trainer.update_d(multi_domain_batch) print(" - Update c") trainer.update_c(multi_domain_batch) @@ -224,14 +218,12 @@ print("Fourth update: gradient step") print(" - Update g") trainer.update_g(multi_domain_batch) - print(" - Update d") - trainer.update_d(multi_domain_batch) + # print(" - Update d") + # trainer.update_d(multi_domain_batch) print(" - Update c") trainer.update_c(multi_domain_batch) - encoder_weights += [ - [p.cpu().numpy()[:5] for p in trainer.G.encoder.parameters()] - ] + encoder_weights += [[p.cpu().numpy()[:5] for p in trainer.G.encoder.parameters()]] # # ? triggers segmentation fault for some unknown reason # # encoder was updated diff --git a/tests/test_utils.py b/tests/test_utils.py index 6ecb176a..50f62ebc 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -38,10 +38,10 @@ # --------------------------------------------- # ----- Testing domains_to_class_tensor ----- # --------------------------------------------- - batch = next(iter(loaders["train"]["rn"])) + batch = next(iter(loaders["train"]["r"])) print(domains_to_class_tensor(batch["domain"], True)) print(domains_to_class_tensor(batch["domain"], False)) - domains = ["rn", "rf", "rf", "sn"] + domains = ["r", "s"] try: domains_to_class_tensor([1, "sg"]) diff --git a/train.py b/train.py index aadef06a..ef8d3632 100644 --- a/train.py +++ b/train.py @@ -21,7 +21,7 @@ def parsed_args(): "--config", default="./config/local_tests.yaml", type=str, - help="What configuration file to use to overwrite shared/defaults.yml", + help="What configuration file to use to overwrite shared/defaults.yaml", ) parser.add_argument( "--exp_desc", default="", type=str, help="Description of the experiment", @@ -75,7 +75,7 @@ def pprint(*args): # ----- Load opts ----- # ----------------------- - opts = load_opts(Path(args.config), default="./shared/trainer/defaults.yml") + opts = load_opts(Path(args.config), default="./shared/trainer/defaults.yaml") opts.output_path = env_to_path(opts.output_path) opts.output_path = get_increased_path(opts.output_path) pprint("Running model in", opts.output_path) @@ -106,9 +106,7 @@ def pprint(*args): if args.dev_mode: pprint("> /!\ Development mode ON") print("Cropping data to 32") - opts.data.transforms += [ - Dict({"name": "crop", "ignore": False, "height": 32, "width": 32}) - ] + opts.data.transforms += [Dict({"name": "crop", "ignore": False, "height": 32, "width": 32})] # ------------------- # ----- Train -----