From 6946facf6913ff76fbb6aa48fd69802b55677d5f Mon Sep 17 00:00:00 2001 From: Rafie Walker Date: Sun, 16 Jun 2024 21:15:50 +0200 Subject: [PATCH] Implement SD3 loss weighting (#8528) * Add lognorm and cosmap weighting * Implement mode sampling * Update examples/dreambooth/train_dreambooth_lora_sd3.py * Update examples/dreambooth/train_dreambooth_lora_sd3.py * Update examples/dreambooth/train_dreambooth_sd3.py * Update examples/dreambooth/train_dreambooth_sd3.py * Update examples/dreambooth/train_dreambooth_sd3.py * Update examples/dreambooth/train_dreambooth_lora_sd3.py * Update examples/dreambooth/train_dreambooth_sd3.py * Update examples/dreambooth/train_dreambooth_sd3.py * Update examples/dreambooth/train_dreambooth_lora_sd3.py * keep timestamp sampling fully on cpu --------- Co-authored-by: Kashif Rasul Co-authored-by: Sayak Paul --- .../dreambooth/train_dreambooth_lora_sd3.py | 28 +++++++++++------ examples/dreambooth/train_dreambooth_sd3.py | 30 ++++++++++++------- 2 files changed, 38 insertions(+), 20 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 2512da1c61f9..67227e2defc6 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1462,7 +1462,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): bsz = model_input.shape[0] # Sample a random timestep for each image - indices = torch.randint(0, noise_scheduler_copy.config.num_train_timesteps, (bsz,)) + # for weighting schemes where we sample timesteps non-uniformly + if args.weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif args.weighting_scheme == "mode": + u = torch.rand(size=(bsz,), device="cpu") + u = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(bsz,), device="cpu") + + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) # Add noise according to flow matching. @@ -1483,16 +1494,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_pred = model_pred * (-sigmas) + noisy_model_input # TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :) + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss if args.weighting_scheme == "sigma_sqrt": weighting = (sigmas**-2.0).float() - elif args.weighting_scheme == "logit_normal": - # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). - u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device) - weighting = torch.nn.functional.sigmoid(u) - elif args.weighting_scheme == "mode": - # See sec 3.1 in the SD3 paper (20). - u = torch.rand(size=(bsz,), device=accelerator.device) - weighting = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + elif args.weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) # simplified flow matching aka 0-rectified flow matching loss # target = model_input - noise diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index 17d35577ed36..7920b4c8e0fa 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -1526,7 +1526,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): bsz = model_input.shape[0] # Sample a random timestep for each image - indices = torch.randint(0, noise_scheduler_copy.config.num_train_timesteps, (bsz,)) + # for weighting schemes where we sample timesteps non-uniformly + if args.weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device="cpu") + u = torch.nn.functional.sigmoid(u) + elif args.weighting_scheme == "mode": + u = torch.rand(size=(bsz,), device="cpu") + u = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(bsz,), device="cpu") + + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) # Add noise according to flow matching. @@ -1560,18 +1571,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Follow: Section 5 of https://arxiv.org/abs/2206.00364. # Preconditioning of the model outputs. model_pred = model_pred * (-sigmas) + noisy_model_input - - # TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :) + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss if args.weighting_scheme == "sigma_sqrt": weighting = (sigmas**-2.0).float() - elif args.weighting_scheme == "logit_normal": - # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). - u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device) - weighting = torch.nn.functional.sigmoid(u) - elif args.weighting_scheme == "mode": - # See sec 3.1 in the SD3 paper (20). - u = torch.rand(size=(bsz,), device=accelerator.device) - weighting = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + elif args.weighting_scheme == "cosmap": + bot = 1 - 2 * sigmas + 2 * sigmas**2 + weighting = 2 / (math.pi * bot) + else: + weighting = torch.ones_like(sigmas) # simplified flow matching aka 0-rectified flow matching loss # target = model_input - noise