-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LoRA] Add LoRA support to AuraFlow #10216
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
da41f32
to
c8364bc
Compare
Thanks for the helping hand, @hlky! |
See https://github.com/huggingface/diffusers/blob/main/tests/lora/test_lora_layers_flux.py https://github.com/huggingface/diffusers/blob/main/tests/lora/test_lora_layers_mochi.py etc as an example for tests. Seems to be missing |
20b5f2d
to
80ac0d4
Compare
@sayakpaul Okay; I'm at a point where I've got actual, valid test failures but have no idea where to look. |
912fb8d
to
f5b9f90
Compare
Here's the log after the latest commit: pytest.log |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. I have left some comments to fix a couple of things. LMK if they're unclear.
f5b9f90
to
1c79095
Compare
Latest test log. |
Failures:
|
It seems nothing but CLIP is supported in diffusers/src/diffusers/models/lora.py Lines 41 to 66 in 532013f
which is called here diffusers/src/diffusers/loaders/lora_pipeline.py Lines 2185 to 2197 in 532013f
This means essentially that we need more plumbing to support this for arbitrary text encoders, or to only support the transformer for AuraFlow. This is because AuraFlow only has one UMT5 encoder. |
Thanks, let's support only |
So I skipped all tests requiring TE in e06d8eb. Latest failures are: pytest.log
I'm not entirely sure how to get past these. |
Could we try to look into each failure and debug? |
b59b25e
to
00c921e
Compare
00c921e
to
1ec07a1
Compare
src/diffusers/loaders/lora_base.py
Outdated
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) | ||
|
||
for name, _ in text_encoder_attn_modules(text_encoder): | ||
for module in ("out_proj", "q_proj", "k_proj", "v_proj"): | ||
rank_key = f"{name}.{module}.lora_B.weight" | ||
if rank_key not in text_encoder_lora_state_dict: | ||
continue | ||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] | ||
|
||
for name, _ in text_encoder_mlp_modules(text_encoder): | ||
for module in ("fc1", "fc2"): | ||
rank_key = f"{name}.{module}.lora_B.weight" | ||
if rank_key not in text_encoder_lora_state_dict: | ||
continue | ||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] | ||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict, original_type=StateDictType.DIFFUSERS) | ||
|
||
for name, module in text_encoder.named_modules(): | ||
if "lora_A" not in name and "lora_B" not in name and isinstance(module, (nn.Linear, nn.Conv2d)): | ||
rank_key = f"{name.removesuffix(".base_layer")}.lora_B.weight" | ||
if rank_key in text_encoder_lora_state_dict: | ||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes support TE LoRA other than CLIP; but they cause one other model to fail, likely an edge case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine to not add these changes in this PR as it seems unrelated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll make a separate PR for this and we can review that one first -- it seems to be the crux of what's failing with the tests IIUC. Then coming back to this one should be easy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so. If we don't support loading LoRAs into a module (_lora_loadable_modules
) then that shouldn't cause any failures.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It still causes the failures here, and I'm kinda clueless how to fix them: #10216 (comment)
For context, this is the first model we'd support without a TE, the tests aren't written for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so.
diffusers/src/diffusers/loaders/lora_pipeline.py
Line 2560 in edb8c1b
class Mochi1LoraLoaderMixin(LoraBaseMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
neither does OmniGen
9a89ed0
to
5620384
Compare
0751c4b
to
12fbd11
Compare
Okay, I've updated the code to mirror Mochi1 instead of Flux, as that was a closer match (no TE LoRA), and removed the TE from loadable modules. I've also skipped TE-only LoRA tests. Here are the remaining failures: Failure list
Latest pytest.log |
I think a better approach to debugging is taking a single test failure and trying to see what's causing them. We could compare the implementation with another model having similarities (Mochi-1 is a good one) and then take things from there. Have we tried that? We're grateful for your contributions so far. But it might be even better if we tried going further a bit for debugging. This will help solidify your contributions, too. diffusers/src/diffusers/loaders/lora_pipeline.py Line 3199 in 328e0d2
|
What does this PR do?
This PR is a simple rebase of #9017
cc @sayakpaul for review.
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.