-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] Pyramid Attention Broadcast #9562
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Co-Authored-By: Xuanlei Zhao <[email protected]>
I can't seem to replicate the results for PAB on CogVideoX-5b T2V or I2V. This is what I get:
@oahzxl Would you be able to give this a review when free? I'm unable to figure out what I'm doing wrong that's causing poor results in these cases. Thank you! |
sure, thanks for your code! i guess it may be related with pos embed or encoder concat of 5b model. i can have a look at the code soon! |
hi, i have done some experiments and here are my conclusions: i first try a simple implementation the org attention is:
for simplicty, i just add pab's logic here:
this should be exactly the same as the logic in pab processor. then i find pab will be numerically unstable with fp16 for cogvideox-5b. so i change to bfloat16, and it works! output-bf16-new.mp4->> so the first problem is float16!
but fail even if i use bfloat16 i find even i set spatial_attn_skip_range to 1 (which means no broadcast), it will also generate random noise. ->> so i think the second problem is in processor, but no clue for now hope it can help you! |
Thank you so much for the investigation! I think I found the bug. This line checks if the processor signature supports a specific keyword arguments before passing them. In this case, since we replace the attention processor with |
glad i can help :) ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! The numbers are extremely promising!
I want to brainstorm a bit about how we should incorporate PAB design-wise.
IIUC, PAB can be applied at a model-level and it rejigs the attention computation of the concerned model. IMO, this is a bit similar to how we do QKV fusion. Entry point to QKV fusion can either be from a pipeline or from a model (if there's support).
If this is correct, then I wonder if supporting PAB through a Mixin
class makes the most elegant design as opposed to enabling it via set_attn_processor()
.
Or are we relying on a Mixin
because we need to depend on pipeline-level attributes which may not be suitable for a model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice, thanks for adding docs for this method! Same comments apply to latte.md
:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I think it will be a nice feature, but not very sure about the design.
- the Attention processor wrapper is not aligned with our design, we should just make custom attention processors (even though we might have to make one for each model that has non default attention processor)
- For another thing, I think this would also won't be compatible with
torch.compile
, no? I think we should consider a design similar to https://github.com/huggingface/diffusers/pull/9524/files. We can probably store the attention output cache (a dict) on pipeline and pass as cross_attention_kwargs on each iterations (just putting the ideas here. not something I have already carefully thought through, so it might not work. feel free to brainstorm)
Co-authored-by: Steven Liu <[email protected]>
Co-authored-by: Steven Liu <[email protected]>
return False | ||
|
||
should_compute_attention = pab_state.iteration > 0 and pab_state.iteration % block_skip_range == 0 | ||
return not should_compute_attention |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be misunderstanding here, but shouldn't we just return should_compute_attention
directly here? Why use return not should_compute_attention
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is used to determine whether attention computation should be skipped or not. So, if skip_callback
were to return True, it means that should_compute_attention
had to have been False, and vice versa, so this is correct.
@DN6 Addressed the review comments. Could you give this another look? |
We need to make some more updates here before merging to address the case of using multiple hooks at once. The current implementation does not really work, if say both FP8 and PAB are enabled together. I will take it up in this PR before merging after layerwise upcasting is merged: #10347 This has already been addressed in group offloading PR but that will take some more time to complete: #10503 |
With the latest changes, it is now possible to use multiple forward-modifying hooks now. Here's an example with FP8 layerwise-upcasting and PAB: import torch
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug
set_verbosity_debug()
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
config = PyramidAttentionBroadcastConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(150, 700),
current_timestep_callback=lambda: pipe.current_timestep,
)
pipe.transformer.enable_cache(config)
pipe.transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
export_to_video(video, "output.mp4", fps=8)
|
src/diffusers/hooks/hooks.py
Outdated
forward = self._module_ref.forward | ||
|
||
fn_ref = FunctionReference() | ||
fn_ref.pre_forward = hook.pre_forward | ||
fn_ref.post_forward = hook.post_forward | ||
fn_ref.old_forward = forward | ||
|
||
if hasattr(hook, "new_forward"): | ||
fn_ref.overwritten_forward = forward | ||
fn_ref.old_forward = functools.update_wrapper( | ||
functools.partial(hook.new_forward, self._module_ref), hook.new_forward | ||
) | ||
|
||
rewritten_forward = create_new_forward(fn_ref) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DN6 Some major changes made to the hooks addition/removal process to be able to support:
- adding multiple hooks to affect the forward pass
- remove hooks arbitrarily (out-of-order is supported as well)
Please take a look when you can. Happy to answer any questions and iterate further if needed
src/diffusers/hooks/hooks.py
Outdated
index = self._hook_order.index(name) | ||
fn_ref = self._fn_refs[index] | ||
|
||
old_forward = fn_ref.old_forward | ||
if fn_ref.overwritten_forward is not None: | ||
old_forward = fn_ref.overwritten_forward | ||
|
||
if index == num_hooks - 1: | ||
self._module_ref.forward = old_forward | ||
else: | ||
self._fn_refs[index + 1].old_forward = old_forward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may look a bit weird - why are we assigning the function reference of next index to the old forward of the hook being removed?
TLDR; Hook invocation order is the reverse of the order in which they are added. Please take a look at the tests related to invocation for better understanding.
When we add hook A
followed by hook B
, the execution order of methods looks like:
forward:
pre_forward of B
pre_forward of A
actual_module_orginal_forward
post_forward of A
post_forward of B
In this example, the function references would have A
-related ones at index 0, and B
related ones at index 1. Removing hook A
requires pointing the B->old_forward
to actual_module_forward
. Removing hook B
requires pointing module->forward
to new_forward(A)
. We handle both cases here.
Let's take a more complex example to understand better. We add hook A
that only has pre/post-forward. We add hook B
that has a new_forward implementation. We add hook C
that only has pre/post-forward. The invocation order would be
forward:
pre_forward of C
pre_forward of B /> pre_forward of A
new_forward of B --/ actual_module_original_forward
post_forward of B \</ post_forward of A
post_forward of C
In this example, the function references would have A
-related ones at index 0, and B
related ones at index 1, and C
related ones at index 2. Removing hook A
requires pointing B->old_forward
(since we overwrote the original forward implementation by making use of a new_forward
method) to actual_module_original_forward
. Removing hook B
requires pointing C->old_forward
to new_forward(A)
. Removing hook C
requires pointing module->forward
to new_forward(B)
On a separate note from the explanation, the invocation design here is very friendly with parellism too IMO, so we can eventually introduce different parallel methods utilizing the same hook design, without being invasive in the actual modeling implementations themselves.
There are a few tests added to make sure execution order is correct.
if should_compute_attention: | ||
output = self.fn_ref.overwritten_forward(*args, **kwargs) | ||
else: | ||
output = self.state.cache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If a hook implements a new_forward
method, it can choose to make a call to the forward method it overwrote. The overwritten function is always stored in the overwritten_forward
attribute of FunctionReference
objects
return output | ||
|
||
|
||
class HookTests(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some basic fast tests to check simple functionality of the hooks on dummy model. This is necessary to make sure all methods are being invoked correctly, and that the hooks are behaving in a predictable manner when added, or arbitrarily removed out-of-order.
We can also support out-of-order hook addition, but currently there is no such use case, so support has not been added.
* update * update * update --------- Co-authored-by: DN6 <[email protected]>
I think we're good to merge now and also got the approval from Dhruv after working together on latest changes! Thanks for the patience and the reviews everyone 🤗 Will merge once CI is green and wrap up the open cache PRs @oahzxl Congratulations on the success of your new work - Data centric parallel! I also really liked reading about the pyramid activation checkpointing that was introduced in VideoSys. Thanks for your patience and help, and also for your work that inspired multiple other papers researching caching mechanism specific to video models. We will be sure to integrate as much as possible to make the methods more easily accessible :) |
What does this PR do?
Adds support for Pyramid Attention Broadcast.
Usage
Benchmark code
mochi---dtype-bf16---cache_method-none---compile-False.mp4
mochi---dtype-bf16---cache_method-pyramid_attention_broadcast---compile-False.mp4
hunyuan_video---dtype-bf16---cache_method-none---compile-False.mp4
hunyuan_video---dtype-bf16---cache_method-pyramid_attention_broadcast---compile-False.mp4
cogvideox_2b.mp4
cogvideox_pab_2b.mp4
cogvideox_5b.mp4
cogvideox_pab_5b.mp4
cogvideox_5b_i2v.mp4
cogvideox_pab_5b_i2v.mp4
latte---dtype-fp16---cache_method-none---compile-False.mp4
latte---dtype-fp16---cache_method-pyramid_attention_broadcast---compile-False.mp4
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
@yiyixuxu @sayakpaul
@oahzxl for PAB, @zRzRzRzRzRzRzR for CogVideoX related changes, @maxin-cn for Latte related changes