From 82c881d12a25745a5ce02e3abbe1f46d1d7dc42e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 24 Dec 2023 13:09:14 +0530 Subject: [PATCH 1/2] spit diffusers-native format from the get go. --- .../dreambooth/train_dreambooth_lora_sdxl.py | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 0f41ad47d1ac..aa6e7d21aa73 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -54,7 +54,7 @@ from diffusers.loaders import LoraLoaderMixin from diffusers.optimization import get_scheduler from diffusers.training_utils import compute_snr -from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -1019,11 +1019,15 @@ def save_model_hook(models, weights, output_dir): for model in models: if isinstance(model, type(accelerator.unwrap_model(unet))): - unet_lora_layers_to_save = get_peft_model_state_dict(model) + unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): - text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) + text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( + get_peft_model_state_dict(model) + ) elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): - text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model) + text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( + get_peft_model_state_dict(model) + ) else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1615,13 +1619,17 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): if accelerator.is_main_process: unet = accelerator.unwrap_model(unet) unet = unet.to(torch.float32) - unet_lora_layers = get_peft_model_state_dict(unet) + unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) if args.train_text_encoder: text_encoder_one = accelerator.unwrap_model(text_encoder_one) - text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32)) + text_encoder_lora_layers = convert_state_dict_to_diffusers( + get_peft_model_state_dict(text_encoder_one.to(torch.float32)) + ) text_encoder_two = accelerator.unwrap_model(text_encoder_two) - text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32)) + text_encoder_2_lora_layers = convert_state_dict_to_diffusers( + get_peft_model_state_dict(text_encoder_two.to(torch.float32)) + ) else: text_encoder_lora_layers = None text_encoder_2_lora_layers = None From c05ed5a12f9cf453add0fdd8fe8dc6716c4a9b75 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 24 Dec 2023 13:54:36 +0530 Subject: [PATCH 2/2] rejig the peft_to_diffusers mapping. --- src/diffusers/utils/state_dict_utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 777c611f7150..6c163034e7fd 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -79,6 +79,14 @@ class StateDictType(enum.Enum): ".v_proj.lora_A": ".v_proj.lora_linear_layer.down", ".out_proj.lora_B": ".out_proj.lora_linear_layer.up", ".out_proj.lora_A": ".out_proj.lora_linear_layer.down", + "to_k.lora_A": "to_k.lora.down", + "to_k.lora_B": "to_k.lora.up", + "to_q.lora_A": "to_q.lora.down", + "to_q.lora_B": "to_q.lora.up", + "to_v.lora_A": "to_v.lora.down", + "to_v.lora_B": "to_v.lora.up", + "to_out.0.lora_A": "to_out.0.lora.down", + "to_out.0.lora_B": "to_out.0.lora.up", } DIFFUSERS_OLD_TO_DIFFUSERS = {