From d2a2981a90d247f766ec19ff299132a54b0054d2 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 16 Jan 2025 08:29:30 +0100 Subject: [PATCH] update; ~very workaround based implementation but it seems to work as expected; needs cleanup and rewrite --- src/diffusers/hooks/group_offloading.py | 115 +++++++++++++++++------- src/diffusers/hooks/hooks.py | 5 ++ 2 files changed, 89 insertions(+), 31 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index c86025a901b6..41d6579779f2 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -21,7 +21,7 @@ from .hooks import HookRegistry, ModelHook -logger = get_logger(__name__) # pylint: disable=invalid-name +logger = get_logger(__name__) # pylint: disable=invalid-name class ModuleGroup: @@ -32,12 +32,16 @@ def __init__( onload_device: torch.device, offload_leader: torch.nn.Module, onload_leader: Optional[torch.nn.Module] = None, + parameters: Optional[List[torch.nn.Parameter]] = None, + buffers: Optional[List[torch.Tensor]] = None, ) -> None: self.modules = modules self.offload_device = offload_device self.onload_device = onload_device self.offload_leader = offload_leader self.onload_leader = onload_leader + self.parameters = parameters + self.buffers = buffers class GroupOffloadingHook(ModelHook): @@ -64,6 +68,7 @@ def __init__( stream: Optional[torch.cuda.Stream] = None, next_group: Optional[ModuleGroup] = None, cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None, + onload_self: bool = False, ) -> None: self.group = group self.offload_on_init = offload_on_init @@ -71,6 +76,7 @@ def __init__( self.stream = stream self.next_group = next_group self.cpu_param_dict = cpu_param_dict + self.onload_self = onload_self def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: if self.offload_on_init: @@ -100,9 +106,16 @@ def onload_(self, module: torch.nn.Module) -> None: with torch.cuda.stream(self.stream): for group_module in self.next_group.modules: group_module.to(self.next_group.onload_device, non_blocking=True) - else: + + if self.stream is None or self.onload_self: for group_module in self.group.modules: group_module.to(self.group.onload_device, non_blocking=self.non_blocking) + if self.group.parameters is not None: + for param in self.group.parameters: + param.data = param.data.to(self.group.onload_device, non_blocking=self.non_blocking) + if self.group.buffers is not None: + for buffer in self.group.buffers: + buffer.data = buffer.data.to(self.group.onload_device, non_blocking=self.non_blocking) def offload_(self, module: torch.nn.Module) -> None: if self.group.offload_leader == module: @@ -113,6 +126,13 @@ def offload_(self, module: torch.nn.Module) -> None: else: for group_module in self.group.modules: group_module.to(self.group.offload_device, non_blocking=self.non_blocking) + if self.group.parameters is not None: + for param in self.group.parameters: + param.data = param.data.to(self.group.offload_device, non_blocking=self.non_blocking) + if self.group.buffers is not None: + for buffer in self.group.buffers: + buffer.data = buffer.data.to(self.group.offload_device, non_blocking=self.non_blocking) + # TODO: do we need to sync here because of GPU->CPU transfer? if self.non_blocking and self.group.offload_device.type == "cpu": torch.cpu.synchronize() @@ -128,9 +148,9 @@ def apply_group_offloading( non_blocking: bool = False, cuda_stream: bool = False, ) -> None: - # stream = None - # if cuda_stream: - # stream = torch.cuda.Stream() + stream = None + if cuda_stream: + stream = torch.cuda.Stream() if offload_group_patterns == "modulelist_or_sequential": if num_blocks_per_group is None: raise ValueError( @@ -148,7 +168,7 @@ def apply_group_offloading( offload_group_patterns = _get_modulelist_or_sequential_group_patterns(module, num_blocks_per_group) _apply_group_offloading_group_patterns( - module, offload_group_patterns, offload_device, onload_device, force_offload, non_blocking + module, offload_group_patterns, offload_device, onload_device, force_offload, non_blocking, stream=stream ) @@ -231,6 +251,7 @@ def _apply_group_offloading_group_patterns( onload_device: torch.device, force_offload: bool, non_blocking: bool, + stream: Optional[torch.cuda.Stream] = None, ) -> None: r""" This function applies offloading to groups of modules based on the provided regex patterns. Each group of modules @@ -269,8 +290,17 @@ def _apply_group_offloading_group_patterns( non_blocking (`bool`): If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation and data transfer. + stream (`torch.cuda.Stream`, *optional*): + If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful + for overlapping computation and data transfer. """ + cpu_param_dict = None + if stream is not None: + for param in module.parameters(): + param.data = param.data.cpu().pin_memory() + cpu_param_dict = {param: param.data for param in module.parameters()} + per_group_modules = [[] for _ in range(len(offload_group_patterns))] per_group_offload_leaders = [None] * len(offload_group_patterns) per_group_onload_leaders = [None] * len(offload_group_patterns) @@ -280,20 +310,20 @@ def _apply_group_offloading_group_patterns( offload_leader_patterns = [pattern[1] for pattern in offload_group_patterns] onload_leader_patterns = [pattern[2] for pattern in offload_group_patterns] - for name, module in module.named_modules(): - if name.count(".") > 1: + for name, submodule in module.named_modules(): + if name == "" or name.count(".") > 1: # We only want the layers that are top-level in the module (encompass all the other submodules) # for enabling offloading. This method is specifically targeted for diffusers format models, # so we can ignore submodules. # TODO(aryan): This is not the case and is just a workaround to make the benchmark code work # for now. We need to support the arbitrary nesting of modules here. continue - num_matches = 0 # Check if the module matches any of the offload group patterns + num_matches = 0 for i, pattern in enumerate(group_patterns): if re.search(pattern, name) is not None: - per_group_modules[i].append(module) + per_group_modules[i].append(submodule) num_matches += 1 # Check if the module matches any of the offload leader patterns @@ -303,7 +333,7 @@ def _apply_group_offloading_group_patterns( raise ValueError( f"Module {name} matches multiple offload leader patterns. Please ensure that offload leader patterns are mutually exclusive." ) - per_group_offload_leaders[i] = module + per_group_offload_leaders[i] = submodule # Check if the module matches any of the onload leader patterns for i, pattern in enumerate(onload_leader_patterns): @@ -314,16 +344,17 @@ def _apply_group_offloading_group_patterns( raise ValueError( f"Module {name} matches multiple onload leader patterns. Please ensure that onload leader patterns are mutually exclusive." ) - per_group_onload_leaders[i] = module + per_group_onload_leaders[i] = submodule if num_matches == 0: - unmatched_group_modules.append(module) + unmatched_group_modules.append((name, submodule)) elif num_matches > 1: raise ValueError( f"Module {name} matches multiple offload group patterns. Please ensure that offloading group patterns are mutually exclusive." ) # Handle modules that matched patterns + groups = [] for i in range(len(per_group_modules)): if per_group_offload_leaders[i] is None: raise ValueError( @@ -336,21 +367,40 @@ def _apply_group_offloading_group_patterns( offload_leader=per_group_offload_leaders[i], onload_leader=per_group_onload_leaders[i], ) - _apply_group_offloading(group, force_offload, non_blocking) - - # Handle modules that did not match patterns - for module in unmatched_group_modules: - group = ModuleGroup([module], offload_device, onload_device, offload_leader=module, onload_leader=module) - _apply_group_offloading(group, force_offload, non_blocking) - - # TODO(aryan): When you add stream support, this may need to be put in an if-branch - # Always keep parameters and buffers on onload_device - for name, param in module.named_parameters(recurse=False): - if torch.is_tensor(param.data): - param.data = param.data.to(onload_device) + groups.append(group) + + for i in range(len(groups)): + next_group = groups[i + 1] if i + 1 < len(groups) and stream is not None else None + should_offload = force_offload or i > 0 + _apply_group_offloading( + groups[i], should_offload, non_blocking, stream, next_group, cpu_param_dict, onload_self=False + ) + + # Ignore parameters/buffers if they're already accounted for in unmatched_group_modules (for example, a nn.Linear + # in the top-level module will also be present in the named_parameters iterator) + parameters = [] + for name, parameter in module.named_parameters(recurse=False): + if not any(name.startswith(unmatched_name) for unmatched_name, _ in unmatched_group_modules): + parameters.append(parameter) + + buffers = [] for name, buffer in module.named_buffers(recurse=False): - if torch.is_tensor(buffer.data): - buffer.data = buffer.data.to(onload_device) + if not any(name.startswith(unmatched_name) for unmatched_name, _ in unmatched_group_modules): + buffers.append(buffer) + + unmatched_modules = [module for _, module in unmatched_group_modules] + unmatched_group = ModuleGroup( + unmatched_modules, + offload_device, + onload_device, + offload_leader=module, + onload_leader=None, + parameters=parameters, + buffers=buffers, + ) + _apply_group_offloading( + unmatched_group, force_offload, non_blocking, stream, groups[0], cpu_param_dict, onload_self=True + ) def _apply_group_offloading( @@ -360,9 +410,12 @@ def _apply_group_offloading( stream: Optional[torch.cuda.Stream] = None, next_group: Optional[ModuleGroup] = None, cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None, + onload_self: bool = False, ) -> None: for module in group.modules: - hook = GroupOffloadingHook(group, offload_on_init, non_blocking, stream, next_group, cpu_param_dict) + hook = GroupOffloadingHook( + group, offload_on_init, non_blocking, stream, next_group, cpu_param_dict, onload_self + ) registry = HookRegistry.check_if_exists_or_initialize(module) registry.register_hook(hook, "group_offloading") @@ -375,11 +428,11 @@ def _get_modulelist_or_sequential_group_patterns(module: torch.nn.Module, num_bl blocks. The generated patterns can be used to create ModuleGroup objects which are offloaded and onloaded together. """ group_patterns = [] - + # We only want the layers that are top-level in the module (encompass all the other submodules) # for enabling offloading. This method is specifically targeted for diffusers format models, # so we can ignore everything but the children of this module. - for name, submodule in module.children(): + for name, submodule in module.named_children(): if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): continue for i in range(0, len(submodule), num_blocks_per_group): @@ -389,6 +442,6 @@ def _get_modulelist_or_sequential_group_patterns(module: torch.nn.Module, num_bl onload_leader_pattern = rf"{name}\.{i}\b" offload_leader_pattern = rf"{name}\.{i + num_modules - 1}\b" group_patterns.append((pattern, offload_leader_pattern, onload_leader_pattern)) - + logger.debug(f"Generated group patterns for apply_groupwise_offloading: {group_patterns}") return group_patterns diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index e80ac6e88389..bef4c65c41e1 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -33,6 +33,7 @@ class ModelHook: def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: r""" Hook that is executed when a model is initialized. + Args: module (`torch.nn.Module`): The module attached to this hook. @@ -42,6 +43,7 @@ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module: r""" Hook that is executed when a model is deinitalized. + Args: module (`torch.nn.Module`): The module attached to this hook. @@ -53,6 +55,7 @@ def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module: def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: r""" Hook that is executed just before the forward method of the model. + Args: module (`torch.nn.Module`): The module whose forward pass will be executed just after this event. @@ -69,6 +72,7 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[A def post_forward(self, module: torch.nn.Module, output: Any) -> Any: r""" Hook that is executed just after the forward method of the model. + Args: module (`torch.nn.Module`): The module whose forward pass been executed just before this event. @@ -82,6 +86,7 @@ def post_forward(self, module: torch.nn.Module, output: Any) -> Any: def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: r""" Hook that is executed when the hook is detached from a module. + Args: module (`torch.nn.Module`): The module detached from this hook.