Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Training examples] Follow up of huggingface#6306 (huggingface#6346)
Browse files Browse the repository at this point in the history
* add to dreambooth lora.

* add: t2i lora.

* add: sdxl t2i lora.

* style

* lcm lora sdxl.

* unwrap

* fix: enable_adapters().
sayakpaul authored and Jimmy committed Apr 26, 2024

Verified

This commit was signed with the committer’s verified signature.
1 parent bc0a908 commit 509ac93
Showing 4 changed files with 29 additions and 21 deletions.
12 changes: 6 additions & 6 deletions examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
Original file line number Diff line number Diff line change
@@ -51,7 +51,7 @@
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available


@@ -113,7 +113,7 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin
if unet is None:
raise ValueError("Must provide a `unet` when doing intermediate validation.")
unet = accelerator.unwrap_model(unet)
state_dict = get_peft_model_state_dict(unet)
state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
to_load = state_dict
else:
to_load = args.output_dir
@@ -819,7 +819,7 @@ def save_model_hook(models, weights, output_dir):
unet_ = accelerator.unwrap_model(unet)
# also save the checkpoints in native `diffusers` format so that it can be easily
# be independently loaded via `load_lora_weights()`.
state_dict = get_peft_model_state_dict(unet_)
state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet_))
StableDiffusionXLPipeline.save_lora_weights(output_dir, unet_lora_layers=state_dict)

for _, model in enumerate(models):
@@ -1184,7 +1184,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
# solver timestep.

# With the adapters disabled, the `unet` is the regular teacher model.
unet.disable_adapters()
accelerator.unwrap_model(unet).disable_adapters()
with torch.no_grad():
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = unet(
@@ -1248,7 +1248,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
x_prev = solver.ddim_step(pred_x0, pred_noise, index).to(unet.dtype)

# re-enable unet adapters to turn the `unet` into a student unet.
unet.enable_adapters()
accelerator.unwrap_model(unet).enable_adapters()

# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation.
@@ -1332,7 +1332,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = get_peft_model_state_dict(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
StableDiffusionXLPipeline.save_lora_weights(args.output_dir, unet_lora_layers=unet_lora_state_dict)

if args.push_to_hub:
12 changes: 7 additions & 5 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
@@ -54,7 +54,7 @@
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available


@@ -853,9 +853,11 @@ def save_model_hook(models, weights, output_dir):

for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = get_peft_model_state_dict(model)
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
else:
raise ValueError(f"unexpected save model: {model.__class__}")

@@ -1285,11 +1287,11 @@ def compute_text_embeddings(prompt):
unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32)

unet_lora_state_dict = get_peft_model_state_dict(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))

if args.train_text_encoder:
text_encoder = accelerator.unwrap_model(text_encoder)
text_encoder_state_dict = get_peft_model_state_dict(text_encoder)
text_encoder_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder))
else:
text_encoder_state_dict = None

8 changes: 5 additions & 3 deletions examples/text_to_image/train_text_to_image_lora.py
Original file line number Diff line number Diff line change
@@ -44,7 +44,7 @@
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available


@@ -809,7 +809,9 @@ def collate_fn(examples):
accelerator.save_state(save_path)

unwrapped_unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(
get_peft_model_state_dict(unwrapped_unet)
)

StableDiffusionPipeline.save_lora_weights(
save_directory=save_path,
@@ -876,7 +878,7 @@ def collate_fn(examples):
unet = unet.to(torch.float32)

unwrapped_unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))
StableDiffusionPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
18 changes: 11 additions & 7 deletions examples/text_to_image/train_text_to_image_lora_sdxl.py
Original file line number Diff line number Diff line change
@@ -52,7 +52,7 @@
from diffusers.loaders import LoraLoaderMixin
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available


@@ -651,11 +651,15 @@ def save_model_hook(models, weights, output_dir):

for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = get_peft_model_state_dict(model)
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
else:
raise ValueError(f"unexpected save model: {model.__class__}")

@@ -1160,14 +1164,14 @@ def compute_time_ids(original_size, crops_coords_top_left):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = get_peft_model_state_dict(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))

if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
text_encoder_two = accelerator.unwrap_model(text_encoder_two)

text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one)
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two)
text_encoder_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_one))
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_two))
else:
text_encoder_lora_layers = None
text_encoder_2_lora_layers = None

0 comments on commit 509ac93

Please sign in to comment.