Skip to content

Commit

Permalink
first draft model loading refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc committed Jan 17, 2025
1 parent aeac0a0 commit e54c540
Show file tree
Hide file tree
Showing 16 changed files with 625 additions and 647 deletions.
4 changes: 2 additions & 2 deletions scripts/convert_sana_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
SanaPipeline,
SanaTransformer2DModel,
)
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.models.modeling_utils import load_state_dict_into_meta_model
from diffusers.utils.import_utils import is_accelerate_available


Expand Down Expand Up @@ -189,7 +189,7 @@ def main(args):
)

if is_accelerate_available():
load_model_dict_into_meta(transformer, converted_state_dict)
load_state_dict_into_meta_model(transformer, converted_state_dict)
else:
transformer.load_state_dict(converted_state_dict, strict=True, assign=True)

Expand Down
6 changes: 3 additions & 3 deletions scripts/convert_sd3_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from diffusers import AutoencoderKL, SD3Transformer2DModel
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.models.modeling_utils import load_state_dict_into_meta_model
from diffusers.utils.import_utils import is_accelerate_available


Expand Down Expand Up @@ -319,7 +319,7 @@ def main(args):
dual_attention_layers=attn2_layers,
)
if is_accelerate_available():
load_model_dict_into_meta(transformer, converted_transformer_state_dict)
load_state_dict_into_meta_model(transformer, converted_transformer_state_dict)
else:
transformer.load_state_dict(converted_transformer_state_dict, strict=True)

Expand All @@ -339,7 +339,7 @@ def main(args):
)
converted_vae_state_dict = convert_ldm_vae_checkpoint(original_ckpt, vae.config)
if is_accelerate_available():
load_model_dict_into_meta(vae, converted_vae_state_dict)
load_state_dict_into_meta_model(vae, converted_vae_state_dict)
else:
vae.load_state_dict(converted_vae_state_dict, strict=True)

Expand Down
8 changes: 4 additions & 4 deletions scripts/convert_stable_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
StableAudioPipeline,
StableAudioProjectionModel,
)
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.models.modeling_utils import load_state_dict_into_meta_model
from diffusers.utils import is_accelerate_available


Expand Down Expand Up @@ -221,7 +221,7 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
], # assume `seconds_start` and `seconds_total` have the same min / max values.
)
if is_accelerate_available():
load_model_dict_into_meta(projection_model, projection_model_state_dict)
load_state_dict_into_meta_model(projection_model, projection_model_state_dict)
else:
projection_model.load_state_dict(projection_model_state_dict)

Expand All @@ -242,7 +242,7 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
cross_attention_input_dim=model_config["cond_token_dim"],
)
if is_accelerate_available():
load_model_dict_into_meta(model, model_state_dict)
load_state_dict_into_meta_model(model, model_state_dict)
else:
model.load_state_dict(model_state_dict)

Expand All @@ -260,7 +260,7 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
)

if is_accelerate_available():
load_model_dict_into_meta(autoencoder, autoencoder_state_dict)
load_state_dict_into_meta_model(autoencoder, autoencoder_state_dict)
else:
autoencoder.load_state_dict(autoencoder_state_dict)

Expand Down
6 changes: 3 additions & 3 deletions scripts/convert_stable_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
from diffusers.models import StableCascadeUNet
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.models.modeling_utils import load_state_dict_into_meta_model
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.utils import is_accelerate_available

Expand Down Expand Up @@ -126,7 +126,7 @@
switch_level=[False],
)
if is_accelerate_available():
load_model_dict_into_meta(prior_model, prior_state_dict)
load_state_dict_into_meta_model(prior_model, prior_state_dict)
else:
prior_model.load_state_dict(prior_state_dict)

Expand Down Expand Up @@ -181,7 +181,7 @@
)

if is_accelerate_available():
load_model_dict_into_meta(decoder, decoder_state_dict)
load_state_dict_into_meta_model(decoder, decoder_state_dict)
else:
decoder.load_state_dict(decoder_state_dict)

Expand Down
6 changes: 3 additions & 3 deletions scripts/convert_stable_cascade_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
from diffusers.models import StableCascadeUNet
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.models.modeling_utils import load_state_dict_into_meta_model
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.utils import is_accelerate_available

Expand Down Expand Up @@ -133,7 +133,7 @@
)

if is_accelerate_available():
load_model_dict_into_meta(prior_model, prior_state_dict)
load_state_dict_into_meta_model(prior_model, prior_state_dict)
else:
prior_model.load_state_dict(prior_state_dict)

Expand Down Expand Up @@ -189,7 +189,7 @@
)

if is_accelerate_available():
load_model_dict_into_meta(decoder, decoder_state_dict)
load_state_dict_into_meta_model(decoder, decoder_state_dict)
else:
decoder.load_state_dict(decoder_state_dict)

Expand Down
84 changes: 6 additions & 78 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,11 @@
# limitations under the License.
import importlib
import inspect
import re
from contextlib import nullcontext
from typing import Optional

import torch
from huggingface_hub.utils import validate_hf_hub_args

from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging
from ..utils import deprecate, logging
from .single_file_utils import (
SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers,
Expand Down Expand Up @@ -49,12 +45,6 @@
logger = logging.get_logger(__name__)


if is_accelerate_available():
from accelerate import init_empty_weights

from ..models.modeling_utils import load_model_dict_into_meta


SINGLE_FILE_LOADABLE_CLASSES = {
"StableCascadeUNet": {
"checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers,
Expand Down Expand Up @@ -234,9 +224,6 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
subfolder = kwargs.pop("subfolder", None)
revision = kwargs.pop("revision", None)
config_revision = kwargs.pop("config_revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
quantization_config = kwargs.pop("quantization_config", None)
device = kwargs.pop("device", None)
disable_mmap = kwargs.pop("disable_mmap", False)

if isinstance(pretrained_model_link_or_path_or_dict, dict):
Expand All @@ -252,12 +239,6 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
revision=revision,
disable_mmap=disable_mmap,
)
if quantization_config is not None:
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
hf_quantizer.validate_environment()

else:
hf_quantizer = None

mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]

Expand Down Expand Up @@ -336,62 +317,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
)

ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
model = cls.from_config(diffusers_model_config)

# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
return cls.from_pretrained(
pretrained_model_name_or_path=None,
state_dict=diffusers_format_checkpoint,
config=diffusers_model_config,
**kwargs,
)
if use_keep_in_fp32_modules:
keep_in_fp32_modules = cls._keep_in_fp32_modules
if not isinstance(keep_in_fp32_modules, list):
keep_in_fp32_modules = [keep_in_fp32_modules]

else:
keep_in_fp32_modules = []

if hf_quantizer is not None:
hf_quantizer.preprocess_model(
model=model,
device_map=None,
state_dict=diffusers_format_checkpoint,
keep_in_fp32_modules=keep_in_fp32_modules,
)

if is_accelerate_available():
param_device = torch.device(device) if device else torch.device("cpu")
named_buffers = model.named_buffers()
unexpected_keys = load_model_dict_into_meta(
model,
diffusers_format_checkpoint,
dtype=torch_dtype,
device=param_device,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
named_buffers=named_buffers,
)

else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)

if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)

if hf_quantizer is not None:
hf_quantizer.postprocess_model(model)
model.hf_quantizer = hf_quantizer

if torch_dtype is not None and hf_quantizer is None:
model.to(torch_dtype)

model.eval()

return model
26 changes: 4 additions & 22 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
if is_accelerate_available():
from accelerate import init_empty_weights

from ..models.modeling_utils import load_model_dict_into_meta
from ..models.modeling_utils import load_state_dict_into_meta_model

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -1588,18 +1588,9 @@ def create_diffusers_clip_model_from_ldm(
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")

if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
load_state_dict_into_meta_model(model, diffusers_format_checkpoint, dtype=torch_dtype)
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)

if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
model.load_state_dict(diffusers_format_checkpoint, strict=False)

if torch_dtype is not None:
model.to(torch_dtype)
Expand Down Expand Up @@ -2056,16 +2047,7 @@ def create_diffusers_t5_model_from_checkpoint(
diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)

if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)

load_state_dict_into_meta_model(model, diffusers_format_checkpoint, dtype=torch_dtype)
else:
model.load_state_dict(diffusers_format_checkpoint)

Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/loaders/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
ImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import load_model_dict_into_meta
from ..models.modeling_utils import load_state_dict_into_meta_model
from ..utils import (
is_accelerate_available,
is_torch_version,
Expand Down Expand Up @@ -82,7 +82,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
if not low_cpu_mem_usage:
image_projection.load_state_dict(updated_state_dict, strict=True)
else:
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
load_state_dict_into_meta_model(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)

return image_projection

Expand Down Expand Up @@ -153,7 +153,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
else:
device = self.device
dtype = self.dtype
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
load_state_dict_into_meta_model(attn_procs[name], value_dict, device=device, dtype=dtype)

key_id += 1

Expand Down
8 changes: 5 additions & 3 deletions src/diffusers/loaders/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ..models.embeddings import IPAdapterTimeImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict_into_meta_model


class SD3Transformer2DLoadersMixin:
Expand Down Expand Up @@ -59,7 +59,7 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _
if not low_cpu_mem_usage:
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
else:
load_model_dict_into_meta(
load_state_dict_into_meta_model(
attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype
)

Expand All @@ -86,4 +86,6 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _
if not low_cpu_mem_usage:
self.image_proj.load_state_dict(state_dict["image_proj"], strict=True)
else:
load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype)
load_state_dict_into_meta_model(
self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype
)
6 changes: 3 additions & 3 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict
from ..models.modeling_utils import load_state_dict, load_state_dict_into_meta_model
from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
Expand Down Expand Up @@ -753,7 +753,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
if not low_cpu_mem_usage:
image_projection.load_state_dict(updated_state_dict, strict=True)
else:
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
load_state_dict_into_meta_model(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)

return image_projection

Expand Down Expand Up @@ -846,7 +846,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
else:
device = next(iter(value_dict.values())).device
dtype = next(iter(value_dict.values())).dtype
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
load_state_dict_into_meta_model(attn_procs[name], value_dict, device=device, dtype=dtype)

key_id += 2

Expand Down
Loading

0 comments on commit e54c540

Please sign in to comment.