Skip to content

Commit

Permalink
Add tests for shared prefix alignment with flex attn
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Liu <[email protected]>

nit

Signed-off-by: Austin Liu <[email protected]>
  • Loading branch information
austin362667 committed Dec 30, 2024
1 parent 42ff02a commit 30001e7
Showing 1 changed file with 301 additions and 0 deletions.
301 changes: 301 additions & 0 deletions test/transformers/test_flex_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,301 @@
from test.utils import assert_verbose_allclose, set_seed, supports_bfloat16

import pytest
import torch
import torch.nn.functional as F
from torch.nn.attention.flex_attention import (
create_block_mask,
create_mask,
flex_attention,
)

from liger_kernel.utils import infer_device


def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx


def prefix_mask(b, h, q_idx, kv_idx, rejected_index, chosen_index):
return (
~(
(q_idx >= rejected_index[b])
& (chosen_index[b] <= kv_idx)
& (kv_idx < rejected_index[b])
)
) & (q_idx >= kv_idx)


device = infer_device()
set_seed(42)


def _test_correctness_flex(B, H, S, D, mask_func, dtype, atol, rtol, device="cuda"):
"""
Test attention mechanisms with various implementations.
Parameters:
B (int): Batch size
H (int): Number of attention heads
S (int): Sequence length
D (int): Hidden dimension per head
mask_func: A function that generates custom attention mask
dtype: Data type for computation
atol (float): Absolute tolerance for comparison
rtol (float): Relative tolerance for comparison
"""
torch.manual_seed(0)

# Initialize input tensors, i.e. the tensors after q, k, and v projections of hidden states (attention head input)
query_torch = torch.randn(
B, H, S, D, device=device, dtype=dtype, requires_grad=True
)
key_torch = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
value_torch = torch.randn(
B, H, S, D, device=device, dtype=dtype, requires_grad=True
)

query_flex = query_torch.clone().detach().requires_grad_(True)
key_flex = key_torch.clone().detach().requires_grad_(True)
value_flex = value_torch.clone().detach().requires_grad_(True)

block_mask = create_block_mask(
mask_func, B, H, S, S, device=device
) # Sparsity block mask
mask = create_mask(mask_func, B, H, S, S, device=device) # Regular mask

# If you are using a causal mask with FA2, you can enable `is_causal`."
# e.g.,
# F.scaled_dot_product_attention(query, key, value, is_causal=is_causal)

torch_out = F.scaled_dot_product_attention(
query_torch, key_torch, value_torch, attn_mask=mask
)

flex_out = flex_attention(query_flex, key_flex, value_flex, block_mask=block_mask)

# Check forward pass
assert_verbose_allclose(flex_out, torch_out, atol=atol, rtol=rtol)

grad_out = torch.randn_like(torch_out)
torch_out.backward(grad_out)
flex_out.backward(grad_out)

# Check gradients
assert_verbose_allclose(query_flex.grad, query_torch.grad, atol=atol, rtol=rtol)
assert_verbose_allclose(key_flex.grad, key_torch.grad, atol=atol, rtol=rtol)
assert_verbose_allclose(value_flex.grad, value_torch.grad, atol=atol, rtol=rtol)


@pytest.mark.parametrize(
"B, H, S, D",
[
(2, 8, 1024, 32),
(3, 12, 2048, 64),
],
)
@pytest.mark.parametrize(
"dtype, atol, rtol",
[
pytest.param(
torch.bfloat16,
3e-2,
5e-1,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
(torch.float16, 1e-2, 5e-3),
(torch.float32, 1e-3, 5e-4),
],
)
def test_correctness_flex(B, H, S, D, dtype, atol, rtol):
_test_correctness_flex(B, H, S, D, causal_mask, dtype, atol, rtol)

# Roughly generate custom rejected and chosen indices for each batch
chosen_index = torch.randint(0, S // 2, (B,), device="cuda")
rejected_index = torch.randint(S // 2, S, (B,), device="cuda")

def wrapped_prefix_mask(b, h, q_idx, kv_idx):
return prefix_mask(b, h, q_idx, kv_idx, rejected_index, chosen_index)

_test_correctness_flex(B, H, S, D, wrapped_prefix_mask, dtype, atol, rtol)


def _test_correctness_prefix(
B=2,
H=8,
P=512,
C=256,
R=256,
D=32,
dtype=torch.float32,
atol=1e-3,
rtol=5e-4,
device="cuda",
):
"""
Test that prefix sharing attention matches separate computations (i.e. two separate casual masked attention, prefix+chosen and prefix+rejected).
The mental model is:
A. prefix + chosen
P
P P
P P P
P P P C
P P P C C
P P P C C C
B. prefix + rejected
P
P P
P P P
P P P R
P P P R R
P P P R R R
C. shared prefix + chosen + rejected
P
P P
P P P
P P P C
P P P C C
P P P C C C
P P P R
P P P R R
P P P R R R
We test them as belwo to ensure attention value equivalence:
1. prefix of shared attn (upper of C.) == prefix of chosen attn (upper of A.)
2. prefix of shared attn (upper of C.) == prefix of rejected attn (upper of B.)
P P
P P = P P
P P P P P P
3. prefix of shared attn (middle right of C.) == prefix of chosen attn (lower right of A.)
C C
C C = C C
C C C C C C
4. prefix of shared attn (lower right of C.) == prefix of rejected attn (lower right of B.)
R R
R R = R R
R R R R R R
Args:
B: batch size
H: number of heads
P: prefix length
C: chosen response length
R: rejected response length
D: hidden dimension per head
"""
torch.manual_seed(0)

# Total sequence length for shared version
S = P + C + R

# Initialize input tensors, i.e. the tensors after q, k, and v projections of hidden states (attention head input)
query = torch.randn(B, H, S, D, device=device, dtype=dtype)
key = torch.randn(B, H, S, D, device=device, dtype=dtype)
value = torch.randn(B, H, S, D, device=device, dtype=dtype)

# Split tensors for separate computation
query_prefix = query[:, :, :P, :]
key_prefix = key[:, :, :P, :]
value_prefix = value[:, :, :P, :]

query_chosen = query[:, :, P : P + C, :]
key_chosen = key[:, :, P : P + C, :]
value_chosen = value[:, :, P : P + C, :]

query_rejected = query[:, :, P + C :, :]
key_rejected = key[:, :, P + C :, :]
value_rejected = value[:, :, P + C :, :]

chosen_index = torch.full((B,), P + C, device=device)
rejected_index = torch.full((B,), S, device=device)

def wrapped_prefix_mask(b, h, q_idx, kv_idx):
return prefix_mask(b, h, q_idx, kv_idx, rejected_index, chosen_index)

block_mask = create_block_mask(wrapped_prefix_mask, B, H, S, S, device=device)
shared_out = flex_attention(query, key, value, block_mask=block_mask)

# Compute attention for prefix + chosen separately
PC = P + C
query_pc = torch.cat([query_prefix, query_chosen], dim=2)
key_pc = torch.cat([key_prefix, key_chosen], dim=2)
value_pc = torch.cat([value_prefix, value_chosen], dim=2)

def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx

pc_block_mask = create_block_mask(causal_mask, B, H, PC, PC, device=device)
pc_out = flex_attention(query_pc, key_pc, value_pc, block_mask=pc_block_mask)

# Compute attention for prefix + rejected separately
PR = P + R
query_pr = torch.cat([query_prefix, query_rejected], dim=2)
key_pr = torch.cat([key_prefix, key_rejected], dim=2)
value_pr = torch.cat([value_prefix, value_rejected], dim=2)

pr_block_mask = create_block_mask(causal_mask, B, H, PR, PR, device=device)
pr_out = flex_attention(query_pr, key_pr, value_pr, block_mask=pr_block_mask)

shared_prefix = shared_out[:, :, :P, :P]
shared_chosen = shared_out[:, :, P : P + C, P : P + C]
shared_rejected = shared_out[:, :, P + C :, P + C :]

separate_prefix_c = pc_out[:, :, :P, :P]
separate_chosen = pc_out[:, :, P:, P:]
separate_prefix_r = pr_out[:, :, :P, :P]
separate_rejected = pr_out[:, :, P:, P:]

# Verify prefix outputs are identical
assert torch.allclose(
shared_prefix, separate_prefix_c, atol=atol, rtol=rtol
), "Prefix attention from shared computation doesn't match prefix+chosen computation"
assert torch.allclose(
shared_prefix, separate_prefix_r, atol=atol, rtol=rtol
), "Prefix attention from shared computation doesn't match prefix+rejected computation"

# Verify chosen and rejected outputs
assert torch.allclose(
shared_chosen, separate_chosen, atol=atol, rtol=rtol
), "Chosen response attention doesn't match between shared and separate computation"
assert torch.allclose(
shared_rejected, separate_rejected, atol=atol, rtol=rtol
), "Rejected response attention doesn't match between shared and separate computation"

print("All attention values match between shared and separate computations!")


@pytest.mark.parametrize(
"B, H, P, C, R, D",
[
(2, 8, 512, 256, 256, 32),
(3, 12, 1024, 512, 512, 64),
],
)
@pytest.mark.parametrize(
"dtype, atol, rtol",
[
pytest.param(
torch.bfloat16,
3e-2,
5e-1,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
(torch.float16, 1e-2, 5e-3),
(torch.float32, 1e-3, 5e-4),
],
)
def test_correctness_prefix(B, H, P, C, R, D, dtype, atol, rtol):
"""Parametrized test for different configurations"""
_test_correctness_prefix(
B=B, H=H, P=P, C=C, R=R, D=D, dtype=dtype, atol=atol, rtol=rtol
)

0 comments on commit 30001e7

Please sign in to comment.