Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] feat: support loading loras into 4bit quantized Flux models. #10578

Merged
merged 6 commits into from
Jan 15, 2025

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Jan 14, 2025

What does this PR do?

We broke Flux LoRA (not Control LoRA) loading for 4bit BnB Flux in 0.32.0, when supporting Flux Control LoRAs (yeah only applies to Flux).

To reproduce:

Code
import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, FluxTransformer2DModel, FluxPipeline
from huggingface_hub import hf_hub_download


transformer_4bit = FluxTransformer2DModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    subfolder="transformer",
    quantization_config=DiffusersBitsAndBytesConfig(
        load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
    ),
    torch_dtype=torch.bfloat16,
)
pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    transformer=transformer_4bit,
    torch_dtype=torch.bfloat16,
).to("cuda")

pipe.load_lora_weights(
    hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), 
    adapter_name="hyper-sd"
)
pipe.set_adapters("hyper-sd", adapter_weights=0.125)

prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."

image = pipe(
    prompt=prompt,
    height=1024,
    width=1024,
    max_sequence_length=512,
    num_inference_steps=8,
    guidance_scale=50,
    generator=torch.Generator().manual_seed(42),
).images[0]
image.save("out.jpg")

This above code will run with v0.31.0-release branch but will fail with v0.32.0-release along with main. This went uncaught because we don't test for it.

This PR attempts to partially fix the problem so that we can at least resort to a behavior similar to what was happening in the v0.31.0-release branch. Want to use this PR to refine how we're doing that. I want to ship this PR first and tackle the TODOs in a follow-up. Once this PR is done, we might have to do a patch release.

Related issue: #10550

@sayakpaul sayakpaul changed the title [WIP] [LoRA] feat: support loading loras into 4bit quantized models. [WIP] [LoRA] feat: support loading loras into 4bit quantized Flux models. Jan 14, 2025
Comment on lines 1989 to 1991
if quantization_config.load_in_4bit:
expansion_shape = torch.Size(expansion_shape).numel()
expansion_shape = ((expansion_shape + 1) // 2, 1)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only 4bit bnb models flatten.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about adding a comment along the lines of: "Handle 4bit bnb weights, which are flattened and compress 2 params into 1".

I'm not quite sure why we need (shape+1) // 2, maybe this could be added to the comment too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this comes from bitsandbytes. Cc: @matthewdouglas

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for rounding, i.e. if expansion_shape is odd it will have an additional 8bit tensor with just one value packed into it.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing the issue with loading LoRA into Flux models that are quantized with 4 bit bnb.

The issue of the actual parameter shape is a common trap, it would be great if the original shape could be retrieved from bnb, not sure if there is a method for that.

Sayak, could you quickly sketch why this used to work and no longer does? IIRC, there was no special handling for 4 bit bnb previously, was there?

Comment on lines 1989 to 1991
if quantization_config.load_in_4bit:
expansion_shape = torch.Size(expansion_shape).numel()
expansion_shape = ((expansion_shape + 1) // 2, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about adding a comment along the lines of: "Handle 4bit bnb weights, which are flattened and compress 2 params into 1".

I'm not quite sure why we need (shape+1) // 2, maybe this could be added to the comment too.

module_weight_shape = module_weight.shape
expansion_shape = (out_features, in_features)
quantization_config = getattr(transformer, "quantization_config", None)
if quantization_config and quantization_config.quant_method == "bitsandbytes":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to have a utility function to get the shape to avoid code duplication?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this needs to happen. As I mentioned this PR is very much a PoC to gather feedback and I will refine it. But I wanted to first explore if this a good way to approach the problem.

@sayakpaul
Copy link
Member Author

Sayak, could you quickly sketch why this used to work and no longer does? IIRC, there was no special handling for 4 bit bnb previously, was there?

It used to because we didn't have any support for loading Flux Control LoRA, some relevant pieces of vital code:

def _maybe_expand_transformer_param_shape_or_error_(

def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):

@BenjaminBossan
Copy link
Member

It used to because we didn't have any support for loading Flux Control LoRA

Okay, so Flux control LoRA + 4bit bnb was never possible. From the initial description, I got the wrong impression that this is a full regression.

@sayakpaul
Copy link
Member Author

Okay, so Flux control LoRA + 4bit bnb was never possible. From the initial description, I got the wrong impression that this is a full regression.

  1. Flux control LoRA + 4bit bnb was never possible.
  2. However, Flux LoRA + 4bit bnb was possible. So, I think Flux Control LoRA support introduced a regression at this point as we cannot do Flux LoRA + 4bit bnb anymore? What am I missing?

@BenjaminBossan
Copy link
Member

  1. Flux control LoRA + 4bit bnb was never possible.
  2. However, Flux LoRA + 4bit bnb was possible. So, I think Flux Control LoRA support introduced a regression at this point as we cannot do Flux LoRA + 4bit bnb anymore? What am I missing?

We fully agree there. What I meant is that I misunderstood your original message to mean that Flux control LoRA + 4 bit bnb used to work and was trying to understand why it breaks now. This is what I meant by "full regression".

@matthewdouglas
Copy link
Member

The issue of the actual parameter shape is a common trap, it would be great if the original shape could be retrieved from bnb, not sure if there is a method for that.

@BenjaminBossan That's a good point. The Params4bit and Linear4bit should have a quant_state attribute on them, which has a shape to indicate the original shape. I haven't used this much myself but it could be an option here.

@sayakpaul
Copy link
Member Author

We fully agree there. What I meant is that I misunderstood your original message to mean that Flux control LoRA + 4 bit bnb used to work and was trying to understand why it breaks now. This is what I meant by "full regression".

Thanks, @BenjaminBossan. I will work on the suggestions to make it ready for another review pass.

That's a good point. The Params4bit and Linear4bit should have a quant_state attribute on them, which has a shape to indicate the original shape. I haven't used this much myself but it could be an option here.

Thanks @matthewdouglas! Will try to see if we can use it here.

@sayakpaul
Copy link
Member Author

Just confirmed that this works:

from diffusers import FluxTransformer2DModel

model = FluxTransformer2DModel.from_pretrained(
    "hf-internal-testing/flux.1-dev-nf4-pkg", subfolder="transformer"
)
print(model.context_embedder.weight.quant_state.shape)
# torch.Size([3072, 4096])

@sayakpaul sayakpaul marked this pull request as ready for review January 15, 2025 04:14
@sayakpaul sayakpaul changed the title [WIP] [LoRA] feat: support loading loras into 4bit quantized Flux models. [LoRA] feat: support loading loras into 4bit quantized Flux models. Jan 15, 2025
@sayakpaul sayakpaul requested a review from DN6 January 15, 2025 04:14
@sayakpaul
Copy link
Member Author

@BenjaminBossan I have addressed the rest of the feedback. PTAL.

@DN6 ready for your review, too. Just as FYI, I have run all the LoRA related integration tests for Flux and they all pass.

base_weight_param_name: str = None,
) -> "torch.Size":
def _get_weight_shape(weight: torch.Tensor):
return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8bit params preserve the original shape?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They do.

module_path = (
base_weight_param_name.rsplit(".weight", 1)[0]
if base_weight_param_name.endswith(".weight")
else base_weight_param_name
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what case would this param name not end with weight? Since the subsequent call to get_weight_shape assumes that there is a weight attribute?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. Let me modify that. Thanks for catching.

@sayakpaul sayakpaul requested a review from DN6 January 15, 2025 06:51
@sayakpaul
Copy link
Member Author

@DN6 applied changes and re-run all the required tests and they passed.

@sayakpaul sayakpaul merged commit 2432f80 into main Jan 15, 2025
15 checks passed
@sayakpaul sayakpaul deleted the 4bit-lora-loading branch January 15, 2025 07:10
@BenjaminBossan
Copy link
Member

Thanks for fixing this, Sayak.

DN6 pushed a commit that referenced this pull request Jan 15, 2025
…10578)

* feat: support loading loras into 4bit quantized models.

* updates

* update

* remove weight check.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants