-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] Pyramid Attention Broadcast #9562
Changes from 76 commits
67c729d
6d3bdb5
3737101
d5c738d
1c97e04
ae4abb1
955e4f7
9f6987f
b3547c6
afd0c17
6265b65
6816fe1
9cb4e87
6b1f55e
37d2366
a5f51bb
18b7d6d
c52cf42
3de2c18
d95d61a
6090575
af51f5d
903514f
b690db2
63ab886
d40bced
0ea904e
9d452dc
cfe3921
b972c4b
2b558ff
0b2629d
9182f57
d974401
62b5b8d
bb250d6
cbc086f
7debcec
a5c34af
ad24269
bbcde6b
b148ab4
d4ecd6c
6cca58f
c2e0e3b
d9fad00
35296eb
2436b3f
95c8148
76afc6a
fb66167
1040c91
ffbabb5
1b92b1d
88d917d
e4d8b12
cc94647
071a0ba
ae8bd99
a9ee5a4
1a59688
c8616a6
08a209d
15e645d
d6ce4ab
96fce86
107e375
f7d7e38
40fc7a5
248f103
0a290a6
fe93975
2b59994
a8e460e
8c74a7a
3f3e26a
83d221f
847760e
3d269ad
5535fd6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. --> | ||
|
||
# Caching methods | ||
|
||
## Pyramid Attention Broadcast | ||
|
||
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You. | ||
|
||
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation. | ||
|
||
Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request. | ||
|
||
```python | ||
import torch | ||
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig | ||
|
||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) | ||
pipe.to("cuda") | ||
|
||
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of | ||
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention | ||
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to | ||
# poorer quality of generated videos. | ||
config = PyramidAttentionBroadcastConfig( | ||
spatial_attention_block_skip_range=2, | ||
spatial_attention_timestep_skip_range=(100, 800), | ||
current_timestep_callback=lambda: pipe.current_timestep, | ||
) | ||
pipe.transformer.enable_cache(config) | ||
``` | ||
|
||
### CacheMixin | ||
|
||
[[autodoc]] CacheMixin | ||
|
||
### PyramidAttentionBroadcastConfig | ||
|
||
[[autodoc]] PyramidAttentionBroadcastConfig | ||
|
||
[[autodoc]] apply_pyramid_attention_broadcast |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
# limitations under the License. | ||
|
||
import functools | ||
import gc | ||
from typing import Any, Dict, Optional, Tuple | ||
|
||
import torch | ||
|
@@ -30,6 +31,9 @@ class ModelHook: | |
|
||
_is_stateful = False | ||
|
||
def __init__(self): | ||
self.fn_ref: "FunctionReference" = None | ||
|
||
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: | ||
r""" | ||
Hook that is executed when a model is initialized. | ||
|
@@ -48,8 +52,6 @@ def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module: | |
module (`torch.nn.Module`): | ||
The module attached to this hook. | ||
""" | ||
module.forward = module._old_forward | ||
del module._old_forward | ||
return module | ||
|
||
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: | ||
|
@@ -99,6 +101,14 @@ def reset_state(self, module: torch.nn.Module): | |
return module | ||
|
||
|
||
class FunctionReference: | ||
def __init__(self) -> None: | ||
self.pre_forward = None | ||
self.post_forward = None | ||
self.old_forward = None | ||
self.overwritten_forward = None | ||
|
||
|
||
class HookRegistry: | ||
def __init__(self, module_ref: torch.nn.Module) -> None: | ||
super().__init__() | ||
|
@@ -107,51 +117,68 @@ def __init__(self, module_ref: torch.nn.Module) -> None: | |
|
||
self._module_ref = module_ref | ||
self._hook_order = [] | ||
self._fn_refs = [] | ||
|
||
def register_hook(self, hook: ModelHook, name: str) -> None: | ||
if name in self.hooks.keys(): | ||
logger.warning(f"Hook with name {name} already exists, replacing it.") | ||
|
||
if hasattr(self._module_ref, "_old_forward"): | ||
old_forward = self._module_ref._old_forward | ||
else: | ||
old_forward = self._module_ref.forward | ||
self._module_ref._old_forward = self._module_ref.forward | ||
|
||
self._module_ref = hook.initialize_hook(self._module_ref) | ||
|
||
if hasattr(hook, "new_forward"): | ||
rewritten_forward = hook.new_forward | ||
|
||
def create_new_forward(function_reference: FunctionReference): | ||
def new_forward(module, *args, **kwargs): | ||
args, kwargs = hook.pre_forward(module, *args, **kwargs) | ||
output = rewritten_forward(module, *args, **kwargs) | ||
return hook.post_forward(module, output) | ||
else: | ||
args, kwargs = function_reference.pre_forward(module, *args, **kwargs) | ||
output = function_reference.old_forward(*args, **kwargs) | ||
return function_reference.post_forward(module, output) | ||
|
||
def new_forward(module, *args, **kwargs): | ||
args, kwargs = hook.pre_forward(module, *args, **kwargs) | ||
output = old_forward(*args, **kwargs) | ||
return hook.post_forward(module, output) | ||
return new_forward | ||
|
||
forward = self._module_ref.forward | ||
|
||
fn_ref = FunctionReference() | ||
fn_ref.pre_forward = hook.pre_forward | ||
fn_ref.post_forward = hook.post_forward | ||
fn_ref.old_forward = forward | ||
|
||
if hasattr(hook, "new_forward"): | ||
fn_ref.overwritten_forward = forward | ||
fn_ref.old_forward = functools.update_wrapper( | ||
functools.partial(hook.new_forward, self._module_ref), hook.new_forward | ||
) | ||
|
||
rewritten_forward = create_new_forward(fn_ref) | ||
self._module_ref.forward = functools.update_wrapper( | ||
functools.partial(new_forward, self._module_ref), old_forward | ||
functools.partial(rewritten_forward, self._module_ref), rewritten_forward | ||
) | ||
|
||
hook.fn_ref = fn_ref | ||
self.hooks[name] = hook | ||
self._hook_order.append(name) | ||
self._fn_refs.append(fn_ref) | ||
|
||
def get_hook(self, name: str) -> Optional[ModelHook]: | ||
if name not in self.hooks.keys(): | ||
return None | ||
return self.hooks[name] | ||
return self.hooks.get(name, None) | ||
|
||
def remove_hook(self, name: str, recurse: bool = True) -> None: | ||
num_hooks = len(self._hook_order) | ||
if name in self.hooks.keys(): | ||
hook = self.hooks[name] | ||
index = self._hook_order.index(name) | ||
fn_ref = self._fn_refs[index] | ||
|
||
old_forward = fn_ref.old_forward | ||
if fn_ref.overwritten_forward is not None: | ||
old_forward = fn_ref.overwritten_forward | ||
|
||
if index == num_hooks - 1: | ||
self._module_ref.forward = old_forward | ||
else: | ||
self._fn_refs[index + 1].old_forward = old_forward | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This may look a bit weird - why are we assigning the function reference of next index to the old forward of the hook being removed? TLDR; Hook invocation order is the reverse of the order in which they are added. Please take a look at the tests related to invocation for better understanding. When we add hook
In this example, the function references would have Let's take a more complex example to understand better. We add hook
In this example, the function references would have On a separate note from the explanation, the invocation design here is very friendly with parellism too IMO, so we can eventually introduce different parallel methods utilizing the same hook design, without being invasive in the actual modeling implementations themselves. There are a few tests added to make sure execution order is correct. |
||
|
||
self._module_ref = hook.deinitalize_hook(self._module_ref) | ||
del self.hooks[name] | ||
self._hook_order.remove(name) | ||
self._hook_order.pop(index) | ||
self._fn_refs.pop(index) | ||
|
||
if recurse: | ||
for module_name, module in self._module_ref.named_modules(): | ||
|
@@ -160,8 +187,10 @@ def remove_hook(self, name: str, recurse: bool = True) -> None: | |
if hasattr(module, "_diffusers_hook"): | ||
module._diffusers_hook.remove_hook(name, recurse=False) | ||
|
||
gc.collect() | ||
|
||
def reset_stateful_hooks(self, recurse: bool = True) -> None: | ||
for hook_name in self._hook_order: | ||
for hook_name in reversed(self._hook_order): | ||
hook = self.hooks[hook_name] | ||
if hook._is_stateful: | ||
hook.reset_state(self._module_ref) | ||
|
@@ -180,9 +209,13 @@ def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry | |
return module._diffusers_hook | ||
|
||
def __repr__(self) -> str: | ||
hook_repr = "" | ||
registry_repr = "" | ||
for i, hook_name in enumerate(self._hook_order): | ||
hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})" | ||
if self.hooks[hook_name].__class__.__repr__ is not object.__repr__: | ||
hook_repr = self.hooks[hook_name].__repr__() | ||
else: | ||
hook_repr = self.hooks[hook_name].__class__.__name__ | ||
registry_repr += f" ({i}) {hook_name} - {hook_repr}" | ||
if i < len(self._hook_order) - 1: | ||
hook_repr += "\n" | ||
return f"HookRegistry(\n{hook_repr}\n)" | ||
registry_repr += "\n" | ||
return f"HookRegistry(\n{registry_repr}\n)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DN6 Some major changes made to the hooks addition/removal process to be able to support:
Please take a look when you can. Happy to answer any questions and iterate further if needed