Skip to content

Commit

Permalink
named_buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
hlky committed Jan 14, 2025
1 parent 3d56e94 commit bb2e228
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,13 +362,15 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =

if is_accelerate_available():
param_device = torch.device(device) if device else torch.device("cpu")
named_buffers = model.named_buffers()
unexpected_keys = load_model_dict_into_meta(
model,
diffusers_format_checkpoint,
dtype=torch_dtype,
device=param_device,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
named_buffers=named_buffers,
)

else:
Expand Down
8 changes: 6 additions & 2 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from array import array
from collections import OrderedDict
from pathlib import Path
from typing import List, Optional, Union
from typing import Iterator, List, Optional, Tuple, Union

import safetensors
import torch
Expand Down Expand Up @@ -185,6 +185,7 @@ def load_model_dict_into_meta(
model_name_or_path: Optional[str] = None,
hf_quantizer=None,
keep_in_fp32_modules=None,
named_buffers: Optional[Iterator[Tuple[str, torch.Tensor]]] = None,
) -> List[str]:
if device is not None and not isinstance(device, (str, torch.device)):
raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
Expand Down Expand Up @@ -246,7 +247,10 @@ def load_model_dict_into_meta(
else:
set_module_tensor_to_device(model, param_name, device, value=param)

for param_name, param in model.named_buffers():
if named_buffers is None:
return unexpected_keys

for param_name, param in named_buffers:
if is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
):
Expand Down
3 changes: 3 additions & 0 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
" those weights or else make sure your checkpoint file is correct."
)

named_buffers = model.named_buffers()

unexpected_keys = load_model_dict_into_meta(
model,
state_dict,
Expand All @@ -910,6 +912,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
model_name_or_path=pretrained_model_name_or_path,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
named_buffers=named_buffers,
)

if cls._keys_to_ignore_on_load_unexpected is not None:
Expand Down

0 comments on commit bb2e228

Please sign in to comment.