Skip to content

Commit

Permalink
Selectively disable yapf for parts of one_config.py
Browse files Browse the repository at this point in the history
  • Loading branch information
brunomazzottiamd committed Aug 16, 2024
1 parent abab401 commit 2cbf137
Showing 1 changed file with 30 additions and 7 deletions.
37 changes: 30 additions & 7 deletions python/perf-kernels/tune_gemm/one_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,23 @@ def parse_args():

def parse_config(cfg_str):
values = cfg_str.split("_")
# yapf: disable
config_name = {
"M": "M", "N": "N", "K": "K", "BM": "BLOCK_SIZE_M", "BN": "BLOCK_SIZE_N", "BK": "BLOCK_SIZE_K", "GM":
"GROUP_SIZE_M", "SK": "SPLIT_K", "nW": "num_warps", "nS": "num_stages", "EU": "waves_per_eu", "kP": "kpack",
"mfma": "matrix_instr_nonkdim"
"M": "M",
"N": "N",
"K": "K",
"BM": "BLOCK_SIZE_M",
"BN": "BLOCK_SIZE_N",
"BK": "BLOCK_SIZE_K",
"GM": "GROUP_SIZE_M",
"SK": "SPLIT_K",
"nW": "num_warps",
"nS": "num_stages",
"EU": "waves_per_eu",
"kP": "kpack",
"mfma": "matrix_instr_nonkdim",
}
# yapf: enable
config = {}
for val in values:
match = re.search("([a-zA-Z]*)([0-9]*)", val)
Expand All @@ -65,12 +77,23 @@ def main():
if args.config_str:
config = parse_config(args.config_str)
else:
# yapf: disable
config = {
"M": args.m, "N": args.n, "K": args.k, "BLOCK_SIZE_M": args.block_m, "BLOCK_SIZE_N": args.block_n,
"BLOCK_SIZE_K": args.block_k, "GROUP_SIZE_M": args.group_m, "SPLIT_K": args.split_k, "num_warps":
args.num_warps, "num_stages": args.num_stages, "waves_per_eu": args.waves_per_eu, "kpack": args.kpack,
"matrix_instr_nonkdim": args.matrix_instr_nonkdim
"M": args.m,
"N": args.n,
"K": args.k,
"BLOCK_SIZE_M": args.block_m,
"BLOCK_SIZE_N": args.block_n,
"BLOCK_SIZE_K": args.block_k,
"GROUP_SIZE_M": args.group_m,
"SPLIT_K": args.split_k,
"num_warps": args.num_warps,
"num_stages": args.num_stages,
"waves_per_eu": args.waves_per_eu,
"kpack": args.kpack,
"matrix_instr_nonkdim": args.matrix_instr_nonkdim,
}
# yapf: enable
tune_gemm.test_correctness(config["M"], config["N"], config["K"], args.col_a, args.col_b, args.dtype_a,
args.dtype_b, args.dtype_c, args.init_type, config, args.bias_vector, verbose=True)

Expand Down

0 comments on commit 2cbf137

Please sign in to comment.