Skip to content

Commit

Permalink
Fix dangling gpu_has_mfma use (#325)
Browse files Browse the repository at this point in the history
* Fix dangling gpu_has_mfma use

This PR replaces gpu_has_mfma use with gpu_matrix_core_version

* add basic test
  • Loading branch information
binarman authored Sep 11, 2023
1 parent 6691de6 commit a06072f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
9 changes: 5 additions & 4 deletions python/test/unit/language/test_core_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1588,10 +1588,11 @@ def get_variant_golden(a, b):
return c_padded[:SIZE_M, :SIZE_N]


@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K', [
[64, 32, 128, 4, 64, 32, 64],
@pytest.mark.parametrize('SIZE_M,SIZE_N,SIZE_K,NUM_WARPS,BLOCK_SIZE_M,BLOCK_SIZE_N,BLOCK_SIZE_K,NUM_STAGES', [
[64, 32, 128, 4, 64, 32, 64, 0],
[64, 32, 128, 4, 64, 32, 64, 2]
])
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K):
def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, NUM_STAGES):
a = torch.randn((SIZE_M, SIZE_K), device='cuda', dtype=torch.float16)
b = torch.randn((SIZE_K, SIZE_N), device='cuda', dtype=torch.float16)
c = torch.empty((SIZE_M, SIZE_N), device=a.device, dtype=torch.float32)
Expand All @@ -1603,7 +1604,7 @@ def test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLO
M=a.shape[0], N=b.shape[1], K=a.shape[1],
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
num_warps=NUM_WARPS,
num_stages=2)
num_stages=NUM_STAGES)
golden = torch.matmul(a, b)

# It's not easy to get a proper error threshold in different size
Expand Down
2 changes: 1 addition & 1 deletion python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def optimize_ttgir(mod, num_stages, arch):
pm.add_tritongpu_accelerate_matmul_pass(matrix_core_version)
pm.add_tritongpu_remove_layout_conversions_pass()
pm.add_tritongpu_optimize_dot_operands_pass()
if num_stages == 0 and is_hip() and gpu_has_mfma():
if num_stages == 0 and is_hip() and gpu_matrix_core_version() != 0:
pm.add_tritongpu_stream_pipeline_pass()
pm.add_canonicalizer_pass()
else:
Expand Down

0 comments on commit a06072f

Please sign in to comment.