Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
remove weight check.
Browse files Browse the repository at this point in the history
sayakpaul committed Jan 15, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 8b13c1e commit c92758f
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
@@ -1983,7 +1983,8 @@ def _maybe_expand_transformer_param_shape_or_error_(
out_features = state_dict[lora_B_weight_name].shape[0]

# Model maybe loaded with different quantization schemes which may flatten the params.
# `bitsandbytes`, for example, flatten the weights when using 4bit.
# `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models
# preserve weight shape.
module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module)

# This means there's no need for an expansion in the params, so we simply skip.
@@ -2120,11 +2121,11 @@ def _get_weight_shape(weight: torch.Tensor):
if base_module is not None:
return _get_weight_shape(base_module.weight)
elif base_weight_param_name is not None:
module_path = (
base_weight_param_name.rsplit(".weight", 1)[0]
if base_weight_param_name.endswith(".weight")
else base_weight_param_name
)
if not base_weight_param_name.endswith(".weight"):
raise ValueError(
f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}."
)
module_path = base_weight_param_name.rsplit(".weight", 1)[0]
submodule = get_submodule_by_name(model, module_path)
return _get_weight_shape(submodule.weight)

0 comments on commit c92758f

Please sign in to comment.