Skip to content

Commit

Permalink
skip failing MX tests on cuda 10.0
Browse files Browse the repository at this point in the history
Summary:

PyTorch's Triton version does not yet work on cuda 10.0, skipping
relevant tests from MX folder for now.

Test Plan:

```
pytest test/prototype/mx_formats/ -s -x
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 80c4e04efaa75d94df6d17a3d8b5ff45788c0179
ghstack-comment-id: 2614583779
Pull Request resolved: #1624
  • Loading branch information
vkuzo committed Jan 26, 2025
1 parent 47f96f1 commit c68b696
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 3 deletions.
8 changes: 7 additions & 1 deletion test/prototype/mx_formats/test_custom_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
sem_vals_to_f32,
)
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_100

torch.manual_seed(0)

Expand Down Expand Up @@ -310,6 +310,9 @@ def test_fp4_pack_unpack():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
def test_fp4_triton_unscaled_cast():
packed_vals = torch.arange(0, 255, dtype=torch.uint8, device="cuda")
f32_ref = f4_unpacked_to_f32(unpack_uint4(packed_vals))
Expand All @@ -320,6 +323,9 @@ def test_fp4_triton_unscaled_cast():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
def test_fp4_triton_scaled_cast():
size = (256,)
orig_vals = torch.randn(size, dtype=torch.float, device="cuda") * 100
Expand Down
12 changes: 11 additions & 1 deletion test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
swap_linear_with_mx_linear,
)
from torchao.quantization.utils import compute_error
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_89
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
is_sm_at_least_89,
is_sm_at_least_100,
)

torch.manual_seed(2)

Expand Down Expand Up @@ -99,6 +103,9 @@ def test_activation_checkpointing():


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
@pytest.mark.parametrize("bias", [False, True])
# TODO(future PR): figure out why torch.compile does not match eager when
Expand Down Expand Up @@ -184,6 +191,9 @@ def test_inference_linear(elem_dtype, bias, input_shape):


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_inference_compile_simple(elem_dtype):
"""
Expand Down
11 changes: 10 additions & 1 deletion test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
to_dtype,
)
from torchao.quantization.utils import compute_error
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_89
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
is_sm_at_least_89,
is_sm_at_least_100,
)

torch.manual_seed(2)

Expand Down Expand Up @@ -166,6 +170,8 @@ def test_transpose(elem_dtype, fp4_triton):
"""
if elem_dtype != DTYPE_FP4 and fp4_triton:
pytest.skip("unsupported configuration")
elif fp4_triton and is_sm_at_least_100():
pytest.skip("triton does not work yet on CUDA capability 10.0")

M, K = 128, 256
block_size = 32
Expand Down Expand Up @@ -205,6 +211,9 @@ def test_view(elem_dtype):


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
@pytest.mark.parametrize("hp_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("all_zeros", [False, True])
Expand Down
9 changes: 9 additions & 0 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,15 @@ def is_sm_at_least_90():
)


# TODO(future PR): rename to 8_9, 9_0, 10_0 instead of 89, 10, 100
def is_sm_at_least_100():
return (
torch.cuda.is_available()
and torch.version.cuda
and torch.cuda.get_device_capability() >= (10, 0)
)


TORCH_VERSION_AFTER_2_5 = _torch_version_at_least("2.5.0.dev")
TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev")
TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev")
Expand Down

0 comments on commit c68b696

Please sign in to comment.