From 83ada42b760ee4ade708dee4e22f5edcea405f1c Mon Sep 17 00:00:00 2001 From: zzzzzzzk <41361256+zk1998@users.noreply.github.com> Date: Wed, 19 Jul 2023 09:56:21 +0800 Subject: [PATCH] [quantizer] fix broadcast bug (#236) --- tests/quantizer_test.py | 10 ++++++++++ tinynn/graph/quantization/quantizer.py | 8 ++++---- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/quantizer_test.py b/tests/quantizer_test.py index 19269391..3246f595 100644 --- a/tests/quantizer_test.py +++ b/tests/quantizer_test.py @@ -1293,6 +1293,16 @@ def forward(self, x): check_quantize_rewrite(model, inputs) + def test_quantized_mul_different_shape_complex(self): + class Model(nn.Module): + def forward(self, x): + return x.transpose(0, 1) * x + + model = Model() + inputs = torch.randn(1, 3, 224, 224) + + check_quantize_rewrite(model, inputs) + def test_quantized_add_relu_different_shape(self): class Model(nn.Module): def forward(self, x): diff --git a/tinynn/graph/quantization/quantizer.py b/tinynn/graph/quantization/quantizer.py index 8178a9ba..c0fa5f29 100644 --- a/tinynn/graph/quantization/quantizer.py +++ b/tinynn/graph/quantization/quantizer.py @@ -2821,17 +2821,17 @@ def _is_broadcastable_binary_quantized_op_node(node: TraceNode, custom_data) -> for l_dim, r_dim in zip(l_shape, r_shape): if l_dim > r_dim: - if ref_index in (None, 0): + if ref_index in (None, 0) and r_dim == 1: ref_index = 0 else: ref_index = -1 - break + break elif l_dim < r_dim: - if ref_index in (None, 1): + if ref_index in (None, 1) and l_dim == 1: ref_index = 1 else: ref_index = -1 - break + break if ref_index >= 0: src_index = 1 - ref_index