diff --git a/docs/tutorials/train.md b/docs/tutorials/train.md index 804461e..8aea482 100644 --- a/docs/tutorials/train.md +++ b/docs/tutorials/train.md @@ -25,7 +25,10 @@ from quac.training.data import TrainingDataset dataset = TrainingDataset( source="path/to/training/data", reference="path/to/training/data", - validation="path/to/validation/data" + img_size=128, + batch_size=4, + num_workers=4 + ) ``` diff --git a/src/quac/training/config.py b/src/quac/training/config.py index 6884e43..f29543d 100644 --- a/src/quac/training/config.py +++ b/src/quac/training/config.py @@ -18,6 +18,8 @@ class DataConfig(BaseModel): batch_size: int = 1 num_workers: int = 4 grayscale: bool = False + mean: Optional[float] = 0.5 + std: Optional[float] = 0.5 class RunConfig(BaseModel): @@ -39,7 +41,7 @@ class ValConfig(BaseModel): class LossConfig(BaseModel): - lambda_ds: float = 1.0 + lambda_ds: float = 0.0 # No diversity by default lambda_sty: float = 1.0 lambda_cyc: float = 1.0 lambda_reg: float = 1.0 @@ -48,11 +50,11 @@ class LossConfig(BaseModel): class SolverConfig(BaseModel): root_dir: str - f_lr: float = 1e-4 + f_lr: float = 1e-6 lr: float = 1e-4 - beta1: float = 0.5 + beta1: float = 0.0 beta2: float = 0.99 - weight_decay: float = 0.1 + weight_decay: float = 1e-4 class ExperimentConfig(BaseModel): diff --git a/src/quac/training/data_loader.py b/src/quac/training/data_loader.py index b65e5ff..051ab15 100644 --- a/src/quac/training/data_loader.py +++ b/src/quac/training/data_loader.py @@ -156,16 +156,15 @@ def get_train_loader( if grayscale: transform_list.append(transforms.Grayscale()) - transform = transforms.Compose( - [ - *transform_list, - transforms.Resize([img_size, img_size]), - transforms.RandomHorizontalFlip(), - transforms.RandomVerticalFlip(), - transforms.ToTensor(), - transforms.Normalize(mean=mean, std=std), - ] - ) + transform_list += [ + transforms.Resize([img_size, img_size]), + transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(), + transforms.ToTensor(), + ] + if mean is not None and std is not None: + transform_list.append(transforms.Normalize(mean=mean, std=std)) + transform = transforms.Compose(transform_list) if which == "source": # dataset = ImageFolder(root, transform) @@ -249,14 +248,14 @@ def get_test_loader( transform_list = [] if grayscale: transform_list.append(transforms.Grayscale()) - transform = transforms.Compose( - [ - *transform_list, - transforms.Resize([img_size, img_size]), - transforms.ToTensor(), - transforms.Normalize(mean=mean, std=std), - ] - ) + + transform_list += [ + transforms.Resize([img_size, img_size]), + transforms.ToTensor(), + ] + if mean is not None and std is not None: + transform_list.append(transforms.Normalize(mean=mean, std=std)) + transform = transforms.Compose(transform_list) dataset = ImageFolder(root, transform) return data.DataLoader( @@ -366,6 +365,8 @@ def __init__( batch_size=8, num_workers=4, grayscale=False, + mean=None, + std=None, ): self.src = get_train_loader( root=source, @@ -374,6 +375,8 @@ def __init__( batch_size=batch_size, num_workers=num_workers, grayscale=grayscale, + mean=mean, + std=std, ) self.reference = get_train_loader( root=reference, @@ -382,6 +385,8 @@ def __init__( batch_size=batch_size, num_workers=num_workers, grayscale=grayscale, + mean=mean, + std=std, ) @@ -513,6 +518,7 @@ def loader_src(self): grayscale=self.grayscale, mean=self.mean, std=self.std, + drop_last=False, ) @property @@ -525,4 +531,5 @@ def loader_ref(self): grayscale=self.grayscale, mean=self.mean, std=self.std, + drop_last=True, ) diff --git a/src/quac/training/solver.py b/src/quac/training/solver.py index 81de9a0..1e89016 100644 --- a/src/quac/training/solver.py +++ b/src/quac/training/solver.py @@ -143,6 +143,7 @@ def train( val_loader=None, val_config=None, ): + start = datetime.datetime.now() nets = self.nets nets_ema = self.nets_ema optims = self.optims @@ -217,7 +218,9 @@ def train( optims.generator.step() # compute moving average of network parameters - nets_ema.update(nets) + moving_average(nets.generator, nets_ema.generator, beta=0.999) + moving_average(nets.style_encoder, nets_ema.style_encoder, beta=0.999) + moving_average(nets.mapping_network, nets_ema.mapping_network, beta=0.999) # decay weight for diversity sensitive loss if lambda_ds > 0: @@ -225,10 +228,10 @@ def train( if (i + 1) % eval_every == 0 and val_loader is not None: self.evaluate( - val_loader, iteration=i + 1, mode="latent", val_config=val_config + val_loader, iteration=i + 1, mode="reference", val_config=val_config ) self.evaluate( - val_loader, iteration=i + 1, mode="reference", val_config=val_config + val_loader, iteration=i + 1, mode="latent", val_config=val_config ) # save model checkpoints @@ -237,6 +240,7 @@ def train( # print out log losses, images if (i + 1) % log_every == 0: + elapsed = datetime.datetime.now() - start self.log( d_losses_latent, d_losses_ref, @@ -251,6 +255,7 @@ def train( y_trg, # Target classes step=i + 1, total_iters=total_iters, + elapsed_time=elapsed, ) def log( @@ -268,6 +273,7 @@ def log( y_target, step, total_iters, + elapsed_time, ): all_losses = dict() for loss, prefix in zip( @@ -289,28 +295,26 @@ def log( caption = " ".join([str(x) for x in label.cpu().tolist()]) self.run.log({name: [wandb.Image(img, caption=caption)]}, step=step) - else: - now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M") - print( - f"[{now}]: {step}/{total_iters}", - flush=True, - ) - g_losses = "\t".join( - [ - f"{key}: {value:.4f}" - for key, value in all_losses.items() - if not key.startswith("D/") - ] - ) - d_losses = "\t".join( - [ - f"{key}: {value:.4f}" - for key, value in all_losses.items() - if key.startswith("D/") - ] - ) - print(f"G Losses: {g_losses}", flush=True) - print(f"D Losses: {d_losses}", flush=True) + print( + f"[{elapsed_time}]: {step}/{total_iters}", + flush=True, + ) + g_losses = "\t".join( + [ + f"{key}: {value:.4f}" + for key, value in all_losses.items() + if not key.startswith("D/") + ] + ) + d_losses = "\t".join( + [ + f"{key}: {value:.4f}" + for key, value in all_losses.items() + if key.startswith("D/") + ] + ) + print(f"G Losses: {g_losses}", flush=True) + print(f"D Losses: {d_losses}", flush=True) @torch.no_grad() def evaluate( @@ -354,6 +358,8 @@ def evaluate( for trg_idx, trg_domain in enumerate(domains): src_domains = [x for x in val_loader.available_sources if x != trg_domain] val_loader.set_target(trg_domain) + if mode == "reference": + loader_ref = val_loader.loader_ref for src_idx, src_domain in enumerate(src_domains): task = "%s/%s" % (src_domain, trg_domain) @@ -378,25 +384,37 @@ def evaluate( z_trg = torch.randn(N, self.latent_dim).to(device) s_trg = self.nets_ema.mapping_network(z_trg, y_trg) else: + # x_ref = x_trg.clone() try: + # TODO don't need to re-do this every time, just use + # the same set of reference images for the whole dataset! x_ref = next(iter_ref).to(device) except: - iter_ref = iter(val_loader.loader_ref) + iter_ref = iter(loader_ref) x_ref = next(iter_ref).to(device) if x_ref.size(0) > N: x_ref = x_ref[:N] + elif x_ref.size(0) < N: + raise ValueError( + "Not enough reference images." + "Make sure that the batch size of the validation loader is bigger than `num_outs_per_domain`." + ) s_trg = self.nets_ema.style_encoder(x_ref, y_trg) x_fake = self.nets_ema.generator(x_src, s_trg) # Run the classification - predictions.append( - classifier( - x_fake, assume_normalized=val_config.assume_normalized - ) - .cpu() - .numpy() + pred = classifier( + x_fake, assume_normalized=val_config.assume_normalized ) + predictions.append(pred.cpu().numpy()) + # predictions.append( + # classifier( + # x_fake, assume_normalized=val_config.assume_normalized + # ) + # .cpu() + # .numpy() + # ) predictions = np.stack(predictions, axis=0) assert len(predictions) > 0 # Do it in a vectorized way, by reshaping the predictions