diff --git a/python/test/unit/language/test_core_amd.py b/python/test/unit/language/test_core_amd.py index e63ec27397b7..67218c5bc588 100644 --- a/python/test/unit/language/test_core_amd.py +++ b/python/test/unit/language/test_core_amd.py @@ -1588,10 +1588,11 @@ def get_variant_golden(a, b): return c_padded[:SIZE_M, :SIZE_N] -@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K', [ - [64, 32, 128, 4, 64, 32, 64], +@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,NUM_STAGES', [ + [64, 32, 128, 4, 64, 32, 64, 0], + [64, 32, 128, 4, 64, 32, 64, 2] ]) -def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K): +def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, NUM_STAGES): a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16) b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16) c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32) @@ -1603,7 +1604,7 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO M=a.shape[0], N=b.shape[1], K=a.shape[1], BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, num_warps=NUM_WARPS, - num_stages=2) + num_stages=NUM_STAGES) golden = torch.matmul(a, b) # It's not easy to get a proper error threshold in different size diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index e03d030680c0..ab7832258341 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -91,7 +91,7 @@ def optimize_ttgir(mod, num_stages, arch): pm.add_tritongpu_accelerate_matmul_pass(matrix_core_version) pm.add_tritongpu_remove_layout_conversions_pass() pm.add_tritongpu_optimize_dot_operands_pass() - if num_stages == 0 and is_hip() and gpu_has_mfma(): + if num_stages == 0 and is_hip() and gpu_matrix_core_version() != 0: pm.add_tritongpu_stream_pipeline_pass() pm.add_canonicalizer_pass() else: