Skip to content

Commit

Permalink
Fix *tune_gemm* issue with (1, 1) bias tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
brunomazzottiamd committed Aug 15, 2024
1 parent 4f75b0f commit bd83f1b
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/perf-kernels/tune_gemm/tune_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def gen_rotating_tensors(M, N, K, dtype_a, need_Trans_a, dtype_b, need_Trans_b,
c.append(out_c)
if bias_size > 0:
bs, bs_fp16 = gen_input(M, 1, dtype_b, need_Trans_b, 2, init_type, device='cuda')
bias.append(bs.squeeze())
bias.append(bs.squeeze(dim=1))

in_outs = {"rotating_num": block_count, "input_a": a, "input_b": b, "output_c": c, "bias": bias}

Expand Down Expand Up @@ -369,8 +369,8 @@ def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type
bias = None
if use_bias:
bias, bias_fp16 = gen_input(M, 1, dtype_b, col_b, 2, init_type, device='cuda')
bias = bias.squeeze()
bias_fp16 = bias.squeeze()
bias = bias.squeeze(dim=1)
bias_fp16 = bias.squeeze(dim=1)
# Allocates output.
c = torch.zeros((M, N), device=a.device, dtype=tl_to_torch_types[name_to_tl_types[dtype_c]])
triton_output = matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages,
Expand Down

0 comments on commit bd83f1b

Please sign in to comment.