Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PEFT to advanced training script #6294

Merged
merged 14 commits into from
Dec 27, 2023

Conversation

apolinario
Copy link
Collaborator

@apolinario apolinario commented Dec 22, 2023

What does this PR do?

Adds PEFT to the advanced training script.

Some questions are open still regarding PEFT integration and whether we should change more things in the script to accommodate/adapt the textual inversion training that takes place in this script:

Closes #6118

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, it looks overall good, I left one open question about handling additional unfrozen params

text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
text_encoder_two, dtype=torch.float32, rank=args.rank
text_lora_config = LoraConfig(
r=args.rank, init_lora_weights="gaussian", target_modules=["q_proj", "k_proj", "v_proj", "out_proj"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be correct with respect to the alpha issue we discussed offline @sayakpaul right? Related: #6225

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing to that PR! Adding the lora alpha fixed the problem and now the script is working! (before it was giving jumbled results)

Copy link
Collaborator Author

@apolinario apolinario Dec 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

without alpha

image

with alpha

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Holy cow 🐮

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing! 🤩


if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32))
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you train extra parameters that you keep unfrozen for text encoder, you need to add them in modules_to_save=[xxx] when defining the lora config for the text encoder

Copy link
Collaborator Author

@apolinario apolinario Dec 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it! But I'm not training extra parameters with that operation. With args.train_text_encoder we're doing regular text encoder training.

However, with args.train_text_encoder_ti (that is mutually exclusive with args.train_text_encoder) then the goal is to freeze all but the token_embedding of the text encoder and train the text embeddings for the new tokens introduced in the model.

This is where is taking place - and was working prior to adding PEFT elsewhere: https://github.com/huggingface/diffusers/pull/6294/files#diff-24abe8b0339a563b68e03c979ee9e498ab7c49f3fd749ffb784156f4e2d54d90R1249

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, only thing that is outside of the peft paradigm currently is what's happening when args.train_text_encoder_ti is True, yeah?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, which makes sense because it is not training an adapter per se

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @apolinario for explaining ! makes sense

@apolinario apolinario marked this pull request as ready for review December 23, 2023 01:53
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic work.

To answer your questions, here's what I think:

https://github.com/huggingface/diffusers/blob/add-peft-to-advanced-training-script/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py#L1259

I think to be on the safe side, we could tackle the token embedding part after we're done handling train_text_encoder. Which is what is happening now. So, that's good. I would maybe remove the parameter upcasting part because we're already doing in later in the script.

Also, args.train_text_encoder_ti and args.train_text_encoder ==> can both be set to True?

For the sub-question, I think you meant this line:

text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))

Given that's true, yeah I think this step should be able to pick that up. But just to be sure, I'd maybe right out print the param names.

For pivotal tuning, we "Pivot Halfway" meaning that we can stop the textual inversion training at a % of the steps. I don't see how PEFT affects that but flagging in case someone sees something there:

No, it's should affect anything related to peft.

@@ -37,6 +37,8 @@
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from packaging import version
from peft import LoraConfig
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will need to peft as a dependency in the requirements.txt.

unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
unet_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very important!

Comment on lines +1275 to +1284
# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
models = [unet]
if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two])
for model in models:
for param in model.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another important one!


if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32))
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, only thing that is outside of the peft paradigm currently is what's happening when args.train_text_encoder_ti is True, yeah?

@apolinario
Copy link
Collaborator Author

Also, args.train_text_encoder_ti and args.train_text_encoder ==> can both be set to True?

Not yet. We plan supporting training both the text encoder and textual inversion - but that's for a future version of the script

@apolinario apolinario merged commit 645a62b into main Dec 27, 2023
16 checks passed
donhardman pushed a commit to donhardman/diffusers that referenced this pull request Dec 29, 2023
* Fix ProdigyOPT in SDXL Dreambooth script

* style

* style

* Add PEFT to Advanced Training Script

* style

* style

* ✨ style ✨

* change order for logic operation

* add lora alpha

* style

* Align PEFT to new format

* Update train_dreambooth_lora_sdxl_advanced.py

Apply huggingface#6355 fix

---------

Co-authored-by: multimodalart <[email protected]>
antoine-scenario pushed a commit to antoine-scenario/diffusers that referenced this pull request Jan 2, 2024
* Fix ProdigyOPT in SDXL Dreambooth script

* style

* style

* Add PEFT to Advanced Training Script

* style

* style

* ✨ style ✨

* change order for logic operation

* add lora alpha

* style

* Align PEFT to new format

* Update train_dreambooth_lora_sdxl_advanced.py

Apply huggingface#6355 fix

---------

Co-authored-by: multimodalart <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* Fix ProdigyOPT in SDXL Dreambooth script

* style

* style

* Add PEFT to Advanced Training Script

* style

* style

* ✨ style ✨

* change order for logic operation

* add lora alpha

* style

* Align PEFT to new format

* Update train_dreambooth_lora_sdxl_advanced.py

Apply huggingface#6355 fix

---------

Co-authored-by: multimodalart <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Advanced Diffusion Training] Adapt train_dreambooth_lora_sdxl_advanced.py to use peft
4 participants