Skip to content

Commit

Permalink
Implement SD3 loss weighting (huggingface#8528)
Browse files Browse the repository at this point in the history
* Add lognorm and cosmap weighting

* Implement mode sampling

* Update examples/dreambooth/train_dreambooth_lora_sd3.py

* Update examples/dreambooth/train_dreambooth_lora_sd3.py

* Update examples/dreambooth/train_dreambooth_sd3.py

* Update examples/dreambooth/train_dreambooth_sd3.py

* Update examples/dreambooth/train_dreambooth_sd3.py

* Update examples/dreambooth/train_dreambooth_lora_sd3.py

* Update examples/dreambooth/train_dreambooth_sd3.py

* Update examples/dreambooth/train_dreambooth_sd3.py

* Update examples/dreambooth/train_dreambooth_lora_sd3.py

* keep timestamp sampling fully on cpu

---------

Co-authored-by: Kashif Rasul <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
3 people authored Jun 16, 2024
1 parent 130dd93 commit 6946fac
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 20 deletions.
28 changes: 19 additions & 9 deletions examples/dreambooth/train_dreambooth_lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,7 +1462,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
bsz = model_input.shape[0]

# Sample a random timestep for each image
indices = torch.randint(0, noise_scheduler_copy.config.num_train_timesteps, (bsz,))
# for weighting schemes where we sample timesteps non-uniformly
if args.weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif args.weighting_scheme == "mode":
u = torch.rand(size=(bsz,), device="cpu")
u = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(bsz,), device="cpu")

indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)

# Add noise according to flow matching.
Expand All @@ -1483,16 +1494,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
model_pred = model_pred * (-sigmas) + noisy_model_input

# TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :)
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
if args.weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif args.weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device)
weighting = torch.nn.functional.sigmoid(u)
elif args.weighting_scheme == "mode":
# See sec 3.1 in the SD3 paper (20).
u = torch.rand(size=(bsz,), device=accelerator.device)
weighting = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
elif args.weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)

# simplified flow matching aka 0-rectified flow matching loss
# target = model_input - noise
Expand Down
30 changes: 19 additions & 11 deletions examples/dreambooth/train_dreambooth_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1526,7 +1526,18 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
bsz = model_input.shape[0]

# Sample a random timestep for each image
indices = torch.randint(0, noise_scheduler_copy.config.num_train_timesteps, (bsz,))
# for weighting schemes where we sample timesteps non-uniformly
if args.weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif args.weighting_scheme == "mode":
u = torch.rand(size=(bsz,), device="cpu")
u = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(bsz,), device="cpu")

indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)

# Add noise according to flow matching.
Expand Down Expand Up @@ -1560,18 +1571,15 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
# Preconditioning of the model outputs.
model_pred = model_pred * (-sigmas) + noisy_model_input

# TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :)
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
if args.weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif args.weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device)
weighting = torch.nn.functional.sigmoid(u)
elif args.weighting_scheme == "mode":
# See sec 3.1 in the SD3 paper (20).
u = torch.rand(size=(bsz,), device=accelerator.device)
weighting = 1 - u - args.mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
elif args.weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)

# simplified flow matching aka 0-rectified flow matching loss
# target = model_input - noise
Expand Down

0 comments on commit 6946fac

Please sign in to comment.