diff --git a/python/perf-kernels/flash-attention.py b/python/perf-kernels/flash-attention.py index 86141cc16125..2af965ed78f4 100644 --- a/python/perf-kernels/flash-attention.py +++ b/python/perf-kernels/flash-attention.py @@ -1203,10 +1203,29 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype): torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 8192, 64), (4, 48, 256, 64), (4, 48, 512, 64), - (4, 48, 1024, 64), (8, 48, 4096, 64), (4, 48, 8192, 64), - (4, 48, 128, 128), (4, 48, 4096, 128), (4, 48, 16384, 128), - (4, 16, 1024, 128), (4, 16, 8192, 128), (32, 48, 8192, 128)]) +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ + (1, 2, 256, 63), + (2, 8, 247, 63), + (3, 4, 256, 64), + (4, 48, 256, 64), + (4, 48, 512, 64), + (4, 48, 1024, 64), + (8, 48, 4096, 64), + (4, 48, 8192, 64), + (4, 48, 128, 128), + (4, 48, 4096, 128), + (4, 48, 16384, 128), + (4, 16, 1024, 128), + (4, 16, 8192, 128), + (32, 48, 8192, 128), + (32, 48, 8192, 128), + (32, 48, 8192, 128), + (4, 48, 517, 256), + (4, 48, 1024, 256), + (4, 16, 1024, 512), + (4, 16, 1024, 593), + (4, 16, 1024, 1024), + ]) @pytest.mark.parametrize('causal', [True, False]) def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):