Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Group offloading with cuda stream prefetching #10516

Merged
merged 2 commits into from
Jan 11, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 80 additions & 14 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import re
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union

import torch

Expand Down Expand Up @@ -65,10 +65,21 @@ class GroupOffloadingHook(ModelHook):
encounter such an error.
"""

def __init__(self, group: ModuleGroup, offload_on_init: bool = True, non_blocking: bool = False) -> None:
def __init__(
self,
group: ModuleGroup,
offload_on_init: bool = True,
non_blocking: bool = False,
stream: Optional[torch.cuda.Stream] = None,
next_group: Optional[ModuleGroup] = None,
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
) -> None:
self.group = group
self.offload_on_init = offload_on_init
self.non_blocking = non_blocking
self.stream = stream
self.next_group = next_group
self.cpu_param_dict = cpu_param_dict

def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
if self.offload_on_init:
Expand All @@ -87,16 +98,33 @@ def post_forward(self, module: torch.nn.Module, output):

def onload_(self, module: torch.nn.Module) -> None:
if self.group.onload_leader == module:
for group_module in self.group.modules:
group_module.to(self.group.onload_device, non_blocking=self.non_blocking)
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()

if self.next_group is None:
return

# Start Host->Device transfer for next group
with torch.cuda.stream(self.stream):
for group_module in self.next_group.modules:
group_module.to(self.next_group.onload_device, non_blocking=True)
else:
for group_module in self.group.modules:
group_module.to(self.group.onload_device, non_blocking=self.non_blocking)

def offload_(self, module: torch.nn.Module) -> None:
if self.group.offload_leader == module:
for group_module in self.group.modules:
group_module.to(self.group.offload_device, non_blocking=self.non_blocking)
# TODO: do we need to sync here because of GPU->CPU transfer?
if self.non_blocking and self.group.offload_device.type == "cpu":
torch.cpu.synchronize()
if self.stream is not None:
for group_module in self.group.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
else:
for group_module in self.group.modules:
group_module.to(self.group.offload_device, non_blocking=self.non_blocking)
# TODO: do we need to sync here because of GPU->CPU transfer?
if self.non_blocking and self.group.offload_device.type == "cpu":
torch.cpu.synchronize()


def apply_group_offloading(
Expand All @@ -107,12 +135,22 @@ def apply_group_offloading(
onload_device: torch.device = torch.device("cuda"),
force_offload: bool = True,
non_blocking: bool = False,
cuda_stream: bool = False,
) -> None:
stream = None
if cuda_stream:
stream = torch.cuda.Stream()
if offload_group_patterns == "diffusers_block":
if num_blocks_per_group is None:
raise ValueError("num_blocks_per_group must be provided when using GroupOffloadingType.DIFFUSERS_BLOCK.")
_apply_group_offloading_diffusers_block(
module, num_blocks_per_group, offload_device, onload_device, force_offload, non_blocking
module,
num_blocks_per_group,
offload_device,
onload_device,
force_offload,
non_blocking,
stream,
)
else:
_apply_group_offloading_group_patterns(
Expand All @@ -127,7 +165,14 @@ def _apply_group_offloading_diffusers_block(
onload_device: torch.device,
force_offload: bool,
non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None,
) -> None:
cpu_param_dict = None
if stream is not None:
for param in module.parameters():
param.data = param.data.cpu().pin_memory()
cpu_param_dict = {param: param.data for param in module.parameters()}

# Handle device offloading/onloading for unet/transformer stack modules
for stack_identifier in _COMMON_STACK_IDENTIFIERS:
if not hasattr(module, stack_identifier) or not isinstance(
Expand All @@ -137,14 +182,29 @@ def _apply_group_offloading_diffusers_block(

stack = getattr(module, stack_identifier)
num_blocks = len(stack)
module_groups = []

for i in range(0, num_blocks, num_blocks_per_group):
blocks = stack[i : i + num_blocks_per_group]
group = ModuleGroup(
blocks, offload_device, onload_device, offload_leader=blocks[-1], onload_leader=blocks[0]
)
module_groups.append(group)

for i, group in enumerate(module_groups):
next_group = module_groups[i + 1] if i + 1 < len(module_groups) and stream is not None else None
should_offload = force_offload or i > 0
_apply_group_offloading(group, should_offload, non_blocking)
_apply_group_offloading(group, should_offload, non_blocking, stream, next_group, cpu_param_dict)

if stream is not None:
# Start Host->Device transfer for the first group
with torch.cuda.stream(stream):
for group_module in module_groups[0].modules:
group_module.to(onload_device, non_blocking=True)
if len(module_groups) > 1:
# Assign the first module_group as the next_group for the last module_group
hook_registry = HookRegistry.check_if_exists_or_initialize(module_groups[-1].onload_leader)
hook_registry.hooks["group_offloading"].next_group = module_groups[0]
Comment on lines +204 to +207
Copy link
Member Author

@a-r-r-o-w a-r-r-o-w Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit hacky for the moment just so that I could get it running quickly without putting in much thought. Will try to improve soon


# Handle device offloading/onloading for non-stack modules
for name, submodule in module.named_modules():
Expand All @@ -154,7 +214,6 @@ def _apply_group_offloading_diffusers_block(
# for enabling offloading.
continue
layer_name = name_split[0]
print(layer_name)
if layer_name in _COMMON_STACK_IDENTIFIERS:
continue
group = ModuleGroup(
Expand Down Expand Up @@ -211,8 +270,15 @@ def _apply_group_offloading_group_patterns(
_apply_group_offloading(group, force_offload, non_blocking)


def _apply_group_offloading(group: ModuleGroup, offload_on_init: bool, non_blocking: bool) -> None:
def _apply_group_offloading(
group: ModuleGroup,
offload_on_init: bool,
non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None,
next_group: Optional[ModuleGroup] = None,
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
) -> None:
for module in group.modules:
hook = GroupOffloadingHook(group, offload_on_init, non_blocking)
hook = GroupOffloadingHook(group, offload_on_init, non_blocking, stream, next_group, cpu_param_dict)
registry = HookRegistry.check_if_exists_or_initialize(module)
registry.register_hook(hook, "group_offloading")