Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Pyramid Attention Broadcast #9562

Merged
merged 80 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from 76 commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
67c729d
start pyramid attention broadcast
a-r-r-o-w Oct 1, 2024
6d3bdb5
add coauthor
a-r-r-o-w Oct 3, 2024
3737101
update
a-r-r-o-w Oct 3, 2024
d5c738d
make style
a-r-r-o-w Oct 3, 2024
1c97e04
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Oct 3, 2024
ae4abb1
update
a-r-r-o-w Oct 3, 2024
955e4f7
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Oct 3, 2024
9f6987f
make style
a-r-r-o-w Oct 3, 2024
b3547c6
add docs
a-r-r-o-w Oct 4, 2024
afd0c17
add tests
a-r-r-o-w Oct 4, 2024
6265b65
update
a-r-r-o-w Oct 5, 2024
6816fe1
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Oct 15, 2024
9cb4e87
Update docs/source/en/api/pipelines/cogvideox.md
a-r-r-o-w Oct 15, 2024
6b1f55e
Update docs/source/en/api/pipelines/cogvideox.md
a-r-r-o-w Oct 15, 2024
37d2366
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Oct 30, 2024
a5f51bb
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Oct 31, 2024
18b7d6d
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Nov 5, 2024
c52cf42
Pyramid Attention Broadcast rewrite + introduce hooks (#9826)
a-r-r-o-w Nov 8, 2024
3de2c18
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Nov 9, 2024
d95d61a
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Dec 6, 2024
6090575
merge pyramid-attention-rewrite-2
a-r-r-o-w Dec 9, 2024
af51f5d
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Dec 9, 2024
903514f
make style
a-r-r-o-w Dec 9, 2024
b690db2
remove changes from latte transformer
a-r-r-o-w Dec 9, 2024
63ab886
revert docs changes
a-r-r-o-w Dec 9, 2024
d40bced
better debug message
a-r-r-o-w Dec 9, 2024
0ea904e
add todos for future
a-r-r-o-w Dec 9, 2024
9d452dc
update tests
a-r-r-o-w Dec 9, 2024
cfe3921
make style
a-r-r-o-w Dec 9, 2024
b972c4b
cleanup
a-r-r-o-w Dec 9, 2024
2b558ff
fix
a-r-r-o-w Dec 9, 2024
0b2629d
improve log message; fix latte test
a-r-r-o-w Dec 9, 2024
9182f57
refactor
a-r-r-o-w Dec 9, 2024
d974401
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Dec 30, 2024
62b5b8d
update
a-r-r-o-w Dec 30, 2024
bb250d6
update
a-r-r-o-w Dec 30, 2024
cbc086f
update
a-r-r-o-w Dec 30, 2024
7debcec
revert changes to tests
a-r-r-o-w Dec 30, 2024
a5c34af
update docs
a-r-r-o-w Dec 30, 2024
ad24269
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Dec 30, 2024
bbcde6b
update tests
a-r-r-o-w Dec 30, 2024
b148ab4
Apply suggestions from code review
a-r-r-o-w Dec 31, 2024
d4ecd6c
update
a-r-r-o-w Dec 31, 2024
6cca58f
fix flux test
a-r-r-o-w Dec 31, 2024
c2e0e3b
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Dec 31, 2024
d9fad00
reorder
a-r-r-o-w Jan 2, 2025
35296eb
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Jan 13, 2025
2436b3f
refactor
a-r-r-o-w Jan 13, 2025
95c8148
make fix-copies
a-r-r-o-w Jan 13, 2025
76afc6a
update docs
a-r-r-o-w Jan 13, 2025
fb66167
fixes
a-r-r-o-w Jan 13, 2025
1040c91
more fixes
a-r-r-o-w Jan 13, 2025
ffbabb5
make style
a-r-r-o-w Jan 13, 2025
1b92b1d
update tests
a-r-r-o-w Jan 13, 2025
88d917d
update code example
a-r-r-o-w Jan 13, 2025
e4d8b12
make fix-copies
a-r-r-o-w Jan 13, 2025
cc94647
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Jan 13, 2025
071a0ba
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Jan 15, 2025
ae8bd99
refactor based on reviews
a-r-r-o-w Jan 15, 2025
a9ee5a4
use maybe_free_model_hooks
a-r-r-o-w Jan 15, 2025
1a59688
CacheMixin
a-r-r-o-w Jan 15, 2025
c8616a6
make style
a-r-r-o-w Jan 15, 2025
08a209d
update
a-r-r-o-w Jan 15, 2025
15e645d
add current_timestep property; update docs
a-r-r-o-w Jan 15, 2025
d6ce4ab
make fix-copies
a-r-r-o-w Jan 15, 2025
96fce86
update
a-r-r-o-w Jan 15, 2025
107e375
improve tests
a-r-r-o-w Jan 15, 2025
f7d7e38
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Jan 15, 2025
40fc7a5
try circular import fix
a-r-r-o-w Jan 15, 2025
248f103
apply suggestions from review
a-r-r-o-w Jan 15, 2025
0a290a6
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Jan 16, 2025
fe93975
address review comments
a-r-r-o-w Jan 16, 2025
2b59994
Apply suggestions from code review
a-r-r-o-w Jan 17, 2025
a8e460e
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Jan 22, 2025
8c74a7a
refactor hook implementation
a-r-r-o-w Jan 22, 2025
3f3e26a
add test suite for hooks
a-r-r-o-w Jan 22, 2025
83d221f
PAB Refactor (#10667)
a-r-r-o-w Jan 27, 2025
847760e
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Jan 27, 2025
3d269ad
update
a-r-r-o-w Jan 27, 2025
5535fd6
fix remove hook behaviour
a-r-r-o-w Jan 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,8 @@
title: Attention Processor
- local: api/activations
title: Custom activation functions
- local: api/cache
title: Caching methods
- local: api/normalization
title: Custom normalization layers
- local: api/utilities
Expand Down
49 changes: 49 additions & 0 deletions docs/source/en/api/cache.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->

# Caching methods

## Pyramid Attention Broadcast

[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.

Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.

Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.

```python
import torch
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig

pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")

# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
# poorer quality of generated videos.
config = PyramidAttentionBroadcastConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(100, 800),
current_timestep_callback=lambda: pipe.current_timestep,
)
pipe.transformer.enable_cache(config)
```

### CacheMixin

[[autodoc]] CacheMixin

### PyramidAttentionBroadcastConfig

[[autodoc]] PyramidAttentionBroadcastConfig

[[autodoc]] apply_pyramid_attention_broadcast
11 changes: 11 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

_import_structure = {
"configuration_utils": ["ConfigMixin"],
"hooks": [],
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
Expand Down Expand Up @@ -75,6 +76,13 @@
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]

else:
_import_structure["hooks"].extend(
[
"HookRegistry",
"PyramidAttentionBroadcastConfig",
"apply_pyramid_attention_broadcast",
]
)
_import_structure["models"].extend(
[
"AllegroTransformer3DModel",
Expand All @@ -90,6 +98,7 @@
"AutoencoderKLTemporalDecoder",
"AutoencoderOobleck",
"AutoencoderTiny",
"CacheMixin",
"CogVideoXTransformer3DModel",
"CogView3PlusTransformer2DModel",
"ConsisIDTransformer3DModel",
Expand Down Expand Up @@ -588,6 +597,7 @@
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
from .hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .models import (
AllegroTransformer3DModel,
AsymmetricAutoencoderKL,
Expand All @@ -602,6 +612,7 @@
AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
AutoencoderTiny,
CacheMixin,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
ConsisIDTransformer3DModel,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@


if is_torch_available():
from .hooks import HookRegistry, ModelHook
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
91 changes: 62 additions & 29 deletions src/diffusers/hooks/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import functools
import gc
from typing import Any, Dict, Optional, Tuple

import torch
Expand All @@ -30,6 +31,9 @@ class ModelHook:

_is_stateful = False

def __init__(self):
self.fn_ref: "FunctionReference" = None

def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when a model is initialized.
Expand All @@ -48,8 +52,6 @@ def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
module (`torch.nn.Module`):
The module attached to this hook.
"""
module.forward = module._old_forward
del module._old_forward
return module

def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
Expand Down Expand Up @@ -99,6 +101,14 @@ def reset_state(self, module: torch.nn.Module):
return module


class FunctionReference:
def __init__(self) -> None:
self.pre_forward = None
self.post_forward = None
self.old_forward = None
self.overwritten_forward = None


class HookRegistry:
def __init__(self, module_ref: torch.nn.Module) -> None:
super().__init__()
Expand All @@ -107,51 +117,68 @@ def __init__(self, module_ref: torch.nn.Module) -> None:

self._module_ref = module_ref
self._hook_order = []
self._fn_refs = []

def register_hook(self, hook: ModelHook, name: str) -> None:
if name in self.hooks.keys():
logger.warning(f"Hook with name {name} already exists, replacing it.")

if hasattr(self._module_ref, "_old_forward"):
old_forward = self._module_ref._old_forward
else:
old_forward = self._module_ref.forward
self._module_ref._old_forward = self._module_ref.forward

self._module_ref = hook.initialize_hook(self._module_ref)

if hasattr(hook, "new_forward"):
rewritten_forward = hook.new_forward

def create_new_forward(function_reference: FunctionReference):
def new_forward(module, *args, **kwargs):
args, kwargs = hook.pre_forward(module, *args, **kwargs)
output = rewritten_forward(module, *args, **kwargs)
return hook.post_forward(module, output)
else:
args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
output = function_reference.old_forward(*args, **kwargs)
return function_reference.post_forward(module, output)

def new_forward(module, *args, **kwargs):
args, kwargs = hook.pre_forward(module, *args, **kwargs)
output = old_forward(*args, **kwargs)
return hook.post_forward(module, output)
return new_forward

forward = self._module_ref.forward

fn_ref = FunctionReference()
fn_ref.pre_forward = hook.pre_forward
fn_ref.post_forward = hook.post_forward
fn_ref.old_forward = forward

if hasattr(hook, "new_forward"):
fn_ref.overwritten_forward = forward
fn_ref.old_forward = functools.update_wrapper(
functools.partial(hook.new_forward, self._module_ref), hook.new_forward
)

rewritten_forward = create_new_forward(fn_ref)
Copy link
Member Author

@a-r-r-o-w a-r-r-o-w Jan 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DN6 Some major changes made to the hooks addition/removal process to be able to support:

  • adding multiple hooks to affect the forward pass
  • remove hooks arbitrarily (out-of-order is supported as well)

Please take a look when you can. Happy to answer any questions and iterate further if needed

self._module_ref.forward = functools.update_wrapper(
functools.partial(new_forward, self._module_ref), old_forward
functools.partial(rewritten_forward, self._module_ref), rewritten_forward
)

hook.fn_ref = fn_ref
self.hooks[name] = hook
self._hook_order.append(name)
self._fn_refs.append(fn_ref)

def get_hook(self, name: str) -> Optional[ModelHook]:
if name not in self.hooks.keys():
return None
return self.hooks[name]
return self.hooks.get(name, None)

def remove_hook(self, name: str, recurse: bool = True) -> None:
num_hooks = len(self._hook_order)
if name in self.hooks.keys():
hook = self.hooks[name]
index = self._hook_order.index(name)
fn_ref = self._fn_refs[index]

old_forward = fn_ref.old_forward
if fn_ref.overwritten_forward is not None:
old_forward = fn_ref.overwritten_forward

if index == num_hooks - 1:
self._module_ref.forward = old_forward
else:
self._fn_refs[index + 1].old_forward = old_forward
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may look a bit weird - why are we assigning the function reference of next index to the old forward of the hook being removed?

TLDR; Hook invocation order is the reverse of the order in which they are added. Please take a look at the tests related to invocation for better understanding.

When we add hook A followed by hook B, the execution order of methods looks like:

forward:
  pre_forward of B
    pre_forward of A
      actual_module_orginal_forward
    post_forward of A
  post_forward of B

In this example, the function references would have A-related ones at index 0, and B related ones at index 1. Removing hook A requires pointing the B->old_forward to actual_module_forward. Removing hook B requires pointing module->forward to new_forward(A). We handle both cases here.

Let's take a more complex example to understand better. We add hook A that only has pre/post-forward. We add hook B that has a new_forward implementation. We add hook C that only has pre/post-forward. The invocation order would be

forward:
  pre_forward of C           
    pre_forward of B       /> pre_forward of A
      new_forward of B  --/     actual_module_original_forward
    post_forward of B     \</ post_forward of A
  post_forward of C

In this example, the function references would have A-related ones at index 0, and B related ones at index 1, and C related ones at index 2. Removing hook A requires pointing B->old_forward (since we overwrote the original forward implementation by making use of a new_forward method) to actual_module_original_forward. Removing hook B requires pointing C->old_forward to new_forward(A). Removing hook C requires pointing module->forward to new_forward(B)

On a separate note from the explanation, the invocation design here is very friendly with parellism too IMO, so we can eventually introduce different parallel methods utilizing the same hook design, without being invasive in the actual modeling implementations themselves.

There are a few tests added to make sure execution order is correct.


self._module_ref = hook.deinitalize_hook(self._module_ref)
del self.hooks[name]
self._hook_order.remove(name)
self._hook_order.pop(index)
self._fn_refs.pop(index)

if recurse:
for module_name, module in self._module_ref.named_modules():
Expand All @@ -160,8 +187,10 @@ def remove_hook(self, name: str, recurse: bool = True) -> None:
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook.remove_hook(name, recurse=False)

gc.collect()

def reset_stateful_hooks(self, recurse: bool = True) -> None:
for hook_name in self._hook_order:
for hook_name in reversed(self._hook_order):
hook = self.hooks[hook_name]
if hook._is_stateful:
hook.reset_state(self._module_ref)
Expand All @@ -180,9 +209,13 @@ def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry
return module._diffusers_hook

def __repr__(self) -> str:
hook_repr = ""
registry_repr = ""
for i, hook_name in enumerate(self._hook_order):
hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})"
if self.hooks[hook_name].__class__.__repr__ is not object.__repr__:
hook_repr = self.hooks[hook_name].__repr__()
else:
hook_repr = self.hooks[hook_name].__class__.__name__
registry_repr += f" ({i}) {hook_name} - {hook_repr}"
if i < len(self._hook_order) - 1:
hook_repr += "\n"
return f"HookRegistry(\n{hook_repr}\n)"
registry_repr += "\n"
return f"HookRegistry(\n{registry_repr}\n)"
Loading
Loading