From 36b0c37dd6e8bd3a3ca05c0270d1b6cf89d3018a Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 22 Dec 2024 12:30:20 +0100 Subject: [PATCH 01/45] update --- src/diffusers/models/hooks.py | 340 ++++++++++++++++++++++++++++++++++ 1 file changed, 340 insertions(+) create mode 100644 src/diffusers/models/hooks.py diff --git a/src/diffusers/models/hooks.py b/src/diffusers/models/hooks.py new file mode 100644 index 000000000000..dd9674fbe30a --- /dev/null +++ b/src/diffusers/models/hooks.py @@ -0,0 +1,340 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import re +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, Tuple, List, Type + +import torch + +from ..utils import get_logger + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +# Reference: https://github.com/huggingface/accelerate/blob/ba7ab93f5e688466ea56908ea3b056fae2f9a023/src/accelerate/hooks.py +class ModelHook: + r""" + A hook that contains callbacks to be executed just before and after the forward method of a model. The difference + with PyTorch existing hooks is that they get passed along the kwargs. + """ + + def init_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when a model is initialized. + Args: + module (`torch.nn.Module`): + The module attached to this hook. + """ + return module + + def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: + r""" + Hook that is executed just before the forward method of the model. + Args: + module (`torch.nn.Module`): + The module whose forward pass will be executed just after this event. + args (`Tuple[Any]`): + The positional arguments passed to the module. + kwargs (`Dict[Str, Any]`): + The keyword arguments passed to the module. + Returns: + `Tuple[Tuple[Any], Dict[Str, Any]]`: + A tuple with the treated `args` and `kwargs`. + """ + return args, kwargs + + def post_forward(self, module: torch.nn.Module, output: Any) -> Any: + r""" + Hook that is executed just after the forward method of the model. + Args: + module (`torch.nn.Module`): + The module whose forward pass been executed just before this event. + output (`Any`): + The output of the module. + Returns: + `Any`: The processed `output`. + """ + return output + + def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when the hook is detached from a module. + Args: + module (`torch.nn.Module`): + The module detached from this hook. + """ + return module + + +class SequentialHook(ModelHook): + r"""A hook that can contain several hooks and iterates through them at each event.""" + + def __init__(self, *hooks): + self.hooks = hooks + + def init_hook(self, module): + for hook in self.hooks: + module = hook.init_hook(module) + return module + + def pre_forward(self, module, *args, **kwargs): + for hook in self.hooks: + args, kwargs = hook.pre_forward(module, *args, **kwargs) + return args, kwargs + + def post_forward(self, module, output): + for hook in self.hooks: + output = hook.post_forward(module, output) + return output + + def detach_hook(self, module): + for hook in self.hooks: + module = hook.detach_hook(module) + return module + + +class LayerwiseUpcastingHook(ModelHook): + r""" + A hook that cast the input tensors and torch.nn.Module to a pre-specified dtype before the forward pass + and cast the module back to the original dtype after the forward pass. This is useful when a model is + loaded/stored in a lower precision dtype but performs computation in a higher precision dtype. This + process may lead to quality loss in the output, but can significantly reduce the memory footprint. + """ + + def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype) -> None: + self.storage_dtype = storage_dtype + self.compute_dtype = compute_dtype + + def init_hook(self, module: torch.nn.Module): + module.to(dtype=self.storage_dtype) + return module + + @torch._dynamo.disable(recursive=False) + def pre_forward(self, module: torch.nn.Module, *args, **kwargs): + module.to(dtype=self.compute_dtype) + # How do we account for LongTensor, BoolTensor, etc.? + # args = tuple(align_maybe_tensor_dtype(arg, self.compute_dtype) for arg in args) + # kwargs = {k: align_maybe_tensor_dtype(v, self.compute_dtype) for k, v in kwargs.items()} + return args, kwargs + + @torch._dynamo.disable(recursive=False) + def post_forward(self, module: torch.nn.Module, output): + module.to(dtype=self.storage_dtype) + return output + + +def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False): + r""" + Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove + this behavior and restore the original `forward` method, use `remove_hook_from_module`. + + If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks + together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class. + + Args: + module (`torch.nn.Module`): + The module to attach a hook to. + hook (`ModelHook`): + The hook to attach. + append (`bool`, *optional*, defaults to `False`): + Whether the hook should be chained with an existing one (if module already contains a hook) or not. + Returns: + `torch.nn.Module`: + The same module, with the hook attached (the module is modified in place, so the result can be discarded). + """ + original_hook = hook + + if append and getattr(module, "_diffusers_hook", None) is not None: + old_hook = module._diffusers_hook + remove_hook_from_module(module) + hook = SequentialHook(old_hook, hook) + + if hasattr(module, "_diffusers_hook") and hasattr(module, "_old_forward"): + # If we already put some hook on this module, we replace it with the new one. + old_forward = module._old_forward + else: + old_forward = module.forward + module._old_forward = old_forward + + module = hook.init_hook(module) + module._diffusers_hook = hook + + if hasattr(original_hook, "new_forward"): + new_forward = original_hook.new_forward + else: + + def new_forward(module, *args, **kwargs): + args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) + output = module._old_forward(*args, **kwargs) + return module._diffusers_hook.post_forward(module, output) + + # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail. + # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409 + if "GraphModuleImpl" in str(type(module)): + module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward) + else: + module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward) + + return module + + +def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> torch.nn.Module: + """ + Removes any hook attached to a module via `add_hook_to_module`. + Args: + module (`torch.nn.Module`): + The module to attach a hook to. + recurse (`bool`, defaults to `False`): + Whether to remove the hooks recursively + Returns: + `torch.nn.Module`: + The same module, with the hook detached (the module is modified in place, so the result can be discarded). + """ + + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook.detach_hook(module) + delattr(module, "_diffusers_hook") + + if hasattr(module, "_old_forward"): + # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail. + # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409 + if "GraphModuleImpl" in str(type(module)): + module.__class__.forward = module._old_forward + else: + module.forward = module._old_forward + delattr(module, "_old_forward") + + if recurse: + for child in module.children(): + remove_hook_from_module(child, recurse) + + return module + + +def align_maybe_tensor_dtype(input: Any, dtype: torch.dtype) -> Any: + r""" + Aligns the dtype of a tensor or a list of tensors to a given dtype. + Args: + input (`Any`): + The input tensor, list of tensors, or dictionary of tensors to align. If the input is neither + of these types, it will be returned as is. + dtype (`torch.dtype`): + The dtype to align the tensor(s) to. + Returns: + `Any`: + The tensor or list of tensors aligned to the given dtype. + """ + if isinstance(input, torch.Tensor): + return input.to(dtype=dtype) + if isinstance(input, (list, tuple)): + return [align_maybe_tensor_dtype(t, dtype) for t in input] + if isinstance(input, dict): + return {k: align_maybe_tensor_dtype(v, dtype) for k, v in input.items()} + return input + + +class LayerwiseUpcastingGranualarity(str, Enum): + DIFFUSERS_MODEL = "diffusers_model" + DIFFUSERS_LAYER = "diffusers_layer" + PYTORCH_LAYER = "pytorch_layer" + +# fmt: off +_SUPPORTED_PYTORCH_LAYERS = [ + torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, + torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, + torch.nn.Linear, +] +# fmt: on + + +def apply_layerwise_upcasting_hook(module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype) -> torch.nn.Module: + r""" + Applies a `LayerwiseUpcastingHook` to a given module. + + Args: + module (`torch.nn.Module`): + The module to attach the hook to. + storage_dtype (`torch.dtype`): + The dtype to cast the module to before the forward pass. + compute_dtype (`torch.dtype`): + The dtype to cast the module to during the forward pass. + + Returns: + `torch.nn.Module`: + The same module, with the hook attached (the module is modified in place, so the result can be discarded). + """ + hook = LayerwiseUpcastingHook(storage_dtype, compute_dtype) + return add_hook_to_module(module, hook, append=True) + + +def apply_layerwise_upcasting( + module: torch.nn.Module, + storage_dtype: torch.dtype, + compute_dtype: torch.dtype, + granularity: LayerwiseUpcastingGranualarity = LayerwiseUpcastingGranualarity.PYTORCH_LAYER, + skip_modules_pattern: List[str] = [], + skip_modules_classes: List[Type[torch.nn.Module]] = [], +) -> torch.nn.Module: + if granularity == LayerwiseUpcastingGranualarity.DIFFUSERS_MODEL: + return _apply_layerwise_upcasting_diffusers_model(module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes) + if granularity == LayerwiseUpcastingGranualarity.DIFFUSERS_LAYER: + raise NotImplementedError(f"{LayerwiseUpcastingGranualarity.DIFFUSERS_LAYER} is not yet supported") + if granularity == LayerwiseUpcastingGranualarity.PYTORCH_LAYER: + return _apply_layerwise_upcasting_pytorch_layer(module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes) + + +def _apply_layerwise_upcasting_diffusers_model( + module: torch.nn.Module, + storage_dtype: torch.dtype, + compute_dtype: torch.dtype, + skip_modules_pattern: List[str] = [], + skip_modules_classes: List[Type[torch.nn.Module]] = [], +) -> torch.nn.Module: + from .modeling_utils import ModelMixin + + for name, submodule in module.named_modules(): + if ( + any(re.search(pattern, name) for pattern in skip_modules_pattern) + or any(isinstance(submodule, module_class) for module_class in skip_modules_classes) + or not isinstance(submodule, ModelMixin) + ): + logger.debug(f"Skipping layerwise upcasting for layer \"{name}\"") + continue + logger.debug(f"Applying layerwise upcasting to layer \"{name}\"") + apply_layerwise_upcasting_hook(submodule, storage_dtype, compute_dtype) + return module + + +def _apply_layerwise_upcasting_pytorch_layer( + module: torch.nn.Module, + storage_dtype: torch.dtype, + compute_dtype: torch.dtype, + skip_modules_pattern: List[str] = [], + skip_modules_classes: List[Type[torch.nn.Module]] = [], +) -> torch.nn.Module: + for name, submodule in module.named_modules(): + if ( + any(re.search(pattern, name) for pattern in skip_modules_pattern) + or any(isinstance(submodule, module_class) for module_class in skip_modules_classes) + or not isinstance(submodule, tuple(_SUPPORTED_PYTORCH_LAYERS)) + ): + logger.debug(f"Skipping layerwise upcasting for layer \"{name}\"") + continue + logger.debug(f"Applying layerwise upcasting to layer \"{name}\"") + apply_layerwise_upcasting_hook(submodule, storage_dtype, compute_dtype) + return module From 42046c090f811d7fb295128bce4826cab5540984 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 23 Dec 2024 00:17:34 +0100 Subject: [PATCH 02/45] update --- src/diffusers/models/hooks.py | 69 ++++++++++++++++++++++++++++++++--- 1 file changed, 63 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/hooks.py b/src/diffusers/models/hooks.py index dd9674fbe30a..3fa52d91d1bd 100644 --- a/src/diffusers/models/hooks.py +++ b/src/diffusers/models/hooks.py @@ -21,6 +21,8 @@ import torch from ..utils import get_logger +from .attention import FeedForward, LuminaFeedForward +from .embeddings import LuminaPatchEmbed, CogVideoXPatchEmbed, CogView3PlusPatchEmbed, TimestepEmbedding, HunyuanDiTAttentionPool, AttentionPooling, MochiAttentionPool, GLIGENTextBoundingboxProjection, PixArtAlphaTextProjection logger = get_logger(__name__) # pylint: disable=invalid-name @@ -249,16 +251,58 @@ def align_maybe_tensor_dtype(input: Any, dtype: torch.dtype) -> Any: class LayerwiseUpcastingGranualarity(str, Enum): + r""" + An enumeration class that defines the granularity of the layerwise upcasting process. + + Granularity can be one of the following: + - `DIFFUSERS_MODEL`: + Applies layerwise upcasting to the entire model at the highest diffusers modeling level. This + will cast all the layers of model to the specified storage dtype. This results in the lowest + memory usage for storing the model in memory, but may incur significant loss in quality because + layers that perform normalization with learned parameters (e.g., RMSNorm with elementwise affinity) + are cast to a lower dtype, but this is known to cause quality issues. This method will not reduce the + memory required for the forward pass (which comprises of intermediate activations and gradients) of a + given modeling component, but may be useful in cases like lowering the memory footprint of text + encoders in a pipeline. + - `DIFFUSERS_BLOCK`: + TODO??? + - `DIFFUSERS_LAYER`: + Applies layerwise upcasting to the lower-level diffusers layers of the model. This is more granular + than the `DIFFUSERS_MODEL` level, but less granular than the `PYTORCH_LAYER` level. This method is + applied to only those layers that are a group of linear layers, while excluding precision-critical + layers like modulation and normalization layers. + - `PYTORCH_LAYER`: + Applies layerwise upcasting to lower-level PyTorch primitive layers of the model. This is the most + granular level of layerwise upcasting. The memory footprint for inference and training is greatly + reduced, while also ensuring important operations like normalization with learned parameters remain + unaffected from the downcasting/upcasting process, by default. As not all parameters are casted to + lower precision, the memory footprint for storing the model may be slightly higher than the alternatives. + This method causes the highest number of casting operations, which may contribute to a slight increase + in the overall computation time. + + Note: try and ensure that precision-critical layers like modulation and normalization layers are not casted + to lower precision, as this may lead to significant quality loss. + """ + DIFFUSERS_MODEL = "diffusers_model" DIFFUSERS_LAYER = "diffusers_layer" PYTORCH_LAYER = "pytorch_layer" # fmt: off +_SUPPORTED_DIFFUSERS_LAYERS = [ + AttentionPooling, MochiAttentionPool, HunyuanDiTAttentionPool, + CogVideoXPatchEmbed, CogView3PlusPatchEmbed, LuminaPatchEmbed, + TimestepEmbedding, GLIGENTextBoundingboxProjection, PixArtAlphaTextProjection, + FeedForward, LuminaFeedForward, +] + _SUPPORTED_PYTORCH_LAYERS = [ torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, torch.nn.Linear, ] + +_DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN = ["pos_embed", "patch_embed", "norm"] # fmt: on @@ -291,9 +335,9 @@ def apply_layerwise_upcasting( skip_modules_classes: List[Type[torch.nn.Module]] = [], ) -> torch.nn.Module: if granularity == LayerwiseUpcastingGranualarity.DIFFUSERS_MODEL: - return _apply_layerwise_upcasting_diffusers_model(module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes) + return _apply_layerwise_upcasting_diffusers_model(module, storage_dtype, compute_dtype) if granularity == LayerwiseUpcastingGranualarity.DIFFUSERS_LAYER: - raise NotImplementedError(f"{LayerwiseUpcastingGranualarity.DIFFUSERS_LAYER} is not yet supported") + return _apply_layerwise_upcasting_diffusers_layer(module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes) if granularity == LayerwiseUpcastingGranualarity.PYTORCH_LAYER: return _apply_layerwise_upcasting_pytorch_layer(module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes) @@ -302,16 +346,29 @@ def _apply_layerwise_upcasting_diffusers_model( module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, - skip_modules_pattern: List[str] = [], - skip_modules_classes: List[Type[torch.nn.Module]] = [], ) -> torch.nn.Module: from .modeling_utils import ModelMixin + if not isinstance(module, ModelMixin): + raise ValueError("The input module must be an instance of ModelMixin") + + logger.debug(f"Applying layerwise upcasting to model \"{module.__class__.__name__}\"") + apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype) + return module + + +def _apply_layerwise_upcasting_diffusers_layer( + module: torch.nn.Module, + storage_dtype: torch.dtype, + compute_dtype: torch.dtype, + skip_modules_pattern: List[str] = _DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN, + skip_modules_classes: List[Type[torch.nn.Module]] = [], +) -> torch.nn.Module: for name, submodule in module.named_modules(): if ( any(re.search(pattern, name) for pattern in skip_modules_pattern) or any(isinstance(submodule, module_class) for module_class in skip_modules_classes) - or not isinstance(submodule, ModelMixin) + or not isinstance(submodule, tuple(_SUPPORTED_DIFFUSERS_LAYERS)) ): logger.debug(f"Skipping layerwise upcasting for layer \"{name}\"") continue @@ -324,7 +381,7 @@ def _apply_layerwise_upcasting_pytorch_layer( module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, - skip_modules_pattern: List[str] = [], + skip_modules_pattern: List[str] = _DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN, skip_modules_classes: List[Type[torch.nn.Module]] = [], ) -> torch.nn.Module: for name, submodule in module.named_modules(): From 7dc739b1ced0ff162691762cc559cbfe261534b9 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 23 Dec 2024 00:17:49 +0100 Subject: [PATCH 03/45] make style --- src/diffusers/models/hooks.py | 109 +++++++++++++++++++++------------- 1 file changed, 67 insertions(+), 42 deletions(-) diff --git a/src/diffusers/models/hooks.py b/src/diffusers/models/hooks.py index 3fa52d91d1bd..16cf57408276 100644 --- a/src/diffusers/models/hooks.py +++ b/src/diffusers/models/hooks.py @@ -14,15 +14,24 @@ import functools import re -from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, Tuple, List, Type +from typing import Any, Dict, List, Tuple, Type import torch from ..utils import get_logger from .attention import FeedForward, LuminaFeedForward -from .embeddings import LuminaPatchEmbed, CogVideoXPatchEmbed, CogView3PlusPatchEmbed, TimestepEmbedding, HunyuanDiTAttentionPool, AttentionPooling, MochiAttentionPool, GLIGENTextBoundingboxProjection, PixArtAlphaTextProjection +from .embeddings import ( + AttentionPooling, + CogVideoXPatchEmbed, + CogView3PlusPatchEmbed, + GLIGENTextBoundingboxProjection, + HunyuanDiTAttentionPool, + LuminaPatchEmbed, + MochiAttentionPool, + PixArtAlphaTextProjection, + TimestepEmbedding, +) logger = get_logger(__name__) # pylint: disable=invalid-name @@ -38,6 +47,7 @@ class ModelHook: def init_hook(self, module: torch.nn.Module) -> torch.nn.Module: r""" Hook that is executed when a model is initialized. + Args: module (`torch.nn.Module`): The module attached to this hook. @@ -47,6 +57,7 @@ def init_hook(self, module: torch.nn.Module) -> torch.nn.Module: def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: r""" Hook that is executed just before the forward method of the model. + Args: module (`torch.nn.Module`): The module whose forward pass will be executed just after this event. @@ -63,6 +74,7 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[A def post_forward(self, module: torch.nn.Module, output: Any) -> Any: r""" Hook that is executed just after the forward method of the model. + Args: module (`torch.nn.Module`): The module whose forward pass been executed just before this event. @@ -76,6 +88,7 @@ def post_forward(self, module: torch.nn.Module, output: Any) -> Any: def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: r""" Hook that is executed when the hook is detached from a module. + Args: module (`torch.nn.Module`): The module detached from this hook. @@ -112,10 +125,10 @@ def detach_hook(self, module): class LayerwiseUpcastingHook(ModelHook): r""" - A hook that cast the input tensors and torch.nn.Module to a pre-specified dtype before the forward pass - and cast the module back to the original dtype after the forward pass. This is useful when a model is - loaded/stored in a lower precision dtype but performs computation in a higher precision dtype. This - process may lead to quality loss in the output, but can significantly reduce the memory footprint. + A hook that cast the input tensors and torch.nn.Module to a pre-specified dtype before the forward pass and cast + the module back to the original dtype after the forward pass. This is useful when a model is loaded/stored in a + lower precision dtype but performs computation in a higher precision dtype. This process may lead to quality loss + in the output, but can significantly reduce the memory footprint. """ def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype) -> None: @@ -144,10 +157,14 @@ def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = r""" Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove this behavior and restore the original `forward` method, use `remove_hook_from_module`. + + If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class. + + Args: module (`torch.nn.Module`): The module to attach a hook to. @@ -198,6 +215,7 @@ def new_forward(module, *args, **kwargs): def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> torch.nn.Module: """ Removes any hook attached to a module via `add_hook_to_module`. + Args: module (`torch.nn.Module`): The module to attach a hook to. @@ -231,10 +249,11 @@ def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> t def align_maybe_tensor_dtype(input: Any, dtype: torch.dtype) -> Any: r""" Aligns the dtype of a tensor or a list of tensors to a given dtype. + Args: input (`Any`): - The input tensor, list of tensors, or dictionary of tensors to align. If the input is neither - of these types, it will be returned as is. + The input tensor, list of tensors, or dictionary of tensors to align. If the input is neither of these + types, it will be returned as is. dtype (`torch.dtype`): The dtype to align the tensor(s) to. Returns: @@ -256,38 +275,38 @@ class LayerwiseUpcastingGranualarity(str, Enum): Granularity can be one of the following: - `DIFFUSERS_MODEL`: - Applies layerwise upcasting to the entire model at the highest diffusers modeling level. This - will cast all the layers of model to the specified storage dtype. This results in the lowest - memory usage for storing the model in memory, but may incur significant loss in quality because - layers that perform normalization with learned parameters (e.g., RMSNorm with elementwise affinity) - are cast to a lower dtype, but this is known to cause quality issues. This method will not reduce the - memory required for the forward pass (which comprises of intermediate activations and gradients) of a - given modeling component, but may be useful in cases like lowering the memory footprint of text - encoders in a pipeline. + Applies layerwise upcasting to the entire model at the highest diffusers modeling level. This will cast all + the layers of model to the specified storage dtype. This results in the lowest memory usage for storing the + model in memory, but may incur significant loss in quality because layers that perform normalization with + learned parameters (e.g., RMSNorm with elementwise affinity) are cast to a lower dtype, but this is known + to cause quality issues. This method will not reduce the memory required for the forward pass (which + comprises of intermediate activations and gradients) of a given modeling component, but may be useful in + cases like lowering the memory footprint of text encoders in a pipeline. - `DIFFUSERS_BLOCK`: TODO??? - `DIFFUSERS_LAYER`: - Applies layerwise upcasting to the lower-level diffusers layers of the model. This is more granular - than the `DIFFUSERS_MODEL` level, but less granular than the `PYTORCH_LAYER` level. This method is - applied to only those layers that are a group of linear layers, while excluding precision-critical - layers like modulation and normalization layers. + Applies layerwise upcasting to the lower-level diffusers layers of the model. This is more granular than + the `DIFFUSERS_MODEL` level, but less granular than the `PYTORCH_LAYER` level. This method is applied to + only those layers that are a group of linear layers, while excluding precision-critical layers like + modulation and normalization layers. - `PYTORCH_LAYER`: - Applies layerwise upcasting to lower-level PyTorch primitive layers of the model. This is the most - granular level of layerwise upcasting. The memory footprint for inference and training is greatly - reduced, while also ensuring important operations like normalization with learned parameters remain - unaffected from the downcasting/upcasting process, by default. As not all parameters are casted to - lower precision, the memory footprint for storing the model may be slightly higher than the alternatives. - This method causes the highest number of casting operations, which may contribute to a slight increase - in the overall computation time. - - Note: try and ensure that precision-critical layers like modulation and normalization layers are not casted - to lower precision, as this may lead to significant quality loss. + Applies layerwise upcasting to lower-level PyTorch primitive layers of the model. This is the most granular + level of layerwise upcasting. The memory footprint for inference and training is greatly reduced, while + also ensuring important operations like normalization with learned parameters remain unaffected from the + downcasting/upcasting process, by default. As not all parameters are casted to lower precision, the memory + footprint for storing the model may be slightly higher than the alternatives. This method causes the + highest number of casting operations, which may contribute to a slight increase in the overall computation + time. + + Note: try and ensure that precision-critical layers like modulation and normalization layers are not casted to + lower precision, as this may lead to significant quality loss. """ - + DIFFUSERS_MODEL = "diffusers_model" DIFFUSERS_LAYER = "diffusers_layer" PYTORCH_LAYER = "pytorch_layer" + # fmt: off _SUPPORTED_DIFFUSERS_LAYERS = [ AttentionPooling, MochiAttentionPool, HunyuanDiTAttentionPool, @@ -306,10 +325,12 @@ class LayerwiseUpcastingGranualarity(str, Enum): # fmt: on -def apply_layerwise_upcasting_hook(module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype) -> torch.nn.Module: +def apply_layerwise_upcasting_hook( + module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype +) -> torch.nn.Module: r""" Applies a `LayerwiseUpcastingHook` to a given module. - + Args: module (`torch.nn.Module`): The module to attach the hook to. @@ -317,7 +338,7 @@ def apply_layerwise_upcasting_hook(module: torch.nn.Module, storage_dtype: torch The dtype to cast the module to before the forward pass. compute_dtype (`torch.dtype`): The dtype to cast the module to during the forward pass. - + Returns: `torch.nn.Module`: The same module, with the hook attached (the module is modified in place, so the result can be discarded). @@ -337,9 +358,13 @@ def apply_layerwise_upcasting( if granularity == LayerwiseUpcastingGranualarity.DIFFUSERS_MODEL: return _apply_layerwise_upcasting_diffusers_model(module, storage_dtype, compute_dtype) if granularity == LayerwiseUpcastingGranualarity.DIFFUSERS_LAYER: - return _apply_layerwise_upcasting_diffusers_layer(module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes) + return _apply_layerwise_upcasting_diffusers_layer( + module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes + ) if granularity == LayerwiseUpcastingGranualarity.PYTORCH_LAYER: - return _apply_layerwise_upcasting_pytorch_layer(module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes) + return _apply_layerwise_upcasting_pytorch_layer( + module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes + ) def _apply_layerwise_upcasting_diffusers_model( @@ -352,7 +377,7 @@ def _apply_layerwise_upcasting_diffusers_model( if not isinstance(module, ModelMixin): raise ValueError("The input module must be an instance of ModelMixin") - logger.debug(f"Applying layerwise upcasting to model \"{module.__class__.__name__}\"") + logger.debug(f'Applying layerwise upcasting to model "{module.__class__.__name__}"') apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype) return module @@ -370,9 +395,9 @@ def _apply_layerwise_upcasting_diffusers_layer( or any(isinstance(submodule, module_class) for module_class in skip_modules_classes) or not isinstance(submodule, tuple(_SUPPORTED_DIFFUSERS_LAYERS)) ): - logger.debug(f"Skipping layerwise upcasting for layer \"{name}\"") + logger.debug(f'Skipping layerwise upcasting for layer "{name}"') continue - logger.debug(f"Applying layerwise upcasting to layer \"{name}\"") + logger.debug(f'Applying layerwise upcasting to layer "{name}"') apply_layerwise_upcasting_hook(submodule, storage_dtype, compute_dtype) return module @@ -390,8 +415,8 @@ def _apply_layerwise_upcasting_pytorch_layer( or any(isinstance(submodule, module_class) for module_class in skip_modules_classes) or not isinstance(submodule, tuple(_SUPPORTED_PYTORCH_LAYERS)) ): - logger.debug(f"Skipping layerwise upcasting for layer \"{name}\"") + logger.debug(f'Skipping layerwise upcasting for layer "{name}"') continue - logger.debug(f"Applying layerwise upcasting to layer \"{name}\"") + logger.debug(f'Applying layerwise upcasting to layer "{name}"') apply_layerwise_upcasting_hook(submodule, storage_dtype, compute_dtype) return module From 1fa4ee5025a010feeeb9a0d51f145ba34cc308c5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 23 Dec 2024 01:20:21 +0100 Subject: [PATCH 04/45] remove dynamo disable --- src/diffusers/models/hooks.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/models/hooks.py b/src/diffusers/models/hooks.py index 16cf57408276..c8758e37e5f8 100644 --- a/src/diffusers/models/hooks.py +++ b/src/diffusers/models/hooks.py @@ -139,7 +139,6 @@ def init_hook(self, module: torch.nn.Module): module.to(dtype=self.storage_dtype) return module - @torch._dynamo.disable(recursive=False) def pre_forward(self, module: torch.nn.Module, *args, **kwargs): module.to(dtype=self.compute_dtype) # How do we account for LongTensor, BoolTensor, etc.? @@ -147,7 +146,6 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs): # kwargs = {k: align_maybe_tensor_dtype(v, self.compute_dtype) for k, v in kwargs.items()} return args, kwargs - @torch._dynamo.disable(recursive=False) def post_forward(self, module: torch.nn.Module, output): module.to(dtype=self.storage_dtype) return output From da4907ea5aaf7bfeb3afd44832531d17518d3e44 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 23 Dec 2024 01:20:51 +0100 Subject: [PATCH 05/45] add coauthor Co-Authored-By: Dhruv Nair From bc2ada4dd7a756ffcd4d0c9af5202430502c9878 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 23 Dec 2024 01:33:07 +0100 Subject: [PATCH 06/45] update --- src/diffusers/models/hooks.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/hooks.py b/src/diffusers/models/hooks.py index c8758e37e5f8..8a83d47e1f69 100644 --- a/src/diffusers/models/hooks.py +++ b/src/diffusers/models/hooks.py @@ -267,7 +267,7 @@ def align_maybe_tensor_dtype(input: Any, dtype: torch.dtype) -> Any: return input -class LayerwiseUpcastingGranualarity(str, Enum): +class LayerwiseUpcastingGranularity(str, Enum): r""" An enumeration class that defines the granularity of the layerwise upcasting process. @@ -349,17 +349,17 @@ def apply_layerwise_upcasting( module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, - granularity: LayerwiseUpcastingGranualarity = LayerwiseUpcastingGranualarity.PYTORCH_LAYER, + granularity: LayerwiseUpcastingGranularity = LayerwiseUpcastingGranularity.PYTORCH_LAYER, skip_modules_pattern: List[str] = [], skip_modules_classes: List[Type[torch.nn.Module]] = [], ) -> torch.nn.Module: - if granularity == LayerwiseUpcastingGranualarity.DIFFUSERS_MODEL: + if granularity == LayerwiseUpcastingGranularity.DIFFUSERS_MODEL: return _apply_layerwise_upcasting_diffusers_model(module, storage_dtype, compute_dtype) - if granularity == LayerwiseUpcastingGranualarity.DIFFUSERS_LAYER: + if granularity == LayerwiseUpcastingGranularity.DIFFUSERS_LAYER: return _apply_layerwise_upcasting_diffusers_layer( module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes ) - if granularity == LayerwiseUpcastingGranualarity.PYTORCH_LAYER: + if granularity == LayerwiseUpcastingGranularity.PYTORCH_LAYER: return _apply_layerwise_upcasting_pytorch_layer( module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes ) From 7c31bb03f34b1d2198a3a80a35330a28dbe819c9 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 2 Jan 2025 21:00:31 +0100 Subject: [PATCH 07/45] update --- src/diffusers/models/hooks.py | 31 +------------------------------ 1 file changed, 1 insertion(+), 30 deletions(-) diff --git a/src/diffusers/models/hooks.py b/src/diffusers/models/hooks.py index 8a83d47e1f69..7d363cfaf84d 100644 --- a/src/diffusers/models/hooks.py +++ b/src/diffusers/models/hooks.py @@ -272,19 +272,8 @@ class LayerwiseUpcastingGranularity(str, Enum): An enumeration class that defines the granularity of the layerwise upcasting process. Granularity can be one of the following: - - `DIFFUSERS_MODEL`: - Applies layerwise upcasting to the entire model at the highest diffusers modeling level. This will cast all - the layers of model to the specified storage dtype. This results in the lowest memory usage for storing the - model in memory, but may incur significant loss in quality because layers that perform normalization with - learned parameters (e.g., RMSNorm with elementwise affinity) are cast to a lower dtype, but this is known - to cause quality issues. This method will not reduce the memory required for the forward pass (which - comprises of intermediate activations and gradients) of a given modeling component, but may be useful in - cases like lowering the memory footprint of text encoders in a pipeline. - - `DIFFUSERS_BLOCK`: - TODO??? - `DIFFUSERS_LAYER`: - Applies layerwise upcasting to the lower-level diffusers layers of the model. This is more granular than - the `DIFFUSERS_MODEL` level, but less granular than the `PYTORCH_LAYER` level. This method is applied to + Applies layerwise upcasting to the lower-level diffusers layers of the model. This method is applied to only those layers that are a group of linear layers, while excluding precision-critical layers like modulation and normalization layers. - `PYTORCH_LAYER`: @@ -300,7 +289,6 @@ class LayerwiseUpcastingGranularity(str, Enum): lower precision, as this may lead to significant quality loss. """ - DIFFUSERS_MODEL = "diffusers_model" DIFFUSERS_LAYER = "diffusers_layer" PYTORCH_LAYER = "pytorch_layer" @@ -353,8 +341,6 @@ def apply_layerwise_upcasting( skip_modules_pattern: List[str] = [], skip_modules_classes: List[Type[torch.nn.Module]] = [], ) -> torch.nn.Module: - if granularity == LayerwiseUpcastingGranularity.DIFFUSERS_MODEL: - return _apply_layerwise_upcasting_diffusers_model(module, storage_dtype, compute_dtype) if granularity == LayerwiseUpcastingGranularity.DIFFUSERS_LAYER: return _apply_layerwise_upcasting_diffusers_layer( module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes @@ -365,21 +351,6 @@ def apply_layerwise_upcasting( ) -def _apply_layerwise_upcasting_diffusers_model( - module: torch.nn.Module, - storage_dtype: torch.dtype, - compute_dtype: torch.dtype, -) -> torch.nn.Module: - from .modeling_utils import ModelMixin - - if not isinstance(module, ModelMixin): - raise ValueError("The input module must be an instance of ModelMixin") - - logger.debug(f'Applying layerwise upcasting to model "{module.__class__.__name__}"') - apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype) - return module - - def _apply_layerwise_upcasting_diffusers_layer( module: torch.nn.Module, storage_dtype: torch.dtype, From 8975bbfbb7368533cd96bacab3ebedc40c6b2b7d Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 2 Jan 2025 21:10:45 +0100 Subject: [PATCH 08/45] update --- src/diffusers/__init__.py | 6 + src/diffusers/models/__init__.py | 5 + src/diffusers/models/hooks.py | 202 ++--------------- .../models/layerwise_upcasting_utils.py | 212 ++++++++++++++++++ src/diffusers/utils/dummy_pt_objects.py | 23 ++ 5 files changed, 268 insertions(+), 180 deletions(-) create mode 100644 src/diffusers/models/layerwise_upcasting_utils.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 5e9ab2a117d1..d51597a286a4 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -107,6 +107,7 @@ "I2VGenXLUNet", "Kandinsky3UNet", "LatteTransformer3DModel", + "LayerwiseUpcastingGranularity", "LTXVideoTransformer3DModel", "LuminaNextDiT2DModel", "MochiTransformer3DModel", @@ -135,6 +136,8 @@ "UNetSpatioTemporalConditionModel", "UVit2DModel", "VQModel", + "apply_layerwise_upcasting", + "apply_layerwise_upcasting_hook", ] ) _import_structure["optimization"] = [ @@ -617,6 +620,7 @@ I2VGenXLUNet, Kandinsky3UNet, LatteTransformer3DModel, + LayerwiseUpcastingGranularity, LTXVideoTransformer3DModel, LuminaNextDiT2DModel, MochiTransformer3DModel, @@ -644,6 +648,8 @@ UNetSpatioTemporalConditionModel, UVit2DModel, VQModel, + apply_layerwise_upcasting, + apply_layerwise_upcasting_hook, ) from .optimization import ( get_constant_schedule, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 01e67b01d91a..23f220ced20d 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -123,6 +123,11 @@ UNetControlNetXSModel, ) from .embeddings import ImageProjection + from .layerwise_upcasting_utils import ( + LayerwiseUpcastingGranularity, + apply_layerwise_upcasting, + apply_layerwise_upcasting_hook, + ) from .modeling_utils import ModelMixin from .transformers import ( AllegroTransformer3DModel, diff --git a/src/diffusers/models/hooks.py b/src/diffusers/models/hooks.py index 7d363cfaf84d..4b449620b13c 100644 --- a/src/diffusers/models/hooks.py +++ b/src/diffusers/models/hooks.py @@ -13,25 +13,11 @@ # limitations under the License. import functools -import re -from enum import Enum -from typing import Any, Dict, List, Tuple, Type +from typing import Any, Dict, Tuple import torch from ..utils import get_logger -from .attention import FeedForward, LuminaFeedForward -from .embeddings import ( - AttentionPooling, - CogVideoXPatchEmbed, - CogView3PlusPatchEmbed, - GLIGENTextBoundingboxProjection, - HunyuanDiTAttentionPool, - LuminaPatchEmbed, - MochiAttentionPool, - PixArtAlphaTextProjection, - TimestepEmbedding, -) logger = get_logger(__name__) # pylint: disable=invalid-name @@ -44,6 +30,8 @@ class ModelHook: with PyTorch existing hooks is that they get passed along the kwargs. """ + _is_stateful = False + def init_hook(self, module: torch.nn.Module) -> torch.nn.Module: r""" Hook that is executed when a model is initialized. @@ -95,6 +83,11 @@ def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: """ return module + def reset_state(self, module: torch.nn.Module): + if self._is_stateful: + raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") + return module + class SequentialHook(ModelHook): r"""A hook that can contain several hooks and iterates through them at each event.""" @@ -122,34 +115,12 @@ def detach_hook(self, module): module = hook.detach_hook(module) return module - -class LayerwiseUpcastingHook(ModelHook): - r""" - A hook that cast the input tensors and torch.nn.Module to a pre-specified dtype before the forward pass and cast - the module back to the original dtype after the forward pass. This is useful when a model is loaded/stored in a - lower precision dtype but performs computation in a higher precision dtype. This process may lead to quality loss - in the output, but can significantly reduce the memory footprint. - """ - - def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype) -> None: - self.storage_dtype = storage_dtype - self.compute_dtype = compute_dtype - - def init_hook(self, module: torch.nn.Module): - module.to(dtype=self.storage_dtype) + def reset_state(self, module): + for hook in self.hooks: + if hook._is_stateful: + hook.reset_state(module) return module - def pre_forward(self, module: torch.nn.Module, *args, **kwargs): - module.to(dtype=self.compute_dtype) - # How do we account for LongTensor, BoolTensor, etc.? - # args = tuple(align_maybe_tensor_dtype(arg, self.compute_dtype) for arg in args) - # kwargs = {k: align_maybe_tensor_dtype(v, self.compute_dtype) for k, v in kwargs.items()} - return args, kwargs - - def post_forward(self, module: torch.nn.Module, output): - module.to(dtype=self.storage_dtype) - return output - def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False): r""" @@ -244,148 +215,19 @@ def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> t return module -def align_maybe_tensor_dtype(input: Any, dtype: torch.dtype) -> Any: - r""" - Aligns the dtype of a tensor or a list of tensors to a given dtype. - - Args: - input (`Any`): - The input tensor, list of tensors, or dictionary of tensors to align. If the input is neither of these - types, it will be returned as is. - dtype (`torch.dtype`): - The dtype to align the tensor(s) to. - Returns: - `Any`: - The tensor or list of tensors aligned to the given dtype. - """ - if isinstance(input, torch.Tensor): - return input.to(dtype=dtype) - if isinstance(input, (list, tuple)): - return [align_maybe_tensor_dtype(t, dtype) for t in input] - if isinstance(input, dict): - return {k: align_maybe_tensor_dtype(v, dtype) for k, v in input.items()} - return input - - -class LayerwiseUpcastingGranularity(str, Enum): - r""" - An enumeration class that defines the granularity of the layerwise upcasting process. - - Granularity can be one of the following: - - `DIFFUSERS_LAYER`: - Applies layerwise upcasting to the lower-level diffusers layers of the model. This method is applied to - only those layers that are a group of linear layers, while excluding precision-critical layers like - modulation and normalization layers. - - `PYTORCH_LAYER`: - Applies layerwise upcasting to lower-level PyTorch primitive layers of the model. This is the most granular - level of layerwise upcasting. The memory footprint for inference and training is greatly reduced, while - also ensuring important operations like normalization with learned parameters remain unaffected from the - downcasting/upcasting process, by default. As not all parameters are casted to lower precision, the memory - footprint for storing the model may be slightly higher than the alternatives. This method causes the - highest number of casting operations, which may contribute to a slight increase in the overall computation - time. - - Note: try and ensure that precision-critical layers like modulation and normalization layers are not casted to - lower precision, as this may lead to significant quality loss. +def reset_stateful_hooks(module: torch.nn.Module, recurse: bool = False): """ - - DIFFUSERS_LAYER = "diffusers_layer" - PYTORCH_LAYER = "pytorch_layer" - - -# fmt: off -_SUPPORTED_DIFFUSERS_LAYERS = [ - AttentionPooling, MochiAttentionPool, HunyuanDiTAttentionPool, - CogVideoXPatchEmbed, CogView3PlusPatchEmbed, LuminaPatchEmbed, - TimestepEmbedding, GLIGENTextBoundingboxProjection, PixArtAlphaTextProjection, - FeedForward, LuminaFeedForward, -] - -_SUPPORTED_PYTORCH_LAYERS = [ - torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, - torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, - torch.nn.Linear, -] - -_DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN = ["pos_embed", "patch_embed", "norm"] -# fmt: on - - -def apply_layerwise_upcasting_hook( - module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype -) -> torch.nn.Module: - r""" - Applies a `LayerwiseUpcastingHook` to a given module. + Resets the state of all stateful hooks attached to a module. Args: module (`torch.nn.Module`): - The module to attach the hook to. - storage_dtype (`torch.dtype`): - The dtype to cast the module to before the forward pass. - compute_dtype (`torch.dtype`): - The dtype to cast the module to during the forward pass. - - Returns: - `torch.nn.Module`: - The same module, with the hook attached (the module is modified in place, so the result can be discarded). + The module to reset the stateful hooks from. """ - hook = LayerwiseUpcastingHook(storage_dtype, compute_dtype) - return add_hook_to_module(module, hook, append=True) - - -def apply_layerwise_upcasting( - module: torch.nn.Module, - storage_dtype: torch.dtype, - compute_dtype: torch.dtype, - granularity: LayerwiseUpcastingGranularity = LayerwiseUpcastingGranularity.PYTORCH_LAYER, - skip_modules_pattern: List[str] = [], - skip_modules_classes: List[Type[torch.nn.Module]] = [], -) -> torch.nn.Module: - if granularity == LayerwiseUpcastingGranularity.DIFFUSERS_LAYER: - return _apply_layerwise_upcasting_diffusers_layer( - module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes - ) - if granularity == LayerwiseUpcastingGranularity.PYTORCH_LAYER: - return _apply_layerwise_upcasting_pytorch_layer( - module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes - ) - - -def _apply_layerwise_upcasting_diffusers_layer( - module: torch.nn.Module, - storage_dtype: torch.dtype, - compute_dtype: torch.dtype, - skip_modules_pattern: List[str] = _DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN, - skip_modules_classes: List[Type[torch.nn.Module]] = [], -) -> torch.nn.Module: - for name, submodule in module.named_modules(): - if ( - any(re.search(pattern, name) for pattern in skip_modules_pattern) - or any(isinstance(submodule, module_class) for module_class in skip_modules_classes) - or not isinstance(submodule, tuple(_SUPPORTED_DIFFUSERS_LAYERS)) - ): - logger.debug(f'Skipping layerwise upcasting for layer "{name}"') - continue - logger.debug(f'Applying layerwise upcasting to layer "{name}"') - apply_layerwise_upcasting_hook(submodule, storage_dtype, compute_dtype) - return module + if hasattr(module, "_diffusers_hook") and ( + module._diffusers_hook._is_stateful or isinstance(module._diffusers_hook, SequentialHook) + ): + module._diffusers_hook.reset_state(module) - -def _apply_layerwise_upcasting_pytorch_layer( - module: torch.nn.Module, - storage_dtype: torch.dtype, - compute_dtype: torch.dtype, - skip_modules_pattern: List[str] = _DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN, - skip_modules_classes: List[Type[torch.nn.Module]] = [], -) -> torch.nn.Module: - for name, submodule in module.named_modules(): - if ( - any(re.search(pattern, name) for pattern in skip_modules_pattern) - or any(isinstance(submodule, module_class) for module_class in skip_modules_classes) - or not isinstance(submodule, tuple(_SUPPORTED_PYTORCH_LAYERS)) - ): - logger.debug(f'Skipping layerwise upcasting for layer "{name}"') - continue - logger.debug(f'Applying layerwise upcasting to layer "{name}"') - apply_layerwise_upcasting_hook(submodule, storage_dtype, compute_dtype) - return module + if recurse: + for child in module.children(): + reset_stateful_hooks(child, recurse) diff --git a/src/diffusers/models/layerwise_upcasting_utils.py b/src/diffusers/models/layerwise_upcasting_utils.py new file mode 100644 index 000000000000..7db07d62dd51 --- /dev/null +++ b/src/diffusers/models/layerwise_upcasting_utils.py @@ -0,0 +1,212 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from enum import Enum +from typing import Any, List, Type + +import torch + +from ..utils import get_logger +from .attention import FeedForward, LuminaFeedForward +from .embeddings import ( + AttentionPooling, + CogVideoXPatchEmbed, + CogView3PlusPatchEmbed, + GLIGENTextBoundingboxProjection, + HunyuanDiTAttentionPool, + LuminaPatchEmbed, + MochiAttentionPool, + PixArtAlphaTextProjection, + TimestepEmbedding, +) +from .hooks import ModelHook, add_hook_to_module + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +class LayerwiseUpcastingHook(ModelHook): + r""" + A hook that cast the input tensors and torch.nn.Module to a pre-specified dtype before the forward pass and cast + the module back to the original dtype after the forward pass. This is useful when a model is loaded/stored in a + lower precision dtype but performs computation in a higher precision dtype. This process may lead to quality loss + in the output, but can significantly reduce the memory footprint. + """ + + def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype) -> None: + self.storage_dtype = storage_dtype + self.compute_dtype = compute_dtype + + def init_hook(self, module: torch.nn.Module): + module.to(dtype=self.storage_dtype) + return module + + def pre_forward(self, module: torch.nn.Module, *args, **kwargs): + module.to(dtype=self.compute_dtype) + # How do we account for LongTensor, BoolTensor, etc.? + # args = tuple(align_maybe_tensor_dtype(arg, self.compute_dtype) for arg in args) + # kwargs = {k: align_maybe_tensor_dtype(v, self.compute_dtype) for k, v in kwargs.items()} + return args, kwargs + + def post_forward(self, module: torch.nn.Module, output): + module.to(dtype=self.storage_dtype) + return output + + +class LayerwiseUpcastingGranularity(str, Enum): + r""" + An enumeration class that defines the granularity of the layerwise upcasting process. + + Granularity can be one of the following: + - `DIFFUSERS_LAYER`: + Applies layerwise upcasting to the lower-level diffusers layers of the model. This method is applied to + only those layers that are a group of linear layers, while excluding precision-critical layers like + modulation and normalization layers. + - `PYTORCH_LAYER`: + Applies layerwise upcasting to lower-level PyTorch primitive layers of the model. This is the most granular + level of layerwise upcasting. The memory footprint for inference and training is greatly reduced, while + also ensuring important operations like normalization with learned parameters remain unaffected from the + downcasting/upcasting process, by default. As not all parameters are casted to lower precision, the memory + footprint for storing the model may be slightly higher than the alternatives. This method causes the + highest number of casting operations, which may contribute to a slight increase in the overall computation + time. + + Note: try and ensure that precision-critical layers like modulation and normalization layers are not casted to + lower precision, as this may lead to significant quality loss. + """ + + DIFFUSERS_LAYER = "diffusers_layer" + PYTORCH_LAYER = "pytorch_layer" + + +# fmt: off +_SUPPORTED_DIFFUSERS_LAYERS = [ + AttentionPooling, MochiAttentionPool, HunyuanDiTAttentionPool, + CogVideoXPatchEmbed, CogView3PlusPatchEmbed, LuminaPatchEmbed, + TimestepEmbedding, GLIGENTextBoundingboxProjection, PixArtAlphaTextProjection, + FeedForward, LuminaFeedForward, +] + +_SUPPORTED_PYTORCH_LAYERS = [ + torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, + torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, + torch.nn.Linear, +] + +_DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN = ["pos_embed", "patch_embed", "norm"] +# fmt: on + + +def apply_layerwise_upcasting( + module: torch.nn.Module, + storage_dtype: torch.dtype, + compute_dtype: torch.dtype, + granularity: LayerwiseUpcastingGranularity = LayerwiseUpcastingGranularity.PYTORCH_LAYER, + skip_modules_pattern: List[str] = [], + skip_modules_classes: List[Type[torch.nn.Module]] = [], +) -> torch.nn.Module: + if granularity == LayerwiseUpcastingGranularity.DIFFUSERS_LAYER: + return _apply_layerwise_upcasting_diffusers_layer( + module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes + ) + if granularity == LayerwiseUpcastingGranularity.PYTORCH_LAYER: + return _apply_layerwise_upcasting_pytorch_layer( + module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes + ) + + +def apply_layerwise_upcasting_hook( + module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype +) -> torch.nn.Module: + r""" + Applies a `LayerwiseUpcastingHook` to a given module. + + Args: + module (`torch.nn.Module`): + The module to attach the hook to. + storage_dtype (`torch.dtype`): + The dtype to cast the module to before the forward pass. + compute_dtype (`torch.dtype`): + The dtype to cast the module to during the forward pass. + + Returns: + `torch.nn.Module`: + The same module, with the hook attached (the module is modified in place, so the result can be discarded). + """ + hook = LayerwiseUpcastingHook(storage_dtype, compute_dtype) + return add_hook_to_module(module, hook, append=True) + + +def _apply_layerwise_upcasting_diffusers_layer( + module: torch.nn.Module, + storage_dtype: torch.dtype, + compute_dtype: torch.dtype, + skip_modules_pattern: List[str] = _DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN, + skip_modules_classes: List[Type[torch.nn.Module]] = [], +) -> torch.nn.Module: + for name, submodule in module.named_modules(): + if ( + any(re.search(pattern, name) for pattern in skip_modules_pattern) + or any(isinstance(submodule, module_class) for module_class in skip_modules_classes) + or not isinstance(submodule, tuple(_SUPPORTED_DIFFUSERS_LAYERS)) + ): + logger.debug(f'Skipping layerwise upcasting for layer "{name}"') + continue + logger.debug(f'Applying layerwise upcasting to layer "{name}"') + apply_layerwise_upcasting_hook(submodule, storage_dtype, compute_dtype) + return module + + +def _apply_layerwise_upcasting_pytorch_layer( + module: torch.nn.Module, + storage_dtype: torch.dtype, + compute_dtype: torch.dtype, + skip_modules_pattern: List[str] = _DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN, + skip_modules_classes: List[Type[torch.nn.Module]] = [], +) -> torch.nn.Module: + for name, submodule in module.named_modules(): + if ( + any(re.search(pattern, name) for pattern in skip_modules_pattern) + or any(isinstance(submodule, module_class) for module_class in skip_modules_classes) + or not isinstance(submodule, tuple(_SUPPORTED_PYTORCH_LAYERS)) + ): + logger.debug(f'Skipping layerwise upcasting for layer "{name}"') + continue + logger.debug(f'Applying layerwise upcasting to layer "{name}"') + apply_layerwise_upcasting_hook(submodule, storage_dtype, compute_dtype) + return module + + +def align_maybe_tensor_dtype(input: Any, dtype: torch.dtype) -> Any: + r""" + Aligns the dtype of a tensor or a list of tensors to a given dtype. + + Args: + input (`Any`): + The input tensor, list of tensors, or dictionary of tensors to align. If the input is neither of these + types, it will be returned as is. + dtype (`torch.dtype`): + The dtype to align the tensor(s) to. + Returns: + `Any`: + The tensor or list of tensors aligned to the given dtype. + """ + if isinstance(input, torch.Tensor): + return input.to(dtype=dtype) + if isinstance(input, (list, tuple)): + return [align_maybe_tensor_dtype(t, dtype) for t in input] + if isinstance(input, dict): + return {k: align_maybe_tensor_dtype(v, dtype) for k, v in input.items()} + return input diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 4b6ac10385cf..c8e533a36b34 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -452,6 +452,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class LayerwiseUpcastingGranularity(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class LTXVideoTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] @@ -857,6 +872,14 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +def apply_layerwise_upcasting(*args, **kwargs): + requires_backends(apply_layerwise_upcasting, ["torch"]) + + +def apply_layerwise_upcasting_hook(*args, **kwargs): + requires_backends(apply_layerwise_upcasting_hook, ["torch"]) + + def get_constant_schedule(*args, **kwargs): requires_backends(get_constant_schedule, ["torch"]) From 341fbfca1d4a1551b20383c3e297ce129076bfbc Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 2 Jan 2025 22:09:09 +0100 Subject: [PATCH 09/45] update mixin --- .../autoencoders/autoencoder_asym_kl.py | 2 + src/diffusers/models/autoencoders/vq_model.py | 2 + .../models/layerwise_upcasting_utils.py | 39 +++++++++--- src/diffusers/models/modeling_utils.py | 63 +++++++++++++++++++ .../transformers/auraflow_transformer_2d.py | 1 + .../transformers/cogvideox_transformer_3d.py | 1 + .../models/transformers/dit_transformer_2d.py | 1 + .../transformers/hunyuan_transformer_2d.py | 2 + .../transformers/latte_transformer_3d.py | 2 + .../models/transformers/lumina_nextdit2d.py | 2 + .../transformers/pixart_transformer_2d.py | 1 + .../models/transformers/sana_transformer.py | 1 + .../models/transformers/transformer_2d.py | 1 + .../transformers/transformer_allegro.py | 3 + .../transformers/transformer_cogview3plus.py | 1 + .../models/transformers/transformer_flux.py | 1 + .../transformers/transformer_hunyuan_video.py | 1 + .../models/transformers/transformer_ltx.py | 1 + .../models/transformers/transformer_mochi.py | 1 + .../models/transformers/transformer_sd3.py | 1 + .../transformers/transformer_temporal.py | 2 + src/diffusers/models/unets/unet_1d.py | 2 + src/diffusers/models/unets/unet_2d.py | 1 + .../models/unets/unet_2d_condition.py | 1 + .../models/unets/unet_3d_condition.py | 1 + .../models/unets/unet_motion_model.py | 1 + 26 files changed, 126 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_asym_kl.py b/src/diffusers/models/autoencoders/autoencoder_asym_kl.py index 3f4d46557bf7..3c16b766c23d 100644 --- a/src/diffusers/models/autoencoders/autoencoder_asym_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_asym_kl.py @@ -60,6 +60,8 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. """ + _always_upcast_modules = ["MaskConditionDecoder"] + @register_to_config def __init__( self, diff --git a/src/diffusers/models/autoencoders/vq_model.py b/src/diffusers/models/autoencoders/vq_model.py index ae8a118d719a..82e9dd8479a7 100644 --- a/src/diffusers/models/autoencoders/vq_model.py +++ b/src/diffusers/models/autoencoders/vq_model.py @@ -71,6 +71,8 @@ class VQModel(ModelMixin, ConfigMixin): Type of normalization layer to use. Can be one of `"group"` or `"spatial"`. """ + _always_upcast_modules = ["VectorQuantizer"] + @register_to_config def __init__( self, diff --git a/src/diffusers/models/layerwise_upcasting_utils.py b/src/diffusers/models/layerwise_upcasting_utils.py index 7db07d62dd51..20cd7c65f72f 100644 --- a/src/diffusers/models/layerwise_upcasting_utils.py +++ b/src/diffusers/models/layerwise_upcasting_utils.py @@ -45,6 +45,8 @@ class LayerwiseUpcastingHook(ModelHook): in the output, but can significantly reduce the memory footprint. """ + _is_stateful = False + def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype) -> None: self.storage_dtype = storage_dtype self.compute_dtype = compute_dtype @@ -56,8 +58,8 @@ def init_hook(self, module: torch.nn.Module): def pre_forward(self, module: torch.nn.Module, *args, **kwargs): module.to(dtype=self.compute_dtype) # How do we account for LongTensor, BoolTensor, etc.? - # args = tuple(align_maybe_tensor_dtype(arg, self.compute_dtype) for arg in args) - # kwargs = {k: align_maybe_tensor_dtype(v, self.compute_dtype) for k, v in kwargs.items()} + # args = tuple(_align_maybe_tensor_dtype(arg, self.compute_dtype) for arg in args) + # kwargs = {k: _align_maybe_tensor_dtype(v, self.compute_dtype) for k, v in kwargs.items()} return args, kwargs def post_forward(self, module: torch.nn.Module, output): @@ -105,7 +107,7 @@ class LayerwiseUpcastingGranularity(str, Enum): torch.nn.Linear, ] -_DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN = ["pos_embed", "patch_embed", "norm"] +_DEFAULT_SKIP_MODULES_PATTERN = ["pos_embed", "patch_embed", "norm"] # fmt: on @@ -114,9 +116,27 @@ def apply_layerwise_upcasting( storage_dtype: torch.dtype, compute_dtype: torch.dtype, granularity: LayerwiseUpcastingGranularity = LayerwiseUpcastingGranularity.PYTORCH_LAYER, - skip_modules_pattern: List[str] = [], + skip_modules_pattern: List[str] = _DEFAULT_SKIP_MODULES_PATTERN, skip_modules_classes: List[Type[torch.nn.Module]] = [], ) -> torch.nn.Module: + r""" + Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any + nn.Module using diffusers layers or pytorch primitives. + + Args: + module (`torch.nn.Module`): + The module to attach the hook to. + storage_dtype (`torch.dtype`): + The dtype to cast the module to before the forward pass. + compute_dtype (`torch.dtype`): + The dtype to cast the module to during the forward pass. + granularity (`LayerwiseUpcastingGranularity`, *optional*, defaults to `LayerwiseUpcastingGranularity.PYTORCH_LAYER`): + The granularity of the layerwise upcasting process. + skip_modules_pattern (`List[str]`, defaults to `["pos_embed", "patch_embed", "norm"]`): + A list of patterns to match the names of the modules to skip during the layerwise upcasting process. + skip_modules_classes (`List[Type[torch.nn.Module]]`, defaults to `[]`): + A list of module classes to skip during the layerwise upcasting process. + """ if granularity == LayerwiseUpcastingGranularity.DIFFUSERS_LAYER: return _apply_layerwise_upcasting_diffusers_layer( module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes @@ -153,7 +173,7 @@ def _apply_layerwise_upcasting_diffusers_layer( module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, - skip_modules_pattern: List[str] = _DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN, + skip_modules_pattern: List[str] = _DEFAULT_SKIP_MODULES_PATTERN, skip_modules_classes: List[Type[torch.nn.Module]] = [], ) -> torch.nn.Module: for name, submodule in module.named_modules(): @@ -173,7 +193,7 @@ def _apply_layerwise_upcasting_pytorch_layer( module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, - skip_modules_pattern: List[str] = _DEFAULT_PYTORCH_LAYER_SKIP_MODULES_PATTERN, + skip_modules_pattern: List[str] = _DEFAULT_SKIP_MODULES_PATTERN, skip_modules_classes: List[Type[torch.nn.Module]] = [], ) -> torch.nn.Module: for name, submodule in module.named_modules(): @@ -189,7 +209,7 @@ def _apply_layerwise_upcasting_pytorch_layer( return module -def align_maybe_tensor_dtype(input: Any, dtype: torch.dtype) -> Any: +def _align_maybe_tensor_dtype(input: Any, dtype: torch.dtype) -> Any: r""" Aligns the dtype of a tensor or a list of tensors to a given dtype. @@ -199,6 +219,7 @@ def align_maybe_tensor_dtype(input: Any, dtype: torch.dtype) -> Any: types, it will be returned as is. dtype (`torch.dtype`): The dtype to align the tensor(s) to. + Returns: `Any`: The tensor or list of tensors aligned to the given dtype. @@ -206,7 +227,7 @@ def align_maybe_tensor_dtype(input: Any, dtype: torch.dtype) -> Any: if isinstance(input, torch.Tensor): return input.to(dtype=dtype) if isinstance(input, (list, tuple)): - return [align_maybe_tensor_dtype(t, dtype) for t in input] + return [_align_maybe_tensor_dtype(t, dtype) for t in input] if isinstance(input, dict): - return {k: align_maybe_tensor_dtype(v, dtype) for k, v in input.items()} + return {k: _align_maybe_tensor_dtype(v, dtype) for k, v in input.items()} return input diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index d6efcc736487..0df0aa051531 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -56,6 +56,7 @@ load_or_create_model_card, populate_model_card, ) +from .layerwise_upcasting_utils import LayerwiseUpcastingGranularity, apply_layerwise_upcasting from .model_loading_utils import ( _determine_device_map, _fetch_index_file, @@ -150,6 +151,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): _keys_to_ignore_on_load_unexpected = None _no_split_modules = None _keep_in_fp32_modules = None + _always_upcast_modules = None def __init__(self): super().__init__() @@ -314,6 +316,67 @@ def disable_xformers_memory_efficient_attention(self) -> None: """ self.set_use_memory_efficient_attention_xformers(False) + def enable_layerwise_upcasting( + self, + storage_dtype: torch.dtype = torch.float8_e4m3fn, + compute_dtype: Optional[torch.dtype] = None, + granularity: LayerwiseUpcastingGranularity = LayerwiseUpcastingGranularity.PYTORCH_LAYER, + ) -> None: + r""" + Activates layerwise upcasting for the current model. + + Layerwise upcasting is a technique that casts the model weights to a lower precision dtype for storage but + upcasts them on-the-fly to a higher precision dtype for computation. This process can significantly reduce the + memory footprint from model weights, but may lead to some quality degradation in the outputs. Most degradations + are negligible, mostly stemming from weight casting in normalization and modulation layers. + + By default, most models in diffusers set the `_always_upcast_modules` attribute to ignore patch embedding, + positional embedding and normalization layers. This is because these layers are most likely precision-critical + for quality. If you wish to change this behavior, you can set the `_always_upcast_modules` attribute to `None`, + or call [`~apply_layerwise_upcasting`] with custom arguments. + + Example: + Using [`~models.ModelMixin.enable_layerwise_upcasting`]: + + ```python + >>> from diffusers import CogVideoXTransformer3DModel, apply_layerwise_upcasting + + >>> transformer = CogVideoXTransformer3DModel.from_pretrained( + ... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16 + ... ) + + >>> # Enable layerwise upcasting via the model, which ignores certain modules by default + >>> transformer.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) + + >>> # Or, enable layerwise upcasting with custom arguments via the `apply_layerwise_upcasting` function + >>> apply_layerwise_upcasting( + ... transformer, torch.float8_e4m3fn, torch.bfloat16, skip_modules_pattern=["patch_embed", "norm.*"] + ... ) + ``` + + Args: + storage_dtype (`torch.dtype`): + The dtype to which the model should be cast for storage. + compute_dtype (`torch.dtype`): + The dtype to which the model weights should be cast during the forward pass. + granularity (`LayerwiseUpcastingGranularity`, defaults to "pytorch_layer"): + The granularity of the layerwise upcasting process. Read the documentation of + [`~LayerwiseUpcastingGranularity`] for more information. + """ + + skip_modules_pattern = [] + if self._keep_in_fp32_modules is not None: + skip_modules_pattern.extend(self._keep_in_fp32_modules) + if self._always_upcast_modules is not None: + skip_modules_pattern.extend(self._always_upcast_modules) + skip_modules_pattern = list(set(skip_modules_pattern)) + + if compute_dtype is None: + logger.info("`compute_dtype` not provided when enabling layerwise upcasting. Using `storage_dtype`.") + compute_dtype = self.dtype + + apply_layerwise_upcasting(self, storage_dtype, compute_dtype, granularity, skip_modules_pattern) + def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index b3f29e6b6224..f0078344a940 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -275,6 +275,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin): """ _no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"] + _always_upcast_modules = ["pos_embed", "norm.*"] _supports_gradient_checkpointing = True @register_to_config diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index b47d439774cc..6b1736a6ddcd 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -209,6 +209,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): Scaling factor to apply in 3D positional embeddings across temporal dimensions. """ + _always_upcast_modules = ["patch_embed", "norm.*"] _supports_gradient_checkpointing = True @register_to_config diff --git a/src/diffusers/models/transformers/dit_transformer_2d.py b/src/diffusers/models/transformers/dit_transformer_2d.py index f787c5279499..b4dbe43bc88a 100644 --- a/src/diffusers/models/transformers/dit_transformer_2d.py +++ b/src/diffusers/models/transformers/dit_transformer_2d.py @@ -64,6 +64,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin): A small constant added to the denominator in normalization layers to prevent division by zero. """ + _always_upcast_modules = ["pos_embed", "norm.*"] _supports_gradient_checkpointing = True @register_to_config diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py index 7f3dab220aaa..3fe0c870c624 100644 --- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -244,6 +244,8 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin): Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2 """ + _always_upcast_modules = ["pos_embed", "norm.*", "pooler"] + @register_to_config def __init__( self, diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index d34ccfd20108..7d26843ba13d 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -65,6 +65,8 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin): The number of frames in the video-like data. """ + _always_upcast_modules = ["pos_embed", "norm.*"] + @register_to_config def __init__( self, diff --git a/src/diffusers/models/transformers/lumina_nextdit2d.py b/src/diffusers/models/transformers/lumina_nextdit2d.py index d4f5b4658542..55797511dfaf 100644 --- a/src/diffusers/models/transformers/lumina_nextdit2d.py +++ b/src/diffusers/models/transformers/lumina_nextdit2d.py @@ -221,6 +221,8 @@ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin): overall scale of the model's operations. """ + _always_upcast_modules = ["patch_embedder", "norm.*", "ffn_norm.*"] + @register_to_config def __init__( self, diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py index 7f145edf16fb..395eba1f4630 100644 --- a/src/diffusers/models/transformers/pixart_transformer_2d.py +++ b/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -79,6 +79,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"] + _always_upcast_modules = ["pos_embed", "norm.*", "adaln_single"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index bc3877627529..bcd53a865a58 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -222,6 +222,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): _supports_gradient_checkpointing = True _no_split_modules = ["SanaTransformerBlock", "PatchEmbed"] + _always_upcast_modules = ["patch_embed", "norm.*"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index e208a1c10ed4..94db38f46edd 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -66,6 +66,7 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin): _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock"] + _always_upcast_modules = ["latent_image_embedding", "norm.*"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index fe9c7290b063..e4388e31a696 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -221,6 +221,9 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin): Scaling factor to apply in 3D positional embeddings across time dimension. """ + _supports_gradient_checkpointing = True + _always_upcast_modules = ["pos_embed", "norm.*", "adaln_single"] + @register_to_config def __init__( self, diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py index 94d852f6df4b..9e470d46c3c8 100644 --- a/src/diffusers/models/transformers/transformer_cogview3plus.py +++ b/src/diffusers/models/transformers/transformer_cogview3plus.py @@ -166,6 +166,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True + _always_upcast_modules = ["patch_embed", "norm.*"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index dc2eb26f9d30..503093a36e70 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -241,6 +241,7 @@ class FluxTransformer2DModel( _supports_gradient_checkpointing = True _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] + _always_upcast_modules = ["pos_embed", "norm.*"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index e3f24d97f3fa..636d52d6b843 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -542,6 +542,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, """ _supports_gradient_checkpointing = True + _always_upcast_modules = ["x_embedder", "context_embedder", "norm.*"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index a895340bd124..c65215bcad59 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -295,6 +295,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin """ _supports_gradient_checkpointing = True + _always_upcast_modules = ["norm.*"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 8763ea450253..1f2d8dcc4828 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -336,6 +336,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri _supports_gradient_checkpointing = True _no_split_modules = ["MochiTransformerBlock"] + _always_upcast_modules = ["patch_embed", "norm.*"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 415540ef7f6a..ff2f62162ae2 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -127,6 +127,7 @@ class SD3Transformer2DModel( """ _supports_gradient_checkpointing = True + _always_upcast_modules = ["pos_embed", "norm.*"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_temporal.py b/src/diffusers/models/transformers/transformer_temporal.py index 6ca42b9745fd..1e9234afba9b 100644 --- a/src/diffusers/models/transformers/transformer_temporal.py +++ b/src/diffusers/models/transformers/transformer_temporal.py @@ -67,6 +67,8 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): The maximum length of the sequence over which to apply positional embeddings. """ + _always_upcast_modules = ["norm.*"] + @register_to_config def __init__( self, diff --git a/src/diffusers/models/unets/unet_1d.py b/src/diffusers/models/unets/unet_1d.py index 8efabd98ee7d..b7f2135a58ad 100644 --- a/src/diffusers/models/unets/unet_1d.py +++ b/src/diffusers/models/unets/unet_1d.py @@ -71,6 +71,8 @@ class UNet1DModel(ModelMixin, ConfigMixin): Experimental feature for using a UNet without upsampling. """ + _always_upcast_modules = ["norm.*"] + @register_to_config def __init__( self, diff --git a/src/diffusers/models/unets/unet_2d.py b/src/diffusers/models/unets/unet_2d.py index bec62ce5cf45..6fd2773c62da 100644 --- a/src/diffusers/models/unets/unet_2d.py +++ b/src/diffusers/models/unets/unet_2d.py @@ -90,6 +90,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True + _always_upcast_modules = ["norm.*"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index e488f5897ebc..991fd21ab646 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -166,6 +166,7 @@ class conditioning with `class_embed_type` equal to `None`. _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] + _always_upcast_modules = ["norm.*"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py index 3081fdc4700c..9257336c13a7 100644 --- a/src/diffusers/models/unets/unet_3d_condition.py +++ b/src/diffusers/models/unets/unet_3d_condition.py @@ -97,6 +97,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) """ _supports_gradient_checkpointing = False + _always_upcast_modules = ["norm.*"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index ddc3e41c340d..bb1bea8e885d 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -1301,6 +1301,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft """ _supports_gradient_checkpointing = True + _always_upcast_modules = ["norm.*"] @register_to_config def __init__( From 5f898a1fce77f2ea70fc0f29968fc24e1ff87a36 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 2 Jan 2025 23:12:09 +0100 Subject: [PATCH 10/45] add some basic tests --- src/diffusers/models/unets/unet_2d.py | 5 +- .../models/unets/unet_3d_condition.py | 2 +- .../models/unets/unet_motion_model.py | 3 +- tests/models/test_modeling_common.py | 78 +++++++++++++++++++ 4 files changed, 84 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/unets/unet_2d.py b/src/diffusers/models/unets/unet_2d.py index 6fd2773c62da..291f466ff5e4 100644 --- a/src/diffusers/models/unets/unet_2d.py +++ b/src/diffusers/models/unets/unet_2d.py @@ -291,7 +291,8 @@ def forward( # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) + # TODO(aryan): Need to have this reviewed + t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb) if self.class_embedding is not None: @@ -301,7 +302,7 @@ def forward( if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) - class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) emb = emb + class_emb elif self.class_embedding is None and class_labels is not None: raise ValueError("class_embedding needs to be initialized in order to use class conditioning") diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py index 9257336c13a7..460a65af7512 100644 --- a/src/diffusers/models/unets/unet_3d_condition.py +++ b/src/diffusers/models/unets/unet_3d_condition.py @@ -97,7 +97,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) """ _supports_gradient_checkpointing = False - _always_upcast_modules = ["norm.*"] + _always_upcast_modules = ["norm.*", "time_embedding"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index bb1bea8e885d..6f042e03f718 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -2132,7 +2132,8 @@ def forward( # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) + # TODO(aryan): Need to have this reviewed + t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 4fc14804475a..882adbde79d4 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -14,6 +14,7 @@ # limitations under the License. import copy +import gc import inspect import json import os @@ -56,6 +57,7 @@ CaptureLogger, get_python_version, is_torch_compile, + numpy_cosine_similarity_distance, require_torch_2, require_torch_accelerator_with_training, require_torch_gpu, @@ -1331,6 +1333,82 @@ def test_variant_sharded_ckpt_right_format(self): # Example: diffusion_pytorch_model.fp16-00001-of-00002.safetensors assert all(f.split(".")[1].split("-")[0] == variant for f in shard_files) + def test_layerwise_upcasting_inference(self): + torch.manual_seed(0) + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**config).eval() + model = model.to(torch_device) + base_slice = model(**inputs_dict)[0].flatten().detach().cpu().numpy() + + # fp16-fp32 + torch.manual_seed(0) + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**config).eval() + model = model.to(torch_device) + model.enable_layerwise_upcasting(storage_dtype=torch.float16, compute_dtype=torch.float32) + layerwise_upcast_slice_fp16 = model(**inputs_dict)[0].flatten().detach().cpu().numpy() + + # The precision test is not very important for fast tests. In most cases, the outputs will not be the same. + # We just want to make sure that the layerwise upcasting is working as expected. + self.assertTrue(numpy_cosine_similarity_distance(base_slice, layerwise_upcast_slice_fp16) < 1.0) + + # fp8_e4m3-fp32 + torch.manual_seed(0) + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**config).eval() + model = model.to(torch_device) + model.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32) + layerwise_upcast_slice_fp8_e4m3 = model(**inputs_dict)[0].flatten().detach().cpu().numpy() + + self.assertTrue(numpy_cosine_similarity_distance(base_slice, layerwise_upcast_slice_fp8_e4m3) < 1.0) + + # fp8_e5m2-fp32 + torch.manual_seed(0) + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**config).eval() + model = model.to(torch_device) + model.enable_layerwise_upcasting(storage_dtype=torch.float8_e5m2, compute_dtype=torch.float32) + layerwise_upcast_slice_fp8_e5m2 = model(**inputs_dict)[0].flatten().detach().cpu().numpy() + + self.assertTrue(numpy_cosine_similarity_distance(base_slice, layerwise_upcast_slice_fp8_e5m2) < 1.0) + + @require_torch_gpu + def test_layerwise_upcasting_memory(self): + # fp32 + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + + torch.manual_seed(0) + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**config).eval() + model = model.to(torch_device) + model(**inputs_dict) + base_memory_footprint = model.get_memory_footprint() + base_max_memory = torch.cuda.max_memory_allocated() + + model.to("cpu") + del model + + # fp8_e4m3-fp32 + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + + torch.manual_seed(0) + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**config).eval() + model = model.to(torch_device) + model.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32) + model(**inputs_dict) + fp8_e4m3_memory_footprint = model.get_memory_footprint() + fp8_e4m3_max_memory = torch.cuda.max_memory_allocated() + + self.assertTrue(fp8_e4m3_memory_footprint < base_memory_footprint) + self.assertTrue(fp8_e4m3_max_memory < base_max_memory) + @is_staging_test class ModelPushToHubTester(unittest.TestCase): From 558c64e34276b6ead33b1da2a14ec9fa18c6461d Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 4 Jan 2025 23:45:43 +0100 Subject: [PATCH 11/45] update --- src/diffusers/models/__init__.py | 5 +++++ src/diffusers/models/modeling_utils.py | 4 +++- src/diffusers/pipelines/latte/pipeline_latte.py | 2 +- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 23f220ced20d..74b8c8d9403f 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -51,6 +51,11 @@ _import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] _import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"] _import_structure["embeddings"] = ["ImageProjection"] + _import_structure["layerwise_upcasting_utils"] = [ + "LayerwiseUpcastingGranularity", + "apply_layerwise_upcasting", + "apply_layerwise_upcasting_hook", + ] _import_structure["modeling_utils"] = ["ModelMixin"] _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"] _import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"] diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 0df0aa051531..870f3b73a15c 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -321,6 +321,7 @@ def enable_layerwise_upcasting( storage_dtype: torch.dtype = torch.float8_e4m3fn, compute_dtype: Optional[torch.dtype] = None, granularity: LayerwiseUpcastingGranularity = LayerwiseUpcastingGranularity.PYTORCH_LAYER, + skip_modules_pattern: Optional[List[str]] = None, ) -> None: r""" Activates layerwise upcasting for the current model. @@ -364,7 +365,8 @@ def enable_layerwise_upcasting( [`~LayerwiseUpcastingGranularity`] for more information. """ - skip_modules_pattern = [] + if skip_modules_pattern is None: + skip_modules_pattern = [] if self._keep_in_fp32_modules is not None: skip_modules_pattern.extend(self._keep_in_fp32_modules) if self._always_upcast_modules is not None: diff --git a/src/diffusers/pipelines/latte/pipeline_latte.py b/src/diffusers/pipelines/latte/pipeline_latte.py index 19c4a6d1ddf9..1cfe22e2d8e8 100644 --- a/src/diffusers/pipelines/latte/pipeline_latte.py +++ b/src/diffusers/pipelines/latte/pipeline_latte.py @@ -836,7 +836,7 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if not output_type == "latents": + if not output_type == "latent": video = self.decode_latents(latents, video_length, decode_chunk_size=14) video = self.video_processor.postprocess_video(video=video, output_type=output_type) else: From 7858f2c295f8df375cb87675af95cff0415df38c Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 12 Jan 2025 08:45:48 +0100 Subject: [PATCH 12/45] update --- src/diffusers/__init__.py | 6 - src/diffusers/hooks/__init__.py | 5 + src/diffusers/hooks/hooks.py | 164 ++++++++++++ src/diffusers/hooks/layerwise_upcasting.py | 122 +++++++++ src/diffusers/models/__init__.py | 10 - src/diffusers/models/hooks.py | 233 ------------------ .../models/layerwise_upcasting_utils.py | 233 ------------------ src/diffusers/models/modeling_utils.py | 22 +- src/diffusers/utils/dummy_pt_objects.py | 23 -- 9 files changed, 308 insertions(+), 510 deletions(-) create mode 100644 src/diffusers/hooks/__init__.py create mode 100644 src/diffusers/hooks/hooks.py create mode 100644 src/diffusers/hooks/layerwise_upcasting.py delete mode 100644 src/diffusers/models/hooks.py delete mode 100644 src/diffusers/models/layerwise_upcasting_utils.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d51597a286a4..5e9ab2a117d1 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -107,7 +107,6 @@ "I2VGenXLUNet", "Kandinsky3UNet", "LatteTransformer3DModel", - "LayerwiseUpcastingGranularity", "LTXVideoTransformer3DModel", "LuminaNextDiT2DModel", "MochiTransformer3DModel", @@ -136,8 +135,6 @@ "UNetSpatioTemporalConditionModel", "UVit2DModel", "VQModel", - "apply_layerwise_upcasting", - "apply_layerwise_upcasting_hook", ] ) _import_structure["optimization"] = [ @@ -620,7 +617,6 @@ I2VGenXLUNet, Kandinsky3UNet, LatteTransformer3DModel, - LayerwiseUpcastingGranularity, LTXVideoTransformer3DModel, LuminaNextDiT2DModel, MochiTransformer3DModel, @@ -648,8 +644,6 @@ UNetSpatioTemporalConditionModel, UVit2DModel, VQModel, - apply_layerwise_upcasting, - apply_layerwise_upcasting_hook, ) from .optimization import ( get_constant_schedule, diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py new file mode 100644 index 000000000000..14c16c7d3236 --- /dev/null +++ b/src/diffusers/hooks/__init__.py @@ -0,0 +1,5 @@ +from ..utils import is_torch_available + + +if is_torch_available(): + from .layerwise_upcasting import apply_layerwise_upcasting, apply_layerwise_upcasting_hook diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py new file mode 100644 index 000000000000..9d61e294742f --- /dev/null +++ b/src/diffusers/hooks/hooks.py @@ -0,0 +1,164 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Any, Dict, Tuple + +import torch + +from ..utils.logging import get_logger + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +class ModelHook: + r""" + A hook that contains callbacks to be executed just before and after the forward method of a model. + """ + + _is_stateful = False + + def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when a model is initialized. + + Args: + module (`torch.nn.Module`): + The module attached to this hook. + """ + return module + + def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when a model is deinitalized. + + Args: + module (`torch.nn.Module`): + The module attached to this hook. + """ + module.forward = module._old_forward + del module._old_forward + return module + + def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: + r""" + Hook that is executed just before the forward method of the model. + + Args: + module (`torch.nn.Module`): + The module whose forward pass will be executed just after this event. + args (`Tuple[Any]`): + The positional arguments passed to the module. + kwargs (`Dict[Str, Any]`): + The keyword arguments passed to the module. + + Returns: + `Tuple[Tuple[Any], Dict[Str, Any]]`: + A tuple with the treated `args` and `kwargs`. + """ + return args, kwargs + + def post_forward(self, module: torch.nn.Module, output: Any) -> Any: + r""" + Hook that is executed just after the forward method of the model. + + Args: + module (`torch.nn.Module`): + The module whose forward pass been executed just before this event. + output (`Any`): + The output of the module. + + Returns: + `Any`: The processed `output`. + """ + return output + + def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when the hook is detached from a module. + + Args: + module (`torch.nn.Module`): + The module detached from this hook. + """ + return module + + def reset_state(self, module: torch.nn.Module): + if self._is_stateful: + raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") + return module + + +class HookRegistry: + def __init__(self, module_ref: torch.nn.Module) -> None: + super().__init__() + + self.hooks: Dict[str, ModelHook] = {} + + self._module_ref = module_ref + self._hook_order = [] + + def register_hook(self, hook: ModelHook, name: str) -> None: + if name in self.hooks.keys(): + logger.warning(f"Hook with name {name} already exists, replacing it.") + + if hasattr(self._module_ref, "_old_forward"): + old_forward = self._module_ref._old_forward + else: + old_forward = self._module_ref.forward + self._module_ref._old_forward = self._module_ref.forward + + self._module_ref = hook.initialize_hook(self._module_ref) + + if hasattr(hook, "new_forward"): + new_forward = hook.new_forward + else: + + def new_forward(module, *args, **kwargs): + args, kwargs = hook.pre_forward(module, *args, **kwargs) + output = old_forward(*args, **kwargs) + return hook.post_forward(module, output) + + new_forward = functools.update_wrapper(new_forward, old_forward) + self._module_ref.forward = new_forward.__get__(self._module_ref) + + self.hooks[name] = hook + self._hook_order.append(name) + + def get_hook(self, name: str) -> ModelHook: + if name not in self.hooks.keys(): + raise ValueError(f"Hook with name {name} not found.") + return self.hooks[name] + + def remove_hook(self, name: str) -> None: + if name not in self.hooks.keys(): + raise ValueError(f"Hook with name {name} not found.") + self.hooks[name].deinitalize_hook(self._module_ref) + del self.hooks[name] + self._hook_order.remove(name) + + @classmethod + def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry": + if not hasattr(module, "_diffusers_hook"): + module._diffusers_hook = cls(module) + return module._diffusers_hook + + def __repr__(self) -> str: + hook_repr = "" + for i, hook_name in enumerate(self._hook_order): + hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})" + if i < len(self._hook_order) - 1: + hook_repr += "\n" + return f"HookRegistry(\n{hook_repr}\n)" diff --git a/src/diffusers/hooks/layerwise_upcasting.py b/src/diffusers/hooks/layerwise_upcasting.py new file mode 100644 index 000000000000..61dfd3fe6814 --- /dev/null +++ b/src/diffusers/hooks/layerwise_upcasting.py @@ -0,0 +1,122 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import List, Type + +import torch + +from ..utils import get_logger +from .hooks import HookRegistry, ModelHook + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +# fmt: off +_SUPPORTED_PYTORCH_LAYERS = [ + torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, + torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, + torch.nn.Linear, +] + +_DEFAULT_SKIP_MODULES_PATTERN = ["pos_embed", "patch_embed", "norm"] +# fmt: on + + +class LayerwiseUpcastingHook(ModelHook): + r""" + A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype + for storage. This process may lead to quality loss in the output, but can significantly reduce the memory + footprint. + """ + + _is_stateful = False + + def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype) -> None: + self.storage_dtype = storage_dtype + self.compute_dtype = compute_dtype + + def initialize_hook(self, module: torch.nn.Module): + module.to(dtype=self.storage_dtype) + return module + + def pre_forward(self, module: torch.nn.Module, *args, **kwargs): + module.to(dtype=self.compute_dtype) + return args, kwargs + + def post_forward(self, module: torch.nn.Module, output): + module.to(dtype=self.storage_dtype) + return output + + +def apply_layerwise_upcasting( + module: torch.nn.Module, + storage_dtype: torch.dtype, + compute_dtype: torch.dtype, + skip_modules_pattern: List[str] = _DEFAULT_SKIP_MODULES_PATTERN, + skip_modules_classes: List[Type[torch.nn.Module]] = [], +) -> torch.nn.Module: + r""" + Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any + nn.Module using diffusers layers or pytorch primitives. + + Args: + module (`torch.nn.Module`): + The module whose leaf modules will be cast to a high precision dtype for computation, and to a low + precision dtype for storage. + storage_dtype (`torch.dtype`): + The dtype to cast the module to before/after the forward pass for storage. + compute_dtype (`torch.dtype`): + The dtype to cast the module to during the forward pass for computation. + skip_modules_pattern (`List[str]`, defaults to `["pos_embed", "patch_embed", "norm"]`): + A list of patterns to match the names of the modules to skip during the layerwise upcasting process. + skip_modules_classes (`List[Type[torch.nn.Module]]`, defaults to `[]`): + A list of module classes to skip during the layerwise upcasting process. + """ + for name, submodule in module.named_modules(): + if ( + any(re.search(pattern, name) for pattern in skip_modules_pattern) + or any(isinstance(submodule, module_class) for module_class in skip_modules_classes) + or not isinstance(submodule, tuple(_SUPPORTED_PYTORCH_LAYERS)) + or len(list(submodule.children())) > 0 + ): + logger.debug(f'Skipping layerwise upcasting for layer "{name}"') + continue + logger.debug(f'Applying layerwise upcasting to layer "{name}"') + apply_layerwise_upcasting_hook(submodule, storage_dtype, compute_dtype) + return module + + +def apply_layerwise_upcasting_hook( + module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype +) -> torch.nn.Module: + r""" + Applies a `LayerwiseUpcastingHook` to a given module. + + Args: + module (`torch.nn.Module`): + The module to attach the hook to. + storage_dtype (`torch.dtype`): + The dtype to cast the module to before the forward pass. + compute_dtype (`torch.dtype`): + The dtype to cast the module to during the forward pass. + + Returns: + `torch.nn.Module`: + The same module, with the hook attached (the module is modified in place). + """ + registry = HookRegistry.check_if_exists_or_initialize(module) + hook = LayerwiseUpcastingHook(storage_dtype, compute_dtype) + registry.register_hook(hook, "layerwise_upcasting") diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 74b8c8d9403f..01e67b01d91a 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -51,11 +51,6 @@ _import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] _import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"] _import_structure["embeddings"] = ["ImageProjection"] - _import_structure["layerwise_upcasting_utils"] = [ - "LayerwiseUpcastingGranularity", - "apply_layerwise_upcasting", - "apply_layerwise_upcasting_hook", - ] _import_structure["modeling_utils"] = ["ModelMixin"] _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"] _import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"] @@ -128,11 +123,6 @@ UNetControlNetXSModel, ) from .embeddings import ImageProjection - from .layerwise_upcasting_utils import ( - LayerwiseUpcastingGranularity, - apply_layerwise_upcasting, - apply_layerwise_upcasting_hook, - ) from .modeling_utils import ModelMixin from .transformers import ( AllegroTransformer3DModel, diff --git a/src/diffusers/models/hooks.py b/src/diffusers/models/hooks.py deleted file mode 100644 index 4b449620b13c..000000000000 --- a/src/diffusers/models/hooks.py +++ /dev/null @@ -1,233 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -from typing import Any, Dict, Tuple - -import torch - -from ..utils import get_logger - - -logger = get_logger(__name__) # pylint: disable=invalid-name - - -# Reference: https://github.com/huggingface/accelerate/blob/ba7ab93f5e688466ea56908ea3b056fae2f9a023/src/accelerate/hooks.py -class ModelHook: - r""" - A hook that contains callbacks to be executed just before and after the forward method of a model. The difference - with PyTorch existing hooks is that they get passed along the kwargs. - """ - - _is_stateful = False - - def init_hook(self, module: torch.nn.Module) -> torch.nn.Module: - r""" - Hook that is executed when a model is initialized. - - Args: - module (`torch.nn.Module`): - The module attached to this hook. - """ - return module - - def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: - r""" - Hook that is executed just before the forward method of the model. - - Args: - module (`torch.nn.Module`): - The module whose forward pass will be executed just after this event. - args (`Tuple[Any]`): - The positional arguments passed to the module. - kwargs (`Dict[Str, Any]`): - The keyword arguments passed to the module. - Returns: - `Tuple[Tuple[Any], Dict[Str, Any]]`: - A tuple with the treated `args` and `kwargs`. - """ - return args, kwargs - - def post_forward(self, module: torch.nn.Module, output: Any) -> Any: - r""" - Hook that is executed just after the forward method of the model. - - Args: - module (`torch.nn.Module`): - The module whose forward pass been executed just before this event. - output (`Any`): - The output of the module. - Returns: - `Any`: The processed `output`. - """ - return output - - def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: - r""" - Hook that is executed when the hook is detached from a module. - - Args: - module (`torch.nn.Module`): - The module detached from this hook. - """ - return module - - def reset_state(self, module: torch.nn.Module): - if self._is_stateful: - raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") - return module - - -class SequentialHook(ModelHook): - r"""A hook that can contain several hooks and iterates through them at each event.""" - - def __init__(self, *hooks): - self.hooks = hooks - - def init_hook(self, module): - for hook in self.hooks: - module = hook.init_hook(module) - return module - - def pre_forward(self, module, *args, **kwargs): - for hook in self.hooks: - args, kwargs = hook.pre_forward(module, *args, **kwargs) - return args, kwargs - - def post_forward(self, module, output): - for hook in self.hooks: - output = hook.post_forward(module, output) - return output - - def detach_hook(self, module): - for hook in self.hooks: - module = hook.detach_hook(module) - return module - - def reset_state(self, module): - for hook in self.hooks: - if hook._is_stateful: - hook.reset_state(module) - return module - - -def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False): - r""" - Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove - this behavior and restore the original `forward` method, use `remove_hook_from_module`. - - - - If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks - together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class. - - - - Args: - module (`torch.nn.Module`): - The module to attach a hook to. - hook (`ModelHook`): - The hook to attach. - append (`bool`, *optional*, defaults to `False`): - Whether the hook should be chained with an existing one (if module already contains a hook) or not. - Returns: - `torch.nn.Module`: - The same module, with the hook attached (the module is modified in place, so the result can be discarded). - """ - original_hook = hook - - if append and getattr(module, "_diffusers_hook", None) is not None: - old_hook = module._diffusers_hook - remove_hook_from_module(module) - hook = SequentialHook(old_hook, hook) - - if hasattr(module, "_diffusers_hook") and hasattr(module, "_old_forward"): - # If we already put some hook on this module, we replace it with the new one. - old_forward = module._old_forward - else: - old_forward = module.forward - module._old_forward = old_forward - - module = hook.init_hook(module) - module._diffusers_hook = hook - - if hasattr(original_hook, "new_forward"): - new_forward = original_hook.new_forward - else: - - def new_forward(module, *args, **kwargs): - args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) - output = module._old_forward(*args, **kwargs) - return module._diffusers_hook.post_forward(module, output) - - # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail. - # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409 - if "GraphModuleImpl" in str(type(module)): - module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward) - else: - module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward) - - return module - - -def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> torch.nn.Module: - """ - Removes any hook attached to a module via `add_hook_to_module`. - - Args: - module (`torch.nn.Module`): - The module to attach a hook to. - recurse (`bool`, defaults to `False`): - Whether to remove the hooks recursively - Returns: - `torch.nn.Module`: - The same module, with the hook detached (the module is modified in place, so the result can be discarded). - """ - - if hasattr(module, "_diffusers_hook"): - module._diffusers_hook.detach_hook(module) - delattr(module, "_diffusers_hook") - - if hasattr(module, "_old_forward"): - # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail. - # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409 - if "GraphModuleImpl" in str(type(module)): - module.__class__.forward = module._old_forward - else: - module.forward = module._old_forward - delattr(module, "_old_forward") - - if recurse: - for child in module.children(): - remove_hook_from_module(child, recurse) - - return module - - -def reset_stateful_hooks(module: torch.nn.Module, recurse: bool = False): - """ - Resets the state of all stateful hooks attached to a module. - - Args: - module (`torch.nn.Module`): - The module to reset the stateful hooks from. - """ - if hasattr(module, "_diffusers_hook") and ( - module._diffusers_hook._is_stateful or isinstance(module._diffusers_hook, SequentialHook) - ): - module._diffusers_hook.reset_state(module) - - if recurse: - for child in module.children(): - reset_stateful_hooks(child, recurse) diff --git a/src/diffusers/models/layerwise_upcasting_utils.py b/src/diffusers/models/layerwise_upcasting_utils.py deleted file mode 100644 index 20cd7c65f72f..000000000000 --- a/src/diffusers/models/layerwise_upcasting_utils.py +++ /dev/null @@ -1,233 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -from enum import Enum -from typing import Any, List, Type - -import torch - -from ..utils import get_logger -from .attention import FeedForward, LuminaFeedForward -from .embeddings import ( - AttentionPooling, - CogVideoXPatchEmbed, - CogView3PlusPatchEmbed, - GLIGENTextBoundingboxProjection, - HunyuanDiTAttentionPool, - LuminaPatchEmbed, - MochiAttentionPool, - PixArtAlphaTextProjection, - TimestepEmbedding, -) -from .hooks import ModelHook, add_hook_to_module - - -logger = get_logger(__name__) # pylint: disable=invalid-name - - -class LayerwiseUpcastingHook(ModelHook): - r""" - A hook that cast the input tensors and torch.nn.Module to a pre-specified dtype before the forward pass and cast - the module back to the original dtype after the forward pass. This is useful when a model is loaded/stored in a - lower precision dtype but performs computation in a higher precision dtype. This process may lead to quality loss - in the output, but can significantly reduce the memory footprint. - """ - - _is_stateful = False - - def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype) -> None: - self.storage_dtype = storage_dtype - self.compute_dtype = compute_dtype - - def init_hook(self, module: torch.nn.Module): - module.to(dtype=self.storage_dtype) - return module - - def pre_forward(self, module: torch.nn.Module, *args, **kwargs): - module.to(dtype=self.compute_dtype) - # How do we account for LongTensor, BoolTensor, etc.? - # args = tuple(_align_maybe_tensor_dtype(arg, self.compute_dtype) for arg in args) - # kwargs = {k: _align_maybe_tensor_dtype(v, self.compute_dtype) for k, v in kwargs.items()} - return args, kwargs - - def post_forward(self, module: torch.nn.Module, output): - module.to(dtype=self.storage_dtype) - return output - - -class LayerwiseUpcastingGranularity(str, Enum): - r""" - An enumeration class that defines the granularity of the layerwise upcasting process. - - Granularity can be one of the following: - - `DIFFUSERS_LAYER`: - Applies layerwise upcasting to the lower-level diffusers layers of the model. This method is applied to - only those layers that are a group of linear layers, while excluding precision-critical layers like - modulation and normalization layers. - - `PYTORCH_LAYER`: - Applies layerwise upcasting to lower-level PyTorch primitive layers of the model. This is the most granular - level of layerwise upcasting. The memory footprint for inference and training is greatly reduced, while - also ensuring important operations like normalization with learned parameters remain unaffected from the - downcasting/upcasting process, by default. As not all parameters are casted to lower precision, the memory - footprint for storing the model may be slightly higher than the alternatives. This method causes the - highest number of casting operations, which may contribute to a slight increase in the overall computation - time. - - Note: try and ensure that precision-critical layers like modulation and normalization layers are not casted to - lower precision, as this may lead to significant quality loss. - """ - - DIFFUSERS_LAYER = "diffusers_layer" - PYTORCH_LAYER = "pytorch_layer" - - -# fmt: off -_SUPPORTED_DIFFUSERS_LAYERS = [ - AttentionPooling, MochiAttentionPool, HunyuanDiTAttentionPool, - CogVideoXPatchEmbed, CogView3PlusPatchEmbed, LuminaPatchEmbed, - TimestepEmbedding, GLIGENTextBoundingboxProjection, PixArtAlphaTextProjection, - FeedForward, LuminaFeedForward, -] - -_SUPPORTED_PYTORCH_LAYERS = [ - torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, - torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, - torch.nn.Linear, -] - -_DEFAULT_SKIP_MODULES_PATTERN = ["pos_embed", "patch_embed", "norm"] -# fmt: on - - -def apply_layerwise_upcasting( - module: torch.nn.Module, - storage_dtype: torch.dtype, - compute_dtype: torch.dtype, - granularity: LayerwiseUpcastingGranularity = LayerwiseUpcastingGranularity.PYTORCH_LAYER, - skip_modules_pattern: List[str] = _DEFAULT_SKIP_MODULES_PATTERN, - skip_modules_classes: List[Type[torch.nn.Module]] = [], -) -> torch.nn.Module: - r""" - Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any - nn.Module using diffusers layers or pytorch primitives. - - Args: - module (`torch.nn.Module`): - The module to attach the hook to. - storage_dtype (`torch.dtype`): - The dtype to cast the module to before the forward pass. - compute_dtype (`torch.dtype`): - The dtype to cast the module to during the forward pass. - granularity (`LayerwiseUpcastingGranularity`, *optional*, defaults to `LayerwiseUpcastingGranularity.PYTORCH_LAYER`): - The granularity of the layerwise upcasting process. - skip_modules_pattern (`List[str]`, defaults to `["pos_embed", "patch_embed", "norm"]`): - A list of patterns to match the names of the modules to skip during the layerwise upcasting process. - skip_modules_classes (`List[Type[torch.nn.Module]]`, defaults to `[]`): - A list of module classes to skip during the layerwise upcasting process. - """ - if granularity == LayerwiseUpcastingGranularity.DIFFUSERS_LAYER: - return _apply_layerwise_upcasting_diffusers_layer( - module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes - ) - if granularity == LayerwiseUpcastingGranularity.PYTORCH_LAYER: - return _apply_layerwise_upcasting_pytorch_layer( - module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes - ) - - -def apply_layerwise_upcasting_hook( - module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype -) -> torch.nn.Module: - r""" - Applies a `LayerwiseUpcastingHook` to a given module. - - Args: - module (`torch.nn.Module`): - The module to attach the hook to. - storage_dtype (`torch.dtype`): - The dtype to cast the module to before the forward pass. - compute_dtype (`torch.dtype`): - The dtype to cast the module to during the forward pass. - - Returns: - `torch.nn.Module`: - The same module, with the hook attached (the module is modified in place, so the result can be discarded). - """ - hook = LayerwiseUpcastingHook(storage_dtype, compute_dtype) - return add_hook_to_module(module, hook, append=True) - - -def _apply_layerwise_upcasting_diffusers_layer( - module: torch.nn.Module, - storage_dtype: torch.dtype, - compute_dtype: torch.dtype, - skip_modules_pattern: List[str] = _DEFAULT_SKIP_MODULES_PATTERN, - skip_modules_classes: List[Type[torch.nn.Module]] = [], -) -> torch.nn.Module: - for name, submodule in module.named_modules(): - if ( - any(re.search(pattern, name) for pattern in skip_modules_pattern) - or any(isinstance(submodule, module_class) for module_class in skip_modules_classes) - or not isinstance(submodule, tuple(_SUPPORTED_DIFFUSERS_LAYERS)) - ): - logger.debug(f'Skipping layerwise upcasting for layer "{name}"') - continue - logger.debug(f'Applying layerwise upcasting to layer "{name}"') - apply_layerwise_upcasting_hook(submodule, storage_dtype, compute_dtype) - return module - - -def _apply_layerwise_upcasting_pytorch_layer( - module: torch.nn.Module, - storage_dtype: torch.dtype, - compute_dtype: torch.dtype, - skip_modules_pattern: List[str] = _DEFAULT_SKIP_MODULES_PATTERN, - skip_modules_classes: List[Type[torch.nn.Module]] = [], -) -> torch.nn.Module: - for name, submodule in module.named_modules(): - if ( - any(re.search(pattern, name) for pattern in skip_modules_pattern) - or any(isinstance(submodule, module_class) for module_class in skip_modules_classes) - or not isinstance(submodule, tuple(_SUPPORTED_PYTORCH_LAYERS)) - ): - logger.debug(f'Skipping layerwise upcasting for layer "{name}"') - continue - logger.debug(f'Applying layerwise upcasting to layer "{name}"') - apply_layerwise_upcasting_hook(submodule, storage_dtype, compute_dtype) - return module - - -def _align_maybe_tensor_dtype(input: Any, dtype: torch.dtype) -> Any: - r""" - Aligns the dtype of a tensor or a list of tensors to a given dtype. - - Args: - input (`Any`): - The input tensor, list of tensors, or dictionary of tensors to align. If the input is neither of these - types, it will be returned as is. - dtype (`torch.dtype`): - The dtype to align the tensor(s) to. - - Returns: - `Any`: - The tensor or list of tensors aligned to the given dtype. - """ - if isinstance(input, torch.Tensor): - return input.to(dtype=dtype) - if isinstance(input, (list, tuple)): - return [_align_maybe_tensor_dtype(t, dtype) for t in input] - if isinstance(input, dict): - return {k: _align_maybe_tensor_dtype(v, dtype) for k, v in input.items()} - return input diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 870f3b73a15c..75570ccc3c29 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -23,7 +23,7 @@ from collections import OrderedDict from functools import partial, wraps from pathlib import Path -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Type, Union import safetensors import torch @@ -32,6 +32,7 @@ from torch import Tensor, nn from .. import __version__ +from ..hooks import apply_layerwise_upcasting from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( @@ -48,6 +49,7 @@ is_accelerate_available, is_bitsandbytes_available, is_bitsandbytes_version, + is_peft_available, is_torch_version, logging, ) @@ -56,7 +58,6 @@ load_or_create_model_card, populate_model_card, ) -from .layerwise_upcasting_utils import LayerwiseUpcastingGranularity, apply_layerwise_upcasting from .model_loading_utils import ( _determine_device_map, _fetch_index_file, @@ -320,8 +321,8 @@ def enable_layerwise_upcasting( self, storage_dtype: torch.dtype = torch.float8_e4m3fn, compute_dtype: Optional[torch.dtype] = None, - granularity: LayerwiseUpcastingGranularity = LayerwiseUpcastingGranularity.PYTORCH_LAYER, skip_modules_pattern: Optional[List[str]] = None, + skip_modules_classes: Optional[List[Type[torch.nn.Module]]] = None, ) -> None: r""" Activates layerwise upcasting for the current model. @@ -373,11 +374,22 @@ def enable_layerwise_upcasting( skip_modules_pattern.extend(self._always_upcast_modules) skip_modules_pattern = list(set(skip_modules_pattern)) + if skip_modules_classes is None: + skip_modules_classes = [] + if is_peft_available(): + # By default, we want to skip all peft layers because they have a very low memory footprint. + # If users want to apply layerwise upcasting on peft layers as well, they can utilize the + # `~diffusers.hooks.layerwise_upcasting.apply_layerwise_upcasting` function which provides + # them with more flexibility and control. + from peft.tuners.tuners_utils import BaseTunerLayer + + skip_modules_classes.append(BaseTunerLayer) + if compute_dtype is None: - logger.info("`compute_dtype` not provided when enabling layerwise upcasting. Using `storage_dtype`.") + logger.info("`compute_dtype` not provided when enabling layerwise upcasting. Using dtype of the model.") compute_dtype = self.dtype - apply_layerwise_upcasting(self, storage_dtype, compute_dtype, granularity, skip_modules_pattern) + apply_layerwise_upcasting(self, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes) def save_pretrained( self, diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index c8e533a36b34..4b6ac10385cf 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -452,21 +452,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class LayerwiseUpcastingGranularity(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - class LTXVideoTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] @@ -872,14 +857,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -def apply_layerwise_upcasting(*args, **kwargs): - requires_backends(apply_layerwise_upcasting, ["torch"]) - - -def apply_layerwise_upcasting_hook(*args, **kwargs): - requires_backends(apply_layerwise_upcasting_hook, ["torch"]) - - def get_constant_schedule(*args, **kwargs): requires_backends(get_constant_schedule, ["torch"]) From 3d84b9e25278c72f3dfb460973e6c56d76904f27 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 12 Jan 2025 09:06:04 +0100 Subject: [PATCH 13/45] non_blocking --- src/diffusers/hooks/layerwise_upcasting.py | 20 +++++++++++++------- src/diffusers/models/modeling_utils.py | 14 ++++++++++---- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/src/diffusers/hooks/layerwise_upcasting.py b/src/diffusers/hooks/layerwise_upcasting.py index 61dfd3fe6814..11f0668b1b88 100644 --- a/src/diffusers/hooks/layerwise_upcasting.py +++ b/src/diffusers/hooks/layerwise_upcasting.py @@ -44,20 +44,21 @@ class LayerwiseUpcastingHook(ModelHook): _is_stateful = False - def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype) -> None: + def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None: self.storage_dtype = storage_dtype self.compute_dtype = compute_dtype + self.non_blocking = non_blocking def initialize_hook(self, module: torch.nn.Module): - module.to(dtype=self.storage_dtype) + module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking) return module def pre_forward(self, module: torch.nn.Module, *args, **kwargs): - module.to(dtype=self.compute_dtype) + module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking) return args, kwargs def post_forward(self, module: torch.nn.Module, output): - module.to(dtype=self.storage_dtype) + module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking) return output @@ -67,6 +68,7 @@ def apply_layerwise_upcasting( compute_dtype: torch.dtype, skip_modules_pattern: List[str] = _DEFAULT_SKIP_MODULES_PATTERN, skip_modules_classes: List[Type[torch.nn.Module]] = [], + non_blocking: bool = False, ) -> torch.nn.Module: r""" Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any @@ -84,6 +86,8 @@ def apply_layerwise_upcasting( A list of patterns to match the names of the modules to skip during the layerwise upcasting process. skip_modules_classes (`List[Type[torch.nn.Module]]`, defaults to `[]`): A list of module classes to skip during the layerwise upcasting process. + non_blocking (`bool`, defaults to `False`): + If `True`, the weight casting operations are non-blocking. """ for name, submodule in module.named_modules(): if ( @@ -95,12 +99,12 @@ def apply_layerwise_upcasting( logger.debug(f'Skipping layerwise upcasting for layer "{name}"') continue logger.debug(f'Applying layerwise upcasting to layer "{name}"') - apply_layerwise_upcasting_hook(submodule, storage_dtype, compute_dtype) + apply_layerwise_upcasting_hook(submodule, storage_dtype, compute_dtype, non_blocking) return module def apply_layerwise_upcasting_hook( - module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype + module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool ) -> torch.nn.Module: r""" Applies a `LayerwiseUpcastingHook` to a given module. @@ -112,11 +116,13 @@ def apply_layerwise_upcasting_hook( The dtype to cast the module to before the forward pass. compute_dtype (`torch.dtype`): The dtype to cast the module to during the forward pass. + non_blocking (`bool`): + If `True`, the weight casting operations are non-blocking. Returns: `torch.nn.Module`: The same module, with the hook attached (the module is modified in place). """ registry = HookRegistry.check_if_exists_or_initialize(module) - hook = LayerwiseUpcastingHook(storage_dtype, compute_dtype) + hook = LayerwiseUpcastingHook(storage_dtype, compute_dtype, non_blocking) registry.register_hook(hook, "layerwise_upcasting") diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 9252e987a15c..e7fbd5bae9d2 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -323,6 +323,7 @@ def enable_layerwise_upcasting( compute_dtype: Optional[torch.dtype] = None, skip_modules_pattern: Optional[List[str]] = None, skip_modules_classes: Optional[List[Type[torch.nn.Module]]] = None, + non_blocking: bool = False, ) -> None: r""" Activates layerwise upcasting for the current model. @@ -361,9 +362,12 @@ def enable_layerwise_upcasting( The dtype to which the model should be cast for storage. compute_dtype (`torch.dtype`): The dtype to which the model weights should be cast during the forward pass. - granularity (`LayerwiseUpcastingGranularity`, defaults to "pytorch_layer"): - The granularity of the layerwise upcasting process. Read the documentation of - [`~LayerwiseUpcastingGranularity`] for more information. + skip_modules_pattern (`List[str]`, *optional*): + A list of patterns to match the names of the modules to skip during the layerwise upcasting process. + skip_modules_classes (`List[Type[torch.nn.Module]]`, *optional*): + A list of module classes to skip during the layerwise upcasting process. + non_blocking (`bool`, *optional*, defaults to `False`): + If `True`, the weight casting operations are non-blocking. """ if skip_modules_pattern is None: @@ -389,7 +393,9 @@ def enable_layerwise_upcasting( logger.info("`compute_dtype` not provided when enabling layerwise upcasting. Using dtype of the model.") compute_dtype = self.dtype - apply_layerwise_upcasting(self, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes) + apply_layerwise_upcasting( + self, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking + ) def save_pretrained( self, From 937264736859ed9748148e98ed5da02ebb6415bf Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 12 Jan 2025 20:08:45 +0100 Subject: [PATCH 14/45] improvements --- src/diffusers/hooks/layerwise_upcasting.py | 64 +++++++++++++--------- src/diffusers/models/modeling_utils.py | 9 ++- 2 files changed, 42 insertions(+), 31 deletions(-) diff --git a/src/diffusers/hooks/layerwise_upcasting.py b/src/diffusers/hooks/layerwise_upcasting.py index 11f0668b1b88..657d3b05e437 100644 --- a/src/diffusers/hooks/layerwise_upcasting.py +++ b/src/diffusers/hooks/layerwise_upcasting.py @@ -13,7 +13,7 @@ # limitations under the License. import re -from typing import List, Type +from typing import Optional, Tuple, Type import torch @@ -25,13 +25,13 @@ # fmt: off -_SUPPORTED_PYTORCH_LAYERS = [ +_SUPPORTED_PYTORCH_LAYERS = ( torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, torch.nn.Linear, -] +) -_DEFAULT_SKIP_MODULES_PATTERN = ["pos_embed", "patch_embed", "norm"] +_DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm") # fmt: on @@ -66,10 +66,11 @@ def apply_layerwise_upcasting( module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, - skip_modules_pattern: List[str] = _DEFAULT_SKIP_MODULES_PATTERN, - skip_modules_classes: List[Type[torch.nn.Module]] = [], + skip_modules_pattern: Optional[Tuple[str]] = _DEFAULT_SKIP_MODULES_PATTERN, + skip_modules_classes: Optional[Tuple[Type[torch.nn.Module]]] = [], non_blocking: bool = False, -) -> torch.nn.Module: + _prefix: str = "", +) -> None: r""" Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any nn.Module using diffusers layers or pytorch primitives. @@ -82,30 +83,45 @@ def apply_layerwise_upcasting( The dtype to cast the module to before/after the forward pass for storage. compute_dtype (`torch.dtype`): The dtype to cast the module to during the forward pass for computation. - skip_modules_pattern (`List[str]`, defaults to `["pos_embed", "patch_embed", "norm"]`): + skip_modules_pattern (`Tuple[str]`, defaults to `["pos_embed", "patch_embed", "norm"]`): A list of patterns to match the names of the modules to skip during the layerwise upcasting process. - skip_modules_classes (`List[Type[torch.nn.Module]]`, defaults to `[]`): + skip_modules_classes (`Tuple[Type[torch.nn.Module]]`, defaults to `[]`): A list of module classes to skip during the layerwise upcasting process. non_blocking (`bool`, defaults to `False`): If `True`, the weight casting operations are non-blocking. """ - for name, submodule in module.named_modules(): - if ( - any(re.search(pattern, name) for pattern in skip_modules_pattern) - or any(isinstance(submodule, module_class) for module_class in skip_modules_classes) - or not isinstance(submodule, tuple(_SUPPORTED_PYTORCH_LAYERS)) - or len(list(submodule.children())) > 0 - ): - logger.debug(f'Skipping layerwise upcasting for layer "{name}"') - continue - logger.debug(f'Applying layerwise upcasting to layer "{name}"') - apply_layerwise_upcasting_hook(submodule, storage_dtype, compute_dtype, non_blocking) - return module + if skip_modules_classes is None and skip_modules_pattern is None: + apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking) + return + + should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or ( + skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern) + ) + if should_skip: + logger.debug(f'Skipping layerwise upcasting for layer "{_prefix}"') + return + + if isinstance(module, _SUPPORTED_PYTORCH_LAYERS): + logger.debug(f'Applying layerwise upcasting to layer "{_prefix}"') + apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking) + return + + for name, submodule in module.named_children(): + layer_name = f"{_prefix}.{name}" if _prefix else name + apply_layerwise_upcasting( + submodule, + storage_dtype, + compute_dtype, + skip_modules_pattern, + skip_modules_classes, + non_blocking, + _prefix=layer_name, + ) def apply_layerwise_upcasting_hook( module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool -) -> torch.nn.Module: +) -> None: r""" Applies a `LayerwiseUpcastingHook` to a given module. @@ -118,10 +134,6 @@ def apply_layerwise_upcasting_hook( The dtype to cast the module to during the forward pass. non_blocking (`bool`): If `True`, the weight casting operations are non-blocking. - - Returns: - `torch.nn.Module`: - The same module, with the hook attached (the module is modified in place). """ registry = HookRegistry.check_if_exists_or_initialize(module) hook = LayerwiseUpcastingHook(storage_dtype, compute_dtype, non_blocking) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index e7fbd5bae9d2..4591f742cb50 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -376,18 +376,17 @@ def enable_layerwise_upcasting( skip_modules_pattern.extend(self._keep_in_fp32_modules) if self._always_upcast_modules is not None: skip_modules_pattern.extend(self._always_upcast_modules) - skip_modules_pattern = list(set(skip_modules_pattern)) + skip_modules_pattern = tuple(set(skip_modules_pattern)) if skip_modules_classes is None: - skip_modules_classes = [] + skip_modules_classes = () if is_peft_available(): # By default, we want to skip all peft layers because they have a very low memory footprint. # If users want to apply layerwise upcasting on peft layers as well, they can utilize the # `~diffusers.hooks.layerwise_upcasting.apply_layerwise_upcasting` function which provides # them with more flexibility and control. - from peft.tuners.tuners_utils import BaseTunerLayer - - skip_modules_classes.append(BaseTunerLayer) + if "lora" not in skip_modules_pattern: + skip_modules_pattern += ("lora",) if compute_dtype is None: logger.info("`compute_dtype` not provided when enabling layerwise upcasting. Using dtype of the model.") From e586ef38273cbb2d651b9b2372290a69e83ce2d4 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 13 Jan 2025 05:42:06 +0100 Subject: [PATCH 15/45] update --- src/diffusers/hooks/hooks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 9d61e294742f..398b45d96ed2 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -13,7 +13,7 @@ # limitations under the License. import functools -from typing import Any, Dict, Tuple +from typing import Any, Dict, Optional, Tuple import torch @@ -137,9 +137,9 @@ def new_forward(module, *args, **kwargs): self.hooks[name] = hook self._hook_order.append(name) - def get_hook(self, name: str) -> ModelHook: + def get_hook(self, name: str) -> Optional[ModelHook]: if name not in self.hooks.keys(): - raise ValueError(f"Hook with name {name} not found.") + raise None return self.hooks[name] def remove_hook(self, name: str) -> None: From cfe63182305e5e373318203a7cb5d632de7df251 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 14 Jan 2025 05:36:40 +0100 Subject: [PATCH 16/45] norm.* -> norm --- src/diffusers/models/modeling_utils.py | 2 +- src/diffusers/models/transformers/auraflow_transformer_2d.py | 2 +- src/diffusers/models/transformers/cogvideox_transformer_3d.py | 2 +- src/diffusers/models/transformers/dit_transformer_2d.py | 2 +- src/diffusers/models/transformers/hunyuan_transformer_2d.py | 2 +- src/diffusers/models/transformers/latte_transformer_3d.py | 2 +- src/diffusers/models/transformers/lumina_nextdit2d.py | 2 +- src/diffusers/models/transformers/pixart_transformer_2d.py | 2 +- src/diffusers/models/transformers/sana_transformer.py | 2 +- src/diffusers/models/transformers/transformer_2d.py | 2 +- src/diffusers/models/transformers/transformer_allegro.py | 2 +- src/diffusers/models/transformers/transformer_cogview3plus.py | 2 +- src/diffusers/models/transformers/transformer_flux.py | 2 +- src/diffusers/models/transformers/transformer_hunyuan_video.py | 2 +- src/diffusers/models/transformers/transformer_ltx.py | 2 +- src/diffusers/models/transformers/transformer_mochi.py | 2 +- src/diffusers/models/transformers/transformer_sd3.py | 2 +- src/diffusers/models/transformers/transformer_temporal.py | 2 +- src/diffusers/models/unets/unet_1d.py | 2 +- src/diffusers/models/unets/unet_2d.py | 2 +- src/diffusers/models/unets/unet_2d_condition.py | 2 +- src/diffusers/models/unets/unet_3d_condition.py | 2 +- src/diffusers/models/unets/unet_motion_model.py | 2 +- 23 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 4591f742cb50..d41eb725a4df 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -353,7 +353,7 @@ def enable_layerwise_upcasting( >>> # Or, enable layerwise upcasting with custom arguments via the `apply_layerwise_upcasting` function >>> apply_layerwise_upcasting( - ... transformer, torch.float8_e4m3fn, torch.bfloat16, skip_modules_pattern=["patch_embed", "norm.*"] + ... transformer, torch.float8_e4m3fn, torch.bfloat16, skip_modules_pattern=["patch_embed", "norm"] ... ) ``` diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 7f745e67fec6..834d141b6889 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -276,7 +276,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin """ _no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"] - _always_upcast_modules = ["pos_embed", "norm.*"] + _always_upcast_modules = ["pos_embed", "norm"] _supports_gradient_checkpointing = True @register_to_config diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 52571ee4905c..368c7f98afc4 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -212,7 +212,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): Scaling factor to apply in 3D positional embeddings across temporal dimensions. """ - _always_upcast_modules = ["patch_embed", "norm.*"] + _always_upcast_modules = ["patch_embed", "norm"] _supports_gradient_checkpointing = True _no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"] diff --git a/src/diffusers/models/transformers/dit_transformer_2d.py b/src/diffusers/models/transformers/dit_transformer_2d.py index b4dbe43bc88a..6484d59d9464 100644 --- a/src/diffusers/models/transformers/dit_transformer_2d.py +++ b/src/diffusers/models/transformers/dit_transformer_2d.py @@ -64,7 +64,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin): A small constant added to the denominator in normalization layers to prevent division by zero. """ - _always_upcast_modules = ["pos_embed", "norm.*"] + _always_upcast_modules = ["pos_embed", "norm"] _supports_gradient_checkpointing = True @register_to_config diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py index 3fe0c870c624..dde728e12b20 100644 --- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -244,7 +244,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin): Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2 """ - _always_upcast_modules = ["pos_embed", "norm.*", "pooler"] + _always_upcast_modules = ["pos_embed", "norm", "pooler"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index 7d26843ba13d..4f534c02704f 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -65,7 +65,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin): The number of frames in the video-like data. """ - _always_upcast_modules = ["pos_embed", "norm.*"] + _always_upcast_modules = ["pos_embed", "norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/lumina_nextdit2d.py b/src/diffusers/models/transformers/lumina_nextdit2d.py index 55797511dfaf..56f8768c489d 100644 --- a/src/diffusers/models/transformers/lumina_nextdit2d.py +++ b/src/diffusers/models/transformers/lumina_nextdit2d.py @@ -221,7 +221,7 @@ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin): overall scale of the model's operations. """ - _always_upcast_modules = ["patch_embedder", "norm.*", "ffn_norm.*"] + _always_upcast_modules = ["patch_embedder", "norm", "ffn_norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py index 395eba1f4630..61b1ac730e56 100644 --- a/src/diffusers/models/transformers/pixart_transformer_2d.py +++ b/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -79,7 +79,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"] - _always_upcast_modules = ["pos_embed", "norm.*", "adaln_single"] + _always_upcast_modules = ["pos_embed", "norm", "adaln_single"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index bcd53a865a58..a2201a105958 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -222,7 +222,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): _supports_gradient_checkpointing = True _no_split_modules = ["SanaTransformerBlock", "PatchEmbed"] - _always_upcast_modules = ["patch_embed", "norm.*"] + _always_upcast_modules = ["patch_embed", "norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index 94db38f46edd..41134fc14907 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -66,7 +66,7 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin): _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock"] - _always_upcast_modules = ["latent_image_embedding", "norm.*"] + _always_upcast_modules = ["latent_image_embedding", "norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index e4388e31a696..a62a23ad5b58 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -222,7 +222,7 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True - _always_upcast_modules = ["pos_embed", "norm.*", "adaln_single"] + _always_upcast_modules = ["pos_embed", "norm", "adaln_single"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py index b4b720725ed5..417dbdf06439 100644 --- a/src/diffusers/models/transformers/transformer_cogview3plus.py +++ b/src/diffusers/models/transformers/transformer_cogview3plus.py @@ -166,7 +166,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True - _always_upcast_modules = ["patch_embed", "norm.*"] + _always_upcast_modules = ["patch_embed", "norm"] _no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"] @register_to_config diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index b9fef932ab00..81fa52c36445 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -262,7 +262,7 @@ class FluxTransformer2DModel( _supports_gradient_checkpointing = True _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] - _always_upcast_modules = ["pos_embed", "norm.*"] + _always_upcast_modules = ["pos_embed", "norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 7abed05269f0..71bd874a5b06 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -542,7 +542,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, """ _supports_gradient_checkpointing = True - _always_upcast_modules = ["x_embedder", "context_embedder", "norm.*"] + _always_upcast_modules = ["x_embedder", "context_embedder", "norm"] _no_split_modules = [ "HunyuanVideoTransformerBlock", "HunyuanVideoSingleTransformerBlock", diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index c65215bcad59..07368df92822 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -295,7 +295,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin """ _supports_gradient_checkpointing = True - _always_upcast_modules = ["norm.*"] + _always_upcast_modules = ["norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 1f2d8dcc4828..5a3d82f31813 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -336,7 +336,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri _supports_gradient_checkpointing = True _no_split_modules = ["MochiTransformerBlock"] - _always_upcast_modules = ["patch_embed", "norm.*"] + _always_upcast_modules = ["patch_embed", "norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index ff2f62162ae2..92344061f920 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -127,7 +127,7 @@ class SD3Transformer2DModel( """ _supports_gradient_checkpointing = True - _always_upcast_modules = ["pos_embed", "norm.*"] + _always_upcast_modules = ["pos_embed", "norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_temporal.py b/src/diffusers/models/transformers/transformer_temporal.py index 1e9234afba9b..174e6a52ef73 100644 --- a/src/diffusers/models/transformers/transformer_temporal.py +++ b/src/diffusers/models/transformers/transformer_temporal.py @@ -67,7 +67,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): The maximum length of the sequence over which to apply positional embeddings. """ - _always_upcast_modules = ["norm.*"] + _always_upcast_modules = ["norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_1d.py b/src/diffusers/models/unets/unet_1d.py index b7f2135a58ad..18ac1d63e0d0 100644 --- a/src/diffusers/models/unets/unet_1d.py +++ b/src/diffusers/models/unets/unet_1d.py @@ -71,7 +71,7 @@ class UNet1DModel(ModelMixin, ConfigMixin): Experimental feature for using a UNet without upsampling. """ - _always_upcast_modules = ["norm.*"] + _always_upcast_modules = ["norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_2d.py b/src/diffusers/models/unets/unet_2d.py index 3517e64a87bc..a673b9265251 100644 --- a/src/diffusers/models/unets/unet_2d.py +++ b/src/diffusers/models/unets/unet_2d.py @@ -90,7 +90,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True - _always_upcast_modules = ["norm.*"] + _always_upcast_modules = ["norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 991fd21ab646..18e94eaccf84 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -166,7 +166,7 @@ class conditioning with `class_embed_type` equal to `None`. _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] - _always_upcast_modules = ["norm.*"] + _always_upcast_modules = ["norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py index 460a65af7512..b1bf24cfb849 100644 --- a/src/diffusers/models/unets/unet_3d_condition.py +++ b/src/diffusers/models/unets/unet_3d_condition.py @@ -97,7 +97,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) """ _supports_gradient_checkpointing = False - _always_upcast_modules = ["norm.*", "time_embedding"] + _always_upcast_modules = ["norm", "time_embedding"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 6f042e03f718..d4f6b57b5519 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -1301,7 +1301,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft """ _supports_gradient_checkpointing = True - _always_upcast_modules = ["norm.*"] + _always_upcast_modules = ["norm"] @register_to_config def __init__( From 762741513ef7d769a746e18d564e034d4208fd46 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 15 Jan 2025 15:04:55 +0100 Subject: [PATCH 17/45] apply suggestions from review --- src/diffusers/models/modeling_utils.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 0399b4e2e749..666bb8c0f77d 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -370,23 +370,28 @@ def enable_layerwise_upcasting( If `True`, the weight casting operations are non-blocking. """ + user_provided_patterns = True if skip_modules_pattern is None: skip_modules_pattern = [] + user_provided_patterns = False if self._keep_in_fp32_modules is not None: skip_modules_pattern.extend(self._keep_in_fp32_modules) if self._always_upcast_modules is not None: skip_modules_pattern.extend(self._always_upcast_modules) skip_modules_pattern = tuple(set(skip_modules_pattern)) - if skip_modules_classes is None: - skip_modules_classes = () - if is_peft_available(): + if is_peft_available() and not user_provided_patterns: # By default, we want to skip all peft layers because they have a very low memory footprint. # If users want to apply layerwise upcasting on peft layers as well, they can utilize the # `~diffusers.hooks.layerwise_upcasting.apply_layerwise_upcasting` function which provides # them with more flexibility and control. - if "lora" not in skip_modules_pattern: - skip_modules_pattern += ("lora",) + + from peft.tuners.loha.layer import LoHaLayer + from peft.tuners.lokr.layer import LoKrLayer + from peft.tuners.lora.layer import LoraLayer + + for layer in (LoHaLayer, LoKrLayer, LoraLayer): + skip_modules_pattern += tuple(layer.adapter_layer_names) if compute_dtype is None: logger.info("`compute_dtype` not provided when enabling layerwise upcasting. Using dtype of the model.") From b9e12171431fa989412adb706e83171e9a189c18 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 15 Jan 2025 15:10:14 +0100 Subject: [PATCH 18/45] add example --- src/diffusers/hooks/layerwise_upcasting.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/diffusers/hooks/layerwise_upcasting.py b/src/diffusers/hooks/layerwise_upcasting.py index 657d3b05e437..2feef9d67b8a 100644 --- a/src/diffusers/hooks/layerwise_upcasting.py +++ b/src/diffusers/hooks/layerwise_upcasting.py @@ -67,7 +67,7 @@ def apply_layerwise_upcasting( storage_dtype: torch.dtype, compute_dtype: torch.dtype, skip_modules_pattern: Optional[Tuple[str]] = _DEFAULT_SKIP_MODULES_PATTERN, - skip_modules_classes: Optional[Tuple[Type[torch.nn.Module]]] = [], + skip_modules_classes: Optional[Tuple[Type[torch.nn.Module]]] = None, non_blocking: bool = False, _prefix: str = "", ) -> None: @@ -75,6 +75,24 @@ def apply_layerwise_upcasting( Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any nn.Module using diffusers layers or pytorch primitives. + Example: + + ```python + >>> import torch + >>> from diffusers import CogVideoXPipeline, apply_layerwise_upcasting + + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> apply_layerwise_upcasting( + ... pipe.transformer, + ... storage_dtype=torch.float8_e4m3fn, + ... compute_dtype=torch.bfloat16, + ... skip_modules_pattern=["patch_embed", "norm"], + ... non_blocking=True, + ... ) + ``` + Args: module (`torch.nn.Module`): The module whose leaf modules will be cast to a high precision dtype for computation, and to a low @@ -85,7 +103,7 @@ def apply_layerwise_upcasting( The dtype to cast the module to during the forward pass for computation. skip_modules_pattern (`Tuple[str]`, defaults to `["pos_embed", "patch_embed", "norm"]`): A list of patterns to match the names of the modules to skip during the layerwise upcasting process. - skip_modules_classes (`Tuple[Type[torch.nn.Module]]`, defaults to `[]`): + skip_modules_classes (`Tuple[Type[torch.nn.Module]]`, defaults to `None`): A list of module classes to skip during the layerwise upcasting process. non_blocking (`bool`, defaults to `False`): If `True`, the weight casting operations are non-blocking. From bde103c1d24b8db48e90a46e49ff5597cf82fe01 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 15 Jan 2025 15:13:47 +0100 Subject: [PATCH 19/45] update hook implementation to the latest changes from pyramid attention broadcast --- src/diffusers/hooks/hooks.py | 48 +++++++++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 398b45d96ed2..bef4c65c41e1 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -63,7 +63,6 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[A The positional arguments passed to the module. kwargs (`Dict[Str, Any]`): The keyword arguments passed to the module. - Returns: `Tuple[Tuple[Any], Dict[Str, Any]]`: A tuple with the treated `args` and `kwargs`. @@ -79,7 +78,6 @@ def post_forward(self, module: torch.nn.Module, output: Any) -> Any: The module whose forward pass been executed just before this event. output (`Any`): The output of the module. - Returns: `Any`: The processed `output`. """ @@ -123,7 +121,12 @@ def register_hook(self, hook: ModelHook, name: str) -> None: self._module_ref = hook.initialize_hook(self._module_ref) if hasattr(hook, "new_forward"): - new_forward = hook.new_forward + rewritten_forward = hook.new_forward + + def new_forward(module, *args, **kwargs): + args, kwargs = hook.pre_forward(module, *args, **kwargs) + output = rewritten_forward(module, *args, **kwargs) + return hook.post_forward(module, output) else: def new_forward(module, *args, **kwargs): @@ -131,23 +134,44 @@ def new_forward(module, *args, **kwargs): output = old_forward(*args, **kwargs) return hook.post_forward(module, output) - new_forward = functools.update_wrapper(new_forward, old_forward) - self._module_ref.forward = new_forward.__get__(self._module_ref) + self._module_ref.forward = functools.update_wrapper( + functools.partial(new_forward, self._module_ref), old_forward + ) self.hooks[name] = hook self._hook_order.append(name) def get_hook(self, name: str) -> Optional[ModelHook]: if name not in self.hooks.keys(): - raise None + return None return self.hooks[name] - def remove_hook(self, name: str) -> None: - if name not in self.hooks.keys(): - raise ValueError(f"Hook with name {name} not found.") - self.hooks[name].deinitalize_hook(self._module_ref) - del self.hooks[name] - self._hook_order.remove(name) + def remove_hook(self, name: str, recurse: bool = True) -> None: + if name in self.hooks.keys(): + hook = self.hooks[name] + self._module_ref = hook.deinitalize_hook(self._module_ref) + del self.hooks[name] + self._hook_order.remove(name) + + if recurse: + for module_name, module in self._module_ref.named_modules(): + if module_name == "": + continue + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook.remove_hook(name, recurse=False) + + def reset_stateful_hooks(self, recurse: bool = True) -> None: + for hook_name in self._hook_order: + hook = self.hooks[hook_name] + if hook._is_stateful: + hook.reset_state(self._module_ref) + + if recurse: + for module_name, module in self._module_ref.named_modules(): + if module_name == "": + continue + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook.reset_stateful_hooks(recurse=False) @classmethod def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry": From 64e6c9c9d52c26d83b3514db60fa51514f9249ba Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 15 Jan 2025 15:40:44 +0100 Subject: [PATCH 20/45] deinitialize should raise an error --- src/diffusers/hooks/layerwise_upcasting.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffusers/hooks/layerwise_upcasting.py b/src/diffusers/hooks/layerwise_upcasting.py index 2feef9d67b8a..1c2f7237eaef 100644 --- a/src/diffusers/hooks/layerwise_upcasting.py +++ b/src/diffusers/hooks/layerwise_upcasting.py @@ -53,6 +53,14 @@ def initialize_hook(self, module: torch.nn.Module): module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking) return module + def deinitalize_hook(self, module: torch.nn.Module): + raise NotImplementedError( + "LayerwiseUpcastingHook does not support deinitalization. A model once enabled with layerwise upcasting will " + "have casted its weights to a lower precision dtype for storage. Casting this back to the original dtype " + "will lead to precision loss, which might have an impact on the model's generation quality. The model should " + "be re-initialized and loaded in the original dtype." + ) + def pre_forward(self, module: torch.nn.Module, *args, **kwargs): module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking) return args, kwargs From 7037133c7782d0babfa3daba1f23f0945f51b504 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 15 Jan 2025 22:12:34 +0100 Subject: [PATCH 21/45] update doc page --- docs/source/en/optimization/memory.md | 33 +++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/docs/source/en/optimization/memory.md b/docs/source/en/optimization/memory.md index a2150f9aa0b7..e64db46e2769 100644 --- a/docs/source/en/optimization/memory.md +++ b/docs/source/en/optimization/memory.md @@ -158,6 +158,39 @@ In order to properly offload models after they're called, it is required to run +## FP8 layerwise weight-casting + +PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes. This precision cannot be used for performing computation directly for many different tensor operations due to unimplemented kernel support. However, one can still use these dtypes for storing model weights in lower FP8 precision. For computation, the weights can be upcasted on-the-fly as and when layers are invoked in the forward pass. + +Typically, inference on most models is done with `torch.float16` or `torch.bfloat16` weight/computation precision. Applying layerwise weight-casting, by storing the weights in FP8 precision, cuts down the memory footprint of the model weights by half approximately. + +```python +import torch +from diffusers import CogVideoXPipeline +from diffusers.utils import export_to_video + +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) +pipe.to("cuda") +pipe.transformer.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) + +prompt = ( + "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + "atmosphere of this unique musical performance." +) +video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] +export_to_video(video, "output.mp4", fps=8) +``` + +In the above example, we enable layerwise upcasting on the transformer component of the pipeline. By default, certain layers are skipped from the FP8 weight casting because it can lead to significant degradation of generation quality. For most cases, skipping the normalization and modulation related weight parameters is a good choice (which is also the default choice). However, more control and flexibility can be obtained by directly utilizing the [`~hooks.layerwise_upcasting.apply_layerwise_upcasting`] function instead of using [`~ModelMixin.enable_layerwise_upcasting`]. + +[[autodoc]] ModelMixin.enable_layerwise_upcasting + +[[autodoc]] hooks.layerwise_upcasting.apply_layerwise_upcasting + ## Channels-last memory format The channels-last memory format is an alternative way of ordering NCHW tensors in memory to preserve dimension ordering. Channels-last tensors are ordered in such a way that the channels become the densest dimension (storing images pixel-per-pixel). Since not all operators currently support the channels-last format, it may result in worst performance but you should still try and see if it works for your model. From 390742b0819df60134b0b1d5aeb38f4c338b81d2 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 17 Jan 2025 02:10:01 +0530 Subject: [PATCH 22/45] Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/optimization/memory.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/source/en/optimization/memory.md b/docs/source/en/optimization/memory.md index e64db46e2769..153839de632d 100644 --- a/docs/source/en/optimization/memory.md +++ b/docs/source/en/optimization/memory.md @@ -160,9 +160,9 @@ In order to properly offload models after they're called, it is required to run ## FP8 layerwise weight-casting -PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes. This precision cannot be used for performing computation directly for many different tensor operations due to unimplemented kernel support. However, one can still use these dtypes for storing model weights in lower FP8 precision. For computation, the weights can be upcasted on-the-fly as and when layers are invoked in the forward pass. +PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting. -Typically, inference on most models is done with `torch.float16` or `torch.bfloat16` weight/computation precision. Applying layerwise weight-casting, by storing the weights in FP8 precision, cuts down the memory footprint of the model weights by half approximately. +Typically, inference on most models is done with `torch.float16` or `torch.bfloat16` weight/computation precision. Layerwise weight-casting cuts down the memory footprint of the model weights by approximately half. ```python import torch @@ -185,7 +185,9 @@ video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] export_to_video(video, "output.mp4", fps=8) ``` -In the above example, we enable layerwise upcasting on the transformer component of the pipeline. By default, certain layers are skipped from the FP8 weight casting because it can lead to significant degradation of generation quality. For most cases, skipping the normalization and modulation related weight parameters is a good choice (which is also the default choice). However, more control and flexibility can be obtained by directly utilizing the [`~hooks.layerwise_upcasting.apply_layerwise_upcasting`] function instead of using [`~ModelMixin.enable_layerwise_upcasting`]. +In the above example, layerwise upcasting is enabled on the transformer component of the pipeline. By default, certain layers are skipped from the FP8 weight casting because it can lead to significant degradation of generation quality. The normalization and modulation related weight parameters are also skipped by default. + +However, you gain more control and flexibility by directly utilizing the [`~hooks.layerwise_upcasting.apply_layerwise_upcasting`] function instead of [`~ModelMixin.enable_layerwise_upcasting`]. [[autodoc]] ModelMixin.enable_layerwise_upcasting From 19901e7a8cea926f14452abc7083949810cd8191 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 17 Jan 2025 07:50:54 +0100 Subject: [PATCH 23/45] update docs --- docs/source/en/api/utilities.md | 4 ++++ docs/source/en/optimization/memory.md | 4 ---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/en/api/utilities.md b/docs/source/en/api/utilities.md index d4f4d7d7964f..a17ab1184957 100644 --- a/docs/source/en/api/utilities.md +++ b/docs/source/en/api/utilities.md @@ -41,3 +41,7 @@ Utility and helper functions for working with 🤗 Diffusers. ## randn_tensor [[autodoc]] utils.torch_utils.randn_tensor + +## apply_layerwise_upcasting + +[[autodoc]] hooks.layerwise_upcasting.apply_layerwise_upcasting diff --git a/docs/source/en/optimization/memory.md b/docs/source/en/optimization/memory.md index 153839de632d..ce481a05c722 100644 --- a/docs/source/en/optimization/memory.md +++ b/docs/source/en/optimization/memory.md @@ -189,10 +189,6 @@ In the above example, layerwise upcasting is enabled on the transformer componen However, you gain more control and flexibility by directly utilizing the [`~hooks.layerwise_upcasting.apply_layerwise_upcasting`] function instead of [`~ModelMixin.enable_layerwise_upcasting`]. -[[autodoc]] ModelMixin.enable_layerwise_upcasting - -[[autodoc]] hooks.layerwise_upcasting.apply_layerwise_upcasting - ## Channels-last memory format The channels-last memory format is an alternative way of ordering NCHW tensors in memory to preserve dimension ordering. Channels-last tensors are ordered in such a way that the channels become the densest dimension (storing images pixel-per-pixel). Since not all operators currently support the channels-last format, it may result in worst performance but you should still try and see if it works for your model. From 3ae32b471cfc2e06695a003f59401070f8bd6f18 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 17 Jan 2025 08:35:03 +0100 Subject: [PATCH 24/45] update --- docs/source/en/optimization/memory.md | 12 ++++++--- src/diffusers/hooks/layerwise_upcasting.py | 31 +++++++++++++--------- src/diffusers/models/modeling_utils.py | 27 ++++++++++++++----- 3 files changed, 47 insertions(+), 23 deletions(-) diff --git a/docs/source/en/optimization/memory.md b/docs/source/en/optimization/memory.md index ce481a05c722..50e10f0caa48 100644 --- a/docs/source/en/optimization/memory.md +++ b/docs/source/en/optimization/memory.md @@ -166,12 +166,18 @@ Typically, inference on most models is done with `torch.float16` or `torch.bfloa ```python import torch -from diffusers import CogVideoXPipeline +from diffusers import CogVideoXPipeline, CogVideoXTransformer3DModel from diffusers.utils import export_to_video -pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) +model_id = "THUDM/CogVideoX-5b" + +# Load the model in bfloat16 and enable layerwise upcasting +transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16) +transformer.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) + +# Load the pipeline +pipe = CogVideoXPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16) pipe.to("cuda") -pipe.transformer.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) prompt = ( "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " diff --git a/src/diffusers/hooks/layerwise_upcasting.py b/src/diffusers/hooks/layerwise_upcasting.py index 1c2f7237eaef..7ee920c1e002 100644 --- a/src/diffusers/hooks/layerwise_upcasting.py +++ b/src/diffusers/hooks/layerwise_upcasting.py @@ -13,7 +13,7 @@ # limitations under the License. import re -from typing import Optional, Tuple, Type +from typing import Optional, Tuple, Type, Union import torch @@ -25,13 +25,13 @@ # fmt: off -_SUPPORTED_PYTORCH_LAYERS = ( +SUPPORTED_PYTORCH_LAYERS = ( torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, torch.nn.Linear, ) -_DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm") +DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm", "^proj_in$", "^proj_out$") # fmt: on @@ -74,8 +74,8 @@ def apply_layerwise_upcasting( module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, - skip_modules_pattern: Optional[Tuple[str]] = _DEFAULT_SKIP_MODULES_PATTERN, - skip_modules_classes: Optional[Tuple[Type[torch.nn.Module]]] = None, + skip_modules_pattern: Union[str, Tuple[str, ...]] = "default", + skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None, non_blocking: bool = False, _prefix: str = "", ) -> None: @@ -87,13 +87,14 @@ def apply_layerwise_upcasting( ```python >>> import torch - >>> from diffusers import CogVideoXPipeline, apply_layerwise_upcasting + >>> from diffusers import CogVideoXTransformer3DModel - >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) - >>> pipe.to("cuda") + >>> transformer = CogVideoXTransformer3DModel.from_pretrained( + ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 + ... ) >>> apply_layerwise_upcasting( - ... pipe.transformer, + ... transformer, ... storage_dtype=torch.float8_e4m3fn, ... compute_dtype=torch.bfloat16, ... skip_modules_pattern=["patch_embed", "norm"], @@ -109,13 +110,17 @@ def apply_layerwise_upcasting( The dtype to cast the module to before/after the forward pass for storage. compute_dtype (`torch.dtype`): The dtype to cast the module to during the forward pass for computation. - skip_modules_pattern (`Tuple[str]`, defaults to `["pos_embed", "patch_embed", "norm"]`): - A list of patterns to match the names of the modules to skip during the layerwise upcasting process. - skip_modules_classes (`Tuple[Type[torch.nn.Module]]`, defaults to `None`): + skip_modules_pattern (`Tuple[str, ...]`, defaults to `"default"`): + A list of patterns to match the names of the modules to skip during the layerwise upcasting process. If set + to `"default"`, the default patterns are used. + skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, defaults to `None`): A list of module classes to skip during the layerwise upcasting process. non_blocking (`bool`, defaults to `False`): If `True`, the weight casting operations are non-blocking. """ + if skip_modules_pattern == "default": + skip_modules_pattern = DEFAULT_SKIP_MODULES_PATTERN + if skip_modules_classes is None and skip_modules_pattern is None: apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking) return @@ -127,7 +132,7 @@ def apply_layerwise_upcasting( logger.debug(f'Skipping layerwise upcasting for layer "{_prefix}"') return - if isinstance(module, _SUPPORTED_PYTORCH_LAYERS): + if isinstance(module, SUPPORTED_PYTORCH_LAYERS): logger.debug(f'Applying layerwise upcasting to layer "{_prefix}"') apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking) return diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 666bb8c0f77d..995560ecd0ec 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -104,6 +104,17 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype: """ Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found. """ + # 1. Check if we have attached any dtype modifying hooks (eg. layerwise upcasting) + if isinstance(parameter, nn.Module): + for name, submodule in parameter.named_modules(): + if not hasattr(submodule, "_diffusers_hook"): + continue + registry = submodule._diffusers_hook + hook = registry.get_hook("layerwise_upcasting") + if hook is not None: + return hook.compute_dtype + + # 2. If no dtype modifying hooks are attached, return the dtype of the first floating point parameter/buffer last_dtype = None for param in parameter.parameters(): last_dtype = param.dtype @@ -321,8 +332,8 @@ def enable_layerwise_upcasting( self, storage_dtype: torch.dtype = torch.float8_e4m3fn, compute_dtype: Optional[torch.dtype] = None, - skip_modules_pattern: Optional[List[str]] = None, - skip_modules_classes: Optional[List[Type[torch.nn.Module]]] = None, + skip_modules_pattern: Optional[Tuple[str, ...]] = None, + skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None, non_blocking: bool = False, ) -> None: r""" @@ -362,9 +373,9 @@ def enable_layerwise_upcasting( The dtype to which the model should be cast for storage. compute_dtype (`torch.dtype`): The dtype to which the model weights should be cast during the forward pass. - skip_modules_pattern (`List[str]`, *optional*): + skip_modules_pattern (`Tuple[str, ...]`, *optional*): A list of patterns to match the names of the modules to skip during the layerwise upcasting process. - skip_modules_classes (`List[Type[torch.nn.Module]]`, *optional*): + skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, *optional*): A list of module classes to skip during the layerwise upcasting process. non_blocking (`bool`, *optional*, defaults to `False`): If `True`, the weight casting operations are non-blocking. @@ -372,12 +383,14 @@ def enable_layerwise_upcasting( user_provided_patterns = True if skip_modules_pattern is None: - skip_modules_pattern = [] + from ..hooks.layerwise_upcasting import DEFAULT_SKIP_MODULES_PATTERN + + skip_modules_pattern = DEFAULT_SKIP_MODULES_PATTERN user_provided_patterns = False if self._keep_in_fp32_modules is not None: - skip_modules_pattern.extend(self._keep_in_fp32_modules) + skip_modules_pattern += tuple(self._keep_in_fp32_modules) if self._always_upcast_modules is not None: - skip_modules_pattern.extend(self._always_upcast_modules) + skip_modules_pattern += tuple(self._always_upcast_modules) skip_modules_pattern = tuple(set(skip_modules_pattern)) if is_peft_available() and not user_provided_patterns: From bf797e7746d6bcbc59948d0a8c3151a8a8e9e3cd Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 17 Jan 2025 08:45:59 +0100 Subject: [PATCH 25/45] refactor --- src/diffusers/hooks/layerwise_upcasting.py | 28 ++++++++++++++++++---- src/diffusers/models/modeling_utils.py | 13 ++++------ 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/src/diffusers/hooks/layerwise_upcasting.py b/src/diffusers/hooks/layerwise_upcasting.py index 7ee920c1e002..3d85b56db72b 100644 --- a/src/diffusers/hooks/layerwise_upcasting.py +++ b/src/diffusers/hooks/layerwise_upcasting.py @@ -77,7 +77,6 @@ def apply_layerwise_upcasting( skip_modules_pattern: Union[str, Tuple[str, ...]] = "default", skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None, non_blocking: bool = False, - _prefix: str = "", ) -> None: r""" Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any @@ -97,7 +96,7 @@ def apply_layerwise_upcasting( ... transformer, ... storage_dtype=torch.float8_e4m3fn, ... compute_dtype=torch.bfloat16, - ... skip_modules_pattern=["patch_embed", "norm"], + ... skip_modules_pattern=["patch_embed", "norm", "proj_out"], ... non_blocking=True, ... ) ``` @@ -112,7 +111,9 @@ def apply_layerwise_upcasting( The dtype to cast the module to during the forward pass for computation. skip_modules_pattern (`Tuple[str, ...]`, defaults to `"default"`): A list of patterns to match the names of the modules to skip during the layerwise upcasting process. If set - to `"default"`, the default patterns are used. + to `"default"`, the default patterns are used. If set to `None`, no modules are skipped. If set to `None` + alongside `skip_modules_classes` being `None`, the layerwise upcasting is applied directly to the module + instead of its internal submodules. skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, defaults to `None`): A list of module classes to skip during the layerwise upcasting process. non_blocking (`bool`, defaults to `False`): @@ -125,6 +126,25 @@ def apply_layerwise_upcasting( apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking) return + _apply_layerwise_upcasting( + module, + storage_dtype, + compute_dtype, + skip_modules_pattern, + skip_modules_classes, + non_blocking, + ) + + +def _apply_layerwise_upcasting( + module: torch.nn.Module, + storage_dtype: torch.dtype, + compute_dtype: torch.dtype, + skip_modules_pattern: Optional[Tuple[str, ...]] = None, + skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None, + non_blocking: bool = False, + _prefix: str = "", +) -> None: should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or ( skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern) ) @@ -139,7 +159,7 @@ def apply_layerwise_upcasting( for name, submodule in module.named_children(): layer_name = f"{_prefix}.{name}" if _prefix else name - apply_layerwise_upcasting( + _apply_layerwise_upcasting( submodule, storage_dtype, compute_dtype, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 995560ecd0ec..402ad7788e5f 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -347,13 +347,13 @@ def enable_layerwise_upcasting( By default, most models in diffusers set the `_always_upcast_modules` attribute to ignore patch embedding, positional embedding and normalization layers. This is because these layers are most likely precision-critical for quality. If you wish to change this behavior, you can set the `_always_upcast_modules` attribute to `None`, - or call [`~apply_layerwise_upcasting`] with custom arguments. + or call [`~hooks.layerwise_upcasting.apply_layerwise_upcasting`] with custom arguments. Example: Using [`~models.ModelMixin.enable_layerwise_upcasting`]: ```python - >>> from diffusers import CogVideoXTransformer3DModel, apply_layerwise_upcasting + >>> from diffusers import CogVideoXTransformer3DModel >>> transformer = CogVideoXTransformer3DModel.from_pretrained( ... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16 @@ -361,11 +361,6 @@ def enable_layerwise_upcasting( >>> # Enable layerwise upcasting via the model, which ignores certain modules by default >>> transformer.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) - - >>> # Or, enable layerwise upcasting with custom arguments via the `apply_layerwise_upcasting` function - >>> apply_layerwise_upcasting( - ... transformer, torch.float8_e4m3fn, torch.bfloat16, skip_modules_pattern=["patch_embed", "norm"] - ... ) ``` Args: @@ -374,7 +369,9 @@ def enable_layerwise_upcasting( compute_dtype (`torch.dtype`): The dtype to which the model weights should be cast during the forward pass. skip_modules_pattern (`Tuple[str, ...]`, *optional*): - A list of patterns to match the names of the modules to skip during the layerwise upcasting process. + A list of patterns to match the names of the modules to skip during the layerwise upcasting process. If + set to `None`, default skip patterns are used to ignore certain internal layers of modules and PEFT + layers. skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, *optional*): A list of module classes to skip during the layerwise upcasting process. non_blocking (`bool`, *optional*, defaults to `False`): From 5956a9e36405e05e0f144a5fd9d8e0d4cac473a5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 17 Jan 2025 08:52:02 +0100 Subject: [PATCH 26/45] fix _always_upcast_modules for asym ae and vq_model --- src/diffusers/models/autoencoders/autoencoder_asym_kl.py | 2 +- src/diffusers/models/autoencoders/vq_model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_asym_kl.py b/src/diffusers/models/autoencoders/autoencoder_asym_kl.py index 3c16b766c23d..455828de40c3 100644 --- a/src/diffusers/models/autoencoders/autoencoder_asym_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_asym_kl.py @@ -60,7 +60,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. """ - _always_upcast_modules = ["MaskConditionDecoder"] + _always_upcast_modules = ["decoder"] @register_to_config def __init__( diff --git a/src/diffusers/models/autoencoders/vq_model.py b/src/diffusers/models/autoencoders/vq_model.py index 82e9dd8479a7..c0e2a3b64564 100644 --- a/src/diffusers/models/autoencoders/vq_model.py +++ b/src/diffusers/models/autoencoders/vq_model.py @@ -71,7 +71,7 @@ class VQModel(ModelMixin, ConfigMixin): Type of normalization layer to use. Can be one of `"group"` or `"spatial"`. """ - _always_upcast_modules = ["VectorQuantizer"] + _always_upcast_modules = ["quantize"] @register_to_config def __init__( From 93bd8eee4aed3034640082abd4730042daaaabc1 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 21 Jan 2025 10:49:29 +0100 Subject: [PATCH 27/45] fix lumina embedding forward to not depend on weight dtype --- src/diffusers/models/embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index c64b9587be77..bd3237c24c1c 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1787,7 +1787,7 @@ def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embeddi def forward(self, timestep, caption_feat, caption_mask): # timestep embedding: time_freq = self.time_proj(timestep) - time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype)) + time_embed = self.timestep_embedder(time_freq.to(dtype=caption_feat.dtype)) # caption condition embedding: caption_mask_float = caption_mask.float().unsqueeze(-1) From 77a32a7487e3b78e94d5fda3ac4f128d4701c843 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 21 Jan 2025 13:12:53 +0100 Subject: [PATCH 28/45] refactor tests --- tests/models/test_modeling_common.py | 125 ++++++++++++++------------- 1 file changed, 67 insertions(+), 58 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 882adbde79d4..a3e8bdf854e5 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -18,6 +18,7 @@ import inspect import json import os +import re import tempfile import traceback import unittest @@ -183,6 +184,16 @@ def compute_module_persistent_sizes( return module_sizes +def cast_maybe_tensor_dtype(maybe_tensor, current_dtype, target_dtype): + if torch.is_tensor(maybe_tensor): + return maybe_tensor.to(target_dtype) if maybe_tensor.dtype == current_dtype else maybe_tensor + if isinstance(maybe_tensor, dict): + return {k: cast_maybe_tensor_dtype(v, current_dtype, target_dtype) for k, v in maybe_tensor.items()} + if isinstance(maybe_tensor, list): + return [cast_maybe_tensor_dtype(v, current_dtype, target_dtype) for v in maybe_tensor] + return maybe_tensor + + class ModelUtilsTest(unittest.TestCase): def tearDown(self): super().tearDown() @@ -1334,80 +1345,78 @@ def test_variant_sharded_ckpt_right_format(self): assert all(f.split(".")[1].split("-")[0] == variant for f in shard_files) def test_layerwise_upcasting_inference(self): - torch.manual_seed(0) - config, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**config).eval() - model = model.to(torch_device) - base_slice = model(**inputs_dict)[0].flatten().detach().cpu().numpy() + from diffusers.hooks.layerwise_upcasting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS - # fp16-fp32 torch.manual_seed(0) config, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**config).eval() model = model.to(torch_device) - model.enable_layerwise_upcasting(storage_dtype=torch.float16, compute_dtype=torch.float32) - layerwise_upcast_slice_fp16 = model(**inputs_dict)[0].flatten().detach().cpu().numpy() + base_slice = model(**inputs_dict)[0].flatten().detach().cpu().numpy() - # The precision test is not very important for fast tests. In most cases, the outputs will not be the same. - # We just want to make sure that the layerwise upcasting is working as expected. - self.assertTrue(numpy_cosine_similarity_distance(base_slice, layerwise_upcast_slice_fp16) < 1.0) + def check_linear_dtype(module, storage_dtype, compute_dtype): + patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN + if getattr(module, "_always_upcast_modules", None) is not None: + patterns_to_check += tuple(module._always_upcast_modules) + for name, submodule in module.named_modules(): + if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS): + continue + dtype_to_check = storage_dtype + if any(re.search(pattern, name) for pattern in patterns_to_check): + dtype_to_check = compute_dtype + if getattr(submodule, "weight", None) is not None: + self.assertEqual(submodule.weight.dtype, dtype_to_check) + if getattr(submodule, "bias", None) is not None: + self.assertEqual(submodule.bias.dtype, dtype_to_check) + + def test_layerwise_upcasting(storage_dtype, compute_dtype): + torch.manual_seed(0) + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() + inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype) + model = self.model_class(**config).eval() + model = model.to(torch_device, dtype=compute_dtype) + model.enable_layerwise_upcasting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) + check_linear_dtype(model, storage_dtype, compute_dtype) + output = model(**inputs_dict)[0].float().flatten().detach().cpu().numpy() - # fp8_e4m3-fp32 - torch.manual_seed(0) - config, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**config).eval() - model = model.to(torch_device) - model.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32) - layerwise_upcast_slice_fp8_e4m3 = model(**inputs_dict)[0].flatten().detach().cpu().numpy() + # The precision test is not very important for fast tests. In most cases, the outputs will not be the same. + # We just want to make sure that the layerwise upcasting is working as expected. + self.assertTrue(numpy_cosine_similarity_distance(base_slice, output) < 1.0) - self.assertTrue(numpy_cosine_similarity_distance(base_slice, layerwise_upcast_slice_fp8_e4m3) < 1.0) - - # fp8_e5m2-fp32 - torch.manual_seed(0) - config, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**config).eval() - model = model.to(torch_device) - model.enable_layerwise_upcasting(storage_dtype=torch.float8_e5m2, compute_dtype=torch.float32) - layerwise_upcast_slice_fp8_e5m2 = model(**inputs_dict)[0].flatten().detach().cpu().numpy() - - self.assertTrue(numpy_cosine_similarity_distance(base_slice, layerwise_upcast_slice_fp8_e5m2) < 1.0) + test_layerwise_upcasting(torch.float16, torch.float32) + test_layerwise_upcasting(torch.float8_e4m3fn, torch.float32) + test_layerwise_upcasting(torch.float8_e5m2, torch.float32) + test_layerwise_upcasting(torch.float8_e4m3fn, torch.bfloat16) @require_torch_gpu def test_layerwise_upcasting_memory(self): - # fp32 - gc.collect() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - torch.cuda.synchronize() + def reset_memory_stats(): + gc.collect() + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() - torch.manual_seed(0) - config, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**config).eval() - model = model.to(torch_device) - model(**inputs_dict) - base_memory_footprint = model.get_memory_footprint() - base_max_memory = torch.cuda.max_memory_allocated() + def get_memory_usage(storage_dtype, compute_dtype): + torch.manual_seed(0) + config, inputs_dict = self.prepare_init_args_and_inputs_for_common() + inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype) + model = self.model_class(**config).eval() + model = model.to(torch_device, dtype=compute_dtype) + model.enable_layerwise_upcasting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) - model.to("cpu") - del model + reset_memory_stats() + model(**inputs_dict) + model_memory_footprint = model.get_memory_footprint() + peak_inference_memory_allocated = torch.cuda.max_memory_allocated() - # fp8_e4m3-fp32 - gc.collect() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - torch.cuda.synchronize() + return model_memory_footprint, peak_inference_memory_allocated - torch.manual_seed(0) - config, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**config).eval() - model = model.to(torch_device) - model.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32) - model(**inputs_dict) - fp8_e4m3_memory_footprint = model.get_memory_footprint() - fp8_e4m3_max_memory = torch.cuda.max_memory_allocated() + fp8_e4m3_fp32_memory_footprint, fp8_e4m3_fp32_max_memory = get_memory_usage(torch.float8_e4m3fn, torch.float32) + fp8_e4m3_bf16_memory_footprint, fp8_e4m3_bf16_max_memory = get_memory_usage( + torch.float8_e4m3fn, torch.bfloat16 + ) - self.assertTrue(fp8_e4m3_memory_footprint < base_memory_footprint) - self.assertTrue(fp8_e4m3_max_memory < base_max_memory) + self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint) + self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory) @is_staging_test From 1335d7e1df1d4f3596b3f7e53bb8d288d92ab749 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 21 Jan 2025 13:52:20 +0100 Subject: [PATCH 29/45] add simple lora inference tests --- tests/lora/utils.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index a22f86ad6b89..20c77f312a22 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -2098,3 +2098,43 @@ def test_correct_lora_configs_with_different_ranks(self): lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3)) self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) + + def test_layerwise_upcasting_inference_denoiser(self): + def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32): + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device, dtype=compute_dtype) + pipe.set_progress_bar_config(disable=None) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + if storage_dtype is not None: + denoiser.enable_layerwise_upcasting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) + + return pipe + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe_fp32 = initialize_pipeline(storage_dtype=None) + pipe_fp32(**inputs, generator=torch.manual_seed(0))[0] + + pipe_float8_e4m3_fp32 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32) + pipe_float8_e4m3_fp32(**inputs, generator=torch.manual_seed(0))[0] + + pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) + pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0] From a263e1a09416095fc8895ea606f780fd36357347 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 21 Jan 2025 14:00:09 +0100 Subject: [PATCH 30/45] _always_upcast_modules -> _precision_sensitive_module_patterns --- .../models/autoencoders/autoencoder_asym_kl.py | 2 +- src/diffusers/models/autoencoders/vq_model.py | 2 +- src/diffusers/models/modeling_utils.py | 15 ++++++++------- .../transformers/auraflow_transformer_2d.py | 2 +- .../transformers/cogvideox_transformer_3d.py | 2 +- .../models/transformers/dit_transformer_2d.py | 2 +- .../models/transformers/hunyuan_transformer_2d.py | 2 +- .../models/transformers/latte_transformer_3d.py | 2 +- .../models/transformers/lumina_nextdit2d.py | 2 +- .../models/transformers/pixart_transformer_2d.py | 2 +- .../models/transformers/sana_transformer.py | 2 +- .../transformers/stable_audio_transformer.py | 1 + .../models/transformers/transformer_2d.py | 2 +- .../models/transformers/transformer_allegro.py | 2 +- .../transformers/transformer_cogview3plus.py | 2 +- .../models/transformers/transformer_flux.py | 2 +- .../transformers/transformer_hunyuan_video.py | 2 +- .../models/transformers/transformer_ltx.py | 2 +- .../models/transformers/transformer_mochi.py | 2 +- .../models/transformers/transformer_sd3.py | 2 +- .../models/transformers/transformer_temporal.py | 2 +- src/diffusers/models/unets/unet_1d.py | 2 +- src/diffusers/models/unets/unet_2d.py | 2 +- src/diffusers/models/unets/unet_2d_condition.py | 2 +- src/diffusers/models/unets/unet_3d_condition.py | 2 +- src/diffusers/models/unets/unet_motion_model.py | 2 +- tests/models/test_modeling_common.py | 4 ++-- 27 files changed, 35 insertions(+), 33 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_asym_kl.py b/src/diffusers/models/autoencoders/autoencoder_asym_kl.py index 455828de40c3..20b6ee7b1ad5 100644 --- a/src/diffusers/models/autoencoders/autoencoder_asym_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_asym_kl.py @@ -60,7 +60,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. """ - _always_upcast_modules = ["decoder"] + _precision_sensitive_module_patterns = ["decoder"] @register_to_config def __init__( diff --git a/src/diffusers/models/autoencoders/vq_model.py b/src/diffusers/models/autoencoders/vq_model.py index c0e2a3b64564..5339c88f09ce 100644 --- a/src/diffusers/models/autoencoders/vq_model.py +++ b/src/diffusers/models/autoencoders/vq_model.py @@ -71,7 +71,7 @@ class VQModel(ModelMixin, ConfigMixin): Type of normalization layer to use. Can be one of `"group"` or `"spatial"`. """ - _always_upcast_modules = ["quantize"] + _precision_sensitive_module_patterns = ["quantize"] @register_to_config def __init__( diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 48fa973dfd34..da92ee675faf 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -163,7 +163,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): _keys_to_ignore_on_load_unexpected = None _no_split_modules = None _keep_in_fp32_modules = None - _always_upcast_modules = None + _precision_sensitive_module_patterns = None def __init__(self): super().__init__() @@ -344,10 +344,11 @@ def enable_layerwise_upcasting( memory footprint from model weights, but may lead to some quality degradation in the outputs. Most degradations are negligible, mostly stemming from weight casting in normalization and modulation layers. - By default, most models in diffusers set the `_always_upcast_modules` attribute to ignore patch embedding, - positional embedding and normalization layers. This is because these layers are most likely precision-critical - for quality. If you wish to change this behavior, you can set the `_always_upcast_modules` attribute to `None`, - or call [`~hooks.layerwise_upcasting.apply_layerwise_upcasting`] with custom arguments. + By default, most models in diffusers set the `_precision_sensitive_module_patterns` attribute to ignore patch + embedding, positional embedding and normalization layers. This is because these layers are most likely + precision-critical for quality. If you wish to change this behavior, you can set the + `_precision_sensitive_module_patterns` attribute to `None`, or call + [`~hooks.layerwise_upcasting.apply_layerwise_upcasting`] with custom arguments. Example: Using [`~models.ModelMixin.enable_layerwise_upcasting`]: @@ -386,8 +387,8 @@ def enable_layerwise_upcasting( user_provided_patterns = False if self._keep_in_fp32_modules is not None: skip_modules_pattern += tuple(self._keep_in_fp32_modules) - if self._always_upcast_modules is not None: - skip_modules_pattern += tuple(self._always_upcast_modules) + if self._precision_sensitive_module_patterns is not None: + skip_modules_pattern += tuple(self._precision_sensitive_module_patterns) skip_modules_pattern = tuple(set(skip_modules_pattern)) if is_peft_available() and not user_provided_patterns: diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 834d141b6889..bb332eca2e19 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -276,7 +276,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin """ _no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"] - _always_upcast_modules = ["pos_embed", "norm"] + _precision_sensitive_module_patterns = ["pos_embed", "norm"] _supports_gradient_checkpointing = True @register_to_config diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 368c7f98afc4..ebf987457df0 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -212,7 +212,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): Scaling factor to apply in 3D positional embeddings across temporal dimensions. """ - _always_upcast_modules = ["patch_embed", "norm"] + _precision_sensitive_module_patterns = ["patch_embed", "norm"] _supports_gradient_checkpointing = True _no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"] diff --git a/src/diffusers/models/transformers/dit_transformer_2d.py b/src/diffusers/models/transformers/dit_transformer_2d.py index 6484d59d9464..1660bff968dd 100644 --- a/src/diffusers/models/transformers/dit_transformer_2d.py +++ b/src/diffusers/models/transformers/dit_transformer_2d.py @@ -64,7 +64,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin): A small constant added to the denominator in normalization layers to prevent division by zero. """ - _always_upcast_modules = ["pos_embed", "norm"] + _precision_sensitive_module_patterns = ["pos_embed", "norm"] _supports_gradient_checkpointing = True @register_to_config diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py index dde728e12b20..59ff0dae35fe 100644 --- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -244,7 +244,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin): Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2 """ - _always_upcast_modules = ["pos_embed", "norm", "pooler"] + _precision_sensitive_module_patterns = ["pos_embed", "norm", "pooler"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index 4f534c02704f..3815bbed22fd 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -65,7 +65,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin): The number of frames in the video-like data. """ - _always_upcast_modules = ["pos_embed", "norm"] + _precision_sensitive_module_patterns = ["pos_embed", "norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/lumina_nextdit2d.py b/src/diffusers/models/transformers/lumina_nextdit2d.py index 56f8768c489d..4d7852175983 100644 --- a/src/diffusers/models/transformers/lumina_nextdit2d.py +++ b/src/diffusers/models/transformers/lumina_nextdit2d.py @@ -221,7 +221,7 @@ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin): overall scale of the model's operations. """ - _always_upcast_modules = ["patch_embedder", "norm", "ffn_norm"] + _precision_sensitive_module_patterns = ["patch_embedder", "norm", "ffn_norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py index 61b1ac730e56..eba2c0497633 100644 --- a/src/diffusers/models/transformers/pixart_transformer_2d.py +++ b/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -79,7 +79,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"] - _always_upcast_modules = ["pos_embed", "norm", "adaln_single"] + _precision_sensitive_module_patterns = ["pos_embed", "norm", "adaln_single"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index a2201a105958..b40ba3d5075b 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -222,7 +222,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): _supports_gradient_checkpointing = True _no_split_modules = ["SanaTransformerBlock", "PatchEmbed"] - _always_upcast_modules = ["patch_embed", "norm"] + _precision_sensitive_module_patterns = ["patch_embed", "norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/stable_audio_transformer.py b/src/diffusers/models/transformers/stable_audio_transformer.py index d687dbabf317..6a86e337b0e1 100644 --- a/src/diffusers/models/transformers/stable_audio_transformer.py +++ b/src/diffusers/models/transformers/stable_audio_transformer.py @@ -211,6 +211,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True + _precision_sensitive_module_patterns = ["preprocess_conv", "postprocess_conv", "^proj_in$", "^proj_out$", "norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index 41134fc14907..ecefeba7c3a1 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -66,7 +66,7 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin): _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock"] - _always_upcast_modules = ["latent_image_embedding", "norm"] + _precision_sensitive_module_patterns = ["latent_image_embedding", "norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index a62a23ad5b58..aa61326f37be 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -222,7 +222,7 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True - _always_upcast_modules = ["pos_embed", "norm", "adaln_single"] + _precision_sensitive_module_patterns = ["pos_embed", "norm", "adaln_single"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py index 417dbdf06439..61cc1787968e 100644 --- a/src/diffusers/models/transformers/transformer_cogview3plus.py +++ b/src/diffusers/models/transformers/transformer_cogview3plus.py @@ -166,7 +166,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True - _always_upcast_modules = ["patch_embed", "norm"] + _precision_sensitive_module_patterns = ["patch_embed", "norm"] _no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"] @register_to_config diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 81fa52c36445..936090379e02 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -262,7 +262,7 @@ class FluxTransformer2DModel( _supports_gradient_checkpointing = True _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] - _always_upcast_modules = ["pos_embed", "norm"] + _precision_sensitive_module_patterns = ["pos_embed", "norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index e968b8a7c054..7b223163b664 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -542,7 +542,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, """ _supports_gradient_checkpointing = True - _always_upcast_modules = ["x_embedder", "context_embedder", "norm"] + _precision_sensitive_module_patterns = ["x_embedder", "context_embedder", "norm"] _no_split_modules = [ "HunyuanVideoTransformerBlock", "HunyuanVideoSingleTransformerBlock", diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 07368df92822..2adfb2c7a23e 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -295,7 +295,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin """ _supports_gradient_checkpointing = True - _always_upcast_modules = ["norm"] + _precision_sensitive_module_patterns = ["norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 5a3d82f31813..81ec8e9f6f5b 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -336,7 +336,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri _supports_gradient_checkpointing = True _no_split_modules = ["MochiTransformerBlock"] - _always_upcast_modules = ["patch_embed", "norm"] + _precision_sensitive_module_patterns = ["patch_embed", "norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 92344061f920..110dbbbe7c5d 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -127,7 +127,7 @@ class SD3Transformer2DModel( """ _supports_gradient_checkpointing = True - _always_upcast_modules = ["pos_embed", "norm"] + _precision_sensitive_module_patterns = ["pos_embed", "norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_temporal.py b/src/diffusers/models/transformers/transformer_temporal.py index 174e6a52ef73..b27feb3bfd25 100644 --- a/src/diffusers/models/transformers/transformer_temporal.py +++ b/src/diffusers/models/transformers/transformer_temporal.py @@ -67,7 +67,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): The maximum length of the sequence over which to apply positional embeddings. """ - _always_upcast_modules = ["norm"] + _precision_sensitive_module_patterns = ["norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_1d.py b/src/diffusers/models/unets/unet_1d.py index 18ac1d63e0d0..9b17c0bf7ccf 100644 --- a/src/diffusers/models/unets/unet_1d.py +++ b/src/diffusers/models/unets/unet_1d.py @@ -71,7 +71,7 @@ class UNet1DModel(ModelMixin, ConfigMixin): Experimental feature for using a UNet without upsampling. """ - _always_upcast_modules = ["norm"] + _precision_sensitive_module_patterns = ["norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_2d.py b/src/diffusers/models/unets/unet_2d.py index a673b9265251..08001bafe3ce 100644 --- a/src/diffusers/models/unets/unet_2d.py +++ b/src/diffusers/models/unets/unet_2d.py @@ -90,7 +90,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True - _always_upcast_modules = ["norm"] + _precision_sensitive_module_patterns = ["norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 18e94eaccf84..f6d9bc22e65b 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -166,7 +166,7 @@ class conditioning with `class_embed_type` equal to `None`. _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] - _always_upcast_modules = ["norm"] + _precision_sensitive_module_patterns = ["norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py index b1bf24cfb849..7a76bcf63b81 100644 --- a/src/diffusers/models/unets/unet_3d_condition.py +++ b/src/diffusers/models/unets/unet_3d_condition.py @@ -97,7 +97,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) """ _supports_gradient_checkpointing = False - _always_upcast_modules = ["norm", "time_embedding"] + _precision_sensitive_module_patterns = ["norm", "time_embedding"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index d4f6b57b5519..403f7a3df550 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -1301,7 +1301,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft """ _supports_gradient_checkpointing = True - _always_upcast_modules = ["norm"] + _precision_sensitive_module_patterns = ["norm"] @register_to_config def __init__( diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index a3e8bdf854e5..bbd8b980f11b 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1355,8 +1355,8 @@ def test_layerwise_upcasting_inference(self): def check_linear_dtype(module, storage_dtype, compute_dtype): patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN - if getattr(module, "_always_upcast_modules", None) is not None: - patterns_to_check += tuple(module._always_upcast_modules) + if getattr(module, "_precision_sensitive_module_patterns", None) is not None: + patterns_to_check += tuple(module._precision_sensitive_module_patterns) for name, submodule in module.named_modules(): if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS): continue From 245137feee570859998f2c98c150ec9069cb5457 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 21 Jan 2025 14:06:53 +0100 Subject: [PATCH 31/45] remove todo comments about review; revert changes to self.dtype in unets because .dtype on ModelMixin should be able to handle fp8 weight case --- src/diffusers/models/unets/unet_2d.py | 5 ++--- src/diffusers/models/unets/unet_motion_model.py | 3 +-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/unets/unet_2d.py b/src/diffusers/models/unets/unet_2d.py index 08001bafe3ce..da34b14c4b35 100644 --- a/src/diffusers/models/unets/unet_2d.py +++ b/src/diffusers/models/unets/unet_2d.py @@ -295,8 +295,7 @@ def forward( # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. - # TODO(aryan): Need to have this reviewed - t_emb = t_emb.to(dtype=sample.dtype) + t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb) if self.class_embedding is not None: @@ -306,7 +305,7 @@ def forward( if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) - class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb elif self.class_embedding is None and class_labels is not None: raise ValueError("class_embedding needs to be initialized in order to use class conditioning") diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 5290aa738e18..5e0a5c1e2218 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -2133,8 +2133,7 @@ def forward( # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. - # TODO(aryan): Need to have this reviewed - t_emb = t_emb.to(dtype=sample.dtype) + t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None From b713511dc3c7de83ed9db3acde5ce21cdeeca17b Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 21 Jan 2025 14:12:39 +0100 Subject: [PATCH 32/45] check layer dtypes in lora test --- tests/lora/utils.py | 19 +++++++++++++++++++ tests/models/test_modeling_common.py | 1 + 2 files changed, 20 insertions(+) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 20c77f312a22..f3b2cc695e07 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -14,6 +14,7 @@ # limitations under the License. import inspect import os +import re import tempfile import unittest from itertools import product @@ -2100,6 +2101,23 @@ def test_correct_lora_configs_with_different_ranks(self): self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) def test_layerwise_upcasting_inference_denoiser(self): + from diffusers.hooks.layerwise_upcasting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS + + def check_linear_dtype(module, storage_dtype, compute_dtype): + patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN + if getattr(module, "_precision_sensitive_module_patterns", None) is not None: + patterns_to_check += tuple(module._precision_sensitive_module_patterns) + for name, submodule in module.named_modules(): + if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS): + continue + dtype_to_check = storage_dtype + if "lora" in name or any(re.search(pattern, name) for pattern in patterns_to_check): + dtype_to_check = compute_dtype + if getattr(submodule, "weight", None) is not None: + self.assertEqual(submodule.weight.dtype, dtype_to_check) + if getattr(submodule, "bias", None) is not None: + self.assertEqual(submodule.bias.dtype, dtype_to_check) + def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32): components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) pipe = self.pipeline_class(**components) @@ -2125,6 +2143,7 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32): if storage_dtype is not None: denoiser.enable_layerwise_upcasting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) + check_linear_dtype(denoiser, storage_dtype, compute_dtype) return pipe diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 0b3c1dbafe18..6b5146c69e81 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -62,6 +62,7 @@ require_torch_2, require_torch_accelerator, require_torch_accelerator_with_training, + require_torch_gpu, require_torch_multi_gpu, run_test_in_subprocess, torch_all_close, From ed14d260de6766a0a53f31bf666b6b8d32a88e82 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 21 Jan 2025 14:40:02 +0100 Subject: [PATCH 33/45] fix UNet1DModelTests::test_layerwise_upcasting_inference --- src/diffusers/models/unets/unet_1d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unets/unet_1d.py b/src/diffusers/models/unets/unet_1d.py index 9b17c0bf7ccf..34a647028fc6 100644 --- a/src/diffusers/models/unets/unet_1d.py +++ b/src/diffusers/models/unets/unet_1d.py @@ -225,7 +225,7 @@ def forward( timestep_embed = self.time_proj(timesteps) if self.config.use_timestep_embedding: - timestep_embed = self.time_mlp(timestep_embed) + timestep_embed = self.time_mlp(timestep_embed.to(sample.dtype)) else: timestep_embed = timestep_embed[..., None] timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype) From 2c9c33f3f8e3548ae194e08a6c7cfe7094072a8d Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 21 Jan 2025 14:42:38 +0100 Subject: [PATCH 34/45] _precision_sensitive_module_patterns -> _skip_layerwise_casting_patterns based on feedback --- .../models/autoencoders/autoencoder_asym_kl.py | 2 +- src/diffusers/models/autoencoders/vq_model.py | 2 +- src/diffusers/models/modeling_utils.py | 10 +++++----- .../models/transformers/auraflow_transformer_2d.py | 2 +- .../models/transformers/cogvideox_transformer_3d.py | 2 +- .../models/transformers/dit_transformer_2d.py | 2 +- .../models/transformers/hunyuan_transformer_2d.py | 2 +- .../models/transformers/latte_transformer_3d.py | 2 +- src/diffusers/models/transformers/lumina_nextdit2d.py | 2 +- .../models/transformers/pixart_transformer_2d.py | 2 +- src/diffusers/models/transformers/sana_transformer.py | 2 +- .../models/transformers/stable_audio_transformer.py | 2 +- src/diffusers/models/transformers/transformer_2d.py | 2 +- .../models/transformers/transformer_allegro.py | 2 +- .../models/transformers/transformer_cogview3plus.py | 2 +- src/diffusers/models/transformers/transformer_flux.py | 2 +- .../models/transformers/transformer_hunyuan_video.py | 2 +- src/diffusers/models/transformers/transformer_ltx.py | 2 +- src/diffusers/models/transformers/transformer_mochi.py | 2 +- src/diffusers/models/transformers/transformer_sd3.py | 2 +- .../models/transformers/transformer_temporal.py | 2 +- src/diffusers/models/unets/unet_1d.py | 2 +- src/diffusers/models/unets/unet_2d.py | 2 +- src/diffusers/models/unets/unet_2d_condition.py | 2 +- src/diffusers/models/unets/unet_3d_condition.py | 2 +- src/diffusers/models/unets/unet_motion_model.py | 2 +- tests/lora/utils.py | 4 ++-- tests/models/test_modeling_common.py | 4 ++-- 28 files changed, 34 insertions(+), 34 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_asym_kl.py b/src/diffusers/models/autoencoders/autoencoder_asym_kl.py index 20b6ee7b1ad5..c643dcc72a34 100644 --- a/src/diffusers/models/autoencoders/autoencoder_asym_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_asym_kl.py @@ -60,7 +60,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. """ - _precision_sensitive_module_patterns = ["decoder"] + _skip_layerwise_casting_patterns = ["decoder"] @register_to_config def __init__( diff --git a/src/diffusers/models/autoencoders/vq_model.py b/src/diffusers/models/autoencoders/vq_model.py index 5339c88f09ce..e754e134b35f 100644 --- a/src/diffusers/models/autoencoders/vq_model.py +++ b/src/diffusers/models/autoencoders/vq_model.py @@ -71,7 +71,7 @@ class VQModel(ModelMixin, ConfigMixin): Type of normalization layer to use. Can be one of `"group"` or `"spatial"`. """ - _precision_sensitive_module_patterns = ["quantize"] + _skip_layerwise_casting_patterns = ["quantize"] @register_to_config def __init__( diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 6e108cef0857..8eebf944f5fe 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -163,7 +163,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): _keys_to_ignore_on_load_unexpected = None _no_split_modules = None _keep_in_fp32_modules = None - _precision_sensitive_module_patterns = None + _skip_layerwise_casting_patterns = None def __init__(self): super().__init__() @@ -344,10 +344,10 @@ def enable_layerwise_upcasting( memory footprint from model weights, but may lead to some quality degradation in the outputs. Most degradations are negligible, mostly stemming from weight casting in normalization and modulation layers. - By default, most models in diffusers set the `_precision_sensitive_module_patterns` attribute to ignore patch + By default, most models in diffusers set the `_skip_layerwise_casting_patterns` attribute to ignore patch embedding, positional embedding and normalization layers. This is because these layers are most likely precision-critical for quality. If you wish to change this behavior, you can set the - `_precision_sensitive_module_patterns` attribute to `None`, or call + `_skip_layerwise_casting_patterns` attribute to `None`, or call [`~hooks.layerwise_upcasting.apply_layerwise_upcasting`] with custom arguments. Example: @@ -387,8 +387,8 @@ def enable_layerwise_upcasting( user_provided_patterns = False if self._keep_in_fp32_modules is not None: skip_modules_pattern += tuple(self._keep_in_fp32_modules) - if self._precision_sensitive_module_patterns is not None: - skip_modules_pattern += tuple(self._precision_sensitive_module_patterns) + if self._skip_layerwise_casting_patterns is not None: + skip_modules_pattern += tuple(self._skip_layerwise_casting_patterns) skip_modules_pattern = tuple(set(skip_modules_pattern)) if is_peft_available() and not user_provided_patterns: diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index bb332eca2e19..f1f36b87987d 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -276,7 +276,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin """ _no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"] - _precision_sensitive_module_patterns = ["pos_embed", "norm"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] _supports_gradient_checkpointing = True @register_to_config diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index ebf987457df0..c3039180b81d 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -212,7 +212,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): Scaling factor to apply in 3D positional embeddings across temporal dimensions. """ - _precision_sensitive_module_patterns = ["patch_embed", "norm"] + _skip_layerwise_casting_patterns = ["patch_embed", "norm"] _supports_gradient_checkpointing = True _no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"] diff --git a/src/diffusers/models/transformers/dit_transformer_2d.py b/src/diffusers/models/transformers/dit_transformer_2d.py index 1660bff968dd..7eac313c14db 100644 --- a/src/diffusers/models/transformers/dit_transformer_2d.py +++ b/src/diffusers/models/transformers/dit_transformer_2d.py @@ -64,7 +64,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin): A small constant added to the denominator in normalization layers to prevent division by zero. """ - _precision_sensitive_module_patterns = ["pos_embed", "norm"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] _supports_gradient_checkpointing = True @register_to_config diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py index 59ff0dae35fe..13aa7d076d03 100644 --- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -244,7 +244,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin): Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2 """ - _precision_sensitive_module_patterns = ["pos_embed", "norm", "pooler"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm", "pooler"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index 3815bbed22fd..be06f44a9efe 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -65,7 +65,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin): The number of frames in the video-like data. """ - _precision_sensitive_module_patterns = ["pos_embed", "norm"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/lumina_nextdit2d.py b/src/diffusers/models/transformers/lumina_nextdit2d.py index 4d7852175983..fb2b3815bcd5 100644 --- a/src/diffusers/models/transformers/lumina_nextdit2d.py +++ b/src/diffusers/models/transformers/lumina_nextdit2d.py @@ -221,7 +221,7 @@ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin): overall scale of the model's operations. """ - _precision_sensitive_module_patterns = ["patch_embedder", "norm", "ffn_norm"] + _skip_layerwise_casting_patterns = ["patch_embedder", "norm", "ffn_norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py index eba2c0497633..b1740cc08fdf 100644 --- a/src/diffusers/models/transformers/pixart_transformer_2d.py +++ b/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -79,7 +79,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"] - _precision_sensitive_module_patterns = ["pos_embed", "norm", "adaln_single"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index 5ef4c4aa8e06..a2a54406430d 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -236,7 +236,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): _supports_gradient_checkpointing = True _no_split_modules = ["SanaTransformerBlock", "PatchEmbed", "SanaModulatedNorm"] - _precision_sensitive_module_patterns = ["patch_embed", "norm"] + _skip_layerwise_casting_patterns = ["patch_embed", "norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/stable_audio_transformer.py b/src/diffusers/models/transformers/stable_audio_transformer.py index 6a86e337b0e1..bb370f20f21b 100644 --- a/src/diffusers/models/transformers/stable_audio_transformer.py +++ b/src/diffusers/models/transformers/stable_audio_transformer.py @@ -211,7 +211,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True - _precision_sensitive_module_patterns = ["preprocess_conv", "postprocess_conv", "^proj_in$", "^proj_out$", "norm"] + _skip_layerwise_casting_patterns = ["preprocess_conv", "postprocess_conv", "^proj_in$", "^proj_out$", "norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index ecefeba7c3a1..35e78877f27e 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -66,7 +66,7 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin): _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock"] - _precision_sensitive_module_patterns = ["latent_image_embedding", "norm"] + _skip_layerwise_casting_patterns = ["latent_image_embedding", "norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index aa61326f37be..f32c38394ba4 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -222,7 +222,7 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True - _precision_sensitive_module_patterns = ["pos_embed", "norm", "adaln_single"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py index 61cc1787968e..0376cc2fd70d 100644 --- a/src/diffusers/models/transformers/transformer_cogview3plus.py +++ b/src/diffusers/models/transformers/transformer_cogview3plus.py @@ -166,7 +166,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True - _precision_sensitive_module_patterns = ["patch_embed", "norm"] + _skip_layerwise_casting_patterns = ["patch_embed", "norm"] _no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"] @register_to_config diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 936090379e02..db8d73856689 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -262,7 +262,7 @@ class FluxTransformer2DModel( _supports_gradient_checkpointing = True _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] - _precision_sensitive_module_patterns = ["pos_embed", "norm"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 7b223163b664..210a2e711972 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -542,7 +542,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, """ _supports_gradient_checkpointing = True - _precision_sensitive_module_patterns = ["x_embedder", "context_embedder", "norm"] + _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"] _no_split_modules = [ "HunyuanVideoTransformerBlock", "HunyuanVideoSingleTransformerBlock", diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 2adfb2c7a23e..b5498c0aed01 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -295,7 +295,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin """ _supports_gradient_checkpointing = True - _precision_sensitive_module_patterns = ["norm"] + _skip_layerwise_casting_patterns = ["norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 81ec8e9f6f5b..d16430f27931 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -336,7 +336,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri _supports_gradient_checkpointing = True _no_split_modules = ["MochiTransformerBlock"] - _precision_sensitive_module_patterns = ["patch_embed", "norm"] + _skip_layerwise_casting_patterns = ["patch_embed", "norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 110dbbbe7c5d..2688d3640ea5 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -127,7 +127,7 @@ class SD3Transformer2DModel( """ _supports_gradient_checkpointing = True - _precision_sensitive_module_patterns = ["pos_embed", "norm"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_temporal.py b/src/diffusers/models/transformers/transformer_temporal.py index b27feb3bfd25..3b5aedb79e3c 100644 --- a/src/diffusers/models/transformers/transformer_temporal.py +++ b/src/diffusers/models/transformers/transformer_temporal.py @@ -67,7 +67,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): The maximum length of the sequence over which to apply positional embeddings. """ - _precision_sensitive_module_patterns = ["norm"] + _skip_layerwise_casting_patterns = ["norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_1d.py b/src/diffusers/models/unets/unet_1d.py index 34a647028fc6..ce496fd6baf8 100644 --- a/src/diffusers/models/unets/unet_1d.py +++ b/src/diffusers/models/unets/unet_1d.py @@ -71,7 +71,7 @@ class UNet1DModel(ModelMixin, ConfigMixin): Experimental feature for using a UNet without upsampling. """ - _precision_sensitive_module_patterns = ["norm"] + _skip_layerwise_casting_patterns = ["norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_2d.py b/src/diffusers/models/unets/unet_2d.py index da34b14c4b35..84a1322d2a95 100644 --- a/src/diffusers/models/unets/unet_2d.py +++ b/src/diffusers/models/unets/unet_2d.py @@ -90,7 +90,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True - _precision_sensitive_module_patterns = ["norm"] + _skip_layerwise_casting_patterns = ["norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 806d7a563a19..3447fa0674bc 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -166,7 +166,7 @@ class conditioning with `class_embed_type` equal to `None`. _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] - _precision_sensitive_module_patterns = ["norm"] + _skip_layerwise_casting_patterns = ["norm"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py index e29a76f373c5..398609778e65 100644 --- a/src/diffusers/models/unets/unet_3d_condition.py +++ b/src/diffusers/models/unets/unet_3d_condition.py @@ -97,7 +97,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) """ _supports_gradient_checkpointing = False - _precision_sensitive_module_patterns = ["norm", "time_embedding"] + _skip_layerwise_casting_patterns = ["norm", "time_embedding"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 5e0a5c1e2218..1d0a38a8fb13 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -1301,7 +1301,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft """ _supports_gradient_checkpointing = True - _precision_sensitive_module_patterns = ["norm"] + _skip_layerwise_casting_patterns = ["norm"] @register_to_config def __init__( diff --git a/tests/lora/utils.py b/tests/lora/utils.py index f3b2cc695e07..14dbd04101aa 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -2105,8 +2105,8 @@ def test_layerwise_upcasting_inference_denoiser(self): def check_linear_dtype(module, storage_dtype, compute_dtype): patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN - if getattr(module, "_precision_sensitive_module_patterns", None) is not None: - patterns_to_check += tuple(module._precision_sensitive_module_patterns) + if getattr(module, "_skip_layerwise_casting_patterns", None) is not None: + patterns_to_check += tuple(module._skip_layerwise_casting_patterns) for name, submodule in module.named_modules(): if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS): continue diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 9f966e687fd2..ba3931ff452f 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1357,8 +1357,8 @@ def test_layerwise_upcasting_inference(self): def check_linear_dtype(module, storage_dtype, compute_dtype): patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN - if getattr(module, "_precision_sensitive_module_patterns", None) is not None: - patterns_to_check += tuple(module._precision_sensitive_module_patterns) + if getattr(module, "_skip_layerwise_casting_patterns", None) is not None: + patterns_to_check += tuple(module._skip_layerwise_casting_patterns) for name, submodule in module.named_modules(): if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS): continue From 08211f75db7dd55982e2f8d82fbcb3abd935d5b6 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 21 Jan 2025 14:50:35 +0100 Subject: [PATCH 35/45] skip test in NCSNppModelTests --- tests/models/unets/test_models_unet_2d.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/models/unets/test_models_unet_2d.py b/tests/models/unets/test_models_unet_2d.py index 05bece23efd6..69e882202d12 100644 --- a/tests/models/unets/test_models_unet_2d.py +++ b/tests/models/unets/test_models_unet_2d.py @@ -401,3 +401,15 @@ def test_gradient_checkpointing_is_applied(self): def test_effective_gradient_checkpointing(self): super().test_effective_gradient_checkpointing(skip={"time_proj.weight"}) + + @unittest.skip( + "To make layerwise upcasting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here." + ) + def test_layerwise_upcasting_inference(self): + pass + + @unittest.skip( + "To make layerwise upcasting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here." + ) + def test_layerwise_upcasting_memory(self): + pass From 59e04c3dcd911ce37b6b4a6d10f7ff03f8ea2f0d Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 21 Jan 2025 15:11:03 +0100 Subject: [PATCH 36/45] skip tests for AutoencoderTinyTests --- .../autoencoders/test_models_autoencoder_tiny.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/models/autoencoders/test_models_autoencoder_tiny.py b/tests/models/autoencoders/test_models_autoencoder_tiny.py index 4de3822fa835..9b6ec42496ce 100644 --- a/tests/models/autoencoders/test_models_autoencoder_tiny.py +++ b/tests/models/autoencoders/test_models_autoencoder_tiny.py @@ -173,6 +173,22 @@ def test_effective_gradient_checkpointing(self): continue self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=3e-2)) + @unittest.skip( + "The forward pass of AutoencoderTiny creates a torch.float32 tensor. This causes inference in compute_dtype=torch.bfloat16 to fail. To fix:\n" + "1. Change the forward pass to be dtype agnostic.\n" + "2. Unskip this test." + ) + def test_layerwise_upcasting_inference(self): + pass + + @unittest.skip( + "The forward pass of AutoencoderTiny creates a torch.float32 tensor. This causes inference in compute_dtype=torch.bfloat16 to fail. To fix:\n" + "1. Change the forward pass to be dtype agnostic.\n" + "2. Unskip this test." + ) + def test_layerwise_upcasting_memory(self): + pass + @slow class AutoencoderTinyIntegrationTests(unittest.TestCase): From 0a16826d1e88114d7f4af37c538ca29f12034d3c Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 21 Jan 2025 15:16:14 +0100 Subject: [PATCH 37/45] skip tests for AutoencoderOobleckTests --- .../test_models_autoencoder_oobleck.py | 18 ++++++++++++++++++ tests/models/test_modeling_common.py | 2 ++ 2 files changed, 20 insertions(+) diff --git a/tests/models/autoencoders/test_models_autoencoder_oobleck.py b/tests/models/autoencoders/test_models_autoencoder_oobleck.py index 4807fa298344..1abda26233f7 100644 --- a/tests/models/autoencoders/test_models_autoencoder_oobleck.py +++ b/tests/models/autoencoders/test_models_autoencoder_oobleck.py @@ -114,6 +114,24 @@ def test_forward_with_norm_groups(self): def test_set_attn_processor_for_determinism(self): return + @unittest.skip( + "The convolution layers of AutoencoderOobleck are wrapped with torch.nn.utils.weight_norm. This causes the hook's pre_forward to not " + "cast the module weights to compute_dtype (as required by forward pass). As a result, forward pass errors out. To fix:\n" + "1. Make sure `nn::Module::to` works with `torch.nn.utils.weight_norm` wrapped convolution layer.\n" + "2. Unskip this test." + ) + def test_layerwise_upcasting_inference(self): + pass + + @unittest.skip( + "The convolution layers of AutoencoderOobleck are wrapped with torch.nn.utils.weight_norm. This causes the hook's pre_forward to not " + "cast the module weights to compute_dtype (as required by forward pass). As a result, forward pass errors out. To fix:\n" + "1. Make sure `nn::Module::to` works with `torch.nn.utils.weight_norm` wrapped convolution layer.\n" + "2. Unskip this test." + ) + def test_layerwise_upcasting_memory(self): + pass + @slow class AutoencoderOobleckIntegrationTests(unittest.TestCase): diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index ba3931ff452f..668d48715468 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1366,6 +1366,7 @@ def check_linear_dtype(module, storage_dtype, compute_dtype): if any(re.search(pattern, name) for pattern in patterns_to_check): dtype_to_check = compute_dtype if getattr(submodule, "weight", None) is not None: + print(name, submodule.weight.dtype, dtype_to_check, patterns_to_check) self.assertEqual(submodule.weight.dtype, dtype_to_check) if getattr(submodule, "bias", None) is not None: self.assertEqual(submodule.bias.dtype, dtype_to_check) @@ -1377,6 +1378,7 @@ def test_layerwise_upcasting(storage_dtype, compute_dtype): model = self.model_class(**config).eval() model = model.to(torch_device, dtype=compute_dtype) model.enable_layerwise_upcasting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) + check_linear_dtype(model, storage_dtype, compute_dtype) output = model(**inputs_dict)[0].float().flatten().detach().cpu().numpy() From 1d306b8ef7b95ba77f2720d529688f03eb4a4714 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 21 Jan 2025 15:28:03 +0100 Subject: [PATCH 38/45] skip tests for UNet1DModelTests - unsupported pytorch operations --- tests/models/unets/test_models_unet_1d.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/models/unets/test_models_unet_1d.py b/tests/models/unets/test_models_unet_1d.py index 6eb7d3485c8b..407e85f0db4f 100644 --- a/tests/models/unets/test_models_unet_1d.py +++ b/tests/models/unets/test_models_unet_1d.py @@ -152,6 +152,24 @@ def test_unet_1d_maestro(self): assert (output_sum - 224.0896).abs() < 0.5 assert (output_max - 0.0607).abs() < 4e-4 + @unittest.skip( + "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " + "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" + "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" + "2. Unskip this test." + ) + def test_layerwise_casting_inference(self): + pass + + @unittest.skip( + "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " + "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" + "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" + "2. Unskip this test." + ) + def test_layerwise_casting_memory(self): + pass + class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = UNet1DModel From a9364bd34dd3cc9fa9bf40ad43e0205ffebb95ee Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 21 Jan 2025 15:29:58 +0100 Subject: [PATCH 39/45] layerwise_upcasting -> layerwise_casting --- docs/source/en/api/utilities.md | 4 +- docs/source/en/optimization/memory.md | 8 ++-- src/diffusers/hooks/__init__.py | 2 +- ...wise_upcasting.py => layerwise_casting.py} | 38 +++++++++---------- src/diffusers/models/modeling_utils.py | 34 ++++++++--------- tests/lora/utils.py | 6 +-- .../test_models_autoencoder_oobleck.py | 4 +- .../test_models_autoencoder_tiny.py | 4 +- tests/models/test_modeling_common.py | 22 +++++------ tests/models/unets/test_models_unet_2d.py | 8 ++-- 10 files changed, 65 insertions(+), 65 deletions(-) rename src/diffusers/hooks/{layerwise_upcasting.py => layerwise_casting.py} (84%) diff --git a/docs/source/en/api/utilities.md b/docs/source/en/api/utilities.md index a17ab1184957..b0b78928fb4b 100644 --- a/docs/source/en/api/utilities.md +++ b/docs/source/en/api/utilities.md @@ -42,6 +42,6 @@ Utility and helper functions for working with 🤗 Diffusers. [[autodoc]] utils.torch_utils.randn_tensor -## apply_layerwise_upcasting +## apply_layerwise_casting -[[autodoc]] hooks.layerwise_upcasting.apply_layerwise_upcasting +[[autodoc]] hooks.layerwise_casting.apply_layerwise_casting diff --git a/docs/source/en/optimization/memory.md b/docs/source/en/optimization/memory.md index 50e10f0caa48..4cdc60401914 100644 --- a/docs/source/en/optimization/memory.md +++ b/docs/source/en/optimization/memory.md @@ -171,9 +171,9 @@ from diffusers.utils import export_to_video model_id = "THUDM/CogVideoX-5b" -# Load the model in bfloat16 and enable layerwise upcasting +# Load the model in bfloat16 and enable layerwise casting transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16) -transformer.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) +transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) # Load the pipeline pipe = CogVideoXPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16) @@ -191,9 +191,9 @@ video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] export_to_video(video, "output.mp4", fps=8) ``` -In the above example, layerwise upcasting is enabled on the transformer component of the pipeline. By default, certain layers are skipped from the FP8 weight casting because it can lead to significant degradation of generation quality. The normalization and modulation related weight parameters are also skipped by default. +In the above example, layerwise casting is enabled on the transformer component of the pipeline. By default, certain layers are skipped from the FP8 weight casting because it can lead to significant degradation of generation quality. The normalization and modulation related weight parameters are also skipped by default. -However, you gain more control and flexibility by directly utilizing the [`~hooks.layerwise_upcasting.apply_layerwise_upcasting`] function instead of [`~ModelMixin.enable_layerwise_upcasting`]. +However, you gain more control and flexibility by directly utilizing the [`~hooks.layerwise_casting.apply_layerwise_casting`] function instead of [`~ModelMixin.enable_layerwise_casting`]. ## Channels-last memory format diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 14c16c7d3236..91b2760acad0 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -2,4 +2,4 @@ if is_torch_available(): - from .layerwise_upcasting import apply_layerwise_upcasting, apply_layerwise_upcasting_hook + from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook diff --git a/src/diffusers/hooks/layerwise_upcasting.py b/src/diffusers/hooks/layerwise_casting.py similarity index 84% rename from src/diffusers/hooks/layerwise_upcasting.py rename to src/diffusers/hooks/layerwise_casting.py index 3d85b56db72b..95c361e10100 100644 --- a/src/diffusers/hooks/layerwise_upcasting.py +++ b/src/diffusers/hooks/layerwise_casting.py @@ -35,7 +35,7 @@ # fmt: on -class LayerwiseUpcastingHook(ModelHook): +class LayerwiseCastingHook(ModelHook): r""" A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype for storage. This process may lead to quality loss in the output, but can significantly reduce the memory @@ -55,7 +55,7 @@ def initialize_hook(self, module: torch.nn.Module): def deinitalize_hook(self, module: torch.nn.Module): raise NotImplementedError( - "LayerwiseUpcastingHook does not support deinitalization. A model once enabled with layerwise upcasting will " + "LayerwiseCastingHook does not support deinitalization. A model once enabled with layerwise casting will " "have casted its weights to a lower precision dtype for storage. Casting this back to the original dtype " "will lead to precision loss, which might have an impact on the model's generation quality. The model should " "be re-initialized and loaded in the original dtype." @@ -70,7 +70,7 @@ def post_forward(self, module: torch.nn.Module, output): return output -def apply_layerwise_upcasting( +def apply_layerwise_casting( module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, @@ -79,7 +79,7 @@ def apply_layerwise_upcasting( non_blocking: bool = False, ) -> None: r""" - Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any + Applies layerwise casting to a given module. The module expected here is a Diffusers ModelMixin but it can be any nn.Module using diffusers layers or pytorch primitives. Example: @@ -92,7 +92,7 @@ def apply_layerwise_upcasting( ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 ... ) - >>> apply_layerwise_upcasting( + >>> apply_layerwise_casting( ... transformer, ... storage_dtype=torch.float8_e4m3fn, ... compute_dtype=torch.bfloat16, @@ -110,12 +110,12 @@ def apply_layerwise_upcasting( compute_dtype (`torch.dtype`): The dtype to cast the module to during the forward pass for computation. skip_modules_pattern (`Tuple[str, ...]`, defaults to `"default"`): - A list of patterns to match the names of the modules to skip during the layerwise upcasting process. If set + A list of patterns to match the names of the modules to skip during the layerwise casting process. If set to `"default"`, the default patterns are used. If set to `None`, no modules are skipped. If set to `None` - alongside `skip_modules_classes` being `None`, the layerwise upcasting is applied directly to the module + alongside `skip_modules_classes` being `None`, the layerwise casting is applied directly to the module instead of its internal submodules. skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, defaults to `None`): - A list of module classes to skip during the layerwise upcasting process. + A list of module classes to skip during the layerwise casting process. non_blocking (`bool`, defaults to `False`): If `True`, the weight casting operations are non-blocking. """ @@ -123,10 +123,10 @@ def apply_layerwise_upcasting( skip_modules_pattern = DEFAULT_SKIP_MODULES_PATTERN if skip_modules_classes is None and skip_modules_pattern is None: - apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking) + apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking) return - _apply_layerwise_upcasting( + _apply_layerwise_casting( module, storage_dtype, compute_dtype, @@ -136,7 +136,7 @@ def apply_layerwise_upcasting( ) -def _apply_layerwise_upcasting( +def _apply_layerwise_casting( module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, @@ -149,17 +149,17 @@ def _apply_layerwise_upcasting( skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern) ) if should_skip: - logger.debug(f'Skipping layerwise upcasting for layer "{_prefix}"') + logger.debug(f'Skipping layerwise casting for layer "{_prefix}"') return if isinstance(module, SUPPORTED_PYTORCH_LAYERS): - logger.debug(f'Applying layerwise upcasting to layer "{_prefix}"') - apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking) + logger.debug(f'Applying layerwise casting to layer "{_prefix}"') + apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking) return for name, submodule in module.named_children(): layer_name = f"{_prefix}.{name}" if _prefix else name - _apply_layerwise_upcasting( + _apply_layerwise_casting( submodule, storage_dtype, compute_dtype, @@ -170,11 +170,11 @@ def _apply_layerwise_upcasting( ) -def apply_layerwise_upcasting_hook( +def apply_layerwise_casting_hook( module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool ) -> None: r""" - Applies a `LayerwiseUpcastingHook` to a given module. + Applies a `LayerwiseCastingHook` to a given module. Args: module (`torch.nn.Module`): @@ -187,5 +187,5 @@ def apply_layerwise_upcasting_hook( If `True`, the weight casting operations are non-blocking. """ registry = HookRegistry.check_if_exists_or_initialize(module) - hook = LayerwiseUpcastingHook(storage_dtype, compute_dtype, non_blocking) - registry.register_hook(hook, "layerwise_upcasting") + hook = LayerwiseCastingHook(storage_dtype, compute_dtype, non_blocking) + registry.register_hook(hook, "layerwise_casting") diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 8eebf944f5fe..4d5669e37f5a 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -32,7 +32,7 @@ from torch import Tensor, nn from .. import __version__ -from ..hooks import apply_layerwise_upcasting +from ..hooks import apply_layerwise_casting from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( @@ -104,13 +104,13 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype: """ Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found. """ - # 1. Check if we have attached any dtype modifying hooks (eg. layerwise upcasting) + # 1. Check if we have attached any dtype modifying hooks (eg. layerwise casting) if isinstance(parameter, nn.Module): for name, submodule in parameter.named_modules(): if not hasattr(submodule, "_diffusers_hook"): continue registry = submodule._diffusers_hook - hook = registry.get_hook("layerwise_upcasting") + hook = registry.get_hook("layerwise_casting") if hook is not None: return hook.compute_dtype @@ -328,7 +328,7 @@ def disable_xformers_memory_efficient_attention(self) -> None: """ self.set_use_memory_efficient_attention_xformers(False) - def enable_layerwise_upcasting( + def enable_layerwise_casting( self, storage_dtype: torch.dtype = torch.float8_e4m3fn, compute_dtype: Optional[torch.dtype] = None, @@ -337,9 +337,9 @@ def enable_layerwise_upcasting( non_blocking: bool = False, ) -> None: r""" - Activates layerwise upcasting for the current model. + Activates layerwise casting for the current model. - Layerwise upcasting is a technique that casts the model weights to a lower precision dtype for storage but + Layerwise casting is a technique that casts the model weights to a lower precision dtype for storage but upcasts them on-the-fly to a higher precision dtype for computation. This process can significantly reduce the memory footprint from model weights, but may lead to some quality degradation in the outputs. Most degradations are negligible, mostly stemming from weight casting in normalization and modulation layers. @@ -348,10 +348,10 @@ def enable_layerwise_upcasting( embedding, positional embedding and normalization layers. This is because these layers are most likely precision-critical for quality. If you wish to change this behavior, you can set the `_skip_layerwise_casting_patterns` attribute to `None`, or call - [`~hooks.layerwise_upcasting.apply_layerwise_upcasting`] with custom arguments. + [`~hooks.layerwise_casting.apply_layerwise_casting`] with custom arguments. Example: - Using [`~models.ModelMixin.enable_layerwise_upcasting`]: + Using [`~models.ModelMixin.enable_layerwise_casting`]: ```python >>> from diffusers import CogVideoXTransformer3DModel @@ -360,8 +360,8 @@ def enable_layerwise_upcasting( ... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16 ... ) - >>> # Enable layerwise upcasting via the model, which ignores certain modules by default - >>> transformer.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) + >>> # Enable layerwise casting via the model, which ignores certain modules by default + >>> transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) ``` Args: @@ -370,18 +370,18 @@ def enable_layerwise_upcasting( compute_dtype (`torch.dtype`): The dtype to which the model weights should be cast during the forward pass. skip_modules_pattern (`Tuple[str, ...]`, *optional*): - A list of patterns to match the names of the modules to skip during the layerwise upcasting process. If + A list of patterns to match the names of the modules to skip during the layerwise casting process. If set to `None`, default skip patterns are used to ignore certain internal layers of modules and PEFT layers. skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, *optional*): - A list of module classes to skip during the layerwise upcasting process. + A list of module classes to skip during the layerwise casting process. non_blocking (`bool`, *optional*, defaults to `False`): If `True`, the weight casting operations are non-blocking. """ user_provided_patterns = True if skip_modules_pattern is None: - from ..hooks.layerwise_upcasting import DEFAULT_SKIP_MODULES_PATTERN + from ..hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN skip_modules_pattern = DEFAULT_SKIP_MODULES_PATTERN user_provided_patterns = False @@ -393,8 +393,8 @@ def enable_layerwise_upcasting( if is_peft_available() and not user_provided_patterns: # By default, we want to skip all peft layers because they have a very low memory footprint. - # If users want to apply layerwise upcasting on peft layers as well, they can utilize the - # `~diffusers.hooks.layerwise_upcasting.apply_layerwise_upcasting` function which provides + # If users want to apply layerwise casting on peft layers as well, they can utilize the + # `~diffusers.hooks.layerwise_casting.apply_layerwise_casting` function which provides # them with more flexibility and control. from peft.tuners.loha.layer import LoHaLayer @@ -405,10 +405,10 @@ def enable_layerwise_upcasting( skip_modules_pattern += tuple(layer.adapter_layer_names) if compute_dtype is None: - logger.info("`compute_dtype` not provided when enabling layerwise upcasting. Using dtype of the model.") + logger.info("`compute_dtype` not provided when enabling layerwise casting. Using dtype of the model.") compute_dtype = self.dtype - apply_layerwise_upcasting( + apply_layerwise_casting( self, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking ) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 14dbd04101aa..d0d39d05b08a 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -2100,8 +2100,8 @@ def test_correct_lora_configs_with_different_ranks(self): self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3)) self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) - def test_layerwise_upcasting_inference_denoiser(self): - from diffusers.hooks.layerwise_upcasting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS + def test_layerwise_casting_inference_denoiser(self): + from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS def check_linear_dtype(module, storage_dtype, compute_dtype): patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN @@ -2142,7 +2142,7 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32): ) if storage_dtype is not None: - denoiser.enable_layerwise_upcasting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) + denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) check_linear_dtype(denoiser, storage_dtype, compute_dtype) return pipe diff --git a/tests/models/autoencoders/test_models_autoencoder_oobleck.py b/tests/models/autoencoders/test_models_autoencoder_oobleck.py index 1abda26233f7..1f922a9842ee 100644 --- a/tests/models/autoencoders/test_models_autoencoder_oobleck.py +++ b/tests/models/autoencoders/test_models_autoencoder_oobleck.py @@ -120,7 +120,7 @@ def test_set_attn_processor_for_determinism(self): "1. Make sure `nn::Module::to` works with `torch.nn.utils.weight_norm` wrapped convolution layer.\n" "2. Unskip this test." ) - def test_layerwise_upcasting_inference(self): + def test_layerwise_casting_inference(self): pass @unittest.skip( @@ -129,7 +129,7 @@ def test_layerwise_upcasting_inference(self): "1. Make sure `nn::Module::to` works with `torch.nn.utils.weight_norm` wrapped convolution layer.\n" "2. Unskip this test." ) - def test_layerwise_upcasting_memory(self): + def test_layerwise_casting_memory(self): pass diff --git a/tests/models/autoencoders/test_models_autoencoder_tiny.py b/tests/models/autoencoders/test_models_autoencoder_tiny.py index 9b6ec42496ce..bfbfb7ab8593 100644 --- a/tests/models/autoencoders/test_models_autoencoder_tiny.py +++ b/tests/models/autoencoders/test_models_autoencoder_tiny.py @@ -178,7 +178,7 @@ def test_effective_gradient_checkpointing(self): "1. Change the forward pass to be dtype agnostic.\n" "2. Unskip this test." ) - def test_layerwise_upcasting_inference(self): + def test_layerwise_casting_inference(self): pass @unittest.skip( @@ -186,7 +186,7 @@ def test_layerwise_upcasting_inference(self): "1. Change the forward pass to be dtype agnostic.\n" "2. Unskip this test." ) - def test_layerwise_upcasting_memory(self): + def test_layerwise_casting_memory(self): pass diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 668d48715468..0a6cdffb2a5b 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1346,8 +1346,8 @@ def test_variant_sharded_ckpt_right_format(self): # Example: diffusion_pytorch_model.fp16-00001-of-00002.safetensors assert all(f.split(".")[1].split("-")[0] == variant for f in shard_files) - def test_layerwise_upcasting_inference(self): - from diffusers.hooks.layerwise_upcasting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS + def test_layerwise_casting_inference(self): + from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS torch.manual_seed(0) config, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -1371,28 +1371,28 @@ def check_linear_dtype(module, storage_dtype, compute_dtype): if getattr(submodule, "bias", None) is not None: self.assertEqual(submodule.bias.dtype, dtype_to_check) - def test_layerwise_upcasting(storage_dtype, compute_dtype): + def test_layerwise_casting(storage_dtype, compute_dtype): torch.manual_seed(0) config, inputs_dict = self.prepare_init_args_and_inputs_for_common() inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype) model = self.model_class(**config).eval() model = model.to(torch_device, dtype=compute_dtype) - model.enable_layerwise_upcasting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) + model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) check_linear_dtype(model, storage_dtype, compute_dtype) output = model(**inputs_dict)[0].float().flatten().detach().cpu().numpy() # The precision test is not very important for fast tests. In most cases, the outputs will not be the same. - # We just want to make sure that the layerwise upcasting is working as expected. + # We just want to make sure that the layerwise casting is working as expected. self.assertTrue(numpy_cosine_similarity_distance(base_slice, output) < 1.0) - test_layerwise_upcasting(torch.float16, torch.float32) - test_layerwise_upcasting(torch.float8_e4m3fn, torch.float32) - test_layerwise_upcasting(torch.float8_e5m2, torch.float32) - test_layerwise_upcasting(torch.float8_e4m3fn, torch.bfloat16) + test_layerwise_casting(torch.float16, torch.float32) + test_layerwise_casting(torch.float8_e4m3fn, torch.float32) + test_layerwise_casting(torch.float8_e5m2, torch.float32) + test_layerwise_casting(torch.float8_e4m3fn, torch.bfloat16) @require_torch_gpu - def test_layerwise_upcasting_memory(self): + def test_layerwise_casting_memory(self): def reset_memory_stats(): gc.collect() torch.cuda.synchronize() @@ -1405,7 +1405,7 @@ def get_memory_usage(storage_dtype, compute_dtype): inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype) model = self.model_class(**config).eval() model = model.to(torch_device, dtype=compute_dtype) - model.enable_layerwise_upcasting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) + model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) reset_memory_stats() model(**inputs_dict) diff --git a/tests/models/unets/test_models_unet_2d.py b/tests/models/unets/test_models_unet_2d.py index 69e882202d12..0e5fdc4bba2e 100644 --- a/tests/models/unets/test_models_unet_2d.py +++ b/tests/models/unets/test_models_unet_2d.py @@ -403,13 +403,13 @@ def test_effective_gradient_checkpointing(self): super().test_effective_gradient_checkpointing(skip={"time_proj.weight"}) @unittest.skip( - "To make layerwise upcasting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here." + "To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here." ) - def test_layerwise_upcasting_inference(self): + def test_layerwise_casting_inference(self): pass @unittest.skip( - "To make layerwise upcasting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here." + "To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here." ) - def test_layerwise_upcasting_memory(self): + def test_layerwise_casting_memory(self): pass From c4d5a2b2912d67f35154cc1328ee0881121116d3 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 21 Jan 2025 15:42:42 +0100 Subject: [PATCH 40/45] skip tests for UNetRLModelTests; needs next pytorch release for currently unimplemented operation support --- tests/models/unets/test_models_unet_1d.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/models/unets/test_models_unet_1d.py b/tests/models/unets/test_models_unet_1d.py index 407e85f0db4f..dea1a57e1a76 100644 --- a/tests/models/unets/test_models_unet_1d.py +++ b/tests/models/unets/test_models_unet_1d.py @@ -292,3 +292,21 @@ def test_output_pretrained(self): def test_forward_with_norm_groups(self): # Not implemented yet for this UNet pass + + @unittest.skip( + "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " + "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" + "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" + "2. Unskip this test." + ) + def test_layerwise_casting_inference(self): + pass + + @unittest.skip( + "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " + "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" + "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" + "2. Unskip this test." + ) + def test_layerwise_casting_memory(self): + pass From d175d93d6832da67ed442f943876e7ded72b7280 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 21 Jan 2025 16:29:59 +0100 Subject: [PATCH 41/45] add layerwise fp8 pipeline test --- tests/pipelines/allegro/test_allegro.py | 1 + tests/pipelines/amused/test_amused.py | 1 + tests/pipelines/animatediff/test_animatediff.py | 1 + .../aura_flow/test_pipeline_aura_flow.py | 1 + tests/pipelines/cogvideo/test_cogvideox.py | 1 + .../cogvideo/test_cogvideox_fun_control.py | 1 + tests/pipelines/cogview3/test_cogview3plus.py | 1 + tests/pipelines/consisid/test_consisid.py | 1 + tests/pipelines/controlnet/test_controlnet.py | 1 + .../controlnet/test_controlnet_sdxl.py | 1 + .../controlnet_flux/test_controlnet_flux.py | 1 + .../test_controlnet_hunyuandit.py | 1 + .../controlnet_sd3/test_controlnet_sd3.py | 1 + .../controlnet_xs/test_controlnetxs.py | 1 + .../controlnet_xs/test_controlnetxs_sdxl.py | 1 + tests/pipelines/flux/test_pipeline_flux.py | 1 + .../flux/test_pipeline_flux_control.py | 1 + tests/pipelines/flux/test_pipeline_flux_fill.py | 1 + tests/pipelines/hunyuan_dit/test_hunyuan_dit.py | 1 + .../hunyuan_video/test_hunyuan_video.py | 1 + tests/pipelines/i2vgen_xl/test_i2vgenxl.py | 1 + tests/pipelines/kolors/test_kolors.py | 1 + tests/pipelines/latte/test_latte.py | 1 + tests/pipelines/ltx/test_ltx.py | 1 + tests/pipelines/lumina/test_lumina_nextdit.py | 1 + tests/pipelines/mochi/test_mochi.py | 1 + tests/pipelines/pia/test_pia.py | 1 + tests/pipelines/pixart_alpha/test_pixart.py | 1 + tests/pipelines/pixart_sigma/test_pixart.py | 1 + tests/pipelines/sana/test_sana.py | 1 + .../stable_diffusion/test_stable_diffusion.py | 1 + .../stable_diffusion_2/test_stable_diffusion.py | 1 + .../test_pipeline_stable_diffusion_3.py | 1 + .../test_stable_diffusion_xl.py | 1 + tests/pipelines/test_pipelines_common.py | 17 ++++++++++++++++- 35 files changed, 50 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/allegro/test_allegro.py b/tests/pipelines/allegro/test_allegro.py index 6a5a81bf160f..322be373641a 100644 --- a/tests/pipelines/allegro/test_allegro.py +++ b/tests/pipelines/allegro/test_allegro.py @@ -57,6 +57,7 @@ class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] ) test_xformers_attention = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/amused/test_amused.py b/tests/pipelines/amused/test_amused.py index f28d8708d309..2dfc36a6ce45 100644 --- a/tests/pipelines/amused/test_amused.py +++ b/tests/pipelines/amused/test_amused.py @@ -38,6 +38,7 @@ class AmusedPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = AmusedPipeline params = TEXT_TO_IMAGE_PARAMS | {"encoder_hidden_states", "negative_encoder_hidden_states"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index c7411a7145c5..1b3115c8eb1d 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -60,6 +60,7 @@ class AnimateDiffPipelineFastTests( "callback_on_step_end_tensor_inputs", ] ) + test_layerwise_casting = True def get_dummy_components(self): cross_attention_dim = 8 diff --git a/tests/pipelines/aura_flow/test_pipeline_aura_flow.py b/tests/pipelines/aura_flow/test_pipeline_aura_flow.py index 14bc588df905..bee905f9ae13 100644 --- a/tests/pipelines/aura_flow/test_pipeline_aura_flow.py +++ b/tests/pipelines/aura_flow/test_pipeline_aura_flow.py @@ -30,6 +30,7 @@ class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin): ] ) batch_params = frozenset(["prompt", "negative_prompt"]) + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py index 78fe9d4ef3be..9ce3d8e9de31 100644 --- a/tests/pipelines/cogvideo/test_cogvideox.py +++ b/tests/pipelines/cogvideo/test_cogvideox.py @@ -58,6 +58,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] ) test_xformers_attention = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py index 2a51fc65798c..c936bad4c3d5 100644 --- a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py +++ b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py @@ -55,6 +55,7 @@ class CogVideoXFunControlPipelineFastTests(PipelineTesterMixin, unittest.TestCas ] ) test_xformers_attention = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/cogview3/test_cogview3plus.py b/tests/pipelines/cogview3/test_cogview3plus.py index dcb746e0a55d..102a5c66e624 100644 --- a/tests/pipelines/cogview3/test_cogview3plus.py +++ b/tests/pipelines/cogview3/test_cogview3plus.py @@ -56,6 +56,7 @@ class CogView3PlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] ) test_xformers_attention = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/consisid/test_consisid.py b/tests/pipelines/consisid/test_consisid.py index 31f2bc024af6..f949cfb2d36d 100644 --- a/tests/pipelines/consisid/test_consisid.py +++ b/tests/pipelines/consisid/test_consisid.py @@ -58,6 +58,7 @@ class ConsisIDPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] ) test_xformers_attention = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py index 43814b2b2211..e0fc00171031 100644 --- a/tests/pipelines/controlnet/test_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -126,6 +126,7 @@ class ControlNetPipelineFastTests( batch_params = TEXT_TO_IMAGE_BATCH_PARAMS image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + test_layerwise_casting = True def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index 27f676b15b1c..e75fe8903134 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -75,6 +75,7 @@ class StableDiffusionXLControlNetPipelineFastTests( batch_params = TEXT_TO_IMAGE_BATCH_PARAMS image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + test_layerwise_casting = True def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux.py b/tests/pipelines/controlnet_flux/test_controlnet_flux.py index 5e856b125f32..8b9852dbec6e 100644 --- a/tests/pipelines/controlnet_flux/test_controlnet_flux.py +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux.py @@ -50,6 +50,7 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin): params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py b/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py index 30dfe94e50f1..5c6054ccb605 100644 --- a/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py +++ b/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py @@ -57,6 +57,7 @@ class HunyuanDiTControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMix ] ) batch_params = frozenset(["prompt", "negative_prompt"]) + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py index 7527d17af32a..e1894d555c3c 100644 --- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py +++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py @@ -59,6 +59,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes ] ) batch_params = frozenset(["prompt", "negative_prompt"]) + test_layerwise_casting = True def get_dummy_components( self, num_controlnet_layers: int = 3, qk_norm: Optional[str] = "rms_norm", use_dual_attention=False diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py index 6d53d0618959..4c184db99630 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -139,6 +139,7 @@ class ControlNetXSPipelineFastTests( image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS test_attention_slicing = False + test_layerwise_casting = True def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py index d7ecf92f41cd..7537efe0bbf9 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py @@ -78,6 +78,7 @@ class StableDiffusionXLControlNetXSPipelineFastTests( image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS test_attention_slicing = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index addc29e14670..a3bc1658de74 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -31,6 +31,7 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapte # there is no xformers processor for Flux test_xformers_attention = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/flux/test_pipeline_flux_control.py b/tests/pipelines/flux/test_pipeline_flux_control.py index 2bd511db3d65..7fdb19327213 100644 --- a/tests/pipelines/flux/test_pipeline_flux_control.py +++ b/tests/pipelines/flux/test_pipeline_flux_control.py @@ -22,6 +22,7 @@ class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin): # there is no xformers processor for Flux test_xformers_attention = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/flux/test_pipeline_flux_fill.py b/tests/pipelines/flux/test_pipeline_flux_fill.py index 6c6ec138c781..620ecb8a831f 100644 --- a/tests/pipelines/flux/test_pipeline_flux_fill.py +++ b/tests/pipelines/flux/test_pipeline_flux_fill.py @@ -23,6 +23,7 @@ class FluxFillPipelineFastTests(unittest.TestCase, PipelineTesterMixin): params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) test_xformers_attention = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py b/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py index b295b280a560..6c9117a55c36 100644 --- a/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py +++ b/tests/pipelines/hunyuan_dit/test_hunyuan_dit.py @@ -55,6 +55,7 @@ class HunyuanDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS required_optional_params = PipelineTesterMixin.required_optional_params + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py index 567002268106..ce03381f90d2 100644 --- a/tests/pipelines/hunyuan_video/test_hunyuan_video.py +++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py @@ -53,6 +53,7 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): # there is no xformers processor for Flux test_xformers_attention = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/i2vgen_xl/test_i2vgenxl.py b/tests/pipelines/i2vgen_xl/test_i2vgenxl.py index 22ece0e6d75f..f6ac22a9b575 100644 --- a/tests/pipelines/i2vgen_xl/test_i2vgenxl.py +++ b/tests/pipelines/i2vgen_xl/test_i2vgenxl.py @@ -61,6 +61,7 @@ class I2VGenXLPipelineFastTests(SDFunctionTesterMixin, PipelineTesterMixin, unit required_optional_params = frozenset(["num_inference_steps", "generator", "latents", "return_dict"]) supports_dduf = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/kolors/test_kolors.py b/tests/pipelines/kolors/test_kolors.py index e88ba0282096..cf0b392ddc06 100644 --- a/tests/pipelines/kolors/test_kolors.py +++ b/tests/pipelines/kolors/test_kolors.py @@ -48,6 +48,7 @@ class KolorsPipelineFastTests(PipelineTesterMixin, unittest.TestCase): callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"}) supports_dduf = False + test_layerwise_casting = True def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) diff --git a/tests/pipelines/latte/test_latte.py b/tests/pipelines/latte/test_latte.py index 9667ebff249d..2d5bcba8237a 100644 --- a/tests/pipelines/latte/test_latte.py +++ b/tests/pipelines/latte/test_latte.py @@ -52,6 +52,7 @@ class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS required_optional_params = PipelineTesterMixin.required_optional_params + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/ltx/test_ltx.py b/tests/pipelines/ltx/test_ltx.py index dd166c6242fc..64b366ea8ad6 100644 --- a/tests/pipelines/ltx/test_ltx.py +++ b/tests/pipelines/ltx/test_ltx.py @@ -46,6 +46,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] ) test_xformers_attention = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/lumina/test_lumina_nextdit.py b/tests/pipelines/lumina/test_lumina_nextdit.py index e0fd06847b77..7c1923313b23 100644 --- a/tests/pipelines/lumina/test_lumina_nextdit.py +++ b/tests/pipelines/lumina/test_lumina_nextdit.py @@ -32,6 +32,7 @@ class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterM batch_params = frozenset(["prompt", "negative_prompt"]) supports_dduf = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/mochi/test_mochi.py b/tests/pipelines/mochi/test_mochi.py index c9df5785897c..b7bb844ff311 100644 --- a/tests/pipelines/mochi/test_mochi.py +++ b/tests/pipelines/mochi/test_mochi.py @@ -55,6 +55,7 @@ class MochiPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] ) test_xformers_attention = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/pia/test_pia.py b/tests/pipelines/pia/test_pia.py index e461860eff65..747be38d495c 100644 --- a/tests/pipelines/pia/test_pia.py +++ b/tests/pipelines/pia/test_pia.py @@ -55,6 +55,7 @@ class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, PipelineFr "callback_on_step_end_tensor_inputs", ] ) + test_layerwise_casting = True def get_dummy_components(self): cross_attention_dim = 8 diff --git a/tests/pipelines/pixart_alpha/test_pixart.py b/tests/pipelines/pixart_alpha/test_pixart.py index e7039c61a448..7df6656f6f87 100644 --- a/tests/pipelines/pixart_alpha/test_pixart.py +++ b/tests/pipelines/pixart_alpha/test_pixart.py @@ -50,6 +50,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS required_optional_params = PipelineTesterMixin.required_optional_params + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/pixart_sigma/test_pixart.py b/tests/pipelines/pixart_sigma/test_pixart.py index a92e99366ee3..6e265b9d5eb8 100644 --- a/tests/pipelines/pixart_sigma/test_pixart.py +++ b/tests/pipelines/pixart_sigma/test_pixart.py @@ -55,6 +55,7 @@ class PixArtSigmaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS required_optional_params = PipelineTesterMixin.required_optional_params + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/sana/test_sana.py b/tests/pipelines/sana/test_sana.py index 7109a700403c..f70f9d91f19c 100644 --- a/tests/pipelines/sana/test_sana.py +++ b/tests/pipelines/sana/test_sana.py @@ -52,6 +52,7 @@ class SanaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] ) test_xformers_attention = False + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index ccd5567106d2..1e700bed03f8 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -123,6 +123,7 @@ class StableDiffusionPipelineFastTests( image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS + test_layerwise_casting = True def get_dummy_components(self, time_cond_proj_dim=None): cross_attention_dim = 8 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py index e7114d19e208..10b8a1818a29 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py @@ -75,6 +75,7 @@ class StableDiffusion2PipelineFastTests( image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py index a6f718ae4fbb..df37090eeba2 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py @@ -35,6 +35,7 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin): ] ) batch_params = frozenset(["prompt", "negative_prompt"]) + test_layerwise_casting = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 8550f258045e..f1422022a7aa 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -75,6 +75,7 @@ class StableDiffusionXLPipelineFastTests( image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"}) + test_layerwise_casting = True def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 83b628e09f88..139778994b87 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -987,7 +987,7 @@ class PipelineTesterMixin: test_attention_slicing = True test_xformers_attention = True - + test_layerwise_casting = False supports_dduf = True def get_generator(self, seed): @@ -2027,6 +2027,21 @@ def test_save_load_dduf(self, atol=1e-4, rtol=1e-4): elif isinstance(pipeline_out, torch.Tensor) and isinstance(loaded_pipeline_out, torch.Tensor): assert torch.allclose(pipeline_out, loaded_pipeline_out, atol=atol, rtol=rtol) + def test_layerwise_casting_inference(self): + if not self.test_layerwise_casting: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device, dtype=torch.bfloat16) + pipe.set_progress_bar_config(disable=None) + + denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet + denoiser.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) + + inputs = self.get_dummy_inputs(torch_device) + _ = pipe(**inputs)[0] + @is_staging_test class PipelinePushToHubTester(unittest.TestCase): From bf116912b37f509ccf590ae182982b786520b63b Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 21 Jan 2025 16:36:25 +0100 Subject: [PATCH 42/45] use xfail --- tests/models/test_modeling_common.py | 1 - tests/models/unets/test_models_unet_1d.py | 53 +++++++++++++---------- 2 files changed, 31 insertions(+), 23 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 0a6cdffb2a5b..9513097b86ef 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1366,7 +1366,6 @@ def check_linear_dtype(module, storage_dtype, compute_dtype): if any(re.search(pattern, name) for pattern in patterns_to_check): dtype_to_check = compute_dtype if getattr(submodule, "weight", None) is not None: - print(name, submodule.weight.dtype, dtype_to_check, patterns_to_check) self.assertEqual(submodule.weight.dtype, dtype_to_check) if getattr(submodule, "bias", None) is not None: self.assertEqual(submodule.bias.dtype, dtype_to_check) diff --git a/tests/models/unets/test_models_unet_1d.py b/tests/models/unets/test_models_unet_1d.py index dea1a57e1a76..0f81807b895c 100644 --- a/tests/models/unets/test_models_unet_1d.py +++ b/tests/models/unets/test_models_unet_1d.py @@ -15,6 +15,7 @@ import unittest +import pytest import torch from diffusers import UNet1DModel @@ -152,20 +153,24 @@ def test_unet_1d_maestro(self): assert (output_sum - 224.0896).abs() < 0.5 assert (output_max - 0.0607).abs() < 4e-4 - @unittest.skip( - "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " - "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" - "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" - "2. Unskip this test." + @pytest.mark.xfail( + reason=( + "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " + "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" + "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" + "2. Unskip this test." + ), ) def test_layerwise_casting_inference(self): - pass - - @unittest.skip( - "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " - "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" - "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" - "2. Unskip this test." + super().test_layerwise_casting_inference() + + @pytest.mark.xfail( + reason=( + "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " + "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" + "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" + "2. Unskip this test." + ), ) def test_layerwise_casting_memory(self): pass @@ -293,20 +298,24 @@ def test_forward_with_norm_groups(self): # Not implemented yet for this UNet pass - @unittest.skip( - "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " - "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" - "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" - "2. Unskip this test." + @pytest.mark.xfail( + reason=( + "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " + "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" + "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" + "2. Unskip this test." + ), ) def test_layerwise_casting_inference(self): pass - @unittest.skip( - "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " - "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" - "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" - "2. Unskip this test." + @pytest.mark.xfail( + reason=( + "RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations " + "not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n" + "1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n" + "2. Unskip this test." + ), ) def test_layerwise_casting_memory(self): pass From 1c523b294a1f4b8ab78738ec305b27bc63b4e15e Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 22 Jan 2025 15:22:36 +0530 Subject: [PATCH 43/45] Apply suggestions from code review Co-authored-by: Dhruv Nair --- src/diffusers/hooks/layerwise_casting.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/hooks/layerwise_casting.py b/src/diffusers/hooks/layerwise_casting.py index 95c361e10100..038625e21f0d 100644 --- a/src/diffusers/hooks/layerwise_casting.py +++ b/src/diffusers/hooks/layerwise_casting.py @@ -74,7 +74,7 @@ def apply_layerwise_casting( module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, - skip_modules_pattern: Union[str, Tuple[str, ...]] = "default", + skip_modules_pattern: Union[str, Tuple[str, ...]] = "auto", skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None, non_blocking: bool = False, ) -> None: @@ -109,9 +109,9 @@ def apply_layerwise_casting( The dtype to cast the module to before/after the forward pass for storage. compute_dtype (`torch.dtype`): The dtype to cast the module to during the forward pass for computation. - skip_modules_pattern (`Tuple[str, ...]`, defaults to `"default"`): + skip_modules_pattern (`Tuple[str, ...]`, defaults to `"auto"`): A list of patterns to match the names of the modules to skip during the layerwise casting process. If set - to `"default"`, the default patterns are used. If set to `None`, no modules are skipped. If set to `None` + to `"auto"`, the default patterns are used. If set to `None`, no modules are skipped. If set to `None` alongside `skip_modules_classes` being `None`, the layerwise casting is applied directly to the module instead of its internal submodules. skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, defaults to `None`): @@ -119,7 +119,7 @@ def apply_layerwise_casting( non_blocking (`bool`, defaults to `False`): If `True`, the weight casting operations are non-blocking. """ - if skip_modules_pattern == "default": + if skip_modules_pattern == "auto": skip_modules_pattern = DEFAULT_SKIP_MODULES_PATTERN if skip_modules_classes is None and skip_modules_pattern is None: From 376adf907a62651077629f635477f9a4d406d27d Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 22 Jan 2025 11:29:26 +0100 Subject: [PATCH 44/45] add assertion with fp32 comparison; add tolerance to fp8-fp32 vs fp32-fp32 comparison (required for a few models' test to pass) --- tests/models/test_modeling_common.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 9513097b86ef..40e1b9e48919 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1392,6 +1392,8 @@ def test_layerwise_casting(storage_dtype, compute_dtype): @require_torch_gpu def test_layerwise_casting_memory(self): + MB_TOLERANCE = 0.2 + def reset_memory_stats(): gc.collect() torch.cuda.synchronize() @@ -1409,17 +1411,25 @@ def get_memory_usage(storage_dtype, compute_dtype): reset_memory_stats() model(**inputs_dict) model_memory_footprint = model.get_memory_footprint() - peak_inference_memory_allocated = torch.cuda.max_memory_allocated() + peak_inference_memory_allocated_mb = torch.cuda.max_memory_allocated() / 1024**2 - return model_memory_footprint, peak_inference_memory_allocated + return model_memory_footprint, peak_inference_memory_allocated_mb + fp32_memory_footprint, fp32_max_memory = get_memory_usage(torch.float32, torch.float32) fp8_e4m3_fp32_memory_footprint, fp8_e4m3_fp32_max_memory = get_memory_usage(torch.float8_e4m3fn, torch.float32) fp8_e4m3_bf16_memory_footprint, fp8_e4m3_bf16_max_memory = get_memory_usage( torch.float8_e4m3fn, torch.bfloat16 ) - self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint) + self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint) self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory) + # On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few + # bytes. This only happens for some models, so we allow a small tolerance. + # For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32. + self.assertTrue( + fp8_e4m3_fp32_max_memory < fp32_max_memory + or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE + ) @is_staging_test From 719e8d39bc7e9b0dba2494ff77cd81a8a2d4f829 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 22 Jan 2025 15:02:43 +0100 Subject: [PATCH 45/45] add note about memory consumption on tesla CI runner for failing test --- tests/models/test_modeling_common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 40e1b9e48919..05050e05bb19 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1422,6 +1422,8 @@ def get_memory_usage(storage_dtype, compute_dtype): ) self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint) + # NOTE: the following assertion will fail on our CI (running Tesla T4) due to bf16 using more memory than fp32. + # On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory) # On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few # bytes. This only happens for some models, so we allow a small tolerance.