Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Pyramid Attention Broadcast #9562

Merged
merged 80 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from 72 commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
67c729d
start pyramid attention broadcast
a-r-r-o-w Oct 1, 2024
6d3bdb5
add coauthor
a-r-r-o-w Oct 3, 2024
3737101
update
a-r-r-o-w Oct 3, 2024
d5c738d
make style
a-r-r-o-w Oct 3, 2024
1c97e04
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Oct 3, 2024
ae4abb1
update
a-r-r-o-w Oct 3, 2024
955e4f7
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Oct 3, 2024
9f6987f
make style
a-r-r-o-w Oct 3, 2024
b3547c6
add docs
a-r-r-o-w Oct 4, 2024
afd0c17
add tests
a-r-r-o-w Oct 4, 2024
6265b65
update
a-r-r-o-w Oct 5, 2024
6816fe1
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Oct 15, 2024
9cb4e87
Update docs/source/en/api/pipelines/cogvideox.md
a-r-r-o-w Oct 15, 2024
6b1f55e
Update docs/source/en/api/pipelines/cogvideox.md
a-r-r-o-w Oct 15, 2024
37d2366
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Oct 30, 2024
a5f51bb
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Oct 31, 2024
18b7d6d
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Nov 5, 2024
c52cf42
Pyramid Attention Broadcast rewrite + introduce hooks (#9826)
a-r-r-o-w Nov 8, 2024
3de2c18
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Nov 9, 2024
d95d61a
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Dec 6, 2024
6090575
merge pyramid-attention-rewrite-2
a-r-r-o-w Dec 9, 2024
af51f5d
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Dec 9, 2024
903514f
make style
a-r-r-o-w Dec 9, 2024
b690db2
remove changes from latte transformer
a-r-r-o-w Dec 9, 2024
63ab886
revert docs changes
a-r-r-o-w Dec 9, 2024
d40bced
better debug message
a-r-r-o-w Dec 9, 2024
0ea904e
add todos for future
a-r-r-o-w Dec 9, 2024
9d452dc
update tests
a-r-r-o-w Dec 9, 2024
cfe3921
make style
a-r-r-o-w Dec 9, 2024
b972c4b
cleanup
a-r-r-o-w Dec 9, 2024
2b558ff
fix
a-r-r-o-w Dec 9, 2024
0b2629d
improve log message; fix latte test
a-r-r-o-w Dec 9, 2024
9182f57
refactor
a-r-r-o-w Dec 9, 2024
d974401
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Dec 30, 2024
62b5b8d
update
a-r-r-o-w Dec 30, 2024
bb250d6
update
a-r-r-o-w Dec 30, 2024
cbc086f
update
a-r-r-o-w Dec 30, 2024
7debcec
revert changes to tests
a-r-r-o-w Dec 30, 2024
a5c34af
update docs
a-r-r-o-w Dec 30, 2024
ad24269
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Dec 30, 2024
bbcde6b
update tests
a-r-r-o-w Dec 30, 2024
b148ab4
Apply suggestions from code review
a-r-r-o-w Dec 31, 2024
d4ecd6c
update
a-r-r-o-w Dec 31, 2024
6cca58f
fix flux test
a-r-r-o-w Dec 31, 2024
c2e0e3b
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Dec 31, 2024
d9fad00
reorder
a-r-r-o-w Jan 2, 2025
35296eb
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Jan 13, 2025
2436b3f
refactor
a-r-r-o-w Jan 13, 2025
95c8148
make fix-copies
a-r-r-o-w Jan 13, 2025
76afc6a
update docs
a-r-r-o-w Jan 13, 2025
fb66167
fixes
a-r-r-o-w Jan 13, 2025
1040c91
more fixes
a-r-r-o-w Jan 13, 2025
ffbabb5
make style
a-r-r-o-w Jan 13, 2025
1b92b1d
update tests
a-r-r-o-w Jan 13, 2025
88d917d
update code example
a-r-r-o-w Jan 13, 2025
e4d8b12
make fix-copies
a-r-r-o-w Jan 13, 2025
cc94647
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Jan 13, 2025
071a0ba
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Jan 15, 2025
ae8bd99
refactor based on reviews
a-r-r-o-w Jan 15, 2025
a9ee5a4
use maybe_free_model_hooks
a-r-r-o-w Jan 15, 2025
1a59688
CacheMixin
a-r-r-o-w Jan 15, 2025
c8616a6
make style
a-r-r-o-w Jan 15, 2025
08a209d
update
a-r-r-o-w Jan 15, 2025
15e645d
add current_timestep property; update docs
a-r-r-o-w Jan 15, 2025
d6ce4ab
make fix-copies
a-r-r-o-w Jan 15, 2025
96fce86
update
a-r-r-o-w Jan 15, 2025
107e375
improve tests
a-r-r-o-w Jan 15, 2025
f7d7e38
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Jan 15, 2025
40fc7a5
try circular import fix
a-r-r-o-w Jan 15, 2025
248f103
apply suggestions from review
a-r-r-o-w Jan 15, 2025
0a290a6
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Jan 16, 2025
fe93975
address review comments
a-r-r-o-w Jan 16, 2025
2b59994
Apply suggestions from code review
a-r-r-o-w Jan 17, 2025
a8e460e
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Jan 22, 2025
8c74a7a
refactor hook implementation
a-r-r-o-w Jan 22, 2025
3f3e26a
add test suite for hooks
a-r-r-o-w Jan 22, 2025
83d221f
PAB Refactor (#10667)
a-r-r-o-w Jan 27, 2025
847760e
Merge branch 'main' into pyramid-attention-broadcast
a-r-r-o-w Jan 27, 2025
3d269ad
update
a-r-r-o-w Jan 27, 2025
5535fd6
fix remove hook behaviour
a-r-r-o-w Jan 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,8 @@
title: Attention Processor
- local: api/activations
title: Custom activation functions
- local: api/cache
title: Caching methods
- local: api/normalization
title: Custom normalization layers
- local: api/utilities
Expand Down
49 changes: 49 additions & 0 deletions docs/source/en/api/cache.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->

# Caching methods

## Pyramid Attention Broadcast

[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.

Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.

Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.

```python
import torch
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig

pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")

# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
# poorer quality of generated videos.
config = PyramidAttentionBroadcastConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(100, 800),
current_timestep_callback=lambda: pipe.current_timestep,
)
pipe.transformer.enable_cache(config)
```

### CacheMixin

[[autodoc]] CacheMixin

### PyramidAttentionBroadcastConfig

[[autodoc]] PyramidAttentionBroadcastConfig

[[autodoc]] apply_pyramid_attention_broadcast
11 changes: 11 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

_import_structure = {
"configuration_utils": ["ConfigMixin"],
"hooks": [],
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
Expand Down Expand Up @@ -75,6 +76,13 @@
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]

else:
_import_structure["hooks"].extend(
[
"HookRegistry",
"PyramidAttentionBroadcastConfig",
"apply_pyramid_attention_broadcast",
]
)
_import_structure["models"].extend(
[
"AllegroTransformer3DModel",
Expand All @@ -90,6 +98,7 @@
"AutoencoderKLTemporalDecoder",
"AutoencoderOobleck",
"AutoencoderTiny",
"CacheMixin",
"CogVideoXTransformer3DModel",
"CogView3PlusTransformer2DModel",
"ConsistencyDecoderVAE",
Expand Down Expand Up @@ -586,6 +595,7 @@
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
from .hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .models import (
AllegroTransformer3DModel,
AsymmetricAutoencoderKL,
Expand All @@ -600,6 +610,7 @@
AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
AutoencoderTiny,
CacheMixin,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
ConsistencyDecoderVAE,
Expand Down
6 changes: 6 additions & 0 deletions src/diffusers/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ..utils import is_torch_available


if is_torch_available():
from .hooks import HookRegistry
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
188 changes: 188 additions & 0 deletions src/diffusers/hooks/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from typing import Any, Dict, Optional, Tuple

import torch

from ..utils.logging import get_logger


logger = get_logger(__name__) # pylint: disable=invalid-name


class ModelHook:
r"""
A hook that contains callbacks to be executed just before and after the forward method of a model.
"""

_is_stateful = False

def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when a model is initialized.

Args:
module (`torch.nn.Module`):
The module attached to this hook.
"""
return module

def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when a model is deinitalized.

Args:
module (`torch.nn.Module`):
The module attached to this hook.
"""
module.forward = module._old_forward
del module._old_forward
return module

def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
r"""
Hook that is executed just before the forward method of the model.

Args:
module (`torch.nn.Module`):
The module whose forward pass will be executed just after this event.
args (`Tuple[Any]`):
The positional arguments passed to the module.
kwargs (`Dict[Str, Any]`):
The keyword arguments passed to the module.
Returns:
`Tuple[Tuple[Any], Dict[Str, Any]]`:
A tuple with the treated `args` and `kwargs`.
"""
return args, kwargs

def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
r"""
Hook that is executed just after the forward method of the model.

Args:
module (`torch.nn.Module`):
The module whose forward pass been executed just before this event.
output (`Any`):
The output of the module.
Returns:
`Any`: The processed `output`.
"""
return output

def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when the hook is detached from a module.

Args:
module (`torch.nn.Module`):
The module detached from this hook.
"""
return module

def reset_state(self, module: torch.nn.Module):
if self._is_stateful:
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
return module


class HookRegistry:
def __init__(self, module_ref: torch.nn.Module) -> None:
super().__init__()

self.hooks: Dict[str, ModelHook] = {}

self._module_ref = module_ref
self._hook_order = []

def register_hook(self, hook: ModelHook, name: str) -> None:
if name in self.hooks.keys():
logger.warning(f"Hook with name {name} already exists, replacing it.")

if hasattr(self._module_ref, "_old_forward"):
old_forward = self._module_ref._old_forward
else:
old_forward = self._module_ref.forward
self._module_ref._old_forward = self._module_ref.forward

self._module_ref = hook.initialize_hook(self._module_ref)

if hasattr(hook, "new_forward"):
rewritten_forward = hook.new_forward

def new_forward(module, *args, **kwargs):
args, kwargs = hook.pre_forward(module, *args, **kwargs)
output = rewritten_forward(module, *args, **kwargs)
return hook.post_forward(module, output)
else:

def new_forward(module, *args, **kwargs):
args, kwargs = hook.pre_forward(module, *args, **kwargs)
output = old_forward(*args, **kwargs)
return hook.post_forward(module, output)

self._module_ref.forward = functools.update_wrapper(
functools.partial(new_forward, self._module_ref), old_forward
)

self.hooks[name] = hook
self._hook_order.append(name)

def get_hook(self, name: str) -> Optional[ModelHook]:
if name not in self.hooks.keys():
return None
return self.hooks[name]

def remove_hook(self, name: str, recurse: bool = True) -> None:
if name in self.hooks.keys():
hook = self.hooks[name]
self._module_ref = hook.deinitalize_hook(self._module_ref)
del self.hooks[name]
self._hook_order.remove(name)

if recurse:
for module_name, module in self._module_ref.named_modules():
if module_name == "":
continue
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook.remove_hook(name, recurse=False)

def reset_stateful_hooks(self, recurse: bool = True) -> None:
for hook_name in self._hook_order:
hook = self.hooks[hook_name]
if hook._is_stateful:
hook.reset_state(self._module_ref)

if recurse:
for module_name, module in self._module_ref.named_modules():
if module_name == "":
continue
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook.reset_stateful_hooks(recurse=False)

@classmethod
def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry":
if not hasattr(module, "_diffusers_hook"):
module._diffusers_hook = cls(module)
return module._diffusers_hook

def __repr__(self) -> str:
hook_repr = ""
for i, hook_name in enumerate(self._hook_order):
hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})"
if i < len(self._hook_order) - 1:
hook_repr += "\n"
return f"HookRegistry(\n{hook_repr}\n)"
Loading
Loading