Skip to content

Commit

Permalink
Bringing back check for tmem copy in the test
Browse files Browse the repository at this point in the history
  • Loading branch information
pawelszczerbuk committed Feb 26, 2025
1 parent e8a0b93 commit a77a213
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions python/test/unit/language/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ def test_blocked_scale_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, USE_
b.stride(1), output.stride(0), output.stride(1), BLOCK_M, BLOCK_N, BLOCK_K,
NUM_STAGES=NUM_STAGES, USE_2D_SCALE_LOAD=USE_2D_SCALE_LOAD)
ttgir = out.asm["ttgir"]
ptx = out.asm["ptx"]

def flatten_scale(scale):
num_chunk_m, num_chunk_k, _, _, _ = scale.shape
Expand All @@ -505,6 +506,10 @@ def flatten_scale(scale):
rtol = 0.0001
torch.testing.assert_close(ref_out, output, atol=atol, rtol=rtol)

if USE_2D_SCALE_LOAD:
# Due to an issue in the coalescing pass, tmem_copy can not be generated for the 5D load.
# The issue is fixed using the patch from https://github.com/triton-lang/triton/pull/4914
assert "tcgen05.cp" in ptx
if NUM_STAGES > 1:
if BLOCK_M == BLOCK_K and BLOCK_N == BLOCK_K:
load_pipelined = ttgir.count(f"ttg.local_alloc : () -> !ttg.memdesc<{NUM_STAGES}x{BLOCK_M}x{BLOCK_K}") == 2
Expand Down

0 comments on commit a77a213

Please sign in to comment.