From bd83f1bc64dfc8bfc0e3d17750ea8134f27750b6 Mon Sep 17 00:00:00 2001 From: Bruno Mazzotti Date: Thu, 15 Aug 2024 19:52:24 +0000 Subject: [PATCH] Fix *tune_gemm* issue with (1, 1) bias tensors --- python/perf-kernels/tune_gemm/tune_gemm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/perf-kernels/tune_gemm/tune_gemm.py b/python/perf-kernels/tune_gemm/tune_gemm.py index 7ddcc929c5e4..ebf70a861dc5 100755 --- a/python/perf-kernels/tune_gemm/tune_gemm.py +++ b/python/perf-kernels/tune_gemm/tune_gemm.py @@ -329,7 +329,7 @@ def gen_rotating_tensors(M, N, K, dtype_a, need_Trans_a, dtype_b, need_Trans_b, c.append(out_c) if bias_size > 0: bs, bs_fp16 = gen_input(M, 1, dtype_b, need_Trans_b, 2, init_type, device='cuda') - bias.append(bs.squeeze()) + bias.append(bs.squeeze(dim=1)) in_outs = {"rotating_num": block_count, "input_a": a, "input_b": b, "output_c": c, "bias": bias} @@ -369,8 +369,8 @@ def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type bias = None if use_bias: bias, bias_fp16 = gen_input(M, 1, dtype_b, col_b, 2, init_type, device='cuda') - bias = bias.squeeze() - bias_fp16 = bias.squeeze() + bias = bias.squeeze(dim=1) + bias_fp16 = bias.squeeze(dim=1) # Allocates output. c = torch.zeros((M, N), device=a.device, dtype=tl_to_torch_types[name_to_tl_types[dtype_c]]) triton_output = matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages,