Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add performance reference for important matmul kernels #642

Open
wants to merge 5 commits into
base: main_perf
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions python/perf-kernels/tools/tune_gemm/[email protected]
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
trans,M,N,K,TFLOPS,us
TN,4864,4096,4096,467.39,349.19
TN,4864,4096,4160,567.17,292.26
TN,4864,4096,4224,557.49,301.90
TN,4864,4096,4288,569.55,299.99
TN,4864,4096,4097,501.58,325.47
TN,4864,4096,4098,491.96,331.92
TN,4864,4096,4100,503.51,324.46
TN,4864,4096,4104,515.70,317.10
TN,4864,4096,4112,525.66,311.70
TN,4864,8192,4096,519.95,627.79
TN,4864,8192,4160,579.14,572.43
TN,4864,8192,8192,543.30,1201.6
TN,4864,8192,8256,563.43,1167.7
18 changes: 18 additions & 0 deletions python/perf-kernels/tools/tune_gemm/database.yaml
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are the purposes of this database file ? are they used to benchmark against the ref.csv ? if that's the case, what if the parameters changed, e.g GROUP_SIZE_M change from 4 to 8 ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we may need other database file format if this is a daily/per commit tasks ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This database.yaml is used as the current best perf config. If you have an optimization that can improve the best perf number and requires a different config, we should update the database.
The main purpose is to catch regression.

Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# M // BLOCK_M * N // BLOCK_N % 304 == 0
## 1 workgroup / CU
- {'M': 4864, 'N': 4096, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2}
- {'M': 4864, 'N': 4096, 'K': 4160, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2}
- {'M': 4864, 'N': 4096, 'K': 4224, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2}
- {'M': 4864, 'N': 4096, 'K': 4288, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2}
## 1 workgroup / CU masked loadK
- {'M': 4864, 'N': 4096, 'K': 4097, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2}
- {'M': 4864, 'N': 4096, 'K': 4098, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2}
- {'M': 4864, 'N': 4096, 'K': 4100, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2}
- {'M': 4864, 'N': 4096, 'K': 4104, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2}
- {'M': 4864, 'N': 4096, 'K': 4112, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2}

## 2 workgroups / CU
- {'M': 4864, 'N': 8192, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2}
- {'M': 4864, 'N': 8192, 'K': 4160, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2}
- {'M': 4864, 'N': 8192, 'K': 8192, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2}
- {'M': 4864, 'N': 8192, 'K': 8256, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2}
24 changes: 17 additions & 7 deletions python/perf-kernels/tools/tune_gemm/matmul_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,26 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak,
bias = tl.load(bias_ptrs, mask=offs_am < M, other=0.0)
acc_dtype = tl.float32 if a_ptr.type.element_ty != tl.int8 else tl.int32
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):
if EVEN_K:
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
else:
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)

max_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) - 1
if EVEN_K:
max_k += 1
for k in range(0, max_k):
a = tl.load(tl.multiple_of(a_ptrs, (1, 16)))
b = tl.load(tl.multiple_of(b_ptrs, (16, 1)))
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk

if not EVEN_K:
k = max_k
offs_k = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
a_ptrsX = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrsX = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
a = tl.load(a_ptrsX, mask=offs_k[None, :] < K, other=0.0)
b = tl.load(b_ptrsX, mask=offs_k[:, None] < K, other=0.0)
accumulator += tl.dot(a, b)

c = accumulator.to(c_ptr.type.element_ty)
if BIAS:
c += bias[:, None]
Expand Down
8 changes: 5 additions & 3 deletions python/perf-kernels/tools/tune_gemm/tune_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b):
num_warps = config.get("num_warps")
num_stages = config.get("num_stages")
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
EVEN_K = (K % BLOCK_SIZE_K == 0)
if matrix_instr_nonkdim > mfma:
continue
if mfma == 4 and BLOCK_SIZE_K < 64:
Expand Down Expand Up @@ -149,10 +150,11 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b):
continue
# Skip small block sizes and num_warps for large gemm
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
# We only want to use a small BLOCK_SIZE_K if not EVEN_K
if large_gemm:
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
continue
if BLOCK_SIZE_K < 64:
if BLOCK_SIZE_K < 64 and EVEN_K:
continue
if num_warps < 4:
continue
Expand Down Expand Up @@ -657,14 +659,14 @@ def main():

# write best config to tuning_results.yaml
if run_bench:
print(f"{formatted_tflops} {minTime}")
print(f"{formatted_tflops} {minTime} {bestConfig_compact_str}")
f_results.write(f"{formatted_tflops},{minTime}\n")

sizeDict = {'M': M, 'N': N, 'K': K, 'rowMajorA': row_a_str, 'rowMajorB': row_b_str}
sizeDict.update(bestConfig)
if not run_bench:
f_results.write("- " + str(sizeDict) + " ")
f_results.write(f'# TFLOPS: {formatted_tflops} time(us): {minTime}\n')
f_results.write(f'# {bestConfig_compact_str}\n')

# remove generated files if asked to
if not keepTmp:
Expand Down
Loading