diff --git a/test/transformers/test_flex_attention.py b/test/transformers/test_flex_attention.py new file mode 100644 index 000000000..3f32e23da --- /dev/null +++ b/test/transformers/test_flex_attention.py @@ -0,0 +1,301 @@ +from test.utils import assert_verbose_allclose, set_seed, supports_bfloat16 + +import pytest +import torch +import torch.nn.functional as F +from torch.nn.attention.flex_attention import ( + create_block_mask, + create_mask, + flex_attention, +) + +from liger_kernel.utils import infer_device + + +def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + +def prefix_mask(b, h, q_idx, kv_idx, rejected_index, chosen_index): + return ( + ~( + (q_idx >= rejected_index[b]) + & (chosen_index[b] <= kv_idx) + & (kv_idx < rejected_index[b]) + ) + ) & (q_idx >= kv_idx) + + +device = infer_device() +set_seed(42) + + +def _test_correctness_flex(B, H, S, D, mask_func, dtype, atol, rtol, device="cuda"): + """ + Test attention mechanisms with various implementations. + + Parameters: + B (int): Batch size + H (int): Number of attention heads + S (int): Sequence length + D (int): Hidden dimension per head + mask_func: A function that generates custom attention mask + dtype: Data type for computation + atol (float): Absolute tolerance for comparison + rtol (float): Relative tolerance for comparison + """ + torch.manual_seed(0) + + # Initialize input tensors, i.e. the tensors after q, k, and v projections of hidden states (attention head input) + query_torch = torch.randn( + B, H, S, D, device=device, dtype=dtype, requires_grad=True + ) + key_torch = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True) + value_torch = torch.randn( + B, H, S, D, device=device, dtype=dtype, requires_grad=True + ) + + query_flex = query_torch.clone().detach().requires_grad_(True) + key_flex = key_torch.clone().detach().requires_grad_(True) + value_flex = value_torch.clone().detach().requires_grad_(True) + + block_mask = create_block_mask( + mask_func, B, H, S, S, device=device + ) # Sparsity block mask + mask = create_mask(mask_func, B, H, S, S, device=device) # Regular mask + + # If you are using a causal mask with FA2, you can enable `is_causal`." + # e.g., + # F.scaled_dot_product_attention(query, key, value, is_causal=is_causal) + + torch_out = F.scaled_dot_product_attention( + query_torch, key_torch, value_torch, attn_mask=mask + ) + + flex_out = flex_attention(query_flex, key_flex, value_flex, block_mask=block_mask) + + # Check forward pass + assert_verbose_allclose(flex_out, torch_out, atol=atol, rtol=rtol) + + grad_out = torch.randn_like(torch_out) + torch_out.backward(grad_out) + flex_out.backward(grad_out) + + # Check gradients + assert_verbose_allclose(query_flex.grad, query_torch.grad, atol=atol, rtol=rtol) + assert_verbose_allclose(key_flex.grad, key_torch.grad, atol=atol, rtol=rtol) + assert_verbose_allclose(value_flex.grad, value_torch.grad, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize( + "B, H, S, D", + [ + (2, 8, 1024, 32), + (3, 12, 2048, 64), + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + pytest.param( + torch.bfloat16, + 3e-2, + 5e-1, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + (torch.float16, 1e-2, 5e-3), + (torch.float32, 1e-3, 5e-4), + ], +) +def test_correctness_flex(B, H, S, D, dtype, atol, rtol): + _test_correctness_flex(B, H, S, D, causal_mask, dtype, atol, rtol) + + # Roughly generate custom rejected and chosen indices for each batch + chosen_index = torch.randint(0, S // 2, (B,), device="cuda") + rejected_index = torch.randint(S // 2, S, (B,), device="cuda") + + def wrapped_prefix_mask(b, h, q_idx, kv_idx): + return prefix_mask(b, h, q_idx, kv_idx, rejected_index, chosen_index) + + _test_correctness_flex(B, H, S, D, wrapped_prefix_mask, dtype, atol, rtol) + + +def _test_correctness_prefix( + B=2, + H=8, + P=512, + C=256, + R=256, + D=32, + dtype=torch.float32, + atol=1e-3, + rtol=5e-4, + device="cuda", +): + """ + Test that prefix sharing attention matches separate computations (i.e. two separate casual masked attention, prefix+chosen and prefix+rejected). + The mental model is: + + A. prefix + chosen + P + P P + P P P + P P P C + P P P C C + P P P C C C + + B. prefix + rejected + P + P P + P P P + P P P R + P P P R R + P P P R R R + + C. shared prefix + chosen + rejected + P + P P + P P P + P P P C + P P P C C + P P P C C C + P P P R + P P P R R + P P P R R R + + + We test them as belwo to ensure attention value equivalence: + 1. prefix of shared attn (upper of C.) == prefix of chosen attn (upper of A.) + 2. prefix of shared attn (upper of C.) == prefix of rejected attn (upper of B.) + P P + P P = P P + P P P P P P + + 3. prefix of shared attn (middle right of C.) == prefix of chosen attn (lower right of A.) + C C + C C = C C + C C C C C C + + 4. prefix of shared attn (lower right of C.) == prefix of rejected attn (lower right of B.) + R R + R R = R R + R R R R R R + + Args: + B: batch size + H: number of heads + P: prefix length + C: chosen response length + R: rejected response length + D: hidden dimension per head + """ + torch.manual_seed(0) + + # Total sequence length for shared version + S = P + C + R + + # Initialize input tensors, i.e. the tensors after q, k, and v projections of hidden states (attention head input) + query = torch.randn(B, H, S, D, device=device, dtype=dtype) + key = torch.randn(B, H, S, D, device=device, dtype=dtype) + value = torch.randn(B, H, S, D, device=device, dtype=dtype) + + # Split tensors for separate computation + query_prefix = query[:, :, :P, :] + key_prefix = key[:, :, :P, :] + value_prefix = value[:, :, :P, :] + + query_chosen = query[:, :, P : P + C, :] + key_chosen = key[:, :, P : P + C, :] + value_chosen = value[:, :, P : P + C, :] + + query_rejected = query[:, :, P + C :, :] + key_rejected = key[:, :, P + C :, :] + value_rejected = value[:, :, P + C :, :] + + chosen_index = torch.full((B,), P + C, device=device) + rejected_index = torch.full((B,), S, device=device) + + def wrapped_prefix_mask(b, h, q_idx, kv_idx): + return prefix_mask(b, h, q_idx, kv_idx, rejected_index, chosen_index) + + block_mask = create_block_mask(wrapped_prefix_mask, B, H, S, S, device=device) + shared_out = flex_attention(query, key, value, block_mask=block_mask) + + # Compute attention for prefix + chosen separately + PC = P + C + query_pc = torch.cat([query_prefix, query_chosen], dim=2) + key_pc = torch.cat([key_prefix, key_chosen], dim=2) + value_pc = torch.cat([value_prefix, value_chosen], dim=2) + + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + pc_block_mask = create_block_mask(causal_mask, B, H, PC, PC, device=device) + pc_out = flex_attention(query_pc, key_pc, value_pc, block_mask=pc_block_mask) + + # Compute attention for prefix + rejected separately + PR = P + R + query_pr = torch.cat([query_prefix, query_rejected], dim=2) + key_pr = torch.cat([key_prefix, key_rejected], dim=2) + value_pr = torch.cat([value_prefix, value_rejected], dim=2) + + pr_block_mask = create_block_mask(causal_mask, B, H, PR, PR, device=device) + pr_out = flex_attention(query_pr, key_pr, value_pr, block_mask=pr_block_mask) + + shared_prefix = shared_out[:, :, :P, :P] + shared_chosen = shared_out[:, :, P : P + C, P : P + C] + shared_rejected = shared_out[:, :, P + C :, P + C :] + + separate_prefix_c = pc_out[:, :, :P, :P] + separate_chosen = pc_out[:, :, P:, P:] + separate_prefix_r = pr_out[:, :, :P, :P] + separate_rejected = pr_out[:, :, P:, P:] + + # Verify prefix outputs are identical + assert torch.allclose( + shared_prefix, separate_prefix_c, atol=atol, rtol=rtol + ), "Prefix attention from shared computation doesn't match prefix+chosen computation" + assert torch.allclose( + shared_prefix, separate_prefix_r, atol=atol, rtol=rtol + ), "Prefix attention from shared computation doesn't match prefix+rejected computation" + + # Verify chosen and rejected outputs + assert torch.allclose( + shared_chosen, separate_chosen, atol=atol, rtol=rtol + ), "Chosen response attention doesn't match between shared and separate computation" + assert torch.allclose( + shared_rejected, separate_rejected, atol=atol, rtol=rtol + ), "Rejected response attention doesn't match between shared and separate computation" + + print("All attention values match between shared and separate computations!") + + +@pytest.mark.parametrize( + "B, H, P, C, R, D", + [ + (2, 8, 512, 256, 256, 32), + (3, 12, 1024, 512, 512, 64), + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + pytest.param( + torch.bfloat16, + 3e-2, + 5e-1, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + (torch.float16, 1e-2, 5e-3), + (torch.float32, 1e-3, 5e-4), + ], +) +def test_correctness_prefix(B, H, P, C, R, D, dtype, atol, rtol): + """Parametrized test for different configurations""" + _test_correctness_prefix( + B=B, H=H, P=P, C=C, R=R, D=D, dtype=dtype, atol=atol, rtol=rtol + )