Skip to content

Commit

Permalink
2024-10-07 nightly release (e7f89e4)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Oct 7, 2024
1 parent ed75e6c commit 03f302c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
10 changes: 6 additions & 4 deletions fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,13 @@ def setUp(self) -> None:

def test_quantize_fp8_row(self) -> None:
def _test_quantize_fp8_row(
shape: Tuple[int, int],
shape: Tuple[int, ...],
use_triton: bool,
device: torch.device,
output_device: Optional[torch.device] = None,
use_scale_ub: bool = False,
) -> None:
M, K = shape
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
a = torch.randn(shape, dtype=torch.bfloat16, device=device)

scale_ub = (
torch.tensor([1200], dtype=torch.float, device=device)
Expand All @@ -52,7 +51,8 @@ def _test_quantize_fp8_row(

# Undo scaling.
a_torch = a_fp8.to(torch.bfloat16)
a_torch *= a_scale[:, None]
broadcast_shape = list(a_torch.shape[:-1]) + [-1]
a_torch *= a_scale.view(broadcast_shape)

self.assertTrue(
torch.allclose(
Expand All @@ -61,6 +61,8 @@ def _test_quantize_fp8_row(
)

_test_quantize_fp8_row((2, 3), True, torch.device("cuda"))
# Test with batched input.
_test_quantize_fp8_row((4, 2, 3), True, torch.device("cuda"))
_test_quantize_fp8_row((2, 3), True, torch.device("cuda"), use_scale_ub=True)
_test_quantize_fp8_row((2, 3), False, torch.device("cpu"), torch.device("cuda"))
_test_quantize_fp8_row(
Expand Down
5 changes: 3 additions & 2 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2128,9 +2128,10 @@ def quantize_fp8_row_meta(
"""Shape function for torch compile."""
if output_device is None:
output_device = a.device
M, K = a.shape
# Flatten to 2D since each row of each potential batch gets a scale.
M = a.view(-1, a.shape[-1]).shape[0]
dtype = get_fp8_constants()[0]
fake_out = torch.empty((M, K), device=output_device, dtype=dtype)
fake_out = torch.empty(a.shape, device=output_device, dtype=dtype)
fake_scale = torch.empty((M), device=output_device, dtype=torch.float32)
return fake_out, fake_scale

Expand Down

0 comments on commit 03f302c

Please sign in to comment.