Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune gemm v3.4] Add xcd-based pid remapping and change back to rocprofv1 #630

Merged
merged 5 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 18 additions & 21 deletions python/perf-kernels/tune_gemm/tune_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_full_tuning_space():
block_k_range = [16, 32, 64, 128, 256]
split_k_range = [1, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24]
num_warps_range = [1, 2, 4, 8]
group_m_range = [1, 4, 8, 16, 32]
group_m_range = [1, 2, 4, 8, 16]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason why we get rid of 2, and 32? If we can't enable XCD mapping, we may need 32.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me add XCD mapping and re-think about the range for group_m then.

# For now we see better perf with num_stages=0 for all gemm configs we care
# But keep this explicit so that we do not forget we may need to set it to
# other values in the future
Expand Down Expand Up @@ -169,20 +169,20 @@ def need_split_k(SIZE_M, SIZE_N, SIZE_K):
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024


def extract_kernel_time(M, N, K, config, df, bias_size):
# Correct the header by removing 'sig' and 'obj' to reduce number from 21 to 19
# once the bug(https://github.com/ROCm/rocprofiler/issues/144) fixed, we should
# not need below two lines
cols = [
'Index', 'KernelName', 'gpu-id', 'queue-id', 'queue-index', 'pid', 'tid', 'grd', 'wgr', 'lds', 'scr',
'arch_vgpr', 'accum_vgpr', 'sgpr', 'wave_size', 'DispatchNs', 'BeginNs', 'EndNs', 'CompleteNs'
]
df.columns = cols
def extract_kernel_time(M, N, K, config, df):
configStr = gen_configStr(config)
filtered_df = df[df['KernelName'].str.contains(configStr, na=False)].copy()
filtered_df['DurationNs'] = filtered_df['EndNs'] - filtered_df['BeginNs']
meanTime = filtered_df['DurationNs'].tail(100).mean()
return config, meanTime
df = df[df['KernelName'].str.contains(configStr)]
meanTime = df['DurationNs'].tail(100).mean()

first_value = df['DurationNs'].iloc[0]
filtered_data = df['DurationNs'][df['DurationNs'] <= first_value]
new_meanTime = filtered_data.tail(100).mean()

maxTime = df['DurationNs'].max()
maxTimeID = df['DurationNs'].idxmax()
#print(f"{maxTime=} {maxTimeID=} {meanTime=} {new_meanTime=} {first_value=}")
df['DurationNs'].to_csv(f"{M}-{N}-{K}.csv", index=False)
return config, new_meanTime


def profile_batch_kernels(M, N, K, gpuid, gpus, jobs, verbose):
Expand All @@ -197,7 +197,7 @@ def profile_batch_kernels(M, N, K, gpuid, gpus, jobs, verbose):
if verbose:
print(f"profiling {kernel_name} on GPU {gpuid}")
run_bash_command_wrapper(
f"rocprofv2 --plugin file --plugin-version 1 --kernel-trace -o {jobId} python {get_filename_profile_driver(M, N, K, jobId)}",
f"rocprof --stats -o results_{jobId}.csv python {get_filename_profile_driver(M, N, K, jobId)}",
capture=(verbose < 2))
jobId += ngpus

Expand Down Expand Up @@ -244,13 +244,10 @@ def tune_gemm_config(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type
thread_pool = multiprocessing.Pool(processes=num_threads)
tasks = []
idx = 0
df_prof = [
pd.read_csv(f"results_{i}.csv", skiprows=1, header=None, delimiter=',', quotechar='"', escapechar='\\')
for i in range(jobs)
]
df_prof = [pd.read_csv(f"results_{i}.csv") for i in range(jobs)]
for config in configs:
file_idx = idx % jobs
tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, df_prof[file_idx], bias_size))]
tasks += [thread_pool.apply_async(extract_kernel_time, args=(M, N, K, config, df_prof[file_idx]))]
idx += 1
thread_pool.close()
thread_pool.join()
Expand Down Expand Up @@ -441,7 +438,7 @@ def parse_args():
parser.add_argument("--num_threads", type=int, default=32,
help="number of threads to use for kernel compilation and post processing")
parser.add_argument("--jobs", type=int, default=1, help="number of tasks during the profiling process")
parser.add_argument("--iters", type=int, default=1000, help="number of iterations used in --benchmark mode")
parser.add_argument("--iters", type=int, default=200, help="number of iterations used in --benchmark mode")
parser.add_argument("--init_type", type=str, default='randn', choices=['randn', 'hpl', 'trig_float', 'zeros'],
help="Input tensor initialization (default normal distribution)")
parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion python/perf-kernels/tune_gemm/utils/file_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def generate_profile_tasks(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, ini

# call all matmul_xxx functions
idx = 0
runs = iters if run_bench else 200
runs = iters if run_bench else 120
for config in configs:
configStr = gen_configStr(config)
matmul_call_str = f"""
Expand Down
Loading