diff --git a/python/perf-kernels/tune_gemm/one_config.py b/python/perf-kernels/tune_gemm/one_config.py index 9b745a4d5053..5354a270f493 100644 --- a/python/perf-kernels/tune_gemm/one_config.py +++ b/python/perf-kernels/tune_gemm/one_config.py @@ -46,11 +46,23 @@ def parse_args(): def parse_config(cfg_str): values = cfg_str.split("_") + # yapf: disable config_name = { - "M": "M", "N": "N", "K": "K", "BM": "BLOCK_SIZE_M", "BN": "BLOCK_SIZE_N", "BK": "BLOCK_SIZE_K", "GM": - "GROUP_SIZE_M", "SK": "SPLIT_K", "nW": "num_warps", "nS": "num_stages", "EU": "waves_per_eu", "kP": "kpack", - "mfma": "matrix_instr_nonkdim" + "M": "M", + "N": "N", + "K": "K", + "BM": "BLOCK_SIZE_M", + "BN": "BLOCK_SIZE_N", + "BK": "BLOCK_SIZE_K", + "GM": "GROUP_SIZE_M", + "SK": "SPLIT_K", + "nW": "num_warps", + "nS": "num_stages", + "EU": "waves_per_eu", + "kP": "kpack", + "mfma": "matrix_instr_nonkdim", } + # yapf: enable config = {} for val in values: match = re.search("([a-zA-Z]*)([0-9]*)", val) @@ -65,12 +77,23 @@ def main(): if args.config_str: config = parse_config(args.config_str) else: + # yapf: disable config = { - "M": args.m, "N": args.n, "K": args.k, "BLOCK_SIZE_M": args.block_m, "BLOCK_SIZE_N": args.block_n, - "BLOCK_SIZE_K": args.block_k, "GROUP_SIZE_M": args.group_m, "SPLIT_K": args.split_k, "num_warps": - args.num_warps, "num_stages": args.num_stages, "waves_per_eu": args.waves_per_eu, "kpack": args.kpack, - "matrix_instr_nonkdim": args.matrix_instr_nonkdim + "M": args.m, + "N": args.n, + "K": args.k, + "BLOCK_SIZE_M": args.block_m, + "BLOCK_SIZE_N": args.block_n, + "BLOCK_SIZE_K": args.block_k, + "GROUP_SIZE_M": args.group_m, + "SPLIT_K": args.split_k, + "num_warps": args.num_warps, + "num_stages": args.num_stages, + "waves_per_eu": args.waves_per_eu, + "kpack": args.kpack, + "matrix_instr_nonkdim": args.matrix_instr_nonkdim, } + # yapf: enable tune_gemm.test_correctness(config["M"], config["N"], config["K"], args.col_a, args.col_b, args.dtype_a, args.dtype_b, args.dtype_c, args.init_type, config, args.bias_vector, verbose=True)