From f07018bfbfa1649edaeb32bc491c9379c865e31e Mon Sep 17 00:00:00 2001 From: Manish Gupta Date: Fri, 4 Oct 2024 10:14:08 -0700 Subject: [PATCH] Rowwise F8F8BF16 GEMMs - Add to Quantize Bench (#3220) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/318 # Summary - Add the F8F8BF16 Rowwise to Quantize Bench Differential Revision: D63857729 --- .../experimental/gen_ai/bench/quantize_ops.py | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py index 65b34a956..8673bdb34 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py @@ -500,7 +500,11 @@ def cuda(self) -> bool: return True +#################################################################################################### # CUTLASS kernel v2 +#################################################################################################### + + @register_quantize_op class CutlassFP8TensorwiseGemm_v2(QuantizeOpBase): """ @@ -535,6 +539,44 @@ def cuda(self) -> bool: return True +# CUTLASS kernel v2 +@register_quantize_op +class CutlassFP8RowwiseGemm_v2(QuantizeOpBase): + """ + FP8 matmul with rowwise scaling. + """ + + def quantize(self, x, w): + # Quantize both input tensors. + xq, x_scale = quantize_fp8_row(x) + wq, w_scale = quantize_fp8_row(w) + return xq, wq, x_scale, w_scale + + def compute(self, xq, wq, x_scale, w_scale): + return torch.ops.cutlass_extensions.f8f8bf16_rowwise(xq, wq, x_scale, w_scale) + + def quantize_and_compute(self, x, w): + xq, wq, x_scale, w_scale = self.quantize(x, w) + return self.compute(xq, wq, x_scale, w_scale) + + @property + def name(self) -> str: + return "cutlass_rowwise_v2" + + @property + def hip(self) -> bool: + # Need to add support for better quantize kernel. + # Also may have an issue with cuda graphs. + return False + + @property + def cuda(self) -> bool: + return True + + +#################################################################################################### + + @register_quantize_op class F8I4RowwiseGemm(QuantizeOpBase): """