From 779c17b7840690741480d2d719eb440a547a7328 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 14 Jan 2025 17:25:31 +0530 Subject: [PATCH 1/4] feat: support loading loras into 4bit quantized models. --- src/diffusers/loaders/lora_pipeline.py | 25 ++++++++++++++++++++++--- tests/quantization/bnb/test_4bit.py | 24 ++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 7492ba028c81..65474930b9ea 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1982,9 +1982,19 @@ def _maybe_expand_transformer_param_shape_or_error_( out_features = state_dict[lora_B_weight_name].shape[0] # This means there's no need for an expansion in the params, so we simply skip. - if tuple(module_weight.shape) == (out_features, in_features): + module_weight_shape = module_weight.shape + expansion_shape = (out_features, in_features) + quantization_config = getattr(transformer, "quantization_config", None) + if quantization_config and quantization_config.quant_method == "bitsandbytes": + if quantization_config.load_in_4bit: + expansion_shape = torch.Size(expansion_shape).numel() + expansion_shape = ((expansion_shape + 1) // 2, 1) + + if tuple(module_weight_shape) == expansion_shape: continue + # TODO (sayakpaul): We still need to consider if the module we're expanding is + # quantized and handle it accordingly if that is the case. module_out_features, module_in_features = module_weight.shape debug_message = "" if in_features > module_in_features: @@ -2080,13 +2090,22 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): base_weight_param = transformer_state_dict[base_param_name] lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"] - if base_weight_param.shape[1] > lora_A_param.shape[1]: + # TODO (sayakpaul): Handle the cases when we actually need to expand. + base_out_feature_shape = base_weight_param.shape[1] + lora_A_out_feature_shape = lora_A_param.shape[1] + quantization_config = getattr(transformer, "quantization_config", None) + if quantization_config and quantization_config.quant_method == "bitsandbytes": + if quantization_config.load_in_4bit: + lora_A_out_feature_shape = lora_A_param.shape.numel() + lora_A_out_feature_shape = ((lora_A_out_feature_shape + 1) // 2, 1)[1] + + if base_out_feature_shape > lora_A_out_feature_shape: shape = (lora_A_param.shape[0], base_weight_param.shape[1]) expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device) expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param) lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight expanded_module_names.add(k) - elif base_weight_param.shape[1] < lora_A_param.shape[1]: + elif lora_A_out_feature_shape < lora_A_out_feature_shape: raise NotImplementedError( f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new." ) diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 1e631114f038..6a56a60f4f8e 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -20,6 +20,7 @@ import numpy as np import pytest import safetensors.torch +from huggingface_hub import hf_hub_download from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel from diffusers.utils import is_accelerate_version, logging @@ -32,6 +33,7 @@ numpy_cosine_similarity_distance, require_accelerate, require_bitsandbytes_version_greater, + require_peft_version_greater, require_torch, require_torch_gpu, require_transformers_version_greater, @@ -568,6 +570,28 @@ def test_quality(self): max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) self.assertTrue(max_diff < 1e-3) + @require_peft_version_greater("0.14.0") + def test_lora_loading_works(self): + self.pipeline_4bit.load_lora_weights( + hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd" + ) + self.pipeline_4bit.set_adapters("hyper-sd", adapter_weights=0.125) + + output = self.pipeline_4bit( + prompt=self.prompt, + height=256, + width=256, + max_sequence_length=64, + output_type="np", + num_inference_steps=8, + generator=torch.Generator().manual_seed(42), + ).images + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.5347, 0.5342, 0.5283, 0.5093, 0.4988, 0.5093, 0.5044, 0.5015, 0.4946]) + + max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) + self.assertTrue(max_diff < 1e-3) + @slow class BaseBnb4BitSerializationTests(Base4bitTests): From d3d8ef28e16fa4027d938b364c80005e2e029a7f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 15 Jan 2025 09:44:14 +0530 Subject: [PATCH 2/4] updates --- src/diffusers/loaders/lora_pipeline.py | 51 ++++++++++++++++---------- src/diffusers/utils/__init__.py | 2 +- src/diffusers/utils/loading_utils.py | 13 +++++++ tests/quantization/bnb/test_4bit.py | 4 +- 4 files changed, 47 insertions(+), 23 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 65474930b9ea..e40a8a4cc69b 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -21,6 +21,7 @@ from ..utils import ( USE_PEFT_BACKEND, deprecate, + get_submodule_by_name, is_peft_available, is_peft_version, is_torch_version, @@ -1981,16 +1982,12 @@ def _maybe_expand_transformer_param_shape_or_error_( in_features = state_dict[lora_A_weight_name].shape[1] out_features = state_dict[lora_B_weight_name].shape[0] + # Model maybe loaded with different quantization schemes which may flatten the params. + # `bitsandbytes`, for example, flatten the weights when using 4bit. + module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module) + # This means there's no need for an expansion in the params, so we simply skip. - module_weight_shape = module_weight.shape - expansion_shape = (out_features, in_features) - quantization_config = getattr(transformer, "quantization_config", None) - if quantization_config and quantization_config.quant_method == "bitsandbytes": - if quantization_config.load_in_4bit: - expansion_shape = torch.Size(expansion_shape).numel() - expansion_shape = ((expansion_shape + 1) // 2, 1) - - if tuple(module_weight_shape) == expansion_shape: + if tuple(module_weight_shape) == (out_features, in_features): continue # TODO (sayakpaul): We still need to consider if the module we're expanding is @@ -2090,22 +2087,16 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): base_weight_param = transformer_state_dict[base_param_name] lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"] - # TODO (sayakpaul): Handle the cases when we actually need to expand. - base_out_feature_shape = base_weight_param.shape[1] - lora_A_out_feature_shape = lora_A_param.shape[1] - quantization_config = getattr(transformer, "quantization_config", None) - if quantization_config and quantization_config.quant_method == "bitsandbytes": - if quantization_config.load_in_4bit: - lora_A_out_feature_shape = lora_A_param.shape.numel() - lora_A_out_feature_shape = ((lora_A_out_feature_shape + 1) // 2, 1)[1] + # TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization. + base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name) - if base_out_feature_shape > lora_A_out_feature_shape: + if base_module_shape[1] > lora_A_param.shape[1]: shape = (lora_A_param.shape[0], base_weight_param.shape[1]) expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device) expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param) lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight expanded_module_names.add(k) - elif lora_A_out_feature_shape < lora_A_out_feature_shape: + elif base_module_shape[1] < lora_A_param.shape[1]: raise NotImplementedError( f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new." ) @@ -2117,6 +2108,28 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): return lora_state_dict + @staticmethod + def _calculate_module_shape( + model: "torch.nn.Module", + base_module: "torch.nn.Linear" = None, + base_weight_param_name: str = None, + ) -> "torch.Size": + def _get_weight_shape(weight: torch.Tensor): + return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape + + if base_module is not None: + return _get_weight_shape(base_module.weight) + elif base_weight_param_name is not None: + module_path = ( + base_weight_param_name.rsplit(".weight", 1)[0] + if base_weight_param_name.endswith(".weight") + else base_weight_param_name + ) + submodule = get_submodule_by_name(model, module_path) + return _get_weight_shape(submodule.weight) + + raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.") + # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 5a171d078ce3..0c0613f3c43e 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -101,7 +101,7 @@ is_xformers_available, requires_backends, ) -from .loading_utils import get_module_from_name, load_image, load_video +from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video from .logging import get_logger from .outputs import BaseOutput from .peft_utils import ( diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py index bac24fa23e63..02e9f3cd0757 100644 --- a/src/diffusers/utils/loading_utils.py +++ b/src/diffusers/utils/loading_utils.py @@ -148,3 +148,16 @@ def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]: module = new_module tensor_name = splits[-1] return module, tensor_name + + +def get_submodule_by_name(root_module, module_path: str): + current = root_module + parts = module_path.split(".") + for part in parts: + # If part is integer-like and the current module supports indexing, convert to int + if part.isdigit(): + idx = int(part) + current = current[idx] # e.g., for nn.ModuleList or nn.Sequential + else: + current = getattr(current, part) + return current diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 6a56a60f4f8e..a9b9ab753084 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -33,7 +33,6 @@ numpy_cosine_similarity_distance, require_accelerate, require_bitsandbytes_version_greater, - require_peft_version_greater, require_torch, require_torch_gpu, require_transformers_version_greater, @@ -570,8 +569,7 @@ def test_quality(self): max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) self.assertTrue(max_diff < 1e-3) - @require_peft_version_greater("0.14.0") - def test_lora_loading_works(self): + def test_lora_loading(self): self.pipeline_4bit.load_lora_weights( hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd" ) From 8b13c1e412206ed86678ee423343ca179d64f6e9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 15 Jan 2025 09:51:08 +0530 Subject: [PATCH 3/4] update --- src/diffusers/utils/loading_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py index 02e9f3cd0757..fd66aaa4da6e 100644 --- a/src/diffusers/utils/loading_utils.py +++ b/src/diffusers/utils/loading_utils.py @@ -154,7 +154,6 @@ def get_submodule_by_name(root_module, module_path: str): current = root_module parts = module_path.split(".") for part in parts: - # If part is integer-like and the current module supports indexing, convert to int if part.isdigit(): idx = int(part) current = current[idx] # e.g., for nn.ModuleList or nn.Sequential From c92758fb328865750cc8c383ccf849ba7738a00a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 15 Jan 2025 11:59:19 +0530 Subject: [PATCH 4/4] remove weight check. --- src/diffusers/loaders/lora_pipeline.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index e40a8a4cc69b..efefe5264daa 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1983,7 +1983,8 @@ def _maybe_expand_transformer_param_shape_or_error_( out_features = state_dict[lora_B_weight_name].shape[0] # Model maybe loaded with different quantization schemes which may flatten the params. - # `bitsandbytes`, for example, flatten the weights when using 4bit. + # `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models + # preserve weight shape. module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module) # This means there's no need for an expansion in the params, so we simply skip. @@ -2120,11 +2121,11 @@ def _get_weight_shape(weight: torch.Tensor): if base_module is not None: return _get_weight_shape(base_module.weight) elif base_weight_param_name is not None: - module_path = ( - base_weight_param_name.rsplit(".weight", 1)[0] - if base_weight_param_name.endswith(".weight") - else base_weight_param_name - ) + if not base_weight_param_name.endswith(".weight"): + raise ValueError( + f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}." + ) + module_path = base_weight_param_name.rsplit(".weight", 1)[0] submodule = get_submodule_by_name(model, module_path) return _get_weight_shape(submodule.weight)