Skip to content

Commit

Permalink
Add testcase for overall mean regression
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexAUT committed Oct 1, 2024
1 parent 7f61047 commit 76687cf
Showing 1 changed file with 87 additions and 75 deletions.
162 changes: 87 additions & 75 deletions python/perf-kernels/tools/tune_gemm/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,85 +4,17 @@
import pytest
import warnings
from copy import deepcopy
import statistics


@pytest.mark.parametrize('config', [
# M // BLOCK_M * N // BLOCK_N % 304 == 0
# 1 workgroup / CU
{
'M': 4864, 'N': 4096, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0,
'matrix_instr_nonkdim': 16, 'kpack': 2
},
{
'M': 4864, 'N': 4096, 'K': 4160, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0,
'matrix_instr_nonkdim': 16, 'kpack': 2
},
{
'M': 4864, 'N': 4096, 'K': 4224, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0,
'matrix_instr_nonkdim': 16, 'kpack': 2
},
{
'M': 4864, 'N': 4096, 'K': 4288, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0,
'matrix_instr_nonkdim': 16, 'kpack': 2
},
# 1 workgroup / CU masked loadK
{
'M': 4864, 'N': 4096, 'K': 4097, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0,
'matrix_instr_nonkdim': 16, 'kpack': 2
},
{
'M': 4864, 'N': 4096, 'K': 4098, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0,
'matrix_instr_nonkdim': 16, 'kpack': 2
},
{
'M': 4864, 'N': 4096, 'K': 4100, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0,
'matrix_instr_nonkdim': 16, 'kpack': 2
},
{
'M': 4864, 'N': 4096, 'K': 4104, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0,
'matrix_instr_nonkdim': 16, 'kpack': 2
},
{
'M': 4864, 'N': 4096, 'K': 4112, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0,
'matrix_instr_nonkdim': 16, 'kpack': 2
},
# 2 workgroups / CU
{
'M': 4864, 'N': 8192, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0,
'matrix_instr_nonkdim': 16, 'kpack': 2
},
{
'M': 4864, 'N': 8192, 'K': 4160, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0,
'matrix_instr_nonkdim': 16, 'kpack': 2
},
{
'M': 4864, 'N': 8192, 'K': 8192, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0,
'matrix_instr_nonkdim': 16, 'kpack': 2
},
{
'M': 4864, 'N': 8192, 'K': 8256, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0,
'matrix_instr_nonkdim': 16, 'kpack': 2
},
], ids=lambda val: f"Config: {val}")
class TestRegression:

@classmethod
def setup_class(self):
self.slowdown_threshold = 0.97

self.test_results = []
self.test_perf_ratios = []
try:
with open('gemm-performance-report-reference.yaml', 'r') as ref_file:
self.reference_data = yaml.safe_load(ref_file)
Expand All @@ -95,6 +27,78 @@ def teardown_class(self):
with open('gemm-performance-report.yaml', 'w') as out_file:
yaml.safe_dump(self.test_results, out_file)

@pytest.mark.parametrize('config', [
# M // BLOCK_M * N // BLOCK_N % 304 == 0
# 1 workgroup / CU
{
'M': 4864, 'N': 4096, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N':
256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu':
0, 'matrix_instr_nonkdim': 16, 'kpack': 2
},
{
'M': 4864, 'N': 4096, 'K': 4160, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N':
256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu':
0, 'matrix_instr_nonkdim': 16, 'kpack': 2
},
{
'M': 4864, 'N': 4096, 'K': 4224, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N':
256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu':
0, 'matrix_instr_nonkdim': 16, 'kpack': 2
},
{
'M': 4864, 'N': 4096, 'K': 4288, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N':
256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu':
0, 'matrix_instr_nonkdim': 16, 'kpack': 2
},
# 1 workgroup / CU masked loadK
{
'M': 4864, 'N': 4096, 'K': 4097, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N':
256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu':
0, 'matrix_instr_nonkdim': 16, 'kpack': 2
},
{
'M': 4864, 'N': 4096, 'K': 4098, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N':
256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu':
0, 'matrix_instr_nonkdim': 16, 'kpack': 2
},
{
'M': 4864, 'N': 4096, 'K': 4100, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N':
256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu':
0, 'matrix_instr_nonkdim': 16, 'kpack': 2
},
{
'M': 4864, 'N': 4096, 'K': 4104, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N':
256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu':
0, 'matrix_instr_nonkdim': 16, 'kpack': 2
},
{
'M': 4864, 'N': 4096, 'K': 4112, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N':
256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu':
0, 'matrix_instr_nonkdim': 16, 'kpack': 2
},
# 2 workgroups / CU
{
'M': 4864, 'N': 8192, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N':
256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu':
0, 'matrix_instr_nonkdim': 16, 'kpack': 2
},
{
'M': 4864, 'N': 8192, 'K': 4160, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N':
256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu':
0, 'matrix_instr_nonkdim': 16, 'kpack': 2
},
{
'M': 4864, 'N': 8192, 'K': 8192, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N':
256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu':
0, 'matrix_instr_nonkdim': 16, 'kpack': 2
},
{
'M': 4864, 'N': 8192, 'K': 8256, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N':
256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu':
0, 'matrix_instr_nonkdim': 16, 'kpack': 2
},
], ids=lambda val: f"Config: {val}")
def test_matmul_performance_regression(self, config, record_property):
# Get GPU ids
gpus = [0]
Expand Down Expand Up @@ -138,7 +142,6 @@ def test_matmul_performance_regression(self, config, record_property):
self.test_results.append({'config': config, 'tflops': float(tri_tflops)})

# Look for reference run

reference_run = None
for run in self.reference_data:
if run['config'] == config:
Expand All @@ -147,9 +150,18 @@ def test_matmul_performance_regression(self, config, record_property):

if reference_run is not None:
performance_ratio = tri_tflops / reference_run['tflops']
slowdown_threshold = 0.97
self.test_perf_ratios.append(performance_ratio)
regression_percent = (100.0 * (1.0 - performance_ratio))
record_property("Performance difference (lower is better)", f"{regression_percent:.2f}%")
assert performance_ratio > slowdown_threshold, f'Performance regressed by {regression_percent:.2f}% (threshold={((1.0 - slowdown_threshold) * 100.0 ):.2f}%)'
assert performance_ratio > self.slowdown_threshold, f'Performance regressed by {regression_percent:.2f}% (threshold={((1.0 - self.slowdown_threshold) * 100.0 ):.2f}%)'
else:
pytest.skip("No performance reference found!")

def test_overall_performance_difference(self, record_property):
if len(self.test_perf_ratios) < 2:
pytest.skip("Overall results will be tested if test count > 2")
perf_diff_mean = statistics.mean(self.test_perf_ratios)
regression_percent = (100.0 * (1.0 - perf_diff_mean))

record_property("Overall performance difference (mean)", f"{regression_percent:.2f}%")
assert perf_diff_mean > self.slowdown_threshold, f'Performance regressed by {regression_percent:.2f}% (threshold={((1.0 - self.slowdown_threshold) * 100.0 ):.2f}%)'

0 comments on commit 76687cf

Please sign in to comment.