diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index c14aae08c..33cd84b1d 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -7,8 +7,9 @@ from mmengine.runner import load_checkpoint from mmengine.structures import BaseDataElement from torch import nn -from torch.ao.quantization import FakeQuantizeBase +from mmrazor.models.fake_quants import BaseFakeQuantize +from mmrazor.models.observers import BaseObserver from mmrazor.models.task_modules import build_graphmodule from mmrazor.registry import MODEL_WRAPPERS, MODELS from ..base import BaseAlgorithm @@ -62,8 +63,14 @@ def __init__(self, self.forward_modes = forward_modes self.qmodels = self._build_qmodels(self.architecture) - self.sync_qparams('predict') + self.reset_min_max_vals(self.qmodels) + + def reset_min_max_vals(self, model): + for module in model.modules(): + if isinstance(module, BaseObserver): + assert hasattr(module, 'reset_min_max_vals') + module.reset_min_max_vals() def sync_qparams(self, src_mode): @@ -72,7 +79,7 @@ def traverse(module, prefix): if module is None: continue child_name = f'{prefix}{name}' - if isinstance(child, FakeQuantizeBase): + if isinstance(child, BaseFakeQuantize): for name, param in child.named_parameters(): param_name = f'{child_name}.{name}' src_param = src_state_dict[param_name] @@ -114,10 +121,15 @@ def _build_qmodels(self, model): graph_mopdule = build_graphmodule(model, traced_graph) observed_module = self.quantizer.prepare(model, graph_mopdule) qmodels[mode] = observed_module - # import pdb - # pdb.set_trace() - # dummy_input = torch.randn(self.input_shapes) - # qmodels['predict'](dummy_input, None, 'predict') + + is_training = qmodels['predict'].training + # Avoid random input changing bn's statistics + qmodels['predict'].eval() + # Input a random tensor to modify the shape of scale and zero_point + # in fakequant in per-channel mode. + dummy_input = torch.randn(self.input_shapes) + qmodels['predict'](dummy_input, None, 'predict') + qmodels['predict'].train(mode=is_training) return qmodels