Skip to content

Commit

Permalink
fix IP-Adapter weights conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
catwell authored and rodSiry committed Feb 28, 2024
1 parent 715c553 commit d6c57bd
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions scripts/conversion/convert_diffusers_ip_adapter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import argparse
from pathlib import Path
from typing import Any

import torch

from refiners.fluxion.utils import load_tensors, save_to_safetensors
from refiners.fluxion.utils import save_to_safetensors
from refiners.foundationals.latent_diffusion import SD1IPAdapter, SD1UNet, SDXLIPAdapter, SDXLUNet

# Running:
Expand Down Expand Up @@ -66,13 +65,17 @@ def main() -> None:
if args.output_path is None:
args.output_path = f"{Path(args.source_path).stem}.safetensors"

weights: dict[str, Any] = load_tensors(args.source_path, device="cpu")
# Do not use `load_tensors`: first-level values are not tensors.
weights: dict[str, dict[str, torch.Tensor]] = torch.load(args.source_path, "cpu") # type: ignore
assert isinstance(weights, dict)
assert sorted(weights.keys()) == ["image_proj", "ip_adapter"]

fine_grained = "latents" in weights["image_proj"] # aka IP-Adapter plus
image_proj_weights = weights["image_proj"]
ip_adapter_weights = weights["ip_adapter"]

fine_grained = "latents" in image_proj_weights # aka IP-Adapter plus

match len(weights["ip_adapter"]):
match len(ip_adapter_weights):
case 32:
ip_adapter = SD1IPAdapter(target=SD1UNet(in_channels=4), fine_grained=fine_grained)
cross_attn_mapping = CROSS_ATTN_MAPPING["sd15"]
Expand All @@ -87,7 +90,6 @@ def main() -> None:

state_dict: dict[str, torch.Tensor] = {}

image_proj_weights = weights["image_proj"]
image_proj_state_dict: dict[str, torch.Tensor]

if fine_grained:
Expand Down Expand Up @@ -130,7 +132,6 @@ def main() -> None:
for k, v in image_proj_state_dict.items():
state_dict[f"image_proj.{k}"] = v

ip_adapter_weights: dict[str, torch.Tensor] = weights["ip_adapter"]
assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2

for i, _ in enumerate(ip_adapter.sub_adapters):
Expand Down

0 comments on commit d6c57bd

Please sign in to comment.