diff --git a/mmrazor/models/quantizers/openvino_quantizer.py b/mmrazor/models/quantizers/openvino_quantizer.py index bac432baa..25f76bbb1 100644 --- a/mmrazor/models/quantizers/openvino_quantizer.py +++ b/mmrazor/models/quantizers/openvino_quantizer.py @@ -75,7 +75,6 @@ def prepare(self, model, graph_module): if target_next: prepared = del_fakequant_after_target( prepared, target_next, inplace=True) - print(prepared) return prepared diff --git a/mmrazor/models/task_modules/tracer/fx/graph_utils.py b/mmrazor/models/task_modules/tracer/fx/graph_utils.py index 952f31b4b..94a163145 100644 --- a/mmrazor/models/task_modules/tracer/fx/graph_utils.py +++ b/mmrazor/models/task_modules/tracer/fx/graph_utils.py @@ -1,19 +1,45 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +from typing import Any, List, Tuple +import torch.fx from torch.ao.quantization.fake_quantize import FakeQuantizeBase -def _get_attrs(target, attrs): +def _get_attrs(target: torch.nn.Module, attr: str) -> Any: + """Get the attribute from target. - attrs = attrs.split('.') + Args: + target (torch.nn.Module): Get the attribute from target module. + attr (str): The target attribute. + + Returns: + Any: The target attribute. + """ + + attrs: List[str] = attr.split('.') for att in attrs: target = getattr(target, att, None) return target -def del_fakequant_before_target(prepared_model, target_patterns, inplace=True): +def del_fakequant_before_target(prepared_model: torch.fx.GraphModule, + target_patterns: Tuple, + inplace: bool = True) -> torch.fx.GraphModule: + """Delete useless fakequant before nodes whose target attribute + (node.target) is in `target_patterns`. + + Args: + prepared_model (GraphModule): Prepared standalone module. + target_patterns (tuple): Fakequants before nodes whose target attribute + (node.target) is in `target_patterns` will be deleted. + inplace (bool): Can optionally do the operation in-place. Defaults to + True. + + Returns: + GraphModule: Prepared standalone module after deletion. + """ def recursive_find_erased_nodes(node): """Find FakeQuant before target node recursively. @@ -40,14 +66,16 @@ def recursive_find_erased_nodes(node): """ if node is None: return - if isinstance( + if node.op == 'call_module' and isinstance( _get_attrs(prepared_model, node.target), FakeQuantizeBase): nodes_to_erase.append(node) return for prev_node in node.args: - recursive_find_erased_nodes(prev_node) + if isinstance(prev_node, torch.fx.Node): + recursive_find_erased_nodes(prev_node) for prev_node in node.kwargs.values(): - recursive_find_erased_nodes(prev_node) + if isinstance(prev_node, torch.fx.Node): + recursive_find_erased_nodes(prev_node) return if not inplace: @@ -55,7 +83,7 @@ def recursive_find_erased_nodes(node): new_graph = copy.deepcopy(prepared_model.graph) for node in new_graph.nodes: if node.target in target_patterns: - nodes_to_erase = [] + nodes_to_erase: List[torch.fx.Node] = [] recursive_find_erased_nodes(node) for to_erase in nodes_to_erase: to_erase.replace_all_uses_with(to_erase.args[0]) @@ -66,7 +94,22 @@ def recursive_find_erased_nodes(node): return prepared_model -def del_fakequant_after_target(prepared_model, target_patterns, inplace=True): +def del_fakequant_after_target(prepared_model: torch.fx.GraphModule, + target_patterns: Tuple, + inplace: bool = True) -> torch.fx.GraphModule: + """Delete useless fakequant after nodes whose target attribute + (node.target) is in `target_patterns`. + + Args: + prepared_model (GraphModule): Prepared standalone module. + target_patterns (tuple): Fakequants after nodes whose target attribute + (node.target) is in `target_patterns` will be deleted. + inplace (bool): Can optionally do the operation in-place. Defaults to + True. + + Returns: + GraphModule: Prepared standalone module after deletion. + """ if not inplace: prepared_model = copy.deepcopy(prepared_model) new_graph = copy.deepcopy(prepared_model.graph) @@ -91,7 +134,22 @@ def del_fakequant_after_target(prepared_model, target_patterns, inplace=True): return prepared_model -def del_fakequant_before_module(prepared_model, module_patterns, inplace=True): +def del_fakequant_before_module(prepared_model: torch.fx.GraphModule, + module_patterns: Tuple, + inplace: bool = True) -> torch.fx.GraphModule: + """Delete useless fakequant before modules whose type are in + `module_patterns`. + + Args: + prepared_model (GraphModule): Prepared standalone module. + target_patterns (tuple): Fakequants before modules whose type is in + `module_patterns` will be deleted. + inplace (bool): Can optionally do the operation in-place. + Defaults to True. + + Returns: + GraphModule: Prepared standalone module after deletion. + """ if not inplace: prepared_model = copy.deepcopy(prepared_model) new_graph = copy.deepcopy(prepared_model.graph) @@ -99,9 +157,9 @@ def del_fakequant_before_module(prepared_model, module_patterns, inplace=True): if node.op == 'call_module' and isinstance( _get_attrs(prepared_model, node.target), module_patterns): to_erase = node.args[0] - if not isinstance( + if not (to_erase.op == 'call_module' and isinstance( _get_attrs(prepared_model, to_erase.target), - FakeQuantizeBase): + FakeQuantizeBase)): continue if len(to_erase.users) > 1: continue @@ -113,7 +171,22 @@ def del_fakequant_before_module(prepared_model, module_patterns, inplace=True): return prepared_model -def del_fakequant_after_module(prepared_model, module_patterns, inplace=True): +def del_fakequant_after_module(prepared_model: torch.fx.GraphModule, + module_patterns: Tuple, + inplace: bool = True) -> torch.fx.GraphModule: + """Delete useless fakequant after modules whose type are in + `module_patterns`. + + Args: + prepared_model (GraphModule): Prepared standalone module. + target_patterns (tuple): Fakequants after modules whose type is in + `module_patterns` will be deleted. + inplace (bool): Can optionally do the operation in-place. + Defaults to True. + + Returns: + GraphModule: Prepared standalone module after deletion. + """ if not inplace: prepared_model = copy.deepcopy(prepared_model) new_graph = copy.deepcopy(prepared_model.graph) diff --git a/tests/test_models/test_task_modules/test_graph_utils.py b/tests/test_models/test_task_modules/test_graph_utils.py new file mode 100644 index 000000000..0bdae7cf3 --- /dev/null +++ b/tests/test_models/test_task_modules/test_graph_utils.py @@ -0,0 +1,307 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch +import torch.nn as nn +from torch.ao.quantization import QConfigMapping +from torch.ao.quantization.fake_quantize import FakeQuantizeBase +from torch.ao.quantization.fx import prepare +from torch.ao.quantization.quantize_fx import _fuse_fx + +from mmrazor.models.task_modules import build_graphmodule +from mmrazor.models.task_modules.tracer import CustomTracer +from mmrazor.models.task_modules.tracer.fx import (del_fakequant_after_module, + del_fakequant_after_target, + del_fakequant_before_module, + del_fakequant_before_target) +from mmrazor.models.utils import str2class +from mmrazor.structures.quantization import BackendConfigs, QConfigHander + + +def _get_attrs(target, attrs): + attrs = attrs.split('.') + + for att in attrs: + target = getattr(target, att, None) + return target + + +class BasicBlock(nn.Module): + + def __init__(self, in_channels, out_channels): + super(BasicBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.mid_channels = out_channels + + self.norm1 = nn.BatchNorm2d(self.mid_channels) + self.norm2 = nn.BatchNorm2d(out_channels) + self.conv1 = nn.Conv2d(in_channels, self.mid_channels, 1) + self.conv2 = nn.Conv2d(self.mid_channels, out_channels, 1) + + self.relu = nn.ReLU6() + self.drop_path = nn.Identity() + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + out = self.drop_path(out) + + out += identity + + return out + + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class ToyModel(nn.Module): + + def __init__(self): + super().__init__() + self.stem_layer = nn.Sequential( + nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3), nn.ReLU()) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.block = BasicBlock(3, 3) + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(3, 4) + + def forward(self, x): + x = self.stem_layer(x) + x = self.maxpool(x) + x = self.block(x) + x = self.gap(x) + x = x.flatten(1) + x = self.fc(x) + return x + + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1), +) + + +class TestGraphUtils(TestCase): + + def setUp(self): + self.tracer = CustomTracer() + self.backend_config = BackendConfigs['native'] + self.qconfig = QConfigHander(global_qconfig) + self.qconfig_mapping = QConfigMapping().set_global( + self.qconfig.convert()) + self.example_inputs = (torch.randn(1, 3, 224, 224), ) + + def swap_ff_with_fxff(self, model): + modules_to_swap = [] + for name, module in model.named_children(): + if isinstance(module, torch.ao.nn.quantized.FloatFunctional): + modules_to_swap.append(name) + else: + self.swap_ff_with_fxff(module) + + for name in modules_to_swap: + del model._modules[name] + model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional() + + def test_del_fakequant_before_target(self): + + model_to_quantize = ToyModel() + model_to_quantize.eval() + + self.swap_ff_with_fxff(model_to_quantize) + traced_graph = self.tracer.trace(model_to_quantize) + graph_module = build_graphmodule(model_to_quantize, traced_graph) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + + target_prev = ('output', ) + + prepared_after_del = del_fakequant_before_target( + prepared, target_prev, inplace=False) + for node in prepared.graph.nodes: + if node.target == 'output': + args = node.args + self.assertIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + + for node in prepared_after_del.graph.nodes: + if node.target == 'output': + args = node.args + self.assertEqual(len(args), 1) + self.assertEqual(args[0].target, 'fc') + self.assertIsInstance( + _get_attrs(prepared, args[0].target), nn.Linear) + + prepared_after_del = del_fakequant_before_target( + prepared, target_prev, inplace=True) + for node in prepared_after_del.graph.nodes: + if node.target == 'output': + args = node.args + self.assertEqual(len(args), 1) + self.assertEqual(args[0].target, 'fc') + + def test_del_fakequant_after_target(self): + model_to_quantize = ToyModel() + model_to_quantize.eval() + + self.swap_ff_with_fxff(model_to_quantize) + traced_graph = self.tracer.trace(model_to_quantize) + graph_module = build_graphmodule(model_to_quantize, traced_graph) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + + target_next = ('flatten', ) + + prepared_after_del = del_fakequant_after_target( + prepared, target_next, inplace=False) + for node in prepared.graph.nodes: + if node.target == 'fc': + args = node.args + self.assertIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + + for node in prepared_after_del.graph.nodes: + if node.target == 'fc': + args = node.args + self.assertEqual(len(args), 1) + self.assertEqual(args[0].target, 'flatten') + + prepared_after_del = del_fakequant_after_target( + prepared, target_next, inplace=True) + for node in prepared_after_del.graph.nodes: + if node.target == 'fc': + args = node.args + self.assertEqual(len(args), 1) + self.assertEqual(args[0].target, 'flatten') + + def test_del_fakequant_before_module(self): + model_to_quantize = ToyModel() + model_to_quantize.eval() + + self.swap_ff_with_fxff(model_to_quantize) + traced_graph = self.tracer.trace(model_to_quantize) + graph_module = build_graphmodule(model_to_quantize, traced_graph) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + + module_prev = ('torch.nn.ReLU6', 'torch.nn.Identity') + + prepared_after_del = del_fakequant_before_module( + prepared, str2class(module_prev), inplace=False) + for node in prepared.graph.nodes: + if node.op == 'call_module' and isinstance( + _get_attrs(prepared, node.target), str2class(module_prev)): + args = node.args + self.assertIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + + for node in prepared_after_del.graph.nodes: + if node.op == 'call_module' and isinstance( + _get_attrs(prepared, node.target), str2class(module_prev)): + args = node.args + if args[0].op == 'call_module': + self.assertNotIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + + prepared_after_del = del_fakequant_before_module( + prepared, str2class(module_prev), inplace=True) + for node in prepared_after_del.graph.nodes: + if node.op == 'call_module' and isinstance( + _get_attrs(prepared, node.target), str2class(module_prev)): + args = node.args + if args[0].op == 'call_module': + self.assertNotIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + + def test_del_fakequant_after_module(self): + model_to_quantize = ToyModel() + model_to_quantize.eval() + + self.swap_ff_with_fxff(model_to_quantize) + traced_graph = self.tracer.trace(model_to_quantize) + graph_module = build_graphmodule(model_to_quantize, traced_graph) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + + module_next = ('torch.nn.MaxPool2d', ) + + prepared_after_del = del_fakequant_after_module( + prepared, str2class(module_next), inplace=False) + for node in prepared.graph.nodes: + if node.target == 'block_conv1': + args = node.args + self.assertIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + + for node in prepared_after_del.graph.nodes: + if node.target == 'block_conv1': + args = node.args + self.assertIsInstance( + _get_attrs(prepared, args[0].target), nn.MaxPool2d) + + prepared_after_del = del_fakequant_after_module( + prepared, str2class(module_next), inplace=True) + for node in prepared_after_del.graph.nodes: + if node.target == 'block_conv1': + args = node.args + self.assertIsInstance( + _get_attrs(prepared, args[0].target), nn.MaxPool2d)