Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul committed Jan 18, 2025
1 parent d978f18 commit b23f875
Showing 1 changed file with 75 additions and 0 deletions.
75 changes: 75 additions & 0 deletions src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,13 +558,88 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
new_state_dict = {**ait_sd, **te_state_dict}
return new_state_dict

def _convert_mixture_state_dict_to_diffusers(state_dict):
new_state_dict = {}

def _convert(original_key, diffusers_key, state_dict, new_state_dict):
down_key = f"{original_key}.lora_down.weight"
down_weight = state_dict.pop(down_key)
lora_rank = down_weight.shape[0]

up_weight_key = f"{original_key}.lora_up.weight"
up_weight = state_dict.pop(up_weight_key)

alpha_key = f"{original_key}.alpha"
alpha = state_dict.pop(alpha_key)

# scale weight by alpha and dim
scale = alpha / lora_rank
# calculate scale_down and scale_up
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
down_weight = down_weight * scale_down
up_weight = up_weight * scale_up

diffusers_down_key = f"{diffusers_key}.lora_A.weight"
new_state_dict[diffusers_down_key] = down_weight
new_state_dict[diffusers_down_key.replace(".lora_A.", ".lora_B.")] = up_weight

all_unique_keys = {
k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "") for k in state_dict
}
all_unique_keys = sorted(all_unique_keys)
assert all("lora_transformer_" in k for k in all_unique_keys), f"{all_unique_keys=}"

for k in all_unique_keys:
if k.startswith("lora_transformer_single_transformer_blocks_"):
i = int(k.split("lora_transformer_single_transformer_blocks_")[-1].split("_")[0])
diffusers_key = f"single_transformer_blocks.{i}"
elif k.startswith("lora_transformer_transformer_blocks_"):
i = int(k.split("lora_transformer_transformer_blocks_")[-1].split("_")[0])
diffusers_key = f"transformer_blocks.{i}"
else:
raise NotImplementedError

if "attn_" in k:
if "_to_out_0" in k:
diffusers_key += ".attn.to_out.0"
elif "_to_add_out" in k:
diffusers_key += ".attn.to_add_out"
elif any(qkv in k for qkv in ["to_q", "to_k", "to_v"]):
remaining = k.split("attn_")[-1]
diffusers_key += f".attn.{remaining}"
elif any(add_qkv in k for add_qkv in ["add_q_proj", "add_k_proj", "add_v_proj"]):
remaining = k.split("attn_")[-1]
diffusers_key += f".attn.{remaining}"

if diffusers_key == f"transformer_blocks.{i}":
print(k, diffusers_key)
_convert(k, diffusers_key, state_dict, new_state_dict)

if len(state_dict) > 0:
raise ValueError(
f"Expected an empty state dict at this point but its has these keys which couldn't be parsed: {list(state_dict.keys())}."
)

new_state_dict = {f"transformer.{k}": v for k, v in new_state_dict.items()}
return new_state_dict

# This is weird.
# https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=sharp_detailed_foot.safetensors
# has both `peft` and non-peft state dict.
has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict)
if has_peft_state_dict:
state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
return state_dict
# Another weird one.
has_mixture = any(
k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict
)
if has_mixture:
return _convert_mixture_state_dict_to_diffusers(state_dict)
return _convert_sd_scripts_to_ai_toolkit(state_dict)


Expand Down

0 comments on commit b23f875

Please sign in to comment.