Skip to content

Commit

Permalink
add separate quantization primitives for float8 (#1597)
Browse files Browse the repository at this point in the history
  • Loading branch information
danielvegamyhre authored Jan 25, 2025
1 parent 6b472e5 commit 47f96f1
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 0 deletions.
70 changes: 70 additions & 0 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,21 @@
import unittest

import torch
from parameterized import parameterized

from torchao.dtypes.utils import is_device
from torchao.float8.float8_utils import EPS as float8_eps
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
choose_qparams_affine,
choose_qparams_affine_float8,
dequantize_affine,
dequantize_affine_float8,
fake_quantize_affine,
fake_quantize_affine_cachemask,
quantize_affine,
quantize_affine_float8,
)

# TODO: remove test for utils?
Expand Down Expand Up @@ -838,6 +843,71 @@ def test_fake_quantize_affine_cachemask(self):
torch.testing.assert_close(dequantized, fake_quantized)
torch.testing.assert_close(expected_mask, mask)

@parameterized.expand(
[
(
torch.float32,
torch.float8_e4m3fn,
),
(
torch.float32,
torch.float8_e5m2,
),
(
torch.bfloat16,
torch.float8_e4m3fn,
),
(
torch.bfloat16,
torch.float8_e5m2,
),
]
)
def test_float8_quant_primitives(self, hp_dtype, float8_dtype):
input = torch.randn(10, 10)

# float8 quantization primitives
scale = choose_qparams_affine_float8(input, float8_dtype=float8_dtype)
quantized = quantize_affine_float8(input, scale, float8_dtype=float8_dtype)
dequantized = dequantize_affine_float8(quantized, scale, output_dtype=hp_dtype)

# reference implementation using generic primitives
expected_scale, _ = choose_qparams_affine(
input,
MappingType.SYMMETRIC,
input.shape,
float8_dtype,
eps=float8_eps, # use same EPS as float8 training
scale_dtype=torch.float32,
quant_min=torch.finfo(float8_dtype).min,
quant_max=torch.finfo(float8_dtype).max,
)
expected_quantized = quantize_affine(
input,
input.shape,
scale,
output_dtype=float8_dtype,
quant_min=torch.finfo(float8_dtype).min,
quant_max=torch.finfo(float8_dtype).max,
zero_point=None,
zero_point_domain=None,
)
expected_dequantized = dequantize_affine(
expected_quantized,
input.shape,
scale,
input_dtype=float8_dtype,
output_dtype=hp_dtype,
quant_min=torch.finfo(float8_dtype).min,
quant_max=torch.finfo(float8_dtype).max,
zero_point=None,
zero_point_domain=None,
)

self.assertTrue(torch.equal(expected_scale, scale))
torch.testing.assert_close(expected_quantized, quantized)
torch.testing.assert_close(expected_dequantized, dequantized)


if __name__ == "__main__":
unittest.main()
67 changes: 67 additions & 0 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
"MappingType",
"ZeroPointDomain",
"TorchAODType",
"choose_qparams_affine_float8",
"quantize_affine_float8",
"dequantize_affine_float8",
]


Expand Down Expand Up @@ -1300,3 +1303,67 @@ def dequantize_affine_floatx(
tensor = tensor * scale.float().view(-1, 1)
tensor = tensor.to(dtype=output_dtype)
return tensor


def choose_qparams_affine_float8(
tensor: torch.Tensor,
float8_dtype: torch.dtype = torch.float8_e4m3fn,
) -> torch.Tensor:
"""
Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity.
Args:
tensor (torch.Tensor): Input tensor to be quantized.
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
"""
# only tensorwise scaling is supported for now:
quant_min, quant_max = torch.finfo(float8_dtype).min, torch.finfo(float8_dtype).max
min_val_neg = torch.min(tensor)
max_val_pos = torch.max(tensor)
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
return scale.to(dtype=torch.float32)


def quantize_affine_float8(
tensor: torch.Tensor,
scale: torch.Tensor,
float8_dtype: torch.dtype = torch.float8_e4m3fn,
) -> torch.Tensor:
"""
Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor.
Args:
tensor (torch.Tensor): Input tensor to be quantized.
scale (torch.Tensor): Scaling factor for the quantization.
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
"""
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
# upcasted to `float32` to multiply with the scale, since scale is a fp32 tensor in float8 quantization.
# In order to match numerics between eager and compile, we upcast manually here.
tensor_scaled = tensor.to(torch.float32) / scale
max_value = torch.finfo(float8_dtype).max
tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value)
fp8_tensor = tensor_clamped.to(float8_dtype)
return fp8_tensor


def dequantize_affine_float8(
tensor: torch.Tensor,
scale: torch.Tensor,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Dequantizes the float8 tensor to high precision tensor.
Args:
tensor (torch.Tensor): Input float8 tensor to be dequantized.
scale (torch.Tensor): Scaling factor for the dequantization.
output_dtype (torch.dtype): Data type of the output tensor (e.g., torch.float32).
"""
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
# upcasted to `float32` to divide by the scale, since scale is a fp32 for float8 quantization.
# In order to match numerics between eager and compile, we upcast manually here.
fp8_tensor = tensor.to(torch.float32)
hp_tensor = fp8_tensor * scale
return hp_tensor.to(output_dtype)

0 comments on commit 47f96f1

Please sign in to comment.