From 72b6259ecb9a43d2e915246a239126aca67b9a87 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Sat, 18 Jan 2025 11:04:13 +0100 Subject: [PATCH] fix dduf --- src/diffusers/loaders/unet.py | 2 +- src/diffusers/models/model_loading_utils.py | 19 ------------------- src/diffusers/models/modeling_utils.py | 6 ++++-- 3 files changed, 5 insertions(+), 22 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index c47f27fbf171..c68349c36dba 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -30,7 +30,7 @@ IPAdapterPlusImageProjection, MultiIPAdapterImageProjection, ) -from ..models.modeling_utils import load_state_dict, load_model_dict_into_meta +from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict from ..utils import ( USE_PEFT_BACKEND, _get_model_file, diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 33c07a2e2f9a..93b3a7fbc609 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -307,25 +307,6 @@ def load_model_dict_into_meta( return error_msgs, offload_index, state_dict_index -def load_model_dict_into_meta( - model, - state_dict: OrderedDict, - dtype: Optional[Union[str, torch.dtype]] = None, - model_name_or_path: Optional[str] = None, - hf_quantizer=None, - keep_in_fp32_modules=None, - device_map=None, - unexpected_keys=None, - is_safetensors=None, - offload_folder=None, - offload_index=None, - state_dict_index=None, - state_dict_folder=None, -) -> List[str]: - error_msgs = [] - return error_msgs, offload_index, state_dict_index - - def _load_state_dict_into_model( model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False ) -> List[str]: diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index c91e1c042ecd..9d93a1946e88 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -64,8 +64,8 @@ _fetch_index_file, _fetch_index_file_legacy, _load_state_dict_into_model, - load_state_dict, load_model_dict_into_meta, + load_state_dict, ) @@ -1033,6 +1033,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P dtype=torch_dtype, hf_quantizer=hf_quantizer, keep_in_fp32_modules=keep_in_fp32_modules, + dduf_entries=dduf_entries, ) loading_info = { "missing_keys": missing_keys, @@ -1156,6 +1157,7 @@ def _load_pretrained_model( device_map=None, offload_state_dict=None, offload_folder=None, + dduf_entries=None, ): model_state_dict = model.state_dict() expected_keys = list(model_state_dict.keys()) @@ -1209,7 +1211,7 @@ def _load_pretrained_model( if len(resolved_archive_file) > 1: resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards") for shard_file in resolved_archive_file: - state_dict = load_state_dict(shard_file) + state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries) model._fix_state_dict_keys_on_load(state_dict) def _find_mismatched_keys(