Skip to content

Commit

Permalink
fix dduf
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc committed Jan 18, 2025
1 parent 17c1be2 commit 72b6259
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import load_state_dict, load_model_dict_into_meta
from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
Expand Down
19 changes: 0 additions & 19 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,25 +307,6 @@ def load_model_dict_into_meta(
return error_msgs, offload_index, state_dict_index


def load_model_dict_into_meta(
model,
state_dict: OrderedDict,
dtype: Optional[Union[str, torch.dtype]] = None,
model_name_or_path: Optional[str] = None,
hf_quantizer=None,
keep_in_fp32_modules=None,
device_map=None,
unexpected_keys=None,
is_safetensors=None,
offload_folder=None,
offload_index=None,
state_dict_index=None,
state_dict_folder=None,
) -> List[str]:
error_msgs = []
return error_msgs, offload_index, state_dict_index


def _load_state_dict_into_model(
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
) -> List[str]:
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@
_fetch_index_file,
_fetch_index_file_legacy,
_load_state_dict_into_model,
load_state_dict,
load_model_dict_into_meta,
load_state_dict,
)


Expand Down Expand Up @@ -1033,6 +1033,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
dtype=torch_dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
dduf_entries=dduf_entries,
)
loading_info = {
"missing_keys": missing_keys,
Expand Down Expand Up @@ -1156,6 +1157,7 @@ def _load_pretrained_model(
device_map=None,
offload_state_dict=None,
offload_folder=None,
dduf_entries=None,
):
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
Expand Down Expand Up @@ -1209,7 +1211,7 @@ def _load_pretrained_model(
if len(resolved_archive_file) > 1:
resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file)
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
model._fix_state_dict_keys_on_load(state_dict)

def _find_mismatched_keys(
Expand Down

0 comments on commit 72b6259

Please sign in to comment.