Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Module Group Offloading #10503

Merged
merged 50 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
d1737e3
update
a-r-r-o-w Jan 9, 2025
2783669
fix
a-r-r-o-w Jan 9, 2025
6a9a3e5
non_blocking; handle parameters and buffers
a-r-r-o-w Jan 10, 2025
c426a34
update
a-r-r-o-w Jan 10, 2025
d579037
Group offloading with cuda stream prefetching (#10516)
a-r-r-o-w Jan 11, 2025
5f33621
Merge branch 'main' into groupwise-offloading
a-r-r-o-w Jan 11, 2025
a8eabd0
update
a-r-r-o-w Jan 12, 2025
deda9a3
Merge branch 'main' into groupwise-offloading
a-r-r-o-w Jan 16, 2025
80ac5a7
copy model hook implementation from pab
a-r-r-o-w Jan 16, 2025
d2a2981
update; ~very workaround based implementation but it seems to work as…
a-r-r-o-w Jan 16, 2025
01c7d22
more workarounds to make it actually work
a-r-r-o-w Jan 16, 2025
22aff34
cleanup
a-r-r-o-w Jan 16, 2025
42bc19b
rewrite
a-r-r-o-w Jan 17, 2025
8c63bf5
update
a-r-r-o-w Jan 19, 2025
e09e716
make sure to sync current stream before overwriting with pinned params
a-r-r-o-w Jan 19, 2025
bf379c1
Merge branch 'main' into groupwise-offloading
a-r-r-o-w Jan 19, 2025
0bf0baf
better check
a-r-r-o-w Jan 19, 2025
b850c75
update
a-r-r-o-w Jan 20, 2025
6ed9c2f
remove hook implementation to not deal with merge conflict
a-r-r-o-w Jan 23, 2025
13dd337
Merge branch 'main' into groupwise-offloading
a-r-r-o-w Jan 23, 2025
073d4bc
re-add hook changes
a-r-r-o-w Jan 23, 2025
8ba2bda
why use more memory when less memory do trick
a-r-r-o-w Jan 23, 2025
b2e838f
why still use slightly more memory when less memory do trick
a-r-r-o-w Jan 23, 2025
f30c55f
Merge branch 'main' into groupwise-offloading
a-r-r-o-w Jan 23, 2025
5ea3d8a
optimise
a-r-r-o-w Jan 26, 2025
db2fd3b
add model tests
a-r-r-o-w Jan 26, 2025
a0160e1
add pipeline tests
a-r-r-o-w Jan 26, 2025
aaa9a53
update docs
a-r-r-o-w Jan 26, 2025
17b2753
Merge branch 'main' into groupwise-offloading
a-r-r-o-w Jan 26, 2025
edf8103
add layernorm and groupnorm
a-r-r-o-w Jan 26, 2025
af62c93
Merge branch 'main' into groupwise-offloading
a-r-r-o-w Jan 28, 2025
f227e15
Merge branch 'main' into groupwise-offloading
a-r-r-o-w Feb 4, 2025
24f9273
address review comments
a-r-r-o-w Feb 4, 2025
8f10d05
improve tests; add docs
a-r-r-o-w Feb 4, 2025
06b411f
improve docs
a-r-r-o-w Feb 4, 2025
8bd7e3b
Merge branch 'main' into groupwise-offloading
a-r-r-o-w Feb 4, 2025
904e470
Apply suggestions from code review
a-r-r-o-w Feb 5, 2025
3172ed5
apply suggestions from code review
a-r-r-o-w Feb 5, 2025
72aa57f
Merge branch 'main' into groupwise-offloading
a-r-r-o-w Feb 5, 2025
aee24bc
update tests
a-r-r-o-w Feb 5, 2025
db125ce
Merge branch 'main' into groupwise-offloading
a-r-r-o-w Feb 6, 2025
3f20e6b
apply suggestions from review
a-r-r-o-w Feb 6, 2025
840576a
enable_group_offloading -> enable_group_offload for naming consistency
a-r-r-o-w Feb 6, 2025
8804d74
raise errors if multiple offloading strategies used; add relevant tests
a-r-r-o-w Feb 6, 2025
954bb7d
handle .to() when group offload applied
a-r-r-o-w Feb 6, 2025
ba6c4a8
Merge branch 'main' into groupwise-offloading
a-r-r-o-w Feb 6, 2025
da88c33
refactor some repeated code
a-r-r-o-w Feb 6, 2025
a872e84
remove unintentional change from merge conflict
a-r-r-o-w Feb 6, 2025
6be43b8
handle .cuda()
a-r-r-o-w Feb 6, 2025
274b84e
Merge branch 'main' into groupwise-offloading
a-r-r-o-w Feb 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/api/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,7 @@ Utility and helper functions for working with 🤗 Diffusers.
## apply_layerwise_casting

[[autodoc]] hooks.layerwise_casting.apply_layerwise_casting

## apply_group_offloading

[[autodoc]] hooks.group_offloading.apply_group_offloading
40 changes: 40 additions & 0 deletions docs/source/en/optimization/memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,46 @@ In order to properly offload models after they're called, it is required to run

</Tip>

## Group offloading

Group offloading is the middle ground between sequential and model offloading. It works by offloading groups of internal layers (either `torch.nn.ModuleList` or `torch.nn.Sequential`), which uses less memory than model-level offloading. It is also faster than sequential-level offloading because the number of device synchronizations is reduced.

To enable group offloading, call the [`~ModelMixin.enable_group_offloading`] method on the model if it is a Diffusers model implementation. For any other model implementation, use [`~hooks.group_offloading.apply_group_offloading`]:

```python
import torch
from diffusers import CogVideoXPipeline
from diffusers.hooks import apply_group_offloading
from diffusers.utils import export_to_video

# Load the pipeline
onload_device = torch.device("cuda")
offload_device = torch.device("cpu")
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)

# We can utilize the enable_group_offloading method for Diffusers model implementations
pipe.transformer.enable_group_offloading(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True)

# For any other model implementations, the apply_group_offloading function can be used
apply_group_offloading(pipe.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2)
apply_group_offloading(pipe.vae, onload_device=onload_device, offload_type="leaf_level")

prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
# This utilized about 14.79 GB. It can be further reduced by using tiling and using leaf_level offloading throughout the pipeline.
print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
export_to_video(video, "output.mp4", fps=8)
```

Group offloading (for CUDA devices with support for asynchronous data transfer streams) overlaps data transfer and computation to reduce the overall execution time compared to sequential offloading. This is enabled using layer prefetching with CUDA streams. The next layer to be executed is loaded onto the accelerator device while the current layer is being executed - this increases the memory requirements slightly. Group offloading also supports leaf-level offloading (equivalent to sequential CPU offloading) but can be made much faster when using streams.

## FP8 layerwise weight-casting

PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting.
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


if is_torch_available():
from .group_offloading import apply_group_offloading
from .hooks import HookRegistry, ModelHook
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
647 changes: 647 additions & 0 deletions src/diffusers/hooks/group_offloading.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/diffusers/hooks/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def get_hook(self, name: str) -> Optional[ModelHook]:
return self.hooks.get(name, None)

def remove_hook(self, name: str, recurse: bool = True) -> None:
num_hooks = len(self._hook_order)
if name in self.hooks.keys():
num_hooks = len(self._hook_order)
hook = self.hooks[name]
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/autoencoders/autoencoder_oobleck.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ class AutoencoderOobleck(ModelMixin, ConfigMixin):
"""

_supports_gradient_checkpointing = False
_supports_group_offloading = False

@register_to_config
def __init__(
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/autoencoders/consistency_decoder_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
```
"""

_supports_group_offloading = False

@register_to_config
def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/autoencoders/vq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class VQModel(ModelMixin, ConfigMixin):
"""

_skip_layerwise_casting_patterns = ["quantize"]
_supports_group_offloading = False

@register_to_config
def __init__(
Expand Down
61 changes: 60 additions & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from torch import Tensor, nn

from .. import __version__
from ..hooks import apply_layerwise_casting
from ..hooks import apply_group_offloading, apply_layerwise_casting
from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
from ..quantizers.quantization_config import QuantizationMethod
from ..utils import (
Expand Down Expand Up @@ -87,6 +87,15 @@

def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
try:
if hasattr(parameter, "_diffusers_hook"):
for submodule in parameter.modules():
if not hasattr(submodule, "_diffusers_hook"):
continue
registry = parameter._diffusers_hook
hook = registry.get_hook("group_offloading")
if hook is not None:
return hook.group.onload_device

parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
return next(parameters_and_buffers).device
except StopIteration:
Expand Down Expand Up @@ -165,6 +174,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_no_split_modules = None
_keep_in_fp32_modules = None
_skip_layerwise_casting_patterns = None
_supports_group_offloading = True

def __init__(self):
super().__init__()
Expand Down Expand Up @@ -436,6 +446,55 @@ def enable_layerwise_casting(
self, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking
)

def enable_group_offloading(
self,
onload_device: torch.device,
offload_device: torch.device = torch.device("cpu"),
offload_type: str = "block_level",
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
) -> None:
r"""
Activates group offloading for the current model.

See [`~hooks.group_offloading.apply_group_offloading`] for more information.

Example:

```python
>>> from diffusers import CogVideoXTransformer3DModel

>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
... )

>>> transformer.enable_group_offloading(
... onload_device=torch.device("cuda"),
... offload_device=torch.device("cpu"),
... offload_type="leaf_level",
... use_stream=True,
... )
```
"""
if getattr(self, "enable_tiling", None) is not None and getattr(self, "use_tiling", False) and use_stream:
msg = (
"Applying group offloading on autoencoders, with CUDA streams, may not work as expected if the first "
"forward pass is executed with tiling enabled. Please make sure to either:\n"
"1. Run a forward pass with small input shapes.\n"
"2. Or, run a forward pass with tiling disabled (can still use small dummy inputs)."
)
logger.warning(msg)
if not self._supports_group_offloading:
raise ValueError(
f"{self.__class__.__name__} does not support group offloading. Please make sure to set the boolean attribute "
f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please "
f"open an issue at https://github.com/huggingface/diffusers/issues."
)
apply_group_offloading(
self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream
)

def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/dit_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):

_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_supports_gradient_checkpointing = True
_supports_group_offloading = False

@register_to_config
def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
"""

_skip_layerwise_casting_patterns = ["pos_embed", "norm", "pooler"]
_supports_group_offloading = False

@register_to_config
def __init__(
Expand Down
14 changes: 14 additions & 0 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,20 @@ def _execution_device(self):
[`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
Accelerate's module hooks.
"""
# When apply group offloading at the leaf_level, we're in the same situation as accelerate's sequential
# offloading. We need to return the onload device of the group offloading hooks so that the intermediates
# required for computation (latents, prompt embeddings, etc.) can be created on the correct device.
for name, model in self.components.items():
if not isinstance(model, torch.nn.Module):
continue
for submodule in model.modules():
if not hasattr(submodule, "_diffusers_hook"):
continue
registry = submodule._diffusers_hook
hook = registry.get_hook("group_offloading")
if hook is not None:
return hook.group.onload_device

for name, model in self.components.items():
if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload:
continue
Expand Down
155 changes: 155 additions & 0 deletions tests/hooks/test_group_offloading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import unittest

import torch

from diffusers.models import ModelMixin
from diffusers.utils.testing_utils import require_torch_gpu, torch_device


class DummyBlock(torch.nn.Module):
def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
super().__init__()

self.proj_in = torch.nn.Linear(in_features, hidden_features)
self.activation = torch.nn.ReLU()
self.proj_out = torch.nn.Linear(hidden_features, out_features)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj_in(x)
x = self.activation(x)
x = self.proj_out(x)
return x


class DummyModel(ModelMixin):
def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None:
super().__init__()

self.linear_1 = torch.nn.Linear(in_features, hidden_features)
self.activation = torch.nn.ReLU()
self.blocks = torch.nn.ModuleList(
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
)
self.linear_2 = torch.nn.Linear(hidden_features, out_features)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear_1(x)
x = self.activation(x)
for block in self.blocks:
x = block(x)
x = self.linear_2(x)
return x


@require_torch_gpu
class GroupOffloadTests(unittest.TestCase):
in_features = 64
hidden_features = 256
out_features = 64
num_layers = 4

def setUp(self):
with torch.no_grad():
self.model = self.get_model()
self.input = torch.randn((4, self.in_features)).to(torch_device)

def tearDown(self):
super().tearDown()

del self.model
del self.input
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

def get_model(self):
torch.manual_seed(0)
return DummyModel(
in_features=self.in_features,
hidden_features=self.hidden_features,
out_features=self.out_features,
num_layers=self.num_layers,
)

def test_offloading_forward_pass(self):
@torch.no_grad()
def run_forward(model):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
self.assertTrue(
all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in model.modules()
if hasattr(module, "_diffusers_hook")
)
)
model.eval()
output = model(self.input)[0].cpu()
max_memory_allocated = torch.cuda.max_memory_allocated()
return output, max_memory_allocated

self.model.to(torch_device)
output_without_group_offloading, mem_baseline = run_forward(self.model)
self.model.to("cpu")

model = self.get_model()
model.enable_group_offloading(torch_device, offload_type="block_level", num_blocks_per_group=3)
output_with_group_offloading1, mem1 = run_forward(model)

model = self.get_model()
model.enable_group_offloading(torch_device, offload_type="block_level", num_blocks_per_group=1)
output_with_group_offloading2, mem2 = run_forward(model)

model = self.get_model()
model.enable_group_offloading(
torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True
)
output_with_group_offloading3, mem3 = run_forward(model)

model = self.get_model()
model.enable_group_offloading(torch_device, offload_type="leaf_level")
output_with_group_offloading4, mem4 = run_forward(model)

model = self.get_model()
model.enable_group_offloading(torch_device, offload_type="leaf_level", use_stream=True)
output_with_group_offloading5, mem5 = run_forward(model)

# Precision assertions - offloading should not impact the output
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading5, atol=1e-5))

# Memory assertions - offloading should reduce memory usage
self.assertTrue(mem4 <= mem5 < mem2 < mem3 < mem1 < mem_baseline)

def test_error_raised_if_streams_used_and_no_cuda_device(self):
original_is_available = torch.cuda.is_available
torch.cuda.is_available = lambda: False
with self.assertRaises(ValueError):
self.model.enable_group_offloading(
onload_device=torch.device("cuda"), offload_type="leaf_level", use_stream=True
)
torch.cuda.is_available = original_is_available

def test_error_raised_if_supports_group_offloading_false(self):
self.model._supports_group_offloading = False
with self.assertRaisesRegex(ValueError, "does not support group offloading"):
self.model.enable_group_offloading(onload_device=torch.device("cuda"))
Loading