From d47adcede88ab79d590d0fdad27672137fd883ee Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Fri, 4 Oct 2024 11:18:51 -0500 Subject: [PATCH] Fix mfma cnt --- .../tools/tune_gemm/process_json.py | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/python/perf-kernels/tools/tune_gemm/process_json.py b/python/perf-kernels/tools/tune_gemm/process_json.py index c762da69ac79..aa2d478d7902 100644 --- a/python/perf-kernels/tools/tune_gemm/process_json.py +++ b/python/perf-kernels/tools/tune_gemm/process_json.py @@ -54,7 +54,8 @@ def gen_all_clk(code_fullname, trace_fullname): code_list = code_data['code'] found_1st_barrier = False - mfma_cnt = 0 + mfma_dsRead_cnt = 0 + mfma_cnt_total = 0 should_cnt = False ## Find the s_barriers for i in range(len(code_list)): @@ -68,10 +69,14 @@ def gen_all_clk(code_fullname, trace_fullname): ## This is barrier2 or barrier3 should_cnt = False if "mfma" in code_list[i][0] and should_cnt: - mfma_cnt += 1 + mfma_dsRead_cnt += 1 + if "mfma" in code_list[i][0]: + mfma_cnt_total += 1 - mfma_dsRead_cnt = mfma_cnt - mfma_dsWrite_cnt = 128 - mfma_cnt + ## /= 2 because the last iteration of local_load and tt.dot + ## is peeled off by stream-pipeliner + mfma_cnt_total /= 2 + mfma_dsWrite_cnt = mfma_cnt_total - mfma_dsRead_cnt if len(marker_barrier) != 3: print(f"Not 3 barriers?? Found {len(marker_barrier)}") @@ -121,7 +126,16 @@ def gen_all_clk(code_fullname, trace_fullname): if len1 == 0 or len2 == 0 or len3 == 0: incomplete = True - return firstInstr_clk, instrAfterBarrier1_clk, instrAfterBarrier2_clk, instrAfterBarrier3_clk, lastInstr_clk, mfma_dsRead_cnt, mfma_dsWrite_cnt, incomplete + #print(f"{firstInstr_clk}") + #print(f"{instrAfterBarrier1_clk}") + #print(f"{instrAfterBarrier2_clk}") + #print(f"{instrAfterBarrier3_clk}") + #print(f"{lastInstr_clk}") + #print(f"{mfma_dsRead_cnt}") + #print(f"{mfma_dsWrite_cnt}") + #print(f"{incomplete}") + + return firstInstr_clk, instrAfterBarrier1_clk, instrAfterBarrier2_clk, instrAfterBarrier3_clk, lastInstr_clk, mfma_dsRead_cnt, int(mfma_dsWrite_cnt), incomplete def gen_coarse_clk(instr0_clk, bar1_clk, bar3_clk, instr9_clk): @@ -225,6 +239,9 @@ def main(): continue trace_filename = f"se{se}_sm{sm}_sl{sl}_wv{wid}.json" trace_fullname = os.path.join(trace_dir, trace_filename) + if not os.path.isfile(trace_fullname): + #print(f"trace file not found {trace_fullname}") + return pro, loop, epi, iter_clk, lat1, lat2, lat_sum, idle1, idle2, incomplete = parse_trace(code_fullname, trace_fullname) if incomplete: continue