diff --git a/python/perf-kernels/03-matrix-multiplication-stream-k.py b/python/perf-kernels/03-matrix-multiplication-stream-k.py deleted file mode 100755 index 62d820719b9a..000000000000 --- a/python/perf-kernels/03-matrix-multiplication-stream-k.py +++ /dev/null @@ -1,395 +0,0 @@ -#!/usr/bin/env python -## matmul stream-k implementation -## Credit goes to @pommedeterresautee -## See https://github.com/openai/triton/issues/1393 - -# (echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"') | sudo tee -a /etc/modprobe.d/RestrictedProfiling.conf >/dev/null -# sudo update-initramfs -u -k all -# cat /proc/driver/nvidia/params | grep RmProfilingAdminOnly -# sudo apt-get install zlib1g-dev -# for reproductible experiments -# sudo nvidia-smi -pm 1 -i 0 -# sudo nvidia-smi -i 0 -pl 350 # 400 for A100 -# sudo nvidia-smi -i 0 -lgc 1005 -from typing import Optional - -import torch -import triton -import triton.language as tl -import random - -#from triton.runtime.driver import CudaUtils -import json - -torch.manual_seed(123) -random.seed(123) - -#device = torch.cuda.current_device() -#cuda_utils = CudaUtils() -#total_sm = cuda_utils.get_device_properties(device)["multiprocessor_count"] -#total_sm = 110 # for MI250 -total_sm = 304 # for MI300X -print(f"total SMs: {total_sm}") - -# --------------------------------------------------------------------------- -# Triton kernels -# --------------------------------------------------------------------------- - - -@triton.jit() -def swizzle_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr): - grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = tile_id // width - group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (tile_id % group_size) - pid_n = (tile_id % width) // group_size - return pid_m, pid_n - - -@triton.jit() -def linear_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr): - pid_m = tile_id // tl.cdiv(N, BLOCK_N) - pid_n = tile_id % tl.cdiv(N, BLOCK_N) - return pid_m, pid_n - - -@triton.jit() -def streamk_gemm( - A, - B, - C, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - total_full_tiles_streamk, - total_partial_tiles_streamk, - iters_per_tile, - total_tiles_streamk, - total_programs_streamk, - ACC_TYPE: tl.constexpr, - GROUP_M: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, -): - pid = tl.program_id(0) - - # Determine whether we are in the first wave or full_tiles phase based on pid - is_first_wave = pid < total_programs_streamk and total_programs_streamk > 0 - - # Calculate starting and ending iterations for first wave - if not is_first_wave: - tile_id = tl.program_id(0) + total_tiles_streamk - total_programs_streamk - if GROUP_M > 0: - pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - else: - pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - - # do matrix multiplication - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rk = tl.arange(0, BLOCK_K) - # pointers - A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak - B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(A_BASE) - b = tl.load(B_BASE) - acc += tl.dot(a, b) - A_BASE += BLOCK_K * stride_ak - B_BASE += BLOCK_K * stride_bk - # acc = acc.to(tl.float16) # restore C.dtype.element_ty - # rematerialize rm and rn to save registers -# rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) -# rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn - tl.store(C_, acc) - else: - # start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) - start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) - last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum(pid + 1, total_partial_tiles_streamk) - while start_iter < last_iter: - remainder = start_iter % iters_per_tile - end_iter = tl.minimum(start_iter + (iters_per_tile - remainder), last_iter) - # where are we in the grid - tile_id = start_iter // iters_per_tile - if GROUP_M > 0: - pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - else: - pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rk = tl.arange(0, BLOCK_K) - A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_K * stride_ak * remainder - B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_K * stride_bk * remainder - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for current_iter in range(start_iter, end_iter): - a = tl.load(A_BASE) - b = tl.load(B_BASE) - acc += tl.dot(a, b) - A_BASE += BLOCK_K * stride_ak - B_BASE += BLOCK_K * stride_bk - - if remainder == 0 and end_iter % iters_per_tile == 0: - C_ = C + rm[:, - None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! - tl.store(C_, acc) - else: - C_ = C + rm[:, - None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! - tl.atomic_add(C_, acc) - - start_iter = end_iter - - -# --------------------------------------------------------------------------- -# Wrapper -# --------------------------------------------------------------------------- - - -class matmul(torch.autograd.Function): - - _debug = True - - @staticmethod - def set_debug(debug: bool): - matmul._debug = debug - - @staticmethod - def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: int, BLK_N: int, BLK_K: int, - two_tiles: bool, num_stages: int, num_warps: int, waves_per_eu: int, mfmaInstrSize: int, kpack: int): - device = a.device - - assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported" - # checks constraints - assert a.shape[1] == b.shape[0], "incompatible dimensions" - M, K = a.shape - _, N = b.shape - # accumulator types - ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - # compute grid (work to do per SM on the first wave) - total_blocks_M = triton.cdiv(M, BLK_M) - total_blocks_N = triton.cdiv(N, BLK_N) - iters_per_tile = triton.cdiv(K, BLK_K) - GROUP_M = 4 # 0 to disable swizzling - total_tiles = total_blocks_M * total_blocks_N - - if total_programs_streamk > 0: # Stream-K - # last wave may occupy less than total_programs_streamk SMs - total_tiles_streamk = total_tiles % total_programs_streamk - # for two-tile Stream-K + data-parallel from original paper - if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: - total_tiles_streamk += total_programs_streamk - # remaining tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk - total_iters_streamk = total_tiles_streamk * iters_per_tile - # iterations related to full waves - total_full_tiles_streamk = total_iters_streamk // total_programs_streamk - # iterations related to last (partial) wave - total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk - - else: # all tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk = 0 - total_full_tiles_streamk = 0 - total_partial_tiles_streamk = 0 - total_iters_streamk = 0 - - if matmul._debug: - print(f"M,N,K={M},{N},{K} ; BLK_M,N,K={BLK_M},{BLK_N},{BLK_K}") - print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}") - print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}") - print(f"{total_programs_streamk=}") - print(f"{total_blocking_tiles=}") - print(f"{total_full_tiles_streamk=}") - print(f"{total_partial_tiles_streamk=}") - print(f"{iters_per_tile=}") - print(f"{total_iters_streamk=}") - - # allocates output - c = torch.zeros((M, N), device=device, dtype=a.dtype) - # allocates locks to sync work accross SMs - grids = total_programs_streamk + total_blocking_tiles - kk = streamk_gemm[(grids, )]( - a, - b, - c, - M, - N, - K, - a.stride(0), - a.stride(1), - b.stride(0), - b.stride(1), - c.stride(0), - c.stride(1), - total_full_tiles_streamk=total_full_tiles_streamk, - total_partial_tiles_streamk=total_partial_tiles_streamk, - iters_per_tile=iters_per_tile, - total_tiles_streamk=total_tiles_streamk, - total_programs_streamk=total_programs_streamk, - ACC_TYPE=ACC_TYPE, - GROUP_M=GROUP_M, - BLOCK_M=BLK_M, - BLOCK_N=BLK_N, - BLOCK_K=BLK_K, - num_stages=num_stages, - num_warps=num_warps, - waves_per_eu=waves_per_eu, - matrix_instr_nonkdim=mfmaInstrSize, - kpack=kpack, - ) - if matmul._debug: - print(f"{kk.n_regs} registers used, {kk.n_spills} spills") - - # print(kk.asm['ttgir']) - # print(kk.asm['amdgcn']) - - return c - - @staticmethod - def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=128, BLK_K=32, two_tiles=True, - num_stages=3, num_warps=4, waves_per_eu=2, mfmaInstrSize=16, kpack=1): - return matmul._call(a=a, b=b, total_programs_streamk=grid, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, - two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages, waves_per_eu=waves_per_eu, - mfmaInstrSize=mfmaInstrSize, kpack=kpack) - - -# --------------------------------------------------------------------------- -# Example and Benchmark -# --------------------------------------------------------------------------- - -perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) - -#m, n, k = 4864, 4096, 8256 # some problem size to test -#m, n, k = 4096, 4096, 8192 # some problem size to test -#m, n, k = 8192, 8192, 8192 # some problem size to test -m, n, k = 6912, 768, 256 # some problem size to test -A = torch.randn(m, k, device="cuda", dtype=torch.float16) -B = torch.randn(k, n, device="cuda", dtype=torch.float16) -BLK_M = 64 -BLK_N = 64 -BLK_K = 64 -two_tiles = 'True' -num_stages = 0 -num_warps = 4 -waves_per_eu = 0 -mfmaInstrSize = 16 -kpack = 2 - -matmul.set_debug(True) -C = matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, - kpack) -#exit(0) -matmul.set_debug(False) -expected = A @ B - -#assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}" -print("pass validation test") - -# for debugging, uncomment the following line -# exit(0) - -triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B)) -print(f"PyTorch: {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, - num_warps, waves_per_eu, mfmaInstrSize, kpack)) -print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, - num_warps, waves_per_eu, mfmaInstrSize, kpack)) -print(f"hybrid stream-k (grid={total_sm * 2}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, - waves_per_eu, mfmaInstrSize, kpack)) -print(f"tile matmul (grid=0): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -exit(0) -# --------------------------------------------------------------------------- -# Log-sampled benchmark -# --------------------------------------------------------------------------- - -# tried to reproduce the tests described in the paper -num_samples = 1000 # 32768 -step = 256 -values = ((torch.logspace(torch.tensor(step).log2(), - torch.tensor(8192).log2(), num_samples, base=2) / step).round() * step).unique().tolist() -shapes = [(int(m), int(n), int(k)) for m in values for n in values for k in values] -shapes = random.sample(shapes, num_samples) -assert len(shapes) == num_samples - -results = [] -for idx, (m, n, k) in enumerate(shapes): - # print progress bar - if idx % 10 == 0 and idx > 0: - speedups = [r["speedup"] for r in results] - print(f"{idx}/{num_samples} - average speedup: {sum(speedups) / len(speedups):.3f}") - - A = torch.randn(m, k, device="cuda", dtype=torch.float16) - B = torch.randn(k, n, device="cuda", dtype=torch.float16) - output: Optional[torch.Tensor] = None - - def wrapper_matmul(*args, **kwargs): - global output - output = matmul.apply(*args, **kwargs) - return output - - expected = A @ B - pytorch_ms = triton.testing.do_bench(lambda: A @ B) - measures = list() - for two_tiles in [True, False]: - nb_sm = [total_sm, total_sm * 2] - total_tile = (m // BLK_M) * (n // BLK_N) - if total_tile < total_sm * 2: - nb_sm.append(total_tile) - nb_sm += random.sample(range(2, total_sm * 2, 2), 10) - for sm in nb_sm: - triton_ms = triton.testing.do_bench( - lambda: wrapper_matmul(A, B, sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu)) - max_disc = (output - expected).abs().max().item() - # large tolerance to accomodate for large K (rounding due to half precision), we just want to catch bugs. - assert max_disc <= 5., f"pb size: {m}x{n}x{k} - max discrepancy: {max_disc} - sm: {sm}, 2 tiles: {two_tiles}\n{output}\n{expected}" - info = { - "2 tiles": two_tiles, - "sm": sm, - "disc": max_disc, - "triton_ms": triton_ms, - } - measures.append(info) - best_triton_ms = min([m["triton_ms"] for m in measures]) - d = { - "m": m, - "n": n, - "k": k, - "triton": measures, - "pytorch_ms": pytorch_ms, - "speedup": pytorch_ms / best_triton_ms, - } - results.append(d) - measures = list() - -results.sort(key=lambda x: x["speedup"], reverse=False) - -# --------------------------------------------------------------------------- -# Benchmark export -# --------------------------------------------------------------------------- - -with open("results.json", "w") as f: - json.dump(results, f, indent=4) - -# 32760/32768 - average speedup: 0.962 (A100) -# 990/1000 - average speedup: 1.063 (3090 RTX with while loop and 2 tiles disabled / enabled) diff --git a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-oldversion.py b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-oldversion.py deleted file mode 100644 index beb8b0df9b1f..000000000000 --- a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-oldversion.py +++ /dev/null @@ -1,485 +0,0 @@ -## matmul stream-k implementation -## Credit goes to @pommedeterresautee -## See https://github.com/openai/triton/issues/1393 - -# (echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"') | sudo tee -a /etc/modprobe.d/RestrictedProfiling.conf >/dev/null -# sudo update-initramfs -u -k all -# cat /proc/driver/nvidia/params | grep RmProfilingAdminOnly -# sudo apt-get install zlib1g-dev -# for reproductible experiments -# sudo nvidia-smi -pm 1 -i 0 -# sudo nvidia-smi -i 0 -pl 350 # 400 for A100 -# sudo nvidia-smi -i 0 -lgc 1005 -from typing import Optional - -import torch -import triton -import triton.language as tl -import random - -#from triton.runtime.driver import CudaUtils -import json - -torch.manual_seed(123) -random.seed(123) - -#device = torch.cuda.current_device() -#cuda_utils = CudaUtils() -#total_sm = cuda_utils.get_device_properties(device)["multiprocessor_count"] -#total_sm = 110 # for MI250 -total_sm = 304 # for MI300X -print(f"total SMs: {total_sm}") - -# --------------------------------------------------------------------------- -# Triton kernels -# --------------------------------------------------------------------------- - - -@triton.jit() -def swizzle_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr): - grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = tile_id // width - group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (tile_id % group_size) - pid_n = (tile_id % width) // group_size - return pid_m, pid_n - - -@triton.jit() -def linear_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr): - pid_m = tile_id // tl.cdiv(N, BLOCK_N) - pid_n = tile_id % tl.cdiv(N, BLOCK_N) - return pid_m, pid_n - - -# iterate, multiply and accumulate over K axis -@triton.jit() -def mac_loop( - A, - B, - C, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - tile_id, - mod1, - mod2, - iters_per_tile, - start_iter, - end_iter, - pid_m, - pid_n, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - ACC_TYPE: tl.constexpr, -): - - # where are we in the grid - # tile_id = start_iter // iters_per_tile - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rk = tl.arange(0, BLOCK_K) - # A = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_K * stride_ak * (start_iter % iters_per_tile) - # B = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_K * stride_bk * (start_iter % iters_per_tile) - A = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_K * stride_ak * (mod1) - B = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_K * stride_bk * (mod1) - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - - for current_iter in range(start_iter, end_iter): - a = tl.load(A) - b = tl.load(B) - acc += tl.dot(a, b) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - - #if end_iter % iters_per_tile == 0: # last iteration of the tile always happens before its start on another SM - - -# if mod2 == 0:# last iteration of the tile always happens before its start on another SM -# C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! -# tl.store(C_, acc) -# if start_iter % iters_per_tile != 0: # only if tile has been partially processed -# if mod1 != 0: # only if tile has been partially processed -# tl.atomic_xchg(locks + tile_id, 1) -# else: -# while tl.atomic_cas(locks + tile_id, 1, 1) != 1: -# pass -# C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! -# tl.atomic_add(C_, acc) - if mod1 == 0 and mod2 == 0: - C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! - tl.store(C_, acc) - else: - C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! - tl.atomic_add(C_, acc) - - -@triton.jit() -def first_wave( - A, - B, - C, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - total_full_tiles_streamk, - total_partial_tiles_streamk, - iters_per_tile, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - ACC_TYPE: tl.constexpr, - GROUP_M: tl.constexpr, -): - pid = tl.program_id(0) - start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) - last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum(pid + 1, total_partial_tiles_streamk) - - while start_iter < last_iter: - end_iter = tl.minimum(start_iter + (iters_per_tile - start_iter % iters_per_tile), last_iter) - mod1 = start_iter % iters_per_tile - mod2 = end_iter % iters_per_tile - tile_id = start_iter // iters_per_tile - if GROUP_M > 0: - pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - else: - pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - mac_loop( - A, - B, - C, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - tile_id, - mod1, - mod2, - iters_per_tile, - start_iter, - end_iter, - pid_m, - pid_n, - BLOCK_M, - BLOCK_N, - BLOCK_K, - ACC_TYPE, - ) - - start_iter = end_iter - - -# similar to the reference matmul kernel -@triton.jit() -def full_tiles( - A, - B, - C, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - total_tiles_streamk, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - ACC_TYPE: tl.constexpr, - GROUP_M: tl.constexpr, -): - # first wave has done more tiles than there are SMs, we adjust pid - tile_id = tl.program_id(0) + total_tiles_streamk - if GROUP_M > 0: - pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - else: - pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - - # do matrix multiplication - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rk = tl.arange(0, BLOCK_K) - # pointers - A = A + rm[:, None] * stride_am + rk[None, :] * stride_ak - B = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(A) - b = tl.load(B) - acc += tl.dot(a, b) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - acc = acc.to(tl.float16) # restore C.dtype.element_ty - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn - tl.store(C, acc) - - -# --------------------------------------------------------------------------- -# Wrapper -# --------------------------------------------------------------------------- - - -class matmul(torch.autograd.Function): - - _debug = False - - @staticmethod - def set_debug(debug: bool): - matmul._debug = debug - - @staticmethod - def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: int, BLK_N: int, BLK_K: int, - two_tiles: bool, num_stages: int, num_warps: int): - device = a.device - - assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported" - # checks constraints - assert a.shape[1] == b.shape[0], "incompatible dimensions" - M, K = a.shape - _, N = b.shape - # accumulator types - ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - # compute grid (work to do per SM on the first wave) - total_blocks_M = triton.cdiv(M, BLK_M) - total_blocks_N = triton.cdiv(N, BLK_N) - iters_per_tile = triton.cdiv(K, BLK_K) - GROUP_M = 8 # 0 to disable swizzling - total_tiles = total_blocks_M * total_blocks_N - - if total_programs_streamk > 0: # Stream-K - # last wave may occupy less than total_programs_streamk SMs - total_tiles_streamk = total_tiles % total_programs_streamk - # for two-tile Stream-K + data-parallel from original paper - if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: - total_tiles_streamk += total_programs_streamk - # remaining tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk - total_iters_streamk = total_tiles_streamk * iters_per_tile - # iterations related to full waves - total_full_tiles_streamk = total_iters_streamk // total_programs_streamk - # iterations related to last (partial) wave - total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk - - else: # all tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk = 0 - total_full_tiles_streamk = 0 - total_partial_tiles_streamk = 0 - total_iters_streamk = 0 - - if matmul._debug: - print(f"M,N,K={M},{N},{K} ; BLK_M,N,K={BLK_M},{BLK_N},{BLK_K}") - print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}") - print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}") - print(f"{total_programs_streamk=}") - print(f"{total_blocking_tiles=}") - print(f"{iters_per_tile=}") - print(f"{total_iters_streamk=}") - - # allocates output - c = torch.zeros((M, N), device=device, dtype=a.dtype) - # allocates locks to sync work accross SMs - k1 = first_wave[(total_programs_streamk, )]( - a, - b, - c, - M, - N, - K, - a.stride(0), - a.stride(1), - b.stride(0), - b.stride(1), - c.stride(0), - c.stride(1), - total_full_tiles_streamk=total_full_tiles_streamk, - total_partial_tiles_streamk=total_partial_tiles_streamk, - iters_per_tile=iters_per_tile, - BLOCK_M=BLK_M, - BLOCK_N=BLK_N, - BLOCK_K=BLK_K, - ACC_TYPE=ACC_TYPE, - GROUP_M=GROUP_M, - num_stages=num_stages, - num_warps=num_warps, - ) - if matmul._debug: - print(f"{k1.n_regs} registers used, {k1.n_spills} spills") - k2 = full_tiles[(total_blocking_tiles, )]( - a, - b, - c, - M, - N, - K, - a.stride(0), - a.stride(1), - b.stride(0), - b.stride(1), - c.stride(0), - c.stride(1), - total_tiles_streamk=total_tiles_streamk, - BLOCK_M=BLK_M, - BLOCK_N=BLK_N, - BLOCK_K=BLK_K, - ACC_TYPE=ACC_TYPE, - GROUP_M=GROUP_M, - num_stages=num_stages, - num_warps=num_warps, - ) - if matmul._debug: - print(f"{k2.n_regs} registers used, {k2.n_spills} spills") - return c - - @staticmethod - def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=128, BLK_K=32, two_tiles=True, - num_stages=3, num_warps=4): - return matmul._call(a=a, b=b, total_programs_streamk=grid, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, - two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages) - - -# --------------------------------------------------------------------------- -# Example and Benchmark -# --------------------------------------------------------------------------- - -perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) - -m, n, k = 8192, 8192, 8192 # some problem size to test -A = torch.randn(m, k, device="cuda", dtype=torch.float16) -B = torch.randn(k, n, device="cuda", dtype=torch.float16) -BLK_M = 128 -BLK_N = 256 -BLK_K = 16 -two_tiles = 'True' -num_stages = 0 -num_warps = 4 - -matmul.set_debug(True) -C = matmul.apply(A, B, total_sm, 128, 128, 32, 4, 4) -matmul.set_debug(False) -expected = A @ B - -assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}" - -# for debugging, uncomment the following line -# exit(0) - -triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B)) -print(f"PyTorch: {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench( - lambda: matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps)) -print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench( - lambda: matmul.apply(A, B, total_sm * 2, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps)) -print(f"hybrid stream-k (grid={total_sm * 2}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench( - lambda: matmul.apply(A, B, 0, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps)) -print(f"tile matmul (grid=0): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -exit(0) -# --------------------------------------------------------------------------- -# Log-sampled benchmark -# --------------------------------------------------------------------------- - -# tried to reproduce the tests described in the paper -num_samples = 1000 # 32768 -step = 256 -values = ((torch.logspace(torch.tensor(step).log2(), - torch.tensor(8192).log2(), num_samples, base=2) / step).round() * step).unique().tolist() -shapes = [(int(m), int(n), int(k)) for m in values for n in values for k in values] -shapes = random.sample(shapes, num_samples) -assert len(shapes) == num_samples - -results = [] -for idx, (m, n, k) in enumerate(shapes): - # print progress bar - if idx % 10 == 0 and idx > 0: - speedups = [r["speedup"] for r in results] - print(f"{idx}/{num_samples} - average speedup: {sum(speedups) / len(speedups):.3f}") - - A = torch.randn(m, k, device="cuda", dtype=torch.float16) - B = torch.randn(k, n, device="cuda", dtype=torch.float16) - output: Optional[torch.Tensor] = None - - def wrapper_matmul(*args, **kwargs): - global output - output = matmul.apply(*args, **kwargs) - return output - - expected = A @ B - pytorch_ms = triton.testing.do_bench(lambda: A @ B) - measures = list() - for two_tiles in [True, False]: - nb_sm = [total_sm, total_sm * 2] - total_tile = (m // 128) * (n // 128) - if total_tile < total_sm * 2: - nb_sm.append(total_tile) - nb_sm += random.sample(range(2, total_sm * 2, 2), 10) - for sm in nb_sm: - triton_ms = triton.testing.do_bench(lambda: wrapper_matmul(A, B, sm, 128, 128, 32, two_tiles, 4, 4)) - max_disc = (output - expected).abs().max().item() - # large tolerance to accomodate for large K (rounding due to half precision), we just want to catch bugs. - assert max_disc <= 5., f"pb size: {m}x{n}x{k} - max discrepancy: {max_disc} - sm: {sm}, 2 tiles: {two_tiles}\n{output}\n{expected}" - info = { - "2 tiles": two_tiles, - "sm": sm, - "disc": max_disc, - "triton_ms": triton_ms, - } - measures.append(info) - best_triton_ms = min([m["triton_ms"] for m in measures]) - d = { - "m": m, - "n": n, - "k": k, - "triton": measures, - "pytorch_ms": pytorch_ms, - "speedup": pytorch_ms / best_triton_ms, - } - results.append(d) - measures = list() - -results.sort(key=lambda x: x["speedup"], reverse=False) - -# --------------------------------------------------------------------------- -# Benchmark export -# --------------------------------------------------------------------------- - -with open("results.json", "w") as f: - json.dump(results, f, indent=4) - -# 32760/32768 - average speedup: 0.962 (A100) -# 990/1000 - average speedup: 1.063 (3090 RTX with while loop and 2 tiles disabled / enabled) diff --git a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singlekern-autotune.py b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singlekern-autotune.py deleted file mode 100644 index a35d691a0225..000000000000 --- a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singlekern-autotune.py +++ /dev/null @@ -1,563 +0,0 @@ -## matmul stream-k implementation -## Credit goes to @pommedeterresautee -## See https://github.com/openai/triton/issues/1393 - -# (echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"') | sudo tee -a /etc/modprobe.d/RestrictedProfiling.conf >/dev/null -# sudo update-initramfs -u -k all -# cat /proc/driver/nvidia/params | grep RmProfilingAdminOnly -# sudo apt-get install zlib1g-dev -# for reproductible experiments -# sudo nvidia-smi -pm 1 -i 0 -# sudo nvidia-smi -i 0 -pl 350 # 400 for A100 -# sudo nvidia-smi -i 0 -lgc 1005 -from typing import Optional - -import torch -import triton -import triton.language as tl -import random - -#from triton.runtime.driver import CudaUtils -import json - -torch.manual_seed(123) -random.seed(123) - -#device = torch.cuda.current_device() -#cuda_utils = CudaUtils() -#total_sm = cuda_utils.get_device_properties(device)["multiprocessor_count"] -#total_sm = 110 # for MI250 -total_sm = 304 # for MI300X -print(f"total SMs: {total_sm}") -# global flag to indicate whether using the full tuing space -tuning_full_space = True -# --------------------------------------------------------------------------- -# Triton kernels -# --------------------------------------------------------------------------- - - -@triton.jit() -def swizzle_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr): - grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = tile_id // width - group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (tile_id % group_size) - pid_n = (tile_id % width) // group_size - return pid_m, pid_n - - -@triton.jit() -def linear_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr): - pid_m = tile_id // tl.cdiv(N, BLOCK_N) - pid_n = tile_id % tl.cdiv(N, BLOCK_N) - return pid_m, pid_n - - -@triton.jit() -def get_tile_config(M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, two_tiles, - total_programs_streamk): - total_blocks_M = tl.cdiv(M, BLOCK_M) - total_blocks_N = tl.cdiv(N, BLOCK_N) - iters_per_tile = tl.cdiv(K, BLOCK_K) - # GROUP_M = 0 # 0 to disable swizzling - total_tiles = total_blocks_M * total_blocks_N - if total_programs_streamk > 0: # Stream-K - # last wave may occupy less than total_programs_streamk SMs - total_tiles_streamk = total_tiles % total_programs_streamk - # for two-tile Stream-K + data-parallel from original paper - if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: - total_tiles_streamk += total_programs_streamk - # remaining tiles are computed using classical blocking - total_iters_streamk = total_tiles_streamk * iters_per_tile - # iterations related to full waves - total_full_tiles_streamk = total_iters_streamk // total_programs_streamk - # iterations related to last (partial) wave - total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk - - else: # all tiles are computed using classical blocking - total_tiles_streamk = 0 - total_full_tiles_streamk = 0 - total_partial_tiles_streamk = 0 - total_iters_streamk = 0 - - return iters_per_tile, total_tiles_streamk, total_full_tiles_streamk, total_partial_tiles_streamk, total_iters_streamk - - -# pruned some unreasonable config -def prune_configs(configs, named_args): - # call only for full tuning space - if not tuning_full_space: - return configs - - SIZE_M = named_args["A"].shape[0] - SIZE_N = named_args["B"].shape[1] - # SIZE_K = named_args["A"].shape[1] - - pruned_configs = [] - for config in configs: - kw = config.kwargs - BLOCK_M, BLOCK_N, _ =\ - kw["BLOCK_M"], kw["BLOCK_N"], kw["BLOCK_K"] - if SIZE_M <= 32 and BLOCK_M != 32: - continue - if SIZE_N <= 32 and BLOCK_N != 32: - continue - - pruned_configs.append(config) - - return pruned_configs - - -def get_full_tuning_space(): - configs = [] - if not tuning_full_space: - return configs - - block_mn_range = [64, 128, 256] - block_k_range = [16, 32, 64] - num_warps_range = [1, 2, 4, 8] - # group_m_range = [0, 1, 2, 4, 8] - group_m_range = [0, 4, 8] - # For now we see better perf with num_stages=0 for all gemm configs we care - # But keep this explicit so that we do not forget we may need to set it to - # other values in the future - num_stage_range = [0] - waves_per_eu_range = [0] - matrix_instr_nonkdim_range = [16, 32] - kpack_range = [1, 2] - - for block_m in block_mn_range: - for block_n in block_mn_range: - for block_k in block_k_range: - for num_warps in num_warps_range: - for group_m in group_m_range: - for num_stages in num_stage_range: - for num_waves_per_eu in waves_per_eu_range: - for matrix_instr_nonkdim in matrix_instr_nonkdim_range: - for kpack in kpack_range: - configs.append( - triton.Config( - { - 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, - 'GROUP_M': group_m, 'waves_per_eu': num_waves_per_eu, - 'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack - }, - num_stages=num_stages, - num_warps=num_warps, - )) - - return configs - - -#To do: we need update the default autotune configuration once we go through the whole performance test sets. -@triton.autotune( - configs=get_full_tuning_space() if tuning_full_space else [ - triton.Config( - { - 'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 8, 'waves_per_eu': 0, 'matrix_instr_nonkdim': - 16, 'kpack': 1 - }, num_warps=4, num_stages=0), - triton.Config( - { - 'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 16, 'GROUP_M': 8, 'waves_per_eu': 2, 'matrix_instr_nonkdim': - 16, 'kpack': 1 - }, num_warps=4, num_stages=0), - triton.Config( - { - 'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 4, 'waves_per_eu': 0, 'matrix_instr_nonkdim': - 16, 'kpack': 1 - }, num_warps=4, num_stages=0), - triton.Config( - { - 'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 16, 'GROUP_M': 4, 'waves_per_eu': 2, 'matrix_instr_nonkdim': - 16, 'kpack': 1 - }, num_warps=4, num_stages=0), - triton.Config( - { - 'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_M': 16, 'waves_per_eu': 0, 'matrix_instr_nonkdim': - 16, 'kpack': 1 - }, num_warps=4, num_stages=0), - triton.Config( - { - 'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 16, 'GROUP_M': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': - 16, 'kpack': 1 - }, num_warps=4, num_stages=4), - ], - key=['M', 'N', 'K'], - # prune_configs_by={ - # 'early_config_prune': prune_configs, - # 'perf_model': None, - # "top_k": None - # }, - reset_to_zero=['C'], -) -@triton.jit() -def streamk_gemm( - A, - B, - C, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - # total_full_tiles_streamk, total_partial_tiles_streamk, iters_per_tile, - # total_tiles_streamk, - total_programs_streamk, - two_tiles, - ACC_TYPE: tl.constexpr, - GROUP_M: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, -): - pid = tl.program_id(0) - iters_per_tile, total_tiles_streamk, total_full_tiles_streamk, total_partial_tiles_streamk, total_iters_streamk = get_tile_config( - M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, total_programs_streamk) - - # Determine whether we are in the first wave or full_tiles phase based on pid - is_first_wave = pid < total_programs_streamk and total_programs_streamk > 0 - - # Calculate starting and ending iterations for first wave - if not is_first_wave: - tile_id = tl.program_id(0) + total_tiles_streamk - total_programs_streamk - if GROUP_M > 0: - pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - else: - pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - - # do matrix multiplication - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - precomputed_stride_ak = BLOCK_K * stride_ak - precomputed_stride_bk = BLOCK_K * stride_bk - # pointers - A_BASE = A + ram[:, None] * stride_am + rk[None, :] * stride_ak - B_BASE = B + rk[:, None] * stride_bk + rbn[None, :] * stride_bn - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(A_BASE) - b = tl.load(B_BASE) - acc += tl.dot(a, b) - A_BASE += precomputed_stride_ak - B_BASE += precomputed_stride_bk - # acc = acc.to(tl.float16) # restore C.dtype.element_ty - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn - tl.store(C_, acc) - else: - start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) - last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum(pid + 1, total_partial_tiles_streamk) - while start_iter < last_iter: - remainder = start_iter % iters_per_tile - end_iter = tl.minimum(start_iter + (iters_per_tile - remainder), last_iter) - # where are we in the grid - tile_id = start_iter // iters_per_tile - if GROUP_M > 0: - pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - else: - pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - A_BASE = A + ram[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_K * stride_ak * remainder - B_BASE = B + rk[:, None] * stride_bk + rbn[None, :] * stride_bn + BLOCK_K * stride_bk * remainder - precomputed_stride_ak = BLOCK_K * stride_ak - precomputed_stride_bk = BLOCK_K * stride_bk - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for current_iter in range(start_iter, end_iter): - a = tl.load(A_BASE) - b = tl.load(B_BASE) - acc += tl.dot(a, b) - A_BASE += precomputed_stride_ak - B_BASE += precomputed_stride_bk - - # acc = acc.to(tl.float16) # restore C.dtype.element_ty - if remainder == 0 and end_iter % iters_per_tile == 0: - C_ = C + rm[:, - None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! - tl.store(C_, acc) - else: - C_ = C + rm[:, - None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! - tl.atomic_add(C_, acc) - - start_iter = end_iter - - -# --------------------------------------------------------------------------- -# Wrapper -# --------------------------------------------------------------------------- - - -class matmul(torch.autograd.Function): - - _debug = True - - @staticmethod - def set_debug(debug: bool): - matmul._debug = debug - - @staticmethod - def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLOCK_M: int, BLOCK_N: int, BLOCK_K: int, - two_tiles: bool, num_stages: int, num_warps: int, waves_per_eu: int, mfmaInstrSize: int, kpack: int): - - def compute_total_blocking_tiles(M, N, BLOCK_M, BLOCK_N, two_tiles, total_programs_streamk): - total_blocks_M = triton.cdiv(M, BLOCK_M) - total_blocks_N = triton.cdiv(N, BLOCK_N) - total_tiles = total_blocks_M * total_blocks_N - - if total_programs_streamk > 0: # Stream-K - # last wave may occupy less than total_programs_streamk SMs - total_tiles_streamk = total_tiles % total_programs_streamk - # for two-tile Stream-K + data-parallel from original paper - if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: - total_tiles_streamk += total_programs_streamk - # remaining tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk - else: # all tiles are computed using classical blocking - total_blocking_tiles = total_tiles - - return total_blocking_tiles - - device = a.device - - assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported" - # checks constraints - assert a.shape[1] == b.shape[0], "incompatible dimensions" - M, K = a.shape - _, N = b.shape - # accumulator types - ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - # compute grid (work to do per SM on the first wave) - # GROUP_M = 8 # 0 to disable swizzling - - if matmul._debug: - total_blocks_M = triton.cdiv(M, BLOCK_M) - total_blocks_N = triton.cdiv(N, BLOCK_N) - iters_per_tile = triton.cdiv(K, BLOCK_K) - total_tiles = total_blocks_M * total_blocks_N - if total_programs_streamk > 0: # Stream-K - # last wave may occupy less than total_programs_streamk SMs - total_tiles_streamk = total_tiles % total_programs_streamk - # for two-tile Stream-K + data-parallel from original paper - if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: - total_tiles_streamk += total_programs_streamk - # remaining tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk - total_iters_streamk = total_tiles_streamk * iters_per_tile - # iterations related to full waves - # total_full_tiles_streamk = total_iters_streamk // total_programs_streamk - # iterations related to last (partial) wave - total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk - - else: # all tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk = 0 - # total_full_tiles_streamk = 0 - total_partial_tiles_streamk = 0 - total_iters_streamk = 0 - print(f"M,N,K={M},{N},{K} ; BLOCK_M,N,K={BLOCK_M},{BLOCK_N},{BLOCK_K}") - print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}") - print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}") - print(f"{total_programs_streamk=}") - print(f"{total_blocking_tiles=}") - print(f"{total_partial_tiles_streamk=}") - print(f"{iters_per_tile=}") - print(f"{total_iters_streamk=}") - - # allocates output - c = torch.zeros((M, N), device=device, dtype=a.dtype) - grids = lambda META: (total_programs_streamk + compute_total_blocking_tiles(M, N, META['BLOCK_M'], META[ - 'BLOCK_N'], two_tiles, total_programs_streamk), ) - kk = streamk_gemm[(grids)]( - a, - b, - c, - M, - N, - K, - a.stride(0), - a.stride(1), - b.stride(0), - b.stride(1), - c.stride(0), - c.stride(1), - # total_full_tiles_streamk=total_full_tiles_streamk, - # total_partial_tiles_streamk=total_partial_tiles_streamk, - # iters_per_tile=iters_per_tile, - # total_tiles_streamk=total_tiles_streamk, - total_programs_streamk=total_programs_streamk, - two_tiles=two_tiles, - ACC_TYPE=ACC_TYPE, - # GROUP_M=GROUP_M, - # BLOCK_M=BLOCK_M, - # BLOCK_N=BLOCK_N, - # BLOCK_K=BLOCK_K, - # num_stages=num_stages, - # num_warps=num_warps, - # waves_per_eu = waves_per_eu, - ) - if matmul._debug: - print(f"{kk.n_regs} registers used, {kk.n_spills} spills") - - # print(kk.asm['ttgir']) - # print(kk.asm['amdgcn']) - return c - - @staticmethod - def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLOCK_M=128, BLOCK_N=128, BLOCK_K=32, two_tiles=True, - num_stages=3, num_warps=4, waves_per_eu=2, mfmaInstrSize=16, kpack=1): - return matmul._call(a=a, b=b, total_programs_streamk=grid, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, - two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages, waves_per_eu=waves_per_eu, - mfmaInstrSize=mfmaInstrSize, kpack=kpack) - - -# --------------------------------------------------------------------------- -# Example and Benchmark -# --------------------------------------------------------------------------- - -perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) - -#m, n, k = 1792, 7424, 4864 # some problem size to test -#m, n, k = 8192, 8192, 8192 # some problem size to test -m, n, k = 4096, 4096, 8192 # some problem size to test -A = torch.randn(m, k, device="cuda", dtype=torch.float16) -B = torch.randn(k, n, device="cuda", dtype=torch.float16) -#A = torch.ones((m, k), device="cuda", dtype=torch.float16) -#B = torch.ones((k, n), device="cuda", dtype=torch.float16) -BLOCK_M = 256 -BLOCK_N = 256 -BLOCK_K = 64 -two_tiles = True -num_stages = 0 -num_warps = 8 -waves_per_eu = 0 -mfmaInstrSize = 16 -kpack = 1 - -matmul.set_debug(True) -C = matmul.apply(A, B, total_sm, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, num_stages, num_warps, waves_per_eu, - mfmaInstrSize, kpack) -matmul.set_debug(False) -expected = A @ B - -assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}" -print("pass validation test") - -# for debugging, uncomment the following line -#exit(0) - -triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B)) -print(f"PyTorch: {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, - num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)) -print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") -print(f'SIZE: {m},{n},{k} Best tuning config: ({streamk_gemm.get_best_config()})') - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, - num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)) -print(f"hybrid stream-k (grid={total_sm * 2}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") -print(f'SIZE: {m},{n},{k} Best tuning config: ({streamk_gemm.get_best_config()})') - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, num_stages, - num_warps, waves_per_eu, mfmaInstrSize, kpack)) -print(f"tile matmul (grid=0): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") -print(f'SIZE: {m},{n},{k} Best tuning config: ({streamk_gemm.get_best_config()})') - -exit(0) -# --------------------------------------------------------------------------- -# Log-sampled benchmark -# --------------------------------------------------------------------------- - -# tried to reproduce the tests described in the paper -perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) -num_samples = 1000 # 32768 -step = 256 -values = ((torch.logspace(torch.tensor(step).log2(), - torch.tensor(8192).log2(), num_samples, base=2) / step).round() * step).unique().tolist() -shapes = [(int(m), int(n), int(k)) for m in values for n in values for k in values] -shapes = random.sample(shapes, num_samples) -assert len(shapes) == num_samples - -results = [] -for idx, (m, n, k) in enumerate(shapes): - # print progress bar - if idx % 10 == 0 and idx > 0: - speedups = [r["speedup"] for r in results] - print(f"{idx}/{num_samples} - average speedup: {sum(speedups) / len(speedups):.3f}") - - A = torch.randn(m, k, device="cuda", dtype=torch.float16) - B = torch.randn(k, n, device="cuda", dtype=torch.float16) - output: Optional[torch.Tensor] = None - - def wrapper_matmul(*args, **kwargs): - global output - output = matmul.apply(*args, **kwargs) - return output - - expected = A @ B - pytorch_ms = triton.testing.do_bench(lambda: A @ B) - measures = list() - for two_tiles in [True, False]: - nb_sm = [total_sm, total_sm * 2] - total_tile = (m // BLOCK_M) * (n // BLOCK_N) - if total_tile < total_sm * 2: - nb_sm.append(total_tile) - nb_sm += random.sample(range(2, total_sm * 2, 2), 10) - for sm in nb_sm: - triton_ms = triton.testing.do_bench(lambda: wrapper_matmul(A, B, sm, BLOCK_M, BLOCK_N, BLOCK_K, two_tiles, - num_stages, num_warps, waves_per_eu)) - max_disc = (output - expected).abs().max().item() - # large tolerance to accomodate for large K (rounding due to half precision), we just want to catch bugs. - assert max_disc <= 5., f"pb size: {m}x{n}x{k} - max discrepancy: {max_disc} - sm: {sm}, 2 tiles: {two_tiles}\n{output}\n{expected}" - Best_tuning_config = f'SIZE: {m},{n},{k} Best tuning config: ({streamk_gemm.get_best_config()})' - info = { - "2 tiles": two_tiles, - "sm": sm, - "disc": max_disc, - "triton_ms": triton_ms, - "Best tuning config": Best_tuning_config, - } - measures.append(info) - best_triton_ms = min([m["triton_ms"] for m in measures]) - d = { - "m": m, - "n": n, - "k": k, - "triton": measures, - "pytorch_ms": pytorch_ms, - "speedup": pytorch_ms / best_triton_ms, - } - results.append(d) - measures = list() - -results.sort(key=lambda x: x["speedup"], reverse=False) - -# --------------------------------------------------------------------------- -# Benchmark export -# --------------------------------------------------------------------------- - -with open("results.json", "w") as f: - json.dump(results, f, indent=4) - -# 32760/32768 - average speedup: 0.962 (A100) -# 990/1000 - average speedup: 1.063 (3090 RTX with while loop and 2 tiles disabled / enabled) diff --git a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singleloop-nomod.py b/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singleloop-nomod.py deleted file mode 100644 index 2651ad59d923..000000000000 --- a/python/perf-kernels/streamk/03-matrix-multiplication-stream-k-singleloop-nomod.py +++ /dev/null @@ -1,387 +0,0 @@ -## matmul stream-k implementation -## Credit goes to @pommedeterresautee -## See https://github.com/openai/triton/issues/1393 - -# (echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"') | sudo tee -a /etc/modprobe.d/RestrictedProfiling.conf >/dev/null -# sudo update-initramfs -u -k all -# cat /proc/driver/nvidia/params | grep RmProfilingAdminOnly -# sudo apt-get install zlib1g-dev -# for reproductible experiments -# sudo nvidia-smi -pm 1 -i 0 -# sudo nvidia-smi -i 0 -pl 350 # 400 for A100 -# sudo nvidia-smi -i 0 -lgc 1005 -from typing import Optional - -import torch -import triton -import triton.language as tl -import random - -#from triton.runtime.driver import CudaUtils -import json - -torch.manual_seed(123) -random.seed(123) - -#device = torch.cuda.current_device() -#cuda_utils = CudaUtils() -#total_sm = cuda_utils.get_device_properties(device)["multiprocessor_count"] -#total_sm = 110 # for MI250 -total_sm = 304 # for MI300X -print(f"total SMs: {total_sm}") - -# --------------------------------------------------------------------------- -# Triton kernels -# --------------------------------------------------------------------------- - - -@triton.jit() -def swizzle_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr): - grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = tile_id // width - group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (tile_id % group_size) - pid_n = (tile_id % width) // group_size - return pid_m, pid_n - - -@triton.jit() -def linear_tile(tile_id, M, N, K, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr): - pid_m = tile_id // tl.cdiv(N, BLOCK_N) - pid_n = tile_id % tl.cdiv(N, BLOCK_N) - return pid_m, pid_n - - -@triton.jit() -def first_wave( - A, - B, - C, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - total_full_tiles_streamk, - total_partial_tiles_streamk, - iters_per_tile, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - ACC_TYPE: tl.constexpr, - GROUP_M: tl.constexpr, -): - pid = tl.program_id(0) - start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) - last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum(pid + 1, total_partial_tiles_streamk) - - while start_iter < last_iter: - remainder = start_iter % iters_per_tile - end_iter = tl.minimum(start_iter + (iters_per_tile - remainder), last_iter) - # where are we in the grid - tile_id = start_iter // iters_per_tile - if GROUP_M > 0: - pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - else: - pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rk = tl.arange(0, BLOCK_K) - A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_K * stride_ak * remainder - B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_K * stride_bk * remainder - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - - for current_iter in range(start_iter, end_iter): - a = tl.load(A_BASE) - b = tl.load(B_BASE) - acc += tl.dot(a, b) - A_BASE += BLOCK_K * stride_ak - B_BASE += BLOCK_K * stride_bk - - if remainder == 0 and end_iter % iters_per_tile == 0: - C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! - tl.store(C_, acc) - else: - C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn # compute inside the if/else to avoid spilling! - tl.atomic_add(C_, acc) - - start_iter = end_iter - - -# similar to the reference matmul kernel -@triton.jit() -def full_tiles( - A, - B, - C, - M, - N, - K, - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - total_tiles_streamk, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - ACC_TYPE: tl.constexpr, - GROUP_M: tl.constexpr, -): - # first wave has done more tiles than there are SMs, we adjust pid - tile_id = tl.program_id(0) + total_tiles_streamk - if GROUP_M > 0: - pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - else: - pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) - - # do matrix multiplication - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - rk = tl.arange(0, BLOCK_K) - # pointers - A = A + rm[:, None] * stride_am + rk[None, :] * stride_ak - B = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(A) - b = tl.load(B) - acc += tl.dot(a, b) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - acc = acc.to(tl.float16) # restore C.dtype.element_ty - # rematerialize rm and rn to save registers - # rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - # rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn - tl.store(C, acc) - - -# --------------------------------------------------------------------------- -# Wrapper -# --------------------------------------------------------------------------- - - -class matmul(torch.autograd.Function): - - _debug = True - - @staticmethod - def set_debug(debug: bool): - matmul._debug = debug - - @staticmethod - def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: int, BLK_N: int, BLK_K: int, - two_tiles: bool, num_stages: int, num_warps: int, waves_per_eu: int, mfmaInstrSize: int, kpack: int): - device = a.device - - assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported" - # checks constraints - assert a.shape[1] == b.shape[0], "incompatible dimensions" - M, K = a.shape - _, N = b.shape - # accumulator types - ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 - # compute grid (work to do per SM on the first wave) - total_blocks_M = triton.cdiv(M, BLK_M) - total_blocks_N = triton.cdiv(N, BLK_N) - iters_per_tile = triton.cdiv(K, BLK_K) - GROUP_M = 4 # 0 to disable swizzling - total_tiles = total_blocks_M * total_blocks_N - - if total_programs_streamk > 0: # Stream-K - # last wave may occupy less than total_programs_streamk SMs - total_tiles_streamk = total_tiles % total_programs_streamk - # for two-tile Stream-K + data-parallel from original paper - if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: - total_tiles_streamk += total_programs_streamk - # remaining tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk - total_iters_streamk = total_tiles_streamk * iters_per_tile - # iterations related to full waves - total_full_tiles_streamk = total_iters_streamk // total_programs_streamk - # iterations related to last (partial) wave - total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk - - else: # all tiles are computed using classical blocking - total_blocking_tiles = total_tiles - total_tiles_streamk = 0 - total_full_tiles_streamk = 0 - total_partial_tiles_streamk = 0 - total_iters_streamk = 0 - - if matmul._debug: - print(f"M,N,K={M},{N},{K} ; BLK_M,N,K={BLK_M},{BLK_N},{BLK_K}") - print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}") - print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}") - print(f"{total_programs_streamk=}") - print(f"{total_blocking_tiles=}") - print(f"{iters_per_tile=}") - print(f"{total_iters_streamk=}") - - # allocates output - c = torch.zeros((M, N), device=device, dtype=a.dtype) - - k1 = first_wave[(total_programs_streamk, )]( - a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), - total_full_tiles_streamk=total_full_tiles_streamk, total_partial_tiles_streamk=total_partial_tiles_streamk, - iters_per_tile=iters_per_tile, BLOCK_M=BLK_M, BLOCK_N=BLK_N, BLOCK_K=BLK_K, ACC_TYPE=ACC_TYPE, - GROUP_M=GROUP_M, num_stages=num_stages, num_warps=num_warps, waves_per_eu=waves_per_eu, - matrix_instr_nonkdim=mfmaInstrSize, kpack=kpack) - if matmul._debug: - print(f"{k1.n_regs} registers used, {k1.n_spills} spills") - k2 = full_tiles[(total_blocking_tiles, )](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), - c.stride(0), c.stride(1), total_tiles_streamk=total_tiles_streamk, - BLOCK_M=BLK_M, BLOCK_N=BLK_N, BLOCK_K=BLK_K, ACC_TYPE=ACC_TYPE, - GROUP_M=GROUP_M, num_stages=num_stages, num_warps=num_warps, - waves_per_eu=waves_per_eu, matrix_instr_nonkdim=mfmaInstrSize, - kpack=kpack) - if matmul._debug: - print(f"{k2.n_regs} registers used, {k2.n_spills} spills") -# print(k2.asm['amdgcn']) - return c - - @staticmethod - def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=128, BLK_K=32, two_tiles=True, - num_stages=3, num_warps=4, waves_per_eu=2, mfmaInstrSize=16, kpack=1): - return matmul._call(a=a, b=b, total_programs_streamk=grid, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, - two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages, waves_per_eu=waves_per_eu, - mfmaInstrSize=mfmaInstrSize, kpack=kpack) - - -# --------------------------------------------------------------------------- -# Example and Benchmark -# --------------------------------------------------------------------------- - -perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3) - -#m, n, k = 4864, 4096, 8256 # some problem size to test -m, n, k = 6912, 768, 256 # some problem size to test -#m, n, k = 8192, 8192, 8192 # some problem size to test -A = torch.randn(m, k, device="cuda", dtype=torch.float16) -B = torch.randn(k, n, device="cuda", dtype=torch.float16) -#A = torch.ones((m, k), device="cuda", dtype=torch.float16) -#B = torch.ones((k, n), device="cuda", dtype=torch.float16) -BLK_M = 64 -BLK_N = 64 -BLK_K = 64 -two_tiles = 'True' -num_stages = 0 -num_warps = 4 -waves_per_eu = 0 -mfmaInstrSize = 16 -kpack = 2 - -matmul.set_debug(True) -C = matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, - kpack) -#exit(0) -matmul.set_debug(False) -expected = A @ B - -assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}" - -# for debugging, uncomment the following line - -triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B)) -print(f"PyTorch: {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, - num_warps, waves_per_eu, mfmaInstrSize, kpack)) -print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, - num_warps, waves_per_eu, mfmaInstrSize, kpack)) -print(f"hybrid stream-k (grid={total_sm * 2}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, - waves_per_eu, mfmaInstrSize, kpack)) -print(f"tile matmul (grid=0): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops") - -exit(0) -# --------------------------------------------------------------------------- -# Log-sampled benchmark -# --------------------------------------------------------------------------- - -# tried to reproduce the tests described in the paper -num_samples = 1000 # 32768 -step = 256 -values = ((torch.logspace(torch.tensor(step).log2(), - torch.tensor(8192).log2(), num_samples, base=2) / step).round() * step).unique().tolist() -shapes = [(int(m), int(n), int(k)) for m in values for n in values for k in values] -shapes = random.sample(shapes, num_samples) -assert len(shapes) == num_samples - -results = [] -for idx, (m, n, k) in enumerate(shapes): - # print progress bar - if idx % 10 == 0 and idx > 0: - speedups = [r["speedup"] for r in results] - print(f"{idx}/{num_samples} - average speedup: {sum(speedups) / len(speedups):.3f}") - - A = torch.randn(m, k, device="cuda", dtype=torch.float16) - B = torch.randn(k, n, device="cuda", dtype=torch.float16) - output: Optional[torch.Tensor] = None - - def wrapper_matmul(*args, **kwargs): - global output - output = matmul.apply(*args, **kwargs) - return output - - expected = A @ B - pytorch_ms = triton.testing.do_bench(lambda: A @ B) - measures = list() - for two_tiles in [True, False]: - nb_sm = [total_sm, total_sm * 2] - total_tile = (m // BLK_M) * (n // BLK_N) - if total_tile < total_sm * 2: - nb_sm.append(total_tile) - nb_sm += random.sample(range(2, total_sm * 2, 2), 10) - for sm in nb_sm: - triton_ms = triton.testing.do_bench(lambda: wrapper_matmul( - A, B, sm, BLK_M, BLK_N, BLK_K, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)) - max_disc = (output - expected).abs().max().item() - # large tolerance to accomodate for large K (rounding due to half precision), we just want to catch bugs. - assert max_disc <= 5., f"pb size: {m}x{n}x{k} - max discrepancy: {max_disc} - sm: {sm}, 2 tiles: {two_tiles}\n{output}\n{expected}" - info = { - "2 tiles": two_tiles, - "sm": sm, - "disc": max_disc, - "triton_ms": triton_ms, - } - measures.append(info) - best_triton_ms = min([m["triton_ms"] for m in measures]) - d = { - "m": m, - "n": n, - "k": k, - "triton": measures, - "pytorch_ms": pytorch_ms, - "speedup": pytorch_ms / best_triton_ms, - } - results.append(d) - measures = list() - -results.sort(key=lambda x: x["speedup"], reverse=False) - -# --------------------------------------------------------------------------- -# Benchmark export -# --------------------------------------------------------------------------- - -with open("results.json", "w") as f: - json.dump(results, f, indent=4) - -# 32760/32768 - average speedup: 0.962 (A100) -# 990/1000 - average speedup: 1.063 (3090 RTX with while loop and 2 tiles disabled / enabled) diff --git a/python/perf-kernels/streamk/README.md b/python/perf-kernels/streamk/README.md new file mode 100644 index 000000000000..aa0b11d41b73 --- /dev/null +++ b/python/perf-kernels/streamk/README.md @@ -0,0 +1,43 @@ +# streamk gemm script v0.1 + +The plan is to use this version as the base version for the future triton streamk gemm development. + +### Main features +- comparable performance with tune gemm + +- use the persistent loop so that a WG may work on multiple output tiles, and also allowing workgroups to do part of the work for an output tile. + +- use atomics for spinning lock to replace atomic_add for the final output. + +- pid renumbering based on chiplet structure of MI300X + +- dynamic grid setting + +- tuning script adapt from tune_gemm + +### Usage + +Go to the script dir +```bash +cd triton/python/perf_kernels/streamk +``` + +1. Tune gemm sizes given in a yaml file and check correctness on the way +```bash +python tune_streamk.py --gemm_size_file input_gemm_sizes.yaml --compare +``` + +2. Tune a single gemm size +```bash +python tune_streamk.py -m 16 -n 16 -k 16 +``` + +3. Choose the file to store tuning results +```bash +python tune_streamk.py --gemm_size_file input_gemm_sizes.yaml --o output_tuning.yaml +``` + +4. Only check correctness given the tuning results +```bash +python tune_streamk.py --gemm_size_file output_tuning.yaml --compare_wo_tuning +``` diff --git a/python/perf-kernels/streamk/streamk_kernel.py b/python/perf-kernels/streamk/streamk_kernel.py new file mode 100644 index 000000000000..138e6540e203 --- /dev/null +++ b/python/perf-kernels/streamk/streamk_kernel.py @@ -0,0 +1,206 @@ +import triton +import triton.language as tl + + +@triton.jit() +def get_new_pid(current_pid, num_cus): + # Number of XCDs + num_xcds = 8 + # Number of pids per XCD in the new arrangement + pids_per_xcd = num_cus // num_xcds + # Compute current XCD and local pid within the XCD + xcd = current_pid % num_xcds + local_pid = current_pid // num_xcds + + # Calculate new pid based on the new grouping + new_pid = xcd * pids_per_xcd + local_pid + return new_pid + + +@triton.jit() +def get_tiles_config( + M, + N, + K, + num_cus, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + total_blocks_M = tl.cdiv(M, BLOCK_SIZE_M) + total_blocks_N = tl.cdiv(N, BLOCK_SIZE_N) + iters_per_tile = tl.cdiv(K, BLOCK_SIZE_K) + + total_tiles = total_blocks_M * total_blocks_N + if num_cus > 0 and total_tiles > num_cus: # Stream-K + total_streamk_tiles = total_tiles % num_cus + total_full_tiles = total_tiles - total_streamk_tiles + total_streamk_iters = total_streamk_tiles * iters_per_tile + # iterations related to full waves + streamk_iters_pcu = total_streamk_iters // num_cus + # iterations related to last (partial) wave + streamk_remainder_iters = total_streamk_iters % num_cus + + else: # all tiles are computed using classical blocking + total_full_tiles = total_tiles + total_streamk_tiles = 0 + streamk_iters_pcu = 0 + streamk_remainder_iters = 0 + total_streamk_iters = 0 + + return iters_per_tile, total_full_tiles, total_streamk_tiles, streamk_iters_pcu, streamk_remainder_iters + + +@triton.jit() +def streamk_gemm( + A, + B, + C, + P, + locks, + M, + N, + K, + num_cus, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + EVEN_K: tl.constexpr, +): + pid = tl.program_id(0) + pid = get_new_pid(pid, num_cus) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + iters_per_tile, total_full_tiles, total_streamk_tiles, streamk_iters_pcu, streamk_remainder_iters = get_tiles_config( + M, N, K, num_cus, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K) + + acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 + rk = tl.arange(0, BLOCK_SIZE_K) + + for tile_id in range(pid, total_full_tiles, num_cus): + if GROUP_SIZE_M == 1: + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if EVEN_K: + a = tl.load(A_BASE) + b = tl.load(B_BASE) + else: + a = tl.load(A_BASE, mask=rk[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(B_BASE, mask=rk[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + acc += tl.dot(a, b) + A_BASE += BLOCK_SIZE_K * stride_ak + B_BASE += BLOCK_SIZE_K * stride_bk + + c = acc.to(C.type.element_ty) + + rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + mask = (rm < M)[:, None] & (rn < N)[None, :] + tl.store(C_, c, mask=mask) + + start_iter = total_full_tiles * iters_per_tile + pid * streamk_iters_pcu + tl.minimum(pid, streamk_remainder_iters) + last_iter = total_full_tiles * iters_per_tile + (pid + 1) * streamk_iters_pcu + tl.minimum( + pid + 1, streamk_remainder_iters) + while start_iter < last_iter: + remainder = start_iter % iters_per_tile + end_iter = tl.minimum(start_iter + (iters_per_tile - remainder), last_iter) + # where are we in the grid + tile_id = start_iter // iters_per_tile + if GROUP_SIZE_M == 1: + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + # rk = tl.arange(0, BLOCK_SIZE_K) + A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak + BLOCK_SIZE_K * stride_ak * remainder + B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + BLOCK_SIZE_K * stride_bk * remainder + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + for current_iter in range(start_iter, end_iter): + if EVEN_K: + a = tl.load(A_BASE) + b = tl.load(B_BASE) + else: + global_k_offset = (current_iter % iters_per_tile) * BLOCK_SIZE_K + k_mask = global_k_offset + rk < K + a = tl.load(A_BASE, mask=k_mask[None, :], other=0.0) + b = tl.load(B_BASE, mask=k_mask[:, None], other=0.0) + acc += tl.dot(a, b) + A_BASE += BLOCK_SIZE_K * stride_ak + B_BASE += BLOCK_SIZE_K * stride_bk + + tile_iter = tile_id * iters_per_tile + if start_iter == tile_iter: + tile_iter_end = tile_iter + iters_per_tile + next_pid = pid + 1 + end = end_iter + while (end < tile_iter_end and next_pid < num_cus): + # todo: try use tl.load once cache modifier landed upstream + while tl.atomic_cas(locks + next_pid, 1, 1) != 1: + pass + rm1 = tl.arange(0, BLOCK_SIZE_M) + rn1 = tl.arange(0, BLOCK_SIZE_N) + rm1 = tl.max_contiguous(tl.multiple_of(rm1, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn1 = tl.max_contiguous(tl.multiple_of(rn1, BLOCK_SIZE_N), BLOCK_SIZE_N) + P_ = P + next_pid * BLOCK_SIZE_M * BLOCK_SIZE_N + rm1[:, None] * BLOCK_SIZE_N + rn1[None, :] + acc += tl.load(P_) + end += streamk_iters_pcu + (next_pid < streamk_remainder_iters) + + next_pid += 1 + + c = acc.to(C.type.element_ty) + + rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + mask = (rm < M)[:, None] & (rn < N)[None, :] + tl.store(C_, c, mask=mask) + + else: + rm1 = tl.arange(0, BLOCK_SIZE_M) + rn1 = tl.arange(0, BLOCK_SIZE_N) + rm1 = tl.max_contiguous(tl.multiple_of(rm1, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn1 = tl.max_contiguous(tl.multiple_of(rn1, BLOCK_SIZE_N), BLOCK_SIZE_N) + P_ = P + pid * BLOCK_SIZE_M * BLOCK_SIZE_N + rm1[:, None] * BLOCK_SIZE_N + rn1[None, :] + tl.store(P_, acc) + tl.atomic_xchg(locks + pid, 1) + + start_iter = end_iter diff --git a/python/perf-kernels/streamk/tune_streamk.py b/python/perf-kernels/streamk/tune_streamk.py new file mode 100644 index 000000000000..3b0fbdb960c7 --- /dev/null +++ b/python/perf-kernels/streamk/tune_streamk.py @@ -0,0 +1,847 @@ +# fp8 +import argparse +import sys +import yaml +import os +import glob +import subprocess + +import torch +import triton +import triton.language as tl + +from streamk_kernel import streamk_gemm + +from datetime import datetime +import multiprocessing +import pandas as pd + +device_oi = 650. / 3.0 + + +def get_full_tuning_space(): + configs = [] + + block_mn_range = [16, 32, 64, 128, 256] + block_k_range = [16, 32, 64, 128, 256] + num_warps_range = [1, 2, 4, 8] + group_m_range = [1, 4, 8, 16, 32] + # For now we see better perf with num_stages=0 for all gemm configs we care + # But keep this explicit so that we do not forget we may need to set it to + # other values in the future + num_stage_range = [0] + waves_per_eu_range = [0] + matrix_instr_nonkdim_range = [16, 32] + kpack_range = [1, 2] + + for block_m in block_mn_range: + for block_n in block_mn_range: + for block_k in block_k_range: + for num_warps in num_warps_range: + for group_m in group_m_range: + for num_stages in num_stage_range: + for waves_per_eu in waves_per_eu_range: + for matrix_instr_nonkdim in matrix_instr_nonkdim_range: + for kpack in kpack_range: + configs.append({ + 'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, + 'GROUP_SIZE_M': group_m, 'num_warps': num_warps, 'num_stages': num_stages, + 'waves_per_eu': waves_per_eu, 'matrix_instr_nonkdim': matrix_instr_nonkdim, + 'kpack': kpack + }) + + return configs + + +def get_gemm_oi(M, N, K): + FLOPs = 2 * M * N * K + # 4 for fp32 + # to do check dtype for bytesmoved + bytesmoved = (M * K + K * N + 2 * M * N) * 4 + return FLOPs / bytesmoved + + +def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b): + pruned_configs = [] + + if M < 32 or N < 32: + mfma = 16 + else: + mfma = 32 + + # TODO (zhanglx): figure out the boundary between large and small gemms + large_gemm = False + if M >= 2048 and N >= 2048: + large_gemm = True + + for config in configs: + BLOCK_SIZE_M = config.get("BLOCK_SIZE_M") + BLOCK_SIZE_N = config.get("BLOCK_SIZE_N") + BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") + num_warps = config.get("num_warps") + matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") + kpack = config.get("kpack") + if matrix_instr_nonkdim > mfma: + continue + if mfma == 4 and BLOCK_SIZE_K < 64: + continue + # some layouts could not work properly in case + # number elemens per thread is less 1 + if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: + continue + GROUP_M = config.get("GROUP_SIZE_M") + if BLOCK_SIZE_M < matrix_instr_nonkdim or BLOCK_SIZE_N < matrix_instr_nonkdim: + continue + if BLOCK_SIZE_K == 16 and matrix_instr_nonkdim == 16 and kpack == 2: + continue + if M <= matrix_instr_nonkdim and BLOCK_SIZE_M != matrix_instr_nonkdim: + continue + if N <= matrix_instr_nonkdim and BLOCK_SIZE_N != matrix_instr_nonkdim: + continue + # Skip BLOCK_SIZE that is too large compare to M/N + # unless BLOCK_SIZE is already small enough + if BLOCK_SIZE_M > M * 2 and BLOCK_SIZE_M != 16: + continue + if BLOCK_SIZE_N > N * 2 and BLOCK_SIZE_N != 16: + continue + # skip large GROUP_M + if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: + continue + # out of shared memory resource + # TODO (zhanglx): This does not consider the LDS usage in the epilogue + LDS = BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b + if LDS > 65536: + continue + # Skip small block sizes and num_warps for large gemm + # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 + if large_gemm: + if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: + continue + if BLOCK_SIZE_K < 64: + continue + if num_warps < 4: + continue + + pruned_configs.append(config) + + return pruned_configs + + +def run_bash_command_wrapper(commandstring, capture=True): + try: + run_bash_command(commandstring, capture) + except subprocess.CalledProcessError: + if not capture: + print(f"running {commandstring} one more time") + run_bash_command(commandstring, capture) + + +def run_bash_command(commandstring, capture=True): + if capture: + proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash', stdout=subprocess.PIPE) + return proc.stdout.splitlines() + proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash') + return None + + +def read_config(config): + block_m = config.get('BLOCK_SIZE_M') + block_n = config.get('BLOCK_SIZE_N') + block_k = config.get('BLOCK_SIZE_K') + group_m = config.get('GROUP_SIZE_M') + num_warps = config.get('num_warps') + num_stages = config.get('num_stages') + waves_per_eu = config.get('waves_per_eu') + mfma_instr_size = config.get('matrix_instr_nonkdim') + kpack = config.get('kpack') + return block_m, block_n, block_k, group_m, num_warps, num_stages, waves_per_eu, mfma_instr_size, kpack + + +def gen_kernel_and_configStr_from_config(M, N, K, num_cus, EVEN_K, config, dtype_a, dtype_b, dtype_c, dtype_p, + dtype_lock): + block_m, block_n, block_k, group_m, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config(config) + torch_dtype_a = 'fp16' + torch_dtype_b = 'fp16' + torch_dtype_c = 'fp16' + torch_dtype_p = 'fp32' + torch_dtype_lock = 'int32' + if dtype_a: + torch_dtype_a = tl_to_torch_types[name_to_tl_types[dtype_a]] + if dtype_b: + torch_dtype_b = tl_to_torch_types[name_to_tl_types[dtype_b]] + if dtype_c: + torch_dtype_c = tl_to_torch_types[name_to_tl_types[dtype_c]] + if dtype_p: + torch_dtype_p = tl_to_torch_types[name_to_tl_types[dtype_p]] + if dtype_lock: + torch_dtype_lock = tl_to_torch_types[name_to_tl_types[dtype_lock]] + configStr = f"M{M}_N{N}_K{K}_BM{block_m}_BN{block_n}_BK{block_k}_GM{group_m}_nW{num_warps}_nS{num_stages}_EU{waves_per_eu}_kP{kpack}_mfma{mfmaInstrSize}" + + matmul_def_str = f""" +def matmul_{configStr}(a, b, c, P, locks, M, N, K, num_cus, am, ak, bk, bn, cm, cn, warmup=False): + grid = num_cus + #print(f'config: streamk_gemm_{configStr}', flush=True) + if warmup: + streamk_gemm_{configStr}.warmup( + {torch_dtype_a}, {torch_dtype_b}, {torch_dtype_c}, {torch_dtype_p}, {torch_dtype_lock}, + M, N, K, num_cus, + am, ak, bk, bn, cm, cn, + BLOCK_SIZE_M = {block_m}, + BLOCK_SIZE_N = {block_n}, + BLOCK_SIZE_K = {block_k}, + GROUP_SIZE_M = {group_m}, + num_warps = {num_warps}, + num_stages = {num_stages}, + waves_per_eu = {waves_per_eu}, + matrix_instr_nonkdim = {mfmaInstrSize}, + kpack = {kpack}, + EVEN_K = {EVEN_K}, + grid=(1,) + ) + return None + else: + streamk_gemm_{configStr}[grid,]( + a, b, c, P, locks, + M, N, K, num_cus, + am, ak, bk, bn, cm, cn, + BLOCK_SIZE_M = {block_m}, + BLOCK_SIZE_N = {block_n}, + BLOCK_SIZE_K = {block_k}, + GROUP_SIZE_M = {group_m}, + num_warps = {num_warps}, + num_stages = {num_stages}, + waves_per_eu = {waves_per_eu}, + matrix_instr_nonkdim = {mfmaInstrSize}, + kpack = {kpack}, + EVEN_K = {EVEN_K} + ) + return c + +def try_config_{configStr}(M, N, K, num_cus, am, ak, bk, bn, cm, cn): + try: + matmul_{configStr}(None, None, None, None, None, M, N, K, num_cus, am, ak, bk, bn, cm, cn, True) + return True + except Exception as e: + print(f'invalid config(compilation): {configStr}: ', e, flush=True) + return False +""" + return configStr, matmul_def_str + + +def generated_kernel_name(M, N, K, gpu_id): + return f"generated_kernel{M}-{N}-{K}-{gpu_id}.py" + + +# Open {len(gpus)} files +# generated_kernelM-N-K-{gpus[0]}.py, generated_kernelM-N-K-{gpus[1]}.py, ..., generated_kernelM-N-K-{gpus[-1]}.py +# and generate +# 1. matmul kernels of all configs +# 2. wrapper function matmul to invoke all the generated kernels +# 3. Another wraper function try_config to invoke matmul function +# 4. test_gemm to invoke +# 4.1 run try_config in parallel +# 4.2 matmul in a loop of 10 iterations +def generate_kernel(M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, dtype_p, dtype_lock, init_type, configs, + jobs, iters, run_bench): + filenames = [] + for i in range(jobs): + filenames.append(generated_kernel_name(M, N, K, i)) + f_kernel = [open(path, 'w') for path in filenames] + + # write imports + import_str = """import torch +import triton +import triton.language as tl +import argparse +import sys +import multiprocessing +from tune_streamk import gen_input +""" + for fi in range(jobs): + f_kernel[fi].write(import_str + "\n") + + # write definitions of streamk_gemm_xxx + # and matmul_xxx and try_config + with open("streamk_kernel.py") as file: + streamk_gemm_code = file.read() + idx = 0 + for config in configs: + file_idx = idx % jobs + EVEN_K = True if K % config.get('BLOCK_SIZE_K') == 0 else False + configStr, matmul_def_str = gen_kernel_and_configStr_from_config(M, N, K, num_cus, EVEN_K, config, dtype_a, + dtype_b, dtype_c, dtype_p, dtype_lock) + # Copy the streamk_gemm with name replaced + streamk_gemm_config = streamk_gemm_code.replace("streamk_gemm", f"streamk_gemm_{configStr}") + streamk_gemm_config = streamk_gemm_config.replace("import triton.language as tl", "") + streamk_gemm_config = streamk_gemm_config.replace("import triton", "") + f_kernel[file_idx].write(streamk_gemm_config + "\n\n") + f_kernel[file_idx].write(matmul_def_str + "\n") + idx += 1 + + # write test_gemm + # pre string + block_m = config.get('BLOCK_SIZE_M') + block_n = config.get('BLOCK_SIZE_N') + test_gemm_pre_str = f"""def test_gemm(M, N, K, num_cus, num_threads): + thread_pool = multiprocessing.Pool(processes=num_threads) + a, a_fp16 = gen_input(M, K, '{dtype_a}', {col_a}, 1, '{init_type}', device='cuda') + b, b_fp16 = gen_input(K, N, '{dtype_b}', {col_b}, 2, '{init_type}', device='cuda') + c = torch.zeros((M, N), device=a.device, dtype={tl_to_torch_types[name_to_tl_types[dtype_c]]}) + task_args = (M, N, K, num_cus, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1)) + + if num_threads > 1: + results = [] + config_names = [] +""" + for fi in range(jobs): + f_kernel[fi].write(test_gemm_pre_str + "\n") + + # warm up call of all matmul functions in parallel + idx = 0 + for config in configs: + EVEN_K = True if K % config.get('BLOCK_SIZE_K') == 0 else False + configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, num_cus, EVEN_K, config, None, None, None, None, + None) + task_str = f" results += [thread_pool.apply_async(try_config_{configStr}, args=task_args)]\n" + \ + f" config_names += ['{configStr}']\n" + f_kernel[idx % jobs].write(task_str) + idx += 1 + + for fi in range(jobs): + threadpool_str = """ + failed_configs = [] + for i in range(len(results)): + results[i].wait() + res = results[i].get() + if not res: + failed_configs += [config_names[i]] + thread_pool.close() + thread_pool.join() + with open("{filename}.failed_configs", "w") as f: + for cfg in failed_configs: + f.write(cfg + "\\n") + else: + try: + with open("{filename}.failed_configs", "r") as f: + failed_configs = [cfg.strip() for cfg in f.readlines()] + except Exception: + failed_configs = [] + """.format(filename=filenames[fi]) + f_kernel[fi].write(threadpool_str) + # call all matmul_xxx functions + idx = 0 + runs = iters if run_bench else 200 + for config in configs: + EVEN_K = True if K % config.get('BLOCK_SIZE_K') == 0 else False + configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, num_cus, EVEN_K, config, None, None, None, None, + None) + block_m = config.get('BLOCK_SIZE_M') + block_n = config.get('BLOCK_SIZE_N') + matmul_call_str = f""" + if '{configStr}' not in failed_configs: + print(f"{configStr}") + for i in range({runs}): + locks = torch.zeros((num_cus,), device = "cuda", dtype = torch.int32) + P = torch.zeros((num_cus, {block_m}*{block_n}), device="cuda", dtype=torch.float32) + d = matmul_{configStr}(a, b, c, P, locks, M, N, K, num_cus, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1))""" + f_kernel[idx % jobs].write(matmul_call_str + "\n") + idx += 1 + # post string + for fi in range(jobs): + f_kernel[fi].write(" return d\n") + + # def main and call test_gemm + def_main_str = """ +def main(): + parser = argparse.ArgumentParser( + prog="tune a specific gemm size", + allow_abbrev=False,) + parser.add_argument("-n", type=int, default=1, help='number of threads') + args = parser.parse_args() + numThreads = args.n + num_cus = 304 + """ + test_gemm_call_str = f'test_gemm({M}, {N}, {K}, num_cus, numThreads)' + for fi in range(jobs): + f_kernel[fi].write(def_main_str) + f_kernel[fi].write(test_gemm_call_str + "\n\n") + f_kernel[fi].write("""if __name__ == '__main__': + sys.exit(main())""") + f_kernel[fi].close() + + +def extract_kernel_time(M, N, K, num_cus, EVEN_K, config, df): + # Correct the header by removing 'sig' and 'obj' to reduce number from 21 to 19 + # once the bug is fixed, we should not need below two lines + cols = [ + 'Index', 'KernelName', 'gpu-id', 'queue-id', 'queue-index', 'pid', 'tid', 'grd', 'wgr', 'lds', 'scr', + 'arch_vgpr', 'accum_vgpr', 'sgpr', 'wave_size', 'DispatchNs', 'BeginNs', 'EndNs', 'CompleteNs' + ] + df.columns = cols + + configStr, _ = gen_kernel_and_configStr_from_config(M, N, K, num_cus, EVEN_K, config, None, None, None, None, None) + + filtered_df = df[df['KernelName'].str.contains(configStr, na=False)].copy() + filtered_df['DurationNs'] = filtered_df['EndNs'] - filtered_df['BeginNs'] + meanTime = filtered_df['DurationNs'].tail(100).mean() + return config, meanTime + + +def profile_batch_kernels(M, N, K, num_cus, gpuid, gpus, jobs, verbose): + ngpus = len(gpus) + gpuIdx = gpus.index(gpuid) + if gpuIdx + 1 > jobs: + return + os.environ['ROCR_VISIBLE_DEVICES'] = str(gpuid) + jobId = gpuIdx + while jobId < jobs: + if verbose: + print(f"profiling {generated_kernel_name(M, N, K, jobId)} on GPU {gpuid}") + run_bash_command_wrapper( + f"rocprofv2 --plugin file --plugin-version 1 --kernel-trace -o {jobId} python {generated_kernel_name(M, N, K, jobId)}", + capture=(verbose < 2)) + jobId += ngpus + + +def tune_gemm_config(M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, dtype_p, dtype_lock, init_type, configs, + run_bench, jobs, iters, skipWarmup, verbose=0, num_threads=16, gpus=[0]): + # Generate kernel out of all configs + generate_kernel(M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, dtype_p, dtype_lock, init_type, configs, + jobs, iters, run_bench) + + # remove any compiled kernel in the cache + run_bash_command("rm -rf ~/.triton/cache") + + # precompile the kernels in parallel + start_time = datetime.now() + if not skipWarmup: + for i in range(jobs): + run_bash_command(f"python {generated_kernel_name(M, N, K, i)} -n {num_threads}", capture=(verbose < 2)) + compile_end = datetime.now() + compile_time = compile_end - start_time + if verbose: + print(f"compile time: {compile_time}", flush=True) + + # profile generated kernels + running = [ + multiprocessing.Process(target=profile_batch_kernels, args=(M, N, K, num_cus, gpu_id, gpus, jobs, verbose)) + for gpu_id in gpus + ] + for p in running: + p.start() + for p in running: + p.join() + + profile_end = datetime.now() + profile_time = profile_end - compile_end + if verbose: + print(f"profile time: {profile_time}", flush=True) + + # post process results.csv to get the best config and minTime + # TODO: process the file in parallel + minTime = 1024 * 1024 * 1024 + thread_pool = multiprocessing.Pool(processes=num_threads) + tasks = [] + idx = 0 + df_prof = [ + pd.read_csv(f"results_{i}.csv", skiprows=1, header=None, delimiter=',', quotechar='"', escapechar='\\') + for i in range(jobs) + ] + for config in configs: + EVEN_K = True if K % config.get('BLOCK_SIZE_K') == 0 else False + file_idx = idx % jobs + tasks += [ + thread_pool.apply_async(extract_kernel_time, args=(M, N, K, num_cus, EVEN_K, config, df_prof[file_idx])) + ] + idx += 1 + thread_pool.close() + thread_pool.join() + + for task in tasks: + config, myTime = task.get() + if myTime: + min_us = myTime / 1000 + if min_us < minTime: + minTime = min_us + bestConfig = config + else: + min_us = -1 + print(f"invalid config(post processing): SIZE {M} {N} {K}: {config}", flush=True) + post_end = datetime.now() + post_time = post_end - profile_end + if verbose: + print(f"post procesing time: {post_time}", flush=True) + return minTime, bestConfig, compile_time, profile_time, post_time + + +def gen_input(M, N, ty_name, needTrans, seed, init_type, device='cuda'): + d_type = name_to_tl_types[ty_name] + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @triton.jit + def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + input = tl.load(input_ptr + offsets, mask=mask) + output = input + tl.store(output_ptr + offsets, output, mask=mask) + + def init_by_size_and_type(size, dtype, init_type): + if init_type == 'hpl': + return torch.empty(size, device='cuda', dtype=dtype).uniform_(-0.5, 0.5) + # This init type has element[i] in row[j] equal to sin(i+j*N) + elif init_type == 'trig_float': + M, N = size + return torch.reshape(torch.arange(0, M * N), (M, N)).sin().to(dtype=dtype, device='cuda') + elif init_type == 'zeros': + return torch.zeros(size, dtype=dtype, device='cuda') + elif init_type == "randn": + temp = torch.randn(size, dtype=dtype, device='cuda') + return temp + else: + raise ValueError("Bad matrix initialization type.") + + raw_data = init_by_size_and_type((N, M) if needTrans else (M, N), torch.float32, init_type) + if needTrans: + raw_data = raw_data.T + if (d_type == tl.float8e4b8 and TORCH_HAS_FP8E4B8) or \ + (d_type == tl.float8e5b16 and TORCH_HAS_FP8E5B16) or not d_type.is_fp8(): + input = raw_data.to(tl_to_torch_types[d_type]) + input_f16 = input.to(torch.float16) + else: + f8_tensor = raw_data.to(torch.int8) + # keep only two bits of exponent to avoid overflow + f8_tensor = f8_tensor & 0b00111111 + input = triton.reinterpret(f8_tensor, d_type) + input_f16 = torch.empty_like(f8_tensor, dtype=torch.float16) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + n_elements = raw_data.numel() + copy_kernel[grid](input, input_f16, n_elements, BLOCK_SIZE=1024) + + return input, input_f16 + + +def matmul(a, b, c, P, locks, num_cus, block_m, block_n, block_k, group_m, num_warps, num_stages, waves_per_eu, + mfmaInstrSize, kpack, EVEN_K): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + #assert a.is_contiguous(), "Matrix A must be contiguous" + #assert b.is_contiguous(), "Matrix B must be contiguous" + M, K = a.shape + K, N = b.shape + # 1D launch kernel where each block gets its own program. + + grid = num_cus + + streamk_gemm[ + grid, + ](a, b, c, P, locks, M, N, K, num_cus, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), + BLOCK_SIZE_M=block_m, BLOCK_SIZE_N=block_n, BLOCK_SIZE_K=block_k, GROUP_SIZE_M=group_m, num_warps=num_warps, + num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=mfmaInstrSize, kpack=kpack, EVEN_K=EVEN_K) + return c + + +def test_correctness(M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, config, verbose): + block_m, block_n, block_k, group_m, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config(config) + torch.manual_seed(0) + #a = torch.randn((M, K), device='cuda', dtype=datatype) + #b = torch.randn((K, N), device='cuda', dtype=datatype) + a, a_fp16 = gen_input(M, K, dtype_a, col_a, 1, init_type, device='cuda') + b, b_fp16 = gen_input(K, N, dtype_b, col_b, 2, init_type, device='cuda') + # Allocates output. + print(f"{block_k}") + EVEN_K = K % block_k == 0 + c = torch.zeros((M, N), device=a.device, dtype=tl_to_torch_types[name_to_tl_types[dtype_c]]) + locks = torch.zeros((num_cus, ), device="cuda", dtype=torch.int32) + P = torch.zeros((num_cus, block_m * block_n), device="cuda", dtype=torch.float32) + triton_output = matmul(a, b, c, P, locks, num_cus, block_m, block_n, block_k, group_m, num_warps, num_stages, + waves_per_eu, mfmaInstrSize, kpack, EVEN_K) + torch_output = torch.matmul(a_fp16, b_fp16) + # print(f"triton_output={triton_output}") + # print(f"torch_output={torch_output}") + rtol = 0 if torch.version.hip is None else 1e-2 + atol = 1e-3 + row_a_str = 'N' if col_a else 'T' + row_b_str = 'N' if col_b else 'T' + size_str = '' + if verbose: + size_str = f'SIZE M: {M}, N: {N}, K: {K}, trans: {row_a_str}{row_b_str}' + if torch.allclose(triton_output.to(torch.float16), torch_output, atol=atol, rtol=rtol): + print(f'{size_str} Correct✅') + else: + print(f'{size_str} Incorrect❌') + + +def get_default_tuning_result_filename(): + git_branch_name = run_bash_command("git rev-parse --abbrev-ref HEAD") + git_branch_name = git_branch_name[0].decode() + git_commit_hash = run_bash_command("git rev-parse --short HEAD") + git_commit_hash = git_commit_hash[0].decode() + + dt_string = datetime.now().strftime("%m-%d-%Y-%H:%M:%S") + defaultName = f"tuning_results_{git_branch_name}@{git_commit_hash}_{dt_string}.yaml" + return defaultName + + +def parse_args(): + parser = argparse.ArgumentParser( + prog="tune a specific gemm size", + allow_abbrev=False, + ) + + parser.add_argument("-m", type=int, default=0) + parser.add_argument("-n", type=int, default=0) + parser.add_argument("-k", type=int, default=0) + parser.add_argument("-col_a", action='store_true', default=False, help='whether matrix a is column major') + parser.add_argument("-col_b", action='store_true', default=False, help='whether matrix b is column major') + parser.add_argument("-dtype_a", type=str, default='fp16', help="matrix a element data type") + parser.add_argument("-dtype_b", type=str, default='fp16', help="matrix b element data type") + parser.add_argument("-dtype_c", type=str, default='fp16', help="output element data type") + parser.add_argument("--ngpus", type=int, default=0, help='number of GPUs used in the profiling step') + parser.add_argument("--gpu_ids", type=lambda s: [int(id) for id in s.split(',')], default=[], + help='list of gpu ids to use for tuning') + parser.add_argument("--gemm_size_file", type=str, default="", help='yaml file to indicate matrix size') + parser.add_argument("--o", type=str, default=get_default_tuning_result_filename(), + help='yaml file to store tuning results') + parser.add_argument("--keep", action='store_true', default=False, help='keep generated files') + parser.add_argument("--compare", action='store_true', default=False, help="Whether check result correctness") + parser.add_argument("--compare_wo_tuning", action='store_true', default=False, + help="Whether check result correctness") + parser.add_argument("--benchmark", action='store_true', default=False, help="Benchmark the given config") + parser.add_argument("--time_breakdown", action='store_true', default=False, + help="Show detailed time breakdown of each step during the tuning") + parser.add_argument("--verbose", action='store_true', default=False, + help="enables time_breakdown and additional logging messages") + parser.add_argument("--num_threads", type=int, default=16, + help="number of threads to use for kernel compilation and post processing") + parser.add_argument("--jobs", type=int, default=1, help="number of generated files") + parser.add_argument("--iters", type=int, default=1000, help="number of generated files") + parser.add_argument("--init_type", type=str, default='randn', + help="Initialization type for input matrices (default uniform rand [0, 1.0)])") + parser.add_argument("--no_warmup", action='store_true', default=False, help="Do not call the warmup kernel") + args = parser.parse_args() + + return args + + +TORCH_HAS_FP8E5B16 = hasattr(torch, 'float8_e5m2fnuz') +TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz') +tl_to_torch_types = { + tl.float16: torch.float16, + tl.bfloat16: torch.bfloat16, + tl.float32: torch.float32, + tl.int8: torch.int8, + tl.int32: torch.int32, +} +if TORCH_HAS_FP8E5B16: + tl_to_torch_types[tl.float8e5b16] = torch.float8_e5m2fnuz +if TORCH_HAS_FP8E4B8: + tl_to_torch_types[tl.float8e4b8] = torch.float8_e4m3fnuz + +name_to_tl_types = { + 'int8': tl.int8, + 'int32': tl.int32, + 'fp16': tl.float16, + 'fp32': tl.float32, + 'bf16': tl.bfloat16, + 'fp8': tl.float8e4b8, + 'bf8': tl.float8e5b16, +} + + +def process_item(item): + M = item['M'] + N = item['N'] + K = item['K'] + col_a = False if item['rowMajorA'] == 'T' else True + col_b = False if item['rowMajorB'] == 'T' else True + del item['M'] + del item['N'] + del item['K'] + del item['rowMajorA'] + del item['rowMajorB'] + return M, N, K, col_a, col_b, item + + +def type_name_to_bytes(ty_name): + if '32' in ty_name: + return 4 + if '16' in ty_name: + return 2 + if '8' in ty_name: + return 1 + else: + print(f"Unrecognized input type name {ty_name}") + sys.exit(1) + + +def format_output(unformatted): + if unformatted < 0.0001: + formatted = "{:.3e}".format(unformatted) + elif unformatted > 1000: + formatted = "{:.1f}".format(unformatted) + else: + formatted = "{:.2f}".format(unformatted) + return formatted + + +def main(): + args = parse_args() + matrix_size_file = args.gemm_size_file + tuning_output_file = args.o + keepTmp = args.keep + run_bench = args.benchmark + jobs = args.jobs + iters = args.iters + skipWarmup = args.no_warmup + num_cus = 304 + + # Get GPU ids + ngpus = args.ngpus + gpu_ids = args.gpu_ids + if ngpus != 0 and gpu_ids: + print("--ngpus and --gpu_ids are mutually exclusive options") + return os.EX_USAGE + if ngpus == 0 and not gpu_ids: + ngpus = 1 + if ngpus != 0: + gpus = range(ngpus) + if gpu_ids: + gpus = gpu_ids + + if run_bench: + gpus = [gpus[0]] + jobs = 1 + + # Get element type + dtype_a = args.dtype_a + dtype_b = args.dtype_b + dtype_c = args.dtype_c + dtype_p = 'fp32' + dtype_lock = 'int32' + if dtype_a not in name_to_tl_types or dtype_b not in name_to_tl_types or dtype_c not in name_to_tl_types: + print(f"Unsupported dtype_a {args.dtype_a} or dtype_b {args.dtype_b} or dtype_c {args.dtype_c}") + print("Supported types: ", list(name_to_tl_types.keys())) + sys.exit(1) + + mnks = [] + # TODO: make it more robust to get user input + init_type = args.init_type + if matrix_size_file == "" or not os.path.isfile(matrix_size_file): + M = args.m + N = args.n + K = args.k + col_a = args.col_a + col_b = args.col_b + mnks = [(M, N, K, col_a, col_b, None)] + else: + with open(matrix_size_file) as file: + matrix_sizes = yaml.safe_load(file) + for item in matrix_sizes: + M, N, K, col_a, col_b, item = process_item(item) + mnks.append((M, N, K, col_a, col_b, item)) + + # Check correctness from given configs + if args.compare_wo_tuning: + for (M, N, K, col_a, col_b, myConfig) in mnks: + test_correctness(M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, myConfig, True) + return + + configs_full = get_full_tuning_space() + + start_time = datetime.now() + if run_bench: + print(f"Benchmarking gemm with {dtype_a} inputs") + print("trans M N K TFLOPS us") + else: + print(f"Tuning {len(mnks)} gemm sizes starts at: {start_time}", flush=True) + f_results = open(tuning_output_file, 'w') + + for (M, N, K, col_a, col_b, myConfig) in mnks: + start_local_time = datetime.now() + # Obtain a pruned tuning space according to gemm size + # If running benchmark, use the provided config + pruned_configs = [myConfig] if run_bench else prune_configs(M, N, K, configs_full, type_name_to_bytes(dtype_a), + type_name_to_bytes(dtype_b)) + + row_a_str = 'N' if col_a else 'T' + row_b_str = 'N' if col_b else 'T' + size_str = f'SIZE: {M} {N} {K} {row_a_str}{row_b_str}' + if not run_bench: + print(f"{size_str} nConfigs: {len(pruned_configs)}", end=" ", flush=True) + else: + print(f"{row_a_str}{row_b_str} {M:5d} {N:5d} {K:5d} ", end="") + + # The main tuning funtion for one gemm size + verbose_level = 0 + if args.time_breakdown: + verbose_level = 1 + if args.verbose: + verbose_level = 2 + minTime, bestConfig, compile_time, profile_time, post_time = tune_gemm_config( + M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, dtype_p, dtype_lock, init_type, pruned_configs, + run_bench, jobs, iters, skipWarmup, num_threads=args.num_threads, gpus=gpus, verbose=verbose_level) + + EVEN_K = True if K % bestConfig.get('BLOCK_SIZE_K') == 0 else False + # post processing the numbers + perf_tflops = lambda us: 2 * M * N * K * 1e-12 / (us * 1e-6) + tri_tflops = perf_tflops(minTime) + formatted_tflops = format_output(tri_tflops) + minTime = format_output(minTime) + if not run_bench: + print(f'TFLOPS: {formatted_tflops} time(us): {minTime}', end=" ", flush=True) + + bestConfig_compact_str, _ = gen_kernel_and_configStr_from_config(M, N, K, num_cus, EVEN_K, bestConfig, None, + None, None, None, None) + if not run_bench: + print(f'best_config: {bestConfig_compact_str}', end=" ", flush=True) + + # write best config to tuning_results.yaml + if run_bench: + print(f"{formatted_tflops} {minTime}") + + sizeDict = {'M': M, 'N': N, 'K': K, 'rowMajorA': row_a_str, 'rowMajorB': row_b_str} + sizeDict.update(bestConfig) + if not run_bench: + f_results.write("- " + str(sizeDict) + " ") + f_results.write(f'# TFLOPS: {formatted_tflops} time(us): {minTime}\n') + + # remove generated files if asked to + if not keepTmp: + for i in range(jobs): + generated_script = generated_kernel_name(M, N, K, i) + os.remove(generated_script) + if not skipWarmup: + os.remove(generated_script + ".failed_configs") + for f in glob.glob(f"results_{i}.*"): + os.remove(f) + + # Check correctness if asked to + if args.compare: + print("correctness: ", end=" ", flush=True) + test_correctness(M, N, K, num_cus, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, bestConfig, False) + elif not run_bench: + print("", flush=True) + + end_local_time = datetime.now() + if not run_bench: + print( + f">>> Elapsed time: {end_local_time - start_local_time} = {compile_time} (compile) + {profile_time} (profile) + {post_time} (post processing)", + flush=True) + + if not run_bench: + f_results.close() + + end_time = datetime.now() + tuning_time = end_time - start_time + if not run_bench: + print(f"Tuning ends at: {end_time}") + print(f"Total tuning time (h:m:s): {tuning_time}") + + +if __name__ == '__main__': + sys.exit(main())