diff --git a/test/prototype/mx_formats/test_custom_cast.py b/test/prototype/mx_formats/test_custom_cast.py index 6f9a76cf19..d27e1831c9 100644 --- a/test/prototype/mx_formats/test_custom_cast.py +++ b/test/prototype/mx_formats/test_custom_cast.py @@ -40,7 +40,7 @@ sem_vals_to_f32, ) from torchao.prototype.mx_formats.mx_tensor import MXTensor -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_100 torch.manual_seed(0) @@ -310,6 +310,9 @@ def test_fp4_pack_unpack(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4") +@pytest.mark.skipif( + is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0" +) def test_fp4_triton_unscaled_cast(): packed_vals = torch.arange(0, 255, dtype=torch.uint8, device="cuda") f32_ref = f4_unpacked_to_f32(unpack_uint4(packed_vals)) @@ -320,6 +323,9 @@ def test_fp4_triton_unscaled_cast(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4") +@pytest.mark.skipif( + is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0" +) def test_fp4_triton_scaled_cast(): size = (256,) orig_vals = torch.randn(size, dtype=torch.float, device="cuda") * 100 diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index d280e38c36..35afeb7959 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -18,7 +18,11 @@ swap_linear_with_mx_linear, ) from torchao.quantization.utils import compute_error -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_89 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_4, + is_sm_at_least_89, + is_sm_at_least_100, +) torch.manual_seed(2) @@ -99,6 +103,9 @@ def test_activation_checkpointing(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0" +) @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) @pytest.mark.parametrize("bias", [False, True]) # TODO(future PR): figure out why torch.compile does not match eager when @@ -184,6 +191,9 @@ def test_inference_linear(elem_dtype, bias, input_shape): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0" +) @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_inference_compile_simple(elem_dtype): """ diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index ae87ee021e..21cb49c064 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -21,7 +21,11 @@ to_dtype, ) from torchao.quantization.utils import compute_error -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_89 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_4, + is_sm_at_least_89, + is_sm_at_least_100, +) torch.manual_seed(2) @@ -166,6 +170,8 @@ def test_transpose(elem_dtype, fp4_triton): """ if elem_dtype != DTYPE_FP4 and fp4_triton: pytest.skip("unsupported configuration") + elif fp4_triton and is_sm_at_least_100(): + pytest.skip("triton does not work yet on CUDA capability 10.0") M, K = 128, 256 block_size = 32 @@ -205,6 +211,9 @@ def test_view(elem_dtype): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0" +) @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) @pytest.mark.parametrize("hp_dtype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("all_zeros", [False, True]) diff --git a/torchao/utils.py b/torchao/utils.py index 7a17c1b104..f67463f9f7 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -630,6 +630,15 @@ def is_sm_at_least_90(): ) +# TODO(future PR): rename to 8_9, 9_0, 10_0 instead of 89, 10, 100 +def is_sm_at_least_100(): + return ( + torch.cuda.is_available() + and torch.version.cuda + and torch.cuda.get_device_capability() >= (10, 0) + ) + + TORCH_VERSION_AFTER_2_5 = _torch_version_at_least("2.5.0.dev") TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev") TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev")