Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate ZeroPointDomain.NONE & None zero point domains #1556

Merged
merged 7 commits into from
Jan 29, 2025
Merged
57 changes: 49 additions & 8 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
quantize_,
)
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
dequantize_affine,
)
from torchao.quantization.smoothquant import (
Expand Down Expand Up @@ -99,6 +101,10 @@

COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

ACT_MAPPING_TYPES = [MappingType.ASYMMETRIC, MappingType.SYMMETRIC]

WEIGHT_ZERO_POINT_DOMAINS = [ZeroPointDomain.NONE, ZeroPointDomain.INT]

COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()


Expand All @@ -118,9 +124,20 @@ def _int8wo_groupwise_api(mod):
quantize_(mod, int8_weight_only(group_size=group_size), set_inductor_config=False)


def _int8da_int8w_api(mod):
def _int8da_int8w_api(
mod,
act_mapping_type=MappingType.SYMMETRIC,
weight_zero_point_domain=ZeroPointDomain.INT,
):
if TORCH_VERSION_AT_LEAST_2_4:
quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
quantize_(
mod,
int8_dynamic_activation_int8_weight(
act_mapping_type=act_mapping_type,
weight_zp_domain=weight_zero_point_domain,
),
set_inductor_config=False,
)
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(mod)
else:
Expand Down Expand Up @@ -959,25 +976,49 @@ def _test_lin_weight_subclass_api_impl(
mod[0].weight.tensor_impl.get_plain()

test = mod(x)

self.assertGreater(
SQNR(ref_f, test),
min_sqnr,
f"{api.__name__} failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}",
f"API failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}",
)

mod_qc = torch.compile(mod, mode="max-autotune")
test_comp = mod_qc(x)
self.assertGreater(
SQNR(ref_f, test_comp),
min_sqnr,
f"{api.__name__} failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}",
f"API failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}",
drisspg marked this conversation as resolved.
Show resolved Hide resolved
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
def test_int8_dynamic_quant_subclass_api(self, device, dtype):
self._test_lin_weight_subclass_api_impl(
_int8da_int8w_api, device, 35, test_dtype=dtype
@parameterized.expand(
list(
itertools.product(
COMMON_DEVICES,
COMMON_DTYPES,
ACT_MAPPING_TYPES,
WEIGHT_ZERO_POINT_DOMAINS,
)
)
)
def test_int8_dynamic_quant_subclass_api(
self, device, dtype, act_mapping, weight_zero_point_domain
):
from functools import partial
drisspg marked this conversation as resolved.
Show resolved Hide resolved

if (
not TORCH_VERSION_AT_LEAST_2_5
and dtype in (torch.float16, torch.bfloat16)
and act_mapping is MappingType.ASYMMETRIC
and device == "cpu"
):
self.skipTest("Inductor-CPU codegen issue fixed in torch 2.5")
api = partial(
_int8da_int8w_api,
act_mapping_type=act_mapping,
weight_zero_point_domain=weight_zero_point_domain,
)
self._test_lin_weight_subclass_api_impl(api, device, 35, test_dtype=dtype)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(is_fbcode(), "broken in fbcode")
Expand Down
15 changes: 7 additions & 8 deletions test/quantization/test_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)


Expand Down Expand Up @@ -74,7 +75,7 @@ def test_block_size_calc_success(self):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
example_inputs = [
torch.randn(10, 2048),
Expand All @@ -93,7 +94,7 @@ def test_block_size_calc_success(self):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
for example_input in example_inputs:
obs(example_input)
Expand All @@ -108,7 +109,7 @@ def test_block_size_row_errors(self):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
example_inputs = [
torch.randn(10, 2048),
Expand All @@ -127,7 +128,7 @@ def test_block_size_row_errors(self):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
example_inputs = [
torch.randn(10, 2048),
Expand Down Expand Up @@ -155,7 +156,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
if observe_weight:
weight_observer = AffineQuantizedMinMaxObserver(
Expand All @@ -165,7 +166,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
else:
weight_observer = None
Expand Down Expand Up @@ -199,7 +200,6 @@ def test_linear_observer_tensor(self, observe_weight: bool):
input_scale.item(),
max_val / max_fp8,
)
self.assertIsNotNone(input_zero_point)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a change of behavior when you change zero_point_domain for None to ZeroPointDomain.NONE?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, input_zero_point would now be None. So, instead of removing that line, I now added self.assertIsNone(input_zero_point). Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so what is the meaning of zero_point_domain == None before?

Copy link
Contributor Author

@sanchitintel sanchitintel Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like some APIs/implementations were creating/expecting a None zero_point when zero_point_domain ZeroPointDomain.NONE or None was used, while choose_qparams_affine was not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jerryzh168, is it possible that some torchao users' code may be expecting a non-None zero_point with zero_point_domain ZeroPointDomain.NONE/None, making this change BC-breaking for them? Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that most usages of this function are internal to torchao so that its okay to BC break, you can add the label just to be sure

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your advice, @drisspg!
Could you please help add such a label as GitHub isn't displaying an option to me for adding it? Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a bc-breaking label, please also write a bc-breaking note similar to #1049

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again, @jerryzh168! I added a note & rebased the PR.


if observe_weight:
weight_observer = linear.weight.weight_observer
Expand All @@ -210,7 +210,6 @@ def test_linear_observer_tensor(self, observe_weight: bool):
atol=5e-5,
rtol=0.0,
)
self.assertIsNotNone(weight_zero_point)
else:
self.assertIsNone(linear.weight.weight_observer)

Expand Down
26 changes: 26 additions & 0 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,32 @@ def test_fake_quantize_affine_cachemask(self):
torch.testing.assert_close(dequantized, fake_quantized)
torch.testing.assert_close(expected_mask, mask)

# ZeroPointDomain.NONE should work
def test_none_zero_point_domain(self):
sanchitintel marked this conversation as resolved.
Show resolved Hide resolved
input = torch.randn(10, 256)
mapping_type = MappingType.SYMMETRIC
dtype = torch.int8
block_size = (1, 128)
quant_min = None
quant_max = None
eps = 1e-6
scale_dtype = torch.float32
zero_point_dtype = torch.int64
_, zero_point = choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=True,
zero_point_domain=ZeroPointDomain.NONE,
)
self.assertTrue(zero_point is None)


if __name__ == "__main__":
unittest.main()
7 changes: 3 additions & 4 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,7 @@ def from_hp_to_intx(
zero_point_domain,
)
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None
# TODO should probably consolidate ZeroPointDomain.NONE and None
if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE:
if zero_point_domain == ZeroPointDomain.NONE:
zero_point = None
data = quantize_affine(
input_float,
Expand Down Expand Up @@ -360,7 +359,7 @@ def from_hp_to_floatx(
scale_dtype=scale_dtype,
zero_point_dtype=None,
preserve_zero=True,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
_layout=_layout,
use_hqq=False,
)
Expand All @@ -387,7 +386,7 @@ def from_hp_to_floatx_static(
target_dtype=target_dtype,
quant_min=math.ceil(torch.finfo(target_dtype).min),
quant_max=math.ceil(torch.finfo(target_dtype).max),
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
_layout=_layout,
)
else:
Expand Down
8 changes: 5 additions & 3 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def insert_observers_(
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)

# Create a linear module
Expand Down Expand Up @@ -688,7 +688,7 @@ def int4_weight_only(
group_size=128,
layout=TensorCoreTiledLayout(inner_k_tiles=8),
use_hqq=False,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
):
"""
Applies uint4 weight-only asymmetric per-group quantization to linear layers, using
Expand Down Expand Up @@ -731,7 +731,7 @@ def apply_int4_weight_only_quant(weight):
assert (
type(layout) in LAYOUT_TO_ZERO_POINT_DOMAIN.keys()
), f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}"
if zero_point_domain is None:
if zero_point_domain == ZeroPointDomain.NONE:
# the first value is the default one
zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)][0]
else:
Expand Down Expand Up @@ -857,6 +857,7 @@ def int8_dynamic_activation_int8_weight(
layout=PlainLayout(),
act_mapping_type=MappingType.SYMMETRIC,
weight_only_decode=False,
weight_zp_domain=ZeroPointDomain.NONE,
sanchitintel marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Applies int8 dynamic symmetric per-token activation and int8 per-channel weight
Expand Down Expand Up @@ -901,6 +902,7 @@ def get_weight_block_size(x):
eps=eps,
zero_point_dtype=zero_point_dtype,
_layout=layout,
zero_point_domain=weight_zp_domain,
)
weight = to_linear_activation_quantized(weight, input_quant_func)
return weight
Expand Down
31 changes: 13 additions & 18 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def _quantize_affine(
zero_point,
quant_min,
quant_max,
output_dtype,
zero_point_domain,
).to(output_dtype)

Expand All @@ -371,6 +372,7 @@ def _quantize_affine_no_dtype_cast(
zero_point: Optional[torch.Tensor],
quant_min: Union[int, float],
quant_max: Union[int, float],
quant_dtype: Optional[torch.dtype],
zero_point_domain: Optional[str] = ZeroPointDomain.INT.name,
drisspg marked this conversation as resolved.
Show resolved Hide resolved
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -415,13 +417,12 @@ def _quantize_affine_no_dtype_cast(
assert (
zero_point is None
), "zero_point should be None when zero_point_domain is NONE"
quant = torch.clamp(torch.round(input * (1.0 / scale)), quant_min, quant_max)
elif zero_point_domain is None:
# This case handles quantization for float8 we expect no zero point and no zero point domain
assert (
zero_point is None
), "zero_point should be None when zero_point_domain is None"
quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max)
if _is_float8_type(quant_dtype):
quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max)
else:
quant = torch.clamp(
torch.round(input * (1.0 / scale)), quant_min, quant_max
)
else:
assert zero_point_domain == ZeroPointDomain.FLOAT.name
mid_point = (quant_max + quant_min + 1) / 2
Expand Down Expand Up @@ -564,16 +565,6 @@ def _dequantize_affine_no_dtype_check(
), "zero_point should be None when zero_point_domain is NONE"
dequant = input.to(output_dtype)
dequant = dequant * scale
elif zero_point_domain is None:
# This case handles dequantization for float8 we expect no zero point and no zero point domain
assert (
zero_point is None
), "zero_point should be None when zero_point_domain is None"
assert _is_float8_type(
input.dtype
), f"dequantiztion with no zero point domain is only supported with FP8 types, got {input.dtype}"
dequant = input.to(output_dtype)
dequant = dequant * scale
else:
assert (
zero_point_domain == ZeroPointDomain.FLOAT.name
Expand Down Expand Up @@ -700,6 +691,7 @@ def _do_fake_quantize_affine(
zero_point,
quant_min,
quant_max,
quant_dtype,
zero_point_domain.name,
)
dq = _dequantize_affine_no_dtype_check(
Expand Down Expand Up @@ -927,8 +919,11 @@ def _choose_qparams_affine(
raise ValueError(
"zero_point_domain should be ZeroPointDomain.INT or ZeroPointDomain.NONE for symmetric quantization"
)
if zero_point_domain == ZeroPointDomain.NONE.name:
zero_point = None
else:
zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2))
scale = torch.clamp(scale, min=eps)
zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2))
else:
assert mapping_type == MappingType.ASYMMETRIC.name
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
Expand Down
Loading