Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune gemm v3.4] Add xcd-based pid remapping and change back to rocprofv1 #630

Merged
merged 5 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions python/perf-kernels/tune_gemm/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# GEMM tuning script (current v3.3)
# GEMM tuning script (current v3.4)

## matmul kernel

The matmul kernel implementation can be found as [matmul_kernel.py](https://github.com/ROCm/triton/blob/main_perf/python/perf-kernels/tune_gemm/matmul_kernel.py), which includes the following features:
- XCD-based pid remapping
- grouping order of workgroup id, which is controlled by `GROUP_SIZE_M`, that
implements L2 cache optimization introduced in the [tutorial](https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html#l2-cache-optimizations).
- split-k algorithm, which is controlled by `SPLIT_K`.
Expand Down Expand Up @@ -144,7 +145,7 @@ The default value is 1000.

The general idea of the tuning script can be summarized as
- Compile all the kernels in the tuning space in parallel.
- Divide the tuning space into tasks and invoke `rocprofv2` once per
- Divide the tuning space into tasks and invoke `rocprof` once per
task. This will save invocation overhead of the profiler.
- Profile tasks in parallel on multiple GPUs.

Expand Down Expand Up @@ -309,6 +310,21 @@ places:
- Statically set `device` and `stream` in the [jit.py](https://github.com/triton-lang/triton/blob/fd691c67ac20958a67693358186d877790f5f48f/python/triton/runtime/jit.py#L588-L589)


# GEMM Tuning Script v3.4

## API changes

No API changes

## Implementation changes

- Now the matmul_kernel supports XCD-based pid remapping. Details with experiments
will be added later.
- Switched back to rocprofv1. Check [ticket#228](https://github.com/ROCm/triton-internal/issues/228) for more details.
- Improved the post-procesing logic to filter out the "spikes" in the profiling results.
- Reduced the number of iterations in both tuning and benchmark mode (120 and 200).


# One config running script

`one_config.py` is a script that runs one given matmul config.
Expand Down
16 changes: 14 additions & 2 deletions python/perf-kernels/tune_gemm/matmul_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,22 @@
def matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm,
stride_cn, stride_bias, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, SPLIT_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BIAS: tl.constexpr,
EVEN_K: tl.constexpr):
EVEN_K: tl.constexpr, GRID_MN: tl.constexpr, NUM_XCDS: tl.constexpr):
pid = tl.program_id(axis=0)
pid_z = tl.program_id(1)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)

if NUM_XCDS != 1:
## pid remapping on xcds
# Number of pids per XCD in the new arrangement
pids_per_xcd = GRID_MN // NUM_XCDS
# Compute current XCD and local pid within the XCD
xcd = pid % NUM_XCDS
local_pid = pid // NUM_XCDS
# Calculate new pid based on the new grouping
pid = xcd * pids_per_xcd + local_pid

if GROUP_SIZE_M == 1:
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
Expand All @@ -19,8 +30,9 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak,
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this change make an impact?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really. This affects the swizzling in the very last group when M % GROUP_SIZE_M !=0, which is not usually in our tuning space.

pid_n = (pid % num_pid_in_group) // group_size_m

if SPLIT_K == 1:
offs_k = tl.arange(0, BLOCK_SIZE_K)
else:
Expand Down
39 changes: 16 additions & 23 deletions python/perf-kernels/tune_gemm/tune_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_full_tuning_space():
block_k_range = [16, 32, 64, 128, 256]
split_k_range = [1, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24]
num_warps_range = [1, 2, 4, 8]
group_m_range = [1, 4, 8, 16, 32]
group_m_range = [1, 2, 4, 8, 16, 32]
# For now we see better perf with num_stages=0 for all gemm configs we care
# But keep this explicit so that we do not forget we may need to set it to
# other values in the future
Expand Down Expand Up @@ -157,7 +157,7 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b):
if num_warps < 4:
continue
# check if tiling is integer multiple of GEMM size because we have no boundary check
if M % BLOCK_SIZE_M != 0 or N % BLOCK_SIZE_N != 0 or K % BLOCK_SIZE_K != 0:
if M % BLOCK_SIZE_M != 0 or N % BLOCK_SIZE_N != 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

M, N could be irregular as well ?

continue

pruned_configs.append(config)
Expand All @@ -169,20 +169,15 @@ def need_split_k(SIZE_M, SIZE_N, SIZE_K):
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024


def extract_kernel_time(M, N, K, config, df, bias_size):
# Correct the header by removing 'sig' and 'obj' to reduce number from 21 to 19
# once the bug(https://github.com/ROCm/rocprofiler/issues/144) fixed, we should
# not need below two lines
cols = [
'Index', 'KernelName', 'gpu-id', 'queue-id', 'queue-index', 'pid', 'tid', 'grd', 'wgr', 'lds', 'scr',
'arch_vgpr', 'accum_vgpr', 'sgpr', 'wave_size', 'DispatchNs', 'BeginNs', 'EndNs', 'CompleteNs'
]
df.columns = cols
def extract_kernel_time(M, N, K, config, df):
configStr = gen_configStr(config)
filtered_df = df[df['KernelName'].str.contains(configStr, na=False)].copy()
filtered_df['DurationNs'] = filtered_df['EndNs'] - filtered_df['BeginNs']
meanTime = filtered_df['DurationNs'].tail(100).mean()
return config, meanTime
df = df[df['KernelName'].str.contains(configStr)]

first_value = df['DurationNs'].iloc[0]
filtered_data = df['DurationNs'][df['DurationNs'] <= first_value]
new_meanTime = filtered_data.tail(100).mean()

return config, new_meanTime


def profile_batch_kernels(M, N, K, gpuid, gpus, jobs, verbose):
Expand All @@ -197,7 +192,7 @@ def profile_batch_kernels(M, N, K, gpuid, gpus, jobs, verbose):
if verbose:
print(f"profiling {kernel_name} on GPU {gpuid}")
run_bash_command_wrapper(
f"rocprofv2 --plugin file --plugin-version 1 --kernel-trace -o {jobId} python {get_filename_profile_driver(M, N, K, jobId)}",
f"rocprof --stats -o results_{jobId}.csv python {get_filename_profile_driver(M, N, K, jobId)}",
capture=(verbose < 2))
jobId += ngpus

Expand Down Expand Up @@ -244,13 +239,10 @@ def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type
thread_pool = multiprocessing.Pool(processes=num_threads)
tasks = []
idx = 0
df_prof = [
pd.read_csv(f"results_{i}.csv", skiprows=1, header=None, delimiter=',', quotechar='"', escapechar='\\')
for i in range(jobs)
]
df_prof = [pd.read_csv(f"results_{i}.csv") for i in range(jobs)]
for config in configs:
file_idx = idx % jobs
tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, df_prof[file_idx], bias_size))]
tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, df_prof[file_idx]))]
idx += 1
thread_pool.close()
thread_pool.join()
Expand Down Expand Up @@ -366,11 +358,12 @@ def matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps
grid = triton.cdiv(M, block_m) * triton.cdiv(N, block_n), split_k
stride_bias = bias.stride(0) if use_bias else 0
EVEN_K = K % block_k == 0
num_xcds = 1 if split_k > 1 else 8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XCD = 8 is only applicable to MI300X ?

matmul_kernel[grid](a, b, c, bias, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0),
c.stride(1), stride_bias=stride_bias, BLOCK_SIZE_M=block_m, BLOCK_SIZE_N=block_n,
BLOCK_SIZE_K=block_k, GROUP_SIZE_M=group_m, SPLIT_K=split_k, num_warps=num_warps,
num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=mfmaInstrSize,
kpack=kpack, BIAS=use_bias, EVEN_K=EVEN_K)
kpack=kpack, BIAS=use_bias, EVEN_K=EVEN_K, GRID_MN=grid[0], NUM_XCDS=num_xcds)
return c


Expand Down Expand Up @@ -441,7 +434,7 @@ def parse_args():
parser.add_argument("--num_threads", type=int, default=32,
help="number of threads to use for kernel compilation and post processing")
parser.add_argument("--jobs", type=int, default=1, help="number of tasks during the profiling process")
parser.add_argument("--iters", type=int, default=1000, help="number of iterations used in --benchmark mode")
parser.add_argument("--iters", type=int, default=200, help="number of iterations used in --benchmark mode")
parser.add_argument("--init_type", type=str, default='randn', choices=['randn', 'hpl', 'trig_float', 'zeros'],
help="Input tensor initialization (default normal distribution)")
parser.add_argument(
Expand Down
17 changes: 13 additions & 4 deletions python/perf-kernels/tune_gemm/utils/file_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def gen_kernel_and_configStr_from_config(config, EVEN_K, dtype_a, dtype_b, dtype

use_bias = bias_size > 0

## Let's enable xcd-based pid remapping only when split-K is NOT used
## Also #xcd is fixed to 8. If we are tuning for MI308, please change it to 4
num_xcds = 1 if split_k > 1 else 8

if warmup:
torch_dtype_a = 'fp16'
torch_dtype_b = 'fp16'
Expand All @@ -89,6 +93,7 @@ def gen_kernel_and_configStr_from_config(config, EVEN_K, dtype_a, dtype_b, dtype

matmul_def_str = f"""
def matmul_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, biasn):
grid_mn = triton.cdiv(M, {block_m}) * triton.cdiv(N, {block_n})
matmul_kernel_{configStr}.warmup(
{torch_dtype_a}, {torch_dtype_b}, {torch_dtype_c}, {torch_dtype_c},
M, N, K,
Expand All @@ -103,8 +108,10 @@ def matmul_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, biasn):
waves_per_eu = {waves_per_eu},
matrix_instr_nonkdim = {mfmaInstrSize},
kpack = {kpack},
BIAS={use_bias},
EVEN_K={EVEN_K},
BIAS = {use_bias},
EVEN_K = {EVEN_K},
GRID_MN = grid_mn,
NUM_XCDS = {num_xcds},
grid=(1,),
)
return None
Expand Down Expand Up @@ -136,7 +143,9 @@ def matmul_{configStr}(a, b, c, bias, M, N, K, am, ak, bk, bn, cm, cn, biasn):
matrix_instr_nonkdim = {mfmaInstrSize},
kpack = {kpack},
BIAS = {use_bias},
EVEN_K = {EVEN_K}
EVEN_K = {EVEN_K},
GRID_MN = grid[0],
NUM_XCDS = {num_xcds}
)
return c
"""
Expand Down Expand Up @@ -310,7 +319,7 @@ def generate_profile_tasks(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, ini

# call all matmul_xxx functions
idx = 0
runs = iters if run_bench else 200
runs = iters if run_bench else 120
for config in configs:
configStr = gen_configStr(config)
matmul_call_str = f"""
Expand Down
Loading