Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] make LoRAs trained with peft loadable when peft isn't installed #6306

Merged
merged 3 commits into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from diffusers.loaders import LoraLoaderMixin
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available


Expand Down Expand Up @@ -1019,11 +1019,15 @@ def save_model_hook(models, weights, output_dir):

for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = get_peft_model_state_dict(model)
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
else:
raise ValueError(f"unexpected save model: {model.__class__}")

Expand Down Expand Up @@ -1615,13 +1619,17 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32)
unet_lora_layers = get_peft_model_state_dict(unet)
unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))

if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
text_encoder_lora_layers = convert_state_dict_to_diffusers(
get_peft_model_state_dict(text_encoder_one.to(torch.float32))
)
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32))
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(
get_peft_model_state_dict(text_encoder_two.to(torch.float32))
)
else:
text_encoder_lora_layers = None
text_encoder_2_lora_layers = None
Expand Down
8 changes: 8 additions & 0 deletions src/diffusers/utils/state_dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ class StateDictType(enum.Enum):
".v_proj.lora_A": ".v_proj.lora_linear_layer.down",
".out_proj.lora_B": ".out_proj.lora_linear_layer.up",
".out_proj.lora_A": ".out_proj.lora_linear_layer.down",
"to_k.lora_A": "to_k.lora.down",
"to_k.lora_B": "to_k.lora.up",
"to_q.lora_A": "to_q.lora.down",
"to_q.lora_B": "to_q.lora.up",
"to_v.lora_A": "to_v.lora.down",
"to_v.lora_B": "to_v.lora.up",
"to_out.0.lora_A": "to_out.0.lora.down",
"to_out.0.lora_B": "to_out.0.lora.up",
}

DIFFUSERS_OLD_TO_DIFFUSERS = {
Expand Down
Loading