Skip to content

Commit

Permalink
revert name change
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc committed Jan 17, 2025
1 parent e54c540 commit 645abc9
Show file tree
Hide file tree
Showing 11 changed files with 30 additions and 32 deletions.
4 changes: 2 additions & 2 deletions scripts/convert_sana_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
SanaPipeline,
SanaTransformer2DModel,
)
from diffusers.models.modeling_utils import load_state_dict_into_meta_model
from diffusers.models.modeling_utils import load_state_dict_into_meta
from diffusers.utils.import_utils import is_accelerate_available


Expand Down Expand Up @@ -189,7 +189,7 @@ def main(args):
)

if is_accelerate_available():
load_state_dict_into_meta_model(transformer, converted_state_dict)
load_state_dict_into_meta(transformer, converted_state_dict)
else:
transformer.load_state_dict(converted_state_dict, strict=True, assign=True)

Expand Down
6 changes: 3 additions & 3 deletions scripts/convert_sd3_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from diffusers import AutoencoderKL, SD3Transformer2DModel
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
from diffusers.models.modeling_utils import load_state_dict_into_meta_model
from diffusers.models.modeling_utils import load_state_dict_into_meta
from diffusers.utils.import_utils import is_accelerate_available


Expand Down Expand Up @@ -319,7 +319,7 @@ def main(args):
dual_attention_layers=attn2_layers,
)
if is_accelerate_available():
load_state_dict_into_meta_model(transformer, converted_transformer_state_dict)
load_state_dict_into_meta(transformer, converted_transformer_state_dict)
else:
transformer.load_state_dict(converted_transformer_state_dict, strict=True)

Expand All @@ -339,7 +339,7 @@ def main(args):
)
converted_vae_state_dict = convert_ldm_vae_checkpoint(original_ckpt, vae.config)
if is_accelerate_available():
load_state_dict_into_meta_model(vae, converted_vae_state_dict)
load_state_dict_into_meta(vae, converted_vae_state_dict)
else:
vae.load_state_dict(converted_vae_state_dict, strict=True)

Expand Down
8 changes: 4 additions & 4 deletions scripts/convert_stable_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
StableAudioPipeline,
StableAudioProjectionModel,
)
from diffusers.models.modeling_utils import load_state_dict_into_meta_model
from diffusers.models.modeling_utils import load_state_dict_into_meta
from diffusers.utils import is_accelerate_available


Expand Down Expand Up @@ -221,7 +221,7 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
], # assume `seconds_start` and `seconds_total` have the same min / max values.
)
if is_accelerate_available():
load_state_dict_into_meta_model(projection_model, projection_model_state_dict)
load_state_dict_into_meta(projection_model, projection_model_state_dict)
else:
projection_model.load_state_dict(projection_model_state_dict)

Expand All @@ -242,7 +242,7 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
cross_attention_input_dim=model_config["cond_token_dim"],
)
if is_accelerate_available():
load_state_dict_into_meta_model(model, model_state_dict)
load_state_dict_into_meta(model, model_state_dict)
else:
model.load_state_dict(model_state_dict)

Expand All @@ -260,7 +260,7 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
)

if is_accelerate_available():
load_state_dict_into_meta_model(autoencoder, autoencoder_state_dict)
load_state_dict_into_meta(autoencoder, autoencoder_state_dict)
else:
autoencoder.load_state_dict(autoencoder_state_dict)

Expand Down
6 changes: 3 additions & 3 deletions scripts/convert_stable_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
from diffusers.models import StableCascadeUNet
from diffusers.models.modeling_utils import load_state_dict_into_meta_model
from diffusers.models.modeling_utils import load_state_dict_into_meta
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.utils import is_accelerate_available

Expand Down Expand Up @@ -126,7 +126,7 @@
switch_level=[False],
)
if is_accelerate_available():
load_state_dict_into_meta_model(prior_model, prior_state_dict)
load_state_dict_into_meta(prior_model, prior_state_dict)
else:
prior_model.load_state_dict(prior_state_dict)

Expand Down Expand Up @@ -181,7 +181,7 @@
)

if is_accelerate_available():
load_state_dict_into_meta_model(decoder, decoder_state_dict)
load_state_dict_into_meta(decoder, decoder_state_dict)
else:
decoder.load_state_dict(decoder_state_dict)

Expand Down
6 changes: 3 additions & 3 deletions scripts/convert_stable_cascade_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
from diffusers.models import StableCascadeUNet
from diffusers.models.modeling_utils import load_state_dict_into_meta_model
from diffusers.models.modeling_utils import load_state_dict_into_meta
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.utils import is_accelerate_available

Expand Down Expand Up @@ -133,7 +133,7 @@
)

if is_accelerate_available():
load_state_dict_into_meta_model(prior_model, prior_state_dict)
load_state_dict_into_meta(prior_model, prior_state_dict)
else:
prior_model.load_state_dict(prior_state_dict)

Expand Down Expand Up @@ -189,7 +189,7 @@
)

if is_accelerate_available():
load_state_dict_into_meta_model(decoder, decoder_state_dict)
load_state_dict_into_meta(decoder, decoder_state_dict)
else:
decoder.load_state_dict(decoder_state_dict)

Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
if is_accelerate_available():
from accelerate import init_empty_weights

from ..models.modeling_utils import load_state_dict_into_meta_model
from ..models.modeling_utils import load_state_dict_into_meta

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -1588,7 +1588,7 @@ def create_diffusers_clip_model_from_ldm(
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")

if is_accelerate_available():
load_state_dict_into_meta_model(model, diffusers_format_checkpoint, dtype=torch_dtype)
load_state_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
else:
model.load_state_dict(diffusers_format_checkpoint, strict=False)

Expand Down Expand Up @@ -2047,7 +2047,7 @@ def create_diffusers_t5_model_from_checkpoint(
diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)

if is_accelerate_available():
load_state_dict_into_meta_model(model, diffusers_format_checkpoint, dtype=torch_dtype)
load_state_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
else:
model.load_state_dict(diffusers_format_checkpoint)

Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/loaders/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
ImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import load_state_dict_into_meta_model
from ..models.modeling_utils import load_state_dict_into_meta
from ..utils import (
is_accelerate_available,
is_torch_version,
Expand Down Expand Up @@ -82,7 +82,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
if not low_cpu_mem_usage:
image_projection.load_state_dict(updated_state_dict, strict=True)
else:
load_state_dict_into_meta_model(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
load_state_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)

return image_projection

Expand Down Expand Up @@ -153,7 +153,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
else:
device = self.device
dtype = self.dtype
load_state_dict_into_meta_model(attn_procs[name], value_dict, device=device, dtype=dtype)
load_state_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)

key_id += 1

Expand Down
8 changes: 3 additions & 5 deletions src/diffusers/loaders/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ..models.embeddings import IPAdapterTimeImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict_into_meta_model
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict_into_meta


class SD3Transformer2DLoadersMixin:
Expand Down Expand Up @@ -59,7 +59,7 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _
if not low_cpu_mem_usage:
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
else:
load_state_dict_into_meta_model(
load_state_dict_into_meta(
attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype
)

Expand All @@ -86,6 +86,4 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _
if not low_cpu_mem_usage:
self.image_proj.load_state_dict(state_dict["image_proj"], strict=True)
else:
load_state_dict_into_meta_model(
self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype
)
load_state_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype)
6 changes: 3 additions & 3 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import load_state_dict, load_state_dict_into_meta_model
from ..models.modeling_utils import load_state_dict, load_state_dict_into_meta
from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
Expand Down Expand Up @@ -753,7 +753,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
if not low_cpu_mem_usage:
image_projection.load_state_dict(updated_state_dict, strict=True)
else:
load_state_dict_into_meta_model(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
load_state_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)

return image_projection

Expand Down Expand Up @@ -846,7 +846,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
else:
device = next(iter(value_dict.values())).device
dtype = next(iter(value_dict.values())).dtype
load_state_dict_into_meta_model(attn_procs[name], value_dict, device=device, dtype=dtype)
load_state_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)

key_id += 2

Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def load_state_dict(
)


def load_state_dict_into_meta_model(
def load_state_dict_into_meta(
model,
state_dict: OrderedDict,
dtype: Optional[Union[str, torch.dtype]] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
_fetch_index_file_legacy,
_load_state_dict_into_model,
load_state_dict,
load_state_dict_into_meta_model,
load_state_dict_into_meta,
)


Expand Down Expand Up @@ -1244,7 +1244,7 @@ def _find_mismatched_keys(
)

if low_cpu_mem_usage:
new_error_msgs, offload_index, state_dict_index = load_state_dict_into_meta_model(
new_error_msgs, offload_index, state_dict_index = load_state_dict_into_meta(
model,
state_dict,
device_map=device_map,
Expand Down

0 comments on commit 645abc9

Please sign in to comment.