Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune gemm v3.4] Add xcd-based pid remapping and change back to rocprofv1 #630

Merged
merged 5 commits into from
Aug 19, 2024

Conversation

zhanglx13
Copy link

@zhanglx13 zhanglx13 commented Aug 17, 2024

This PR reverts #613 since there is a severe problem with rocprofv2 described in ticket#228.
The problem is that rocprofv2 will "miss" a lot of kernels in the tuning space. Therefore, sub-optimal config is picked.

We will switch back to rocprofv2 when the issue is resolved.

This PR also enabled xcd-based pid remapping. I need to run more experiments to understand the effects of xcd-based remapping and group_size_m (as described in ticket#229).
To disable xcd-based remapping, change this line from

if NUM_XCDS != 1:

to

if NUM_XCDS == 1:

- set --iters=200 as default. This is enough since the time is stable
after the first few runs.
- Filter out kernel time that is too large. We use the first kernel
time as the threshold. There must be something wrong with the kernel
if its elapsedTime is larger than the first run. We need to
investigate the reason. For now, just filter them out.
@@ -54,7 +54,7 @@ def get_full_tuning_space():
block_k_range = [16, 32, 64, 128, 256]
split_k_range = [1, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24]
num_warps_range = [1, 2, 4, 8]
group_m_range = [1, 4, 8, 16, 32]
group_m_range = [1, 2, 4, 8, 16]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason why we get rid of 2, and 32? If we can't enable XCD mapping, we may need 32.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me add XCD mapping and re-think about the range for group_m then.

@xiaohuguo2023
Copy link
Member

xiaohuguo2023 commented Aug 18, 2024

I think we also need enable irregular shapes tuning by removing below two lines

https://github.com/ROCm/triton/blob/e21d43cc414180bd9ccec121e164af2ae3faf290/python/perf-kernels/tune_gemm/tune_gemm.py#L160C13-L161C25

@zhanglx13 zhanglx13 changed the title [tune gemm] Change back to rocprofv1 [tune gemm] Add xcd-based pid remapping and change back to rocprofv1 Aug 18, 2024
@zhanglx13 zhanglx13 changed the title [tune gemm] Add xcd-based pid remapping and change back to rocprofv1 [tune gemm v3.4] Add xcd-based pid remapping and change back to rocprofv1 Aug 18, 2024
@@ -19,8 +30,9 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak,
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this change make an impact?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really. This affects the swizzling in the very last group when M % GROUP_SIZE_M !=0, which is not usually in our tuning space.

@@ -157,7 +157,7 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b):
if num_warps < 4:
continue
# check if tiling is integer multiple of GEMM size because we have no boundary check
if M % BLOCK_SIZE_M != 0 or N % BLOCK_SIZE_N != 0 or K % BLOCK_SIZE_K != 0:
if M % BLOCK_SIZE_M != 0 or N % BLOCK_SIZE_N != 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

M, N could be irregular as well ?

@@ -366,11 +358,12 @@ def matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps
grid = triton.cdiv(M, block_m) * triton.cdiv(N, block_n), split_k
stride_bias = bias.stride(0) if use_bias else 0
EVEN_K = K % block_k == 0
num_xcds = 1 if split_k > 1 else 8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XCD = 8 is only applicable to MI300X ?

@zhanglx13 zhanglx13 merged commit 15cb3a8 into main_perf Aug 19, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants