From da2e9e0dcc5a9d21d2c97a0f85cdb3e1034900bf Mon Sep 17 00:00:00 2001 From: sanchitintel Date: Tue, 14 Jan 2025 22:47:45 -0800 Subject: [PATCH 1/6] Fix ZeroPointDomain.NONE support & make it default for da8w8 weights --- test/integration/test_integration.py | 57 +++++++++++++++++++--- test/quantization/test_observer.py | 15 +++--- test/quantization/test_quant_primitives.py | 26 ++++++++++ torchao/dtypes/affine_quantized_tensor.py | 7 ++- torchao/quantization/quant_api.py | 8 +-- torchao/quantization/quant_primitives.py | 31 +++++------- 6 files changed, 103 insertions(+), 41 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index bcd8af7ad3..7e9787f07f 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -45,6 +45,8 @@ quantize_, ) from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, dequantize_affine, ) from torchao.quantization.smoothquant import ( @@ -99,6 +101,10 @@ COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] +ACT_MAPPING_TYPES = [MappingType.ASYMMETRIC, MappingType.SYMMETRIC] + +WEIGHT_ZERO_POINT_DOMAINS = [ZeroPointDomain.NONE, ZeroPointDomain.INT] + COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy() @@ -118,9 +124,20 @@ def _int8wo_groupwise_api(mod): quantize_(mod, int8_weight_only(group_size=group_size), set_inductor_config=False) -def _int8da_int8w_api(mod): +def _int8da_int8w_api( + mod, + act_mapping_type=MappingType.SYMMETRIC, + weight_zero_point_domain=ZeroPointDomain.INT, +): if TORCH_VERSION_AT_LEAST_2_4: - quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False) + quantize_( + mod, + int8_dynamic_activation_int8_weight( + act_mapping_type=act_mapping_type, + weight_zp_domain=weight_zero_point_domain, + ), + set_inductor_config=False, + ) if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(mod) else: @@ -959,10 +976,11 @@ def _test_lin_weight_subclass_api_impl( mod[0].weight.tensor_impl.get_plain() test = mod(x) + self.assertGreater( SQNR(ref_f, test), min_sqnr, - f"{api.__name__} failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}", + f"API failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}", ) mod_qc = torch.compile(mod, mode="max-autotune") @@ -970,14 +988,37 @@ def _test_lin_weight_subclass_api_impl( self.assertGreater( SQNR(ref_f, test_comp), min_sqnr, - f"{api.__name__} failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}", + f"API failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}", ) - @parameterized.expand(COMMON_DEVICE_DTYPE) - def test_int8_dynamic_quant_subclass_api(self, device, dtype): - self._test_lin_weight_subclass_api_impl( - _int8da_int8w_api, device, 35, test_dtype=dtype + @parameterized.expand( + list( + itertools.product( + COMMON_DEVICES, + COMMON_DTYPES, + ACT_MAPPING_TYPES, + WEIGHT_ZERO_POINT_DOMAINS, + ) + ) + ) + def test_int8_dynamic_quant_subclass_api( + self, device, dtype, act_mapping, weight_zero_point_domain + ): + from functools import partial + + if ( + not TORCH_VERSION_AT_LEAST_2_5 + and dtype in (torch.float16, torch.bfloat16) + and act_mapping is MappingType.ASYMMETRIC + and device == "cpu" + ): + self.skipTest("Inductor-CPU codegen issue fixed in torch 2.5") + api = partial( + _int8da_int8w_api, + act_mapping_type=act_mapping, + weight_zero_point_domain=weight_zero_point_domain, ) + self._test_lin_weight_subclass_api_impl(api, device, 35, test_dtype=dtype) @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(is_fbcode(), "broken in fbcode") diff --git a/test/quantization/test_observer.py b/test/quantization/test_observer.py index 0526ee01b2..8ec15eb201 100644 --- a/test/quantization/test_observer.py +++ b/test/quantization/test_observer.py @@ -21,6 +21,7 @@ ) from torchao.quantization.quant_primitives import ( MappingType, + ZeroPointDomain, ) @@ -74,7 +75,7 @@ def test_block_size_calc_success(self): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) example_inputs = [ torch.randn(10, 2048), @@ -93,7 +94,7 @@ def test_block_size_calc_success(self): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) for example_input in example_inputs: obs(example_input) @@ -108,7 +109,7 @@ def test_block_size_row_errors(self): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) example_inputs = [ torch.randn(10, 2048), @@ -127,7 +128,7 @@ def test_block_size_row_errors(self): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) example_inputs = [ torch.randn(10, 2048), @@ -155,7 +156,7 @@ def test_linear_observer_tensor(self, observe_weight: bool): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) if observe_weight: weight_observer = AffineQuantizedMinMaxObserver( @@ -165,7 +166,7 @@ def test_linear_observer_tensor(self, observe_weight: bool): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) else: weight_observer = None @@ -199,7 +200,6 @@ def test_linear_observer_tensor(self, observe_weight: bool): input_scale.item(), max_val / max_fp8, ) - self.assertIsNotNone(input_zero_point) if observe_weight: weight_observer = linear.weight.weight_observer @@ -210,7 +210,6 @@ def test_linear_observer_tensor(self, observe_weight: bool): atol=5e-5, rtol=0.0, ) - self.assertIsNotNone(weight_zero_point) else: self.assertIsNone(linear.weight.weight_observer) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 102e76cb1a..00fe300864 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -838,6 +838,32 @@ def test_fake_quantize_affine_cachemask(self): torch.testing.assert_close(dequantized, fake_quantized) torch.testing.assert_close(expected_mask, mask) + # ZeroPointDomain.NONE should work + def test_none_zero_point_domain(self): + input = torch.randn(10, 256) + mapping_type = MappingType.SYMMETRIC + dtype = torch.int8 + block_size = (1, 128) + quant_min = None + quant_max = None + eps = 1e-6 + scale_dtype = torch.float32 + zero_point_dtype = torch.int64 + _, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=True, + zero_point_domain=ZeroPointDomain.NONE, + ) + self.assertTrue(zero_point is None) + if __name__ == "__main__": unittest.main() diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index e7aca34c5f..9d3da97810 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -261,8 +261,7 @@ def from_hp_to_intx( zero_point_domain, ) # choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None - # TODO should probably consolidate ZeroPointDomain.NONE and None - if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE: + if zero_point_domain == ZeroPointDomain.NONE: zero_point = None data = quantize_affine( input_float, @@ -360,7 +359,7 @@ def from_hp_to_floatx( scale_dtype=scale_dtype, zero_point_dtype=None, preserve_zero=True, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, _layout=_layout, use_hqq=False, ) @@ -387,7 +386,7 @@ def from_hp_to_floatx_static( target_dtype=target_dtype, quant_min=math.ceil(torch.finfo(target_dtype).min), quant_max=math.ceil(torch.finfo(target_dtype).max), - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, _layout=_layout, ) else: diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index b2eff196fd..184b96334c 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -387,7 +387,7 @@ def insert_observers_( eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) # Create a linear module @@ -688,7 +688,7 @@ def int4_weight_only( group_size=128, layout=TensorCoreTiledLayout(inner_k_tiles=8), use_hqq=False, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ): """ Applies uint4 weight-only asymmetric per-group quantization to linear layers, using @@ -731,7 +731,7 @@ def apply_int4_weight_only_quant(weight): assert ( type(layout) in LAYOUT_TO_ZERO_POINT_DOMAIN.keys() ), f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}" - if zero_point_domain is None: + if zero_point_domain == ZeroPointDomain.NONE: # the first value is the default one zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)][0] else: @@ -857,6 +857,7 @@ def int8_dynamic_activation_int8_weight( layout=PlainLayout(), act_mapping_type=MappingType.SYMMETRIC, weight_only_decode=False, + weight_zp_domain=ZeroPointDomain.NONE, ): """ Applies int8 dynamic symmetric per-token activation and int8 per-channel weight @@ -901,6 +902,7 @@ def get_weight_block_size(x): eps=eps, zero_point_dtype=zero_point_dtype, _layout=layout, + zero_point_domain=weight_zp_domain, ) weight = to_linear_activation_quantized(weight, input_quant_func) return weight diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index e587d4bc2b..61b508bdc0 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -360,6 +360,7 @@ def _quantize_affine( zero_point, quant_min, quant_max, + output_dtype, zero_point_domain, ).to(output_dtype) @@ -371,6 +372,7 @@ def _quantize_affine_no_dtype_cast( zero_point: Optional[torch.Tensor], quant_min: Union[int, float], quant_max: Union[int, float], + quant_dtype: Optional[torch.dtype], zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, ) -> torch.Tensor: """ @@ -415,13 +417,12 @@ def _quantize_affine_no_dtype_cast( assert ( zero_point is None ), "zero_point should be None when zero_point_domain is NONE" - quant = torch.clamp(torch.round(input * (1.0 / scale)), quant_min, quant_max) - elif zero_point_domain is None: - # This case handles quantization for float8 we expect no zero point and no zero point domain - assert ( - zero_point is None - ), "zero_point should be None when zero_point_domain is None" - quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max) + if _is_float8_type(quant_dtype): + quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max) + else: + quant = torch.clamp( + torch.round(input * (1.0 / scale)), quant_min, quant_max + ) else: assert zero_point_domain == ZeroPointDomain.FLOAT.name mid_point = (quant_max + quant_min + 1) / 2 @@ -564,16 +565,6 @@ def _dequantize_affine_no_dtype_check( ), "zero_point should be None when zero_point_domain is NONE" dequant = input.to(output_dtype) dequant = dequant * scale - elif zero_point_domain is None: - # This case handles dequantization for float8 we expect no zero point and no zero point domain - assert ( - zero_point is None - ), "zero_point should be None when zero_point_domain is None" - assert _is_float8_type( - input.dtype - ), f"dequantiztion with no zero point domain is only supported with FP8 types, got {input.dtype}" - dequant = input.to(output_dtype) - dequant = dequant * scale else: assert ( zero_point_domain == ZeroPointDomain.FLOAT.name @@ -700,6 +691,7 @@ def _do_fake_quantize_affine( zero_point, quant_min, quant_max, + quant_dtype, zero_point_domain.name, ) dq = _dequantize_affine_no_dtype_check( @@ -927,8 +919,11 @@ def _choose_qparams_affine( raise ValueError( "zero_point_domain should be ZeroPointDomain.INT or ZeroPointDomain.NONE for symmetric quantization" ) + if zero_point_domain == ZeroPointDomain.NONE.name: + zero_point = None + else: + zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) scale = torch.clamp(scale, min=eps) - zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) else: assert mapping_type == MappingType.ASYMMETRIC.name scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) From bbc8dcd8c0cc5973936618564d8a59514fb87b12 Mon Sep 17 00:00:00 2001 From: sanchitintel Date: Wed, 15 Jan 2025 11:14:24 -0800 Subject: [PATCH 2/6] Fix bug & apply review recommendations --- test/integration/test_integration.py | 3 +-- torchao/quantization/quant_primitives.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 7e9787f07f..4d39c4d9ae 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -10,6 +10,7 @@ import logging import os import unittest +from functools import partial import torch import torch.nn as nn @@ -1004,8 +1005,6 @@ def _test_lin_weight_subclass_api_impl( def test_int8_dynamic_quant_subclass_api( self, device, dtype, act_mapping, weight_zero_point_domain ): - from functools import partial - if ( not TORCH_VERSION_AT_LEAST_2_5 and dtype in (torch.float16, torch.bfloat16) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 61b508bdc0..949afc968f 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -372,7 +372,7 @@ def _quantize_affine_no_dtype_cast( zero_point: Optional[torch.Tensor], quant_min: Union[int, float], quant_max: Union[int, float], - quant_dtype: Optional[torch.dtype], + quant_dtype: torch.dtype, zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, ) -> torch.Tensor: """ From 4956a2ea77aebdf87f62a3f5c310dc40a64c76b3 Mon Sep 17 00:00:00 2001 From: sanchitintel Date: Fri, 17 Jan 2025 12:49:32 -0800 Subject: [PATCH 3/6] Throw exceptions when None zero_point_domain is used --- test/quantization/test_observer.py | 2 + test/quantization/test_quant_primitives.py | 53 +++++++++++++------ torchao/dtypes/affine_quantized_tensor.py | 11 ++-- torchao/dtypes/uintx/marlin_qqq_tensor.py | 4 +- torchao/quantization/observer.py | 5 +- .../qat/affine_fake_quantized_tensor.py | 5 ++ torchao/quantization/qat/api.py | 2 + torchao/quantization/quant_primitives.py | 35 +++++++----- 8 files changed, 81 insertions(+), 36 deletions(-) diff --git a/test/quantization/test_observer.py b/test/quantization/test_observer.py index 8ec15eb201..4567f3baef 100644 --- a/test/quantization/test_observer.py +++ b/test/quantization/test_observer.py @@ -200,6 +200,7 @@ def test_linear_observer_tensor(self, observe_weight: bool): input_scale.item(), max_val / max_fp8, ) + self.assertIsNone(input_zero_point) if observe_weight: weight_observer = linear.weight.weight_observer @@ -210,6 +211,7 @@ def test_linear_observer_tensor(self, observe_weight: bool): atol=5e-5, rtol=0.0, ) + self.assertIsNone(weight_zero_point) else: self.assertIsNone(linear.weight.weight_observer) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 00fe300864..9a97218077 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -838,8 +838,8 @@ def test_fake_quantize_affine_cachemask(self): torch.testing.assert_close(dequantized, fake_quantized) torch.testing.assert_close(expected_mask, mask) - # ZeroPointDomain.NONE should work def test_none_zero_point_domain(self): + """A None value for a ZeroPointDomain should not work, but ZeroPointDomain.NONE should""" input = torch.randn(10, 256) mapping_type = MappingType.SYMMETRIC dtype = torch.int8 @@ -849,20 +849,43 @@ def test_none_zero_point_domain(self): eps = 1e-6 scale_dtype = torch.float32 zero_point_dtype = torch.int64 - _, zero_point = choose_qparams_affine( - input, - mapping_type, - block_size, - dtype, - quant_min, - quant_max, - eps, - scale_dtype=scale_dtype, - zero_point_dtype=zero_point_dtype, - preserve_zero=True, - zero_point_domain=ZeroPointDomain.NONE, - ) - self.assertTrue(zero_point is None) + try: + _, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=True, + zero_point_domain=None, + ) + except ValueError: + # This exception was expected + # Now test for ZeroPointDomain.NONE + _, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=True, + zero_point_domain=ZeroPointDomain.NONE, + ) + self.assertTrue(zero_point is None) + else: + # An exception should have been thrown for zero_point_domain None + self.assertTrue( + False, + msg="A runtime exception should have been thrown for zero_point_domain None", + ) if __name__ == "__main__": diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 9d3da97810..8bb061ecb5 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -85,6 +85,8 @@ def __new__( dtype=None, strides=None, ): + if zero_point_domain is None: + raise ValueError("please use ZeroPointDomain.NONE instead of None") kwargs = {} kwargs["device"] = tensor_impl.device kwargs["layout"] = ( @@ -203,7 +205,7 @@ def from_hp_to_intx( scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, - zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, _layout: Layout = PlainLayout(), use_hqq: bool = False, ): @@ -298,13 +300,12 @@ def from_hp_to_intx_static( target_dtype: torch.dtype, quant_min: Optional[int] = None, quant_max: Optional[int] = None, - zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, _layout: Layout = PlainLayout(), ): + if zero_point_domain is None: + raise ValueError("please use ZeroPointDomain.NONE instead of None") if target_dtype not in FP8_TYPES: - assert ( - zero_point_domain is not None - ), "zero_point_domain must be specified for non-fp8 types" assert ( zero_point is not None ), "zero_point must be specified for non-fp8 types" diff --git a/torchao/dtypes/uintx/marlin_qqq_tensor.py b/torchao/dtypes/uintx/marlin_qqq_tensor.py index b75d959b41..bd28b7123d 100644 --- a/torchao/dtypes/uintx/marlin_qqq_tensor.py +++ b/torchao/dtypes/uintx/marlin_qqq_tensor.py @@ -55,9 +55,11 @@ def from_hp_to_intx( block_size: Tuple[int, ...], quant_min: Optional[int] = None, quant_max: Optional[int] = None, - zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, _layout: Optional[Layout] = None, ): + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") original_shape = input_float.shape input_float = _layout.pre_process(input_float) nbits = int(math.log2(quant_max - quant_min + 1)) diff --git a/torchao/quantization/observer.py b/torchao/quantization/observer.py index 06509c7b91..cbbe1b581d 100644 --- a/torchao/quantization/observer.py +++ b/torchao/quantization/observer.py @@ -104,11 +104,12 @@ def __init__( scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, - zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ): super().__init__() assert granularity is not None, "granularity is None" - + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") self.mapping_type = mapping_type self.target_dtype = target_dtype self.granularity = granularity diff --git a/torchao/quantization/qat/affine_fake_quantized_tensor.py b/torchao/quantization/qat/affine_fake_quantized_tensor.py index b84200ac9c..f60c858b73 100644 --- a/torchao/quantization/qat/affine_fake_quantized_tensor.py +++ b/torchao/quantization/qat/affine_fake_quantized_tensor.py @@ -41,6 +41,9 @@ def forward( preserve_zero: bool = True, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> "AffineFakeQuantizedTensor": + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") + def apply_fake_quant_fn(t: torch.Tensor): assert isinstance(t, AffineFakeQuantizedTensor) qmin, qmax = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) @@ -158,6 +161,8 @@ def from_float( preserve_zero: bool = True, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ): + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") return _ToAffineFakeQuantized.apply( original_input, mapping_type, diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py index cd3813291f..925a0eed3c 100644 --- a/torchao/quantization/qat/api.py +++ b/torchao/quantization/qat/api.py @@ -96,6 +96,8 @@ def __init__( group_size: Optional[int] = None, is_symmetric: Optional[bool] = None, ): + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") self.dtype = dtype self.granularity = self._get_granularity(granularity, group_size) self.mapping_type = self._get_mapping_type(mapping_type, is_symmetric) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 949afc968f..12ec2360f9 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -281,7 +281,7 @@ def quantize_affine( output_dtype: torch.dtype, quant_min: Optional[Union[int, float]] = None, quant_max: Optional[Union[int, float]] = None, - zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> torch.Tensor: """ Args: @@ -316,6 +316,8 @@ def quantize_affine( Output: quantized tensor with requested dtype """ + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") return _quantize_affine( input, block_size, @@ -324,7 +326,7 @@ def quantize_affine( output_dtype, quant_min, quant_max, - zero_point_domain.name if zero_point_domain is not None else None, + zero_point_domain.name, ) @@ -337,7 +339,7 @@ def _quantize_affine( output_dtype: torch.dtype, quant_min: Optional[Union[int, float, bool]] = None, quant_max: Optional[Union[int, float, bool]] = None, - zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, + zero_point_domain: str = ZeroPointDomain.INT.name, ) -> torch.Tensor: """op definition that has compatible signatures with custom op library @@ -373,7 +375,7 @@ def _quantize_affine_no_dtype_cast( quant_min: Union[int, float], quant_max: Union[int, float], quant_dtype: torch.dtype, - zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, + zero_point_domain: str = ZeroPointDomain.INT.name, ) -> torch.Tensor: """ The op does the following: @@ -468,6 +470,8 @@ def dequantize_affine( Output: dequantized Tensor, with requested dtype or fp32 """ + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") return _dequantize_affine( input, block_size, @@ -476,7 +480,7 @@ def dequantize_affine( input_dtype, quant_min, quant_max, - zero_point_domain.name if zero_point_domain is not None else None, + zero_point_domain.name, output_dtype=output_dtype, ) @@ -612,6 +616,8 @@ def fake_quantize_affine( value during quantization default is ZeroPointDomain.INT """ + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") (_, fq) = _do_fake_quantize_affine( input, block_size, @@ -654,6 +660,8 @@ def fake_quantize_affine_cachemask( ) """ + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") (q, dq) = _do_fake_quantize_affine( input, block_size, @@ -719,7 +727,7 @@ def choose_qparams_affine( scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, - zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -753,6 +761,8 @@ def choose_qparams_affine( Output: Tuple of scales and zero_points Tensor with requested dtype """ + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") return _choose_qparams_affine( input, mapping_type.name, @@ -764,7 +774,7 @@ def choose_qparams_affine( scale_dtype, zero_point_dtype, preserve_zero, - zero_point_domain.name if zero_point_domain is not None else None, + zero_point_domain.name, ) @@ -780,7 +790,7 @@ def choose_qparams_affine_with_min_max( scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, - zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> Tuple[torch.Tensor, torch.Tensor]: """A variant of :func:`~torchao.quantization.quant_primitives.choose_qparams_affine` operator that pass in min_val and max_val directly instead of deriving these from a single input. @@ -792,6 +802,8 @@ def choose_qparams_affine_with_min_max( difference: instead of passing in `input` Tensor and use that to calculate min_val/max_val and then scale/zero_point, we pass in min_val/max_val directly """ + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") return _choose_qparams_affine( None, mapping_type.name, @@ -803,7 +815,7 @@ def choose_qparams_affine_with_min_max( scale_dtype, zero_point_dtype, preserve_zero, - zero_point_domain.name if zero_point_domain is not None else None, + zero_point_domain.name, min_val, max_val, ) @@ -910,10 +922,7 @@ def _choose_qparams_affine( raise ValueError( "preserve_zero == False is not supported for symmetric quantization" ) - if ( - zero_point_domain is not None - and zero_point_domain == ZeroPointDomain.FLOAT.name - ): + if zero_point_domain == ZeroPointDomain.FLOAT.name: # TODO INT should not be a valid ZeroPointDomain for symmetric quantization since # symmetric quant doesn't have a zero_point raise ValueError( From 8116c0cf27fea79cec907c3de75dd3af8a0d3286 Mon Sep 17 00:00:00 2001 From: sanchitintel Date: Fri, 17 Jan 2025 13:23:15 -0800 Subject: [PATCH 4/6] Use ZeroPointDomain.NONE for weight in int8_dynamic_activation_int8_weight --- test/integration/test_integration.py | 11 +---------- torchao/quantization/quant_api.py | 4 ++-- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 4d39c4d9ae..73f57b940b 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -47,7 +47,6 @@ ) from torchao.quantization.quant_primitives import ( MappingType, - ZeroPointDomain, dequantize_affine, ) from torchao.quantization.smoothquant import ( @@ -104,8 +103,6 @@ ACT_MAPPING_TYPES = [MappingType.ASYMMETRIC, MappingType.SYMMETRIC] -WEIGHT_ZERO_POINT_DOMAINS = [ZeroPointDomain.NONE, ZeroPointDomain.INT] - COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy() @@ -128,14 +125,12 @@ def _int8wo_groupwise_api(mod): def _int8da_int8w_api( mod, act_mapping_type=MappingType.SYMMETRIC, - weight_zero_point_domain=ZeroPointDomain.INT, ): if TORCH_VERSION_AT_LEAST_2_4: quantize_( mod, int8_dynamic_activation_int8_weight( act_mapping_type=act_mapping_type, - weight_zp_domain=weight_zero_point_domain, ), set_inductor_config=False, ) @@ -998,13 +993,10 @@ def _test_lin_weight_subclass_api_impl( COMMON_DEVICES, COMMON_DTYPES, ACT_MAPPING_TYPES, - WEIGHT_ZERO_POINT_DOMAINS, ) ) ) - def test_int8_dynamic_quant_subclass_api( - self, device, dtype, act_mapping, weight_zero_point_domain - ): + def test_int8_dynamic_quant_subclass_api(self, device, dtype, act_mapping): if ( not TORCH_VERSION_AT_LEAST_2_5 and dtype in (torch.float16, torch.bfloat16) @@ -1015,7 +1007,6 @@ def test_int8_dynamic_quant_subclass_api( api = partial( _int8da_int8w_api, act_mapping_type=act_mapping, - weight_zero_point_domain=weight_zero_point_domain, ) self._test_lin_weight_subclass_api_impl(api, device, 35, test_dtype=dtype) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 184b96334c..3408025715 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -857,7 +857,6 @@ def int8_dynamic_activation_int8_weight( layout=PlainLayout(), act_mapping_type=MappingType.SYMMETRIC, weight_only_decode=False, - weight_zp_domain=ZeroPointDomain.NONE, ): """ Applies int8 dynamic symmetric per-token activation and int8 per-channel weight @@ -876,6 +875,7 @@ def apply_int8_dynamic_activation_int8_weight_quant(weight): # weight settings mapping_type = MappingType.SYMMETRIC + weight_zero_point_domain = ZeroPointDomain.NONE def get_weight_block_size(x): return (1, x.shape[1]) @@ -902,7 +902,7 @@ def get_weight_block_size(x): eps=eps, zero_point_dtype=zero_point_dtype, _layout=layout, - zero_point_domain=weight_zp_domain, + zero_point_domain=weight_zero_point_domain, ) weight = to_linear_activation_quantized(weight, input_quant_func) return weight From 7b9477d827aa8b83b808f7ddfc3777848b0fbf60 Mon Sep 17 00:00:00 2001 From: sanchitintel Date: Sun, 26 Jan 2025 22:22:02 -0800 Subject: [PATCH 5/6] Rebase with the latest main branch --- test/quantization/test_quant_primitives.py | 4 ++-- torchao/dtypes/affine_quantized_tensor.py | 2 ++ torchao/quantization/quant_primitives.py | 8 ++++++++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index e4cda23ea8..3ca58ff996 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -939,7 +939,7 @@ def test_float8_quant_primitives(self, hp_dtype, float8_dtype): quant_min=torch.finfo(float8_dtype).min, quant_max=torch.finfo(float8_dtype).max, zero_point=None, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) expected_dequantized = dequantize_affine( expected_quantized, @@ -950,7 +950,7 @@ def test_float8_quant_primitives(self, hp_dtype, float8_dtype): quant_min=torch.finfo(float8_dtype).min, quant_max=torch.finfo(float8_dtype).max, zero_point=None, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) self.assertTrue(torch.equal(expected_scale, scale)) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index b37b3a07c8..284011442b 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -303,6 +303,8 @@ def from_hp_to_intx_static( """Create an integer AffineQuantizedTensor from a high precision tensor using static parameters.""" if zero_point_domain is None: raise ValueError("please use ZeroPointDomain.NONE instead of None") + elif zero_point_domain is None and zero_point is not None: + raise ValueError("zero_point should be None when zero_point_domain is NONE") if target_dtype not in FP8_TYPES: assert ( zero_point is not None diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index f3a501cd81..1a3034ad77 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -321,6 +321,8 @@ def quantize_affine( """ if zero_point_domain is None: raise ValueError("Please use ZeroPointDomain.NONE instead of None") + elif zero_point_domain is None and zero_point is not None: + raise ValueError("zero_point should be None when zero_point_domain is NONE") return _quantize_affine( input, block_size, @@ -475,6 +477,8 @@ def dequantize_affine( """ if zero_point_domain is None: raise ValueError("Please use ZeroPointDomain.NONE instead of None") + elif zero_point_domain is None and zero_point is not None: + raise ValueError("zero_point should be None when zero_point_domain is NONE") return _dequantize_affine( input, block_size, @@ -621,6 +625,8 @@ def fake_quantize_affine( """ if zero_point_domain is None: raise ValueError("Please use ZeroPointDomain.NONE instead of None") + elif zero_point_domain is None and zero_point is not None: + raise ValueError("zero_point should be None when zero_point_domain is NONE") (_, fq) = _do_fake_quantize_affine( input, block_size, @@ -665,6 +671,8 @@ def fake_quantize_affine_cachemask( """ if zero_point_domain is None: raise ValueError("Please use ZeroPointDomain.NONE instead of None") + elif zero_point_domain is None and zero_point is not None: + raise ValueError("zero_point should be None when zero_point_domain is NONE") (q, dq) = _do_fake_quantize_affine( input, block_size, From cf909987d1c012bd2ea67d14c48efb82947ffe45 Mon Sep 17 00:00:00 2001 From: sanchitintel Date: Mon, 27 Jan 2025 11:08:47 -0800 Subject: [PATCH 6/6] Fix typo --- torchao/dtypes/affine_quantized_tensor.py | 2 +- torchao/quantization/quant_primitives.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 284011442b..715aaeb9ec 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -303,7 +303,7 @@ def from_hp_to_intx_static( """Create an integer AffineQuantizedTensor from a high precision tensor using static parameters.""" if zero_point_domain is None: raise ValueError("please use ZeroPointDomain.NONE instead of None") - elif zero_point_domain is None and zero_point is not None: + elif zero_point_domain is ZeroPointDomain.NONE and zero_point is not None: raise ValueError("zero_point should be None when zero_point_domain is NONE") if target_dtype not in FP8_TYPES: assert ( diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 1a3034ad77..05be8c5c30 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -321,7 +321,7 @@ def quantize_affine( """ if zero_point_domain is None: raise ValueError("Please use ZeroPointDomain.NONE instead of None") - elif zero_point_domain is None and zero_point is not None: + elif zero_point_domain is ZeroPointDomain.NONE and zero_point is not None: raise ValueError("zero_point should be None when zero_point_domain is NONE") return _quantize_affine( input, @@ -477,7 +477,7 @@ def dequantize_affine( """ if zero_point_domain is None: raise ValueError("Please use ZeroPointDomain.NONE instead of None") - elif zero_point_domain is None and zero_point is not None: + elif zero_point_domain is ZeroPointDomain.NONE and zero_point is not None: raise ValueError("zero_point should be None when zero_point_domain is NONE") return _dequantize_affine( input, @@ -625,7 +625,7 @@ def fake_quantize_affine( """ if zero_point_domain is None: raise ValueError("Please use ZeroPointDomain.NONE instead of None") - elif zero_point_domain is None and zero_point is not None: + elif zero_point_domain is ZeroPointDomain.NONE and zero_point is not None: raise ValueError("zero_point should be None when zero_point_domain is NONE") (_, fq) = _do_fake_quantize_affine( input,