Skip to content

Commit

Permalink
Fix bias dtype issue for the TMA kernel (#3199)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#297

Pull Request resolved: #3199

Passing bias dtype through a const arg instead of hardcoding it in the kernel.

Addressing https://fb.workplace.com/groups/fbgemmusers/permalink/8689681817779189/

Reviewed By: sijiac

Differential Revision: D63569991

fbshipit-source-id: 46b5621bb668493369b1752512eb9fe86a8340df
  • Loading branch information
htyu authored and facebook-github-bot committed Oct 2, 2024
1 parent c50b5c8 commit 4a4d187
Showing 1 changed file with 29 additions and 7 deletions.
36 changes: 29 additions & 7 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,28 @@ def reinterpret_fp8_type(tensor: torch.Tensor, dtype: tl.dtype) -> TensorWrapper
return tl_reinterpret(tensor, dtype=dtype)


def map_dtype_to_triton(dtype: torch.dtype) -> tl.dtype:
"""
Maps torch dtype to triton dtype.
Args:
dtype (torch.dtype): input dtype.
Returns:
tl.dtype: triton dtype.
"""
if dtype == torch.float16:
return tl.float16
elif dtype == torch.bfloat16:
return tl.bfloat16
elif dtype == torch.float32:
return tl.float32
elif dtype == torch.int32:
return tl.int32
else:
raise ValueError(f"Unsupported dtype {dtype}")


def init_to_zero(name):
return lambda nargs: nargs[name].zero_()

Expand Down Expand Up @@ -746,6 +768,7 @@ def _kernel_matmul_fp8_row_tma_persistent(
stride_cn,
dot_out_dtype: tl.constexpr,
c_dtype: tl.constexpr,
bias_dtype: tl.constexpr,
allow_tf32: tl.constexpr,
fp8_fast_accum: tl.constexpr,
BLOCK_M: tl.constexpr,
Expand Down Expand Up @@ -813,7 +836,6 @@ def _kernel_matmul_fp8_row_tma_persistent(

dtype_fp8 = tl.float8e4nv
scale_dtype = tl.float32
bias_dtype = tl.float32

for _ in range(0, k_tiles * tiles_per_SM):
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
Expand Down Expand Up @@ -1110,6 +1132,10 @@ def persistent_grid_tma(META):
desc_b_scale = desc_helper.get_tma_descriptor_kernel_param("b_scale")
desc_bias = desc_helper.get_tma_descriptor_kernel_param("bias")

bias_dtype_triton = None
if bias is not None:
bias_dtype_triton = map_dtype_to_triton(bias.dtype)

# pyre-ignore
torch._library.capture_triton(_kernel_matmul_fp8_row_tma_persistent)[
persistent_grid_tma
Expand All @@ -1134,6 +1160,7 @@ def persistent_grid_tma(META):
c.stride(1),
dot_out_dtype=dot_out_dtype_triton,
c_dtype=c_dtype_triton,
bias_dtype=bias_dtype_triton,
allow_tf32=allow_tf32,
fp8_fast_accum=fp8_fast_accum,
GROUP_M=8,
Expand Down Expand Up @@ -1864,12 +1891,7 @@ def prep_matmul(
assert isinstance(
dot_out_dtype, torch.dtype
), f"dot_out_dtype type {type(dot_out_dtype)} must be a torch.dtype"
if dot_out_dtype == torch.bfloat16:
dot_out_dtype_triton = tl.bfloat16
elif dot_out_dtype == torch.float32:
dot_out_dtype_triton = tl.float32
else:
dot_out_dtype_triton = tl.int32
dot_out_dtype_triton = map_dtype_to_triton(dot_out_dtype)

return M, N, K, m_key, n_key, k_key, c, c_dtype_triton, dot_out_dtype_triton, device

Expand Down

0 comments on commit 4a4d187

Please sign in to comment.