From 2432f80ca37f882af733244df24b46f2d447fbcf Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 15 Jan 2025 12:40:40 +0530 Subject: [PATCH] [LoRA] feat: support loading loras into 4bit quantized Flux models. (#10578) * feat: support loading loras into 4bit quantized models. * updates * update * remove weight check. --- src/diffusers/loaders/lora_pipeline.py | 39 ++++++++++++++++++++++++-- src/diffusers/utils/__init__.py | 2 +- src/diffusers/utils/loading_utils.py | 12 ++++++++ tests/quantization/bnb/test_4bit.py | 22 +++++++++++++++ 4 files changed, 71 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 7492ba028c81..efefe5264daa 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -21,6 +21,7 @@ from ..utils import ( USE_PEFT_BACKEND, deprecate, + get_submodule_by_name, is_peft_available, is_peft_version, is_torch_version, @@ -1981,10 +1982,17 @@ def _maybe_expand_transformer_param_shape_or_error_( in_features = state_dict[lora_A_weight_name].shape[1] out_features = state_dict[lora_B_weight_name].shape[0] + # Model maybe loaded with different quantization schemes which may flatten the params. + # `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models + # preserve weight shape. + module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module) + # This means there's no need for an expansion in the params, so we simply skip. - if tuple(module_weight.shape) == (out_features, in_features): + if tuple(module_weight_shape) == (out_features, in_features): continue + # TODO (sayakpaul): We still need to consider if the module we're expanding is + # quantized and handle it accordingly if that is the case. module_out_features, module_in_features = module_weight.shape debug_message = "" if in_features > module_in_features: @@ -2080,13 +2088,16 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): base_weight_param = transformer_state_dict[base_param_name] lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"] - if base_weight_param.shape[1] > lora_A_param.shape[1]: + # TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization. + base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name) + + if base_module_shape[1] > lora_A_param.shape[1]: shape = (lora_A_param.shape[0], base_weight_param.shape[1]) expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device) expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param) lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight expanded_module_names.add(k) - elif base_weight_param.shape[1] < lora_A_param.shape[1]: + elif base_module_shape[1] < lora_A_param.shape[1]: raise NotImplementedError( f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new." ) @@ -2098,6 +2109,28 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): return lora_state_dict + @staticmethod + def _calculate_module_shape( + model: "torch.nn.Module", + base_module: "torch.nn.Linear" = None, + base_weight_param_name: str = None, + ) -> "torch.Size": + def _get_weight_shape(weight: torch.Tensor): + return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape + + if base_module is not None: + return _get_weight_shape(base_module.weight) + elif base_weight_param_name is not None: + if not base_weight_param_name.endswith(".weight"): + raise ValueError( + f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}." + ) + module_path = base_weight_param_name.rsplit(".weight", 1)[0] + submodule = get_submodule_by_name(model, module_path) + return _get_weight_shape(submodule.weight) + + raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.") + # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 5a171d078ce3..0c0613f3c43e 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -101,7 +101,7 @@ is_xformers_available, requires_backends, ) -from .loading_utils import get_module_from_name, load_image, load_video +from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video from .logging import get_logger from .outputs import BaseOutput from .peft_utils import ( diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py index bac24fa23e63..fd66aaa4da6e 100644 --- a/src/diffusers/utils/loading_utils.py +++ b/src/diffusers/utils/loading_utils.py @@ -148,3 +148,15 @@ def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]: module = new_module tensor_name = splits[-1] return module, tensor_name + + +def get_submodule_by_name(root_module, module_path: str): + current = root_module + parts = module_path.split(".") + for part in parts: + if part.isdigit(): + idx = int(part) + current = current[idx] # e.g., for nn.ModuleList or nn.Sequential + else: + current = getattr(current, part) + return current diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 1e631114f038..a9b9ab753084 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -20,6 +20,7 @@ import numpy as np import pytest import safetensors.torch +from huggingface_hub import hf_hub_download from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel from diffusers.utils import is_accelerate_version, logging @@ -568,6 +569,27 @@ def test_quality(self): max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) self.assertTrue(max_diff < 1e-3) + def test_lora_loading(self): + self.pipeline_4bit.load_lora_weights( + hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd" + ) + self.pipeline_4bit.set_adapters("hyper-sd", adapter_weights=0.125) + + output = self.pipeline_4bit( + prompt=self.prompt, + height=256, + width=256, + max_sequence_length=64, + output_type="np", + num_inference_steps=8, + generator=torch.Generator().manual_seed(42), + ).images + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.5347, 0.5342, 0.5283, 0.5093, 0.4988, 0.5093, 0.5044, 0.5015, 0.4946]) + + max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) + self.assertTrue(max_diff < 1e-3) + @slow class BaseBnb4BitSerializationTests(Base4bitTests):