Skip to content

Commit

Permalink
reset minmax in the minmaxobserver
Browse files Browse the repository at this point in the history
  • Loading branch information
HIT-cwh committed Dec 26, 2022
1 parent 5707d07 commit a11a465
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions mmrazor/models/algorithms/quantization/mm_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from mmengine.runner import load_checkpoint
from mmengine.structures import BaseDataElement
from torch import nn
from torch.ao.quantization import FakeQuantizeBase

from mmrazor.models.fake_quants import BaseFakeQuantize
from mmrazor.models.observers import BaseObserver
from mmrazor.models.task_modules import build_graphmodule
from mmrazor.registry import MODEL_WRAPPERS, MODELS
from ..base import BaseAlgorithm
Expand Down Expand Up @@ -62,8 +63,14 @@ def __init__(self,
self.forward_modes = forward_modes

self.qmodels = self._build_qmodels(self.architecture)

self.sync_qparams('predict')
self.reset_min_max_vals(self.qmodels)

def reset_min_max_vals(self, model):
for module in model.modules():
if isinstance(module, BaseObserver):
assert hasattr(module, 'reset_min_max_vals')
module.reset_min_max_vals()

def sync_qparams(self, src_mode):

Expand All @@ -72,7 +79,7 @@ def traverse(module, prefix):
if module is None:
continue
child_name = f'{prefix}{name}'
if isinstance(child, FakeQuantizeBase):
if isinstance(child, BaseFakeQuantize):
for name, param in child.named_parameters():
param_name = f'{child_name}.{name}'
src_param = src_state_dict[param_name]
Expand Down Expand Up @@ -114,10 +121,15 @@ def _build_qmodels(self, model):
graph_mopdule = build_graphmodule(model, traced_graph)
observed_module = self.quantizer.prepare(model, graph_mopdule)
qmodels[mode] = observed_module
# import pdb
# pdb.set_trace()
# dummy_input = torch.randn(self.input_shapes)
# qmodels['predict'](dummy_input, None, 'predict')

is_training = qmodels['predict'].training
# Avoid random input changing bn's statistics
qmodels['predict'].eval()
# Input a random tensor to modify the shape of scale and zero_point
# in fakequant in per-channel mode.
dummy_input = torch.randn(self.input_shapes)
qmodels['predict'](dummy_input, None, 'predict')
qmodels['predict'].train(mode=is_training)

return qmodels

Expand Down

0 comments on commit a11a465

Please sign in to comment.