-
Notifications
You must be signed in to change notification settings - Fork 245
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add tests for shared prefix alignment with flex attn
Signed-off-by: Austin Liu <[email protected]> nit Signed-off-by: Austin Liu <[email protected]>
- Loading branch information
1 parent
5c5a7b4
commit 031707d
Showing
1 changed file
with
301 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,301 @@ | ||
from test.utils import assert_verbose_allclose, set_seed, supports_bfloat16 | ||
|
||
import pytest | ||
import torch | ||
import torch.nn.functional as F | ||
from torch.nn.attention.flex_attention import ( | ||
create_block_mask, | ||
create_mask, | ||
flex_attention, | ||
) | ||
|
||
from liger_kernel.utils import infer_device | ||
|
||
|
||
def causal_mask(b, h, q_idx, kv_idx): | ||
return q_idx >= kv_idx | ||
|
||
|
||
def prefix_mask(b, h, q_idx, kv_idx, rejected_index, chosen_index): | ||
return ( | ||
~( | ||
(q_idx >= rejected_index[b]) | ||
& (chosen_index[b] <= kv_idx) | ||
& (kv_idx < rejected_index[b]) | ||
) | ||
) & (q_idx >= kv_idx) | ||
|
||
|
||
device = infer_device() | ||
set_seed(42) | ||
|
||
|
||
def _test_correctness_flex(B, H, S, D, mask_func, dtype, atol, rtol, device="cuda"): | ||
""" | ||
Test attention mechanisms with various implementations. | ||
Parameters: | ||
B (int): Batch size | ||
H (int): Number of attention heads | ||
S (int): Sequence length | ||
D (int): Hidden dimension per head | ||
mask_func: A function that generates custom attention mask | ||
dtype: Data type for computation | ||
atol (float): Absolute tolerance for comparison | ||
rtol (float): Relative tolerance for comparison | ||
""" | ||
torch.manual_seed(0) | ||
|
||
# Initialize input tensors, i.e. the tensors after q, k, and v projections of hidden states (attention head input) | ||
query_torch = torch.randn( | ||
B, H, S, D, device=device, dtype=dtype, requires_grad=True | ||
) | ||
key_torch = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True) | ||
value_torch = torch.randn( | ||
B, H, S, D, device=device, dtype=dtype, requires_grad=True | ||
) | ||
|
||
query_flex = query_torch.clone().detach().requires_grad_(True) | ||
key_flex = key_torch.clone().detach().requires_grad_(True) | ||
value_flex = value_torch.clone().detach().requires_grad_(True) | ||
|
||
block_mask = create_block_mask( | ||
mask_func, B, H, S, S, device=device | ||
) # Sparsity block mask | ||
mask = create_mask(mask_func, B, H, S, S, device=device) # Regular mask | ||
|
||
# If you are using a causal mask with FA2, you can enable `is_causal`." | ||
# e.g., | ||
# F.scaled_dot_product_attention(query, key, value, is_causal=is_causal) | ||
|
||
torch_out = F.scaled_dot_product_attention( | ||
query_torch, key_torch, value_torch, attn_mask=mask | ||
) | ||
|
||
flex_out = flex_attention(query_flex, key_flex, value_flex, block_mask=block_mask) | ||
|
||
# Check forward pass | ||
assert_verbose_allclose(flex_out, torch_out, atol=atol, rtol=rtol) | ||
|
||
grad_out = torch.randn_like(torch_out) | ||
torch_out.backward(grad_out) | ||
flex_out.backward(grad_out) | ||
|
||
# Check gradients | ||
assert_verbose_allclose(query_flex.grad, query_torch.grad, atol=atol, rtol=rtol) | ||
assert_verbose_allclose(key_flex.grad, key_torch.grad, atol=atol, rtol=rtol) | ||
assert_verbose_allclose(value_flex.grad, value_torch.grad, atol=atol, rtol=rtol) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"B, H, S, D", | ||
[ | ||
(2, 8, 1024, 32), | ||
(3, 12, 2048, 64), | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"dtype, atol, rtol", | ||
[ | ||
pytest.param( | ||
torch.bfloat16, | ||
3e-2, | ||
5e-1, | ||
marks=pytest.mark.skipif( | ||
not supports_bfloat16(), reason="bfloat16 not supported on this GPU" | ||
), | ||
), | ||
(torch.float16, 1e-2, 5e-3), | ||
(torch.float32, 1e-3, 5e-4), | ||
], | ||
) | ||
def test_correctness_flex(B, H, S, D, dtype, atol, rtol): | ||
_test_correctness_flex(B, H, S, D, causal_mask, dtype, atol, rtol) | ||
|
||
# Roughly generate custom rejected and chosen indices for each batch | ||
chosen_index = torch.randint(0, S // 2, (B,), device="cuda") | ||
rejected_index = torch.randint(S // 2, S, (B,), device="cuda") | ||
|
||
def wrapped_prefix_mask(b, h, q_idx, kv_idx): | ||
return prefix_mask(b, h, q_idx, kv_idx, rejected_index, chosen_index) | ||
|
||
_test_correctness_flex(B, H, S, D, wrapped_prefix_mask, dtype, atol, rtol) | ||
|
||
|
||
def _test_correctness_prefix( | ||
B=2, | ||
H=8, | ||
P=512, | ||
C=256, | ||
R=256, | ||
D=32, | ||
dtype=torch.float32, | ||
atol=1e-3, | ||
rtol=5e-4, | ||
device="cuda", | ||
): | ||
""" | ||
Test that prefix sharing attention matches separate computations (i.e. two separate casual masked attention, prefix+chosen and prefix+rejected). | ||
The mental model is: | ||
A. prefix + chosen | ||
P | ||
P P | ||
P P P | ||
P P P C | ||
P P P C C | ||
P P P C C C | ||
B. prefix + rejected | ||
P | ||
P P | ||
P P P | ||
P P P R | ||
P P P R R | ||
P P P R R R | ||
C. shared prefix + chosen + rejected | ||
P | ||
P P | ||
P P P | ||
P P P C | ||
P P P C C | ||
P P P C C C | ||
P P P R | ||
P P P R R | ||
P P P R R R | ||
We test them as belwo to ensure attention value equivalence: | ||
1. prefix of shared attn (upper of C.) == prefix of chosen attn (upper of A.) | ||
2. prefix of shared attn (upper of C.) == prefix of rejected attn (upper of B.) | ||
P P | ||
P P = P P | ||
P P P P P P | ||
3. prefix of shared attn (middle right of C.) == prefix of chosen attn (lower right of A.) | ||
C C | ||
C C = C C | ||
C C C C C C | ||
4. prefix of shared attn (lower right of C.) == prefix of rejected attn (lower right of B.) | ||
R R | ||
R R = R R | ||
R R R R R R | ||
Args: | ||
B: batch size | ||
H: number of heads | ||
P: prefix length | ||
C: chosen response length | ||
R: rejected response length | ||
D: hidden dimension per head | ||
""" | ||
torch.manual_seed(0) | ||
|
||
# Total sequence length for shared version | ||
S = P + C + R | ||
|
||
# Initialize input tensors, i.e. the tensors after q, k, and v projections of hidden states (attention head input) | ||
query = torch.randn(B, H, S, D, device=device, dtype=dtype) | ||
key = torch.randn(B, H, S, D, device=device, dtype=dtype) | ||
value = torch.randn(B, H, S, D, device=device, dtype=dtype) | ||
|
||
# Split tensors for separate computation | ||
query_prefix = query[:, :, :P, :] | ||
key_prefix = key[:, :, :P, :] | ||
value_prefix = value[:, :, :P, :] | ||
|
||
query_chosen = query[:, :, P : P + C, :] | ||
key_chosen = key[:, :, P : P + C, :] | ||
value_chosen = value[:, :, P : P + C, :] | ||
|
||
query_rejected = query[:, :, P + C :, :] | ||
key_rejected = key[:, :, P + C :, :] | ||
value_rejected = value[:, :, P + C :, :] | ||
|
||
chosen_index = torch.full((B,), P + C, device=device) | ||
rejected_index = torch.full((B,), S, device=device) | ||
|
||
def wrapped_prefix_mask(b, h, q_idx, kv_idx): | ||
return prefix_mask(b, h, q_idx, kv_idx, rejected_index, chosen_index) | ||
|
||
block_mask = create_block_mask(wrapped_prefix_mask, B, H, S, S, device=device) | ||
shared_out = flex_attention(query, key, value, block_mask=block_mask) | ||
|
||
# Compute attention for prefix + chosen separately | ||
PC = P + C | ||
query_pc = torch.cat([query_prefix, query_chosen], dim=2) | ||
key_pc = torch.cat([key_prefix, key_chosen], dim=2) | ||
value_pc = torch.cat([value_prefix, value_chosen], dim=2) | ||
|
||
def causal_mask(b, h, q_idx, kv_idx): | ||
return q_idx >= kv_idx | ||
|
||
pc_block_mask = create_block_mask(causal_mask, B, H, PC, PC, device=device) | ||
pc_out = flex_attention(query_pc, key_pc, value_pc, block_mask=pc_block_mask) | ||
|
||
# Compute attention for prefix + rejected separately | ||
PR = P + R | ||
query_pr = torch.cat([query_prefix, query_rejected], dim=2) | ||
key_pr = torch.cat([key_prefix, key_rejected], dim=2) | ||
value_pr = torch.cat([value_prefix, value_rejected], dim=2) | ||
|
||
pr_block_mask = create_block_mask(causal_mask, B, H, PR, PR, device=device) | ||
pr_out = flex_attention(query_pr, key_pr, value_pr, block_mask=pr_block_mask) | ||
|
||
shared_prefix = shared_out[:, :, :P, :P] | ||
shared_chosen = shared_out[:, :, P : P + C, P : P + C] | ||
shared_rejected = shared_out[:, :, P + C :, P + C :] | ||
|
||
separate_prefix_c = pc_out[:, :, :P, :P] | ||
separate_chosen = pc_out[:, :, P:, P:] | ||
separate_prefix_r = pr_out[:, :, :P, :P] | ||
separate_rejected = pr_out[:, :, P:, P:] | ||
|
||
# Verify prefix outputs are identical | ||
assert torch.allclose( | ||
shared_prefix, separate_prefix_c, atol=atol, rtol=rtol | ||
), "Prefix attention from shared computation doesn't match prefix+chosen computation" | ||
assert torch.allclose( | ||
shared_prefix, separate_prefix_r, atol=atol, rtol=rtol | ||
), "Prefix attention from shared computation doesn't match prefix+rejected computation" | ||
|
||
# Verify chosen and rejected outputs | ||
assert torch.allclose( | ||
shared_chosen, separate_chosen, atol=atol, rtol=rtol | ||
), "Chosen response attention doesn't match between shared and separate computation" | ||
assert torch.allclose( | ||
shared_rejected, separate_rejected, atol=atol, rtol=rtol | ||
), "Rejected response attention doesn't match between shared and separate computation" | ||
|
||
print("All attention values match between shared and separate computations!") | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"B, H, P, C, R, D", | ||
[ | ||
(2, 8, 512, 256, 256, 32), | ||
(3, 12, 1024, 512, 512, 64), | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"dtype, atol, rtol", | ||
[ | ||
pytest.param( | ||
torch.bfloat16, | ||
3e-2, | ||
5e-1, | ||
marks=pytest.mark.skipif( | ||
not supports_bfloat16(), reason="bfloat16 not supported on this GPU" | ||
), | ||
), | ||
(torch.float16, 1e-2, 5e-3), | ||
(torch.float32, 1e-3, 5e-4), | ||
], | ||
) | ||
def test_correctness_prefix(B, H, P, C, R, D, dtype, atol, rtol): | ||
"""Parametrized test for different configurations""" | ||
_test_correctness_prefix( | ||
B=B, H=H, P=P, C=C, R=R, D=D, dtype=dtype, atol=atol, rtol=rtol | ||
) |