-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LoRA] feat: support loading loras into 4bit quantized Flux models. #10578
Conversation
if quantization_config.load_in_4bit: | ||
expansion_shape = torch.Size(expansion_shape).numel() | ||
expansion_shape = ((expansion_shape + 1) // 2, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only 4bit bnb models flatten.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about adding a comment along the lines of: "Handle 4bit bnb weights, which are flattened and compress 2 params into 1".
I'm not quite sure why we need (shape+1) // 2, maybe this could be added to the comment too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this comes from bitsandbytes
. Cc: @matthewdouglas
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for rounding, i.e. if expansion_shape
is odd it will have an additional 8bit tensor with just one value packed into it.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing the issue with loading LoRA into Flux models that are quantized with 4 bit bnb.
The issue of the actual parameter shape is a common trap, it would be great if the original shape could be retrieved from bnb, not sure if there is a method for that.
Sayak, could you quickly sketch why this used to work and no longer does? IIRC, there was no special handling for 4 bit bnb previously, was there?
if quantization_config.load_in_4bit: | ||
expansion_shape = torch.Size(expansion_shape).numel() | ||
expansion_shape = ((expansion_shape + 1) // 2, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about adding a comment along the lines of: "Handle 4bit bnb weights, which are flattened and compress 2 params into 1".
I'm not quite sure why we need (shape+1) // 2, maybe this could be added to the comment too.
module_weight_shape = module_weight.shape | ||
expansion_shape = (out_features, in_features) | ||
quantization_config = getattr(transformer, "quantization_config", None) | ||
if quantization_config and quantization_config.quant_method == "bitsandbytes": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to have a utility function to get the shape to avoid code duplication?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this needs to happen. As I mentioned this PR is very much a PoC to gather feedback and I will refine it. But I wanted to first explore if this a good way to approach the problem.
It used to because we didn't have any support for loading Flux Control LoRA, some relevant pieces of vital code: diffusers/src/diffusers/loaders/lora_pipeline.py Line 1941 in be62c85
diffusers/src/diffusers/loaders/lora_pipeline.py Line 2055 in be62c85
|
Okay, so Flux control LoRA + 4bit bnb was never possible. From the initial description, I got the wrong impression that this is a full regression. |
|
We fully agree there. What I meant is that I misunderstood your original message to mean that Flux control LoRA + 4 bit bnb used to work and was trying to understand why it breaks now. This is what I meant by "full regression". |
@BenjaminBossan That's a good point. The |
Thanks, @BenjaminBossan. I will work on the suggestions to make it ready for another review pass.
Thanks @matthewdouglas! Will try to see if we can use it here. |
Just confirmed that this works: from diffusers import FluxTransformer2DModel
model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/flux.1-dev-nf4-pkg", subfolder="transformer"
)
print(model.context_embedder.weight.quant_state.shape)
# torch.Size([3072, 4096]) |
@BenjaminBossan I have addressed the rest of the feedback. PTAL. @DN6 ready for your review, too. Just as FYI, I have run all the LoRA related integration tests for Flux and they all pass. |
base_weight_param_name: str = None, | ||
) -> "torch.Size": | ||
def _get_weight_shape(weight: torch.Tensor): | ||
return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8bit params preserve the original shape?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They do.
module_path = ( | ||
base_weight_param_name.rsplit(".weight", 1)[0] | ||
if base_weight_param_name.endswith(".weight") | ||
else base_weight_param_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In what case would this param name not end with weight
? Since the subsequent call to get_weight_shape
assumes that there is a weight attribute?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True. Let me modify that. Thanks for catching.
@DN6 applied changes and re-run all the required tests and they passed. |
Thanks for fixing this, Sayak. |
…10578) * feat: support loading loras into 4bit quantized models. * updates * update * remove weight check.
What does this PR do?
We broke Flux LoRA (not Control LoRA) loading for 4bit BnB Flux in 0.32.0, when supporting Flux Control LoRAs (yeah only applies to Flux).
To reproduce:
Code
This above code will run with
v0.31.0-release
branch but will fail withv0.32.0-release
along withmain
. This went uncaught because we don't test for it.This PR attempts to partially fix the problem so that we can at least resort to a behavior similar to what was happening in the
v0.31.0-release
branch. Want to use this PR to refine how we're doing that. I want to ship this PR first and tackle the TODOs in a follow-up. Once this PR is done, we might have to do a patch release.Related issue: #10550