Skip to content

Commit

Permalink
update; ~very workaround based implementation but it seems to work as…
Browse files Browse the repository at this point in the history
… expected; needs cleanup and rewrite
  • Loading branch information
a-r-r-o-w committed Jan 16, 2025
1 parent 80ac5a7 commit d2a2981
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 31 deletions.
115 changes: 84 additions & 31 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .hooks import HookRegistry, ModelHook


logger = get_logger(__name__) # pylint: disable=invalid-name
logger = get_logger(__name__) # pylint: disable=invalid-name


class ModuleGroup:
Expand All @@ -32,12 +32,16 @@ def __init__(
onload_device: torch.device,
offload_leader: torch.nn.Module,
onload_leader: Optional[torch.nn.Module] = None,
parameters: Optional[List[torch.nn.Parameter]] = None,
buffers: Optional[List[torch.Tensor]] = None,
) -> None:
self.modules = modules
self.offload_device = offload_device
self.onload_device = onload_device
self.offload_leader = offload_leader
self.onload_leader = onload_leader
self.parameters = parameters
self.buffers = buffers


class GroupOffloadingHook(ModelHook):
Expand All @@ -64,13 +68,15 @@ def __init__(
stream: Optional[torch.cuda.Stream] = None,
next_group: Optional[ModuleGroup] = None,
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
onload_self: bool = False,
) -> None:
self.group = group
self.offload_on_init = offload_on_init
self.non_blocking = non_blocking
self.stream = stream
self.next_group = next_group
self.cpu_param_dict = cpu_param_dict
self.onload_self = onload_self

def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
if self.offload_on_init:
Expand Down Expand Up @@ -100,9 +106,16 @@ def onload_(self, module: torch.nn.Module) -> None:
with torch.cuda.stream(self.stream):
for group_module in self.next_group.modules:
group_module.to(self.next_group.onload_device, non_blocking=True)
else:

if self.stream is None or self.onload_self:
for group_module in self.group.modules:
group_module.to(self.group.onload_device, non_blocking=self.non_blocking)
if self.group.parameters is not None:
for param in self.group.parameters:
param.data = param.data.to(self.group.onload_device, non_blocking=self.non_blocking)
if self.group.buffers is not None:
for buffer in self.group.buffers:
buffer.data = buffer.data.to(self.group.onload_device, non_blocking=self.non_blocking)

def offload_(self, module: torch.nn.Module) -> None:
if self.group.offload_leader == module:
Expand All @@ -113,6 +126,13 @@ def offload_(self, module: torch.nn.Module) -> None:
else:
for group_module in self.group.modules:
group_module.to(self.group.offload_device, non_blocking=self.non_blocking)
if self.group.parameters is not None:
for param in self.group.parameters:
param.data = param.data.to(self.group.offload_device, non_blocking=self.non_blocking)
if self.group.buffers is not None:
for buffer in self.group.buffers:
buffer.data = buffer.data.to(self.group.offload_device, non_blocking=self.non_blocking)

# TODO: do we need to sync here because of GPU->CPU transfer?
if self.non_blocking and self.group.offload_device.type == "cpu":
torch.cpu.synchronize()
Expand All @@ -128,9 +148,9 @@ def apply_group_offloading(
non_blocking: bool = False,
cuda_stream: bool = False,
) -> None:
# stream = None
# if cuda_stream:
# stream = torch.cuda.Stream()
stream = None
if cuda_stream:
stream = torch.cuda.Stream()
if offload_group_patterns == "modulelist_or_sequential":
if num_blocks_per_group is None:
raise ValueError(
Expand All @@ -148,7 +168,7 @@ def apply_group_offloading(
offload_group_patterns = _get_modulelist_or_sequential_group_patterns(module, num_blocks_per_group)

_apply_group_offloading_group_patterns(
module, offload_group_patterns, offload_device, onload_device, force_offload, non_blocking
module, offload_group_patterns, offload_device, onload_device, force_offload, non_blocking, stream=stream
)


Expand Down Expand Up @@ -231,6 +251,7 @@ def _apply_group_offloading_group_patterns(
onload_device: torch.device,
force_offload: bool,
non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None,
) -> None:
r"""
This function applies offloading to groups of modules based on the provided regex patterns. Each group of modules
Expand Down Expand Up @@ -269,8 +290,17 @@ def _apply_group_offloading_group_patterns(
non_blocking (`bool`):
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
and data transfer.
stream (`torch.cuda.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer.
"""

cpu_param_dict = None
if stream is not None:
for param in module.parameters():
param.data = param.data.cpu().pin_memory()
cpu_param_dict = {param: param.data for param in module.parameters()}

per_group_modules = [[] for _ in range(len(offload_group_patterns))]
per_group_offload_leaders = [None] * len(offload_group_patterns)
per_group_onload_leaders = [None] * len(offload_group_patterns)
Expand All @@ -280,20 +310,20 @@ def _apply_group_offloading_group_patterns(
offload_leader_patterns = [pattern[1] for pattern in offload_group_patterns]
onload_leader_patterns = [pattern[2] for pattern in offload_group_patterns]

for name, module in module.named_modules():
if name.count(".") > 1:
for name, submodule in module.named_modules():
if name == "" or name.count(".") > 1:
# We only want the layers that are top-level in the module (encompass all the other submodules)
# for enabling offloading. This method is specifically targeted for diffusers format models,
# so we can ignore submodules.
# TODO(aryan): This is not the case and is just a workaround to make the benchmark code work
# for now. We need to support the arbitrary nesting of modules here.
continue
num_matches = 0

# Check if the module matches any of the offload group patterns
num_matches = 0
for i, pattern in enumerate(group_patterns):
if re.search(pattern, name) is not None:
per_group_modules[i].append(module)
per_group_modules[i].append(submodule)
num_matches += 1

# Check if the module matches any of the offload leader patterns
Expand All @@ -303,7 +333,7 @@ def _apply_group_offloading_group_patterns(
raise ValueError(
f"Module {name} matches multiple offload leader patterns. Please ensure that offload leader patterns are mutually exclusive."
)
per_group_offload_leaders[i] = module
per_group_offload_leaders[i] = submodule

# Check if the module matches any of the onload leader patterns
for i, pattern in enumerate(onload_leader_patterns):
Expand All @@ -314,16 +344,17 @@ def _apply_group_offloading_group_patterns(
raise ValueError(
f"Module {name} matches multiple onload leader patterns. Please ensure that onload leader patterns are mutually exclusive."
)
per_group_onload_leaders[i] = module
per_group_onload_leaders[i] = submodule

if num_matches == 0:
unmatched_group_modules.append(module)
unmatched_group_modules.append((name, submodule))
elif num_matches > 1:
raise ValueError(
f"Module {name} matches multiple offload group patterns. Please ensure that offloading group patterns are mutually exclusive."
)

# Handle modules that matched patterns
groups = []
for i in range(len(per_group_modules)):
if per_group_offload_leaders[i] is None:
raise ValueError(
Expand All @@ -336,21 +367,40 @@ def _apply_group_offloading_group_patterns(
offload_leader=per_group_offload_leaders[i],
onload_leader=per_group_onload_leaders[i],
)
_apply_group_offloading(group, force_offload, non_blocking)

# Handle modules that did not match patterns
for module in unmatched_group_modules:
group = ModuleGroup([module], offload_device, onload_device, offload_leader=module, onload_leader=module)
_apply_group_offloading(group, force_offload, non_blocking)

# TODO(aryan): When you add stream support, this may need to be put in an if-branch
# Always keep parameters and buffers on onload_device
for name, param in module.named_parameters(recurse=False):
if torch.is_tensor(param.data):
param.data = param.data.to(onload_device)
groups.append(group)

for i in range(len(groups)):
next_group = groups[i + 1] if i + 1 < len(groups) and stream is not None else None
should_offload = force_offload or i > 0
_apply_group_offloading(
groups[i], should_offload, non_blocking, stream, next_group, cpu_param_dict, onload_self=False
)

# Ignore parameters/buffers if they're already accounted for in unmatched_group_modules (for example, a nn.Linear
# in the top-level module will also be present in the named_parameters iterator)
parameters = []
for name, parameter in module.named_parameters(recurse=False):
if not any(name.startswith(unmatched_name) for unmatched_name, _ in unmatched_group_modules):
parameters.append(parameter)

buffers = []
for name, buffer in module.named_buffers(recurse=False):
if torch.is_tensor(buffer.data):
buffer.data = buffer.data.to(onload_device)
if not any(name.startswith(unmatched_name) for unmatched_name, _ in unmatched_group_modules):
buffers.append(buffer)

unmatched_modules = [module for _, module in unmatched_group_modules]
unmatched_group = ModuleGroup(
unmatched_modules,
offload_device,
onload_device,
offload_leader=module,
onload_leader=None,
parameters=parameters,
buffers=buffers,
)
_apply_group_offloading(
unmatched_group, force_offload, non_blocking, stream, groups[0], cpu_param_dict, onload_self=True
)


def _apply_group_offloading(
Expand All @@ -360,9 +410,12 @@ def _apply_group_offloading(
stream: Optional[torch.cuda.Stream] = None,
next_group: Optional[ModuleGroup] = None,
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
onload_self: bool = False,
) -> None:
for module in group.modules:
hook = GroupOffloadingHook(group, offload_on_init, non_blocking, stream, next_group, cpu_param_dict)
hook = GroupOffloadingHook(
group, offload_on_init, non_blocking, stream, next_group, cpu_param_dict, onload_self
)
registry = HookRegistry.check_if_exists_or_initialize(module)
registry.register_hook(hook, "group_offloading")

Expand All @@ -375,11 +428,11 @@ def _get_modulelist_or_sequential_group_patterns(module: torch.nn.Module, num_bl
blocks. The generated patterns can be used to create ModuleGroup objects which are offloaded and onloaded together.
"""
group_patterns = []

# We only want the layers that are top-level in the module (encompass all the other submodules)
# for enabling offloading. This method is specifically targeted for diffusers format models,
# so we can ignore everything but the children of this module.
for name, submodule in module.children():
for name, submodule in module.named_children():
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
continue
for i in range(0, len(submodule), num_blocks_per_group):
Expand All @@ -389,6 +442,6 @@ def _get_modulelist_or_sequential_group_patterns(module: torch.nn.Module, num_bl
onload_leader_pattern = rf"{name}\.{i}\b"
offload_leader_pattern = rf"{name}\.{i + num_modules - 1}\b"
group_patterns.append((pattern, offload_leader_pattern, onload_leader_pattern))

logger.debug(f"Generated group patterns for apply_groupwise_offloading: {group_patterns}")
return group_patterns
5 changes: 5 additions & 0 deletions src/diffusers/hooks/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class ModelHook:
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when a model is initialized.
Args:
module (`torch.nn.Module`):
The module attached to this hook.
Expand All @@ -42,6 +43,7 @@ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when a model is deinitalized.
Args:
module (`torch.nn.Module`):
The module attached to this hook.
Expand All @@ -53,6 +55,7 @@ def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
r"""
Hook that is executed just before the forward method of the model.
Args:
module (`torch.nn.Module`):
The module whose forward pass will be executed just after this event.
Expand All @@ -69,6 +72,7 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[A
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
r"""
Hook that is executed just after the forward method of the model.
Args:
module (`torch.nn.Module`):
The module whose forward pass been executed just before this event.
Expand All @@ -82,6 +86,7 @@ def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when the hook is detached from a module.
Args:
module (`torch.nn.Module`):
The module detached from this hook.
Expand Down

0 comments on commit d2a2981

Please sign in to comment.