Skip to content

Commit

Permalink
test varlen
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Sep 19, 2024
1 parent 937968d commit e2f325d
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions python/perf-kernels/flash-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,10 +1203,29 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype):
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)


@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 8192, 64), (4, 48, 256, 64), (4, 48, 512, 64),
(4, 48, 1024, 64), (8, 48, 4096, 64), (4, 48, 8192, 64),
(4, 48, 128, 128), (4, 48, 4096, 128), (4, 48, 16384, 128),
(4, 16, 1024, 128), (4, 16, 8192, 128), (32, 48, 8192, 128)])
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [
(1, 2, 256, 63),
(2, 8, 247, 63),
(3, 4, 256, 64),
(4, 48, 256, 64),
(4, 48, 512, 64),
(4, 48, 1024, 64),
(8, 48, 4096, 64),
(4, 48, 8192, 64),
(4, 48, 128, 128),
(4, 48, 4096, 128),
(4, 48, 16384, 128),
(4, 16, 1024, 128),
(4, 16, 8192, 128),
(32, 48, 8192, 128),
(32, 48, 8192, 128),
(32, 48, 8192, 128),
(4, 48, 517, 256),
(4, 48, 1024, 256),
(4, 16, 1024, 512),
(4, 16, 1024, 593),
(4, 16, 1024, 1024),
])
@pytest.mark.parametrize('causal', [True, False])
def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):

Expand Down

0 comments on commit e2f325d

Please sign in to comment.