From be983082f8e83c4bf470841029d33506a1e23560 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 10 Jan 2025 06:29:20 +0100 Subject: [PATCH 1/2] cuda stream prefetch --- src/diffusers/hooks/group_offloading.py | 95 ++++++++++++++++--- .../models/transformers/transformer_ltx.py | 1 + 2 files changed, 82 insertions(+), 14 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 8eda18053eb9..98ceff2f1a58 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -13,7 +13,7 @@ # limitations under the License. import re -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import torch @@ -65,10 +65,21 @@ class GroupOffloadingHook(ModelHook): encounter such an error. """ - def __init__(self, group: ModuleGroup, offload_on_init: bool = True, non_blocking: bool = False) -> None: + def __init__( + self, + group: ModuleGroup, + offload_on_init: bool = True, + non_blocking: bool = False, + stream: Optional[torch.cuda.Stream] = None, + next_group: Optional[ModuleGroup] = None, + cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None, + ) -> None: self.group = group self.offload_on_init = offload_on_init self.non_blocking = non_blocking + self.stream = stream + self.next_group = next_group + self.cpu_param_dict = cpu_param_dict def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: if self.offload_on_init: @@ -87,16 +98,34 @@ def post_forward(self, module: torch.nn.Module, output): def onload_(self, module: torch.nn.Module) -> None: if self.group.onload_leader == module: - for group_module in self.group.modules: - group_module.to(self.group.onload_device, non_blocking=self.non_blocking) + breakpoint() + if self.stream is not None: + # Wait for previous Host->Device transfer to complete + self.stream.synchronize() + + if self.next_group is None: + return + + # Start Host->Device transfer for next group + with torch.cuda.stream(self.stream): + for group_module in self.next_group.modules: + group_module.to(self.next_group.onload_device, non_blocking=True) + else: + for group_module in self.group.modules: + group_module.to(self.group.onload_device, non_blocking=self.non_blocking) def offload_(self, module: torch.nn.Module) -> None: if self.group.offload_leader == module: - for group_module in self.group.modules: - group_module.to(self.group.offload_device, non_blocking=self.non_blocking) - # TODO: do we need to sync here because of GPU->CPU transfer? - if self.non_blocking and self.group.offload_device.type == "cpu": - torch.cpu.synchronize() + if self.stream is not None: + for group_module in self.group.modules: + for param in group_module.parameters(): + param.data = self.cpu_param_dict[param] + else: + for group_module in self.group.modules: + group_module.to(self.group.offload_device, non_blocking=self.non_blocking) + # TODO: do we need to sync here because of GPU->CPU transfer? + if self.non_blocking and self.group.offload_device.type == "cpu": + torch.cpu.synchronize() def apply_group_offloading( @@ -107,12 +136,22 @@ def apply_group_offloading( onload_device: torch.device = torch.device("cuda"), force_offload: bool = True, non_blocking: bool = False, + cuda_stream: bool = False, ) -> None: + stream = None + if cuda_stream: + stream = torch.cuda.Stream() if offload_group_patterns == "diffusers_block": if num_blocks_per_group is None: raise ValueError("num_blocks_per_group must be provided when using GroupOffloadingType.DIFFUSERS_BLOCK.") _apply_group_offloading_diffusers_block( - module, num_blocks_per_group, offload_device, onload_device, force_offload, non_blocking + module, + num_blocks_per_group, + offload_device, + onload_device, + force_offload, + non_blocking, + stream, ) else: _apply_group_offloading_group_patterns( @@ -127,7 +166,14 @@ def _apply_group_offloading_diffusers_block( onload_device: torch.device, force_offload: bool, non_blocking: bool, + stream: Optional[torch.cuda.Stream] = None, ) -> None: + cpu_param_dict = None + if stream is not None: + for param in module.parameters(): + param.data = param.data.cpu().pin_memory() + cpu_param_dict = {param: param.data for param in module.parameters()} + # Handle device offloading/onloading for unet/transformer stack modules for stack_identifier in _COMMON_STACK_IDENTIFIERS: if not hasattr(module, stack_identifier) or not isinstance( @@ -137,14 +183,29 @@ def _apply_group_offloading_diffusers_block( stack = getattr(module, stack_identifier) num_blocks = len(stack) + module_groups = [] for i in range(0, num_blocks, num_blocks_per_group): blocks = stack[i : i + num_blocks_per_group] group = ModuleGroup( blocks, offload_device, onload_device, offload_leader=blocks[-1], onload_leader=blocks[0] ) + module_groups.append(group) + + for i, group in enumerate(module_groups): + next_group = module_groups[i + 1] if i + 1 < len(module_groups) and stream is not None else None should_offload = force_offload or i > 0 - _apply_group_offloading(group, should_offload, non_blocking) + _apply_group_offloading(group, should_offload, non_blocking, stream, next_group, cpu_param_dict) + + if stream is not None: + # Start Host->Device transfer for the first group + with torch.cuda.stream(stream): + for group_module in module_groups[0].modules: + group_module.to(onload_device, non_blocking=True) + if len(module_groups) > 1: + # Assign the first module_group as the next_group for the last module_group + hook_registry = HookRegistry.check_if_exists_or_initialize(module_groups[-1].onload_leader) + hook_registry.hooks["group_offloading"].next_group = module_groups[0] # Handle device offloading/onloading for non-stack modules for name, submodule in module.named_modules(): @@ -154,7 +215,6 @@ def _apply_group_offloading_diffusers_block( # for enabling offloading. continue layer_name = name_split[0] - print(layer_name) if layer_name in _COMMON_STACK_IDENTIFIERS: continue group = ModuleGroup( @@ -211,8 +271,15 @@ def _apply_group_offloading_group_patterns( _apply_group_offloading(group, force_offload, non_blocking) -def _apply_group_offloading(group: ModuleGroup, offload_on_init: bool, non_blocking: bool) -> None: +def _apply_group_offloading( + group: ModuleGroup, + offload_on_init: bool, + non_blocking: bool, + stream: Optional[torch.cuda.Stream] = None, + next_group: Optional[ModuleGroup] = None, + cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None, +) -> None: for module in group.modules: - hook = GroupOffloadingHook(group, offload_on_init, non_blocking) + hook = GroupOffloadingHook(group, offload_on_init, non_blocking, stream, next_group, cpu_param_dict) registry = HookRegistry.check_if_exists_or_initialize(module) registry.register_hook(hook, "group_offloading") diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index a895340bd124..4229c97d01ee 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -240,6 +240,7 @@ def forward( norm_hidden_states = self.norm1(hidden_states) num_ada_params = self.scale_shift_table.shape[0] + breakpoint() ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa From 193b98c5260b924466de071dcc9174b4eb80d5e6 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 10 Jan 2025 06:29:47 +0100 Subject: [PATCH 2/2] remove breakpoints --- src/diffusers/hooks/group_offloading.py | 1 - src/diffusers/models/transformers/transformer_ltx.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 98ceff2f1a58..e2f0c73f1d0c 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -98,7 +98,6 @@ def post_forward(self, module: torch.nn.Module, output): def onload_(self, module: torch.nn.Module) -> None: if self.group.onload_leader == module: - breakpoint() if self.stream is not None: # Wait for previous Host->Device transfer to complete self.stream.synchronize() diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 4229c97d01ee..a895340bd124 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -240,7 +240,6 @@ def forward( norm_hidden_states = self.norm1(hidden_states) num_ada_params = self.scale_shift_table.shape[0] - breakpoint() ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa