Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit tests for shared prefix masked attention with torch.FlexAttention #504

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 283 additions & 0 deletions test/transformers/test_flex_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
import pytest
import torch
import torch.nn.functional as F

from test.utils import assert_verbose_allclose
from test.utils import set_seed
from test.utils import supports_bfloat16
from torch.nn.attention.flex_attention import create_block_mask
from torch.nn.attention.flex_attention import create_mask
from torch.nn.attention.flex_attention import flex_attention

from liger_kernel.utils import infer_device


def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx


def prefix_mask(b, h, q_idx, kv_idx, rejected_index, chosen_index):
return (~((q_idx >= rejected_index[b]) & (chosen_index[b] <= kv_idx) & (kv_idx < rejected_index[b]))) & (
q_idx >= kv_idx
)


device = infer_device()
set_seed(42)


def _test_correctness_flex(B, H, S, D, mask_func, dtype, atol, rtol, device="cuda"):
"""
Test attention mechanisms with various implementations.

Parameters:
B (int): Batch size
H (int): Number of attention heads
S (int): Sequence length
D (int): Hidden dimension per head
mask_func: A function that generates custom attention mask
dtype: Data type for computation
atol (float): Absolute tolerance for comparison
rtol (float): Relative tolerance for comparison
"""
torch.manual_seed(0)

# Initialize input tensors, i.e. the tensors after q, k, and v projections of hidden states (attention head input)
query_torch = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
key_torch = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
value_torch = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)

query_flex = query_torch.clone().detach().requires_grad_(True)
key_flex = key_torch.clone().detach().requires_grad_(True)
value_flex = value_torch.clone().detach().requires_grad_(True)

block_mask = create_block_mask(mask_func, B, H, S, S, device=device) # Sparsity block mask
mask = create_mask(mask_func, B, H, S, S, device=device) # Regular mask

# If you are using a causal mask with FA2, you can enable `is_causal`."
# e.g.,
# F.scaled_dot_product_attention(query, key, value, is_causal=is_causal)

torch_out = F.scaled_dot_product_attention(query_torch, key_torch, value_torch, attn_mask=mask)

flex_out = flex_attention(query_flex, key_flex, value_flex, block_mask=block_mask)

# Check forward pass
assert_verbose_allclose(flex_out, torch_out, atol=atol, rtol=rtol)

grad_out = torch.randn_like(torch_out)
torch_out.backward(grad_out)
flex_out.backward(grad_out)

# Check gradients
assert_verbose_allclose(query_flex.grad, query_torch.grad, atol=atol, rtol=rtol)
assert_verbose_allclose(key_flex.grad, key_torch.grad, atol=atol, rtol=rtol)
assert_verbose_allclose(value_flex.grad, value_torch.grad, atol=atol, rtol=rtol)


@pytest.mark.parametrize(
"B, H, S, D",
[
(2, 8, 1024, 32),
(3, 12, 2048, 64),
],
)
@pytest.mark.parametrize(
"dtype, atol, rtol",
[
pytest.param(
torch.bfloat16,
3e-2,
5e-1,
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
),
(torch.float16, 1e-2, 5e-3),
(torch.float32, 1e-3, 5e-4),
],
)
def test_correctness_flex(B, H, S, D, dtype, atol, rtol):
_test_correctness_flex(B, H, S, D, causal_mask, dtype, atol, rtol)

# Roughly generate custom rejected and chosen indices for each batch
chosen_index = torch.randint(0, S // 2, (B,), device="cuda")
rejected_index = torch.randint(S // 2, S, (B,), device="cuda")

def wrapped_prefix_mask(b, h, q_idx, kv_idx):
return prefix_mask(b, h, q_idx, kv_idx, rejected_index, chosen_index)

_test_correctness_flex(B, H, S, D, wrapped_prefix_mask, dtype, atol, rtol)


def _test_correctness_prefix(
B=2,
H=8,
P=512,
C=256,
R=256,
D=32,
dtype=torch.float32,
atol=1e-3,
rtol=5e-4,
device="cuda",
):
"""
Test that prefix sharing attention matches separate computations (i.e. two separate casual masked attention, prefix+chosen and prefix+rejected).
The mental model is:

A. prefix + chosen
P
P P
P P P
P P P C
P P P C C
P P P C C C

B. prefix + rejected
P
P P
P P P
P P P R
P P P R R
P P P R R R

C. shared prefix + chosen + rejected
P
P P
P P P
P P P C
P P P C C
P P P C C C
P P P R
P P P R R
P P P R R R


We test them as below to ensure attention value equivalence:
1. prefix of shared attn (upper of C.) == prefix of chosen attn (upper of A.)
2. prefix of shared attn (upper of C.) == prefix of rejected attn (upper of B.)
P P
P P = P P
P P P P P P

3. prefix of shared attn (middle right of C.) == prefix of chosen attn (lower right of A.)
C C
C C = C C
C C C C C C

4. prefix of shared attn (lower right of C.) == prefix of rejected attn (lower right of B.)
R R
R R = R R
R R R R R R

Args:
B: batch size
H: number of heads
P: prefix length
C: chosen response length
R: rejected response length
D: hidden dimension per head
"""
torch.manual_seed(0)

# Total sequence length for shared version
S = P + C + R

# Initialize input tensors, i.e. the tensors after q, k, and v projections of hidden states (attention head input)
query = torch.randn(B, H, S, D, device=device, dtype=dtype)
key = torch.randn(B, H, S, D, device=device, dtype=dtype)
value = torch.randn(B, H, S, D, device=device, dtype=dtype)

# Split tensors for separate computation
query_prefix = query[:, :, :P, :]
key_prefix = key[:, :, :P, :]
value_prefix = value[:, :, :P, :]

query_chosen = query[:, :, P : P + C, :]
key_chosen = key[:, :, P : P + C, :]
value_chosen = value[:, :, P : P + C, :]

query_rejected = query[:, :, P + C :, :]
key_rejected = key[:, :, P + C :, :]
value_rejected = value[:, :, P + C :, :]

chosen_index = torch.full((B,), P + C, device=device)
rejected_index = torch.full((B,), S, device=device)

def wrapped_prefix_mask(b, h, q_idx, kv_idx):
return prefix_mask(b, h, q_idx, kv_idx, rejected_index, chosen_index)

block_mask = create_block_mask(wrapped_prefix_mask, B, H, S, S, device=device)
shared_out = flex_attention(query, key, value, block_mask=block_mask)

# Compute attention for prefix + chosen separately
PC = P + C
query_pc = torch.cat([query_prefix, query_chosen], dim=2)
key_pc = torch.cat([key_prefix, key_chosen], dim=2)
value_pc = torch.cat([value_prefix, value_chosen], dim=2)

def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx

pc_block_mask = create_block_mask(causal_mask, B, H, PC, PC, device=device)
pc_out = flex_attention(query_pc, key_pc, value_pc, block_mask=pc_block_mask)

# Compute attention for prefix + rejected separately
PR = P + R
query_pr = torch.cat([query_prefix, query_rejected], dim=2)
key_pr = torch.cat([key_prefix, key_rejected], dim=2)
value_pr = torch.cat([value_prefix, value_rejected], dim=2)

pr_block_mask = create_block_mask(causal_mask, B, H, PR, PR, device=device)
pr_out = flex_attention(query_pr, key_pr, value_pr, block_mask=pr_block_mask)

shared_prefix = shared_out[:, :, :P, :P]
shared_chosen = shared_out[:, :, P : P + C, P : P + C]
shared_rejected = shared_out[:, :, P + C :, P + C :]

separate_prefix_c = pc_out[:, :, :P, :P]
separate_chosen = pc_out[:, :, P:, P:]
separate_prefix_r = pr_out[:, :, :P, :P]
separate_rejected = pr_out[:, :, P:, P:]

# Verify prefix outputs are identical
assert torch.allclose(
shared_prefix, separate_prefix_c, atol=atol, rtol=rtol
), "Prefix attention from shared computation doesn't match prefix+chosen computation"
assert torch.allclose(
shared_prefix, separate_prefix_r, atol=atol, rtol=rtol
), "Prefix attention from shared computation doesn't match prefix+rejected computation"

# Verify chosen and rejected outputs
assert torch.allclose(
shared_chosen, separate_chosen, atol=atol, rtol=rtol
), "Chosen response attention doesn't match between shared and separate computation"
assert torch.allclose(
shared_rejected, separate_rejected, atol=atol, rtol=rtol
), "Rejected response attention doesn't match between shared and separate computation"

print("All attention values match between shared and separate computations!")


@pytest.mark.parametrize(
"B, H, P, C, R, D",
[
(2, 8, 512, 256, 256, 32),
(3, 12, 1024, 512, 512, 64),
],
)
@pytest.mark.parametrize(
"dtype, atol, rtol",
[
pytest.param(
torch.bfloat16,
3e-2,
5e-1,
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
),
(torch.float16, 1e-2, 5e-3),
(torch.float32, 1e-3, 5e-4),
],
)
def test_correctness_prefix(B, H, P, C, R, D, dtype, atol, rtol):
"""Parametrized test for different configurations"""
_test_correctness_prefix(B=B, H=H, P=P, C=C, R=R, D=D, dtype=dtype, atol=atol, rtol=rtol)
Loading