Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit tests for shared prefix masked attention with torch.FlexAttention #504

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

austin362667
Copy link
Collaborator

@austin362667 austin362667 commented Dec 29, 2024

Summary

TLDR of #476: The shared prefix attention mask is an optimization for paired-preference alignment training.

To pave the way for #476, this PR aims to set up basic unit tests of flex attn with casual and shared prefix mask.

Testing Done

Benchmarks

  1. Casual Attention Mask (Flash Attention 2 vs. Torch Scaled Dot Product Attention vs. FlexAttention)

image

  1. Shared Prefix Attention Mask (Flash Attention 2 vs. Torch Scaled Dot Product Attention vs. FlexAttention)

image

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@austin362667 austin362667 force-pushed the austin362667/prefix_sharing branch 4 times, most recently from 45759a3 to 30001e7 Compare December 30, 2024 06:19
@austin362667
Copy link
Collaborator Author

austin362667 commented Dec 30, 2024

The FlexAttention correctness check failed only on AMD GPU might be related to this ISSUE and PR.

Signed-off-by: Austin Liu <[email protected]>

nit

Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
@austin362667 austin362667 force-pushed the austin362667/prefix_sharing branch from 30001e7 to bd4c9f5 Compare January 8, 2025 03:13
@austin362667
Copy link
Collaborator Author

austin362667 commented Jan 20, 2025

The FlexAttention correctness check failed only on AMD GPU might be related to this pytorch/pytorch#138300 and pytorch/pytorch#140172.

Maybe this is relevant: #506. It looks like PyTorch ROCm 6.2 is not as solid as PyTorch ROCm 6.3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants