Skip to content

Commit

Permalink
allow loading of sd models from safetensors without online lookups us…
Browse files Browse the repository at this point in the history
…ing local config files (#5019)

finish config_files implementation
  • Loading branch information
vladmandic authored and patrickvonplaten committed Sep 14, 2023
1 parent 0c2f1cc commit 7512fc4
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
1 change: 1 addition & 0 deletions scripts/convert_original_stable_diffusion_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@
pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path_or_dict=args.checkpoint_path,
original_config_file=args.original_config_file,
config_files=args.config_files,
image_size=args.image_size,
prediction_type=args.prediction_type,
model_type=args.pipeline_type,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2098,6 +2098,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt

original_config_file = kwargs.pop("original_config_file", None)
config_files = kwargs.pop("config_files", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
Expand Down Expand Up @@ -2215,6 +2216,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
vae=vae,
tokenizer=tokenizer,
original_config_file=original_config_file,
config_files=config_files,
)

if torch_dtype is not None:
Expand Down
26 changes: 19 additions & 7 deletions src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,25 +1256,37 @@ def download_from_original_stable_diffusion_ckpt(
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
config_url = None

# model_type = "v1"
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
if config_files is not None and "v1" in config_files:
original_config_file = config_files["v1"]
else:
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"

if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024:
# model_type = "v2"
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"

if config_files is not None and "v2" in config_files:
original_config_file = config_files["v2"]
else:
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
if global_step == 110000:
# v2.1 needs to upcast attention
upcast_attention = True
elif key_name_sd_xl_base in checkpoint:
# only base xl has two text embedders
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
if config_files is not None and "xl" in config_files:
original_config_file = config_files["xl"]
else:
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
elif key_name_sd_xl_refiner in checkpoint:
# only refiner xl has embedder and one text embedders
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"

original_config_file = BytesIO(requests.get(config_url).content)
if config_files is not None and "xl_refiner" in config_files:
original_config_file = config_files["xl_refiner"]
else:
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
if config_url is not None:
original_config_file = BytesIO(requests.get(config_url).content)

original_config = OmegaConf.load(original_config_file)

Expand Down

0 comments on commit 7512fc4

Please sign in to comment.