Skip to content

Commit

Permalink
Rowwise F8F8BF16 GEMMs - Add to Quantize Bench (pytorch#3220)
Browse files Browse the repository at this point in the history
Summary:

X-link: facebookresearch/FBGEMM#318

# Summary

- Add the F8F8BF16 Rowwise to Quantize Bench

Differential Revision: D63857729
  • Loading branch information
manishucsd authored and facebook-github-bot committed Oct 9, 2024
1 parent f14ca4d commit a2e31be
Showing 1 changed file with 55 additions and 1 deletion.
56 changes: 55 additions & 1 deletion fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,11 @@ def cuda(self) -> bool:
return True


####################################################################################################
# CUTLASS kernel v2
####################################################################################################


@register_quantize_op
class CutlassFP8TensorwiseGemm_v2(QuantizeOpBase):
"""
Expand All @@ -514,7 +518,12 @@ def quantize(self, x, w):
return xq, wq, x_scale, w_scale

def compute(self, xq, wq, x_scale, w_scale):
return torch.ops.cutlass_extensions.f8f8bf16(xq, wq, x_scale * w_scale)
if hasattr(torch.ops.cutlass_extensions, "f8f8bf16"):
return torch.ops.cutlass_extensions.f8f8bf16(xq, wq, x_scale * w_scale)
else:
raise RuntimeError(
"Skipping cutlass_extensions_v2 runs as it is not supported"
)

def quantize_and_compute(self, x, w):
xq, wq, x_scale, w_scale = self.quantize(x, w)
Expand All @@ -535,6 +544,51 @@ def cuda(self) -> bool:
return True


# CUTLASS kernel v2
@register_quantize_op
class CutlassFP8RowwiseGemm_v2(QuantizeOpBase):
"""
FP8 matmul with rowwise scaling.
"""

def quantize(self, x, w):
# Quantize both input tensors.
xq, x_scale = quantize_fp8_row(x)
wq, w_scale = quantize_fp8_row(w)
return xq, wq, x_scale, w_scale

def compute(self, xq, wq, x_scale, w_scale):
if hasattr(torch.ops.cutlass_extensions, "f8f8bf16_rowwise"):
return torch.ops.cutlass_extensions.f8f8bf16_rowwise(
xq, wq, x_scale, w_scale
)
else:
raise RuntimeError(
"Skipping cutlass_extensions_v2 runs as it is not supported"
)

def quantize_and_compute(self, x, w):
xq, wq, x_scale, w_scale = self.quantize(x, w)
return self.compute(xq, wq, x_scale, w_scale)

@property
def name(self) -> str:
return "cutlass_rowwise_v2"

@property
def hip(self) -> bool:
# Need to add support for better quantize kernel.
# Also may have an issue with cuda graphs.
return False

@property
def cuda(self) -> bool:
return True


####################################################################################################


@register_quantize_op
class F8I4RowwiseGemm(QuantizeOpBase):
"""
Expand Down

0 comments on commit a2e31be

Please sign in to comment.