Skip to content

Commit

Permalink
Clean Up Comments in LCM(-LoRA) Distillation Scripts. (huggingface#6145)
Browse files Browse the repository at this point in the history
* Clean up comments in LCM(-LoRA) distillation scripts.

* Calculate predicted source noise noise_pred correctly for all prediction_types.

* make style

* apply suggestions from review

---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
2 people authored and Jimmy committed Apr 26, 2024
1 parent 6beab9a commit 86087cd
Show file tree
Hide file tree
Showing 4 changed files with 350 additions and 147 deletions.
127 changes: 90 additions & 37 deletions examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def __call__(self, x):
return False


class Text2ImageDataset:
class SDText2ImageDataset:
def __init__(
self,
train_shards_path_or_url: Union[str, List[str]],
Expand Down Expand Up @@ -359,19 +359,43 @@ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=


# Compare LCMScheduler.step, Step 4
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
if prediction_type == "epsilon":
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
pred_x_0 = (sample - sigmas * model_output) / alphas
elif prediction_type == "sample":
pred_x_0 = model_output
elif prediction_type == "v_prediction":
pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
pred_x_0 = alphas * sample - sigmas * model_output
else:
raise ValueError(f"Prediction type {prediction_type} currently not supported.")
raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)

return pred_x_0


# Based on step 4 in DDIMScheduler.step
def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
if prediction_type == "epsilon":
pred_epsilon = model_output
elif prediction_type == "sample":
pred_epsilon = (sample - alphas * model_output) / sigmas
elif prediction_type == "v_prediction":
pred_epsilon = alphas * model_output + sigmas * sample
else:
raise ValueError(
f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
f" are supported."
)

return pred_epsilon


def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
Expand Down Expand Up @@ -835,34 +859,35 @@ def main(args):
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
)

# The scheduler calculates the alpha and sigma schedule for us
# DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
# Initialize the DDIM ODE solver for distillation.
solver = DDIMSolver(
noise_scheduler.alphas_cumprod.numpy(),
timesteps=noise_scheduler.config.num_train_timesteps,
ddim_timesteps=args.num_ddim_timesteps,
)

# 2. Load tokenizers from SD-XL checkpoint.
# 2. Load tokenizers from SD 1.X/2.X checkpoint.
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
)

# 3. Load text encoders from SD-1.5 checkpoint.
# 3. Load text encoders from SD 1.X/2.X checkpoint.
# import correct text encoder classes
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
)

# 4. Load VAE from SD-XL checkpoint (or more stable VAE)
# 4. Load VAE from SD 1.X/2.X checkpoint
vae = AutoencoderKL.from_pretrained(
args.pretrained_teacher_model,
subfolder="vae",
revision=args.teacher_revision,
)

# 5. Load teacher U-Net from SD-XL checkpoint
# 5. Load teacher U-Net from SD 1.X/2.X checkpoint
teacher_unet = UNet2DConditionModel.from_pretrained(
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
)
Expand All @@ -872,7 +897,7 @@ def main(args):
text_encoder.requires_grad_(False)
teacher_unet.requires_grad_(False)

# 7. Create online (`unet`) student U-Nets.
# 7. Create online student U-Net.
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
)
Expand Down Expand Up @@ -935,6 +960,7 @@ def main(args):
# Also move the alpha and sigma noise schedules to accelerator.device.
alpha_schedule = alpha_schedule.to(accelerator.device)
sigma_schedule = sigma_schedule.to(accelerator.device)
# Move the ODE solver to accelerator.device.
solver = solver.to(accelerator.device)

# 10. Handle saving and loading of checkpoints
Expand Down Expand Up @@ -1011,13 +1037,14 @@ def load_model_hook(models, input_dir):
eps=args.adam_epsilon,
)

# 13. Dataset creation and data processing
# Here, we compute not just the text embeddings but also the additional embeddings
# needed for the SD XL UNet to operate.
def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True):
prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train)
return {"prompt_embeds": prompt_embeds}

dataset = Text2ImageDataset(
dataset = SDText2ImageDataset(
train_shards_path_or_url=args.train_shards_path_or_url,
num_train_examples=args.max_train_samples,
per_gpu_batch_size=args.train_batch_size,
Expand All @@ -1037,6 +1064,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
tokenizer=tokenizer,
)

# 14. LR Scheduler creation
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
Expand All @@ -1051,6 +1079,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
num_training_steps=args.max_train_steps,
)

# 15. Prepare for training
# Prepare everything with our `accelerator`.
unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)

Expand All @@ -1072,7 +1101,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
).input_ids.to(accelerator.device)
uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]

# Train!
# 16. Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

logger.info("***** Running training *****")
Expand Down Expand Up @@ -1123,6 +1152,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
# 1. Load and process the image and text conditioning
image, text = batch

image = image.to(accelerator.device, non_blocking=True)
Expand All @@ -1140,37 +1170,37 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok

latents = latents * vae.config.scaling_factor
latents = latents.to(weight_dtype)

# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]

# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
# 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
# For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
start_timesteps = solver.ddim_timesteps[index]
timesteps = start_timesteps - topk
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)

# 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
# 3. Get boundary scalings for start_timesteps and (end) timesteps.
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]

# 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
# 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
# timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
noise = torch.randn_like(latents)
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)

# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
# 5. Sample a random guidance scale w from U[w_min, w_max]
# Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w = w.reshape(bsz, 1, 1, 1)
w = w.to(device=latents.device, dtype=latents.dtype)

# 20.4.8. Prepare prompt embeds and unet_added_conditions
# 6. Prepare prompt embeds and unet_added_conditions
prompt_embeds = encoded_text.pop("prompt_embeds")

# 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
# 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
noise_pred = unet(
noisy_model_input,
start_timesteps,
Expand All @@ -1179,7 +1209,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
added_cond_kwargs=encoded_text,
).sample

pred_x_0 = predicted_origin(
pred_x_0 = get_predicted_original_sample(
noise_pred,
start_timesteps,
noisy_model_input,
Expand All @@ -1190,17 +1220,27 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok

model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0

# 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
# noisy_latents with both the conditioning embedding c and unconditional embedding 0
# Get teacher model prediction on noisy_latents and conditional embedding
# 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
# predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad():
with torch.autocast("cuda"):
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
start_timesteps,
encoder_hidden_states=prompt_embeds.to(weight_dtype),
).sample
cond_pred_x0 = predicted_origin(
cond_pred_x0 = get_predicted_original_sample(
cond_teacher_output,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
)
cond_pred_noise = get_predicted_noise(
cond_teacher_output,
start_timesteps,
noisy_model_input,
Expand All @@ -1209,13 +1249,21 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
sigma_schedule,
)

# Get teacher model prediction on noisy_latents and unconditional embedding
# 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
uncond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
start_timesteps,
encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
).sample
uncond_pred_x0 = predicted_origin(
uncond_pred_x0 = get_predicted_original_sample(
uncond_teacher_output,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
)
uncond_pred_noise = get_predicted_noise(
uncond_teacher_output,
start_timesteps,
noisy_model_input,
Expand All @@ -1224,12 +1272,17 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
sigma_schedule,
)

# 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
# 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
# Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
# 4. Run one step of the ODE solver to estimate the next point x_prev on the
# augmented PF-ODE trajectory (solving backward in time)
# Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
x_prev = solver.ddim_step(pred_x0, pred_noise, index)

# 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation.
with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype):
target_noise_pred = unet(
Expand All @@ -1238,7 +1291,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
timestep_cond=None,
encoder_hidden_states=prompt_embeds.float(),
).sample
pred_x_0 = predicted_origin(
pred_x_0 = get_predicted_original_sample(
target_noise_pred,
timesteps,
x_prev,
Expand All @@ -1248,15 +1301,15 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
)
target = c_skip * x_prev + c_out * pred_x_0

# 20.4.13. Calculate loss
# 10. Calculate loss
if args.loss_type == "l2":
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
elif args.loss_type == "huber":
loss = torch.mean(
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
)

# 20.4.14. Backpropagate on the online student model (`unet`)
# 11. Backpropagate on the online student model (`unet`)
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
Expand Down
Loading

0 comments on commit 86087cd

Please sign in to comment.