From d1737e3d427ae693e5d52b25265c3d03e4b9b5fb Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 9 Jan 2025 02:09:43 +0100 Subject: [PATCH 01/37] update --- src/diffusers/hooks/__init__.py | 5 + src/diffusers/hooks/group_offloading.py | 184 ++++++++++++++++++++++++ src/diffusers/hooks/hooks.py | 172 ++++++++++++++++++++++ 3 files changed, 361 insertions(+) create mode 100644 src/diffusers/hooks/__init__.py create mode 100644 src/diffusers/hooks/group_offloading.py create mode 100644 src/diffusers/hooks/hooks.py diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py new file mode 100644 index 000000000000..6cfabea0c48e --- /dev/null +++ b/src/diffusers/hooks/__init__.py @@ -0,0 +1,5 @@ +from ..utils import is_torch_available + + +if is_torch_available(): + from .group_offloading import apply_group_offloading diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py new file mode 100644 index 000000000000..20beebc9b5d7 --- /dev/null +++ b/src/diffusers/hooks/group_offloading.py @@ -0,0 +1,184 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import List, Optional, Union + +import torch + +from .hooks import HookRegistry, ModelHook + + +_TRANSFORMER_STACK_IDENTIFIERS = [ + "transformer_blocks", + "single_transformer_blocks", + "temporal_transformer_blocks", + "transformer_layers", + "layers", + "blocks", +] + + +class ModuleGroup: + def __init__( + self, + modules: List[torch.nn.Module], + offload_device: torch.device, + onload_device: torch.device, + offload_leader: torch.nn.Module, + onload_leader: Optional[torch.nn.Module] = None, + ) -> None: + self.modules = modules + self.offload_device = offload_device + self.onload_device = onload_device + self.offload_leader = offload_leader + self.onload_leader = onload_leader + + +class GroupOffloadingHook(ModelHook): + r""" + A hook that offloads groups of torch.nn.Module to the CPU for storage and onloads to accelerator device for + computation. Each group has one "onload leader" module that is responsible for onloading, and an "offload leader" + module that is responsible for offloading. + + This implementation assumes the following: + - For offload_group_patterns="diffusers_block", the leader of a group can be automatically determined. For a custom + user-provided regex pattern, the module that triggers its forward pass first is considered the leader. + - The inputs are already on the correct device. This is expected because the hook does not modify the state of + inputs or outputs at any stage of the forward pass. If an error is raised due to the device of modules and inputs + not matching during the forward pass for any model in Diffusers, this means that the forward pass of the model is + not written in the expected. Please open an issue at https://github.com/huggingface/diffusers/issues if you + encounter such an error. + """ + + def __init__(self, group: ModuleGroup, offload_on_init: bool = True) -> None: + self.group = group + self.offload_on_init = offload_on_init + + def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: + if self.offload_on_init: + self.offload_(module) + return module + + def onload_(self, module: torch.nn.Module) -> None: + if self.group.onload_leader is not None and self.group.onload_leader == module: + for group_module in self.group.modules: + group_module.to(self.group.onload_device) + + def offload_(self, module: torch.nn.Module) -> None: + if self.group.offload_leader == module: + for group_module in self.group.modules: + group_module.to(self.group.offload_device) + + def pre_forward(self, module: torch.nn.Module, *args, **kwargs): + if self.group.onload_leader is None: + self.group.onload_leader = module + self.onload_(module) + return args, kwargs + + def post_forward(self, module: torch.nn.Module, output): + self.offload_(module) + return output + + +def apply_group_offloading( + module: torch.nn.Module, + offload_group_patterns: Union[str, List[str]] = "diffusers_block", + num_blocks_per_group: Optional[int] = None, + offload_device: torch.device = torch.device("cpu"), + onload_device: torch.device = torch.device("cuda"), + force_offload: bool = True, +) -> None: + if offload_group_patterns == "diffusers_block": + _apply_group_offloading_diffusers_block( + module, num_blocks_per_group, offload_device, onload_device, force_offload + ) + else: + _apply_group_offloading_group_patterns( + module, offload_group_patterns, offload_device, onload_device, force_offload + ) + + +def _apply_group_offloading_diffusers_block( + module: torch.nn.Module, + num_blocks_per_group: int, + offload_device: torch.device, + onload_device: torch.device, + force_offload: bool, +) -> None: + if num_blocks_per_group is None: + raise ValueError("num_blocks_per_group must be provided when using GroupOffloadingType.DIFFUSERS_BLOCK.") + + for transformer_stack_identifier in _TRANSFORMER_STACK_IDENTIFIERS: + if not hasattr(module, transformer_stack_identifier) or not isinstance( + getattr(module, transformer_stack_identifier), torch.nn.ModuleList + ): + continue + + transformer_stack = getattr(module, transformer_stack_identifier) + num_blocks = len(transformer_stack) + + for i in range(0, num_blocks, num_blocks_per_group): + blocks = transformer_stack[i : i + num_blocks_per_group] + group = ModuleGroup( + blocks, offload_device, onload_device, offload_leader=blocks[-1], onload_leader=blocks[0] + ) + should_offload = force_offload or i == 0 + _apply_group_offloading(group, should_offload) + + +def _apply_group_offloading_group_patterns( + module: torch.nn.Module, + offload_group_patterns: List[str], + offload_device: torch.device, + onload_device: torch.device, + force_offload: bool, +) -> None: + per_group_modules = [] + for i, offload_group_pattern in enumerate(offload_group_patterns): + group_modules = [] + group_module_names = [] + for name, module in module.named_modules(): + if re.search(offload_group_pattern, name) is not None: + group_modules.append(module) + group_module_names.append(name) + per_group_modules.append( + { + "modules": group_modules, + "module_names": group_module_names, + } + ) + + # Check if there are any overlapping modules between groups + for i, group in enumerate(per_group_modules): + for j, other_group in enumerate(per_group_modules): + if j <= i: + continue + if any(module_name in group["module_names"] for module_name in other_group["module_names"]): + raise ValueError( + f"Overlapping modules between groups {i} and {j}. Please ensure that offloading group patterns are mutually exclusive." + ) + + # Apply offloading to each group + for group in per_group_modules: + # TODO: handle offload leader correctly + group = ModuleGroup(group["modules"], offload_device, onload_device, offload_leader=group["modules"][-1]) + _apply_group_offloading(group, force_offload) + + +def _apply_group_offloading(group: ModuleGroup, offload_on_init) -> None: + for module in group.modules: + hook = GroupOffloadingHook(group, offload_on_init=offload_on_init) + registry = HookRegistry.check_if_exists_or_initialize(module) + registry.register_hook(hook, "group_offloading") diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py new file mode 100644 index 000000000000..d220f0c83b27 --- /dev/null +++ b/src/diffusers/hooks/hooks.py @@ -0,0 +1,172 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Any, Dict, Tuple + +import torch + +from ..utils.logging import get_logger + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +class ModelHook: + r""" + A hook that contains callbacks to be executed just before and after the forward method of a model. + """ + + _is_stateful = False + + def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when a model is initialized. + + Args: + module (`torch.nn.Module`): + The module attached to this hook. + """ + return module + + def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when a model is deinitalized. + + Args: + module (`torch.nn.Module`): + The module attached to this hook. + """ + module.forward = module._old_forward + del module._old_forward + return module + + def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: + r""" + Hook that is executed just before the forward method of the model. + + Args: + module (`torch.nn.Module`): + The module whose forward pass will be executed just after this event. + args (`Tuple[Any]`): + The positional arguments passed to the module. + kwargs (`Dict[Str, Any]`): + The keyword arguments passed to the module. + + Returns: + `Tuple[Tuple[Any], Dict[Str, Any]]`: + A tuple with the treated `args` and `kwargs`. + """ + return args, kwargs + + def post_forward(self, module: torch.nn.Module, output: Any) -> Any: + r""" + Hook that is executed just after the forward method of the model. + + Args: + module (`torch.nn.Module`): + The module whose forward pass been executed just before this event. + output (`Any`): + The output of the module. + + Returns: + `Any`: The processed `output`. + """ + return output + + def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when the hook is detached from a module. + + Args: + module (`torch.nn.Module`): + The module detached from this hook. + """ + return module + + def reset_state(self, module: torch.nn.Module): + if self._is_stateful: + raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") + return module + + +class HookRegistry: + def __init__(self, module_ref: torch.nn.Module) -> None: + super().__init__() + + self.hooks: Dict[str, ModelHook] = {} + + self._module_ref = module_ref + self._hook_order = [] + + def register_hook(self, hook: ModelHook, name: str) -> None: + if name in self.hooks.keys(): + logger.warning(f"Hook with name {name} already exists, replacing it.") + + if hasattr(self._module_ref, "_old_forward"): + old_forward = self._module_ref._old_forward + else: + old_forward = self._module_ref.forward + self._module_ref._old_forward = self._module_ref.forward + + self._module_ref = hook.initialize_hook(self._module_ref) + + if hasattr(hook, "new_forward"): + new_forward = hook.new_forward + else: + + def new_forward(module, *args, **kwargs): + args, kwargs = hook.pre_forward(module, *args, **kwargs) + output = old_forward(*args, **kwargs) + return hook.post_forward(module, output) + + # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail. + # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409 + if "GraphModuleImpl" in str(type(self._module_ref)): + self._module_ref.__class__.forward = functools.update_wrapper( + functools.partial(new_forward, self._module_ref), old_forward + ) + else: + self._module_ref.forward = functools.update_wrapper( + functools.partial(new_forward, self._module_ref), old_forward + ) + + self.hooks[name] = hook + self._hook_order.append(name) + + def get_hook(self, name: str) -> ModelHook: + if name not in self.hooks.keys(): + raise ValueError(f"Hook with name {name} not found.") + return self.hooks[name] + + def remove_hook(self, name: str) -> None: + if name not in self.hooks.keys(): + raise ValueError(f"Hook with name {name} not found.") + self.hooks[name].deinitalize_hook(self._module_ref) + del self.hooks[name] + self._hook_order.remove(name) + + @classmethod + def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry": + if not hasattr(module, "_diffusers_hook"): + module._diffusers_hook = cls(module) + return module._diffusers_hook + + def __repr__(self) -> str: + hook_repr = "" + for i, hook_name in enumerate(self._hook_order): + hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})" + if i < len(self._hook_order) - 1: + hook_repr += "\n" + return f"HookRegistry(\n{hook_repr}\n)" From 278366999fd003a7e4976429dd88a8d6fe25e797 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 9 Jan 2025 02:37:18 +0100 Subject: [PATCH 02/37] fix --- src/diffusers/hooks/group_offloading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 20beebc9b5d7..445263cf2e8a 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -134,7 +134,7 @@ def _apply_group_offloading_diffusers_block( group = ModuleGroup( blocks, offload_device, onload_device, offload_leader=blocks[-1], onload_leader=blocks[0] ) - should_offload = force_offload or i == 0 + should_offload = force_offload or i > 0 _apply_group_offloading(group, should_offload) From 6a9a3e598f81ef1c6284e47990725609bfd85e0a Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 10 Jan 2025 02:50:18 +0100 Subject: [PATCH 03/37] non_blocking; handle parameters and buffers --- src/diffusers/hooks/group_offloading.py | 88 +++++++++++++++++-------- src/diffusers/hooks/hooks.py | 12 +--- 2 files changed, 62 insertions(+), 38 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 445263cf2e8a..48bbdb78e33f 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -20,14 +20,17 @@ from .hooks import HookRegistry, ModelHook -_TRANSFORMER_STACK_IDENTIFIERS = [ +_COMMON_STACK_IDENTIFIERS = { "transformer_blocks", "single_transformer_blocks", "temporal_transformer_blocks", "transformer_layers", "layers", "blocks", -] + "down_blocks", + "up_blocks", + "mid_blocks", +} class ModuleGroup: @@ -62,25 +65,16 @@ class GroupOffloadingHook(ModelHook): encounter such an error. """ - def __init__(self, group: ModuleGroup, offload_on_init: bool = True) -> None: + def __init__(self, group: ModuleGroup, offload_on_init: bool = True, non_blocking: bool = False) -> None: self.group = group self.offload_on_init = offload_on_init + self.non_blocking = non_blocking def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: if self.offload_on_init: self.offload_(module) return module - def onload_(self, module: torch.nn.Module) -> None: - if self.group.onload_leader is not None and self.group.onload_leader == module: - for group_module in self.group.modules: - group_module.to(self.group.onload_device) - - def offload_(self, module: torch.nn.Module) -> None: - if self.group.offload_leader == module: - for group_module in self.group.modules: - group_module.to(self.group.offload_device) - def pre_forward(self, module: torch.nn.Module, *args, **kwargs): if self.group.onload_leader is None: self.group.onload_leader = module @@ -91,6 +85,19 @@ def post_forward(self, module: torch.nn.Module, output): self.offload_(module) return output + def onload_(self, module: torch.nn.Module) -> None: + if self.group.onload_leader == module: + for group_module in self.group.modules: + group_module.to(self.group.onload_device, non_blocking=self.non_blocking) + + def offload_(self, module: torch.nn.Module) -> None: + if self.group.offload_leader == module: + for group_module in self.group.modules: + group_module.to(self.group.offload_device, non_blocking=self.non_blocking) + # TODO: do we need to sync here because of GPU->CPU transfer? + if self.non_blocking and self.group.offload_device.type == "cpu": + torch.cpu.synchronize() + def apply_group_offloading( module: torch.nn.Module, @@ -99,14 +106,17 @@ def apply_group_offloading( offload_device: torch.device = torch.device("cpu"), onload_device: torch.device = torch.device("cuda"), force_offload: bool = True, + non_blocking: bool = False, ) -> None: if offload_group_patterns == "diffusers_block": + if num_blocks_per_group is None: + raise ValueError("num_blocks_per_group must be provided when using GroupOffloadingType.DIFFUSERS_BLOCK.") _apply_group_offloading_diffusers_block( - module, num_blocks_per_group, offload_device, onload_device, force_offload + module, num_blocks_per_group, offload_device, onload_device, force_offload, non_blocking ) else: _apply_group_offloading_group_patterns( - module, offload_group_patterns, offload_device, onload_device, force_offload + module, offload_group_patterns, offload_device, onload_device, force_offload, non_blocking ) @@ -116,26 +126,47 @@ def _apply_group_offloading_diffusers_block( offload_device: torch.device, onload_device: torch.device, force_offload: bool, + non_blocking: bool, ) -> None: - if num_blocks_per_group is None: - raise ValueError("num_blocks_per_group must be provided when using GroupOffloadingType.DIFFUSERS_BLOCK.") - - for transformer_stack_identifier in _TRANSFORMER_STACK_IDENTIFIERS: - if not hasattr(module, transformer_stack_identifier) or not isinstance( - getattr(module, transformer_stack_identifier), torch.nn.ModuleList + # Handle device offloading/onloading for unet/transformer stack modules + for stack_identifier in _COMMON_STACK_IDENTIFIERS: + if not hasattr(module, stack_identifier) or not isinstance( + getattr(module, stack_identifier), torch.nn.ModuleList ): continue - transformer_stack = getattr(module, transformer_stack_identifier) - num_blocks = len(transformer_stack) + stack = getattr(module, stack_identifier) + num_blocks = len(stack) for i in range(0, num_blocks, num_blocks_per_group): - blocks = transformer_stack[i : i + num_blocks_per_group] + blocks = stack[i : i + num_blocks_per_group] group = ModuleGroup( blocks, offload_device, onload_device, offload_leader=blocks[-1], onload_leader=blocks[0] ) should_offload = force_offload or i > 0 - _apply_group_offloading(group, should_offload) + _apply_group_offloading(group, should_offload, non_blocking) + + # Handle device offloading/onloading for non-stack modules + for name, submodule in module.named_modules(): + name_split = name.split(".") + if not isinstance(submodule, torch.nn.Module) or name == "" or len(name_split) > 1: + # We only want the layers that are top-level in the module (encompass all the submodules) + # for enabling offloading. + continue + layer_name = name_split[0] + print(layer_name) + if layer_name in _COMMON_STACK_IDENTIFIERS: + continue + group = ModuleGroup( + [submodule], offload_device, onload_device, offload_leader=submodule, onload_leader=submodule + ) + _apply_group_offloading(group, force_offload, non_blocking) + + # Always keep parameters and buffers on onload_device + for name, param in module.named_parameters(recurse=False): + param.data = param.data.to(onload_device) + for name, buffer in module.named_buffers(recurse=False): + buffer.data = buffer.data.to(onload_device) def _apply_group_offloading_group_patterns( @@ -144,6 +175,7 @@ def _apply_group_offloading_group_patterns( offload_device: torch.device, onload_device: torch.device, force_offload: bool, + non_blocking: bool, ) -> None: per_group_modules = [] for i, offload_group_pattern in enumerate(offload_group_patterns): @@ -174,11 +206,11 @@ def _apply_group_offloading_group_patterns( for group in per_group_modules: # TODO: handle offload leader correctly group = ModuleGroup(group["modules"], offload_device, onload_device, offload_leader=group["modules"][-1]) - _apply_group_offloading(group, force_offload) + _apply_group_offloading(group, force_offload, non_blocking) -def _apply_group_offloading(group: ModuleGroup, offload_on_init) -> None: +def _apply_group_offloading(group: ModuleGroup, offload_on_init: bool, non_blocking: bool) -> None: for module in group.modules: - hook = GroupOffloadingHook(group, offload_on_init=offload_on_init) + hook = GroupOffloadingHook(group, offload_on_init, non_blocking) registry = HookRegistry.check_if_exists_or_initialize(module) registry.register_hook(hook, "group_offloading") diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index d220f0c83b27..9d61e294742f 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -131,16 +131,8 @@ def new_forward(module, *args, **kwargs): output = old_forward(*args, **kwargs) return hook.post_forward(module, output) - # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail. - # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409 - if "GraphModuleImpl" in str(type(self._module_ref)): - self._module_ref.__class__.forward = functools.update_wrapper( - functools.partial(new_forward, self._module_ref), old_forward - ) - else: - self._module_ref.forward = functools.update_wrapper( - functools.partial(new_forward, self._module_ref), old_forward - ) + new_forward = functools.update_wrapper(new_forward, old_forward) + self._module_ref.forward = new_forward.__get__(self._module_ref) self.hooks[name] = hook self._hook_order.append(name) From c426a343b78d0d2b6aeef54e23457e6046c36790 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 10 Jan 2025 02:51:05 +0100 Subject: [PATCH 04/37] update --- src/diffusers/hooks/group_offloading.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 48bbdb78e33f..8eda18053eb9 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -164,9 +164,11 @@ def _apply_group_offloading_diffusers_block( # Always keep parameters and buffers on onload_device for name, param in module.named_parameters(recurse=False): - param.data = param.data.to(onload_device) + if torch.is_tensor(param.data): + param.data = param.data.to(onload_device) for name, buffer in module.named_buffers(recurse=False): - buffer.data = buffer.data.to(onload_device) + if torch.is_tensor(buffer.data): + buffer.data = buffer.data.to(onload_device) def _apply_group_offloading_group_patterns( From d579037f1c4ea094523052800cfc7105b55ffa57 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 11 Jan 2025 12:45:58 +0530 Subject: [PATCH 05/37] Group offloading with cuda stream prefetching (#10516) * cuda stream prefetch * remove breakpoints --- src/diffusers/hooks/group_offloading.py | 94 +++++++++++++++++++++---- 1 file changed, 80 insertions(+), 14 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 8eda18053eb9..e2f0c73f1d0c 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -13,7 +13,7 @@ # limitations under the License. import re -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import torch @@ -65,10 +65,21 @@ class GroupOffloadingHook(ModelHook): encounter such an error. """ - def __init__(self, group: ModuleGroup, offload_on_init: bool = True, non_blocking: bool = False) -> None: + def __init__( + self, + group: ModuleGroup, + offload_on_init: bool = True, + non_blocking: bool = False, + stream: Optional[torch.cuda.Stream] = None, + next_group: Optional[ModuleGroup] = None, + cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None, + ) -> None: self.group = group self.offload_on_init = offload_on_init self.non_blocking = non_blocking + self.stream = stream + self.next_group = next_group + self.cpu_param_dict = cpu_param_dict def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: if self.offload_on_init: @@ -87,16 +98,33 @@ def post_forward(self, module: torch.nn.Module, output): def onload_(self, module: torch.nn.Module) -> None: if self.group.onload_leader == module: - for group_module in self.group.modules: - group_module.to(self.group.onload_device, non_blocking=self.non_blocking) + if self.stream is not None: + # Wait for previous Host->Device transfer to complete + self.stream.synchronize() + + if self.next_group is None: + return + + # Start Host->Device transfer for next group + with torch.cuda.stream(self.stream): + for group_module in self.next_group.modules: + group_module.to(self.next_group.onload_device, non_blocking=True) + else: + for group_module in self.group.modules: + group_module.to(self.group.onload_device, non_blocking=self.non_blocking) def offload_(self, module: torch.nn.Module) -> None: if self.group.offload_leader == module: - for group_module in self.group.modules: - group_module.to(self.group.offload_device, non_blocking=self.non_blocking) - # TODO: do we need to sync here because of GPU->CPU transfer? - if self.non_blocking and self.group.offload_device.type == "cpu": - torch.cpu.synchronize() + if self.stream is not None: + for group_module in self.group.modules: + for param in group_module.parameters(): + param.data = self.cpu_param_dict[param] + else: + for group_module in self.group.modules: + group_module.to(self.group.offload_device, non_blocking=self.non_blocking) + # TODO: do we need to sync here because of GPU->CPU transfer? + if self.non_blocking and self.group.offload_device.type == "cpu": + torch.cpu.synchronize() def apply_group_offloading( @@ -107,12 +135,22 @@ def apply_group_offloading( onload_device: torch.device = torch.device("cuda"), force_offload: bool = True, non_blocking: bool = False, + cuda_stream: bool = False, ) -> None: + stream = None + if cuda_stream: + stream = torch.cuda.Stream() if offload_group_patterns == "diffusers_block": if num_blocks_per_group is None: raise ValueError("num_blocks_per_group must be provided when using GroupOffloadingType.DIFFUSERS_BLOCK.") _apply_group_offloading_diffusers_block( - module, num_blocks_per_group, offload_device, onload_device, force_offload, non_blocking + module, + num_blocks_per_group, + offload_device, + onload_device, + force_offload, + non_blocking, + stream, ) else: _apply_group_offloading_group_patterns( @@ -127,7 +165,14 @@ def _apply_group_offloading_diffusers_block( onload_device: torch.device, force_offload: bool, non_blocking: bool, + stream: Optional[torch.cuda.Stream] = None, ) -> None: + cpu_param_dict = None + if stream is not None: + for param in module.parameters(): + param.data = param.data.cpu().pin_memory() + cpu_param_dict = {param: param.data for param in module.parameters()} + # Handle device offloading/onloading for unet/transformer stack modules for stack_identifier in _COMMON_STACK_IDENTIFIERS: if not hasattr(module, stack_identifier) or not isinstance( @@ -137,14 +182,29 @@ def _apply_group_offloading_diffusers_block( stack = getattr(module, stack_identifier) num_blocks = len(stack) + module_groups = [] for i in range(0, num_blocks, num_blocks_per_group): blocks = stack[i : i + num_blocks_per_group] group = ModuleGroup( blocks, offload_device, onload_device, offload_leader=blocks[-1], onload_leader=blocks[0] ) + module_groups.append(group) + + for i, group in enumerate(module_groups): + next_group = module_groups[i + 1] if i + 1 < len(module_groups) and stream is not None else None should_offload = force_offload or i > 0 - _apply_group_offloading(group, should_offload, non_blocking) + _apply_group_offloading(group, should_offload, non_blocking, stream, next_group, cpu_param_dict) + + if stream is not None: + # Start Host->Device transfer for the first group + with torch.cuda.stream(stream): + for group_module in module_groups[0].modules: + group_module.to(onload_device, non_blocking=True) + if len(module_groups) > 1: + # Assign the first module_group as the next_group for the last module_group + hook_registry = HookRegistry.check_if_exists_or_initialize(module_groups[-1].onload_leader) + hook_registry.hooks["group_offloading"].next_group = module_groups[0] # Handle device offloading/onloading for non-stack modules for name, submodule in module.named_modules(): @@ -154,7 +214,6 @@ def _apply_group_offloading_diffusers_block( # for enabling offloading. continue layer_name = name_split[0] - print(layer_name) if layer_name in _COMMON_STACK_IDENTIFIERS: continue group = ModuleGroup( @@ -211,8 +270,15 @@ def _apply_group_offloading_group_patterns( _apply_group_offloading(group, force_offload, non_blocking) -def _apply_group_offloading(group: ModuleGroup, offload_on_init: bool, non_blocking: bool) -> None: +def _apply_group_offloading( + group: ModuleGroup, + offload_on_init: bool, + non_blocking: bool, + stream: Optional[torch.cuda.Stream] = None, + next_group: Optional[ModuleGroup] = None, + cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None, +) -> None: for module in group.modules: - hook = GroupOffloadingHook(group, offload_on_init, non_blocking) + hook = GroupOffloadingHook(group, offload_on_init, non_blocking, stream, next_group, cpu_param_dict) registry = HookRegistry.check_if_exists_or_initialize(module) registry.register_hook(hook, "group_offloading") From a8eabd07988a1a3c85314f1ae672025bff7380b8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 12 Jan 2025 03:54:20 +0100 Subject: [PATCH 06/37] update --- src/diffusers/hooks/group_offloading.py | 352 ++++++++++++++++-------- 1 file changed, 231 insertions(+), 121 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index e2f0c73f1d0c..c86025a901b6 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -13,24 +13,15 @@ # limitations under the License. import re -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import torch +from ..utils import get_logger from .hooks import HookRegistry, ModelHook -_COMMON_STACK_IDENTIFIERS = { - "transformer_blocks", - "single_transformer_blocks", - "temporal_transformer_blocks", - "transformer_layers", - "layers", - "blocks", - "down_blocks", - "up_blocks", - "mid_blocks", -} +logger = get_logger(__name__) # pylint: disable=invalid-name class ModuleGroup: @@ -129,7 +120,7 @@ def offload_(self, module: torch.nn.Module) -> None: def apply_group_offloading( module: torch.nn.Module, - offload_group_patterns: Union[str, List[str]] = "diffusers_block", + offload_group_patterns: Union[str, List[str]] = "modulelist_or_sequential", num_blocks_per_group: Optional[int] = None, offload_device: torch.device = torch.device("cpu"), onload_device: torch.device = torch.device("cuda"), @@ -137,90 +128,222 @@ def apply_group_offloading( non_blocking: bool = False, cuda_stream: bool = False, ) -> None: - stream = None - if cuda_stream: - stream = torch.cuda.Stream() - if offload_group_patterns == "diffusers_block": + # stream = None + # if cuda_stream: + # stream = torch.cuda.Stream() + if offload_group_patterns == "modulelist_or_sequential": if num_blocks_per_group is None: - raise ValueError("num_blocks_per_group must be provided when using GroupOffloadingType.DIFFUSERS_BLOCK.") - _apply_group_offloading_diffusers_block( - module, - num_blocks_per_group, - offload_device, - onload_device, - force_offload, - non_blocking, - stream, - ) - else: - _apply_group_offloading_group_patterns( - module, offload_group_patterns, offload_device, onload_device, force_offload, non_blocking - ) + raise ValueError( + "num_blocks_per_group must be provided when using offload_group_patterns='modulelist_or_sequential'." + ) + # _apply_group_offloading_diffusers_block( + # module, + # num_blocks_per_group, + # offload_device, + # onload_device, + # force_offload, + # non_blocking, + # stream, + # ) + offload_group_patterns = _get_modulelist_or_sequential_group_patterns(module, num_blocks_per_group) + + _apply_group_offloading_group_patterns( + module, offload_group_patterns, offload_device, onload_device, force_offload, non_blocking + ) + + +# def _apply_group_offloading_diffusers_block( +# module: torch.nn.Module, +# num_blocks_per_group: int, +# offload_device: torch.device, +# onload_device: torch.device, +# force_offload: bool, +# non_blocking: bool, +# stream: Optional[torch.cuda.Stream] = None, +# ) -> None: +# cpu_param_dict = None +# if stream is not None: +# for param in module.parameters(): +# param.data = param.data.cpu().pin_memory() +# cpu_param_dict = {param: param.data for param in module.parameters()} + +# # Handle device offloading/onloading for unet/transformer stack modules +# for stack_identifier in _COMMON_STACK_IDENTIFIERS: +# if not hasattr(module, stack_identifier) or not isinstance( +# getattr(module, stack_identifier), torch.nn.ModuleList +# ): +# continue + +# stack = getattr(module, stack_identifier) +# num_blocks = len(stack) +# module_groups = [] + +# for i in range(0, num_blocks, num_blocks_per_group): +# blocks = stack[i : i + num_blocks_per_group] +# group = ModuleGroup( +# blocks, offload_device, onload_device, offload_leader=blocks[-1], onload_leader=blocks[0] +# ) +# module_groups.append(group) + +# for i, group in enumerate(module_groups): +# next_group = module_groups[i + 1] if i + 1 < len(module_groups) and stream is not None else None +# should_offload = force_offload or i > 0 +# _apply_group_offloading(group, should_offload, non_blocking, stream, next_group, cpu_param_dict) + +# if stream is not None: +# # Start Host->Device transfer for the first group +# with torch.cuda.stream(stream): +# for group_module in module_groups[0].modules: +# group_module.to(onload_device, non_blocking=True) +# if len(module_groups) > 1: +# # Assign the first module_group as the next_group for the last module_group +# hook_registry = HookRegistry.check_if_exists_or_initialize(module_groups[-1].onload_leader) +# hook_registry.hooks["group_offloading"].next_group = module_groups[0] + +# # Handle device offloading/onloading for non-stack modules +# for name, submodule in module.named_modules(): +# name_split = name.split(".") +# if not isinstance(submodule, torch.nn.Module) or name == "" or len(name_split) > 1: +# # We only want the layers that are top-level in the module (encompass all the submodules) +# # for enabling offloading. +# continue +# layer_name = name_split[0] +# if layer_name in _COMMON_STACK_IDENTIFIERS: +# continue +# group = ModuleGroup( +# [submodule], offload_device, onload_device, offload_leader=submodule, onload_leader=submodule +# ) +# _apply_group_offloading(group, force_offload, non_blocking) + +# # Always keep parameters and buffers on onload_device +# for name, param in module.named_parameters(recurse=False): +# if torch.is_tensor(param.data): +# param.data = param.data.to(onload_device) +# for name, buffer in module.named_buffers(recurse=False): +# if torch.is_tensor(buffer.data): +# buffer.data = buffer.data.to(onload_device) -def _apply_group_offloading_diffusers_block( +def _apply_group_offloading_group_patterns( module: torch.nn.Module, - num_blocks_per_group: int, + offload_group_patterns: List[Tuple[str, str, Optional[str]]], offload_device: torch.device, onload_device: torch.device, force_offload: bool, non_blocking: bool, - stream: Optional[torch.cuda.Stream] = None, ) -> None: - cpu_param_dict = None - if stream is not None: - for param in module.parameters(): - param.data = param.data.cpu().pin_memory() - cpu_param_dict = {param: param.data for param in module.parameters()} - - # Handle device offloading/onloading for unet/transformer stack modules - for stack_identifier in _COMMON_STACK_IDENTIFIERS: - if not hasattr(module, stack_identifier) or not isinstance( - getattr(module, stack_identifier), torch.nn.ModuleList - ): - continue + r""" + This function applies offloading to groups of modules based on the provided regex patterns. Each group of modules + that match a pattern are offloaded and onloaded together. The order of the patterns in the list is important as it + determines the order of execution of the forward pass. If the order is not correct, group offloading may almost + certainly fail with device mismatch errors. + + In the interest of simplicity, this function does not handle complicated cases where one regex pattern matches a + module, but another regex pattern matches an internal submodule of that module. This would be a difficult case to + handle and require a more complex checker, which is not implemented here. As a general rule of thumb, make sure to + provide regex patterns for all models that are at the same level of the computation graph in terms of invocation + order. For example, either all leaf modules, or all transformer blocks, etc. + + Note that parameters and buffers are always kept on the onload_device. This is because they are usually small + enough to not have any impact on memory usage. If you require support for offloading parameters and buffers, please + open an issue at https://github.com/huggingface/diffusers/issues. + + Args: + module (`torch.nn.Module`): + The module to which group offloading is applied. + offload_group_patterns (`List[Tuple[str, str, Optional[str]]]`): + A list of tuples that determine groups of modules that are offloaded and onloaded together. Each tuple + contains three elements: + - A regex pattern that matches the names of the modules in the group. + - A regex pattern that matches a single layer that is the offload leader of the group. + - An optional regex pattern that matches a single layer that is the onload leader of the group. This can be + set to None because it is easier to determine the onload leader based on the forward invocation order, + which triggers the call to GroupOffloadingHook. + offload_device (`torch.device`): + The device to which the group of modules are offloaded. This should typically be the CPU. + onload_device (`torch.device`): + The device to which the group of modules are onloaded. + force_offload (`bool`): + If True, all module groups are offloaded to the offload_device. If False, only layers that match + `offload_group_patterns` are offloaded to the offload_device. + non_blocking (`bool`): + If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation + and data transfer. + """ - stack = getattr(module, stack_identifier) - num_blocks = len(stack) - module_groups = [] + per_group_modules = [[] for _ in range(len(offload_group_patterns))] + per_group_offload_leaders = [None] * len(offload_group_patterns) + per_group_onload_leaders = [None] * len(offload_group_patterns) + unmatched_group_modules = [] + + group_patterns = [pattern[0] for pattern in offload_group_patterns] + offload_leader_patterns = [pattern[1] for pattern in offload_group_patterns] + onload_leader_patterns = [pattern[2] for pattern in offload_group_patterns] + + for name, module in module.named_modules(): + if name.count(".") > 1: + # We only want the layers that are top-level in the module (encompass all the other submodules) + # for enabling offloading. This method is specifically targeted for diffusers format models, + # so we can ignore submodules. + # TODO(aryan): This is not the case and is just a workaround to make the benchmark code work + # for now. We need to support the arbitrary nesting of modules here. + continue + num_matches = 0 + + # Check if the module matches any of the offload group patterns + for i, pattern in enumerate(group_patterns): + if re.search(pattern, name) is not None: + per_group_modules[i].append(module) + num_matches += 1 + + # Check if the module matches any of the offload leader patterns + for i, pattern in enumerate(offload_leader_patterns): + if re.search(pattern, name) is not None: + if per_group_offload_leaders[i] is not None: + raise ValueError( + f"Module {name} matches multiple offload leader patterns. Please ensure that offload leader patterns are mutually exclusive." + ) + per_group_offload_leaders[i] = module + + # Check if the module matches any of the onload leader patterns + for i, pattern in enumerate(onload_leader_patterns): + if pattern is None: + continue + if re.search(pattern, name) is not None: + if per_group_onload_leaders[i] is not None: + raise ValueError( + f"Module {name} matches multiple onload leader patterns. Please ensure that onload leader patterns are mutually exclusive." + ) + per_group_onload_leaders[i] = module + + if num_matches == 0: + unmatched_group_modules.append(module) + elif num_matches > 1: + raise ValueError( + f"Module {name} matches multiple offload group patterns. Please ensure that offloading group patterns are mutually exclusive." + ) - for i in range(0, num_blocks, num_blocks_per_group): - blocks = stack[i : i + num_blocks_per_group] - group = ModuleGroup( - blocks, offload_device, onload_device, offload_leader=blocks[-1], onload_leader=blocks[0] + # Handle modules that matched patterns + for i in range(len(per_group_modules)): + if per_group_offload_leaders[i] is None: + raise ValueError( + f"No offload leader found for group {i}. Please ensure that each group has a single offload leader." ) - module_groups.append(group) - - for i, group in enumerate(module_groups): - next_group = module_groups[i + 1] if i + 1 < len(module_groups) and stream is not None else None - should_offload = force_offload or i > 0 - _apply_group_offloading(group, should_offload, non_blocking, stream, next_group, cpu_param_dict) - - if stream is not None: - # Start Host->Device transfer for the first group - with torch.cuda.stream(stream): - for group_module in module_groups[0].modules: - group_module.to(onload_device, non_blocking=True) - if len(module_groups) > 1: - # Assign the first module_group as the next_group for the last module_group - hook_registry = HookRegistry.check_if_exists_or_initialize(module_groups[-1].onload_leader) - hook_registry.hooks["group_offloading"].next_group = module_groups[0] - - # Handle device offloading/onloading for non-stack modules - for name, submodule in module.named_modules(): - name_split = name.split(".") - if not isinstance(submodule, torch.nn.Module) or name == "" or len(name_split) > 1: - # We only want the layers that are top-level in the module (encompass all the submodules) - # for enabling offloading. - continue - layer_name = name_split[0] - if layer_name in _COMMON_STACK_IDENTIFIERS: - continue group = ModuleGroup( - [submodule], offload_device, onload_device, offload_leader=submodule, onload_leader=submodule + per_group_modules[i], + offload_device, + onload_device, + offload_leader=per_group_offload_leaders[i], + onload_leader=per_group_onload_leaders[i], ) _apply_group_offloading(group, force_offload, non_blocking) + # Handle modules that did not match patterns + for module in unmatched_group_modules: + group = ModuleGroup([module], offload_device, onload_device, offload_leader=module, onload_leader=module) + _apply_group_offloading(group, force_offload, non_blocking) + + # TODO(aryan): When you add stream support, this may need to be put in an if-branch # Always keep parameters and buffers on onload_device for name, param in module.named_parameters(recurse=False): if torch.is_tensor(param.data): @@ -230,46 +353,6 @@ def _apply_group_offloading_diffusers_block( buffer.data = buffer.data.to(onload_device) -def _apply_group_offloading_group_patterns( - module: torch.nn.Module, - offload_group_patterns: List[str], - offload_device: torch.device, - onload_device: torch.device, - force_offload: bool, - non_blocking: bool, -) -> None: - per_group_modules = [] - for i, offload_group_pattern in enumerate(offload_group_patterns): - group_modules = [] - group_module_names = [] - for name, module in module.named_modules(): - if re.search(offload_group_pattern, name) is not None: - group_modules.append(module) - group_module_names.append(name) - per_group_modules.append( - { - "modules": group_modules, - "module_names": group_module_names, - } - ) - - # Check if there are any overlapping modules between groups - for i, group in enumerate(per_group_modules): - for j, other_group in enumerate(per_group_modules): - if j <= i: - continue - if any(module_name in group["module_names"] for module_name in other_group["module_names"]): - raise ValueError( - f"Overlapping modules between groups {i} and {j}. Please ensure that offloading group patterns are mutually exclusive." - ) - - # Apply offloading to each group - for group in per_group_modules: - # TODO: handle offload leader correctly - group = ModuleGroup(group["modules"], offload_device, onload_device, offload_leader=group["modules"][-1]) - _apply_group_offloading(group, force_offload, non_blocking) - - def _apply_group_offloading( group: ModuleGroup, offload_on_init: bool, @@ -282,3 +365,30 @@ def _apply_group_offloading( hook = GroupOffloadingHook(group, offload_on_init, non_blocking, stream, next_group, cpu_param_dict) registry = HookRegistry.check_if_exists_or_initialize(module) registry.register_hook(hook, "group_offloading") + + +def _get_modulelist_or_sequential_group_patterns(module: torch.nn.Module, num_blocks_per_group: int) -> List[str]: + r""" + This function generates group patterns for offloading based on the number of blocks per group. Given a module, it + will iterate through the submodules and find usages of torch.nn.ModuleList and torch.nn.Sequential. For each group + of `num_blocks_per_group` consecutive blocks, it will generate a regex pattern that matches the names of these + blocks. The generated patterns can be used to create ModuleGroup objects which are offloaded and onloaded together. + """ + group_patterns = [] + + # We only want the layers that are top-level in the module (encompass all the other submodules) + # for enabling offloading. This method is specifically targeted for diffusers format models, + # so we can ignore everything but the children of this module. + for name, submodule in module.children(): + if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): + continue + for i in range(0, len(submodule), num_blocks_per_group): + num_modules = len(submodule[i : i + num_blocks_per_group]) + pattern = "|".join([rf"{name}\.{i + j}\b" for j in range(num_modules)]) + pattern = f"({pattern})" + onload_leader_pattern = rf"{name}\.{i}\b" + offload_leader_pattern = rf"{name}\.{i + num_modules - 1}\b" + group_patterns.append((pattern, offload_leader_pattern, onload_leader_pattern)) + + logger.debug(f"Generated group patterns for apply_groupwise_offloading: {group_patterns}") + return group_patterns From 80ac5a72756011c424f062943602e169018fcf3a Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 16 Jan 2025 07:07:12 +0100 Subject: [PATCH 07/37] copy model hook implementation from pab --- src/diffusers/hooks/hooks.py | 57 ++++++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 19 deletions(-) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 9d61e294742f..e80ac6e88389 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -13,7 +13,7 @@ # limitations under the License. import functools -from typing import Any, Dict, Tuple +from typing import Any, Dict, Optional, Tuple import torch @@ -33,7 +33,6 @@ class ModelHook: def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: r""" Hook that is executed when a model is initialized. - Args: module (`torch.nn.Module`): The module attached to this hook. @@ -43,7 +42,6 @@ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module: r""" Hook that is executed when a model is deinitalized. - Args: module (`torch.nn.Module`): The module attached to this hook. @@ -55,7 +53,6 @@ def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module: def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: r""" Hook that is executed just before the forward method of the model. - Args: module (`torch.nn.Module`): The module whose forward pass will be executed just after this event. @@ -63,7 +60,6 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[A The positional arguments passed to the module. kwargs (`Dict[Str, Any]`): The keyword arguments passed to the module. - Returns: `Tuple[Tuple[Any], Dict[Str, Any]]`: A tuple with the treated `args` and `kwargs`. @@ -73,13 +69,11 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[A def post_forward(self, module: torch.nn.Module, output: Any) -> Any: r""" Hook that is executed just after the forward method of the model. - Args: module (`torch.nn.Module`): The module whose forward pass been executed just before this event. output (`Any`): The output of the module. - Returns: `Any`: The processed `output`. """ @@ -88,7 +82,6 @@ def post_forward(self, module: torch.nn.Module, output: Any) -> Any: def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: r""" Hook that is executed when the hook is detached from a module. - Args: module (`torch.nn.Module`): The module detached from this hook. @@ -123,7 +116,12 @@ def register_hook(self, hook: ModelHook, name: str) -> None: self._module_ref = hook.initialize_hook(self._module_ref) if hasattr(hook, "new_forward"): - new_forward = hook.new_forward + rewritten_forward = hook.new_forward + + def new_forward(module, *args, **kwargs): + args, kwargs = hook.pre_forward(module, *args, **kwargs) + output = rewritten_forward(module, *args, **kwargs) + return hook.post_forward(module, output) else: def new_forward(module, *args, **kwargs): @@ -131,23 +129,44 @@ def new_forward(module, *args, **kwargs): output = old_forward(*args, **kwargs) return hook.post_forward(module, output) - new_forward = functools.update_wrapper(new_forward, old_forward) - self._module_ref.forward = new_forward.__get__(self._module_ref) + self._module_ref.forward = functools.update_wrapper( + functools.partial(new_forward, self._module_ref), old_forward + ) self.hooks[name] = hook self._hook_order.append(name) - def get_hook(self, name: str) -> ModelHook: + def get_hook(self, name: str) -> Optional[ModelHook]: if name not in self.hooks.keys(): - raise ValueError(f"Hook with name {name} not found.") + return None return self.hooks[name] - def remove_hook(self, name: str) -> None: - if name not in self.hooks.keys(): - raise ValueError(f"Hook with name {name} not found.") - self.hooks[name].deinitalize_hook(self._module_ref) - del self.hooks[name] - self._hook_order.remove(name) + def remove_hook(self, name: str, recurse: bool = True) -> None: + if name in self.hooks.keys(): + hook = self.hooks[name] + self._module_ref = hook.deinitalize_hook(self._module_ref) + del self.hooks[name] + self._hook_order.remove(name) + + if recurse: + for module_name, module in self._module_ref.named_modules(): + if module_name == "": + continue + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook.remove_hook(name, recurse=False) + + def reset_stateful_hooks(self, recurse: bool = True) -> None: + for hook_name in self._hook_order: + hook = self.hooks[hook_name] + if hook._is_stateful: + hook.reset_state(self._module_ref) + + if recurse: + for module_name, module in self._module_ref.named_modules(): + if module_name == "": + continue + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook.reset_stateful_hooks(recurse=False) @classmethod def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry": From d2a2981a90d247f766ec19ff299132a54b0054d2 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 16 Jan 2025 08:29:30 +0100 Subject: [PATCH 08/37] update; ~very workaround based implementation but it seems to work as expected; needs cleanup and rewrite --- src/diffusers/hooks/group_offloading.py | 115 +++++++++++++++++------- src/diffusers/hooks/hooks.py | 5 ++ 2 files changed, 89 insertions(+), 31 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index c86025a901b6..41d6579779f2 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -21,7 +21,7 @@ from .hooks import HookRegistry, ModelHook -logger = get_logger(__name__) # pylint: disable=invalid-name +logger = get_logger(__name__) # pylint: disable=invalid-name class ModuleGroup: @@ -32,12 +32,16 @@ def __init__( onload_device: torch.device, offload_leader: torch.nn.Module, onload_leader: Optional[torch.nn.Module] = None, + parameters: Optional[List[torch.nn.Parameter]] = None, + buffers: Optional[List[torch.Tensor]] = None, ) -> None: self.modules = modules self.offload_device = offload_device self.onload_device = onload_device self.offload_leader = offload_leader self.onload_leader = onload_leader + self.parameters = parameters + self.buffers = buffers class GroupOffloadingHook(ModelHook): @@ -64,6 +68,7 @@ def __init__( stream: Optional[torch.cuda.Stream] = None, next_group: Optional[ModuleGroup] = None, cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None, + onload_self: bool = False, ) -> None: self.group = group self.offload_on_init = offload_on_init @@ -71,6 +76,7 @@ def __init__( self.stream = stream self.next_group = next_group self.cpu_param_dict = cpu_param_dict + self.onload_self = onload_self def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: if self.offload_on_init: @@ -100,9 +106,16 @@ def onload_(self, module: torch.nn.Module) -> None: with torch.cuda.stream(self.stream): for group_module in self.next_group.modules: group_module.to(self.next_group.onload_device, non_blocking=True) - else: + + if self.stream is None or self.onload_self: for group_module in self.group.modules: group_module.to(self.group.onload_device, non_blocking=self.non_blocking) + if self.group.parameters is not None: + for param in self.group.parameters: + param.data = param.data.to(self.group.onload_device, non_blocking=self.non_blocking) + if self.group.buffers is not None: + for buffer in self.group.buffers: + buffer.data = buffer.data.to(self.group.onload_device, non_blocking=self.non_blocking) def offload_(self, module: torch.nn.Module) -> None: if self.group.offload_leader == module: @@ -113,6 +126,13 @@ def offload_(self, module: torch.nn.Module) -> None: else: for group_module in self.group.modules: group_module.to(self.group.offload_device, non_blocking=self.non_blocking) + if self.group.parameters is not None: + for param in self.group.parameters: + param.data = param.data.to(self.group.offload_device, non_blocking=self.non_blocking) + if self.group.buffers is not None: + for buffer in self.group.buffers: + buffer.data = buffer.data.to(self.group.offload_device, non_blocking=self.non_blocking) + # TODO: do we need to sync here because of GPU->CPU transfer? if self.non_blocking and self.group.offload_device.type == "cpu": torch.cpu.synchronize() @@ -128,9 +148,9 @@ def apply_group_offloading( non_blocking: bool = False, cuda_stream: bool = False, ) -> None: - # stream = None - # if cuda_stream: - # stream = torch.cuda.Stream() + stream = None + if cuda_stream: + stream = torch.cuda.Stream() if offload_group_patterns == "modulelist_or_sequential": if num_blocks_per_group is None: raise ValueError( @@ -148,7 +168,7 @@ def apply_group_offloading( offload_group_patterns = _get_modulelist_or_sequential_group_patterns(module, num_blocks_per_group) _apply_group_offloading_group_patterns( - module, offload_group_patterns, offload_device, onload_device, force_offload, non_blocking + module, offload_group_patterns, offload_device, onload_device, force_offload, non_blocking, stream=stream ) @@ -231,6 +251,7 @@ def _apply_group_offloading_group_patterns( onload_device: torch.device, force_offload: bool, non_blocking: bool, + stream: Optional[torch.cuda.Stream] = None, ) -> None: r""" This function applies offloading to groups of modules based on the provided regex patterns. Each group of modules @@ -269,8 +290,17 @@ def _apply_group_offloading_group_patterns( non_blocking (`bool`): If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation and data transfer. + stream (`torch.cuda.Stream`, *optional*): + If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful + for overlapping computation and data transfer. """ + cpu_param_dict = None + if stream is not None: + for param in module.parameters(): + param.data = param.data.cpu().pin_memory() + cpu_param_dict = {param: param.data for param in module.parameters()} + per_group_modules = [[] for _ in range(len(offload_group_patterns))] per_group_offload_leaders = [None] * len(offload_group_patterns) per_group_onload_leaders = [None] * len(offload_group_patterns) @@ -280,20 +310,20 @@ def _apply_group_offloading_group_patterns( offload_leader_patterns = [pattern[1] for pattern in offload_group_patterns] onload_leader_patterns = [pattern[2] for pattern in offload_group_patterns] - for name, module in module.named_modules(): - if name.count(".") > 1: + for name, submodule in module.named_modules(): + if name == "" or name.count(".") > 1: # We only want the layers that are top-level in the module (encompass all the other submodules) # for enabling offloading. This method is specifically targeted for diffusers format models, # so we can ignore submodules. # TODO(aryan): This is not the case and is just a workaround to make the benchmark code work # for now. We need to support the arbitrary nesting of modules here. continue - num_matches = 0 # Check if the module matches any of the offload group patterns + num_matches = 0 for i, pattern in enumerate(group_patterns): if re.search(pattern, name) is not None: - per_group_modules[i].append(module) + per_group_modules[i].append(submodule) num_matches += 1 # Check if the module matches any of the offload leader patterns @@ -303,7 +333,7 @@ def _apply_group_offloading_group_patterns( raise ValueError( f"Module {name} matches multiple offload leader patterns. Please ensure that offload leader patterns are mutually exclusive." ) - per_group_offload_leaders[i] = module + per_group_offload_leaders[i] = submodule # Check if the module matches any of the onload leader patterns for i, pattern in enumerate(onload_leader_patterns): @@ -314,16 +344,17 @@ def _apply_group_offloading_group_patterns( raise ValueError( f"Module {name} matches multiple onload leader patterns. Please ensure that onload leader patterns are mutually exclusive." ) - per_group_onload_leaders[i] = module + per_group_onload_leaders[i] = submodule if num_matches == 0: - unmatched_group_modules.append(module) + unmatched_group_modules.append((name, submodule)) elif num_matches > 1: raise ValueError( f"Module {name} matches multiple offload group patterns. Please ensure that offloading group patterns are mutually exclusive." ) # Handle modules that matched patterns + groups = [] for i in range(len(per_group_modules)): if per_group_offload_leaders[i] is None: raise ValueError( @@ -336,21 +367,40 @@ def _apply_group_offloading_group_patterns( offload_leader=per_group_offload_leaders[i], onload_leader=per_group_onload_leaders[i], ) - _apply_group_offloading(group, force_offload, non_blocking) - - # Handle modules that did not match patterns - for module in unmatched_group_modules: - group = ModuleGroup([module], offload_device, onload_device, offload_leader=module, onload_leader=module) - _apply_group_offloading(group, force_offload, non_blocking) - - # TODO(aryan): When you add stream support, this may need to be put in an if-branch - # Always keep parameters and buffers on onload_device - for name, param in module.named_parameters(recurse=False): - if torch.is_tensor(param.data): - param.data = param.data.to(onload_device) + groups.append(group) + + for i in range(len(groups)): + next_group = groups[i + 1] if i + 1 < len(groups) and stream is not None else None + should_offload = force_offload or i > 0 + _apply_group_offloading( + groups[i], should_offload, non_blocking, stream, next_group, cpu_param_dict, onload_self=False + ) + + # Ignore parameters/buffers if they're already accounted for in unmatched_group_modules (for example, a nn.Linear + # in the top-level module will also be present in the named_parameters iterator) + parameters = [] + for name, parameter in module.named_parameters(recurse=False): + if not any(name.startswith(unmatched_name) for unmatched_name, _ in unmatched_group_modules): + parameters.append(parameter) + + buffers = [] for name, buffer in module.named_buffers(recurse=False): - if torch.is_tensor(buffer.data): - buffer.data = buffer.data.to(onload_device) + if not any(name.startswith(unmatched_name) for unmatched_name, _ in unmatched_group_modules): + buffers.append(buffer) + + unmatched_modules = [module for _, module in unmatched_group_modules] + unmatched_group = ModuleGroup( + unmatched_modules, + offload_device, + onload_device, + offload_leader=module, + onload_leader=None, + parameters=parameters, + buffers=buffers, + ) + _apply_group_offloading( + unmatched_group, force_offload, non_blocking, stream, groups[0], cpu_param_dict, onload_self=True + ) def _apply_group_offloading( @@ -360,9 +410,12 @@ def _apply_group_offloading( stream: Optional[torch.cuda.Stream] = None, next_group: Optional[ModuleGroup] = None, cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None, + onload_self: bool = False, ) -> None: for module in group.modules: - hook = GroupOffloadingHook(group, offload_on_init, non_blocking, stream, next_group, cpu_param_dict) + hook = GroupOffloadingHook( + group, offload_on_init, non_blocking, stream, next_group, cpu_param_dict, onload_self + ) registry = HookRegistry.check_if_exists_or_initialize(module) registry.register_hook(hook, "group_offloading") @@ -375,11 +428,11 @@ def _get_modulelist_or_sequential_group_patterns(module: torch.nn.Module, num_bl blocks. The generated patterns can be used to create ModuleGroup objects which are offloaded and onloaded together. """ group_patterns = [] - + # We only want the layers that are top-level in the module (encompass all the other submodules) # for enabling offloading. This method is specifically targeted for diffusers format models, # so we can ignore everything but the children of this module. - for name, submodule in module.children(): + for name, submodule in module.named_children(): if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): continue for i in range(0, len(submodule), num_blocks_per_group): @@ -389,6 +442,6 @@ def _get_modulelist_or_sequential_group_patterns(module: torch.nn.Module, num_bl onload_leader_pattern = rf"{name}\.{i}\b" offload_leader_pattern = rf"{name}\.{i + num_modules - 1}\b" group_patterns.append((pattern, offload_leader_pattern, onload_leader_pattern)) - + logger.debug(f"Generated group patterns for apply_groupwise_offloading: {group_patterns}") return group_patterns diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index e80ac6e88389..bef4c65c41e1 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -33,6 +33,7 @@ class ModelHook: def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: r""" Hook that is executed when a model is initialized. + Args: module (`torch.nn.Module`): The module attached to this hook. @@ -42,6 +43,7 @@ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module: r""" Hook that is executed when a model is deinitalized. + Args: module (`torch.nn.Module`): The module attached to this hook. @@ -53,6 +55,7 @@ def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module: def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: r""" Hook that is executed just before the forward method of the model. + Args: module (`torch.nn.Module`): The module whose forward pass will be executed just after this event. @@ -69,6 +72,7 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[A def post_forward(self, module: torch.nn.Module, output: Any) -> Any: r""" Hook that is executed just after the forward method of the model. + Args: module (`torch.nn.Module`): The module whose forward pass been executed just before this event. @@ -82,6 +86,7 @@ def post_forward(self, module: torch.nn.Module, output: Any) -> Any: def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: r""" Hook that is executed when the hook is detached from a module. + Args: module (`torch.nn.Module`): The module detached from this hook. From 01c7d2200affa8f26fc9f06c8c43337aeaadfe74 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 16 Jan 2025 09:02:16 +0100 Subject: [PATCH 09/37] more workarounds to make it actually work --- src/diffusers/hooks/group_offloading.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 41d6579779f2..eaa8728ba608 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -116,6 +116,8 @@ def onload_(self, module: torch.nn.Module) -> None: if self.group.buffers is not None: for buffer in self.group.buffers: buffer.data = buffer.data.to(self.group.onload_device, non_blocking=self.non_blocking) + if self.onload_self: + torch.cuda.synchronize() def offload_(self, module: torch.nn.Module) -> None: if self.group.offload_leader == module: @@ -388,7 +390,8 @@ def _apply_group_offloading_group_patterns( if not any(name.startswith(unmatched_name) for unmatched_name, _ in unmatched_group_modules): buffers.append(buffer) - unmatched_modules = [module for _, module in unmatched_group_modules] + ignore_blocks = ["transformer_blocks", "single_transformer_blocks", "temporal_transformer_blocks", "blocks"] + unmatched_modules = [module for name, module in unmatched_group_modules if name not in ignore_blocks] unmatched_group = ModuleGroup( unmatched_modules, offload_device, From 22aff343763f5381602754815b7237e4ea17e6e8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 16 Jan 2025 17:21:51 +0100 Subject: [PATCH 10/37] cleanup --- src/diffusers/hooks/group_offloading.py | 444 ++++++++---------------- 1 file changed, 142 insertions(+), 302 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index eaa8728ba608..db4ba06f4ddb 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re -from typing import Dict, List, Optional, Tuple, Union +from contextlib import nullcontext +from typing import Dict, List, Optional import torch +from accelerate.utils import send_to_device from ..utils import get_logger from .hooks import HookRegistry, ModelHook @@ -34,6 +35,10 @@ def __init__( onload_leader: Optional[torch.nn.Module] = None, parameters: Optional[List[torch.nn.Parameter]] = None, buffers: Optional[List[torch.Tensor]] = None, + non_blocking: bool = False, + stream: Optional[torch.cuda.Stream] = None, + cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None, + onload_self: bool = True, ) -> None: self.modules = modules self.offload_device = offload_device @@ -42,213 +47,124 @@ def __init__( self.onload_leader = onload_leader self.parameters = parameters self.buffers = buffers + self.non_blocking = non_blocking or stream is not None + self.stream = stream + self.cpu_param_dict = cpu_param_dict + self.onload_self = onload_self + + if self.stream is not None and self.cpu_param_dict is None: + raise ValueError("cpu_param_dict must be provided when using stream for data transfer.") + + def onload_(self): + context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream) + if self.stream is not None: + # Wait for previous Host->Device transfer to complete + self.stream.synchronize() + + with context: + for group_module in self.modules: + group_module.to(self.onload_device, non_blocking=self.non_blocking) + if self.parameters is not None: + for param in self.parameters: + param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) + if self.buffers is not None: + for buffer in self.buffers: + buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) + + def offload_(self): + if self.stream is not None: + for group_module in self.modules: + for param in group_module.parameters(): + param.data = self.cpu_param_dict[param] + else: + for group_module in self.modules: + group_module.to(self.offload_device, non_blocking=self.non_blocking) + if self.parameters is not None: + for param in self.parameters: + param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking) + if self.buffers is not None: + for buffer in self.buffers: + buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking) + + # TODO: do we need to sync here because of GPU->CPU transfer? + if self.non_blocking and self.offload_device.type == "cpu": + torch.cpu.synchronize() class GroupOffloadingHook(ModelHook): r""" A hook that offloads groups of torch.nn.Module to the CPU for storage and onloads to accelerator device for computation. Each group has one "onload leader" module that is responsible for onloading, and an "offload leader" - module that is responsible for offloading. - - This implementation assumes the following: - - For offload_group_patterns="diffusers_block", the leader of a group can be automatically determined. For a custom - user-provided regex pattern, the module that triggers its forward pass first is considered the leader. - - The inputs are already on the correct device. This is expected because the hook does not modify the state of - inputs or outputs at any stage of the forward pass. If an error is raised due to the device of modules and inputs - not matching during the forward pass for any model in Diffusers, this means that the forward pass of the model is - not written in the expected. Please open an issue at https://github.com/huggingface/diffusers/issues if you - encounter such an error. + module that is responsible for offloading. If prefetching is enabled, the onload leader of the previous module + group is responsible for onloading the current module group. """ def __init__( self, group: ModuleGroup, offload_on_init: bool = True, - non_blocking: bool = False, - stream: Optional[torch.cuda.Stream] = None, next_group: Optional[ModuleGroup] = None, - cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None, - onload_self: bool = False, ) -> None: self.group = group self.offload_on_init = offload_on_init - self.non_blocking = non_blocking - self.stream = stream self.next_group = next_group - self.cpu_param_dict = cpu_param_dict - self.onload_self = onload_self def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: - if self.offload_on_init: - self.offload_(module) + if self.offload_on_init and self.group.offload_leader == module: + self.group.offload_() return module def pre_forward(self, module: torch.nn.Module, *args, **kwargs): if self.group.onload_leader is None: self.group.onload_leader = module - self.onload_(module) + if self.group.onload_leader == module: + if self.group.onload_self: + self.group.onload_() + if self.next_group is not None and not self.next_group.onload_self: + self.next_group.onload_() + args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking) + kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking) return args, kwargs def post_forward(self, module: torch.nn.Module, output): - self.offload_(module) - return output - - def onload_(self, module: torch.nn.Module) -> None: - if self.group.onload_leader == module: - if self.stream is not None: - # Wait for previous Host->Device transfer to complete - self.stream.synchronize() - - if self.next_group is None: - return - - # Start Host->Device transfer for next group - with torch.cuda.stream(self.stream): - for group_module in self.next_group.modules: - group_module.to(self.next_group.onload_device, non_blocking=True) - - if self.stream is None or self.onload_self: - for group_module in self.group.modules: - group_module.to(self.group.onload_device, non_blocking=self.non_blocking) - if self.group.parameters is not None: - for param in self.group.parameters: - param.data = param.data.to(self.group.onload_device, non_blocking=self.non_blocking) - if self.group.buffers is not None: - for buffer in self.group.buffers: - buffer.data = buffer.data.to(self.group.onload_device, non_blocking=self.non_blocking) - if self.onload_self: - torch.cuda.synchronize() - - def offload_(self, module: torch.nn.Module) -> None: if self.group.offload_leader == module: - if self.stream is not None: - for group_module in self.group.modules: - for param in group_module.parameters(): - param.data = self.cpu_param_dict[param] - else: - for group_module in self.group.modules: - group_module.to(self.group.offload_device, non_blocking=self.non_blocking) - if self.group.parameters is not None: - for param in self.group.parameters: - param.data = param.data.to(self.group.offload_device, non_blocking=self.non_blocking) - if self.group.buffers is not None: - for buffer in self.group.buffers: - buffer.data = buffer.data.to(self.group.offload_device, non_blocking=self.non_blocking) - - # TODO: do we need to sync here because of GPU->CPU transfer? - if self.non_blocking and self.group.offload_device.type == "cpu": - torch.cpu.synchronize() + self.group.offload_() + return output def apply_group_offloading( module: torch.nn.Module, - offload_group_patterns: Union[str, List[str]] = "modulelist_or_sequential", + offload_type: str = "block_level", num_blocks_per_group: Optional[int] = None, offload_device: torch.device = torch.device("cpu"), onload_device: torch.device = torch.device("cuda"), force_offload: bool = True, non_blocking: bool = False, - cuda_stream: bool = False, + use_stream: bool = False, ) -> None: stream = None - if cuda_stream: - stream = torch.cuda.Stream() - if offload_group_patterns == "modulelist_or_sequential": + if use_stream: + if torch.cuda.is_available(): + stream = torch.cuda.Stream() + else: + raise ValueError("Using streams for data transfer requires a CUDA device.") + + if offload_type == "block_level": if num_blocks_per_group is None: - raise ValueError( - "num_blocks_per_group must be provided when using offload_group_patterns='modulelist_or_sequential'." - ) - # _apply_group_offloading_diffusers_block( - # module, - # num_blocks_per_group, - # offload_device, - # onload_device, - # force_offload, - # non_blocking, - # stream, - # ) - offload_group_patterns = _get_modulelist_or_sequential_group_patterns(module, num_blocks_per_group) - - _apply_group_offloading_group_patterns( - module, offload_group_patterns, offload_device, onload_device, force_offload, non_blocking, stream=stream - ) + raise ValueError("num_blocks_per_group must be provided when using offload_group_patterns='block_level'.") + _apply_group_offloading_block_level( + module, num_blocks_per_group, offload_device, onload_device, force_offload, non_blocking, stream=stream + ) + # elif offload_type == "leaf_level": + # _apply_group_offloading_leaf_level( + # module, offload_device, onload_device, force_offload, non_blocking, stream=stream + # ) -# def _apply_group_offloading_diffusers_block( -# module: torch.nn.Module, -# num_blocks_per_group: int, -# offload_device: torch.device, -# onload_device: torch.device, -# force_offload: bool, -# non_blocking: bool, -# stream: Optional[torch.cuda.Stream] = None, -# ) -> None: -# cpu_param_dict = None -# if stream is not None: -# for param in module.parameters(): -# param.data = param.data.cpu().pin_memory() -# cpu_param_dict = {param: param.data for param in module.parameters()} -# # Handle device offloading/onloading for unet/transformer stack modules -# for stack_identifier in _COMMON_STACK_IDENTIFIERS: -# if not hasattr(module, stack_identifier) or not isinstance( -# getattr(module, stack_identifier), torch.nn.ModuleList -# ): -# continue - -# stack = getattr(module, stack_identifier) -# num_blocks = len(stack) -# module_groups = [] - -# for i in range(0, num_blocks, num_blocks_per_group): -# blocks = stack[i : i + num_blocks_per_group] -# group = ModuleGroup( -# blocks, offload_device, onload_device, offload_leader=blocks[-1], onload_leader=blocks[0] -# ) -# module_groups.append(group) - -# for i, group in enumerate(module_groups): -# next_group = module_groups[i + 1] if i + 1 < len(module_groups) and stream is not None else None -# should_offload = force_offload or i > 0 -# _apply_group_offloading(group, should_offload, non_blocking, stream, next_group, cpu_param_dict) - -# if stream is not None: -# # Start Host->Device transfer for the first group -# with torch.cuda.stream(stream): -# for group_module in module_groups[0].modules: -# group_module.to(onload_device, non_blocking=True) -# if len(module_groups) > 1: -# # Assign the first module_group as the next_group for the last module_group -# hook_registry = HookRegistry.check_if_exists_or_initialize(module_groups[-1].onload_leader) -# hook_registry.hooks["group_offloading"].next_group = module_groups[0] - -# # Handle device offloading/onloading for non-stack modules -# for name, submodule in module.named_modules(): -# name_split = name.split(".") -# if not isinstance(submodule, torch.nn.Module) or name == "" or len(name_split) > 1: -# # We only want the layers that are top-level in the module (encompass all the submodules) -# # for enabling offloading. -# continue -# layer_name = name_split[0] -# if layer_name in _COMMON_STACK_IDENTIFIERS: -# continue -# group = ModuleGroup( -# [submodule], offload_device, onload_device, offload_leader=submodule, onload_leader=submodule -# ) -# _apply_group_offloading(group, force_offload, non_blocking) - -# # Always keep parameters and buffers on onload_device -# for name, param in module.named_parameters(recurse=False): -# if torch.is_tensor(param.data): -# param.data = param.data.to(onload_device) -# for name, buffer in module.named_buffers(recurse=False): -# if torch.is_tensor(buffer.data): -# buffer.data = buffer.data.to(onload_device) - - -def _apply_group_offloading_group_patterns( +def _apply_group_offloading_block_level( module: torch.nn.Module, - offload_group_patterns: List[Tuple[str, str, Optional[str]]], + num_blocks_per_group: int, offload_device: torch.device, onload_device: torch.device, force_offload: bool, @@ -256,32 +172,11 @@ def _apply_group_offloading_group_patterns( stream: Optional[torch.cuda.Stream] = None, ) -> None: r""" - This function applies offloading to groups of modules based on the provided regex patterns. Each group of modules - that match a pattern are offloaded and onloaded together. The order of the patterns in the list is important as it - determines the order of execution of the forward pass. If the order is not correct, group offloading may almost - certainly fail with device mismatch errors. - - In the interest of simplicity, this function does not handle complicated cases where one regex pattern matches a - module, but another regex pattern matches an internal submodule of that module. This would be a difficult case to - handle and require a more complex checker, which is not implemented here. As a general rule of thumb, make sure to - provide regex patterns for all models that are at the same level of the computation graph in terms of invocation - order. For example, either all leaf modules, or all transformer blocks, etc. - - Note that parameters and buffers are always kept on the onload_device. This is because they are usually small - enough to not have any impact on memory usage. If you require support for offloading parameters and buffers, please - open an issue at https://github.com/huggingface/diffusers/issues. + This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. Args: module (`torch.nn.Module`): The module to which group offloading is applied. - offload_group_patterns (`List[Tuple[str, str, Optional[str]]]`): - A list of tuples that determine groups of modules that are offloaded and onloaded together. Each tuple - contains three elements: - - A regex pattern that matches the names of the modules in the group. - - A regex pattern that matches a single layer that is the offload leader of the group. - - An optional regex pattern that matches a single layer that is the onload leader of the group. This can be - set to None because it is easier to determine the onload leader based on the forward invocation order, - which triggers the call to GroupOffloadingHook. offload_device (`torch.device`): The device to which the group of modules are offloaded. This should typically be the CPU. onload_device (`torch.device`): @@ -303,148 +198,93 @@ def _apply_group_offloading_group_patterns( param.data = param.data.cpu().pin_memory() cpu_param_dict = {param: param.data for param in module.parameters()} - per_group_modules = [[] for _ in range(len(offload_group_patterns))] - per_group_offload_leaders = [None] * len(offload_group_patterns) - per_group_onload_leaders = [None] * len(offload_group_patterns) - unmatched_group_modules = [] - - group_patterns = [pattern[0] for pattern in offload_group_patterns] - offload_leader_patterns = [pattern[1] for pattern in offload_group_patterns] - onload_leader_patterns = [pattern[2] for pattern in offload_group_patterns] - - for name, submodule in module.named_modules(): - if name == "" or name.count(".") > 1: - # We only want the layers that are top-level in the module (encompass all the other submodules) - # for enabling offloading. This method is specifically targeted for diffusers format models, - # so we can ignore submodules. - # TODO(aryan): This is not the case and is just a workaround to make the benchmark code work - # for now. We need to support the arbitrary nesting of modules here. + unmatched_modules = [] + matched_module_groups = [] + for name, submodule in module.named_children(): + if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): + unmatched_modules.append((name, submodule)) continue - - # Check if the module matches any of the offload group patterns - num_matches = 0 - for i, pattern in enumerate(group_patterns): - if re.search(pattern, name) is not None: - per_group_modules[i].append(submodule) - num_matches += 1 - - # Check if the module matches any of the offload leader patterns - for i, pattern in enumerate(offload_leader_patterns): - if re.search(pattern, name) is not None: - if per_group_offload_leaders[i] is not None: - raise ValueError( - f"Module {name} matches multiple offload leader patterns. Please ensure that offload leader patterns are mutually exclusive." - ) - per_group_offload_leaders[i] = submodule - - # Check if the module matches any of the onload leader patterns - for i, pattern in enumerate(onload_leader_patterns): - if pattern is None: - continue - if re.search(pattern, name) is not None: - if per_group_onload_leaders[i] is not None: - raise ValueError( - f"Module {name} matches multiple onload leader patterns. Please ensure that onload leader patterns are mutually exclusive." - ) - per_group_onload_leaders[i] = submodule - - if num_matches == 0: - unmatched_group_modules.append((name, submodule)) - elif num_matches > 1: - raise ValueError( - f"Module {name} matches multiple offload group patterns. Please ensure that offloading group patterns are mutually exclusive." + for i in range(0, len(submodule), num_blocks_per_group): + group = ModuleGroup( + modules=submodule[i : i + num_blocks_per_group], + offload_device=offload_device, + onload_device=onload_device, + offload_leader=submodule[i], + onload_leader=None, + non_blocking=non_blocking, + stream=stream, + cpu_param_dict=cpu_param_dict, + onload_self=stream is None, ) + matched_module_groups.append(group) - # Handle modules that matched patterns - groups = [] - for i in range(len(per_group_modules)): - if per_group_offload_leaders[i] is None: - raise ValueError( - f"No offload leader found for group {i}. Please ensure that each group has a single offload leader." - ) - group = ModuleGroup( - per_group_modules[i], - offload_device, - onload_device, - offload_leader=per_group_offload_leaders[i], - onload_leader=per_group_onload_leaders[i], + for i, group in enumerate(matched_module_groups): + next_group = ( + matched_module_groups[i + 1] if i + 1 < len(matched_module_groups) and stream is not None else None ) - groups.append(group) - - for i in range(len(groups)): - next_group = groups[i + 1] if i + 1 < len(groups) and stream is not None else None should_offload = force_offload or i > 0 - _apply_group_offloading( - groups[i], should_offload, non_blocking, stream, next_group, cpu_param_dict, onload_self=False - ) + _apply_group_offloading(group, should_offload, next_group) - # Ignore parameters/buffers if they're already accounted for in unmatched_group_modules (for example, a nn.Linear - # in the top-level module will also be present in the named_parameters iterator) parameters = [] for name, parameter in module.named_parameters(recurse=False): - if not any(name.startswith(unmatched_name) for unmatched_name, _ in unmatched_group_modules): + if not any(name.startswith(unmatched_name) for unmatched_name, _ in unmatched_modules): parameters.append(parameter) buffers = [] for name, buffer in module.named_buffers(recurse=False): - if not any(name.startswith(unmatched_name) for unmatched_name, _ in unmatched_group_modules): + if not any(name.startswith(unmatched_name) for unmatched_name, _ in unmatched_modules): buffers.append(buffer) - ignore_blocks = ["transformer_blocks", "single_transformer_blocks", "temporal_transformer_blocks", "blocks"] - unmatched_modules = [module for name, module in unmatched_group_modules if name not in ignore_blocks] + unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules] unmatched_group = ModuleGroup( - unmatched_modules, - offload_device, - onload_device, + modules=unmatched_modules, + offload_device=offload_device, + onload_device=onload_device, offload_leader=module, onload_leader=None, parameters=parameters, buffers=buffers, + non_blocking=False, + stream=None, + cpu_param_dict=cpu_param_dict, + onload_self=True, ) - _apply_group_offloading( - unmatched_group, force_offload, non_blocking, stream, groups[0], cpu_param_dict, onload_self=True - ) + _apply_group_offloading(unmatched_group, force_offload, matched_module_groups[0]) + + +# def _apply_group_offloading_leaf_level( +# module: torch.nn.Module, +# offload_device: torch.device, +# onload_device: torch.device, +# force_offload: bool, +# non_blocking: bool, +# stream: Optional[torch.cuda.Stream] = None, +# ) -> None: +# r""" +# This function applies offloading to groups of leaf modules in a torch.nn.Module. + +# Args: # module (`torch.nn.Module`): # The module to which group offloading is applied. # offload_device +(`torch.device`): # The device to which the group of modules are offloaded. This should typically be the CPU. # +onload_device (`torch.device`): # The device to which the group of modules are onloaded. # force_offload (`bool`): # If +True, all module groups are offloaded to the offload_device. If False, only layers that match # +`offload_group_patterns` are offloaded to the offload_device. # non_blocking (`bool`): # If True, offloading and +onloading is done asynchronously. This can be useful for overlapping computation # and data transfer. # stream +(`torch.cuda.Stream`, *optional*): # If provided, offloading and onloading is done asynchronously using the provided +stream. This can be useful # for overlapping computation and data transfer. #""" + +# cpu_param_dict = None +# if stream is not None: +# for param in module.parameters(): +# param.data = param.data.cpu().pin_memory() +# cpu_param_dict = {param: param.data for param in module.parameters()} def _apply_group_offloading( group: ModuleGroup, offload_on_init: bool, - non_blocking: bool, - stream: Optional[torch.cuda.Stream] = None, next_group: Optional[ModuleGroup] = None, - cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None, - onload_self: bool = False, ) -> None: for module in group.modules: - hook = GroupOffloadingHook( - group, offload_on_init, non_blocking, stream, next_group, cpu_param_dict, onload_self - ) + hook = GroupOffloadingHook(group, offload_on_init, next_group) registry = HookRegistry.check_if_exists_or_initialize(module) registry.register_hook(hook, "group_offloading") - - -def _get_modulelist_or_sequential_group_patterns(module: torch.nn.Module, num_blocks_per_group: int) -> List[str]: - r""" - This function generates group patterns for offloading based on the number of blocks per group. Given a module, it - will iterate through the submodules and find usages of torch.nn.ModuleList and torch.nn.Sequential. For each group - of `num_blocks_per_group` consecutive blocks, it will generate a regex pattern that matches the names of these - blocks. The generated patterns can be used to create ModuleGroup objects which are offloaded and onloaded together. - """ - group_patterns = [] - - # We only want the layers that are top-level in the module (encompass all the other submodules) - # for enabling offloading. This method is specifically targeted for diffusers format models, - # so we can ignore everything but the children of this module. - for name, submodule in module.named_children(): - if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): - continue - for i in range(0, len(submodule), num_blocks_per_group): - num_modules = len(submodule[i : i + num_blocks_per_group]) - pattern = "|".join([rf"{name}\.{i + j}\b" for j in range(num_modules)]) - pattern = f"({pattern})" - onload_leader_pattern = rf"{name}\.{i}\b" - offload_leader_pattern = rf"{name}\.{i + num_modules - 1}\b" - group_patterns.append((pattern, offload_leader_pattern, onload_leader_pattern)) - - logger.debug(f"Generated group patterns for apply_groupwise_offloading: {group_patterns}") - return group_patterns From 42bc19b658c47b580b0338f7a9090c30637945be Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 17 Jan 2025 22:56:56 +0100 Subject: [PATCH 11/37] rewrite --- src/diffusers/hooks/__init__.py | 1 + src/diffusers/hooks/group_offloading.py | 248 +++++++++++++++++++----- src/diffusers/hooks/hooks.py | 81 +++++--- src/diffusers/models/modeling_utils.py | 7 + 4 files changed, 269 insertions(+), 68 deletions(-) diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 6cfabea0c48e..5c183abefac2 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -3,3 +3,4 @@ if is_torch_available(): from .group_offloading import apply_group_offloading + from .hooks import HookRegistry diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index db4ba06f4ddb..d738d0f6e7ca 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -13,7 +13,7 @@ # limitations under the License. from contextlib import nullcontext -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import torch from accelerate.utils import send_to_device @@ -25,6 +25,11 @@ logger = get_logger(__name__) # pylint: disable=invalid-name +_GROUP_OFFLOADING = "group_offloading" +_LAYER_EXECUTION_TRACKER = "layer_execution_tracker" +_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading" + + class ModuleGroup: def __init__( self, @@ -99,6 +104,8 @@ class GroupOffloadingHook(ModelHook): group is responsible for onloading the current module group. """ + _is_stateful = False + def __init__( self, group: ModuleGroup, @@ -132,6 +139,85 @@ def post_forward(self, module: torch.nn.Module, output): return output +class LazyPrefetchGroupOffloadingHook(ModelHook): + _is_stateful = False + + def __init__(self): + self.execution_order: List[Tuple[str, torch.nn.Module]] = [] + self._layer_execution_tracker_module_names = set() + + def initialize_hook(self, module): + for name, submodule in module.named_modules(): + if name == "" or not hasattr(submodule, "_diffusers_hook"): + continue + + registry = HookRegistry.check_if_exists_or_initialize(submodule) + group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING) + + if group_offloading_hook is not None: + + def make_execution_order_update_callback(current_name, current_submodule): + def callback(): + logger.debug(f"Adding {current_name} to the execution order") + self.execution_order.append((current_name, current_submodule)) + + return callback + + layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule)) + registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER) + self._layer_execution_tracker_module_names.add(name) + + return module + + def post_forward(self, module, output): + num_executed = len(self.execution_order) + execution_order_module_names = {name for name, _ in self.execution_order} + + # Check if the two sets are equal + if execution_order_module_names != self._layer_execution_tracker_module_names: + unexecuted_layers = list(self._layer_execution_tracker_module_names - execution_order_module_names) + logger.warning( + "It seems like some layers were not executed during the forward pass. This may lead to problems when " + "applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please " + "make sure that all layers are executed during the forward pass. The following layers were not executed:\n" + f"{unexecuted_layers=}" + ) + + base_module_registry = HookRegistry.check_if_exists_or_initialize(module) + registries = [HookRegistry.check_if_exists_or_initialize(submodule) for _, submodule in self.execution_order] + + for i in range(num_executed): + registries[i].remove_hook(_LAYER_EXECUTION_TRACKER) + + base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING) + + group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries] + if num_executed > 0: + base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING) + base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group + base_module_group_offloading_hook.next_group.onload_self = False + + for i in range(num_executed - 1): + name1, _ = self.execution_order[i] + name2, _ = self.execution_order[i + 1] + logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}") + group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group + group_offloading_hooks[i].next_group.onload_self = False + + return output + + +class LayerExecutionTrackerHook(ModelHook): + _is_stateful = False + + def __init__(self, execution_order_update_callback): + self.execution_order_update_callback = execution_order_update_callback + + def pre_forward(self, module, *args, **kwargs): + self.execution_order_update_callback() + return args, kwargs + + def apply_group_offloading( module: torch.nn.Module, offload_type: str = "block_level", @@ -156,10 +242,10 @@ def apply_group_offloading( _apply_group_offloading_block_level( module, num_blocks_per_group, offload_device, onload_device, force_offload, non_blocking, stream=stream ) - # elif offload_type == "leaf_level": - # _apply_group_offloading_leaf_level( - # module, offload_device, onload_device, force_offload, non_blocking, stream=stream - # ) + elif offload_type == "leaf_level": + _apply_group_offloading_leaf_level( + module, offload_device, onload_device, force_offload, non_blocking, stream=stream + ) def _apply_group_offloading_block_level( @@ -205,12 +291,13 @@ def _apply_group_offloading_block_level( unmatched_modules.append((name, submodule)) continue for i in range(0, len(submodule), num_blocks_per_group): + current_modules = submodule[i : i + num_blocks_per_group] group = ModuleGroup( modules=submodule[i : i + num_blocks_per_group], offload_device=offload_device, onload_device=onload_device, - offload_leader=submodule[i], - onload_leader=None, + offload_leader=current_modules[-1], + onload_leader=current_modules[0], non_blocking=non_blocking, stream=stream, cpu_param_dict=cpu_param_dict, @@ -223,7 +310,9 @@ def _apply_group_offloading_block_level( matched_module_groups[i + 1] if i + 1 < len(matched_module_groups) and stream is not None else None ) should_offload = force_offload or i > 0 - _apply_group_offloading(group, should_offload, next_group) + + for group_module in group.modules: + _apply_group_offloading_hook(group_module, group, should_offload, next_group) parameters = [] for name, parameter in module.named_parameters(recurse=False): @@ -241,7 +330,88 @@ def _apply_group_offloading_block_level( offload_device=offload_device, onload_device=onload_device, offload_leader=module, - onload_leader=None, + onload_leader=module, + parameters=parameters, + buffers=buffers, + non_blocking=False, + stream=None, + cpu_param_dict=None, + onload_self=True, + ) + _apply_group_offloading_hook(module, unmatched_group, force_offload, matched_module_groups[0]) + + +def _apply_group_offloading_leaf_level( + module: torch.nn.Module, + offload_device: torch.device, + onload_device: torch.device, + force_offload: bool, + non_blocking: bool, + stream: Optional[torch.cuda.Stream] = None, +) -> None: + r""" + This function applies offloading to groups of leaf modules in a torch.nn.Module. + + Args: + module (`torch.nn.Module`): + The module to which group offloading is applied. + offload_device (`torch.device`): + The device to which the group of modules are offloaded. This should typically be the CPU. + onload_device (`torch.device`): + The device to which the group of modules are onloaded. + force_offload (`bool`): + If True, all module groups are offloaded to the offload_device. If False, only layers that match + `offload_group_patterns` are offloaded to the offload_device. + non_blocking (`bool`): + If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation + and data transfer. + stream (`torch.cuda.Stream`, *optional*): + If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful + for overlapping computation and data transfer. + """ + + cpu_param_dict = None + if stream is not None: + for param in module.parameters(): + param.data = param.data.cpu().pin_memory() + cpu_param_dict = {param: param.data for param in module.parameters()} + + for submodule in module.modules(): + if len(list(submodule.children())) != 0: + continue + group = ModuleGroup( + modules=[submodule], + offload_device=offload_device, + onload_device=onload_device, + offload_leader=submodule, + onload_leader=submodule, + non_blocking=non_blocking, + stream=stream, + cpu_param_dict=cpu_param_dict, + onload_self=True, + ) + _apply_group_offloading_hook(submodule, group, True, None) + + parameters = [] + buffers = [] + + def gather_non_module_parameters_and_buffers(m: torch.nn.Module): + if len(list(m.children())) == 0: + return + for parameter in m.parameters(recurse=False): + parameters.append(parameter) + for buffer in m.buffers(recurse=False): + buffers.append(buffer) + for submodule in m.children(): + gather_non_module_parameters_and_buffers(submodule) + + gather_non_module_parameters_and_buffers(module) + unmatched_group = ModuleGroup( + modules=[], + offload_device=offload_device, + onload_device=onload_device, + offload_leader=module, + onload_leader=module, parameters=parameters, buffers=buffers, non_blocking=False, @@ -249,42 +419,32 @@ def _apply_group_offloading_block_level( cpu_param_dict=cpu_param_dict, onload_self=True, ) - _apply_group_offloading(unmatched_group, force_offload, matched_module_groups[0]) - - -# def _apply_group_offloading_leaf_level( -# module: torch.nn.Module, -# offload_device: torch.device, -# onload_device: torch.device, -# force_offload: bool, -# non_blocking: bool, -# stream: Optional[torch.cuda.Stream] = None, -# ) -> None: -# r""" -# This function applies offloading to groups of leaf modules in a torch.nn.Module. - -# Args: # module (`torch.nn.Module`): # The module to which group offloading is applied. # offload_device -(`torch.device`): # The device to which the group of modules are offloaded. This should typically be the CPU. # -onload_device (`torch.device`): # The device to which the group of modules are onloaded. # force_offload (`bool`): # If -True, all module groups are offloaded to the offload_device. If False, only layers that match # -`offload_group_patterns` are offloaded to the offload_device. # non_blocking (`bool`): # If True, offloading and -onloading is done asynchronously. This can be useful for overlapping computation # and data transfer. # stream -(`torch.cuda.Stream`, *optional*): # If provided, offloading and onloading is done asynchronously using the provided -stream. This can be useful # for overlapping computation and data transfer. #""" - -# cpu_param_dict = None -# if stream is not None: -# for param in module.parameters(): -# param.data = param.data.cpu().pin_memory() -# cpu_param_dict = {param: param.data for param in module.parameters()} - - -def _apply_group_offloading( + + if stream is None: + _apply_group_offloading_hook(module, unmatched_group, force_offload, None) + else: + _apply_lazy_group_offloading_hook(module, unmatched_group, force_offload, None) + + +def _apply_group_offloading_hook( + module: torch.nn.Module, + group: ModuleGroup, + offload_on_init: bool, + next_group: Optional[ModuleGroup] = None, +) -> None: + hook = GroupOffloadingHook(group, offload_on_init, next_group) + registry = HookRegistry.check_if_exists_or_initialize(module) + registry.register_hook(hook, _GROUP_OFFLOADING) + + +def _apply_lazy_group_offloading_hook( + module: torch.nn.Module, group: ModuleGroup, offload_on_init: bool, next_group: Optional[ModuleGroup] = None, ) -> None: - for module in group.modules: - hook = GroupOffloadingHook(group, offload_on_init, next_group) - registry = HookRegistry.check_if_exists_or_initialize(module) - registry.register_hook(hook, "group_offloading") + hook = GroupOffloadingHook(group, offload_on_init, next_group) + lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook() + registry = HookRegistry.check_if_exists_or_initialize(module) + registry.register_hook(hook, _GROUP_OFFLOADING) + registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index bef4c65c41e1..617783516df0 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -13,6 +13,7 @@ # limitations under the License. import functools +import gc from typing import Any, Dict, Optional, Tuple import torch @@ -48,8 +49,6 @@ def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module: module (`torch.nn.Module`): The module attached to this hook. """ - module.forward = module._old_forward - del module._old_forward return module def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: @@ -99,6 +98,13 @@ def reset_state(self, module: torch.nn.Module): return module +class FunctionReference: + def __init__(self) -> None: + self.pre_forward = None + self.post_forward = None + self.old_forward = None + + class HookRegistry: def __init__(self, module_ref: torch.nn.Module) -> None: super().__init__() @@ -107,39 +113,50 @@ def __init__(self, module_ref: torch.nn.Module) -> None: self._module_ref = module_ref self._hook_order = [] + self._fn_refs = [] def register_hook(self, hook: ModelHook, name: str) -> None: if name in self.hooks.keys(): logger.warning(f"Hook with name {name} already exists, replacing it.") - if hasattr(self._module_ref, "_old_forward"): - old_forward = self._module_ref._old_forward - else: - old_forward = self._module_ref.forward - self._module_ref._old_forward = self._module_ref.forward + forward = self._module_ref.forward - self._module_ref = hook.initialize_hook(self._module_ref) + fn_ref = FunctionReference() + fn_ref.pre_forward = hook.pre_forward + fn_ref.post_forward = hook.post_forward + fn_ref.old_forward = forward - if hasattr(hook, "new_forward"): - rewritten_forward = hook.new_forward + self._module_ref = hook.initialize_hook(self._module_ref) + def create_new_forward(function_reference: FunctionReference): def new_forward(module, *args, **kwargs): - args, kwargs = hook.pre_forward(module, *args, **kwargs) - output = rewritten_forward(module, *args, **kwargs) - return hook.post_forward(module, output) - else: + args, kwargs = function_reference.pre_forward(module, *args, **kwargs) + output = function_reference.old_forward(*args, **kwargs) + return function_reference.post_forward(module, output) - def new_forward(module, *args, **kwargs): - args, kwargs = hook.pre_forward(module, *args, **kwargs) - output = old_forward(*args, **kwargs) - return hook.post_forward(module, output) + return new_forward + + # if hasattr(hook, "new_forward"): + # fn_ref.old_forward = hook.new_forward - self._module_ref.forward = functools.update_wrapper( - functools.partial(new_forward, self._module_ref), old_forward - ) + # def new_forward(module, *args, **kwargs): + # args, kwargs = hook.pre_forward(module, *args, **kwargs) + # output = rewritten_forward(module, *args, **kwargs) + # return hook.post_forward(module, output) + # else: + + # def new_forward(module, *args, **kwargs): + # args, kwargs = hook.pre_forward(module, *args, **kwargs) + # output = forward(*args, **kwargs) + # return hook.post_forward(module, output) + + new_forward = create_new_forward(fn_ref) + new_forward = functools.update_wrapper(functools.partial(new_forward, self._module_ref), forward) + self._module_ref.forward = new_forward self.hooks[name] = hook self._hook_order.append(name) + self._fn_refs.append(fn_ref) def get_hook(self, name: str) -> Optional[ModelHook]: if name not in self.hooks.keys(): @@ -147,11 +164,23 @@ def get_hook(self, name: str) -> Optional[ModelHook]: return self.hooks[name] def remove_hook(self, name: str, recurse: bool = True) -> None: + num_hooks = len(self._hook_order) if name in self.hooks.keys(): hook = self.hooks[name] + index = self._hook_order.index(name) + + fn_ref = self._fn_refs[index] + + if index == num_hooks - 1: + self._module_ref.forward = fn_ref.old_forward + else: + next_fn_ref = self._fn_refs[index + 1] + next_fn_ref.old_forward = fn_ref.old_forward + self._module_ref = hook.deinitalize_hook(self._module_ref) del self.hooks[name] - self._hook_order.remove(name) + self._hook_order.pop(index) + self._fn_refs.pop(index) if recurse: for module_name, module in self._module_ref.named_modules(): @@ -160,8 +189,10 @@ def remove_hook(self, name: str, recurse: bool = True) -> None: if hasattr(module, "_diffusers_hook"): module._diffusers_hook.remove_hook(name, recurse=False) + gc.collect() + def reset_stateful_hooks(self, recurse: bool = True) -> None: - for hook_name in self._hook_order: + for hook_name in reversed(self._hook_order): hook = self.hooks[hook_name] if hook._is_stateful: hook.reset_state(self._module_ref) @@ -180,9 +211,11 @@ def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry return module._diffusers_hook def __repr__(self) -> str: + num_hooks = len(self._hook_order) hook_repr = "" for i, hook_name in enumerate(self._hook_order): - hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})" + hook_invocation_index = num_hooks - i - 1 + hook_repr += f" ({hook_invocation_index}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})" if i < len(self._hook_order) - 1: hook_repr += "\n" return f"HookRegistry(\n{hook_repr}\n)" diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index fcd7775fb608..16d57d2da808 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -84,6 +84,13 @@ def get_parameter_device(parameter: torch.nn.Module) -> torch.device: try: + if hasattr(parameter, "_diffusers_hook"): + for submodule in parameter.modules(): + if hasattr(submodule, "_diffusers_hook"): + registry = parameter._diffusers_hook + hook = registry.get_hook("group_offloading") + return hook.group.onload_device + parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) return next(parameters_and_buffers).device except StopIteration: From 8c63bf5a5f69d0b4b6fcfcd3fe74c02671ef38d0 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 19 Jan 2025 16:52:47 +0100 Subject: [PATCH 12/37] update --- src/diffusers/hooks/group_offloading.py | 80 +++++++++++++++++++---- src/diffusers/hooks/hooks.py | 4 +- src/diffusers/pipelines/pipeline_utils.py | 20 ++++++ 3 files changed, 89 insertions(+), 15 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index d738d0f6e7ca..d80b52e7541f 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -25,10 +25,18 @@ logger = get_logger(__name__) # pylint: disable=invalid-name +# fmt: off _GROUP_OFFLOADING = "group_offloading" _LAYER_EXECUTION_TRACKER = "layer_execution_tracker" _LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading" +_SUPPORTED_PYTORCH_LAYERS = ( + torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, + torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, + torch.nn.Linear, +) +# fmt: on + class ModuleGroup: def __init__( @@ -61,6 +69,7 @@ def __init__( raise ValueError("cpu_param_dict must be provided when using stream for data transfer.") def onload_(self): + r"""Onloads the group of modules to the onload_device.""" context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream) if self.stream is not None: # Wait for previous Host->Device transfer to complete @@ -77,6 +86,7 @@ def onload_(self): buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) def offload_(self): + r"""Offloads the group of modules to the offload_device.""" if self.stream is not None: for group_module in self.modules: for param in group_module.parameters(): @@ -91,10 +101,6 @@ def offload_(self): for buffer in self.buffers: buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking) - # TODO: do we need to sync here because of GPU->CPU transfer? - if self.non_blocking and self.offload_device.type == "cpu": - torch.cpu.synchronize() - class GroupOffloadingHook(ModelHook): r""" @@ -122,13 +128,20 @@ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: return module def pre_forward(self, module: torch.nn.Module, *args, **kwargs): + # If there wasn't an onload_leader assigned, we assume that the submodule that first called its forward + # method is the onload_leader of the group. if self.group.onload_leader is None: self.group.onload_leader = module + + # If the current module is the onload_leader of the group, we onload the group if it is supposed + # to onload itself. In the case of using prefetching with streams, we onload the next group if + # it is not supposed to onload itself. if self.group.onload_leader == module: if self.group.onload_self: self.group.onload_() if self.next_group is not None and not self.next_group.onload_self: self.next_group.onload_() + args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking) kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking) return args, kwargs @@ -140,6 +153,13 @@ def post_forward(self, module: torch.nn.Module, output): class LazyPrefetchGroupOffloadingHook(ModelHook): + r""" + A hook, used in conjuction with GroupOffloadingHook, that applies lazy prefetching to groups of torch.nn.Module. + This hook is used to determine the order in which the layers are executed during the forward pass. Once the layer + invocation order is known, assignments of the next_group attribute for prefetching can be made, which allows + prefetching groups in the correct order. + """ + _is_stateful = False def __init__(self): @@ -147,6 +167,9 @@ def __init__(self): self._layer_execution_tracker_module_names = set() def initialize_hook(self, module): + # To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any + # of the groups), we add a layer execution tracker hook that will be used to determine the order in which the + # layers are executed during the forward pass. for name, submodule in module.named_modules(): if name == "" or not hasattr(submodule, "_diffusers_hook"): continue @@ -170,10 +193,16 @@ def callback(): return module def post_forward(self, module, output): + # At this point, for the current modules' submodules, we know the execution order of the layers. We can now + # remove the layer execution tracker hooks and apply prefetching by setting the next_group attribute for each + # group offloading hook. num_executed = len(self.execution_order) execution_order_module_names = {name for name, _ in self.execution_order} - # Check if the two sets are equal + # It may be possible that some layers were not executed during the forward pass. This can happen if the layer + # is not used in the forward pass, or if the layer is not executed due to some other reason. In such cases, we + # may not be able to apply prefetching in the correct order, which can lead to device-mismatch related errors + # if the missing layers end up being executed in the future. if execution_order_module_names != self._layer_execution_tracker_module_names: unexecuted_layers = list(self._layer_execution_tracker_module_names - execution_order_module_names) logger.warning( @@ -183,14 +212,17 @@ def post_forward(self, module, output): f"{unexecuted_layers=}" ) - base_module_registry = HookRegistry.check_if_exists_or_initialize(module) - registries = [HookRegistry.check_if_exists_or_initialize(submodule) for _, submodule in self.execution_order] + # Remove the layer execution tracker hooks from the submodules + base_module_registry = module._diffusers_hook + registries = [submodule._diffusers_hook for _, submodule in self.execution_order] for i in range(num_executed): registries[i].remove_hook(_LAYER_EXECUTION_TRACKER) + # Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING) + # Apply lazy prefetching by setting required attributes group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries] if num_executed > 0: base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING) @@ -208,6 +240,11 @@ def post_forward(self, module, output): class LayerExecutionTrackerHook(ModelHook): + r""" + A hook that tracks the order in which the layers are executed during the forward pass by calling back to the + LazyPrefetchGroupOffloadingHook to update the execution order. + """ + _is_stateful = False def __init__(self, execution_order_update_callback): @@ -258,7 +295,8 @@ def _apply_group_offloading_block_level( stream: Optional[torch.cuda.Stream] = None, ) -> None: r""" - This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. + This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to + the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks. Args: module (`torch.nn.Module`): @@ -278,12 +316,14 @@ def _apply_group_offloading_block_level( for overlapping computation and data transfer. """ + # Create a pinned CPU parameter dict for async data transfer if streams are to be used cpu_param_dict = None if stream is not None: for param in module.parameters(): param.data = param.data.cpu().pin_memory() cpu_param_dict = {param: param.data for param in module.parameters()} + # Create module groups for ModuleList and Sequential blocks unmatched_modules = [] matched_module_groups = [] for name, submodule in module.named_children(): @@ -305,6 +345,7 @@ def _apply_group_offloading_block_level( ) matched_module_groups.append(group) + # Apply group offloading hooks to the module groups for i, group in enumerate(matched_module_groups): next_group = ( matched_module_groups[i + 1] if i + 1 < len(matched_module_groups) and stream is not None else None @@ -314,6 +355,9 @@ def _apply_group_offloading_block_level( for group_module in group.modules: _apply_group_offloading_hook(group_module, group, should_offload, next_group) + # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately + # when the forward pass of this module is called. This is because the top-level module is not + # part of any group (as doing so would lead to no VRAM savings). parameters = [] for name, parameter in module.named_parameters(recurse=False): if not any(name.startswith(unmatched_name) for unmatched_name, _ in unmatched_modules): @@ -324,6 +368,8 @@ def _apply_group_offloading_block_level( if not any(name.startswith(unmatched_name) for unmatched_name, _ in unmatched_modules): buffers.append(buffer) + # Create a group for the unmatched submodules of the top-level module so that they are on the correct + # device when the forward pass is called. unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules] unmatched_group = ModuleGroup( modules=unmatched_modules, @@ -350,7 +396,10 @@ def _apply_group_offloading_leaf_level( stream: Optional[torch.cuda.Stream] = None, ) -> None: r""" - This function applies offloading to groups of leaf modules in a torch.nn.Module. + This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory + requirements. However, it can be slower compared to other offloading methods due to the excessive number of device + synchronizations. When using devices that support streams to overlap data transfer and computation, this method can + reduce memory usage without any performance degradation. Args: module (`torch.nn.Module`): @@ -370,14 +419,16 @@ def _apply_group_offloading_leaf_level( for overlapping computation and data transfer. """ + # Create a pinned CPU parameter dict for async data transfer if streams are to be used cpu_param_dict = None if stream is not None: for param in module.parameters(): param.data = param.data.cpu().pin_memory() cpu_param_dict = {param: param.data for param in module.parameters()} - for submodule in module.modules(): - if len(list(submodule.children())) != 0: + # Create module groups for leaf modules and apply group offloading hooks + for name, submodule in module.named_modules(): + if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS): continue group = ModuleGroup( modules=[submodule], @@ -392,11 +443,13 @@ def _apply_group_offloading_leaf_level( ) _apply_group_offloading_hook(submodule, group, True, None) + # Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass + # of the module is called parameters = [] buffers = [] def gather_non_module_parameters_and_buffers(m: torch.nn.Module): - if len(list(m.children())) == 0: + if isinstance(m, _SUPPORTED_PYTORCH_LAYERS): return for parameter in m.parameters(recurse=False): parameters.append(parameter) @@ -420,6 +473,9 @@ def gather_non_module_parameters_and_buffers(m: torch.nn.Module): onload_self=True, ) + # When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer + # and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the + # execution order and apply prefetching in the correct order. if stream is None: _apply_group_offloading_hook(module, unmatched_group, force_offload, None) else: diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 617783516df0..9b0165af74cb 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -159,9 +159,7 @@ def new_forward(module, *args, **kwargs): self._fn_refs.append(fn_ref) def get_hook(self, name: str) -> Optional[ModelHook]: - if name not in self.hooks.keys(): - return None - return self.hooks[name] + return self.hooks.get(name, None) def remove_hook(self, name: str, recurse: bool = True) -> None: num_hooks = len(self._hook_order) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 3cafb77e5d63..66022467dd09 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1020,6 +1020,26 @@ def _execution_device(self): [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from Accelerate's module hooks. """ + diffusers_hook_device = None + for name, model in self.components.items(): + if not isinstance(model, torch.nn.Module): + continue + + for submodule in model.modules(): + if not hasattr(submodule, "_diffusers_hook"): + continue + registry = submodule._diffusers_hook + hook = registry.get_hook("group_offloading") + if hook is not None: + diffusers_hook_device = hook.group.onload_device + break + + if diffusers_hook_device is not None: + break + + if diffusers_hook_device is not None: + return diffusers_hook_device + for name, model in self.components.items(): if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload: continue From e09e716202f480a2e40b2affa1d667a3cdbbee20 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 19 Jan 2025 17:36:58 +0100 Subject: [PATCH 13/37] make sure to sync current stream before overwriting with pinned params not doing so will lead to erroneous computations on the GPU and cause bad results --- src/diffusers/hooks/group_offloading.py | 3 ++- src/diffusers/hooks/hooks.py | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index d80b52e7541f..68669c87134b 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -88,6 +88,7 @@ def onload_(self): def offload_(self): r"""Offloads the group of modules to the offload_device.""" if self.stream is not None: + torch.cuda.current_stream().synchronize() for group_module in self.modules: for param in group_module.parameters(): param.data = self.cpu_param_dict[param] @@ -427,7 +428,7 @@ def _apply_group_offloading_leaf_level( cpu_param_dict = {param: param.data for param in module.parameters()} # Create module groups for leaf modules and apply group offloading hooks - for name, submodule in module.named_modules(): + for submodule in module.modules(): if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS): continue group = ModuleGroup( diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 9b0165af74cb..4fbf1a3c8938 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -151,8 +151,7 @@ def new_forward(module, *args, **kwargs): # return hook.post_forward(module, output) new_forward = create_new_forward(fn_ref) - new_forward = functools.update_wrapper(functools.partial(new_forward, self._module_ref), forward) - self._module_ref.forward = new_forward + self._module_ref.forward = functools.update_wrapper(functools.partial(new_forward, self._module_ref), forward) self.hooks[name] = hook self._hook_order.append(name) From 0bf0bafc91efef26f1c47c35481bad4ec1f4b47c Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 19 Jan 2025 17:43:37 +0100 Subject: [PATCH 14/37] better check --- src/diffusers/models/modeling_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 34c38c27b55f..da9bc4642610 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -86,9 +86,11 @@ def get_parameter_device(parameter: torch.nn.Module) -> torch.device: try: if hasattr(parameter, "_diffusers_hook"): for submodule in parameter.modules(): - if hasattr(submodule, "_diffusers_hook"): - registry = parameter._diffusers_hook - hook = registry.get_hook("group_offloading") + if not hasattr(submodule, "_diffusers_hook"): + continue + registry = parameter._diffusers_hook + hook = registry.get_hook("group_offloading") + if hook is not None: return hook.group.onload_device parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) From b850c759d71b671a4f7c2d08e3047407902759bb Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 20 Jan 2025 09:00:02 +0100 Subject: [PATCH 15/37] update --- src/diffusers/hooks/group_offloading.py | 24 +++++++------ src/diffusers/hooks/hooks.py | 45 ++++++++++++------------- 2 files changed, 36 insertions(+), 33 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 68669c87134b..3703ee70aa2c 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -448,18 +448,22 @@ def _apply_group_offloading_leaf_level( # of the module is called parameters = [] buffers = [] + module_dict = dict(module.named_modules()) - def gather_non_module_parameters_and_buffers(m: torch.nn.Module): - if isinstance(m, _SUPPORTED_PYTORCH_LAYERS): - return - for parameter in m.parameters(recurse=False): - parameters.append(parameter) - for buffer in m.buffers(recurse=False): - buffers.append(buffer) - for submodule in m.children(): - gather_non_module_parameters_and_buffers(submodule) + for name, parameter in module.named_parameters(): + atoms = name.split(".") + parent_name = ".".join(atoms[:-1]) + if parent_name in module_dict and isinstance(module_dict[parent_name], _SUPPORTED_PYTORCH_LAYERS): + continue + parameters.append(parameter) + + for name, buffer in module.named_buffers(): + atoms = name.split(".") + parent_name = ".".join(atoms[:-1]) + if parent_name in module_dict and isinstance(module_dict[parent_name], _SUPPORTED_PYTORCH_LAYERS): + continue + buffers.append(buffer) - gather_non_module_parameters_and_buffers(module) unmatched_group = ModuleGroup( modules=[], offload_device=offload_device, diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 4fbf1a3c8938..dc1fef0077d3 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -31,6 +31,9 @@ class ModelHook: _is_stateful = False + def __init__(self) -> None: + self.fn_ref = None + def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: r""" Hook that is executed when a model is initialized. @@ -103,6 +106,7 @@ def __init__(self) -> None: self.pre_forward = None self.post_forward = None self.old_forward = None + self.is_overwritten_forward = False class HookRegistry: @@ -119,40 +123,36 @@ def register_hook(self, hook: ModelHook, name: str) -> None: if name in self.hooks.keys(): logger.warning(f"Hook with name {name} already exists, replacing it.") - forward = self._module_ref.forward - - fn_ref = FunctionReference() - fn_ref.pre_forward = hook.pre_forward - fn_ref.post_forward = hook.post_forward - fn_ref.old_forward = forward - self._module_ref = hook.initialize_hook(self._module_ref) - def create_new_forward(function_reference: FunctionReference): + def create_new_forward(function_reference: FunctionReference, forward): def new_forward(module, *args, **kwargs): args, kwargs = function_reference.pre_forward(module, *args, **kwargs) - output = function_reference.old_forward(*args, **kwargs) + output = forward(*args, **kwargs) return function_reference.post_forward(module, output) return new_forward - # if hasattr(hook, "new_forward"): - # fn_ref.old_forward = hook.new_forward + forward = self._module_ref.forward - # def new_forward(module, *args, **kwargs): - # args, kwargs = hook.pre_forward(module, *args, **kwargs) - # output = rewritten_forward(module, *args, **kwargs) - # return hook.post_forward(module, output) - # else: + fn_ref = FunctionReference() + fn_ref.pre_forward = hook.pre_forward + fn_ref.post_forward = hook.post_forward + fn_ref.old_forward = forward - # def new_forward(module, *args, **kwargs): - # args, kwargs = hook.pre_forward(module, *args, **kwargs) - # output = forward(*args, **kwargs) - # return hook.post_forward(module, output) + if hasattr(hook, "new_forward"): + new_forward = hook.new_forward + fn_ref.is_overwritten_forward = True + else: + new_forward = forward + fn_ref.is_overwritten_forward = False - new_forward = create_new_forward(fn_ref) - self._module_ref.forward = functools.update_wrapper(functools.partial(new_forward, self._module_ref), forward) + rewritten_forward = create_new_forward(fn_ref, new_forward) + self._module_ref.forward = functools.update_wrapper( + functools.partial(rewritten_forward, self._module_ref), forward + ) + hook.fn_ref = fn_ref self.hooks[name] = hook self._hook_order.append(name) self._fn_refs.append(fn_ref) @@ -165,7 +165,6 @@ def remove_hook(self, name: str, recurse: bool = True) -> None: if name in self.hooks.keys(): hook = self.hooks[name] index = self._hook_order.index(name) - fn_ref = self._fn_refs[index] if index == num_hooks - 1: From 6ed9c2f13f4e620c535438d45bca6c0c887e7b19 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 23 Jan 2025 12:18:26 +0100 Subject: [PATCH 16/37] remove hook implementation to not deal with merge conflict --- src/diffusers/hooks/hooks.py | 217 ----------------------------------- 1 file changed, 217 deletions(-) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index dc1fef0077d3..e69de29bb2d1 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -1,217 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -import gc -from typing import Any, Dict, Optional, Tuple - -import torch - -from ..utils.logging import get_logger - - -logger = get_logger(__name__) # pylint: disable=invalid-name - - -class ModelHook: - r""" - A hook that contains callbacks to be executed just before and after the forward method of a model. - """ - - _is_stateful = False - - def __init__(self) -> None: - self.fn_ref = None - - def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: - r""" - Hook that is executed when a model is initialized. - - Args: - module (`torch.nn.Module`): - The module attached to this hook. - """ - return module - - def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module: - r""" - Hook that is executed when a model is deinitalized. - - Args: - module (`torch.nn.Module`): - The module attached to this hook. - """ - return module - - def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: - r""" - Hook that is executed just before the forward method of the model. - - Args: - module (`torch.nn.Module`): - The module whose forward pass will be executed just after this event. - args (`Tuple[Any]`): - The positional arguments passed to the module. - kwargs (`Dict[Str, Any]`): - The keyword arguments passed to the module. - Returns: - `Tuple[Tuple[Any], Dict[Str, Any]]`: - A tuple with the treated `args` and `kwargs`. - """ - return args, kwargs - - def post_forward(self, module: torch.nn.Module, output: Any) -> Any: - r""" - Hook that is executed just after the forward method of the model. - - Args: - module (`torch.nn.Module`): - The module whose forward pass been executed just before this event. - output (`Any`): - The output of the module. - Returns: - `Any`: The processed `output`. - """ - return output - - def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: - r""" - Hook that is executed when the hook is detached from a module. - - Args: - module (`torch.nn.Module`): - The module detached from this hook. - """ - return module - - def reset_state(self, module: torch.nn.Module): - if self._is_stateful: - raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") - return module - - -class FunctionReference: - def __init__(self) -> None: - self.pre_forward = None - self.post_forward = None - self.old_forward = None - self.is_overwritten_forward = False - - -class HookRegistry: - def __init__(self, module_ref: torch.nn.Module) -> None: - super().__init__() - - self.hooks: Dict[str, ModelHook] = {} - - self._module_ref = module_ref - self._hook_order = [] - self._fn_refs = [] - - def register_hook(self, hook: ModelHook, name: str) -> None: - if name in self.hooks.keys(): - logger.warning(f"Hook with name {name} already exists, replacing it.") - - self._module_ref = hook.initialize_hook(self._module_ref) - - def create_new_forward(function_reference: FunctionReference, forward): - def new_forward(module, *args, **kwargs): - args, kwargs = function_reference.pre_forward(module, *args, **kwargs) - output = forward(*args, **kwargs) - return function_reference.post_forward(module, output) - - return new_forward - - forward = self._module_ref.forward - - fn_ref = FunctionReference() - fn_ref.pre_forward = hook.pre_forward - fn_ref.post_forward = hook.post_forward - fn_ref.old_forward = forward - - if hasattr(hook, "new_forward"): - new_forward = hook.new_forward - fn_ref.is_overwritten_forward = True - else: - new_forward = forward - fn_ref.is_overwritten_forward = False - - rewritten_forward = create_new_forward(fn_ref, new_forward) - self._module_ref.forward = functools.update_wrapper( - functools.partial(rewritten_forward, self._module_ref), forward - ) - - hook.fn_ref = fn_ref - self.hooks[name] = hook - self._hook_order.append(name) - self._fn_refs.append(fn_ref) - - def get_hook(self, name: str) -> Optional[ModelHook]: - return self.hooks.get(name, None) - - def remove_hook(self, name: str, recurse: bool = True) -> None: - num_hooks = len(self._hook_order) - if name in self.hooks.keys(): - hook = self.hooks[name] - index = self._hook_order.index(name) - fn_ref = self._fn_refs[index] - - if index == num_hooks - 1: - self._module_ref.forward = fn_ref.old_forward - else: - next_fn_ref = self._fn_refs[index + 1] - next_fn_ref.old_forward = fn_ref.old_forward - - self._module_ref = hook.deinitalize_hook(self._module_ref) - del self.hooks[name] - self._hook_order.pop(index) - self._fn_refs.pop(index) - - if recurse: - for module_name, module in self._module_ref.named_modules(): - if module_name == "": - continue - if hasattr(module, "_diffusers_hook"): - module._diffusers_hook.remove_hook(name, recurse=False) - - gc.collect() - - def reset_stateful_hooks(self, recurse: bool = True) -> None: - for hook_name in reversed(self._hook_order): - hook = self.hooks[hook_name] - if hook._is_stateful: - hook.reset_state(self._module_ref) - - if recurse: - for module_name, module in self._module_ref.named_modules(): - if module_name == "": - continue - if hasattr(module, "_diffusers_hook"): - module._diffusers_hook.reset_stateful_hooks(recurse=False) - - @classmethod - def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry": - if not hasattr(module, "_diffusers_hook"): - module._diffusers_hook = cls(module) - return module._diffusers_hook - - def __repr__(self) -> str: - num_hooks = len(self._hook_order) - hook_repr = "" - for i, hook_name in enumerate(self._hook_order): - hook_invocation_index = num_hooks - i - 1 - hook_repr += f" ({hook_invocation_index}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})" - if i < len(self._hook_order) - 1: - hook_repr += "\n" - return f"HookRegistry(\n{hook_repr}\n)" From 073d4bc18573846889523962def7d763d4c1e4b7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 23 Jan 2025 12:20:25 +0100 Subject: [PATCH 17/37] re-add hook changes --- src/diffusers/hooks/hooks.py | 91 ++++++++++++++++++++++++------------ 1 file changed, 62 insertions(+), 29 deletions(-) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index bef4c65c41e1..f3968e853476 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -13,6 +13,7 @@ # limitations under the License. import functools +import gc from typing import Any, Dict, Optional, Tuple import torch @@ -30,6 +31,9 @@ class ModelHook: _is_stateful = False + def __init__(self): + self.fn_ref: "FunctionReference" = None + def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: r""" Hook that is executed when a model is initialized. @@ -48,8 +52,6 @@ def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module: module (`torch.nn.Module`): The module attached to this hook. """ - module.forward = module._old_forward - del module._old_forward return module def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: @@ -99,6 +101,14 @@ def reset_state(self, module: torch.nn.Module): return module +class FunctionReference: + def __init__(self) -> None: + self.pre_forward = None + self.post_forward = None + self.old_forward = None + self.overwritten_forward = None + + class HookRegistry: def __init__(self, module_ref: torch.nn.Module) -> None: super().__init__() @@ -107,51 +117,68 @@ def __init__(self, module_ref: torch.nn.Module) -> None: self._module_ref = module_ref self._hook_order = [] + self._fn_refs = [] def register_hook(self, hook: ModelHook, name: str) -> None: if name in self.hooks.keys(): logger.warning(f"Hook with name {name} already exists, replacing it.") - if hasattr(self._module_ref, "_old_forward"): - old_forward = self._module_ref._old_forward - else: - old_forward = self._module_ref.forward - self._module_ref._old_forward = self._module_ref.forward - self._module_ref = hook.initialize_hook(self._module_ref) - if hasattr(hook, "new_forward"): - rewritten_forward = hook.new_forward - + def create_new_forward(function_reference: FunctionReference): def new_forward(module, *args, **kwargs): - args, kwargs = hook.pre_forward(module, *args, **kwargs) - output = rewritten_forward(module, *args, **kwargs) - return hook.post_forward(module, output) - else: + args, kwargs = function_reference.pre_forward(module, *args, **kwargs) + output = function_reference.old_forward(*args, **kwargs) + return function_reference.post_forward(module, output) - def new_forward(module, *args, **kwargs): - args, kwargs = hook.pre_forward(module, *args, **kwargs) - output = old_forward(*args, **kwargs) - return hook.post_forward(module, output) + return new_forward + + forward = self._module_ref.forward + fn_ref = FunctionReference() + fn_ref.pre_forward = hook.pre_forward + fn_ref.post_forward = hook.post_forward + fn_ref.old_forward = forward + + if hasattr(hook, "new_forward"): + fn_ref.overwritten_forward = forward + fn_ref.old_forward = functools.update_wrapper( + functools.partial(hook.new_forward, self._module_ref), hook.new_forward + ) + + rewritten_forward = create_new_forward(fn_ref) self._module_ref.forward = functools.update_wrapper( - functools.partial(new_forward, self._module_ref), old_forward + functools.partial(rewritten_forward, self._module_ref), rewritten_forward ) + hook.fn_ref = fn_ref self.hooks[name] = hook self._hook_order.append(name) + self._fn_refs.append(fn_ref) def get_hook(self, name: str) -> Optional[ModelHook]: - if name not in self.hooks.keys(): - return None - return self.hooks[name] + return self.hooks.get(name, None) def remove_hook(self, name: str, recurse: bool = True) -> None: + num_hooks = len(self._hook_order) if name in self.hooks.keys(): hook = self.hooks[name] + index = self._hook_order.index(name) + fn_ref = self._fn_refs[index] + + old_forward = fn_ref.old_forward + if fn_ref.overwritten_forward is not None: + old_forward = fn_ref.overwritten_forward + + if index == num_hooks - 1: + self._module_ref.forward = old_forward + else: + self._fn_refs[index + 1].old_forward = old_forward + self._module_ref = hook.deinitalize_hook(self._module_ref) del self.hooks[name] - self._hook_order.remove(name) + self._hook_order.pop(index) + self._fn_refs.pop(index) if recurse: for module_name, module in self._module_ref.named_modules(): @@ -160,8 +187,10 @@ def remove_hook(self, name: str, recurse: bool = True) -> None: if hasattr(module, "_diffusers_hook"): module._diffusers_hook.remove_hook(name, recurse=False) + gc.collect() + def reset_stateful_hooks(self, recurse: bool = True) -> None: - for hook_name in self._hook_order: + for hook_name in reversed(self._hook_order): hook = self.hooks[hook_name] if hook._is_stateful: hook.reset_state(self._module_ref) @@ -180,9 +209,13 @@ def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry return module._diffusers_hook def __repr__(self) -> str: - hook_repr = "" + registry_repr = "" for i, hook_name in enumerate(self._hook_order): - hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})" + if self.hooks[hook_name].__class__.__repr__ is not object.__repr__: + hook_repr = self.hooks[hook_name].__repr__() + else: + hook_repr = self.hooks[hook_name].__class__.__name__ + registry_repr += f" ({i}) {hook_name} - {hook_repr}" if i < len(self._hook_order) - 1: - hook_repr += "\n" - return f"HookRegistry(\n{hook_repr}\n)" + registry_repr += "\n" + return f"HookRegistry(\n{registry_repr}\n)" From 8ba2bda231f6947318cd8eae5b9bdbc35431b7d0 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 23 Jan 2025 14:04:27 +0100 Subject: [PATCH 18/37] why use more memory when less memory do trick --- src/diffusers/hooks/group_offloading.py | 140 +++++++++++++++++++----- 1 file changed, 111 insertions(+), 29 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 3703ee70aa2c..db458570374e 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -13,7 +13,7 @@ # limitations under the License. from contextlib import nullcontext -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Set, Tuple import torch from accelerate.utils import send_to_device @@ -284,6 +284,8 @@ def apply_group_offloading( _apply_group_offloading_leaf_level( module, offload_device, onload_device, force_offload, non_blocking, stream=stream ) + else: + raise ValueError(f"Unsupported offload_type: {offload_type}") def _apply_group_offloading_block_level( @@ -325,12 +327,15 @@ def _apply_group_offloading_block_level( cpu_param_dict = {param: param.data for param in module.parameters()} # Create module groups for ModuleList and Sequential blocks + modules_with_group_offloading = set() unmatched_modules = [] matched_module_groups = [] for name, submodule in module.named_children(): if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): unmatched_modules.append((name, submodule)) + modules_with_group_offloading.add(name) continue + for i in range(0, len(submodule), num_blocks_per_group): current_modules = submodule[i : i + num_blocks_per_group] group = ModuleGroup( @@ -345,6 +350,8 @@ def _apply_group_offloading_block_level( onload_self=stream is None, ) matched_module_groups.append(group) + for j in range(i, i + len(current_modules)): + modules_with_group_offloading.add(f"{name}.{j}") # Apply group offloading hooks to the module groups for i, group in enumerate(matched_module_groups): @@ -359,15 +366,10 @@ def _apply_group_offloading_block_level( # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately # when the forward pass of this module is called. This is because the top-level module is not # part of any group (as doing so would lead to no VRAM savings). - parameters = [] - for name, parameter in module.named_parameters(recurse=False): - if not any(name.startswith(unmatched_name) for unmatched_name, _ in unmatched_modules): - parameters.append(parameter) - - buffers = [] - for name, buffer in module.named_buffers(recurse=False): - if not any(name.startswith(unmatched_name) for unmatched_name, _ in unmatched_modules): - buffers.append(buffer) + parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading) + buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading) + parameters = [param for _, param in parameters] + buffers = [buffer for _, buffer in buffers] # Create a group for the unmatched submodules of the top-level module so that they are on the correct # device when the forward pass is called. @@ -428,7 +430,8 @@ def _apply_group_offloading_leaf_level( cpu_param_dict = {param: param.data for param in module.parameters()} # Create module groups for leaf modules and apply group offloading hooks - for submodule in module.modules(): + modules_with_group_offloading = set() + for name, submodule in module.named_modules(): if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS): continue group = ModuleGroup( @@ -443,38 +446,65 @@ def _apply_group_offloading_leaf_level( onload_self=True, ) _apply_group_offloading_hook(submodule, group, True, None) + modules_with_group_offloading.add(name) # Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass # of the module is called - parameters = [] - buffers = [] module_dict = dict(module.named_modules()) + parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading) + buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading) + + # Find closest module parent for each parameter and buffer, and attach group hooks + common_kwargs = { + "modules": [], + "offload_device": offload_device, + "onload_device": onload_device, + "non_blocking": non_blocking, + "stream": stream, + "cpu_param_dict": cpu_param_dict, + "onload_self": True, + } + + for name, param in parameters: + parent_name = _find_parent_module_in_module_dict(name, module_dict) + parent_module = module_dict[parent_name] + logger.info(f"TODO: REMOVETHIS Found parameter {name} with parent module {parent_name}") + assert getattr(parent_module, "_diffusers_hook", None) is None + group = ModuleGroup( + offload_leader=parent_module, + onload_leader=parent_module, + parameters=[param], + buffers=None, + **common_kwargs, + ) + _apply_group_offloading_hook(parent_module, group, True, None) - for name, parameter in module.named_parameters(): - atoms = name.split(".") - parent_name = ".".join(atoms[:-1]) - if parent_name in module_dict and isinstance(module_dict[parent_name], _SUPPORTED_PYTORCH_LAYERS): - continue - parameters.append(parameter) - - for name, buffer in module.named_buffers(): - atoms = name.split(".") - parent_name = ".".join(atoms[:-1]) - if parent_name in module_dict and isinstance(module_dict[parent_name], _SUPPORTED_PYTORCH_LAYERS): - continue - buffers.append(buffer) + for name, buffer in buffers: + parent_name = _find_parent_module_in_module_dict(name, module_dict) + parent_module = module_dict[parent_name] + logger.info(f"TODO: REMOVETHIS Found buffer {name} with parent module {parent_name}") + assert getattr(parent_module, "_diffusers_hook", None) is None + group = ModuleGroup( + offload_leader=parent_module, + onload_leader=parent_module, + parameters=None, + buffers=[buffer], + **common_kwargs, + ) + _apply_group_offloading_hook(parent_module, group, True, None) + # This is a dummy group that will handle lazy prefetching from the top-level module to the first leaf module unmatched_group = ModuleGroup( modules=[], offload_device=offload_device, onload_device=onload_device, offload_leader=module, onload_leader=module, - parameters=parameters, - buffers=buffers, + parameters=None, + buffers=None, non_blocking=False, stream=None, - cpu_param_dict=cpu_param_dict, + cpu_param_dict=None, onload_self=True, ) @@ -509,3 +539,55 @@ def _apply_lazy_group_offloading_hook( registry = HookRegistry.check_if_exists_or_initialize(module) registry.register_hook(hook, _GROUP_OFFLOADING) registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING) + + +def _gather_parameters_with_no_group_offloading_parent( + module: torch.nn.Module, modules_with_group_offloading: Set[str] +) -> List[torch.nn.Parameter]: + parameters = [] + for name, parameter in module.named_parameters(): + has_parent_with_group_offloading = False + atoms = name.split(".") + + while len(atoms) > 0: + parent_name = ".".join(atoms) + if parent_name in modules_with_group_offloading: + has_parent_with_group_offloading = True + break + atoms.pop() + + if not has_parent_with_group_offloading: + logger.info(f"TODO: REMOVETHIS Found parameter {name} with no parent module with group offloading") + parameters.append((name, parameter)) + return parameters + + +def _gather_buffers_with_no_group_offloading_parent( + module: torch.nn.Module, modules_with_group_offloading: Set[str] +) -> List[torch.Tensor]: + buffers = [] + for name, buffer in module.named_buffers(): + has_parent_with_group_offloading = False + atoms = name.split(".") + + while len(atoms) > 0: + parent_name = ".".join(atoms) + if parent_name in modules_with_group_offloading: + has_parent_with_group_offloading = True + break + atoms.pop() + + if not has_parent_with_group_offloading: + logger.info(f"TODO: REMOVETHIS Found buffer {name} with no parent module with group offloading") + buffers.append((name, buffer)) + return buffers + + +def _find_parent_module_in_module_dict(name: str, module_dict: Dict[str, torch.nn.Module]) -> str: + atoms = name.split(".") + while len(atoms) > 0: + parent_name = ".".join(atoms) + if parent_name in module_dict: + return parent_name + atoms.pop() + return "" From b2e838f5defffbd1fc0dce4ed749defc5dc58ce5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 23 Jan 2025 15:29:35 +0100 Subject: [PATCH 19/37] why still use slightly more memory when less memory do trick --- src/diffusers/hooks/group_offloading.py | 53 ++++++++++++------------- 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index db458570374e..c445a0619909 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -455,41 +455,40 @@ def _apply_group_offloading_leaf_level( buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading) # Find closest module parent for each parameter and buffer, and attach group hooks - common_kwargs = { - "modules": [], - "offload_device": offload_device, - "onload_device": onload_device, - "non_blocking": non_blocking, - "stream": stream, - "cpu_param_dict": cpu_param_dict, - "onload_self": True, - } - + parent_to_parameters = {} for name, param in parameters: parent_name = _find_parent_module_in_module_dict(name, module_dict) - parent_module = module_dict[parent_name] - logger.info(f"TODO: REMOVETHIS Found parameter {name} with parent module {parent_name}") - assert getattr(parent_module, "_diffusers_hook", None) is None - group = ModuleGroup( - offload_leader=parent_module, - onload_leader=parent_module, - parameters=[param], - buffers=None, - **common_kwargs, - ) - _apply_group_offloading_hook(parent_module, group, True, None) + if parent_name in parent_to_parameters: + parent_to_parameters[parent_name].append(param) + else: + parent_to_parameters[parent_name] = [param] + parent_to_buffers = {} for name, buffer in buffers: parent_name = _find_parent_module_in_module_dict(name, module_dict) - parent_module = module_dict[parent_name] - logger.info(f"TODO: REMOVETHIS Found buffer {name} with parent module {parent_name}") + if parent_name in parent_to_buffers: + parent_to_buffers[parent_name].append(buffer) + else: + parent_to_buffers[parent_name] = [buffer] + + parent_names = set(parent_to_parameters.keys()) | set(parent_to_buffers.keys()) + for name in parent_names: + parameters = parent_to_parameters.get(name, []) + buffers = parent_to_buffers.get(name, []) + parent_module = module_dict[name] assert getattr(parent_module, "_diffusers_hook", None) is None group = ModuleGroup( + modules=[], + offload_device=offload_device, + onload_device=onload_device, offload_leader=parent_module, onload_leader=parent_module, - parameters=None, - buffers=[buffer], - **common_kwargs, + parameters=parameters, + buffers=buffers, + non_blocking=non_blocking, + stream=stream, + cpu_param_dict=cpu_param_dict, + onload_self=True, ) _apply_group_offloading_hook(parent_module, group, True, None) @@ -557,7 +556,6 @@ def _gather_parameters_with_no_group_offloading_parent( atoms.pop() if not has_parent_with_group_offloading: - logger.info(f"TODO: REMOVETHIS Found parameter {name} with no parent module with group offloading") parameters.append((name, parameter)) return parameters @@ -578,7 +576,6 @@ def _gather_buffers_with_no_group_offloading_parent( atoms.pop() if not has_parent_with_group_offloading: - logger.info(f"TODO: REMOVETHIS Found buffer {name} with no parent module with group offloading") buffers.append((name, buffer)) return buffers From 5ea3d8af549c83e68820a0062966b9090b241972 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 26 Jan 2025 16:48:39 +0100 Subject: [PATCH 20/37] optimise --- src/diffusers/hooks/group_offloading.py | 9 ++++++--- src/diffusers/hooks/hooks.py | 3 --- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index c445a0619909..04cd2c1ca53e 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -218,10 +218,10 @@ def post_forward(self, module, output): registries = [submodule._diffusers_hook for _, submodule in self.execution_order] for i in range(num_executed): - registries[i].remove_hook(_LAYER_EXECUTION_TRACKER) + registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False) # Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass - base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING) + base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False) # Apply lazy prefetching by setting required attributes group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries] @@ -536,7 +536,10 @@ def _apply_lazy_group_offloading_hook( hook = GroupOffloadingHook(group, offload_on_init, next_group) lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook() registry = HookRegistry.check_if_exists_or_initialize(module) - registry.register_hook(hook, _GROUP_OFFLOADING) + # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent + # is the current module. In such cases, we don't want to overwrite the existing group offloading hook. + if registry.get_hook(_GROUP_OFFLOADING) is None: + registry.register_hook(hook, _GROUP_OFFLOADING) registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index f3968e853476..df83b48efbdd 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -13,7 +13,6 @@ # limitations under the License. import functools -import gc from typing import Any, Dict, Optional, Tuple import torch @@ -187,8 +186,6 @@ def remove_hook(self, name: str, recurse: bool = True) -> None: if hasattr(module, "_diffusers_hook"): module._diffusers_hook.remove_hook(name, recurse=False) - gc.collect() - def reset_stateful_hooks(self, recurse: bool = True) -> None: for hook_name in reversed(self._hook_order): hook = self.hooks[hook_name] From db2fd3bab7509b1c25524b9280b4e0405b9e378b Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 26 Jan 2025 22:46:58 +0100 Subject: [PATCH 21/37] add model tests --- src/diffusers/hooks/group_offloading.py | 17 +++++--- src/diffusers/hooks/hooks.py | 5 ++- .../test_models_autoencoder_oobleck.py | 9 +++++ .../test_models_consistency_decoder_vae.py | 4 ++ tests/models/autoencoders/test_models_vq.py | 4 ++ tests/models/test_modeling_common.py | 40 +++++++++++++++++++ .../test_models_dit_transformer2d.py | 8 ++++ .../test_models_transformer_hunyuan_dit.py | 8 ++++ 8 files changed, 89 insertions(+), 6 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 04cd2c1ca53e..d892d41d1ad0 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -387,7 +387,8 @@ def _apply_group_offloading_block_level( cpu_param_dict=None, onload_self=True, ) - _apply_group_offloading_hook(module, unmatched_group, force_offload, matched_module_groups[0]) + next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None + _apply_group_offloading_hook(module, unmatched_group, force_offload, next_group) def _apply_group_offloading_leaf_level( @@ -522,9 +523,13 @@ def _apply_group_offloading_hook( offload_on_init: bool, next_group: Optional[ModuleGroup] = None, ) -> None: - hook = GroupOffloadingHook(group, offload_on_init, next_group) registry = HookRegistry.check_if_exists_or_initialize(module) - registry.register_hook(hook, _GROUP_OFFLOADING) + + # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent + # is the current module. In such cases, we don't want to overwrite the existing group offloading hook. + if registry.get_hook(_GROUP_OFFLOADING) is None: + hook = GroupOffloadingHook(group, offload_on_init, next_group) + registry.register_hook(hook, _GROUP_OFFLOADING) def _apply_lazy_group_offloading_hook( @@ -533,13 +538,15 @@ def _apply_lazy_group_offloading_hook( offload_on_init: bool, next_group: Optional[ModuleGroup] = None, ) -> None: - hook = GroupOffloadingHook(group, offload_on_init, next_group) - lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook() registry = HookRegistry.check_if_exists_or_initialize(module) + # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent # is the current module. In such cases, we don't want to overwrite the existing group offloading hook. if registry.get_hook(_GROUP_OFFLOADING) is None: + hook = GroupOffloadingHook(group, offload_on_init, next_group) registry.register_hook(hook, _GROUP_OFFLOADING) + + lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook() registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index df83b48efbdd..5d502dbc023f 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -120,7 +120,10 @@ def __init__(self, module_ref: torch.nn.Module) -> None: def register_hook(self, hook: ModelHook, name: str) -> None: if name in self.hooks.keys(): - logger.warning(f"Hook with name {name} already exists, replacing it.") + raise ValueError( + f"Hook with name {name} already exists in the registry. Please use a different name or " + f"first remove the existing hook and then add a new one." + ) self._module_ref = hook.initialize_hook(self._module_ref) diff --git a/tests/models/autoencoders/test_models_autoencoder_oobleck.py b/tests/models/autoencoders/test_models_autoencoder_oobleck.py index 1f922a9842ee..5e137451914e 100644 --- a/tests/models/autoencoders/test_models_autoencoder_oobleck.py +++ b/tests/models/autoencoders/test_models_autoencoder_oobleck.py @@ -132,6 +132,15 @@ def test_layerwise_casting_inference(self): def test_layerwise_casting_memory(self): pass + @unittest.skip( + "The convolution layers of AutoencoderOobleck are wrapped with torch.nn.utils.weight_norm. This causes the hook's pre_forward to not " + "cast the module weights to the expected device (as required by forward pass). As a result, forward pass errors out. To fix:\n" + "1. Make sure `nn::Module::to(device)` works with `torch.nn.utils.weight_norm` wrapped convolution layer.\n" + "2. Unskip this test." + ) + def test_group_offloading(self): + pass + @slow class AutoencoderOobleckIntegrationTests(unittest.TestCase): diff --git a/tests/models/autoencoders/test_models_consistency_decoder_vae.py b/tests/models/autoencoders/test_models_consistency_decoder_vae.py index 77977a78d83b..4f1af127cd44 100644 --- a/tests/models/autoencoders/test_models_consistency_decoder_vae.py +++ b/tests/models/autoencoders/test_models_consistency_decoder_vae.py @@ -155,6 +155,10 @@ def test_enable_disable_slicing(self): "Without slicing outputs should match with the outputs when slicing is manually disabled.", ) + @unittest.skip("Not quite sure why this test fails and unable to debug.") + def test_group_offloading(self): + pass + @slow class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase): diff --git a/tests/models/autoencoders/test_models_vq.py b/tests/models/autoencoders/test_models_vq.py index 77abe139d785..33822ed7a882 100644 --- a/tests/models/autoencoders/test_models_vq.py +++ b/tests/models/autoencoders/test_models_vq.py @@ -116,3 +116,7 @@ def test_loss_pretrained(self): expected_output = torch.tensor([0.1936]) # fmt: on self.assertTrue(torch.allclose(output, expected_output, atol=1e-3)) + + @unittest.skip("Group offloading for torch::nn::Embedding layers is not yet supported.") + def test_group_offloading(self): + pass diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 05050e05bb19..04a4d9d28e89 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -37,6 +37,7 @@ from parameterized import parameterized from requests.exceptions import HTTPError +from diffusers.hooks import apply_group_offloading from diffusers.models import UNet2DConditionModel from diffusers.models.attention_processor import ( AttnProcessor, @@ -1433,6 +1434,45 @@ def get_memory_usage(storage_dtype, compute_dtype): or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE ) + @require_torch_gpu + def test_group_offloading(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + torch.manual_seed(0) + + def run_forward(model): + model.eval() + with torch.no_grad(): + return model(**inputs_dict)[0] + + model = self.model_class(**init_dict) + model.to(torch_device) + output_without_group_offloading = run_forward(model) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + apply_group_offloading(model, offload_type="block_level", num_blocks_per_group=1) + output_with_group_offloading1 = run_forward(model) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + apply_group_offloading(model, offload_type="block_level", num_blocks_per_group=1, non_blocking=True) + output_with_group_offloading2 = run_forward(model) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + apply_group_offloading(model, offload_type="leaf_level") + output_with_group_offloading3 = run_forward(model) + + torch.manual_seed(0) + model = self.model_class(**init_dict) + apply_group_offloading(model, offload_type="leaf_level", use_stream=True) + output_with_group_offloading4 = run_forward(model) + + self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5)) + self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5)) + self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5)) + self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5)) + @is_staging_test class ModelPushToHubTester(unittest.TestCase): diff --git a/tests/models/transformers/test_models_dit_transformer2d.py b/tests/models/transformers/test_models_dit_transformer2d.py index 5f4a2f587e92..76db63647404 100644 --- a/tests/models/transformers/test_models_dit_transformer2d.py +++ b/tests/models/transformers/test_models_dit_transformer2d.py @@ -100,3 +100,11 @@ def test_correct_class_remapping_from_pretrained_config(self): def test_correct_class_remapping(self): model = Transformer2DModel.from_pretrained("facebook/DiT-XL-2-256", subfolder="transformer") assert isinstance(model, DiTTransformer2DModel) + + @unittest.skip( + "This model uses a direct call to self.transformer_blocks[0].norm1.emb. This causes attached hooks to not be invoked " + "when block offloading is enabled. In order for it to work, the model should correctly first invoke the forward pass " + "the transformer blocks, so that weights can be onloaded, instead of directly invoking a submodule of the block." + ) + def test_group_offloading(self): + pass diff --git a/tests/models/transformers/test_models_transformer_hunyuan_dit.py b/tests/models/transformers/test_models_transformer_hunyuan_dit.py index ea05abed38d9..1889d2aaaf4a 100644 --- a/tests/models/transformers/test_models_transformer_hunyuan_dit.py +++ b/tests/models/transformers/test_models_transformer_hunyuan_dit.py @@ -111,3 +111,11 @@ def test_set_xformers_attn_processor_for_determinism(self): @unittest.skip("HunyuanDIT use a custom processor HunyuanAttnProcessor2_0") def test_set_attn_processor_for_determinism(self): pass + + @unittest.skip( + "This model uses a direct call to F.multi_head_attention_forward instead of using a torch.nn.Module layer. This " + "usage is not yet supported with group offloading, because the call directly operates on the weights of the module. " + "We attach hooks correctly, but the onloading does not occur because the torch::nn::Module::forward is never invoked." + ) + def test_group_offloading(self): + pass From a0160e110038bfdde676797d8bb6e04635b5b2b7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 26 Jan 2025 23:49:52 +0100 Subject: [PATCH 22/37] add pipeline tests --- tests/models/test_modeling_common.py | 7 ++ tests/pipelines/allegro/test_allegro.py | 1 + tests/pipelines/amused/test_amused.py | 1 + .../pipelines/animatediff/test_animatediff.py | 1 + .../aura_flow/test_pipeline_aura_flow.py | 1 + tests/pipelines/cogvideo/test_cogvideox.py | 1 + .../cogvideo/test_cogvideox_fun_control.py | 1 + tests/pipelines/cogview3/test_cogview3plus.py | 1 + tests/pipelines/consisid/test_consisid.py | 1 + tests/pipelines/controlnet/test_controlnet.py | 1 + .../controlnet/test_controlnet_sdxl.py | 1 + .../controlnet_flux/test_controlnet_flux.py | 1 + .../controlnet_sd3/test_controlnet_sd3.py | 1 + .../controlnet_xs/test_controlnetxs.py | 1 + .../controlnet_xs/test_controlnetxs_sdxl.py | 1 + tests/pipelines/flux/test_pipeline_flux.py | 1 + .../flux/test_pipeline_flux_control.py | 1 + .../pipelines/flux/test_pipeline_flux_fill.py | 1 + .../hunyuan_video/test_hunyuan_video.py | 1 + tests/pipelines/latte/test_latte.py | 1 + tests/pipelines/ltx/test_ltx.py | 1 + tests/pipelines/lumina/test_lumina_nextdit.py | 1 + tests/pipelines/mochi/test_mochi.py | 1 + tests/pipelines/pia/test_pia.py | 1 + tests/pipelines/pixart_alpha/test_pixart.py | 1 + tests/pipelines/pixart_sigma/test_pixart.py | 1 + tests/pipelines/sana/test_sana.py | 1 + .../stable_diffusion/test_stable_diffusion.py | 1 + .../test_stable_diffusion.py | 1 + .../test_pipeline_stable_diffusion_3.py | 1 + .../test_stable_diffusion_xl.py | 1 + tests/pipelines/test_pipelines_common.py | 67 +++++++++++++++++++ 32 files changed, 104 insertions(+) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 04a4d9d28e89..9ea8348e076e 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1440,6 +1440,13 @@ def test_group_offloading(self): torch.manual_seed(0) def run_forward(model): + self.assertTrue( + all( + module._diffusers_hook.get_hook("group_offloading") is not None + for module in model.modules() + if hasattr(module, "_diffusers_hook") + ) + ) model.eval() with torch.no_grad(): return model(**inputs_dict)[0] diff --git a/tests/pipelines/allegro/test_allegro.py b/tests/pipelines/allegro/test_allegro.py index 322be373641a..cbc6355ac84f 100644 --- a/tests/pipelines/allegro/test_allegro.py +++ b/tests/pipelines/allegro/test_allegro.py @@ -58,6 +58,7 @@ class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) test_xformers_attention = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/amused/test_amused.py b/tests/pipelines/amused/test_amused.py index 2dfc36a6ce45..a0fbc5df1c28 100644 --- a/tests/pipelines/amused/test_amused.py +++ b/tests/pipelines/amused/test_amused.py @@ -39,6 +39,7 @@ class AmusedPipelineFastTests(PipelineTesterMixin, unittest.TestCase): params = TEXT_TO_IMAGE_PARAMS | {"encoder_hidden_states", "negative_encoder_hidden_states"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 1b3115c8eb1d..4913a46b8d4f 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -61,6 +61,7 @@ class AnimateDiffPipelineFastTests( ] ) test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): cross_attention_dim = 8 diff --git a/tests/pipelines/aura_flow/test_pipeline_aura_flow.py b/tests/pipelines/aura_flow/test_pipeline_aura_flow.py index bee905f9ae13..f0b67afcc052 100644 --- a/tests/pipelines/aura_flow/test_pipeline_aura_flow.py +++ b/tests/pipelines/aura_flow/test_pipeline_aura_flow.py @@ -31,6 +31,7 @@ class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin): ) batch_params = frozenset(["prompt", "negative_prompt"]) test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py index 9ce3d8e9de31..99ce3a3a4f70 100644 --- a/tests/pipelines/cogvideo/test_cogvideox.py +++ b/tests/pipelines/cogvideo/test_cogvideox.py @@ -59,6 +59,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) test_xformers_attention = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py index c936bad4c3d5..2e962bd247b9 100644 --- a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py +++ b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py @@ -56,6 +56,7 @@ class CogVideoXFunControlPipelineFastTests(PipelineTesterMixin, unittest.TestCas ) test_xformers_attention = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/cogview3/test_cogview3plus.py b/tests/pipelines/cogview3/test_cogview3plus.py index 102a5c66e624..4619de81d535 100644 --- a/tests/pipelines/cogview3/test_cogview3plus.py +++ b/tests/pipelines/cogview3/test_cogview3plus.py @@ -57,6 +57,7 @@ class CogView3PlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) test_xformers_attention = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/consisid/test_consisid.py b/tests/pipelines/consisid/test_consisid.py index f949cfb2d36d..a39c17bb4f79 100644 --- a/tests/pipelines/consisid/test_consisid.py +++ b/tests/pipelines/consisid/test_consisid.py @@ -59,6 +59,7 @@ class ConsisIDPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) test_xformers_attention = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py index e0fc00171031..e2c0c60ddfa4 100644 --- a/tests/pipelines/controlnet/test_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -127,6 +127,7 @@ class ControlNetPipelineFastTests( image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index e75fe8903134..dda6339427f8 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -76,6 +76,7 @@ class StableDiffusionXLControlNetPipelineFastTests( image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux.py b/tests/pipelines/controlnet_flux/test_controlnet_flux.py index 8b9852dbec6e..cce14342699c 100644 --- a/tests/pipelines/controlnet_flux/test_controlnet_flux.py +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux.py @@ -51,6 +51,7 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin): params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py index e1894d555c3c..04daca27c3dd 100644 --- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py +++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py @@ -60,6 +60,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes ) batch_params = frozenset(["prompt", "negative_prompt"]) test_layerwise_casting = True + test_group_offloading = True def get_dummy_components( self, num_controlnet_layers: int = 3, qk_norm: Optional[str] = "rms_norm", use_dual_attention=False diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py index 4c184db99630..1da5b52bd050 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py @@ -140,6 +140,7 @@ class ControlNetXSPipelineFastTests( test_attention_slicing = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py index 7537efe0bbf9..644bb669d8e8 100644 --- a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py +++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py @@ -79,6 +79,7 @@ class StableDiffusionXLControlNetXSPipelineFastTests( test_attention_slicing = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index a3bc1658de74..0735fc5fbef0 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -32,6 +32,7 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapte # there is no xformers processor for Flux test_xformers_attention = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/flux/test_pipeline_flux_control.py b/tests/pipelines/flux/test_pipeline_flux_control.py index 7fdb19327213..5bb7cdec034c 100644 --- a/tests/pipelines/flux/test_pipeline_flux_control.py +++ b/tests/pipelines/flux/test_pipeline_flux_control.py @@ -23,6 +23,7 @@ class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin): # there is no xformers processor for Flux test_xformers_attention = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/flux/test_pipeline_flux_fill.py b/tests/pipelines/flux/test_pipeline_flux_fill.py index 620ecb8a831f..1d488db71ced 100644 --- a/tests/pipelines/flux/test_pipeline_flux_fill.py +++ b/tests/pipelines/flux/test_pipeline_flux_fill.py @@ -24,6 +24,7 @@ class FluxFillPipelineFastTests(unittest.TestCase, PipelineTesterMixin): batch_params = frozenset(["prompt"]) test_xformers_attention = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py index ce03381f90d2..ae7347caff6a 100644 --- a/tests/pipelines/hunyuan_video/test_hunyuan_video.py +++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py @@ -54,6 +54,7 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): # there is no xformers processor for Flux test_xformers_attention = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/latte/test_latte.py b/tests/pipelines/latte/test_latte.py index 2d5bcba8237a..a28eda1f40d7 100644 --- a/tests/pipelines/latte/test_latte.py +++ b/tests/pipelines/latte/test_latte.py @@ -53,6 +53,7 @@ class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase): required_optional_params = PipelineTesterMixin.required_optional_params test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/ltx/test_ltx.py b/tests/pipelines/ltx/test_ltx.py index 64b366ea8ad6..4f72729fc9ce 100644 --- a/tests/pipelines/ltx/test_ltx.py +++ b/tests/pipelines/ltx/test_ltx.py @@ -47,6 +47,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) test_xformers_attention = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/lumina/test_lumina_nextdit.py b/tests/pipelines/lumina/test_lumina_nextdit.py index 7c1923313b23..18dcdef98d7d 100644 --- a/tests/pipelines/lumina/test_lumina_nextdit.py +++ b/tests/pipelines/lumina/test_lumina_nextdit.py @@ -33,6 +33,7 @@ class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterM supports_dduf = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/mochi/test_mochi.py b/tests/pipelines/mochi/test_mochi.py index b7bb844ff311..ed41e82aca9f 100644 --- a/tests/pipelines/mochi/test_mochi.py +++ b/tests/pipelines/mochi/test_mochi.py @@ -56,6 +56,7 @@ class MochiPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) test_xformers_attention = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/pia/test_pia.py b/tests/pipelines/pia/test_pia.py index 747be38d495c..ead6c2b208de 100644 --- a/tests/pipelines/pia/test_pia.py +++ b/tests/pipelines/pia/test_pia.py @@ -56,6 +56,7 @@ class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, PipelineFr ] ) test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): cross_attention_dim = 8 diff --git a/tests/pipelines/pixart_alpha/test_pixart.py b/tests/pipelines/pixart_alpha/test_pixart.py index 7df6656f6f87..ae0f9b50f74e 100644 --- a/tests/pipelines/pixart_alpha/test_pixart.py +++ b/tests/pipelines/pixart_alpha/test_pixart.py @@ -51,6 +51,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): required_optional_params = PipelineTesterMixin.required_optional_params test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/pixart_sigma/test_pixart.py b/tests/pipelines/pixart_sigma/test_pixart.py index 6e265b9d5eb8..9bfeb691d770 100644 --- a/tests/pipelines/pixart_sigma/test_pixart.py +++ b/tests/pipelines/pixart_sigma/test_pixart.py @@ -56,6 +56,7 @@ class PixArtSigmaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): required_optional_params = PipelineTesterMixin.required_optional_params test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/sana/test_sana.py b/tests/pipelines/sana/test_sana.py index f70f9d91f19c..34df808d3320 100644 --- a/tests/pipelines/sana/test_sana.py +++ b/tests/pipelines/sana/test_sana.py @@ -53,6 +53,7 @@ class SanaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) test_xformers_attention = False test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 1e700bed03f8..d60092c4e5cb 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -124,6 +124,7 @@ class StableDiffusionPipelineFastTests( image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self, time_cond_proj_dim=None): cross_attention_dim = 8 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py index 10b8a1818a29..a7375d37eccd 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py @@ -76,6 +76,7 @@ class StableDiffusion2PipelineFastTests( image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py index df37090eeba2..24d03a035066 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py @@ -36,6 +36,7 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin): ) batch_params = frozenset(["prompt", "negative_prompt"]) test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index f1422022a7aa..dfd1c9c37271 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -76,6 +76,7 @@ class StableDiffusionXLPipelineFastTests( image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"}) test_layerwise_casting = True + test_group_offloading = True def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 139778994b87..8f5a6ea17651 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -28,6 +28,7 @@ StableDiffusionXLPipeline, UNet2DConditionModel, ) +from diffusers.hooks import apply_group_offloading from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin from diffusers.models.attention_processor import AttnProcessor @@ -45,6 +46,7 @@ require_accelerator, require_hf_hub_version_greater, require_torch, + require_torch_gpu, require_transformers_version_greater, skip_mps, torch_device, @@ -988,6 +990,7 @@ class PipelineTesterMixin: test_xformers_attention = True test_layerwise_casting = False + test_group_offloading = False supports_dduf = True def get_generator(self, seed): @@ -2042,6 +2045,70 @@ def test_layerwise_casting_inference(self): inputs = self.get_dummy_inputs(torch_device) _ = pipe(**inputs)[0] + @require_torch_gpu + def test_group_offloading_inference(self): + if not self.test_group_offloading: + return + + def create_pipe(): + torch.manual_seed(0) + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + return pipe + + def enable_group_offloading_on_component(pipe, group_offloading_kwargs): + # We intentionally don't test VAE's here. This is because some tests enable tiling on the VAE. If + # tiling is enabled and a forward pass is run, when cuda streams are used, the execution order of + # the layers is not traced correctly. This causes errors. For apply group offloading to VAE, a + # warmup forward pass (even with dummy small inputs) is recommended. + for component_name in [ + "text_encoder", + "text_encoder_2", + "text_encoder_3", + "transformer", + "unet", + "controlnet", + ]: + if not hasattr(pipe, component_name): + continue + component = getattr(pipe, component_name) + apply_group_offloading(component, **group_offloading_kwargs) + self.assertTrue( + all( + module._diffusers_hook.get_hook("group_offloading") is not None + for module in component.modules() + if hasattr(module, "_diffusers_hook") + ) + ) + for component_name in ["vae", "vqvae"]: + if hasattr(pipe, component_name): + getattr(pipe, component_name).to(torch_device) + + def run_forward(pipe): + torch.manual_seed(0) + inputs = self.get_dummy_inputs(torch_device) + return pipe(**inputs)[0] + + pipe = create_pipe().to(torch_device) + output_without_group_offloading = run_forward(pipe) + + pipe = create_pipe() + enable_group_offloading_on_component(pipe, {"offload_type": "block_level", "num_blocks_per_group": 1}) + output_with_group_offloading1 = run_forward(pipe) + + pipe = create_pipe() + enable_group_offloading_on_component(pipe, {"offload_type": "leaf_level"}) + output_with_group_offloading2 = run_forward(pipe) + + if torch.is_tensor(output_without_group_offloading): + output_without_group_offloading = output_without_group_offloading.detach().cpu().numpy() + output_with_group_offloading1 = output_with_group_offloading1.detach().cpu().numpy() + output_with_group_offloading2 = output_with_group_offloading2.detach().cpu().numpy() + + self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-4)) + self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-4)) + @is_staging_test class PipelinePushToHubTester(unittest.TestCase): From aaa9a53447a54b198c6d07b00bae3a17679f0800 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 26 Jan 2025 23:50:31 +0100 Subject: [PATCH 23/37] update docs --- docs/source/en/api/utilities.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/en/api/utilities.md b/docs/source/en/api/utilities.md index b0b78928fb4b..b653cdafbb28 100644 --- a/docs/source/en/api/utilities.md +++ b/docs/source/en/api/utilities.md @@ -45,3 +45,7 @@ Utility and helper functions for working with 🤗 Diffusers. ## apply_layerwise_casting [[autodoc]] hooks.layerwise_casting.apply_layerwise_casting + +## apply_group_offloading + +[[autodoc]] hooks.group_offloading.apply_group_offloading From edf81035505a6e5eea2e1261bdd798d705de37e9 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 27 Jan 2025 00:19:25 +0100 Subject: [PATCH 24/37] add layernorm and groupnorm --- src/diffusers/hooks/group_offloading.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index d892d41d1ad0..b9d6318b730c 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -16,12 +16,15 @@ from typing import Dict, List, Optional, Set, Tuple import torch -from accelerate.utils import send_to_device -from ..utils import get_logger +from ..utils import get_logger, is_accelerate_available from .hooks import HookRegistry, ModelHook +if is_accelerate_available(): + from accelerate.utils import send_to_device + + logger = get_logger(__name__) # pylint: disable=invalid-name @@ -34,6 +37,7 @@ torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, torch.nn.Linear, + torch.nn.LayerNorm, torch.nn.GroupNorm, ) # fmt: on From 24f92739b463c2105d692766cef19b4e04956453 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 4 Feb 2025 06:27:29 +0100 Subject: [PATCH 25/37] address review comments --- src/diffusers/hooks/group_offloading.py | 2 +- .../models/autoencoders/autoencoder_oobleck.py | 1 + .../models/autoencoders/consistency_decoder_vae.py | 2 ++ src/diffusers/models/autoencoders/vq_model.py | 1 + src/diffusers/models/modeling_utils.py | 1 + .../models/transformers/dit_transformer_2d.py | 1 + .../models/transformers/hunyuan_transformer_2d.py | 1 + src/diffusers/pipelines/pipeline_utils.py | 14 ++++---------- .../test_models_autoencoder_oobleck.py | 9 --------- .../test_models_consistency_decoder_vae.py | 4 ---- tests/models/autoencoders/test_models_vq.py | 4 ---- tests/models/test_modeling_common.py | 3 +++ .../transformers/test_models_dit_transformer2d.py | 8 -------- .../test_models_transformer_hunyuan_dit.py | 8 -------- tests/pipelines/test_pipelines_common.py | 2 ++ 15 files changed, 17 insertions(+), 44 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index b9d6318b730c..1c6353ea4ff2 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -343,7 +343,7 @@ def _apply_group_offloading_block_level( for i in range(0, len(submodule), num_blocks_per_group): current_modules = submodule[i : i + num_blocks_per_group] group = ModuleGroup( - modules=submodule[i : i + num_blocks_per_group], + modules=current_modules, offload_device=offload_device, onload_device=onload_device, offload_leader=current_modules[-1], diff --git a/src/diffusers/models/autoencoders/autoencoder_oobleck.py b/src/diffusers/models/autoencoders/autoencoder_oobleck.py index e8e372a709d7..a8c2a2fd3840 100644 --- a/src/diffusers/models/autoencoders/autoencoder_oobleck.py +++ b/src/diffusers/models/autoencoders/autoencoder_oobleck.py @@ -317,6 +317,7 @@ class AutoencoderOobleck(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = False + _supports_group_offloading = False @register_to_config def __init__( diff --git a/src/diffusers/models/autoencoders/consistency_decoder_vae.py b/src/diffusers/models/autoencoders/consistency_decoder_vae.py index 4759b9141242..a0b3309dc522 100644 --- a/src/diffusers/models/autoencoders/consistency_decoder_vae.py +++ b/src/diffusers/models/autoencoders/consistency_decoder_vae.py @@ -68,6 +68,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ``` """ + _supports_group_offloading = False + @register_to_config def __init__( self, diff --git a/src/diffusers/models/autoencoders/vq_model.py b/src/diffusers/models/autoencoders/vq_model.py index e754e134b35f..84215389bf6a 100644 --- a/src/diffusers/models/autoencoders/vq_model.py +++ b/src/diffusers/models/autoencoders/vq_model.py @@ -72,6 +72,7 @@ class VQModel(ModelMixin, ConfigMixin): """ _skip_layerwise_casting_patterns = ["quantize"] + _supports_group_offloading = False @register_to_config def __init__( diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 9d7a37cd7bf8..9f92d0f91f71 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -174,6 +174,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): _no_split_modules = None _keep_in_fp32_modules = None _skip_layerwise_casting_patterns = None + _supports_group_offloading = True def __init__(self): super().__init__() diff --git a/src/diffusers/models/transformers/dit_transformer_2d.py b/src/diffusers/models/transformers/dit_transformer_2d.py index 6e83f49db71c..cdc0738050e4 100644 --- a/src/diffusers/models/transformers/dit_transformer_2d.py +++ b/src/diffusers/models/transformers/dit_transformer_2d.py @@ -66,6 +66,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin): _skip_layerwise_casting_patterns = ["pos_embed", "norm"] _supports_gradient_checkpointing = True + _supports_group_offloading = False @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py index 13aa7d076d03..5608a0f605a6 100644 --- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py +++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py @@ -245,6 +245,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin): """ _skip_layerwise_casting_patterns = ["pos_embed", "norm", "pooler"] + _supports_group_offloading = False @register_to_config def __init__( diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index f776d3c83c24..a747d2955b5d 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1020,25 +1020,19 @@ def _execution_device(self): [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from Accelerate's module hooks. """ - diffusers_hook_device = None + # When apply group offloading at the leaf_level, we're in the same situation as accelerate's sequential + # offloading. We need to return the onload device of the group offloading hooks so that the intermediates + # required for computation (latents, prompt embeddings, etc.) can be created on the correct device. for name, model in self.components.items(): if not isinstance(model, torch.nn.Module): continue - for submodule in model.modules(): if not hasattr(submodule, "_diffusers_hook"): continue registry = submodule._diffusers_hook hook = registry.get_hook("group_offloading") if hook is not None: - diffusers_hook_device = hook.group.onload_device - break - - if diffusers_hook_device is not None: - break - - if diffusers_hook_device is not None: - return diffusers_hook_device + return hook.group.onload_device for name, model in self.components.items(): if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload: diff --git a/tests/models/autoencoders/test_models_autoencoder_oobleck.py b/tests/models/autoencoders/test_models_autoencoder_oobleck.py index 5e137451914e..1f922a9842ee 100644 --- a/tests/models/autoencoders/test_models_autoencoder_oobleck.py +++ b/tests/models/autoencoders/test_models_autoencoder_oobleck.py @@ -132,15 +132,6 @@ def test_layerwise_casting_inference(self): def test_layerwise_casting_memory(self): pass - @unittest.skip( - "The convolution layers of AutoencoderOobleck are wrapped with torch.nn.utils.weight_norm. This causes the hook's pre_forward to not " - "cast the module weights to the expected device (as required by forward pass). As a result, forward pass errors out. To fix:\n" - "1. Make sure `nn::Module::to(device)` works with `torch.nn.utils.weight_norm` wrapped convolution layer.\n" - "2. Unskip this test." - ) - def test_group_offloading(self): - pass - @slow class AutoencoderOobleckIntegrationTests(unittest.TestCase): diff --git a/tests/models/autoencoders/test_models_consistency_decoder_vae.py b/tests/models/autoencoders/test_models_consistency_decoder_vae.py index 4f1af127cd44..77977a78d83b 100644 --- a/tests/models/autoencoders/test_models_consistency_decoder_vae.py +++ b/tests/models/autoencoders/test_models_consistency_decoder_vae.py @@ -155,10 +155,6 @@ def test_enable_disable_slicing(self): "Without slicing outputs should match with the outputs when slicing is manually disabled.", ) - @unittest.skip("Not quite sure why this test fails and unable to debug.") - def test_group_offloading(self): - pass - @slow class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase): diff --git a/tests/models/autoencoders/test_models_vq.py b/tests/models/autoencoders/test_models_vq.py index 33822ed7a882..77abe139d785 100644 --- a/tests/models/autoencoders/test_models_vq.py +++ b/tests/models/autoencoders/test_models_vq.py @@ -116,7 +116,3 @@ def test_loss_pretrained(self): expected_output = torch.tensor([0.1936]) # fmt: on self.assertTrue(torch.allclose(output, expected_output, atol=1e-3)) - - @unittest.skip("Group offloading for torch::nn::Embedding layers is not yet supported.") - def test_group_offloading(self): - pass diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 4172be4dca56..e663198eb168 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1447,6 +1447,9 @@ def run_forward(model): return model(**inputs_dict)[0] model = self.model_class(**init_dict) + if not getattr(model, "_supports_group_offloading", True): + return + model.to(torch_device) output_without_group_offloading = run_forward(model) diff --git a/tests/models/transformers/test_models_dit_transformer2d.py b/tests/models/transformers/test_models_dit_transformer2d.py index 76db63647404..5f4a2f587e92 100644 --- a/tests/models/transformers/test_models_dit_transformer2d.py +++ b/tests/models/transformers/test_models_dit_transformer2d.py @@ -100,11 +100,3 @@ def test_correct_class_remapping_from_pretrained_config(self): def test_correct_class_remapping(self): model = Transformer2DModel.from_pretrained("facebook/DiT-XL-2-256", subfolder="transformer") assert isinstance(model, DiTTransformer2DModel) - - @unittest.skip( - "This model uses a direct call to self.transformer_blocks[0].norm1.emb. This causes attached hooks to not be invoked " - "when block offloading is enabled. In order for it to work, the model should correctly first invoke the forward pass " - "the transformer blocks, so that weights can be onloaded, instead of directly invoking a submodule of the block." - ) - def test_group_offloading(self): - pass diff --git a/tests/models/transformers/test_models_transformer_hunyuan_dit.py b/tests/models/transformers/test_models_transformer_hunyuan_dit.py index 1889d2aaaf4a..ea05abed38d9 100644 --- a/tests/models/transformers/test_models_transformer_hunyuan_dit.py +++ b/tests/models/transformers/test_models_transformer_hunyuan_dit.py @@ -111,11 +111,3 @@ def test_set_xformers_attn_processor_for_determinism(self): @unittest.skip("HunyuanDIT use a custom processor HunyuanAttnProcessor2_0") def test_set_attn_processor_for_determinism(self): pass - - @unittest.skip( - "This model uses a direct call to F.multi_head_attention_forward instead of using a torch.nn.Module layer. This " - "usage is not yet supported with group offloading, because the call directly operates on the weights of the module. " - "We attach hooks correctly, but the onloading does not occur because the torch::nn::Module::forward is never invoked." - ) - def test_group_offloading(self): - pass diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 9eded55866ca..50f19c37ad2a 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -2075,6 +2075,8 @@ def enable_group_offloading_on_component(pipe, group_offloading_kwargs): if not hasattr(pipe, component_name): continue component = getattr(pipe, component_name) + if not getattr(component, "_supports_group_offloading", True): + continue apply_group_offloading(component, **group_offloading_kwargs) self.assertTrue( all( From 8f10d05e5b0225b37d10664d7111790ebb86eeef Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 4 Feb 2025 08:50:19 +0100 Subject: [PATCH 26/37] improve tests; add docs --- docs/source/en/optimization/memory.md | 40 ++++++ src/diffusers/hooks/group_offloading.py | 116 ++++++++++++----- src/diffusers/models/modeling_utils.py | 51 +++++++- tests/hooks/test_group_offloading.py | 159 +++++++++++++++++++++++ tests/models/test_modeling_common.py | 17 +-- tests/pipelines/test_pipelines_common.py | 9 +- 6 files changed, 347 insertions(+), 45 deletions(-) create mode 100644 tests/hooks/test_group_offloading.py diff --git a/docs/source/en/optimization/memory.md b/docs/source/en/optimization/memory.md index 4cdc60401914..99fd981fdc79 100644 --- a/docs/source/en/optimization/memory.md +++ b/docs/source/en/optimization/memory.md @@ -158,6 +158,46 @@ In order to properly offload models after they're called, it is required to run +## Group offloading + +Group offloading is a middle ground between the two above methods. It works by offloading groups of internal layers (either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method is more memory-efficient than model-level offloading. It is also faster than sequential-level offloading, as the number of device synchronizations is reduced. + +Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability to overlap data transfer and computation to reduce the overall execution time. This is enabled using layer prefetching with CUDA streams, i.e., the layer that is to be executed next starts onloading to the accelerator device while the current layer is being executed - this increases the memory requirements slightly. Note that this implementation also supports leaf-level offloading but can be made much faster when using streams. + +To enable group offloading, either call the [`~ModelMixin.enable_group_offloading`] method on the model or pass use [`~hooks.group_offloading.apply_group_offloading`]: + +```python +import torch +from diffusers import CogVideoXPipeline +from diffusers.hooks import apply_group_offloading +from diffusers.utils import export_to_video + +# Load the pipeline +onload_device = torch.device("cuda") +offload_device = torch.device("cpu") +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) + +# We can utilize the enable_group_offloading method for Diffusers model implementations +pipe.transformer.enable_group_offloading(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True) + +# For any other model implementations, the apply_group_offloading function can be used +apply_group_offloading(pipe.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2) +apply_group_offloading(pipe.vae, onload_device=onload_device, offload_type="leaf_level") + +prompt = ( + "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + "atmosphere of this unique musical performance." +) +video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] +# This utilized about 14.79 GB. It can be further reduced by using tiling and using leaf_level offloading throughout the pipeline. +print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB") +export_to_video(video, "output.mp4", fps=8) +``` + ## FP8 layerwise weight-casting PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting. diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 1c6353ea4ff2..3aa91c2944c1 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -37,7 +37,8 @@ torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, torch.nn.Linear, - torch.nn.LayerNorm, torch.nn.GroupNorm, + # TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX + # because of double invocation of the same norm layer in CogVideoXLayerNorm ) # fmt: on @@ -120,15 +121,13 @@ class GroupOffloadingHook(ModelHook): def __init__( self, group: ModuleGroup, - offload_on_init: bool = True, next_group: Optional[ModuleGroup] = None, ) -> None: self.group = group - self.offload_on_init = offload_on_init self.next_group = next_group def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: - if self.offload_on_init and self.group.offload_leader == module: + if self.group.offload_leader == module: self.group.offload_() return module @@ -262,14 +261,78 @@ def pre_forward(self, module, *args, **kwargs): def apply_group_offloading( module: torch.nn.Module, + onload_device: torch.device, + offload_device: torch.device = torch.device("cpu"), offload_type: str = "block_level", num_blocks_per_group: Optional[int] = None, - offload_device: torch.device = torch.device("cpu"), - onload_device: torch.device = torch.device("cuda"), - force_offload: bool = True, non_blocking: bool = False, use_stream: bool = False, ) -> None: + r""" + Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and + where it is beneficial, we need to first provide some context on how other supported offloading methods work. + + Typically, offloading is done at two levels: + - Module-level: In Diffusers, this can be enabled using the `ModelMixin::enable_model_cpu_offload()` method. It + works by offloading each component of a pipeline to the CPU for storage, and onloading to the accelerator device + when needed for computation. This method is more memory-efficient than keeping all components on the accelerator, + but the memory requirements are still quite high. For this method to work, one needs memory equivalent to size of + the model in runtime dtype + size of largest intermediate activation tensors to be able to complete the forward + pass. + - Leaf-level: In Diffusers, this can be enabled using the `ModelMixin::enable_sequential_cpu_offload()` method. It + works by offloading the lowest leaf-level parameters of the computation graph to the CPU for storage, and + onloading only the leafs to the accelerator device for computation. This uses the lowest amount of accelerator + memory, but can be slower due to the excessive number of device synchronizations. + + Group offloading is a middle ground between the two methods. It works by offloading groups of internal layers, + (either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method is more memory-efficient than module-level + offloading. It is also faster than leaf-level offloading, as the number of device synchronizations is reduced. + + Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability to + overlap data transfer and computation to reduce the overall execution time. This is enabled using layer prefetching + with streams, i.e., the layer that is to be executed next starts onloading to the accelerator device while the + current layer is being executed - this increases the memory requirements slightly. Note that this implementation + also supports leaf-level offloading but can be made much faster when using streams. + + Args: + module (`torch.nn.Module`): + The module to which group offloading is applied. + onload_device (`torch.device`): + The device to which the group of modules are onloaded. + offload_device (`torch.device`, defaults to `torch.device("cpu")`): + The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU. + offload_type (`str`, defaults to "block_level"): + The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is + "block_level". + num_blocks_per_group (`int`, *optional*): + The number of blocks per group when using offload_type="block_level". This is required when using + offload_type="block_level". + non_blocking (`bool`, defaults to `False`): + If True, offloading and onloading is done with non-blocking data transfer. + use_stream (`bool`, defaults to `False`): + If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for + overlapping computation and data transfer. + + Example: + ```python + >>> from diffusers import CogVideoXTransformer3DModel + >>> from diffusers.hooks import apply_group_offloading + + >>> transformer = CogVideoXTransformer3DModel.from_pretrained( + ... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16 + ... ) + + >>> apply_group_offloading( + ... transformer, + ... onload_device=torch.device("cuda"), + ... offload_device=torch.device("cpu"), + ... offload_type="block_level", + ... num_blocks_per_group=2, + ... use_stream=True, + ... ) + ``` + """ + stream = None if use_stream: if torch.cuda.is_available(): @@ -279,15 +342,13 @@ def apply_group_offloading( if offload_type == "block_level": if num_blocks_per_group is None: - raise ValueError("num_blocks_per_group must be provided when using offload_group_patterns='block_level'.") + raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.") _apply_group_offloading_block_level( - module, num_blocks_per_group, offload_device, onload_device, force_offload, non_blocking, stream=stream + module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream ) elif offload_type == "leaf_level": - _apply_group_offloading_leaf_level( - module, offload_device, onload_device, force_offload, non_blocking, stream=stream - ) + _apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream) else: raise ValueError(f"Unsupported offload_type: {offload_type}") @@ -297,7 +358,6 @@ def _apply_group_offloading_block_level( num_blocks_per_group: int, offload_device: torch.device, onload_device: torch.device, - force_offload: bool, non_blocking: bool, stream: Optional[torch.cuda.Stream] = None, ) -> None: @@ -312,9 +372,6 @@ def _apply_group_offloading_block_level( The device to which the group of modules are offloaded. This should typically be the CPU. onload_device (`torch.device`): The device to which the group of modules are onloaded. - force_offload (`bool`): - If True, all module groups are offloaded to the offload_device. If False, only layers that match - `offload_group_patterns` are offloaded to the offload_device. non_blocking (`bool`): If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation and data transfer. @@ -362,10 +419,9 @@ def _apply_group_offloading_block_level( next_group = ( matched_module_groups[i + 1] if i + 1 < len(matched_module_groups) and stream is not None else None ) - should_offload = force_offload or i > 0 for group_module in group.modules: - _apply_group_offloading_hook(group_module, group, should_offload, next_group) + _apply_group_offloading_hook(group_module, group, next_group) # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately # when the forward pass of this module is called. This is because the top-level module is not @@ -392,14 +448,13 @@ def _apply_group_offloading_block_level( onload_self=True, ) next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None - _apply_group_offloading_hook(module, unmatched_group, force_offload, next_group) + _apply_group_offloading_hook(module, unmatched_group, next_group) def _apply_group_offloading_leaf_level( module: torch.nn.Module, offload_device: torch.device, onload_device: torch.device, - force_offload: bool, non_blocking: bool, stream: Optional[torch.cuda.Stream] = None, ) -> None: @@ -416,9 +471,6 @@ def _apply_group_offloading_leaf_level( The device to which the group of modules are offloaded. This should typically be the CPU. onload_device (`torch.device`): The device to which the group of modules are onloaded. - force_offload (`bool`): - If True, all module groups are offloaded to the offload_device. If False, only layers that match - `offload_group_patterns` are offloaded to the offload_device. non_blocking (`bool`): If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation and data transfer. @@ -450,7 +502,7 @@ def _apply_group_offloading_leaf_level( cpu_param_dict=cpu_param_dict, onload_self=True, ) - _apply_group_offloading_hook(submodule, group, True, None) + _apply_group_offloading_hook(submodule, group, None) modules_with_group_offloading.add(name) # Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass @@ -495,7 +547,7 @@ def _apply_group_offloading_leaf_level( cpu_param_dict=cpu_param_dict, onload_self=True, ) - _apply_group_offloading_hook(parent_module, group, True, None) + _apply_group_offloading_hook(parent_module, group, None) # This is a dummy group that will handle lazy prefetching from the top-level module to the first leaf module unmatched_group = ModuleGroup( @@ -516,15 +568,14 @@ def _apply_group_offloading_leaf_level( # and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the # execution order and apply prefetching in the correct order. if stream is None: - _apply_group_offloading_hook(module, unmatched_group, force_offload, None) + _apply_group_offloading_hook(module, unmatched_group, None) else: - _apply_lazy_group_offloading_hook(module, unmatched_group, force_offload, None) + _apply_lazy_group_offloading_hook(module, unmatched_group, None) def _apply_group_offloading_hook( module: torch.nn.Module, group: ModuleGroup, - offload_on_init: bool, next_group: Optional[ModuleGroup] = None, ) -> None: registry = HookRegistry.check_if_exists_or_initialize(module) @@ -532,14 +583,13 @@ def _apply_group_offloading_hook( # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent # is the current module. In such cases, we don't want to overwrite the existing group offloading hook. if registry.get_hook(_GROUP_OFFLOADING) is None: - hook = GroupOffloadingHook(group, offload_on_init, next_group) + hook = GroupOffloadingHook(group, next_group) registry.register_hook(hook, _GROUP_OFFLOADING) def _apply_lazy_group_offloading_hook( module: torch.nn.Module, group: ModuleGroup, - offload_on_init: bool, next_group: Optional[ModuleGroup] = None, ) -> None: registry = HookRegistry.check_if_exists_or_initialize(module) @@ -547,7 +597,7 @@ def _apply_lazy_group_offloading_hook( # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent # is the current module. In such cases, we don't want to overwrite the existing group offloading hook. if registry.get_hook(_GROUP_OFFLOADING) is None: - hook = GroupOffloadingHook(group, offload_on_init, next_group) + hook = GroupOffloadingHook(group, next_group) registry.register_hook(hook, _GROUP_OFFLOADING) lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook() @@ -561,14 +611,12 @@ def _gather_parameters_with_no_group_offloading_parent( for name, parameter in module.named_parameters(): has_parent_with_group_offloading = False atoms = name.split(".") - while len(atoms) > 0: parent_name = ".".join(atoms) if parent_name in modules_with_group_offloading: has_parent_with_group_offloading = True break atoms.pop() - if not has_parent_with_group_offloading: parameters.append((name, parameter)) return parameters @@ -581,14 +629,12 @@ def _gather_buffers_with_no_group_offloading_parent( for name, buffer in module.named_buffers(): has_parent_with_group_offloading = False atoms = name.split(".") - while len(atoms) > 0: parent_name = ".".join(atoms) if parent_name in modules_with_group_offloading: has_parent_with_group_offloading = True break atoms.pop() - if not has_parent_with_group_offloading: buffers.append((name, buffer)) return buffers diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 9f92d0f91f71..adcedd4637e0 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -33,7 +33,7 @@ from torch import Tensor, nn from .. import __version__ -from ..hooks import apply_layerwise_casting +from ..hooks import apply_group_offloading, apply_layerwise_casting from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( @@ -446,6 +446,55 @@ def enable_layerwise_casting( self, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking ) + def enable_group_offloading( + self, + onload_device: torch.device, + offload_device: torch.device = torch.device("cpu"), + offload_type: str = "block_level", + num_blocks_per_group: Optional[int] = None, + non_blocking: bool = False, + use_stream: bool = False, + ) -> None: + r""" + Activates group offloading for the current model. + + See [`~hooks.group_offloading.apply_group_offloading`] for more information. + + Example: + + ```python + >>> from diffusers import CogVideoXTransformer3DModel + + >>> transformer = CogVideoXTransformer3DModel.from_pretrained( + ... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16 + ... ) + + >>> transformer.enable_group_offloading( + ... onload_device=torch.device("cuda"), + ... offload_device=torch.device("cpu"), + ... offload_type="leaf_level", + ... use_stream=True, + ... ) + ``` + """ + if getattr(self, "enable_tiling", None) is not None and getattr(self, "use_tiling", False) and use_stream: + msg = ( + "Applying group offloading on autoencoders, with CUDA streams, may not work as expected if the first " + "forward pass is executed with tiling enabled. Please make sure to either:\n" + "1. Run a forward pass with small input shapes.\n" + "2. Or, run a forward pass with tiling disabled (can still use small dummy inputs)." + ) + logger.warning(msg) + if not self._supports_group_offloading: + raise ValueError( + f"{self.__class__.__name__} does not support group offloading. Please make sure to set the boolean attribute " + f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please " + f"open an issue at https://github.com/huggingface/diffusers/issues." + ) + apply_group_offloading( + self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream + ) + def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/tests/hooks/test_group_offloading.py b/tests/hooks/test_group_offloading.py new file mode 100644 index 000000000000..f13515b52fac --- /dev/null +++ b/tests/hooks/test_group_offloading.py @@ -0,0 +1,159 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import torch + +from diffusers.models import ModelMixin +from diffusers.utils.logging import get_logger +from diffusers.utils.testing_utils import require_torch_gpu, torch_device + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +class DummyBlock(torch.nn.Module): + def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None: + super().__init__() + + self.proj_in = torch.nn.Linear(in_features, hidden_features) + self.activation = torch.nn.ReLU() + self.proj_out = torch.nn.Linear(hidden_features, out_features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj_in(x) + x = self.activation(x) + x = self.proj_out(x) + return x + + +class DummyModel(ModelMixin): + def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None: + super().__init__() + + self.linear_1 = torch.nn.Linear(in_features, hidden_features) + self.activation = torch.nn.ReLU() + self.blocks = torch.nn.ModuleList( + [DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)] + ) + self.linear_2 = torch.nn.Linear(hidden_features, out_features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear_1(x) + x = self.activation(x) + for block in self.blocks: + x = block(x) + x = self.linear_2(x) + return x + + +@require_torch_gpu +class GroupOffloadTests(unittest.TestCase): + in_features = 64 + hidden_features = 256 + out_features = 64 + num_layers = 4 + + def setUp(self): + with torch.no_grad(): + self.model = self.get_model() + self.input = torch.randn((4, self.in_features)).to(torch_device) + + def tearDown(self): + super().tearDown() + + del self.model + del self.input + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + def get_model(self): + torch.manual_seed(0) + return DummyModel( + in_features=self.in_features, + hidden_features=self.hidden_features, + out_features=self.out_features, + num_layers=self.num_layers, + ) + + def test_offloading_forward_pass(self): + @torch.no_grad() + def run_forward(model): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + self.assertTrue( + all( + module._diffusers_hook.get_hook("group_offloading") is not None + for module in model.modules() + if hasattr(module, "_diffusers_hook") + ) + ) + model.eval() + output = model(self.input)[0].cpu() + max_memory_reserved = torch.cuda.max_memory_allocated() + return output, max_memory_reserved + + self.model.to(torch_device) + output_without_group_offloading, mem_baseline = run_forward(self.model) + self.model.to("cpu") + + model = self.get_model() + model.enable_group_offloading(torch_device, offload_type="block_level", num_blocks_per_group=3) + output_with_group_offloading1, mem1 = run_forward(model) + + model = self.get_model() + model.enable_group_offloading(torch_device, offload_type="block_level", num_blocks_per_group=1) + output_with_group_offloading2, mem2 = run_forward(model) + + model = self.get_model() + model.enable_group_offloading( + torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True + ) + output_with_group_offloading3, mem3 = run_forward(model) + + model = self.get_model() + model.enable_group_offloading(torch_device, offload_type="leaf_level") + output_with_group_offloading4, mem4 = run_forward(model) + + model = self.get_model() + model.enable_group_offloading(torch_device, offload_type="leaf_level", use_stream=True) + output_with_group_offloading5, mem5 = run_forward(model) + + # Precision assertions - offloading should not impact the output + self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5)) + self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5)) + self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5)) + self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5)) + self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading5, atol=1e-5)) + + # Memory assertions - offloading should reduce memory usage + self.assertTrue(mem4 <= mem5 < mem2 < mem3 < mem1 < mem_baseline) + + def test_error_raised_if_streams_used_and_no_cuda_device(self): + original_is_available = torch.cuda.is_available + torch.cuda.is_available = lambda: False + with self.assertRaises(ValueError): + self.model.enable_group_offloading( + onload_device=torch.device("cuda"), offload_type="leaf_level", use_stream=True + ) + torch.cuda.is_available = original_is_available + + def test_error_raised_if_supports_group_offloading_false(self): + self.model._supports_group_offloading = False + with self.assertRaisesRegex(ValueError, "does not support group offloading"): + self.model.enable_group_offloading(onload_device=torch.device("cuda")) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index e663198eb168..622ce2ce7081 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -37,7 +37,6 @@ from parameterized import parameterized from requests.exceptions import HTTPError -from diffusers.hooks import apply_group_offloading from diffusers.models import UNet2DConditionModel from diffusers.models.attention_processor import ( AttnProcessor, @@ -1434,6 +1433,7 @@ def test_group_offloading(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() torch.manual_seed(0) + @torch.no_grad() def run_forward(model): self.assertTrue( all( @@ -1443,34 +1443,35 @@ def run_forward(model): ) ) model.eval() - with torch.no_grad(): - return model(**inputs_dict)[0] + return model(**inputs_dict)[0] model = self.model_class(**init_dict) if not getattr(model, "_supports_group_offloading", True): return - + model.to(torch_device) output_without_group_offloading = run_forward(model) torch.manual_seed(0) model = self.model_class(**init_dict) - apply_group_offloading(model, offload_type="block_level", num_blocks_per_group=1) + model.enable_group_offloading(torch_device, offload_type="block_level", num_blocks_per_group=1) output_with_group_offloading1 = run_forward(model) torch.manual_seed(0) model = self.model_class(**init_dict) - apply_group_offloading(model, offload_type="block_level", num_blocks_per_group=1, non_blocking=True) + model.enable_group_offloading( + torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True + ) output_with_group_offloading2 = run_forward(model) torch.manual_seed(0) model = self.model_class(**init_dict) - apply_group_offloading(model, offload_type="leaf_level") + model.enable_group_offloading(torch_device, offload_type="leaf_level") output_with_group_offloading3 = run_forward(model) torch.manual_seed(0) model = self.model_class(**init_dict) - apply_group_offloading(model, offload_type="leaf_level", use_stream=True) + model.enable_group_offloading(torch_device, offload_type="leaf_level", use_stream=True) output_with_group_offloading4 = run_forward(model) self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5)) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 50f19c37ad2a..812208b17899 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -2077,7 +2077,14 @@ def enable_group_offloading_on_component(pipe, group_offloading_kwargs): component = getattr(pipe, component_name) if not getattr(component, "_supports_group_offloading", True): continue - apply_group_offloading(component, **group_offloading_kwargs) + if hasattr(component, "enable_group_offloading"): + # For diffusers ModelMixin implementations + component.enable_group_offloading(torch.device(torch_device), **group_offloading_kwargs) + else: + # For other models not part of diffusers + apply_group_offloading( + component, onload_device=torch.device(torch_device), **group_offloading_kwargs + ) self.assertTrue( all( module._diffusers_hook.get_hook("group_offloading") is not None From 06b411fc02dbcd44a4c793bbb6b1b4ac5b77171c Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 4 Feb 2025 08:53:36 +0100 Subject: [PATCH 27/37] improve docs --- docs/source/en/optimization/memory.md | 4 ++-- src/diffusers/hooks/group_offloading.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/source/en/optimization/memory.md b/docs/source/en/optimization/memory.md index 99fd981fdc79..7f8a7df7e48d 100644 --- a/docs/source/en/optimization/memory.md +++ b/docs/source/en/optimization/memory.md @@ -160,9 +160,9 @@ In order to properly offload models after they're called, it is required to run ## Group offloading -Group offloading is a middle ground between the two above methods. It works by offloading groups of internal layers (either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method is more memory-efficient than model-level offloading. It is also faster than sequential-level offloading, as the number of device synchronizations is reduced. +Group offloading is a middle ground between the two above methods. It works by offloading groups of internal layers (either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method uses lower memory than model-level offloading. It is also faster than sequential-level offloading, as the number of device synchronizations is reduced. -Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability to overlap data transfer and computation to reduce the overall execution time. This is enabled using layer prefetching with CUDA streams, i.e., the layer that is to be executed next starts onloading to the accelerator device while the current layer is being executed - this increases the memory requirements slightly. Note that this implementation also supports leaf-level offloading but can be made much faster when using streams. +Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability to overlap data transfer and computation to reduce the overall execution time compared to sequential offloading. This is enabled using layer prefetching with CUDA streams, i.e., the layer that is to be executed next starts onloading to the accelerator device while the current layer is being executed - this increases the memory requirements slightly. Note that this implementation also supports leaf-level offloading but can be made much faster when using streams. To enable group offloading, either call the [`~ModelMixin.enable_group_offloading`] method on the model or pass use [`~hooks.group_offloading.apply_group_offloading`]: diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 3aa91c2944c1..f5f188c86e63 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -285,14 +285,14 @@ def apply_group_offloading( memory, but can be slower due to the excessive number of device synchronizations. Group offloading is a middle ground between the two methods. It works by offloading groups of internal layers, - (either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method is more memory-efficient than module-level + (either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method uses lower memory than module-level offloading. It is also faster than leaf-level offloading, as the number of device synchronizations is reduced. Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability to - overlap data transfer and computation to reduce the overall execution time. This is enabled using layer prefetching - with streams, i.e., the layer that is to be executed next starts onloading to the accelerator device while the - current layer is being executed - this increases the memory requirements slightly. Note that this implementation - also supports leaf-level offloading but can be made much faster when using streams. + overlap data transfer and computation to reduce the overall execution time compared to sequential offloading. This + is enabled using layer prefetching with streams, i.e., the layer that is to be executed next starts onloading to + the accelerator device while the current layer is being executed - this increases the memory requirements slightly. + Note that this implementation also supports leaf-level offloading but can be made much faster when using streams. Args: module (`torch.nn.Module`): From 904e470fdecb63c6b1a3820280fdbe79662f9171 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 5 Feb 2025 06:02:28 +0530 Subject: [PATCH 28/37] Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/optimization/memory.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/en/optimization/memory.md b/docs/source/en/optimization/memory.md index 7f8a7df7e48d..d9311db4e25b 100644 --- a/docs/source/en/optimization/memory.md +++ b/docs/source/en/optimization/memory.md @@ -160,11 +160,11 @@ In order to properly offload models after they're called, it is required to run ## Group offloading -Group offloading is a middle ground between the two above methods. It works by offloading groups of internal layers (either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method uses lower memory than model-level offloading. It is also faster than sequential-level offloading, as the number of device synchronizations is reduced. +Group offloading is the middle ground between CPU and model offloading. It works by offloading groups of internal layers (either `torch.nn.ModuleList` or `torch.nn.Sequential`), which uses less memory than model-level offloading. It is also faster than sequential-level offloading because the number of device synchronizations is reduced. -Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability to overlap data transfer and computation to reduce the overall execution time compared to sequential offloading. This is enabled using layer prefetching with CUDA streams, i.e., the layer that is to be executed next starts onloading to the accelerator device while the current layer is being executed - this increases the memory requirements slightly. Note that this implementation also supports leaf-level offloading but can be made much faster when using streams. +Group offloading (for CUDA devices with support for asynchronous data transfer streams) overlaps data transfer and computation to reduce the overall execution time compared to sequential offloading. This is enabled using layer prefetching with CUDA streams. The next layer to be executed is loaded onto the accelerator device while the current layer is being executed - this increases the memory requirements slightly. Group offloading also supports leaf-level offloading but can be made much faster when using streams. -To enable group offloading, either call the [`~ModelMixin.enable_group_offloading`] method on the model or pass use [`~hooks.group_offloading.apply_group_offloading`]: +To enable group offloading, call the [`~ModelMixin.enable_group_offloading`] method on the model if it is a Diffusers model implementation. For any other model implementation, use [`~hooks.group_offloading.apply_group_offloading`]: ```python import torch From 3172ed5d52042ce0da5b67f13ab1be8c6e1eed10 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 5 Feb 2025 01:35:32 +0100 Subject: [PATCH 29/37] apply suggestions from code review --- docs/source/en/optimization/memory.md | 6 +++--- src/diffusers/hooks/group_offloading.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/source/en/optimization/memory.md b/docs/source/en/optimization/memory.md index d9311db4e25b..efb53b837346 100644 --- a/docs/source/en/optimization/memory.md +++ b/docs/source/en/optimization/memory.md @@ -160,9 +160,7 @@ In order to properly offload models after they're called, it is required to run ## Group offloading -Group offloading is the middle ground between CPU and model offloading. It works by offloading groups of internal layers (either `torch.nn.ModuleList` or `torch.nn.Sequential`), which uses less memory than model-level offloading. It is also faster than sequential-level offloading because the number of device synchronizations is reduced. - -Group offloading (for CUDA devices with support for asynchronous data transfer streams) overlaps data transfer and computation to reduce the overall execution time compared to sequential offloading. This is enabled using layer prefetching with CUDA streams. The next layer to be executed is loaded onto the accelerator device while the current layer is being executed - this increases the memory requirements slightly. Group offloading also supports leaf-level offloading but can be made much faster when using streams. +Group offloading is the middle ground between sequential and model offloading. It works by offloading groups of internal layers (either `torch.nn.ModuleList` or `torch.nn.Sequential`), which uses less memory than model-level offloading. It is also faster than sequential-level offloading because the number of device synchronizations is reduced. To enable group offloading, call the [`~ModelMixin.enable_group_offloading`] method on the model if it is a Diffusers model implementation. For any other model implementation, use [`~hooks.group_offloading.apply_group_offloading`]: @@ -198,6 +196,8 @@ print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} G export_to_video(video, "output.mp4", fps=8) ``` +Group offloading (for CUDA devices with support for asynchronous data transfer streams) overlaps data transfer and computation to reduce the overall execution time compared to sequential offloading. This is enabled using layer prefetching with CUDA streams. The next layer to be executed is loaded onto the accelerator device while the current layer is being executed - this increases the memory requirements slightly. Group offloading also supports leaf-level offloading (equivalent to sequential CPU offloading) but can be made much faster when using streams. + ## FP8 layerwise weight-casting PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting. diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index f5f188c86e63..2d40581dc653 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -286,7 +286,8 @@ def apply_group_offloading( Group offloading is a middle ground between the two methods. It works by offloading groups of internal layers, (either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method uses lower memory than module-level - offloading. It is also faster than leaf-level offloading, as the number of device synchronizations is reduced. + offloading. It is also faster than leaf-level/sequential offloading, as the number of device synchronizations is + reduced. Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability to overlap data transfer and computation to reduce the overall execution time compared to sequential offloading. This From aee24bcc0c99e4d21f24ef043218f602a6a319f1 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 5 Feb 2025 01:39:35 +0100 Subject: [PATCH 30/37] update tests --- tests/hooks/test_group_offloading.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/hooks/test_group_offloading.py b/tests/hooks/test_group_offloading.py index f13515b52fac..54b3e5140fb4 100644 --- a/tests/hooks/test_group_offloading.py +++ b/tests/hooks/test_group_offloading.py @@ -18,13 +18,9 @@ import torch from diffusers.models import ModelMixin -from diffusers.utils.logging import get_logger from diffusers.utils.testing_utils import require_torch_gpu, torch_device -logger = get_logger(__name__) # pylint: disable=invalid-name - - class DummyBlock(torch.nn.Module): def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None: super().__init__() @@ -105,8 +101,8 @@ def run_forward(model): ) model.eval() output = model(self.input)[0].cpu() - max_memory_reserved = torch.cuda.max_memory_allocated() - return output, max_memory_reserved + max_memory_allocated = torch.cuda.max_memory_allocated() + return output, max_memory_allocated self.model.to(torch_device) output_without_group_offloading, mem_baseline = run_forward(self.model) From 3f20e6bc1bbb7e2c1f90c1ecae59ff270fa89659 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 6 Feb 2025 06:05:59 +0100 Subject: [PATCH 31/37] apply suggestions from review --- src/diffusers/hooks/group_offloading.py | 38 +++++++++++-------------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 2d40581dc653..d355673dce3b 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -550,27 +550,23 @@ def _apply_group_offloading_leaf_level( ) _apply_group_offloading_hook(parent_module, group, None) - # This is a dummy group that will handle lazy prefetching from the top-level module to the first leaf module - unmatched_group = ModuleGroup( - modules=[], - offload_device=offload_device, - onload_device=onload_device, - offload_leader=module, - onload_leader=module, - parameters=None, - buffers=None, - non_blocking=False, - stream=None, - cpu_param_dict=None, - onload_self=True, - ) - - # When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer - # and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the - # execution order and apply prefetching in the correct order. - if stream is None: - _apply_group_offloading_hook(module, unmatched_group, None) - else: + if stream is not None: + # When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer + # and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the + # execution order and apply prefetching in the correct order. + unmatched_group = ModuleGroup( + modules=[], + offload_device=offload_device, + onload_device=onload_device, + offload_leader=module, + onload_leader=module, + parameters=None, + buffers=None, + non_blocking=False, + stream=None, + cpu_param_dict=None, + onload_self=True, + ) _apply_lazy_group_offloading_hook(module, unmatched_group, None) From 840576ac2b59993156ea3e37c15da365266faf5e Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 6 Feb 2025 12:37:17 +0100 Subject: [PATCH 32/37] enable_group_offloading -> enable_group_offload for naming consistency --- docs/source/en/optimization/memory.md | 6 +++--- src/diffusers/models/modeling_utils.py | 4 ++-- tests/hooks/test_group_offloading.py | 16 +++++++--------- tests/models/test_modeling_common.py | 10 ++++------ tests/pipelines/test_pipelines_common.py | 10 +++++----- 5 files changed, 21 insertions(+), 25 deletions(-) diff --git a/docs/source/en/optimization/memory.md b/docs/source/en/optimization/memory.md index efb53b837346..9467a770d484 100644 --- a/docs/source/en/optimization/memory.md +++ b/docs/source/en/optimization/memory.md @@ -162,7 +162,7 @@ In order to properly offload models after they're called, it is required to run Group offloading is the middle ground between sequential and model offloading. It works by offloading groups of internal layers (either `torch.nn.ModuleList` or `torch.nn.Sequential`), which uses less memory than model-level offloading. It is also faster than sequential-level offloading because the number of device synchronizations is reduced. -To enable group offloading, call the [`~ModelMixin.enable_group_offloading`] method on the model if it is a Diffusers model implementation. For any other model implementation, use [`~hooks.group_offloading.apply_group_offloading`]: +To enable group offloading, call the [`~ModelMixin.enable_group_offload`] method on the model if it is a Diffusers model implementation. For any other model implementation, use [`~hooks.group_offloading.apply_group_offloading`]: ```python import torch @@ -175,8 +175,8 @@ onload_device = torch.device("cuda") offload_device = torch.device("cpu") pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) -# We can utilize the enable_group_offloading method for Diffusers model implementations -pipe.transformer.enable_group_offloading(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True) +# We can utilize the enable_group_offload method for Diffusers model implementations +pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True) # For any other model implementations, the apply_group_offloading function can be used apply_group_offloading(pipe.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index adcedd4637e0..6485b8f751ab 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -446,7 +446,7 @@ def enable_layerwise_casting( self, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking ) - def enable_group_offloading( + def enable_group_offload( self, onload_device: torch.device, offload_device: torch.device = torch.device("cpu"), @@ -469,7 +469,7 @@ def enable_group_offloading( ... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16 ... ) - >>> transformer.enable_group_offloading( + >>> transformer.enable_group_offload( ... onload_device=torch.device("cuda"), ... offload_device=torch.device("cpu"), ... offload_type="leaf_level", diff --git a/tests/hooks/test_group_offloading.py b/tests/hooks/test_group_offloading.py index 54b3e5140fb4..35bd2d43c7c9 100644 --- a/tests/hooks/test_group_offloading.py +++ b/tests/hooks/test_group_offloading.py @@ -109,25 +109,23 @@ def run_forward(model): self.model.to("cpu") model = self.get_model() - model.enable_group_offloading(torch_device, offload_type="block_level", num_blocks_per_group=3) + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3) output_with_group_offloading1, mem1 = run_forward(model) model = self.get_model() - model.enable_group_offloading(torch_device, offload_type="block_level", num_blocks_per_group=1) + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1) output_with_group_offloading2, mem2 = run_forward(model) model = self.get_model() - model.enable_group_offloading( - torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True - ) + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True) output_with_group_offloading3, mem3 = run_forward(model) model = self.get_model() - model.enable_group_offloading(torch_device, offload_type="leaf_level") + model.enable_group_offload(torch_device, offload_type="leaf_level") output_with_group_offloading4, mem4 = run_forward(model) model = self.get_model() - model.enable_group_offloading(torch_device, offload_type="leaf_level", use_stream=True) + model.enable_group_offload(torch_device, offload_type="leaf_level", use_stream=True) output_with_group_offloading5, mem5 = run_forward(model) # Precision assertions - offloading should not impact the output @@ -144,7 +142,7 @@ def test_error_raised_if_streams_used_and_no_cuda_device(self): original_is_available = torch.cuda.is_available torch.cuda.is_available = lambda: False with self.assertRaises(ValueError): - self.model.enable_group_offloading( + self.model.enable_group_offload( onload_device=torch.device("cuda"), offload_type="leaf_level", use_stream=True ) torch.cuda.is_available = original_is_available @@ -152,4 +150,4 @@ def test_error_raised_if_streams_used_and_no_cuda_device(self): def test_error_raised_if_supports_group_offloading_false(self): self.model._supports_group_offloading = False with self.assertRaisesRegex(ValueError, "does not support group offloading"): - self.model.enable_group_offloading(onload_device=torch.device("cuda")) + self.model.enable_group_offload(onload_device=torch.device("cuda")) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 622ce2ce7081..e848b32e24da 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1454,24 +1454,22 @@ def run_forward(model): torch.manual_seed(0) model = self.model_class(**init_dict) - model.enable_group_offloading(torch_device, offload_type="block_level", num_blocks_per_group=1) + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1) output_with_group_offloading1 = run_forward(model) torch.manual_seed(0) model = self.model_class(**init_dict) - model.enable_group_offloading( - torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True - ) + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True) output_with_group_offloading2 = run_forward(model) torch.manual_seed(0) model = self.model_class(**init_dict) - model.enable_group_offloading(torch_device, offload_type="leaf_level") + model.enable_group_offload(torch_device, offload_type="leaf_level") output_with_group_offloading3 = run_forward(model) torch.manual_seed(0) model = self.model_class(**init_dict) - model.enable_group_offloading(torch_device, offload_type="leaf_level", use_stream=True) + model.enable_group_offload(torch_device, offload_type="leaf_level", use_stream=True) output_with_group_offloading4 = run_forward(model) self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5)) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 812208b17899..355e851f9fdd 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -2059,7 +2059,7 @@ def create_pipe(): pipe.set_progress_bar_config(disable=None) return pipe - def enable_group_offloading_on_component(pipe, group_offloading_kwargs): + def enable_group_offload_on_component(pipe, group_offloading_kwargs): # We intentionally don't test VAE's here. This is because some tests enable tiling on the VAE. If # tiling is enabled and a forward pass is run, when cuda streams are used, the execution order of # the layers is not traced correctly. This causes errors. For apply group offloading to VAE, a @@ -2077,9 +2077,9 @@ def enable_group_offloading_on_component(pipe, group_offloading_kwargs): component = getattr(pipe, component_name) if not getattr(component, "_supports_group_offloading", True): continue - if hasattr(component, "enable_group_offloading"): + if hasattr(component, "enable_group_offload"): # For diffusers ModelMixin implementations - component.enable_group_offloading(torch.device(torch_device), **group_offloading_kwargs) + component.enable_group_offload(torch.device(torch_device), **group_offloading_kwargs) else: # For other models not part of diffusers apply_group_offloading( @@ -2105,11 +2105,11 @@ def run_forward(pipe): output_without_group_offloading = run_forward(pipe) pipe = create_pipe() - enable_group_offloading_on_component(pipe, {"offload_type": "block_level", "num_blocks_per_group": 1}) + enable_group_offload_on_component(pipe, {"offload_type": "block_level", "num_blocks_per_group": 1}) output_with_group_offloading1 = run_forward(pipe) pipe = create_pipe() - enable_group_offloading_on_component(pipe, {"offload_type": "leaf_level"}) + enable_group_offload_on_component(pipe, {"offload_type": "leaf_level"}) output_with_group_offloading2 = run_forward(pipe) if torch.is_tensor(output_without_group_offloading): From 8804d746f09c299d55461dc0baf8c0cb4fa1dbd3 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 6 Feb 2025 13:33:31 +0100 Subject: [PATCH 33/37] raise errors if multiple offloading strategies used; add relevant tests --- src/diffusers/hooks/group_offloading.py | 17 ++++++++++ src/diffusers/pipelines/pipeline_utils.py | 22 +++++++++++++ tests/hooks/test_group_offloading.py | 39 +++++++++++++++++++++++ 3 files changed, 78 insertions(+) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index d355673dce3b..5e4f72d1f243 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -22,6 +22,7 @@ if is_accelerate_available(): + from accelerate.hooks import AlignDevicesHook, CpuOffload from accelerate.utils import send_to_device @@ -341,6 +342,8 @@ def apply_group_offloading( else: raise ValueError("Using streams for data transfer requires a CUDA device.") + _raise_error_if_accelerate_model_or_sequential_hook_present(module) + if offload_type == "block_level": if num_blocks_per_group is None: raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.") @@ -645,3 +648,17 @@ def _find_parent_module_in_module_dict(name: str, module_dict: Dict[str, torch.n return parent_name atoms.pop() return "" + + +def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn.Module) -> None: + if not is_accelerate_available(): + return + for name, submodule in module.named_modules(): + if not hasattr(submodule, "_hf_hook"): + continue + if isinstance(submodule._hf_hook, (AlignDevicesHook, CpuOffload)): + raise ValueError( + f"Cannot apply group offloading to a module that is already applying an alternative " + f"offloading strategy from Accelerate. If you want to apply group offloading, please " + f"disable the existing offloading strategy first. Offending module: {name} ({type(submodule)})" + ) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 9c20e492535a..9ba02c13714e 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1075,6 +1075,8 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will default to "cuda". """ + self._check_group_offloading_inactive_or_raise_error() + is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 if is_pipeline_device_mapped: raise ValueError( @@ -1186,6 +1188,8 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will default to "cuda". """ + self._check_group_offloading_inactive_or_raise_error() + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload else: @@ -1910,6 +1914,24 @@ def from_pipe(cls, pipeline, **kwargs): return new_pipeline + def _check_group_offloading_inactive_or_raise_error(self) -> None: + from ..hooks import HookRegistry + from ..hooks.group_offloading import _GROUP_OFFLOADING + + for name, component in self.components.items(): + if not isinstance(component, torch.nn.Module): + continue + for module in component.modules(): + if not hasattr(module, "_diffusers_hook"): + continue + registry: HookRegistry = module._diffusers_hook + if registry.get_hook(_GROUP_OFFLOADING) is not None: + raise ValueError( + f"You are trying to apply model/sequential CPU offloading to a pipeline that contains " + f"components with group offloading enabled. This is not supported. Please disable group " + f"offloading for the '{name}' component of the pipeline to use other offloading methods." + ) + class StableDiffusionMixin: r""" diff --git a/tests/hooks/test_group_offloading.py b/tests/hooks/test_group_offloading.py index 35bd2d43c7c9..f05626905185 100644 --- a/tests/hooks/test_group_offloading.py +++ b/tests/hooks/test_group_offloading.py @@ -18,6 +18,7 @@ import torch from diffusers.models import ModelMixin +from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.utils.testing_utils import require_torch_gpu, torch_device @@ -56,6 +57,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class DummyPipeline(DiffusionPipeline): + model_cpu_offload_seq = "model" + + def __init__(self, model: torch.nn.Module) -> None: + super().__init__() + + self.register_modules(model=model) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + for _ in range(2): + x = x + 0.1 * self.model(x) + return x + + @require_torch_gpu class GroupOffloadTests(unittest.TestCase): in_features = 64 @@ -151,3 +166,27 @@ def test_error_raised_if_supports_group_offloading_false(self): self.model._supports_group_offloading = False with self.assertRaisesRegex(ValueError, "does not support group offloading"): self.model.enable_group_offload(onload_device=torch.device("cuda")) + + def test_error_raised_if_model_offloading_applied_on_group_offloaded_module(self): + pipe = DummyPipeline(self.model) + pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3) + with self.assertRaisesRegex(ValueError, "You are trying to apply model/sequential CPU offloading"): + pipe.enable_model_cpu_offload() + + def test_error_raised_if_sequential_offloading_applied_on_group_offloaded_module(self): + pipe = DummyPipeline(self.model) + pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3) + with self.assertRaisesRegex(ValueError, "You are trying to apply model/sequential CPU offloading"): + pipe.enable_sequential_cpu_offload() + + def test_error_raised_if_group_offloading_applied_on_model_offloaded_module(self): + pipe = DummyPipeline(self.model) + pipe.enable_model_cpu_offload() + with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"): + pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3) + + def test_error_raised_if_group_offloading_applied_on_sequential_offloaded_module(self): + pipe = DummyPipeline(self.model) + pipe.enable_sequential_cpu_offload() + with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"): + pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3) From 954bb7d0198aeb852b776d0232a5ca0787e2c7dc Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 6 Feb 2025 19:04:16 +0100 Subject: [PATCH 34/37] handle .to() when group offload applied --- src/diffusers/hooks/group_offloading.py | 7 ++++ src/diffusers/models/modeling_utils.py | 20 +++++++++ src/diffusers/pipelines/pipeline_utils.py | 50 ++++++++++++++--------- tests/hooks/test_group_offloading.py | 22 ++++++++++ 4 files changed, 80 insertions(+), 19 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 5e4f72d1f243..7377979f507b 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -662,3 +662,10 @@ def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn f"offloading strategy from Accelerate. If you want to apply group offloading, please " f"disable the existing offloading strategy first. Offending module: {name} ({type(submodule)})" ) + + +def _is_group_offload_enabled(module: torch.nn.Module) -> bool: + for submodule in module.modules(): + if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None: + return True + return False diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 6485b8f751ab..af6416ea8373 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1245,8 +1245,21 @@ def cuda(self, *args, **kwargs): # Adapted from `transformers`. @wraps(torch.nn.Module.to) def to(self, *args, **kwargs): + from ..hooks.group_offloading import _is_group_offload_enabled + + device_arg_or_kwarg_present = any(isinstance(arg, torch.device) for arg in args) or "device" in kwargs dtype_present_in_args = "dtype" in kwargs + # Try converting arguments to torch.device in case they are passed as strings + for arg in args: + if not isinstance(arg, str): + continue + try: + torch.device(arg) + device_arg_or_kwarg_present = True + except RuntimeError: + pass + if not dtype_present_in_args: for arg in args: if isinstance(arg, torch.dtype): @@ -1271,6 +1284,13 @@ def to(self, *args, **kwargs): "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." ) + + if _is_group_offload_enabled(self) and device_arg_or_kwarg_present: + logger.warning( + f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.to()` is not supported." + ) + return self + return super().to(*args, **kwargs) # Taken from `transformers`. diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 9ba02c13714e..0a09bfa610a2 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -394,6 +394,7 @@ def to(self, *args, **kwargs): ) device = device or device_arg + device_type = torch.device(device).type if device is not None else None pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items()) # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU. @@ -424,7 +425,7 @@ def module_is_offloaded(module): "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline." ) - if device and torch.device(device).type == "cuda": + if device_type == "cuda": if pipeline_is_sequentially_offloaded and not pipeline_has_bnb: raise ValueError( "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading." @@ -437,7 +438,7 @@ def module_is_offloaded(module): # Display a warning in this case (the operation succeeds but the benefits are lost) pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items()) - if pipeline_is_offloaded and device and torch.device(device).type == "cuda": + if pipeline_is_offloaded and device_type == "cuda": logger.warning( f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading." ) @@ -449,6 +450,7 @@ def module_is_offloaded(module): is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded for module in modules: _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module) + is_group_offloaded = self._maybe_raise_error_if_group_offload_active(module=module) if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None: logger.warning( @@ -460,11 +462,21 @@ def module_is_offloaded(module): f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}." ) + # Note: we also handle this as the ModelMixin level. The reason for doing it here too is that modeling + # components can be from outside diffusers too, but still have group offloading enabled. + if ( + self._maybe_raise_error_if_group_offload_active(raise_error=False, module=module) + and device is not None + ): + logger.warning( + f"The module '{module.__class__.__name__}' is group offloaded and moving it to {device} via `.to()` is not supported." + ) + # This can happen for `transformer` models. CPU placement was added in # https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly. if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"): module.to(device=device) - elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb: + elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb and not is_group_offloaded: module.to(device, dtype) if ( @@ -1075,7 +1087,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will default to "cuda". """ - self._check_group_offloading_inactive_or_raise_error() + self._maybe_raise_error_if_group_offload_active(raise_error=True) is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 if is_pipeline_device_mapped: @@ -1188,7 +1200,7 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will default to "cuda". """ - self._check_group_offloading_inactive_or_raise_error() + self._maybe_raise_error_if_group_offload_active(raise_error=True) if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): from accelerate import cpu_offload @@ -1914,23 +1926,23 @@ def from_pipe(cls, pipeline, **kwargs): return new_pipeline - def _check_group_offloading_inactive_or_raise_error(self) -> None: - from ..hooks import HookRegistry - from ..hooks.group_offloading import _GROUP_OFFLOADING + def _maybe_raise_error_if_group_offload_active( + self, raise_error: bool = False, module: Optional[torch.nn.Module] = None + ) -> bool: + from ..hooks.group_offloading import _is_group_offload_enabled - for name, component in self.components.items(): - if not isinstance(component, torch.nn.Module): - continue - for module in component.modules(): - if not hasattr(module, "_diffusers_hook"): - continue - registry: HookRegistry = module._diffusers_hook - if registry.get_hook(_GROUP_OFFLOADING) is not None: + components = self.components.values() if module is None else [module] + components = [component for component in components if isinstance(component, torch.nn.Module)] + for component in components: + if _is_group_offload_enabled(component): + if raise_error: raise ValueError( - f"You are trying to apply model/sequential CPU offloading to a pipeline that contains " - f"components with group offloading enabled. This is not supported. Please disable group " - f"offloading for the '{name}' component of the pipeline to use other offloading methods." + "You are trying to apply model/sequential CPU offloading to a pipeline that contains components " + "with group offloading enabled. This is not supported. Please disable group offloading for " + "components of the pipeline to use other offloading methods." ) + return True + return False class StableDiffusionMixin: diff --git a/tests/hooks/test_group_offloading.py b/tests/hooks/test_group_offloading.py index f05626905185..d8f41fc2b1ae 100644 --- a/tests/hooks/test_group_offloading.py +++ b/tests/hooks/test_group_offloading.py @@ -19,6 +19,7 @@ from diffusers.models import ModelMixin from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import get_logger from diffusers.utils.testing_utils import require_torch_gpu, torch_device @@ -153,6 +154,27 @@ def run_forward(model): # Memory assertions - offloading should reduce memory usage self.assertTrue(mem4 <= mem5 < mem2 < mem3 < mem1 < mem_baseline) + def test_warning_logged_if_group_offloaded_module_moved_to_cuda(self): + if torch.device(torch_device).type != "cuda": + return + self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3) + logger = get_logger("diffusers.models.modeling_utils") + logger.setLevel("INFO") + with self.assertLogs(logger, level="WARNING") as cm: + self.model.to(torch_device) + self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0]) + + def test_warning_logged_if_group_offloaded_pipe_moved_to_cuda(self): + if torch.device(torch_device).type != "cuda": + return + pipe = DummyPipeline(self.model) + self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3) + logger = get_logger("diffusers.pipelines.pipeline_utils") + logger.setLevel("INFO") + with self.assertLogs(logger, level="WARNING") as cm: + pipe.to(torch_device) + self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0]) + def test_error_raised_if_streams_used_and_no_cuda_device(self): original_is_available = torch.cuda.is_available torch.cuda.is_available = lambda: False From da88c333c9b389fd4e0a89b84b9031ea5f0ba8b1 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 6 Feb 2025 19:21:45 +0100 Subject: [PATCH 35/37] refactor some repeated code --- src/diffusers/hooks/group_offloading.py | 7 +++++++ src/diffusers/models/modeling_utils.py | 17 +++++++++-------- src/diffusers/pipelines/pipeline_utils.py | 15 +++++++-------- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 7377979f507b..c389c5dc9826 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -669,3 +669,10 @@ def _is_group_offload_enabled(module: torch.nn.Module) -> bool: if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None: return True return False + + +def _get_group_onload_device(module: torch.nn.Module) -> torch.device: + for submodule in module.modules(): + if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None: + return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device + raise ValueError("Group offloading is not enabled for the provided module.") diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index af6416ea8373..be5ea6cf9991 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -86,16 +86,17 @@ def get_parameter_device(parameter: torch.nn.Module) -> torch.device: + from ..hooks.group_offloading import _get_group_onload_device + try: - if hasattr(parameter, "_diffusers_hook"): - for submodule in parameter.modules(): - if not hasattr(submodule, "_diffusers_hook"): - continue - registry = parameter._diffusers_hook - hook = registry.get_hook("group_offloading") - if hook is not None: - return hook.group.onload_device + # Try to get the onload device from the group offloading hook + return _get_group_onload_device(parameter) + except ValueError: + pass + try: + # If the onload device is not available due to no group offloading hooks, try to get the device + # from the first parameter or buffer parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) return next(parameters_and_buffers).device except StopIteration: diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 0a09bfa610a2..2a84af64f8e2 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -462,7 +462,7 @@ def module_is_offloaded(module): f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}." ) - # Note: we also handle this as the ModelMixin level. The reason for doing it here too is that modeling + # Note: we also handle this at the ModelMixin level. The reason for doing it here too is that modeling # components can be from outside diffusers too, but still have group offloading enabled. if ( self._maybe_raise_error_if_group_offload_active(raise_error=False, module=module) @@ -1035,19 +1035,18 @@ def _execution_device(self): [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from Accelerate's module hooks. """ + from ..hooks.group_offloading import _get_group_onload_device + # When apply group offloading at the leaf_level, we're in the same situation as accelerate's sequential # offloading. We need to return the onload device of the group offloading hooks so that the intermediates # required for computation (latents, prompt embeddings, etc.) can be created on the correct device. for name, model in self.components.items(): if not isinstance(model, torch.nn.Module): continue - for submodule in model.modules(): - if not hasattr(submodule, "_diffusers_hook"): - continue - registry = submodule._diffusers_hook - hook = registry.get_hook("group_offloading") - if hook is not None: - return hook.group.onload_device + try: + return _get_group_onload_device(model) + except ValueError: + pass for name, model in self.components.items(): if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload: From a872e84e5fcd748e411b03ee3fd929cd014a6250 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 6 Feb 2025 19:22:56 +0100 Subject: [PATCH 36/37] remove unintentional change from merge conflict --- src/diffusers/hooks/hooks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index d73ee307a4d7..3b2e4ed91c2f 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -177,7 +177,6 @@ def get_hook(self, name: str) -> Optional[ModelHook]: return self.hooks.get(name, None) def remove_hook(self, name: str, recurse: bool = True) -> None: - num_hooks = len(self._hook_order) if name in self.hooks.keys(): num_hooks = len(self._hook_order) hook = self.hooks[name] From 6be43b8a6badeec3514884566a04877f1c5df5bc Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 6 Feb 2025 19:26:10 +0100 Subject: [PATCH 37/37] handle .cuda() --- src/diffusers/models/modeling_utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index be5ea6cf9991..764a8f1e9307 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1229,6 +1229,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Adapted from `transformers`. @wraps(torch.nn.Module.cuda) def cuda(self, *args, **kwargs): + from ..hooks.group_offloading import _is_group_offload_enabled + # Checks if the model has been loaded in 4-bit or 8-bit with BNB if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: if getattr(self, "is_loaded_in_8bit", False): @@ -1241,6 +1243,14 @@ def cuda(self, *args, **kwargs): "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." ) + + # Checks if group offloading is enabled + if _is_group_offload_enabled(self): + logger.warning( + f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.cuda()` is not supported." + ) + return self + return super().cuda(*args, **kwargs) # Adapted from `transformers`.