Skip to content

Commit

Permalink
fix: 🚑 Avoid collapse with new StarGAN
Browse files Browse the repository at this point in the history
  • Loading branch information
adjavon committed Apr 30, 2024
1 parent 22c4f35 commit 75928ea
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 55 deletions.
5 changes: 4 additions & 1 deletion docs/tutorials/train.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ from quac.training.data import TrainingDataset
dataset = TrainingDataset(
source="path/to/training/data",
reference="path/to/training/data",
validation="path/to/validation/data"
img_size=128,
batch_size=4,
num_workers=4

)
```

Expand Down
10 changes: 6 additions & 4 deletions src/quac/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class DataConfig(BaseModel):
batch_size: int = 1
num_workers: int = 4
grayscale: bool = False
mean: Optional[float] = 0.5
std: Optional[float] = 0.5


class RunConfig(BaseModel):
Expand All @@ -39,7 +41,7 @@ class ValConfig(BaseModel):


class LossConfig(BaseModel):
lambda_ds: float = 1.0
lambda_ds: float = 0.0 # No diversity by default
lambda_sty: float = 1.0
lambda_cyc: float = 1.0
lambda_reg: float = 1.0
Expand All @@ -48,11 +50,11 @@ class LossConfig(BaseModel):

class SolverConfig(BaseModel):
root_dir: str
f_lr: float = 1e-4
f_lr: float = 1e-6
lr: float = 1e-4
beta1: float = 0.5
beta1: float = 0.0
beta2: float = 0.99
weight_decay: float = 0.1
weight_decay: float = 1e-4


class ExperimentConfig(BaseModel):
Expand Down
43 changes: 25 additions & 18 deletions src/quac/training/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,16 +156,15 @@ def get_train_loader(
if grayscale:
transform_list.append(transforms.Grayscale())

transform = transforms.Compose(
[
*transform_list,
transforms.Resize([img_size, img_size]),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
]
)
transform_list += [
transforms.Resize([img_size, img_size]),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
]
if mean is not None and std is not None:
transform_list.append(transforms.Normalize(mean=mean, std=std))
transform = transforms.Compose(transform_list)

if which == "source":
# dataset = ImageFolder(root, transform)
Expand Down Expand Up @@ -249,14 +248,14 @@ def get_test_loader(
transform_list = []
if grayscale:
transform_list.append(transforms.Grayscale())
transform = transforms.Compose(
[
*transform_list,
transforms.Resize([img_size, img_size]),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
]
)

transform_list += [
transforms.Resize([img_size, img_size]),
transforms.ToTensor(),
]
if mean is not None and std is not None:
transform_list.append(transforms.Normalize(mean=mean, std=std))
transform = transforms.Compose(transform_list)

dataset = ImageFolder(root, transform)
return data.DataLoader(
Expand Down Expand Up @@ -366,6 +365,8 @@ def __init__(
batch_size=8,
num_workers=4,
grayscale=False,
mean=None,
std=None,
):
self.src = get_train_loader(
root=source,
Expand All @@ -374,6 +375,8 @@ def __init__(
batch_size=batch_size,
num_workers=num_workers,
grayscale=grayscale,
mean=mean,
std=std,
)
self.reference = get_train_loader(
root=reference,
Expand All @@ -382,6 +385,8 @@ def __init__(
batch_size=batch_size,
num_workers=num_workers,
grayscale=grayscale,
mean=mean,
std=std,
)


Expand Down Expand Up @@ -513,6 +518,7 @@ def loader_src(self):
grayscale=self.grayscale,
mean=self.mean,
std=self.std,
drop_last=False,
)

@property
Expand All @@ -525,4 +531,5 @@ def loader_ref(self):
grayscale=self.grayscale,
mean=self.mean,
std=self.std,
drop_last=True,
)
82 changes: 50 additions & 32 deletions src/quac/training/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def train(
val_loader=None,
val_config=None,
):
start = datetime.datetime.now()
nets = self.nets
nets_ema = self.nets_ema
optims = self.optims
Expand Down Expand Up @@ -217,18 +218,20 @@ def train(
optims.generator.step()

# compute moving average of network parameters
nets_ema.update(nets)
moving_average(nets.generator, nets_ema.generator, beta=0.999)
moving_average(nets.style_encoder, nets_ema.style_encoder, beta=0.999)
moving_average(nets.mapping_network, nets_ema.mapping_network, beta=0.999)

# decay weight for diversity sensitive loss
if lambda_ds > 0:
lambda_ds -= initial_lambda_ds / ds_iter

if (i + 1) % eval_every == 0 and val_loader is not None:
self.evaluate(
val_loader, iteration=i + 1, mode="latent", val_config=val_config
val_loader, iteration=i + 1, mode="reference", val_config=val_config
)
self.evaluate(
val_loader, iteration=i + 1, mode="reference", val_config=val_config
val_loader, iteration=i + 1, mode="latent", val_config=val_config
)

# save model checkpoints
Expand All @@ -237,6 +240,7 @@ def train(

# print out log losses, images
if (i + 1) % log_every == 0:
elapsed = datetime.datetime.now() - start
self.log(
d_losses_latent,
d_losses_ref,
Expand All @@ -251,6 +255,7 @@ def train(
y_trg, # Target classes
step=i + 1,
total_iters=total_iters,
elapsed_time=elapsed,
)

def log(
Expand All @@ -268,6 +273,7 @@ def log(
y_target,
step,
total_iters,
elapsed_time,
):
all_losses = dict()
for loss, prefix in zip(
Expand All @@ -289,28 +295,26 @@ def log(
caption = " ".join([str(x) for x in label.cpu().tolist()])
self.run.log({name: [wandb.Image(img, caption=caption)]}, step=step)

else:
now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M")
print(
f"[{now}]: {step}/{total_iters}",
flush=True,
)
g_losses = "\t".join(
[
f"{key}: {value:.4f}"
for key, value in all_losses.items()
if not key.startswith("D/")
]
)
d_losses = "\t".join(
[
f"{key}: {value:.4f}"
for key, value in all_losses.items()
if key.startswith("D/")
]
)
print(f"G Losses: {g_losses}", flush=True)
print(f"D Losses: {d_losses}", flush=True)
print(
f"[{elapsed_time}]: {step}/{total_iters}",
flush=True,
)
g_losses = "\t".join(
[
f"{key}: {value:.4f}"
for key, value in all_losses.items()
if not key.startswith("D/")
]
)
d_losses = "\t".join(
[
f"{key}: {value:.4f}"
for key, value in all_losses.items()
if key.startswith("D/")
]
)
print(f"G Losses: {g_losses}", flush=True)
print(f"D Losses: {d_losses}", flush=True)

@torch.no_grad()
def evaluate(
Expand Down Expand Up @@ -354,6 +358,8 @@ def evaluate(
for trg_idx, trg_domain in enumerate(domains):
src_domains = [x for x in val_loader.available_sources if x != trg_domain]
val_loader.set_target(trg_domain)
if mode == "reference":
loader_ref = val_loader.loader_ref

for src_idx, src_domain in enumerate(src_domains):
task = "%s/%s" % (src_domain, trg_domain)
Expand All @@ -378,25 +384,37 @@ def evaluate(
z_trg = torch.randn(N, self.latent_dim).to(device)
s_trg = self.nets_ema.mapping_network(z_trg, y_trg)
else:
# x_ref = x_trg.clone()
try:
# TODO don't need to re-do this every time, just use
# the same set of reference images for the whole dataset!
x_ref = next(iter_ref).to(device)
except:
iter_ref = iter(val_loader.loader_ref)
iter_ref = iter(loader_ref)
x_ref = next(iter_ref).to(device)

if x_ref.size(0) > N:
x_ref = x_ref[:N]
elif x_ref.size(0) < N:
raise ValueError(
"Not enough reference images."
"Make sure that the batch size of the validation loader is bigger than `num_outs_per_domain`."
)
s_trg = self.nets_ema.style_encoder(x_ref, y_trg)

x_fake = self.nets_ema.generator(x_src, s_trg)
# Run the classification
predictions.append(
classifier(
x_fake, assume_normalized=val_config.assume_normalized
)
.cpu()
.numpy()
pred = classifier(
x_fake, assume_normalized=val_config.assume_normalized
)
predictions.append(pred.cpu().numpy())
# predictions.append(
# classifier(
# x_fake, assume_normalized=val_config.assume_normalized
# )
# .cpu()
# .numpy()
# )
predictions = np.stack(predictions, axis=0)
assert len(predictions) > 0
# Do it in a vectorized way, by reshaping the predictions
Expand Down

0 comments on commit 75928ea

Please sign in to comment.