Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uses pytest markers instead of module skip. #941

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions axlearn/common/flash_attention/gpu_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

Currently tested on A100/H100.
"""

import functools
from typing import Literal

Expand All @@ -28,9 +29,6 @@
from axlearn.common.flash_attention.utils import _repeat_kv_heads, mha_reference
from axlearn.common.test_utils import TestCase

if jax.default_backend() not in ("gpu", "cpu"):
pytest.skip(reason="Incompatible hardware", allow_module_level=True)


@pytest.mark.parametrize(
"batch_size,seq_len,num_heads,per_head_dim",
Expand All @@ -51,6 +49,7 @@
@pytest.mark.parametrize("attention_bias_type", [None, "2d", "4d"])
@pytest.mark.parametrize("use_segment_ids", [True, False])
@pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.float32])
@pytest.mark.gpu
def test_triton_fwd_only_against_ref(
batch_size: int,
seq_len: int,
Expand Down Expand Up @@ -122,6 +121,7 @@ def test_triton_fwd_only_against_ref(
chex.assert_trees_all_close(o, o_ref, atol=0.03)


@pytest.mark.gpu
class FlashDecodingTest(TestCase):
"""Tests FlashDecoding."""

Expand Down Expand Up @@ -234,6 +234,7 @@ def test_decode_against_ref(
@pytest.mark.parametrize("block_size", [64, 128])
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.float32])
@pytest.mark.gpu
def test_triton_against_xla_ref(
batch_size: int,
num_heads: int,
Expand Down Expand Up @@ -353,6 +354,7 @@ def ref_fn(q, k, v, bias, segment_ids, k5):
)
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16])
@pytest.mark.gpu
def test_cudnn_against_triton_ref(
batch_size: int,
num_heads: int,
Expand Down Expand Up @@ -433,6 +435,7 @@ def ref_fn(q, k, v):
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16])
@pytest.mark.parametrize("dropout_rate", [0.1, 0.25])
@pytest.mark.gpu
def test_cudnn_dropout_against_xla_dropout(
batch_size: int,
num_heads: int,
Expand Down Expand Up @@ -515,6 +518,7 @@ def ref_fn(q, k, v):
raise ValueError(f"Unsupported dtype: {dtype}")


@pytest.mark.gpu
def test_cudnn_dropout_determinism():
"""Tests that cuDNN dropout produces identical outputs across runs."""
if jax.default_backend() == "cpu":
Expand Down
2 changes: 2 additions & 0 deletions axlearn/common/flash_attention/tpu_attention_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright © 2023 Apple Inc.

"""Tests TPU FlashAttention kernels."""

from __future__ import annotations

import unittest
Expand Down Expand Up @@ -46,6 +47,7 @@ def jax_fn_mask(query_position: Tensor, key_position: Tensor) -> Tensor:
return jnp.greater_equal(query_position, key_position)


@pytest.mark.tpu
class TestFlashAttention(TestCase):
"""Tests FlashAttention layer."""

Expand Down
2 changes: 1 addition & 1 deletion run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fi

UNQUOTED_PYTEST_FILES=$(echo $1 | tr -d "'")
pytest --durations=100 -v -n auto \
-m "not (gs_login or tpu or high_cpu or fp64)" ${UNQUOTED_PYTEST_FILES} \
-m "not (gs_login or tpu or gpu or high_cpu or fp64)" ${UNQUOTED_PYTEST_FILES} \
--dist worksteal &
TEST_PIDS[$!]=1

Expand Down
Loading