From beacaa55282e003d57d5f3e0cc6bc9c270620506 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 22 Jan 2025 19:49:37 +0530 Subject: [PATCH] [core] Layerwise Upcasting (#10347) * update * update * make style * remove dynamo disable * add coauthor Co-Authored-By: Dhruv Nair * update * update * update * update mixin * add some basic tests * update * update * non_blocking * improvements * update * norm.* -> norm * apply suggestions from review * add example * update hook implementation to the latest changes from pyramid attention broadcast * deinitialize should raise an error * update doc page * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * update docs * update * refactor * fix _always_upcast_modules for asym ae and vq_model * fix lumina embedding forward to not depend on weight dtype * refactor tests * add simple lora inference tests * _always_upcast_modules -> _precision_sensitive_module_patterns * remove todo comments about review; revert changes to self.dtype in unets because .dtype on ModelMixin should be able to handle fp8 weight case * check layer dtypes in lora test * fix UNet1DModelTests::test_layerwise_upcasting_inference * _precision_sensitive_module_patterns -> _skip_layerwise_casting_patterns based on feedback * skip test in NCSNppModelTests * skip tests for AutoencoderTinyTests * skip tests for AutoencoderOobleckTests * skip tests for UNet1DModelTests - unsupported pytorch operations * layerwise_upcasting -> layerwise_casting * skip tests for UNetRLModelTests; needs next pytorch release for currently unimplemented operation support * add layerwise fp8 pipeline test * use xfail * Apply suggestions from code review Co-authored-by: Dhruv Nair * add assertion with fp32 comparison; add tolerance to fp8-fp32 vs fp32-fp32 comparison (required for a few models' test to pass) * add note about memory consumption on tesla CI runner for failing test --------- Co-authored-by: Dhruv Nair Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/api/utilities.md | 4 + docs/source/en/optimization/memory.md | 37 ++++ src/diffusers/hooks/__init__.py | 5 + src/diffusers/hooks/hooks.py | 188 +++++++++++++++++ src/diffusers/hooks/layerwise_casting.py | 191 ++++++++++++++++++ .../autoencoders/autoencoder_asym_kl.py | 2 + src/diffusers/models/autoencoders/vq_model.py | 2 + src/diffusers/models/embeddings.py | 2 +- src/diffusers/models/modeling_utils.py | 100 ++++++++- .../transformers/auraflow_transformer_2d.py | 1 + .../transformers/cogvideox_transformer_3d.py | 1 + .../models/transformers/dit_transformer_2d.py | 1 + .../transformers/hunyuan_transformer_2d.py | 2 + .../transformers/latte_transformer_3d.py | 2 + .../models/transformers/lumina_nextdit2d.py | 2 + .../transformers/pixart_transformer_2d.py | 1 + .../models/transformers/sana_transformer.py | 1 + .../transformers/stable_audio_transformer.py | 1 + .../models/transformers/transformer_2d.py | 1 + .../transformers/transformer_allegro.py | 1 + .../transformers/transformer_cogview3plus.py | 1 + .../models/transformers/transformer_flux.py | 1 + .../transformers/transformer_hunyuan_video.py | 1 + .../models/transformers/transformer_ltx.py | 1 + .../models/transformers/transformer_mochi.py | 1 + .../models/transformers/transformer_sd3.py | 1 + .../transformers/transformer_temporal.py | 2 + src/diffusers/models/unets/unet_1d.py | 4 +- src/diffusers/models/unets/unet_2d.py | 1 + .../models/unets/unet_2d_condition.py | 1 + .../models/unets/unet_3d_condition.py | 1 + .../models/unets/unet_motion_model.py | 1 + tests/lora/utils.py | 59 ++++++ .../test_models_autoencoder_oobleck.py | 18 ++ .../test_models_autoencoder_tiny.py | 16 ++ tests/models/test_modeling_common.py | 101 +++++++++ tests/models/unets/test_models_unet_1d.py | 45 +++++ tests/models/unets/test_models_unet_2d.py | 12 ++ tests/pipelines/allegro/test_allegro.py | 1 + tests/pipelines/amused/test_amused.py | 1 + .../pipelines/animatediff/test_animatediff.py | 1 + .../aura_flow/test_pipeline_aura_flow.py | 1 + tests/pipelines/cogvideo/test_cogvideox.py | 1 + .../cogvideo/test_cogvideox_fun_control.py | 1 + tests/pipelines/cogview3/test_cogview3plus.py | 1 + tests/pipelines/consisid/test_consisid.py | 1 + tests/pipelines/controlnet/test_controlnet.py | 1 + .../controlnet/test_controlnet_sdxl.py | 1 + .../controlnet_flux/test_controlnet_flux.py | 1 + .../test_controlnet_hunyuandit.py | 1 + .../controlnet_sd3/test_controlnet_sd3.py | 1 + .../controlnet_xs/test_controlnetxs.py | 1 + .../controlnet_xs/test_controlnetxs_sdxl.py | 1 + tests/pipelines/flux/test_pipeline_flux.py | 1 + .../flux/test_pipeline_flux_control.py | 1 + .../pipelines/flux/test_pipeline_flux_fill.py | 1 + .../pipelines/hunyuan_dit/test_hunyuan_dit.py | 1 + .../hunyuan_video/test_hunyuan_video.py | 1 + tests/pipelines/i2vgen_xl/test_i2vgenxl.py | 1 + tests/pipelines/kolors/test_kolors.py | 1 + tests/pipelines/latte/test_latte.py | 1 + tests/pipelines/ltx/test_ltx.py | 1 + tests/pipelines/lumina/test_lumina_nextdit.py | 1 + tests/pipelines/mochi/test_mochi.py | 1 + tests/pipelines/pia/test_pia.py | 1 + tests/pipelines/pixart_alpha/test_pixart.py | 1 + tests/pipelines/pixart_sigma/test_pixart.py | 1 + tests/pipelines/sana/test_sana.py | 1 + .../stable_diffusion/test_stable_diffusion.py | 1 + .../test_stable_diffusion.py | 1 + .../test_pipeline_stable_diffusion_3.py | 1 + .../test_stable_diffusion_xl.py | 1 + tests/pipelines/test_pipelines_common.py | 17 +- 73 files changed, 859 insertions(+), 4 deletions(-) create mode 100644 src/diffusers/hooks/__init__.py create mode 100644 src/diffusers/hooks/hooks.py create mode 100644 src/diffusers/hooks/layerwise_casting.py diff --git a/docs/source/en/api/utilities.md b/docs/source/en/api/utilities.md index d4f4d7d7964f..b0b78928fb4b 100644 --- a/docs/source/en/api/utilities.md +++ b/docs/source/en/api/utilities.md @@ -41,3 +41,7 @@ Utility and helper functions for working with 🤗 Diffusers. ## randn_tensor [[autodoc]] utils.torch_utils.randn_tensor + +## apply_layerwise_casting + +[[autodoc]] hooks.layerwise_casting.apply_layerwise_casting diff --git a/docs/source/en/optimization/memory.md b/docs/source/en/optimization/memory.md index a2150f9aa0b7..4cdc60401914 100644 --- a/docs/source/en/optimization/memory.md +++ b/docs/source/en/optimization/memory.md @@ -158,6 +158,43 @@ In order to properly offload models after they're called, it is required to run +## FP8 layerwise weight-casting + +PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting. + +Typically, inference on most models is done with `torch.float16` or `torch.bfloat16` weight/computation precision. Layerwise weight-casting cuts down the memory footprint of the model weights by approximately half. + +```python +import torch +from diffusers import CogVideoXPipeline, CogVideoXTransformer3DModel +from diffusers.utils import export_to_video + +model_id = "THUDM/CogVideoX-5b" + +# Load the model in bfloat16 and enable layerwise casting +transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16) +transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) + +# Load the pipeline +pipe = CogVideoXPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16) +pipe.to("cuda") + +prompt = ( + "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + "atmosphere of this unique musical performance." +) +video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] +export_to_video(video, "output.mp4", fps=8) +``` + +In the above example, layerwise casting is enabled on the transformer component of the pipeline. By default, certain layers are skipped from the FP8 weight casting because it can lead to significant degradation of generation quality. The normalization and modulation related weight parameters are also skipped by default. + +However, you gain more control and flexibility by directly utilizing the [`~hooks.layerwise_casting.apply_layerwise_casting`] function instead of [`~ModelMixin.enable_layerwise_casting`]. + ## Channels-last memory format The channels-last memory format is an alternative way of ordering NCHW tensors in memory to preserve dimension ordering. Channels-last tensors are ordered in such a way that the channels become the densest dimension (storing images pixel-per-pixel). Since not all operators currently support the channels-last format, it may result in worst performance but you should still try and see if it works for your model. diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py new file mode 100644 index 000000000000..91b2760acad0 --- /dev/null +++ b/src/diffusers/hooks/__init__.py @@ -0,0 +1,5 @@ +from ..utils import is_torch_available + + +if is_torch_available(): + from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py new file mode 100644 index 000000000000..bef4c65c41e1 --- /dev/null +++ b/src/diffusers/hooks/hooks.py @@ -0,0 +1,188 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Any, Dict, Optional, Tuple + +import torch + +from ..utils.logging import get_logger + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +class ModelHook: + r""" + A hook that contains callbacks to be executed just before and after the forward method of a model. + """ + + _is_stateful = False + + def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when a model is initialized. + + Args: + module (`torch.nn.Module`): + The module attached to this hook. + """ + return module + + def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when a model is deinitalized. + + Args: + module (`torch.nn.Module`): + The module attached to this hook. + """ + module.forward = module._old_forward + del module._old_forward + return module + + def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: + r""" + Hook that is executed just before the forward method of the model. + + Args: + module (`torch.nn.Module`): + The module whose forward pass will be executed just after this event. + args (`Tuple[Any]`): + The positional arguments passed to the module. + kwargs (`Dict[Str, Any]`): + The keyword arguments passed to the module. + Returns: + `Tuple[Tuple[Any], Dict[Str, Any]]`: + A tuple with the treated `args` and `kwargs`. + """ + return args, kwargs + + def post_forward(self, module: torch.nn.Module, output: Any) -> Any: + r""" + Hook that is executed just after the forward method of the model. + + Args: + module (`torch.nn.Module`): + The module whose forward pass been executed just before this event. + output (`Any`): + The output of the module. + Returns: + `Any`: The processed `output`. + """ + return output + + def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when the hook is detached from a module. + + Args: + module (`torch.nn.Module`): + The module detached from this hook. + """ + return module + + def reset_state(self, module: torch.nn.Module): + if self._is_stateful: + raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") + return module + + +class HookRegistry: + def __init__(self, module_ref: torch.nn.Module) -> None: + super().__init__() + + self.hooks: Dict[str, ModelHook] = {} + + self._module_ref = module_ref + self._hook_order = [] + + def register_hook(self, hook: ModelHook, name: str) -> None: + if name in self.hooks.keys(): + logger.warning(f"Hook with name {name} already exists, replacing it.") + + if hasattr(self._module_ref, "_old_forward"): + old_forward = self._module_ref._old_forward + else: + old_forward = self._module_ref.forward + self._module_ref._old_forward = self._module_ref.forward + + self._module_ref = hook.initialize_hook(self._module_ref) + + if hasattr(hook, "new_forward"): + rewritten_forward = hook.new_forward + + def new_forward(module, *args, **kwargs): + args, kwargs = hook.pre_forward(module, *args, **kwargs) + output = rewritten_forward(module, *args, **kwargs) + return hook.post_forward(module, output) + else: + + def new_forward(module, *args, **kwargs): + args, kwargs = hook.pre_forward(module, *args, **kwargs) + output = old_forward(*args, **kwargs) + return hook.post_forward(module, output) + + self._module_ref.forward = functools.update_wrapper( + functools.partial(new_forward, self._module_ref), old_forward + ) + + self.hooks[name] = hook + self._hook_order.append(name) + + def get_hook(self, name: str) -> Optional[ModelHook]: + if name not in self.hooks.keys(): + return None + return self.hooks[name] + + def remove_hook(self, name: str, recurse: bool = True) -> None: + if name in self.hooks.keys(): + hook = self.hooks[name] + self._module_ref = hook.deinitalize_hook(self._module_ref) + del self.hooks[name] + self._hook_order.remove(name) + + if recurse: + for module_name, module in self._module_ref.named_modules(): + if module_name == "": + continue + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook.remove_hook(name, recurse=False) + + def reset_stateful_hooks(self, recurse: bool = True) -> None: + for hook_name in self._hook_order: + hook = self.hooks[hook_name] + if hook._is_stateful: + hook.reset_state(self._module_ref) + + if recurse: + for module_name, module in self._module_ref.named_modules(): + if module_name == "": + continue + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook.reset_stateful_hooks(recurse=False) + + @classmethod + def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry": + if not hasattr(module, "_diffusers_hook"): + module._diffusers_hook = cls(module) + return module._diffusers_hook + + def __repr__(self) -> str: + hook_repr = "" + for i, hook_name in enumerate(self._hook_order): + hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})" + if i < len(self._hook_order) - 1: + hook_repr += "\n" + return f"HookRegistry(\n{hook_repr}\n)" diff --git a/src/diffusers/hooks/layerwise_casting.py b/src/diffusers/hooks/layerwise_casting.py new file mode 100644 index 000000000000..038625e21f0d --- /dev/null +++ b/src/diffusers/hooks/layerwise_casting.py @@ -0,0 +1,191 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Optional, Tuple, Type, Union + +import torch + +from ..utils import get_logger +from .hooks import HookRegistry, ModelHook + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +# fmt: off +SUPPORTED_PYTORCH_LAYERS = ( + torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, + torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, + torch.nn.Linear, +) + +DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm", "^proj_in$", "^proj_out$") +# fmt: on + + +class LayerwiseCastingHook(ModelHook): + r""" + A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype + for storage. This process may lead to quality loss in the output, but can significantly reduce the memory + footprint. + """ + + _is_stateful = False + + def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None: + self.storage_dtype = storage_dtype + self.compute_dtype = compute_dtype + self.non_blocking = non_blocking + + def initialize_hook(self, module: torch.nn.Module): + module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking) + return module + + def deinitalize_hook(self, module: torch.nn.Module): + raise NotImplementedError( + "LayerwiseCastingHook does not support deinitalization. A model once enabled with layerwise casting will " + "have casted its weights to a lower precision dtype for storage. Casting this back to the original dtype " + "will lead to precision loss, which might have an impact on the model's generation quality. The model should " + "be re-initialized and loaded in the original dtype." + ) + + def pre_forward(self, module: torch.nn.Module, *args, **kwargs): + module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking) + return args, kwargs + + def post_forward(self, module: torch.nn.Module, output): + module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking) + return output + + +def apply_layerwise_casting( + module: torch.nn.Module, + storage_dtype: torch.dtype, + compute_dtype: torch.dtype, + skip_modules_pattern: Union[str, Tuple[str, ...]] = "auto", + skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None, + non_blocking: bool = False, +) -> None: + r""" + Applies layerwise casting to a given module. The module expected here is a Diffusers ModelMixin but it can be any + nn.Module using diffusers layers or pytorch primitives. + + Example: + + ```python + >>> import torch + >>> from diffusers import CogVideoXTransformer3DModel + + >>> transformer = CogVideoXTransformer3DModel.from_pretrained( + ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 + ... ) + + >>> apply_layerwise_casting( + ... transformer, + ... storage_dtype=torch.float8_e4m3fn, + ... compute_dtype=torch.bfloat16, + ... skip_modules_pattern=["patch_embed", "norm", "proj_out"], + ... non_blocking=True, + ... ) + ``` + + Args: + module (`torch.nn.Module`): + The module whose leaf modules will be cast to a high precision dtype for computation, and to a low + precision dtype for storage. + storage_dtype (`torch.dtype`): + The dtype to cast the module to before/after the forward pass for storage. + compute_dtype (`torch.dtype`): + The dtype to cast the module to during the forward pass for computation. + skip_modules_pattern (`Tuple[str, ...]`, defaults to `"auto"`): + A list of patterns to match the names of the modules to skip during the layerwise casting process. If set + to `"auto"`, the default patterns are used. If set to `None`, no modules are skipped. If set to `None` + alongside `skip_modules_classes` being `None`, the layerwise casting is applied directly to the module + instead of its internal submodules. + skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, defaults to `None`): + A list of module classes to skip during the layerwise casting process. + non_blocking (`bool`, defaults to `False`): + If `True`, the weight casting operations are non-blocking. + """ + if skip_modules_pattern == "auto": + skip_modules_pattern = DEFAULT_SKIP_MODULES_PATTERN + + if skip_modules_classes is None and skip_modules_pattern is None: + apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking) + return + + _apply_layerwise_casting( + module, + storage_dtype, + compute_dtype, + skip_modules_pattern, + skip_modules_classes, + non_blocking, + ) + + +def _apply_layerwise_casting( + module: torch.nn.Module, + storage_dtype: torch.dtype, + compute_dtype: torch.dtype, + skip_modules_pattern: Optional[Tuple[str, ...]] = None, + skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None, + non_blocking: bool = False, + _prefix: str = "", +) -> None: + should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or ( + skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern) + ) + if should_skip: + logger.debug(f'Skipping layerwise casting for layer "{_prefix}"') + return + + if isinstance(module, SUPPORTED_PYTORCH_LAYERS): + logger.debug(f'Applying layerwise casting to layer "{_prefix}"') + apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking) + return + + for name, submodule in module.named_children(): + layer_name = f"{_prefix}.{name}" if _prefix else name + _apply_layerwise_casting( + submodule, + storage_dtype, + compute_dtype, + skip_modules_pattern, + skip_modules_classes, + non_blocking, + _prefix=layer_name, + ) + + +def apply_layerwise_casting_hook( + module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool +) -> None: + r""" + Applies a `LayerwiseCastingHook` to a given module. + + Args: + module (`torch.nn.Module`): + The module to attach the hook to. + storage_dtype (`torch.dtype`): + The dtype to cast the module to before the forward pass. + compute_dtype (`torch.dtype`): + The dtype to cast the module to during the forward pass. + non_blocking (`bool`): + If `True`, the weight casting operations are non-blocking. + """ + registry = HookRegistry.check_if_exists_or_initialize(module) + hook = LayerwiseCastingHook(storage_dtype, compute_dtype, non_blocking) + registry.register_hook(hook, "layerwise_casting") diff --git a/src/diffusers/models/autoencoders/autoencoder_asym_kl.py b/src/diffusers/models/autoencoders/autoencoder_asym_kl.py index 3f4d46557bf7..c643dcc72a34 100644 --- a/src/diffusers/models/autoencoders/autoencoder_asym_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_asym_kl.py @@ -60,6 +60,8 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. """ + _skip_layerwise_casting_patterns = ["decoder"] + @register_to_config def __init__( self, diff --git a/src/diffusers/models/autoencoders/vq_model.py b/src/diffusers/models/autoencoders/vq_model.py index ae8a118d719a..e754e134b35f 100644 --- a/src/diffusers/models/autoencoders/vq_model.py +++ b/src/diffusers/models/autoencoders/vq_model.py @@ -71,6 +71,8 @@ class VQModel(ModelMixin, ConfigMixin): Type of normalization layer to use. Can be one of `"group"` or `"spatial"`. """ + _skip_layerwise_casting_patterns = ["quantize"] + @register_to_config def __init__( self, diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index c64b9587be77..bd3237c24c1c 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1787,7 +1787,7 @@ def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embeddi def forward(self, timestep, caption_feat, caption_mask): # timestep embedding: time_freq = self.time_proj(timestep) - time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype)) + time_embed = self.timestep_embedder(time_freq.to(dtype=caption_feat.dtype)) # caption condition embedding: caption_mask_float = caption_mask.float().unsqueeze(-1) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index b57cfb9b1750..4d5669e37f5a 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -23,7 +23,7 @@ from collections import OrderedDict from functools import partial, wraps from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import safetensors import torch @@ -32,6 +32,7 @@ from torch import Tensor, nn from .. import __version__ +from ..hooks import apply_layerwise_casting from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( @@ -48,6 +49,7 @@ is_accelerate_available, is_bitsandbytes_available, is_bitsandbytes_version, + is_peft_available, is_torch_version, logging, ) @@ -102,6 +104,17 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype: """ Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found. """ + # 1. Check if we have attached any dtype modifying hooks (eg. layerwise casting) + if isinstance(parameter, nn.Module): + for name, submodule in parameter.named_modules(): + if not hasattr(submodule, "_diffusers_hook"): + continue + registry = submodule._diffusers_hook + hook = registry.get_hook("layerwise_casting") + if hook is not None: + return hook.compute_dtype + + # 2. If no dtype modifying hooks are attached, return the dtype of the first floating point parameter/buffer last_dtype = None for param in parameter.parameters(): last_dtype = param.dtype @@ -150,6 +163,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): _keys_to_ignore_on_load_unexpected = None _no_split_modules = None _keep_in_fp32_modules = None + _skip_layerwise_casting_patterns = None def __init__(self): super().__init__() @@ -314,6 +328,90 @@ def disable_xformers_memory_efficient_attention(self) -> None: """ self.set_use_memory_efficient_attention_xformers(False) + def enable_layerwise_casting( + self, + storage_dtype: torch.dtype = torch.float8_e4m3fn, + compute_dtype: Optional[torch.dtype] = None, + skip_modules_pattern: Optional[Tuple[str, ...]] = None, + skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None, + non_blocking: bool = False, + ) -> None: + r""" + Activates layerwise casting for the current model. + + Layerwise casting is a technique that casts the model weights to a lower precision dtype for storage but + upcasts them on-the-fly to a higher precision dtype for computation. This process can significantly reduce the + memory footprint from model weights, but may lead to some quality degradation in the outputs. Most degradations + are negligible, mostly stemming from weight casting in normalization and modulation layers. + + By default, most models in diffusers set the `_skip_layerwise_casting_patterns` attribute to ignore patch + embedding, positional embedding and normalization layers. This is because these layers are most likely + precision-critical for quality. If you wish to change this behavior, you can set the + `_skip_layerwise_casting_patterns` attribute to `None`, or call + [`~hooks.layerwise_casting.apply_layerwise_casting`] with custom arguments. + + Example: + Using [`~models.ModelMixin.enable_layerwise_casting`]: + + ```python + >>> from diffusers import CogVideoXTransformer3DModel + + >>> transformer = CogVideoXTransformer3DModel.from_pretrained( + ... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16 + ... ) + + >>> # Enable layerwise casting via the model, which ignores certain modules by default + >>> transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) + ``` + + Args: + storage_dtype (`torch.dtype`): + The dtype to which the model should be cast for storage. + compute_dtype (`torch.dtype`): + The dtype to which the model weights should be cast during the forward pass. + skip_modules_pattern (`Tuple[str, ...]`, *optional*): + A list of patterns to match the names of the modules to skip during the layerwise casting process. If + set to `None`, default skip patterns are used to ignore certain internal layers of modules and PEFT + layers. + skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, *optional*): + A list of module classes to skip during the layerwise casting process. + non_blocking (`bool`, *optional*, defaults to `False`): + If `True`, the weight casting operations are non-blocking. + """ + + user_provided_patterns = True + if skip_modules_pattern is None: + from ..hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN + + skip_modules_pattern = DEFAULT_SKIP_MODULES_PATTERN + user_provided_patterns = False + if self._keep_in_fp32_modules is not None: + skip_modules_pattern += tuple(self._keep_in_fp32_modules) + if self._skip_layerwise_casting_patterns is not None: + skip_modules_pattern += tuple(self._skip_layerwise_casting_patterns) + skip_modules_pattern = tuple(set(skip_modules_pattern)) + + if is_peft_available() and not user_provided_patterns: + # By default, we want to skip all peft layers because they have a very low memory footprint. + # If users want to apply layerwise casting on peft layers as well, they can utilize the + # `~diffusers.hooks.layerwise_casting.apply_layerwise_casting` function which provides + # them with more flexibility and control. + + from peft.tuners.loha.layer import LoHaLayer + from peft.tuners.lokr.layer import LoKrLayer + from peft.tuners.lora.layer import LoraLayer + + for layer in (LoHaLayer, LoKrLayer, LoraLayer): + skip_modules_pattern += tuple(layer.adapter_layer_names) + + if compute_dtype is None: + logger.info("`compute_dtype` not provided when enabling layerwise casting. Using dtype of the model.") + compute_dtype = self.dtype + + apply_layerwise_casting( + self, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking + ) + def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index b35488a89282..f1f36b87987d 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -276,6 +276,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin """ _no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] _supports_gradient_checkpointing = True @register_to_config diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 51634780692d..c3039180b81d 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -212,6 +212,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): Scaling factor to apply in 3D positional embeddings across temporal dimensions. """ + _skip_layerwise_casting_patterns = ["patch_embed", "norm"] _supports_gradient_checkpointing = True _no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"] diff --git a/src/diffusers/models/transformers/dit_transformer_2d.py b/src/diffusers/models/transformers/dit_transformer_2d.py index f787c5279499..7eac313c14db 100644 --- a/src/diffusers/models/transformers/dit_transformer_2d.py +++ b/src/diffusers/models/transformers/dit_transformer_2d.py @@ -64,6 +64,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin): A small constant added to the denominator in normalization layers to prevent division by zero. """ + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] _supports_gradient_checkpointing = True @register_to_config diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py index 7f3dab220aaa..13aa7d076d03 100644 --- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -244,6 +244,8 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin): Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2 """ + _skip_layerwise_casting_patterns = ["pos_embed", "norm", "pooler"] + @register_to_config def __init__( self, diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index d34ccfd20108..be06f44a9efe 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -65,6 +65,8 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin): The number of frames in the video-like data. """ + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + @register_to_config def __init__( self, diff --git a/src/diffusers/models/transformers/lumina_nextdit2d.py b/src/diffusers/models/transformers/lumina_nextdit2d.py index d4f5b4658542..fb2b3815bcd5 100644 --- a/src/diffusers/models/transformers/lumina_nextdit2d.py +++ b/src/diffusers/models/transformers/lumina_nextdit2d.py @@ -221,6 +221,8 @@ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin): overall scale of the model's operations. """ + _skip_layerwise_casting_patterns = ["patch_embedder", "norm", "ffn_norm"] + @register_to_config def __init__( self, diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py index 7f145edf16fb..b1740cc08fdf 100644 --- a/src/diffusers/models/transformers/pixart_transformer_2d.py +++ b/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -79,6 +79,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index 3dac0d5dc7bf..a2a54406430d 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -236,6 +236,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): _supports_gradient_checkpointing = True _no_split_modules = ["SanaTransformerBlock", "PatchEmbed", "SanaModulatedNorm"] + _skip_layerwise_casting_patterns = ["patch_embed", "norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/stable_audio_transformer.py b/src/diffusers/models/transformers/stable_audio_transformer.py index d687dbabf317..bb370f20f21b 100644 --- a/src/diffusers/models/transformers/stable_audio_transformer.py +++ b/src/diffusers/models/transformers/stable_audio_transformer.py @@ -211,6 +211,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["preprocess_conv", "postprocess_conv", "^proj_in$", "^proj_out$", "norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index e208a1c10ed4..35e78877f27e 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -66,6 +66,7 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin): _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock"] + _skip_layerwise_casting_patterns = ["latent_image_embedding", "norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index 81039fd49e0d..f32c38394ba4 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -222,6 +222,7 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py index 369509a3a35e..0376cc2fd70d 100644 --- a/src/diffusers/models/transformers/transformer_cogview3plus.py +++ b/src/diffusers/models/transformers/transformer_cogview3plus.py @@ -166,6 +166,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["patch_embed", "norm"] _no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"] @register_to_config diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index f5e92700b2f3..db8d73856689 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -262,6 +262,7 @@ class FluxTransformer2DModel( _supports_gradient_checkpointing = True _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 4495623119e5..210a2e711972 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -542,6 +542,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, """ _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"] _no_split_modules = [ "HunyuanVideoTransformerBlock", "HunyuanVideoSingleTransformerBlock", diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index a895340bd124..b5498c0aed01 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -295,6 +295,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin """ _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 8763ea450253..d16430f27931 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -336,6 +336,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri _supports_gradient_checkpointing = True _no_split_modules = ["MochiTransformerBlock"] + _skip_layerwise_casting_patterns = ["patch_embed", "norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 415540ef7f6a..2688d3640ea5 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -127,6 +127,7 @@ class SD3Transformer2DModel( """ _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_temporal.py b/src/diffusers/models/transformers/transformer_temporal.py index 6ca42b9745fd..3b5aedb79e3c 100644 --- a/src/diffusers/models/transformers/transformer_temporal.py +++ b/src/diffusers/models/transformers/transformer_temporal.py @@ -67,6 +67,8 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): The maximum length of the sequence over which to apply positional embeddings. """ + _skip_layerwise_casting_patterns = ["norm"] + @register_to_config def __init__( self, diff --git a/src/diffusers/models/unets/unet_1d.py b/src/diffusers/models/unets/unet_1d.py index 8efabd98ee7d..ce496fd6baf8 100644 --- a/src/diffusers/models/unets/unet_1d.py +++ b/src/diffusers/models/unets/unet_1d.py @@ -71,6 +71,8 @@ class UNet1DModel(ModelMixin, ConfigMixin): Experimental feature for using a UNet without upsampling. """ + _skip_layerwise_casting_patterns = ["norm"] + @register_to_config def __init__( self, @@ -223,7 +225,7 @@ def forward( timestep_embed = self.time_proj(timesteps) if self.config.use_timestep_embedding: - timestep_embed = self.time_mlp(timestep_embed) + timestep_embed = self.time_mlp(timestep_embed.to(sample.dtype)) else: timestep_embed = timestep_embed[..., None] timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype) diff --git a/src/diffusers/models/unets/unet_2d.py b/src/diffusers/models/unets/unet_2d.py index 090357237f46..84a1322d2a95 100644 --- a/src/diffusers/models/unets/unet_2d.py +++ b/src/diffusers/models/unets/unet_2d.py @@ -90,6 +90,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 2b896f89e484..3447fa0674bc 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -166,6 +166,7 @@ class conditioning with `class_embed_type` equal to `None`. _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] + _skip_layerwise_casting_patterns = ["norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py index 56739ac24c11..398609778e65 100644 --- a/src/diffusers/models/unets/unet_3d_condition.py +++ b/src/diffusers/models/unets/unet_3d_condition.py @@ -97,6 +97,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) """ _supports_gradient_checkpointing = False + _skip_layerwise_casting_patterns = ["norm", "time_embedding"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 1c07a0760f62..1d0a38a8fb13 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -1301,6 +1301,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft """ _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["norm"] @register_to_config def __init__( diff --git a/tests/lora/utils.py b/tests/lora/utils.py index a22f86ad6b89..d0d39d05b08a 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -14,6 +14,7 @@ # limitations under the License. import inspect import os +import re import tempfile import unittest from itertools import product @@ -2098,3 +2099,61 @@ def test_correct_lora_configs_with_different_ranks(self): lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3)) self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) + + def test_layerwise_casting_inference_denoiser(self): + from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS + + def check_linear_dtype(module, storage_dtype, compute_dtype): + patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN + if getattr(module, "_skip_layerwise_casting_patterns", None) is not None: + patterns_to_check += tuple(module._skip_layerwise_casting_patterns) + for name, submodule in module.named_modules(): + if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS): + continue + dtype_to_check = storage_dtype + if "lora" in name or any(re.search(pattern, name) for pattern in patterns_to_check): + dtype_to_check = compute_dtype + if getattr(submodule, "weight", None) is not None: + self.assertEqual(submodule.weight.dtype, dtype_to_check) + if getattr(submodule, "bias", None) is not None: + self.assertEqual(submodule.bias.dtype, dtype_to_check) + + def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32): + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device, dtype=compute_dtype) + pipe.set_progress_bar_config(disable=None) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + if storage_dtype is not None: + denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) + check_linear_dtype(denoiser, storage_dtype, compute_dtype) + + return pipe + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe_fp32 = initialize_pipeline(storage_dtype=None) + pipe_fp32(**inputs, generator=torch.manual_seed(0))[0] + + pipe_float8_e4m3_fp32 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32) + pipe_float8_e4m3_fp32(**inputs, generator=torch.manual_seed(0))[0] + + pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) + pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0] diff --git a/tests/models/autoencoders/test_models_autoencoder_oobleck.py b/tests/models/autoencoders/test_models_autoencoder_oobleck.py index 4807fa298344..1f922a9842ee 100644 --- a/tests/models/autoencoders/test_models_autoencoder_oobleck.py +++ b/tests/models/autoencoders/test_models_autoencoder_oobleck.py @@ -114,6 +114,24 @@ def test_forward_with_norm_groups(self): def test_set_attn_processor_for_determinism(self): return + @unittest.skip( + "The convolution layers of AutoencoderOobleck are wrapped with torch.nn.utils.weight_norm. This causes the hook's pre_forward to not " + "cast the module weights to compute_dtype (as required by forward pass). As a result, forward pass errors out. To fix:\n" + "1. Make sure `nn::Module::to` works with `torch.nn.utils.weight_norm` wrapped convolution layer.\n" + "2. Unskip this test." + ) + def test_layerwise_casting_inference(self): + pass + + @unittest.skip( + "The convolution layers of AutoencoderOobleck are wrapped with torch.nn.utils.weight_norm. This causes the hook's pre_forward to not " + "cast the module weights to compute_dtype (as required by forward pass). As a result, forward pass errors out. To fix:\n" + "1. Make sure `nn::Module::to` works with `torch.nn.utils.weight_norm` wrapped convolution layer.\n" + "2. Unskip this test." + ) + def test_layerwise_casting_memory(self): + pass + @slow class AutoencoderOobleckIntegrationTests(unittest.TestCase): diff --git a/tests/models/autoencoders/test_models_autoencoder_tiny.py b/tests/models/autoencoders/test_models_autoencoder_tiny.py index 4de3822fa835..bfbfb7ab8593 100644 --- a/tests/models/autoencoders/test_models_autoencoder_tiny.py +++ b/tests/models/autoencoders/test_models_autoencoder_tiny.py @@ -173,6 +173,22 @@ def test_effective_gradient_checkpointing(self): continue self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=3e-2)) + @unittest.skip( + "The forward pass of AutoencoderTiny creates a torch.float32 tensor. This causes inference in compute_dtype=torch.bfloat16 to fail. To fix:\n" + "1. Change the forward pass to be dtype agnostic.\n" + "2. Unskip this test." + ) + def test_layerwise_casting_inference(self): + pass + + @unittest.skip( + "The forward pass of AutoencoderTiny creates a torch.float32 tensor. This causes inference in compute_dtype=torch.bfloat16 to fail. To fix:\n" + "1. Change the forward pass to be dtype agnostic.\n" + "2. Unskip this test." + ) + def test_layerwise_casting_memory(self): + pass + @slow class AutoencoderTinyIntegrationTests(unittest.TestCase): diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index ac3a59d8abe5..05050e05bb19 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -14,9 +14,11 @@ # limitations under the License. import copy +import gc import inspect import json import os +import re import tempfile import traceback import unittest @@ -56,9 +58,11 @@ CaptureLogger, get_python_version, is_torch_compile, + numpy_cosine_similarity_distance, require_torch_2, require_torch_accelerator, require_torch_accelerator_with_training, + require_torch_gpu, require_torch_multi_gpu, run_test_in_subprocess, torch_all_close, @@ -181,6 +185,16 @@ def compute_module_persistent_sizes( return module_sizes +def cast_maybe_tensor_dtype(maybe_tensor, current_dtype, target_dtype): + if torch.is_tensor(maybe_tensor): + return maybe_tensor.to(target_dtype) if maybe_tensor.dtype == current_dtype else maybe_tensor + if isinstance(maybe_tensor, dict): + return {k: cast_maybe_tensor_dtype(v, current_dtype, target_dtype) for k, v in maybe_tensor.items()} + if isinstance(maybe_tensor, list): + return [cast_maybe_tensor_dtype(v, current_dtype, target_dtype) for v in maybe_tensor] + return maybe_tensor + + class ModelUtilsTest(unittest.TestCase): def tearDown(self): super().tearDown() @@ -1332,6 +1346,93 @@ def test_variant_sharded_ckpt_right_format(self): # Example: diffusion_pytorch_model.fp16-00001-of-00002.safetensors assert all(f.split(".")[1].split("-")[0] == variant for f in shard_files) + def test_layerwise_casting_inference(self): + from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS + + torch.manual_seed(0) + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**config).eval() + model = model.to(torch_device) + base_slice = model(**inputs_dict)[0].flatten().detach().cpu().numpy() + + def check_linear_dtype(module, storage_dtype, compute_dtype): + patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN + if getattr(module, "_skip_layerwise_casting_patterns", None) is not None: + patterns_to_check += tuple(module._skip_layerwise_casting_patterns) + for name, submodule in module.named_modules(): + if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS): + continue + dtype_to_check = storage_dtype + if any(re.search(pattern, name) for pattern in patterns_to_check): + dtype_to_check = compute_dtype + if getattr(submodule, "weight", None) is not None: + self.assertEqual(submodule.weight.dtype, dtype_to_check) + if getattr(submodule, "bias", None) is not None: + self.assertEqual(submodule.bias.dtype, dtype_to_check) + + def test_layerwise_casting(storage_dtype, compute_dtype): + torch.manual_seed(0) + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() + inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype) + model = self.model_class(**config).eval() + model = model.to(torch_device, dtype=compute_dtype) + model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) + + check_linear_dtype(model, storage_dtype, compute_dtype) + output = model(**inputs_dict)[0].float().flatten().detach().cpu().numpy() + + # The precision test is not very important for fast tests. In most cases, the outputs will not be the same. + # We just want to make sure that the layerwise casting is working as expected. + self.assertTrue(numpy_cosine_similarity_distance(base_slice, output) < 1.0) + + test_layerwise_casting(torch.float16, torch.float32) + test_layerwise_casting(torch.float8_e4m3fn, torch.float32) + test_layerwise_casting(torch.float8_e5m2, torch.float32) + test_layerwise_casting(torch.float8_e4m3fn, torch.bfloat16) + + @require_torch_gpu + def test_layerwise_casting_memory(self): + MB_TOLERANCE = 0.2 + + def reset_memory_stats(): + gc.collect() + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + def get_memory_usage(storage_dtype, compute_dtype): + torch.manual_seed(0) + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() + inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype) + model = self.model_class(**config).eval() + model = model.to(torch_device, dtype=compute_dtype) + model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) + + reset_memory_stats() + model(**inputs_dict) + model_memory_footprint = model.get_memory_footprint() + peak_inference_memory_allocated_mb = torch.cuda.max_memory_allocated() / 1024**2 + + return model_memory_footprint, peak_inference_memory_allocated_mb + + fp32_memory_footprint, fp32_max_memory = get_memory_usage(torch.float32, torch.float32) + fp8_e4m3_fp32_memory_footprint, fp8_e4m3_fp32_max_memory = get_memory_usage(torch.float8_e4m3fn, torch.float32) + fp8_e4m3_bf16_memory_footprint, fp8_e4m3_bf16_max_memory = get_memory_usage( + torch.float8_e4m3fn, torch.bfloat16 + ) + + self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint) + # NOTE: the following assertion will fail on our CI (running Tesla T4) due to bf16 using more memory than fp32. + # On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. + self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory) + # On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few + # bytes. This only happens for some models, so we allow a small tolerance. + # For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32. + self.assertTrue( + fp8_e4m3_fp32_max_memory < fp32_max_memory + or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE + ) + @is_staging_test class ModelPushToHubTester(unittest.TestCase): diff --git a/tests/models/unets/test_models_unet_1d.py b/tests/models/unets/test_models_unet_1d.py index 6eb7d3485c8b..0f81807b895c 100644 --- a/tests/models/unets/test_models_unet_1d.py +++ b/tests/models/unets/test_models_unet_1d.py @@ -15,6 +15,7 @@ import unittest +import pytest import torch from diffusers import UNet1DModel @@ -152,6 +153,28 @@ def test_unet_1d_maestro(self): assert (output_sum - 224.0896).abs() < 0.5 assert (output_max - 0.0607).abs() < 4e-4 + @pytest.mark.xfail( + reason=( + "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " + "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" + "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" + "2. Unskip this test." + ), + ) + def test_layerwise_casting_inference(self): + super().test_layerwise_casting_inference() + + @pytest.mark.xfail( + reason=( + "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " + "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" + "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" + "2. Unskip this test." + ), + ) + def test_layerwise_casting_memory(self): + pass + class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = UNet1DModel @@ -274,3 +297,25 @@ def test_output_pretrained(self): def test_forward_with_norm_groups(self): # Not implemented yet for this UNet pass + + @pytest.mark.xfail( + reason=( + "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " + "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" + "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" + "2. Unskip this test." + ), + ) + def test_layerwise_casting_inference(self): + pass + + @pytest.mark.xfail( + reason=( + "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " + "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" + "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" + "2. Unskip this test." + ), + ) + def test_layerwise_casting_memory(self): + pass diff --git a/tests/models/unets/test_models_unet_2d.py b/tests/models/unets/test_models_unet_2d.py index 05bece23efd6..0e5fdc4bba2e 100644 --- a/tests/models/unets/test_models_unet_2d.py +++ b/tests/models/unets/test_models_unet_2d.py @@ -401,3 +401,15 @@ def test_gradient_checkpointing_is_applied(self): def test_effective_gradient_checkpointing(self): super().test_effective_gradient_checkpointing(skip={"time_proj.weight"}) + + @unittest.skip( + "To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here." + ) + def test_layerwise_casting_inference(self): + pass + + @unittest.skip( + "To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here." + ) + def test_layerwise_casting_memory(self): + pass diff --git a/tests/pipelines/allegro/test_allegro.py b/tests/pipelines/allegro/test_allegro.py index 6a5a81bf160f..322be373641a 100644 --- a/tests/pipelines/allegro/test_allegro.py +++ b/tests/pipelines/allegro/test_allegro.py @@ -57,6 +57,7 @@ class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] ) test_xformers_attention = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/amused/test_amused.py b/tests/pipelines/amused/test_amused.py index f28d8708d309..2dfc36a6ce45 100644 --- a/tests/pipelines/amused/test_amused.py +++ b/tests/pipelines/amused/test_amused.py @@ -38,6 +38,7 @@ class AmusedPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = AmusedPipeline params = TEXT_TO_IMAGE_PARAMS | {"encoder_hidden_states", "negative_encoder_hidden_states"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index c7411a7145c5..1b3115c8eb1d 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -60,6 +60,7 @@ class AnimateDiffPipelineFastTests( "callback_on_step_end_tensor_inputs", ] ) + test_layerwise_casting = True def get_dummy_components(self): cross_attention_dim = 8 diff --git a/tests/pipelines/aura_flow/test_pipeline_aura_flow.py b/tests/pipelines/aura_flow/test_pipeline_aura_flow.py index 14bc588df905..bee905f9ae13 100644 --- a/tests/pipelines/aura_flow/test_pipeline_aura_flow.py +++ b/tests/pipelines/aura_flow/test_pipeline_aura_flow.py @@ -30,6 +30,7 @@ class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin): ] ) batch_params = frozenset(["prompt", "negative_prompt"]) + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py index 78fe9d4ef3be..9ce3d8e9de31 100644 --- a/tests/pipelines/cogvideo/test_cogvideox.py +++ b/tests/pipelines/cogvideo/test_cogvideox.py @@ -58,6 +58,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] ) test_xformers_attention = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py index 2a51fc65798c..c936bad4c3d5 100644 --- a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py +++ b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py @@ -55,6 +55,7 @@ class CogVideoXFunControlPipelineFastTests(PipelineTesterMixin, unittest.TestCas ] ) test_xformers_attention = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/cogview3/test_cogview3plus.py b/tests/pipelines/cogview3/test_cogview3plus.py index dcb746e0a55d..102a5c66e624 100644 --- a/tests/pipelines/cogview3/test_cogview3plus.py +++ b/tests/pipelines/cogview3/test_cogview3plus.py @@ -56,6 +56,7 @@ class CogView3PlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] ) test_xformers_attention = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/consisid/test_consisid.py b/tests/pipelines/consisid/test_consisid.py index 31f2bc024af6..f949cfb2d36d 100644 --- a/tests/pipelines/consisid/test_consisid.py +++ b/tests/pipelines/consisid/test_consisid.py @@ -58,6 +58,7 @@ class ConsisIDPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] ) test_xformers_attention = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py index 43814b2b2211..e0fc00171031 100644 --- a/tests/pipelines/controlnet/test_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -126,6 +126,7 @@ class ControlNetPipelineFastTests( batch_params = TEXT_TO_IMAGE_BATCH_PARAMS image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + test_layerwise_casting = True def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index 27f676b15b1c..e75fe8903134 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -75,6 +75,7 @@ class StableDiffusionXLControlNetPipelineFastTests( batch_params = TEXT_TO_IMAGE_BATCH_PARAMS image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + test_layerwise_casting = True def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux.py b/tests/pipelines/controlnet_flux/test_controlnet_flux.py index 5e856b125f32..8b9852dbec6e 100644 --- a/tests/pipelines/controlnet_flux/test_controlnet_flux.py +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux.py @@ -50,6 +50,7 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin): params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py b/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py index 30dfe94e50f1..5c6054ccb605 100644 --- a/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py +++ b/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py @@ -57,6 +57,7 @@ class HunyuanDiTControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMix ] ) batch_params = frozenset(["prompt", "negative_prompt"]) + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py index 7527d17af32a..e1894d555c3c 100644 --- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py +++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py @@ -59,6 +59,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes ] ) batch_params = frozenset(["prompt", "negative_prompt"]) + test_layerwise_casting = True def get_dummy_components( self, num_controlnet_layers: int = 3, qk_norm: Optional[str] = "rms_norm", use_dual_attention=False diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py index 6d53d0618959..4c184db99630 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -139,6 +139,7 @@ class ControlNetXSPipelineFastTests( image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS test_attention_slicing = False + test_layerwise_casting = True def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py index d7ecf92f41cd..7537efe0bbf9 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py @@ -78,6 +78,7 @@ class StableDiffusionXLControlNetXSPipelineFastTests( image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS test_attention_slicing = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index addc29e14670..a3bc1658de74 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -31,6 +31,7 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapte # there is no xformers processor for Flux test_xformers_attention = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/flux/test_pipeline_flux_control.py b/tests/pipelines/flux/test_pipeline_flux_control.py index 2bd511db3d65..7fdb19327213 100644 --- a/tests/pipelines/flux/test_pipeline_flux_control.py +++ b/tests/pipelines/flux/test_pipeline_flux_control.py @@ -22,6 +22,7 @@ class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin): # there is no xformers processor for Flux test_xformers_attention = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/flux/test_pipeline_flux_fill.py b/tests/pipelines/flux/test_pipeline_flux_fill.py index 6c6ec138c781..620ecb8a831f 100644 --- a/tests/pipelines/flux/test_pipeline_flux_fill.py +++ b/tests/pipelines/flux/test_pipeline_flux_fill.py @@ -23,6 +23,7 @@ class FluxFillPipelineFastTests(unittest.TestCase, PipelineTesterMixin): params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) test_xformers_attention = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py b/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py index b295b280a560..6c9117a55c36 100644 --- a/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py +++ b/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py @@ -55,6 +55,7 @@ class HunyuanDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS required_optional_params = PipelineTesterMixin.required_optional_params + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py index 567002268106..ce03381f90d2 100644 --- a/tests/pipelines/hunyuan_video/test_hunyuan_video.py +++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py @@ -53,6 +53,7 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): # there is no xformers processor for Flux test_xformers_attention = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/i2vgen_xl/test_i2vgenxl.py b/tests/pipelines/i2vgen_xl/test_i2vgenxl.py index 22ece0e6d75f..f6ac22a9b575 100644 --- a/tests/pipelines/i2vgen_xl/test_i2vgenxl.py +++ b/tests/pipelines/i2vgen_xl/test_i2vgenxl.py @@ -61,6 +61,7 @@ class I2VGenXLPipelineFastTests(SDFunctionTesterMixin, PipelineTesterMixin, unit required_optional_params = frozenset(["num_inference_steps", "generator", "latents", "return_dict"]) supports_dduf = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/kolors/test_kolors.py b/tests/pipelines/kolors/test_kolors.py index e88ba0282096..cf0b392ddc06 100644 --- a/tests/pipelines/kolors/test_kolors.py +++ b/tests/pipelines/kolors/test_kolors.py @@ -48,6 +48,7 @@ class KolorsPipelineFastTests(PipelineTesterMixin, unittest.TestCase): callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"}) supports_dduf = False + test_layerwise_casting = True def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) diff --git a/tests/pipelines/latte/test_latte.py b/tests/pipelines/latte/test_latte.py index 9667ebff249d..2d5bcba8237a 100644 --- a/tests/pipelines/latte/test_latte.py +++ b/tests/pipelines/latte/test_latte.py @@ -52,6 +52,7 @@ class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS required_optional_params = PipelineTesterMixin.required_optional_params + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/ltx/test_ltx.py b/tests/pipelines/ltx/test_ltx.py index dd166c6242fc..64b366ea8ad6 100644 --- a/tests/pipelines/ltx/test_ltx.py +++ b/tests/pipelines/ltx/test_ltx.py @@ -46,6 +46,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] ) test_xformers_attention = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/lumina/test_lumina_nextdit.py b/tests/pipelines/lumina/test_lumina_nextdit.py index e0fd06847b77..7c1923313b23 100644 --- a/tests/pipelines/lumina/test_lumina_nextdit.py +++ b/tests/pipelines/lumina/test_lumina_nextdit.py @@ -32,6 +32,7 @@ class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterM batch_params = frozenset(["prompt", "negative_prompt"]) supports_dduf = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/mochi/test_mochi.py b/tests/pipelines/mochi/test_mochi.py index c9df5785897c..b7bb844ff311 100644 --- a/tests/pipelines/mochi/test_mochi.py +++ b/tests/pipelines/mochi/test_mochi.py @@ -55,6 +55,7 @@ class MochiPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] ) test_xformers_attention = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/pia/test_pia.py b/tests/pipelines/pia/test_pia.py index e461860eff65..747be38d495c 100644 --- a/tests/pipelines/pia/test_pia.py +++ b/tests/pipelines/pia/test_pia.py @@ -55,6 +55,7 @@ class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, PipelineFr "callback_on_step_end_tensor_inputs", ] ) + test_layerwise_casting = True def get_dummy_components(self): cross_attention_dim = 8 diff --git a/tests/pipelines/pixart_alpha/test_pixart.py b/tests/pipelines/pixart_alpha/test_pixart.py index e7039c61a448..7df6656f6f87 100644 --- a/tests/pipelines/pixart_alpha/test_pixart.py +++ b/tests/pipelines/pixart_alpha/test_pixart.py @@ -50,6 +50,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS required_optional_params = PipelineTesterMixin.required_optional_params + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/pixart_sigma/test_pixart.py b/tests/pipelines/pixart_sigma/test_pixart.py index a92e99366ee3..6e265b9d5eb8 100644 --- a/tests/pipelines/pixart_sigma/test_pixart.py +++ b/tests/pipelines/pixart_sigma/test_pixart.py @@ -55,6 +55,7 @@ class PixArtSigmaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS required_optional_params = PipelineTesterMixin.required_optional_params + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/sana/test_sana.py b/tests/pipelines/sana/test_sana.py index 7109a700403c..f70f9d91f19c 100644 --- a/tests/pipelines/sana/test_sana.py +++ b/tests/pipelines/sana/test_sana.py @@ -52,6 +52,7 @@ class SanaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] ) test_xformers_attention = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index ccd5567106d2..1e700bed03f8 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -123,6 +123,7 @@ class StableDiffusionPipelineFastTests( image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS + test_layerwise_casting = True def get_dummy_components(self, time_cond_proj_dim=None): cross_attention_dim = 8 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py index e7114d19e208..10b8a1818a29 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py @@ -75,6 +75,7 @@ class StableDiffusion2PipelineFastTests( image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py index a6f718ae4fbb..df37090eeba2 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py @@ -35,6 +35,7 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin): ] ) batch_params = frozenset(["prompt", "negative_prompt"]) + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 8550f258045e..f1422022a7aa 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -75,6 +75,7 @@ class StableDiffusionXLPipelineFastTests( image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"}) + test_layerwise_casting = True def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 83b628e09f88..139778994b87 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -987,7 +987,7 @@ class PipelineTesterMixin: test_attention_slicing = True test_xformers_attention = True - + test_layerwise_casting = False supports_dduf = True def get_generator(self, seed): @@ -2027,6 +2027,21 @@ def test_save_load_dduf(self, atol=1e-4, rtol=1e-4): elif isinstance(pipeline_out, torch.Tensor) and isinstance(loaded_pipeline_out, torch.Tensor): assert torch.allclose(pipeline_out, loaded_pipeline_out, atol=atol, rtol=rtol) + def test_layerwise_casting_inference(self): + if not self.test_layerwise_casting: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device, dtype=torch.bfloat16) + pipe.set_progress_bar_config(disable=None) + + denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet + denoiser.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) + + inputs = self.get_dummy_inputs(torch_device) + _ = pipe(**inputs)[0] + @is_staging_test class PipelinePushToHubTester(unittest.TestCase):