diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index c926cee060..56bcaf17df 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -10,6 +10,7 @@ import logging import os import unittest +from functools import partial import torch import torch.nn as nn @@ -48,6 +49,7 @@ quantize_, ) from torchao.quantization.quant_primitives import ( + MappingType, dequantize_affine, ) from torchao.quantization.smoothquant import ( @@ -102,6 +104,8 @@ COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] +ACT_MAPPING_TYPES = [MappingType.ASYMMETRIC, MappingType.SYMMETRIC] + COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy() @@ -121,9 +125,18 @@ def _int8wo_groupwise_api(mod): quantize_(mod, int8_weight_only(group_size=group_size), set_inductor_config=False) -def _int8da_int8w_api(mod): +def _int8da_int8w_api( + mod, + act_mapping_type=MappingType.SYMMETRIC, +): if TORCH_VERSION_AT_LEAST_2_4: - quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False) + quantize_( + mod, + int8_dynamic_activation_int8_weight( + act_mapping_type=act_mapping_type, + ), + set_inductor_config=False, + ) if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(mod) else: @@ -962,10 +975,11 @@ def _test_lin_weight_subclass_api_impl( mod[0].weight.tensor_impl.get_plain() test = mod(x) + self.assertGreater( SQNR(ref_f, test), min_sqnr, - f"{api.__name__} failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}", + f"API failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}", ) mod_qc = torch.compile(mod, mode="max-autotune") @@ -973,14 +987,31 @@ def _test_lin_weight_subclass_api_impl( self.assertGreater( SQNR(ref_f, test_comp), min_sqnr, - f"{api.__name__} failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}", + f"API failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}", ) - @parameterized.expand(COMMON_DEVICE_DTYPE) - def test_int8_dynamic_quant_subclass_api(self, device, dtype): - self._test_lin_weight_subclass_api_impl( - _int8da_int8w_api, device, 35, test_dtype=dtype + @parameterized.expand( + list( + itertools.product( + COMMON_DEVICES, + COMMON_DTYPES, + ACT_MAPPING_TYPES, + ) + ) + ) + def test_int8_dynamic_quant_subclass_api(self, device, dtype, act_mapping): + if ( + not TORCH_VERSION_AT_LEAST_2_5 + and dtype in (torch.float16, torch.bfloat16) + and act_mapping is MappingType.ASYMMETRIC + and device == "cpu" + ): + self.skipTest("Inductor-CPU codegen issue fixed in torch 2.5") + api = partial( + _int8da_int8w_api, + act_mapping_type=act_mapping, ) + self._test_lin_weight_subclass_api_impl(api, device, 35, test_dtype=dtype) @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(is_fbcode(), "broken in fbcode") diff --git a/test/quantization/test_observer.py b/test/quantization/test_observer.py index 0526ee01b2..4567f3baef 100644 --- a/test/quantization/test_observer.py +++ b/test/quantization/test_observer.py @@ -21,6 +21,7 @@ ) from torchao.quantization.quant_primitives import ( MappingType, + ZeroPointDomain, ) @@ -74,7 +75,7 @@ def test_block_size_calc_success(self): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) example_inputs = [ torch.randn(10, 2048), @@ -93,7 +94,7 @@ def test_block_size_calc_success(self): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) for example_input in example_inputs: obs(example_input) @@ -108,7 +109,7 @@ def test_block_size_row_errors(self): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) example_inputs = [ torch.randn(10, 2048), @@ -127,7 +128,7 @@ def test_block_size_row_errors(self): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) example_inputs = [ torch.randn(10, 2048), @@ -155,7 +156,7 @@ def test_linear_observer_tensor(self, observe_weight: bool): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) if observe_weight: weight_observer = AffineQuantizedMinMaxObserver( @@ -165,7 +166,7 @@ def test_linear_observer_tensor(self, observe_weight: bool): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) else: weight_observer = None @@ -199,7 +200,7 @@ def test_linear_observer_tensor(self, observe_weight: bool): input_scale.item(), max_val / max_fp8, ) - self.assertIsNotNone(input_zero_point) + self.assertIsNone(input_zero_point) if observe_weight: weight_observer = linear.weight.weight_observer @@ -210,7 +211,7 @@ def test_linear_observer_tensor(self, observe_weight: bool): atol=5e-5, rtol=0.0, ) - self.assertIsNotNone(weight_zero_point) + self.assertIsNone(weight_zero_point) else: self.assertIsNone(linear.weight.weight_observer) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 77616c1c6a..3ca58ff996 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -843,6 +843,55 @@ def test_fake_quantize_affine_cachemask(self): torch.testing.assert_close(dequantized, fake_quantized) torch.testing.assert_close(expected_mask, mask) + def test_none_zero_point_domain(self): + """A None value for a ZeroPointDomain should not work, but ZeroPointDomain.NONE should""" + input = torch.randn(10, 256) + mapping_type = MappingType.SYMMETRIC + dtype = torch.int8 + block_size = (1, 128) + quant_min = None + quant_max = None + eps = 1e-6 + scale_dtype = torch.float32 + zero_point_dtype = torch.int64 + try: + _, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=True, + zero_point_domain=None, + ) + except ValueError: + # This exception was expected + # Now test for ZeroPointDomain.NONE + _, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=True, + zero_point_domain=ZeroPointDomain.NONE, + ) + self.assertTrue(zero_point is None) + else: + # An exception should have been thrown for zero_point_domain None + self.assertTrue( + False, + msg="A runtime exception should have been thrown for zero_point_domain None", + ) + @parameterized.expand( [ ( @@ -890,7 +939,7 @@ def test_float8_quant_primitives(self, hp_dtype, float8_dtype): quant_min=torch.finfo(float8_dtype).min, quant_max=torch.finfo(float8_dtype).max, zero_point=None, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) expected_dequantized = dequantize_affine( expected_quantized, @@ -901,7 +950,7 @@ def test_float8_quant_primitives(self, hp_dtype, float8_dtype): quant_min=torch.finfo(float8_dtype).min, quant_max=torch.finfo(float8_dtype).max, zero_point=None, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) self.assertTrue(torch.equal(expected_scale, scale)) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index e3ac420de7..715aaeb9ec 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -81,6 +81,8 @@ def __new__( dtype=None, strides=None, ): + if zero_point_domain is None: + raise ValueError("please use ZeroPointDomain.NONE instead of None") kwargs = {} kwargs["device"] = tensor_impl.device kwargs["layout"] = ( @@ -199,7 +201,7 @@ def from_hp_to_intx( scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, - zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, _layout: Layout = PlainLayout(), use_hqq: bool = False, ): @@ -258,8 +260,7 @@ def from_hp_to_intx( zero_point_domain, ) # choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None - # TODO should probably consolidate ZeroPointDomain.NONE and None - if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE: + if zero_point_domain == ZeroPointDomain.NONE: zero_point = None data = quantize_affine( input_float, @@ -296,14 +297,15 @@ def from_hp_to_intx_static( target_dtype: torch.dtype, quant_min: Optional[int] = None, quant_max: Optional[int] = None, - zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, _layout: Layout = PlainLayout(), ): """Create an integer AffineQuantizedTensor from a high precision tensor using static parameters.""" + if zero_point_domain is None: + raise ValueError("please use ZeroPointDomain.NONE instead of None") + elif zero_point_domain is ZeroPointDomain.NONE and zero_point is not None: + raise ValueError("zero_point should be None when zero_point_domain is NONE") if target_dtype not in FP8_TYPES: - assert ( - zero_point_domain is not None - ), "zero_point_domain must be specified for non-fp8 types" assert ( zero_point is not None ), "zero_point must be specified for non-fp8 types" @@ -359,7 +361,7 @@ def from_hp_to_floatx( scale_dtype=scale_dtype, zero_point_dtype=None, preserve_zero=True, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, _layout=_layout, use_hqq=False, ) @@ -387,7 +389,7 @@ def from_hp_to_floatx_static( target_dtype=target_dtype, quant_min=math.ceil(torch.finfo(target_dtype).min), quant_max=math.ceil(torch.finfo(target_dtype).max), - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, _layout=_layout, ) else: diff --git a/torchao/dtypes/uintx/marlin_qqq_tensor.py b/torchao/dtypes/uintx/marlin_qqq_tensor.py index 3a4253bb3f..95175caacf 100644 --- a/torchao/dtypes/uintx/marlin_qqq_tensor.py +++ b/torchao/dtypes/uintx/marlin_qqq_tensor.py @@ -54,10 +54,12 @@ def from_hp_to_intx( block_size: Tuple[int, ...], quant_min: Optional[int] = None, quant_max: Optional[int] = None, - zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, _layout: Optional[Layout] = None, ): """Converts a floating point tensor to a Marlin QQQ quantized tensor.""" + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") original_shape = input_float.shape input_float = _layout.pre_process(input_float) nbits = int(math.log2(quant_max - quant_min + 1)) diff --git a/torchao/quantization/observer.py b/torchao/quantization/observer.py index 06509c7b91..cbbe1b581d 100644 --- a/torchao/quantization/observer.py +++ b/torchao/quantization/observer.py @@ -104,11 +104,12 @@ def __init__( scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, - zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ): super().__init__() assert granularity is not None, "granularity is None" - + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") self.mapping_type = mapping_type self.target_dtype = target_dtype self.granularity = granularity diff --git a/torchao/quantization/qat/affine_fake_quantized_tensor.py b/torchao/quantization/qat/affine_fake_quantized_tensor.py index b84200ac9c..f60c858b73 100644 --- a/torchao/quantization/qat/affine_fake_quantized_tensor.py +++ b/torchao/quantization/qat/affine_fake_quantized_tensor.py @@ -41,6 +41,9 @@ def forward( preserve_zero: bool = True, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> "AffineFakeQuantizedTensor": + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") + def apply_fake_quant_fn(t: torch.Tensor): assert isinstance(t, AffineFakeQuantizedTensor) qmin, qmax = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) @@ -158,6 +161,8 @@ def from_float( preserve_zero: bool = True, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ): + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") return _ToAffineFakeQuantized.apply( original_input, mapping_type, diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py index cd3813291f..925a0eed3c 100644 --- a/torchao/quantization/qat/api.py +++ b/torchao/quantization/qat/api.py @@ -96,6 +96,8 @@ def __init__( group_size: Optional[int] = None, is_symmetric: Optional[bool] = None, ): + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") self.dtype = dtype self.granularity = self._get_granularity(granularity, group_size) self.mapping_type = self._get_mapping_type(mapping_type, is_symmetric) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 3a73b97ad1..02af4ced91 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -387,7 +387,7 @@ def insert_observers_( eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) # Create a linear module @@ -688,7 +688,7 @@ def int4_weight_only( group_size=128, layout=TensorCoreTiledLayout(inner_k_tiles=8), use_hqq=False, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ): """ Applies uint4 weight-only asymmetric per-group quantization to linear layers, using @@ -733,7 +733,7 @@ def apply_int4_weight_only_quant(weight): assert ( type(layout) in LAYOUT_TO_ZERO_POINT_DOMAIN.keys() ), f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}" - if zero_point_domain is None: + if zero_point_domain == ZeroPointDomain.NONE: # the first value is the default one zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)][0] else: @@ -877,6 +877,7 @@ def apply_int8_dynamic_activation_int8_weight_quant(weight): # weight settings mapping_type = MappingType.SYMMETRIC + weight_zero_point_domain = ZeroPointDomain.NONE def get_weight_block_size(x): return (1, x.shape[1]) @@ -903,6 +904,7 @@ def get_weight_block_size(x): eps=eps, zero_point_dtype=zero_point_dtype, _layout=layout, + zero_point_domain=weight_zero_point_domain, ) weight = to_linear_activation_quantized(weight, input_quant_func) return weight diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 8b0ce28434..05be8c5c30 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -284,7 +284,7 @@ def quantize_affine( output_dtype: torch.dtype, quant_min: Optional[Union[int, float]] = None, quant_max: Optional[Union[int, float]] = None, - zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> torch.Tensor: """ Args: @@ -319,6 +319,10 @@ def quantize_affine( Output: quantized tensor with requested dtype """ + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") + elif zero_point_domain is ZeroPointDomain.NONE and zero_point is not None: + raise ValueError("zero_point should be None when zero_point_domain is NONE") return _quantize_affine( input, block_size, @@ -327,7 +331,7 @@ def quantize_affine( output_dtype, quant_min, quant_max, - zero_point_domain.name if zero_point_domain is not None else None, + zero_point_domain.name, ) @@ -340,7 +344,7 @@ def _quantize_affine( output_dtype: torch.dtype, quant_min: Optional[Union[int, float, bool]] = None, quant_max: Optional[Union[int, float, bool]] = None, - zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, + zero_point_domain: str = ZeroPointDomain.INT.name, ) -> torch.Tensor: """op definition that has compatible signatures with custom op library @@ -363,6 +367,7 @@ def _quantize_affine( zero_point, quant_min, quant_max, + output_dtype, zero_point_domain, ).to(output_dtype) @@ -374,7 +379,8 @@ def _quantize_affine_no_dtype_cast( zero_point: Optional[torch.Tensor], quant_min: Union[int, float], quant_max: Union[int, float], - zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, + quant_dtype: torch.dtype, + zero_point_domain: str = ZeroPointDomain.INT.name, ) -> torch.Tensor: """ The op does the following: @@ -418,13 +424,12 @@ def _quantize_affine_no_dtype_cast( assert ( zero_point is None ), "zero_point should be None when zero_point_domain is NONE" - quant = torch.clamp(torch.round(input * (1.0 / scale)), quant_min, quant_max) - elif zero_point_domain is None: - # This case handles quantization for float8 we expect no zero point and no zero point domain - assert ( - zero_point is None - ), "zero_point should be None when zero_point_domain is None" - quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max) + if _is_float8_type(quant_dtype): + quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max) + else: + quant = torch.clamp( + torch.round(input * (1.0 / scale)), quant_min, quant_max + ) else: assert zero_point_domain == ZeroPointDomain.FLOAT.name mid_point = (quant_max + quant_min + 1) / 2 @@ -470,6 +475,10 @@ def dequantize_affine( Output: dequantized Tensor, with requested dtype or fp32 """ + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") + elif zero_point_domain is ZeroPointDomain.NONE and zero_point is not None: + raise ValueError("zero_point should be None when zero_point_domain is NONE") return _dequantize_affine( input, block_size, @@ -478,7 +487,7 @@ def dequantize_affine( input_dtype, quant_min, quant_max, - zero_point_domain.name if zero_point_domain is not None else None, + zero_point_domain.name, output_dtype=output_dtype, ) @@ -567,16 +576,6 @@ def _dequantize_affine_no_dtype_check( ), "zero_point should be None when zero_point_domain is NONE" dequant = input.to(output_dtype) dequant = dequant * scale - elif zero_point_domain is None: - # This case handles dequantization for float8 we expect no zero point and no zero point domain - assert ( - zero_point is None - ), "zero_point should be None when zero_point_domain is None" - assert _is_float8_type( - input.dtype - ), f"dequantiztion with no zero point domain is only supported with FP8 types, got {input.dtype}" - dequant = input.to(output_dtype) - dequant = dequant * scale else: assert ( zero_point_domain == ZeroPointDomain.FLOAT.name @@ -624,6 +623,10 @@ def fake_quantize_affine( value during quantization default is ZeroPointDomain.INT """ + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") + elif zero_point_domain is ZeroPointDomain.NONE and zero_point is not None: + raise ValueError("zero_point should be None when zero_point_domain is NONE") (_, fq) = _do_fake_quantize_affine( input, block_size, @@ -666,6 +669,10 @@ def fake_quantize_affine_cachemask( ) """ + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") + elif zero_point_domain is None and zero_point is not None: + raise ValueError("zero_point should be None when zero_point_domain is NONE") (q, dq) = _do_fake_quantize_affine( input, block_size, @@ -703,6 +710,7 @@ def _do_fake_quantize_affine( zero_point, quant_min, quant_max, + quant_dtype, zero_point_domain.name, ) dq = _dequantize_affine_no_dtype_check( @@ -730,7 +738,7 @@ def choose_qparams_affine( scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, - zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -764,6 +772,8 @@ def choose_qparams_affine( Output: Tuple of scales and zero_points Tensor with requested dtype """ + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") return _choose_qparams_affine( input, mapping_type.name, @@ -775,7 +785,7 @@ def choose_qparams_affine( scale_dtype, zero_point_dtype, preserve_zero, - zero_point_domain.name if zero_point_domain is not None else None, + zero_point_domain.name, ) @@ -791,7 +801,7 @@ def choose_qparams_affine_with_min_max( scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, - zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> Tuple[torch.Tensor, torch.Tensor]: """A variant of :func:`~torchao.quantization.quant_primitives.choose_qparams_affine` operator that pass in min_val and max_val directly instead of deriving these from a single input. @@ -803,6 +813,8 @@ def choose_qparams_affine_with_min_max( difference: instead of passing in `input` Tensor and use that to calculate min_val/max_val and then scale/zero_point, we pass in min_val/max_val directly """ + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") return _choose_qparams_affine( None, mapping_type.name, @@ -814,7 +826,7 @@ def choose_qparams_affine_with_min_max( scale_dtype, zero_point_dtype, preserve_zero, - zero_point_domain.name if zero_point_domain is not None else None, + zero_point_domain.name, min_val, max_val, ) @@ -921,17 +933,17 @@ def _choose_qparams_affine( raise ValueError( "preserve_zero == False is not supported for symmetric quantization" ) - if ( - zero_point_domain is not None - and zero_point_domain == ZeroPointDomain.FLOAT.name - ): + if zero_point_domain == ZeroPointDomain.FLOAT.name: # TODO INT should not be a valid ZeroPointDomain for symmetric quantization since # symmetric quant doesn't have a zero_point raise ValueError( "zero_point_domain should be ZeroPointDomain.INT or ZeroPointDomain.NONE for symmetric quantization" ) + if zero_point_domain == ZeroPointDomain.NONE.name: + zero_point = None + else: + zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) scale = torch.clamp(scale, min=eps) - zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) else: assert mapping_type == MappingType.ASYMMETRIC.name scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)