From d487609c5e8d01a3b2146060737ecb68deff619b Mon Sep 17 00:00:00 2001 From: zhoukun Date: Fri, 24 Feb 2023 16:01:56 +0800 Subject: [PATCH 1/2] [quantizer] fix weight quant_min/quant_max specification failure when using tflite as backend --- tests/quantizer_test.py | 79 +++++++++++++++++++++++++- tinynn/graph/quantization/quantizer.py | 46 +++++++++++++-- 2 files changed, 120 insertions(+), 5 deletions(-) diff --git a/tests/quantizer_test.py b/tests/quantizer_test.py index 05d5bd51..09ee4a3d 100644 --- a/tests/quantizer_test.py +++ b/tests/quantizer_test.py @@ -76,6 +76,42 @@ def check_dequantize_rewrite(model, inputs, show_rewritten=True, skip_train=Fals float_model(inputs) +def check_quantizer_convert(model, inputs, skip_train=False): + # TODO: we need check tflite uint8->quant[0,255], int8->quant[-127, 127] for weight. + config = { + 'asymmetric': True, + 'per_tensor': False, + 'remove_weights_after_load': True, + 'ignore_layerwise_config': True, + } + + with model_tracer(): + quantizer = QATQuantizer(model, inputs, work_dir='out', config=config) + qat_model = quantizer.quantize() + + quant_minmax_dict = {} + for name, mod in model.named_modules(): + if hasattr(mod, 'weight_fake_quant'): + quant_minmax_dict[name] = [mod.weight_fake_quant.quant_min, mod.weight_fake_quant.quant_max] + + if not skip_train: + for _ in range(3): + if isinstance(inputs, (list, tuple)): + qat_model(*inputs) + else: + qat_model(inputs) + + with torch.no_grad(): + qat_model.eval() + + qat_model = quantizer.convert(qat_model) + for name, mod in qat_model.named_modules(): + if name in quant_minmax_dict: + int_weight = torch.int_repr(mod.weight()) + assert int_weight.min() >= quant_minmax_dict[name][0] + assert int_weight.max() <= quant_minmax_dict[name][1] + + class QuantizerTester(unittest.TestCase): def test_simple_float_model(self): class Model(nn.Module): @@ -486,7 +522,6 @@ def forward(self, x): check_quantize_rewrite(model, inputs, skip_train=skip_train) - def test_avg_pool1d_with_one_kernel_size(self): class Model(nn.Module): def __init__(self): @@ -1400,6 +1435,48 @@ def forward(self, x): check_quantize_rewrite(model, inputs) + def test_quantizer_convert_fc(self): + class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(3, 1) + + def forward(self, x): + return self.fc(x) + + model = Model() + inputs = torch.randn(1, 3) + + check_quantizer_convert(model, inputs) + + def test_quantizer_convert_conv(self): + class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = nn.Conv2d(3, 1, 1, 1) + + def forward(self, x): + return self.conv(x) + + model = Model() + inputs = torch.randn(1, 3, 224, 224) + + check_quantizer_convert(model, inputs) + + def test_quantizer_convert_transpose(self): + class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = nn.ConvTranspose2d(3, 5, 2, 2, 1) + + def forward(self, x): + return self.conv(x) + + model = Model() + inputs = torch.randn(1, 3, 224, 224) + + check_quantizer_convert(model, inputs) + class DeQuantizerTester(unittest.TestCase): def test_simple_q_model(self): diff --git a/tinynn/graph/quantization/quantizer.py b/tinynn/graph/quantization/quantizer.py index 784030ef..9642d319 100644 --- a/tinynn/graph/quantization/quantizer.py +++ b/tinynn/graph/quantization/quantizer.py @@ -991,10 +991,11 @@ def new_no_observer_set(): if n.endswith('.weight_fake_quant'): observer = getattr(m, 'activation_post_process', None) if observer is not None: - m.quant_min = -127 - m.quant_max = 127 - observer.quant_min = -127 - observer.quant_max = 127 + if m.qscheme == torch.per_channel_symmetric: + m.quant_min = -127 + m.quant_max = 127 + observer.quant_min = -127 + observer.quant_max = 127 self.extra_qat_fusion_postprocess(graph) @@ -3038,6 +3039,40 @@ def convert(self, q_model: nn.Module, backend: str = 'tflite') -> nn.Module: if activ_name is not None: setattr(q_model, activ_name, nn.Identity()) + def gen_wrapper(origin_quantize_weight): + def new_quantize_weight(float_wt, observer, *args, **kwargs): + q_weight = origin_quantize_weight(float_wt, observer) + # Do clamp using observer's quant_min and quant_max + if hasattr(observer, 'quant_min') and hasattr(observer, 'quant_max'): + int_tensor = torch.int_repr(q_weight) + int_tensor = torch.clamp(int_tensor, observer.quant_min, observer.quant_max) + + if observer.qscheme == torch.per_tensor_symmetric or observer.qscheme == torch.per_tensor_affine: + q_weight = torch._make_per_tensor_quantized_tensor( + int_tensor, q_weight.q_scale(), q_weight.q_zero_point() + ) + elif ( + observer.qscheme == torch.per_channel_symmetric or observer.qscheme == torch.per_channel_affine + ): + q_weight = torch._make_per_channel_quantized_tensor( + int_tensor, + q_weight.q_per_channel_scales(), + q_weight.q_per_channel_zero_points(), + q_weight.q_per_channel_axis(), + ) + else: + log.warning('observer qscheme error') + else: + log.warning('observer do not have quant_min and quant_max') + + return q_weight + + return new_quantize_weight + + origin_quantize_weight = torch.nn.quantized.modules.utils._quantize_weight + torch.nn.quantized.modules.linear._quantize_weight = gen_wrapper(origin_quantize_weight) + torch.nn.quantized.modules.conv._quantize_weight = gen_wrapper(origin_quantize_weight) + if type(self).__name__ == 'QATQuantizer': q = queue.Queue() @@ -3077,6 +3112,9 @@ def convert(self, q_model: nn.Module, backend: str = 'tflite') -> nn.Module: else: q_model = torch.quantization.convert(q_model) + torch.nn.quantized.modules.linear._quantize_weight = origin_quantize_weight + torch.nn.quantized.modules.conv._quantize_weight = origin_quantize_weight + return q_model def optimize_conv_bn_fusion(self, q_model, eps=1e-5): From 210900c476dbd9f3baeb051365d6aee23fdab609 Mon Sep 17 00:00:00 2001 From: zhoukun Date: Fri, 24 Feb 2023 16:28:41 +0800 Subject: [PATCH 2/2] [tests] add win32 quantizer config --- tests/quantizer_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/quantizer_test.py b/tests/quantizer_test.py index 09ee4a3d..9b0e0666 100644 --- a/tests/quantizer_test.py +++ b/tests/quantizer_test.py @@ -85,6 +85,9 @@ def check_quantizer_convert(model, inputs, skip_train=False): 'ignore_layerwise_config': True, } + if sys.platform == 'win32': + config.update({'backend': 'fbgemm', 'per_tensor': False}) + with model_tracer(): quantizer = QATQuantizer(model, inputs, work_dir='out', config=config) qat_model = quantizer.quantize()