Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] Quanto Flux LoRA can't load #10512

Open
Mino1289 opened this issue Jan 9, 2025 · 24 comments
Open

[LoRA] Quanto Flux LoRA can't load #10512

Mino1289 opened this issue Jan 9, 2025 · 24 comments
Labels
bug Something isn't working

Comments

@Mino1289
Copy link

Mino1289 commented Jan 9, 2025

Describe the bug

Cannot load LoRAs into quanto-quantized Flux.

import torch 
from diffusers import FluxTransformer2DModel, FluxPipeline
from huggingface_hub import hf_hub_download
from optimum.quanto import qfloat8, quantize, freeze
from transformers import T5EncoderModel

bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16

transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8.safetensors", torch_dtype=dtype)
quantize(transformer, weights=qfloat8)
freeze(transformer)

text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)

pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2

pipe.load_lora_weights(
    hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
)

Logs

ERROR:
Traceback (most recent call last):
  File "/home/user/genAI/test.py", line 56, in <module>
    pipe.load_lora_weights(
  File "/home/user/miniconda3/lib/python3.12/site-packages/diffusers/loaders/lora_pipeline.py", line 1867, in load_lora_weights
    transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/miniconda3/lib/python3.12/site-packages/diffusers/loaders/lora_pipeline.py", line 2490, in _maybe_expand_lora_state_dict
    base_weight_param = transformer_state_dict[base_param_name]
                        ~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
KeyError: 'single_transformer_blocks.0.attn.to_k.weight'

System Info

Python 3.12
diffusers 0.32.0 (I tested 0.32.1 and install from git)

Who can help?

@sayakpaul

@Mino1289 Mino1289 added the bug Something isn't working label Jan 9, 2025
@sayakpaul
Copy link
Member

Can you try with diffusers installation from main?

pip uninstall diffusers -y
pip install git+https://github.com/huggingface/diffusers

@lhjlhj11
Copy link

Can you try with diffusers installation from main?

pip uninstall diffusers -y
pip install git+https://github.com/huggingface/diffusers

The problem is still exit after this operation

@sayakpaul
Copy link
Member

Do you have a minimal reproducible snippet? The provided one isn't minimal and self-contained. I keep asking for that because we have an integration test for Kohya LoRAs here:

def test_flux_kohya(self):

It was run yesterday, too, and it worked fine.

@tyyff
Copy link

tyyff commented Jan 10, 2025

Do you have a minimal reproducible snippet? The provided one isn't minimal and self-contained. I keep asking for that because we have an integration test for Kohya LoRAs here:

def test_flux_kohya(self):

It was run yesterday, too, and it worked fine.

This issue only occurs when loading LoRA after quantizing the FLUX transformer using optimum.quanto. If the model is not quantized, LoRA can be loaded normally. In version 0.31 of diffusers, LoRA could be loaded successfully even after quantization.

@lhjlhj11
Copy link

Do you have a minimal reproducible snippet? The provided one isn't minimal and self-contained. I keep asking for that because we have an integration test for Kohya LoRAs here:

def test_flux_kohya(self):

It was run yesterday, too, and it worked fine.

transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8.safetensors", torch_dtype=dtype)
quantize(transformer, weights=qfloat8)
freeze(transformer)
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2
# this is a 8steps lora
self.pipe.load_lora_weights(load_file(os.path.join(self.model_root, self.config["8steps_lora"]), device=self.device), adapter_name="8steps")
self.pipe.set_adapters(["8steps"], adapter_weights=[0.125])

@sayakpaul
Copy link
Member

@tyyff if you could help me with a minimally reproducible snippet that would be great, ideally with a supported quantization backend like bitsandbytes.

@Mino1289
Copy link
Author

Mino1289 commented Jan 10, 2025

I used the script and quantization method here :
https://gist.github.com/sayakpaul/b664605caf0aa3bf8585ab109dd5ac9c
The script by AmericanPresidentJimmyCarter.

@lhjlhj11
Copy link

@tyyff if you could help me with a minimally reproducible snippet that would be great, ideally with a supported quantization backend like bitsandbytes.

Can you solve the problems with flux-fp8 version? Thanks!!!

@lhjlhj11
Copy link

@tyyff if you could help me with a minimally reproducible snippet that would be great, ideally with a supported quantization backend like bitsandbytes.

Or can diffusers under 0.32.0 support flux redux?

@Yakonrus
Copy link

Yakonrus commented Jan 12, 2025

@tyyff if you could help me with a minimally reproducible snippet that would be great, ideally with a supported quantization backend like bitsandbytes.

Just a combination of two examples from the article on using Flux

import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, FluxPriorReduxPipeline, FluxControlPipeline, FluxTransformer2DModel, FluxPipeline
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
from diffusers.utils import load_image
from image_gen_aux import DepthPreprocessor
from diffusers.utils import load_image
from huggingface_hub import hf_hub_download

text_encoder_8bit = T5EncoderModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    subfolder="text_encoder_2",
    quantization_config=DiffusersBitsAndBytesConfig(load_in_8bit=True),
    torch_dtype=torch.float16,
)

transformer_8bit = FluxTransformer2DModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    subfolder="transformer",
    quantization_config=DiffusersBitsAndBytesConfig(load_in_8bit=True),
    torch_dtype=torch.float16,
)

control_pipe = FluxControlPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    text_encoder=text_encoder_8bit,
    transformer=transformer_8bit,
    torch_dtype=torch.float16,
    device_map="balanced",
)

control_pipe.load_lora_weights("black-forest-labs/FLUX.1-Depth-dev-lora", adapter_name="depth")
control_pipe.load_lora_weights(
    hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
)
control_pipe.set_adapters(["depth", "hyper-sd"], adapter_weights=[0.85, 0.125])
control_pipe.enable_model_cpu_offload()

prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")

processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
control_image = processor(control_image)[0].convert("RGB")

image = control_pipe(
    prompt=prompt,
    control_image=control_image,
    height=1024,
    width=1024,
    num_inference_steps=8,
    guidance_scale=10.0,
    generator=torch.Generator().manual_seed(42),
).images[0]

images[0].save("out.jpg")
    control_pipe.load_lora_weights("black-forest-labs/FLUX.1-Depth-dev-lora", adapter_name="depth")
  File "/usr/local/lib/python3.10/dist-packages/diffusers/loaders/lora_pipeline.py", line 1856, in load_lora_weights
    has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
  File "/usr/local/lib/python3.10/dist-packages/diffusers/loaders/lora_pipeline.py", line 2359, in _maybe_expand_transformer_param_shape_or_error_
    expanded_module = torch.nn.Linear(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py", line 99, in __init__
    self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/parameter.py", line 40, in __new__
    return torch.Tensor._make_subclass(cls, data, requires_grad)
RuntimeError: Only Tensors of floating point and complex dtype can require gradients

@tyyff
Copy link

tyyff commented Jan 13, 2025

@tyyff if you could help me with a minimally reproducible snippet that would be great, ideally with a supported quantization backend like bitsandbytes.

import torch
from diffusers import FluxTransformer2DModel, FluxPipeline
from transformers import T5EncoderModel, CLIPTextModel
import os
from optimum.quanto import freeze, qfloat8, quantize
import random

bfl_repo = "black-forest-labs/FLUX.1-schnell"
dtype = torch.bfloat16

transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder="transformer", torch_dtype=dtype)
quantize(transformer, weights=qfloat8)
freeze(transformer)

text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)

pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2
pipe.to(torch.device("cuda"))
pipe.load_lora_weights("Shakker-Labs/FLUX.1-dev-LoRA-Logo-Design", weight_name="FLUX-dev-lora-Logo-Design.safetensors")
seed = random.randint(1, 1 << 32)
image = pipe(
    prompt="logo,Minimalist,A bunch of grapes and a wine glass",
    guidance_scale=1.,
    output_type="pil",
    num_inference_steps=8,
    generator=torch.Generator("cpu").manual_seed(seed)
).images[0]

image.save("test.png")
  File "/nanjgrowth-train-public/root/nanjgrowth-public-1/tangweiye/training/flux_finetuning/test_utils/test_lora_snippet.py", l
ine 23, in <module>                                                                                                             
    pipe.load_lora_weights("Shakker-Labs/FLUX.1-dev-LoRA-Logo-Design", weight_name="FLUX-dev-lora-Logo-Design.safetensors")     
  File "/root/micromamba/envs/twy_diffusers/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py", line 1866, in load
_lora_weights                                                                                                                   
    transformer_lora_state_dict = self._maybe_expand_lora_state_dict(                                                           
  File "/root/micromamba/envs/twy_diffusers/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py", line 2415, in _may
be_expand_lora_state_dict                                                                                                       
    base_weight_param = transformer_state_dict[base_param_name]                                                                 
KeyError: 'single_transformer_blocks.0.attn.to_k.weight' 

This is pip requirements.txt:

absl-py==2.1.0
accelerate==1.2.1
annotated-types==0.7.0
bitsandbytes==0.45.0
certifi==2024.12.14
charset-normalizer==3.4.1
deepspeed==0.15.4
diffusers==0.32.1
einops==0.8.0
filelock==3.13.1
fsspec==2024.2.0
grpcio==1.68.1
hjson==3.1.0
huggingface-hub==0.27.0
idna==3.10
importlib_metadata==8.5.0
Jinja2==3.1.3
Markdown==3.7
MarkupSafe==2.1.5
mpmath==1.3.0
msgpack==1.1.0
networkx==3.2.1
ninja==1.11.1.3
numpy==1.26.3
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-ml-py==12.560.30
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.1.105
nvidia-nvtx-cu12==12.1.105
optimum-quanto==0.2.6
packaging==24.2
peft==0.14.0
pillow==10.2.0
protobuf==5.29.2
psutil==6.1.1
py-cpuinfo==9.0.0
pydantic==2.10.4
pydantic_core==2.27.2
PyYAML==6.0.2
regex==2024.11.6
requests==2.32.3
safetensors==0.4.5
sentencepiece==0.2.0
six==1.17.0
sympy==1.13.1
tensorboard==2.18.0
tensorboard-data-server==0.7.2
tokenizers==0.21.0
torch==2.4.1+cu121
torchaudio==2.4.1+cu121
torchvision==0.19.1+cu121
tqdm==4.67.1
transformers==4.47.1
triton==3.0.0
typing_extensions==4.12.2
urllib3==2.3.0
Werkzeug==3.1.3
zipp==3.21.0

@sayakpaul
Copy link
Member

Tracking here: #10550.

@sayakpaul
Copy link
Member

I tested with v0.31.0-release and it fails with:

Error ```bash Traceback (most recent call last): File "/home/sayak/diffusers/check_fp8.py", line 22, in pipe.load_lora_weights( File "/home/sayak/diffusers/src/diffusers/loaders/lora_pipeline.py", line 1846, in load_lora_weights self.load_lora_into_transformer( File "/home/sayak/diffusers/src/diffusers/loaders/lora_pipeline.py", line 1949, in load_lora_into_transformer incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs) File "/home/sayak/peft/src/peft/utils/save_and_load.py", line 445, in set_peft_model_state_dict load_result = model.load_state_dict(peft_model_state_dict, strict=False, assign=True) File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2564, in load_state_dict load(self, state_dict) File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2552, in load load(child, child_state_dict, child_prefix) # noqa: F821 File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2552, in load load(child, child_state_dict, child_prefix) # noqa: F821 File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2552, in load load(child, child_state_dict, child_prefix) # noqa: F821 [Previous line repeated 1 more time] File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2535, in load module._load_from_state_dict( File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/optimum/quanto/nn/qmodule.py", line 160, in _load_from_state_dict deserialized_weight = WeightQBytesTensor.load_from_state_dict( File "/home/sayak/.pyenv/versions/3.10.12/envs/diffusers/lib/python3.10/site-packages/optimum/quanto/tensor/weights/qbytes.py", line 77, in load_from_state_dict inner_tensors_dict[name] = state_dict.pop(prefix + name) KeyError: 'time_text_embed.timestep_embedder.linear_1.base_layer.weight._data' ```

Tracking it here:
#10550 (comment)

@Mino1289
Copy link
Author

As I read the issue and PR you linked, the issue i'm facing is most likely due to quanto not being supported with peft.
Using BitsAndBytesConfig should bypass the problem, right?
I'll try later.

@sayakpaul
Copy link
Member

Yes, you're right. 4Bit support is being added in #10578.

However, I just edited your issue title a bit to reflect that Quanto support needs to be added. Hope that is okay with you.

@sayakpaul sayakpaul changed the title Flux LoRA can't load anymore after 0.32.0 [LoRA] Quanto Flux LoRA can't load Jan 15, 2025
@Mino1289
Copy link
Author

And in 8 bit ? My issue is about 8 bit of qfloat8.

It's okay for the name.
Thanks for the quick support!

@sayakpaul
Copy link
Member

Both 4bit and 8bit bitsandbytes models should be able to load LoRAs.

For 8bit, make sure you install peft from its source. If you face problems, please open a new issue.

@Mino1289
Copy link
Author

And diffusers 0.32.1 or from source ?

@sayakpaul
Copy link
Member

Source.

@Amitg1
Copy link

Amitg1 commented Jan 23, 2025

Hi, we are getting this error only from diffusers > 0.31.
when diffusers == 0.31 we can load LoRa and perform inference.
this setup works:
diffusers 0.31.0
transformers 4.48.1
torch 2.5.1
torchaudio 2.5.1
torchvision 0.20.1

same setup with 0.31> diffusers doesn't work with KeyError: 'single_transformer_blocks.0.attn.to_k.weight'

@sayakpaul

@sayakpaul
Copy link
Member

Can you provide a reproducible snippet?

@nitinmukesh
Copy link

@sayakpaul

I think he posted here
#10512 (comment)

@Amitg1
Copy link

Amitg1 commented Jan 23, 2025

not the same person :)
our example is pretty much the same, except we are not executing quantize and freeze and are using
the following class to load the quantized transformer :

    class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
        base_class = FluxTransformer2DModel

and the following to load the quantized encoder:

    class QuantizedModelForTextEncoding(QuantizedTransformersModel):
        auto_class = AutoModelForTextEncoding

the repo we are using is Disty0/FLUX.1-dev-qint8

diffusers 0.31.0
transformers 4.48.1
torch 2.5.1
torchaudio 2.5.1
torchvision 0.20.1

@sayakpaul
Copy link
Member

#10512 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

7 participants