Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quantizer] fix weight quant_min/quant_max specification failure when using tflite as backend #186

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 81 additions & 1 deletion tests/quantizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,45 @@ def check_dequantize_rewrite(model, inputs, show_rewritten=True, skip_train=Fals
float_model(inputs)


def check_quantizer_convert(model, inputs, skip_train=False):
# TODO: we need check tflite uint8->quant[0,255], int8->quant[-127, 127] for weight.
config = {
'asymmetric': True,
'per_tensor': False,
'remove_weights_after_load': True,
'ignore_layerwise_config': True,
}

if sys.platform == 'win32':
config.update({'backend': 'fbgemm', 'per_tensor': False})

with model_tracer():
quantizer = QATQuantizer(model, inputs, work_dir='out', config=config)
qat_model = quantizer.quantize()

quant_minmax_dict = {}
for name, mod in model.named_modules():
if hasattr(mod, 'weight_fake_quant'):
quant_minmax_dict[name] = [mod.weight_fake_quant.quant_min, mod.weight_fake_quant.quant_max]

if not skip_train:
for _ in range(3):
if isinstance(inputs, (list, tuple)):
qat_model(*inputs)
else:
qat_model(inputs)

with torch.no_grad():
qat_model.eval()

qat_model = quantizer.convert(qat_model)
for name, mod in qat_model.named_modules():
if name in quant_minmax_dict:
int_weight = torch.int_repr(mod.weight())
assert int_weight.min() >= quant_minmax_dict[name][0]
assert int_weight.max() <= quant_minmax_dict[name][1]


class QuantizerTester(unittest.TestCase):
def test_simple_float_model(self):
class Model(nn.Module):
Expand Down Expand Up @@ -486,7 +525,6 @@ def forward(self, x):

check_quantize_rewrite(model, inputs, skip_train=skip_train)


def test_avg_pool1d_with_one_kernel_size(self):
class Model(nn.Module):
def __init__(self):
Expand Down Expand Up @@ -1400,6 +1438,48 @@ def forward(self, x):

check_quantize_rewrite(model, inputs)

def test_quantizer_convert_fc(self):
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc = nn.Linear(3, 1)

def forward(self, x):
return self.fc(x)

model = Model()
inputs = torch.randn(1, 3)

check_quantizer_convert(model, inputs)

def test_quantizer_convert_conv(self):
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = nn.Conv2d(3, 1, 1, 1)

def forward(self, x):
return self.conv(x)

model = Model()
inputs = torch.randn(1, 3, 224, 224)

check_quantizer_convert(model, inputs)

def test_quantizer_convert_transpose(self):
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = nn.ConvTranspose2d(3, 5, 2, 2, 1)

def forward(self, x):
return self.conv(x)

model = Model()
inputs = torch.randn(1, 3, 224, 224)

check_quantizer_convert(model, inputs)


class DeQuantizerTester(unittest.TestCase):
def test_simple_q_model(self):
Expand Down
46 changes: 42 additions & 4 deletions tinynn/graph/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,10 +991,11 @@ def new_no_observer_set():
if n.endswith('.weight_fake_quant'):
observer = getattr(m, 'activation_post_process', None)
if observer is not None:
m.quant_min = -127
m.quant_max = 127
observer.quant_min = -127
observer.quant_max = 127
if m.qscheme == torch.per_channel_symmetric:
m.quant_min = -127
m.quant_max = 127
observer.quant_min = -127
observer.quant_max = 127

self.extra_qat_fusion_postprocess(graph)

Expand Down Expand Up @@ -3038,6 +3039,40 @@ def convert(self, q_model: nn.Module, backend: str = 'tflite') -> nn.Module:
if activ_name is not None:
setattr(q_model, activ_name, nn.Identity())

def gen_wrapper(origin_quantize_weight):
def new_quantize_weight(float_wt, observer, *args, **kwargs):
q_weight = origin_quantize_weight(float_wt, observer)
# Do clamp using observer's quant_min and quant_max
if hasattr(observer, 'quant_min') and hasattr(observer, 'quant_max'):
int_tensor = torch.int_repr(q_weight)
int_tensor = torch.clamp(int_tensor, observer.quant_min, observer.quant_max)

if observer.qscheme == torch.per_tensor_symmetric or observer.qscheme == torch.per_tensor_affine:
q_weight = torch._make_per_tensor_quantized_tensor(
int_tensor, q_weight.q_scale(), q_weight.q_zero_point()
)
elif (
observer.qscheme == torch.per_channel_symmetric or observer.qscheme == torch.per_channel_affine
):
q_weight = torch._make_per_channel_quantized_tensor(
int_tensor,
q_weight.q_per_channel_scales(),
q_weight.q_per_channel_zero_points(),
q_weight.q_per_channel_axis(),
)
else:
log.warning('observer qscheme error')
else:
log.warning('observer do not have quant_min and quant_max')

return q_weight

return new_quantize_weight

origin_quantize_weight = torch.nn.quantized.modules.utils._quantize_weight
torch.nn.quantized.modules.linear._quantize_weight = gen_wrapper(origin_quantize_weight)
torch.nn.quantized.modules.conv._quantize_weight = gen_wrapper(origin_quantize_weight)

if type(self).__name__ == 'QATQuantizer':

q = queue.Queue()
Expand Down Expand Up @@ -3077,6 +3112,9 @@ def convert(self, q_model: nn.Module, backend: str = 'tflite') -> nn.Module:
else:
q_model = torch.quantization.convert(q_model)

torch.nn.quantized.modules.linear._quantize_weight = origin_quantize_weight
torch.nn.quantized.modules.conv._quantize_weight = origin_quantize_weight

return q_model

def optimize_conv_bn_fusion(self, q_model, eps=1e-5):
Expand Down