Skip to content

Commit

Permalink
add graph utils pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
HIT-cwh committed Dec 23, 2022
1 parent 297a2ab commit 5707d07
Show file tree
Hide file tree
Showing 3 changed files with 392 additions and 13 deletions.
1 change: 0 additions & 1 deletion mmrazor/models/quantizers/openvino_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def prepare(self, model, graph_module):
if target_next:
prepared = del_fakequant_after_target(
prepared, target_next, inplace=True)
print(prepared)

return prepared

Expand Down
97 changes: 85 additions & 12 deletions mmrazor/models/task_modules/tracer/fx/graph_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,45 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Any, List, Tuple

import torch.fx
from torch.ao.quantization.fake_quantize import FakeQuantizeBase


def _get_attrs(target, attrs):
def _get_attrs(target: torch.nn.Module, attr: str) -> Any:
"""Get the attribute from target.
attrs = attrs.split('.')
Args:
target (torch.nn.Module): Get the attribute from target module.
attr (str): The target attribute.
Returns:
Any: The target attribute.
"""

attrs: List[str] = attr.split('.')

for att in attrs:
target = getattr(target, att, None)
return target


def del_fakequant_before_target(prepared_model, target_patterns, inplace=True):
def del_fakequant_before_target(prepared_model: torch.fx.GraphModule,
target_patterns: Tuple,
inplace: bool = True) -> torch.fx.GraphModule:
"""Delete useless fakequant before nodes whose target attribute
(node.target) is in `target_patterns`.
Args:
prepared_model (GraphModule): Prepared standalone module.
target_patterns (tuple): Fakequants before nodes whose target attribute
(node.target) is in `target_patterns` will be deleted.
inplace (bool): Can optionally do the operation in-place. Defaults to
True.
Returns:
GraphModule: Prepared standalone module after deletion.
"""

def recursive_find_erased_nodes(node):
"""Find FakeQuant before target node recursively.
Expand All @@ -40,22 +66,24 @@ def recursive_find_erased_nodes(node):
"""
if node is None:
return
if isinstance(
if node.op == 'call_module' and isinstance(
_get_attrs(prepared_model, node.target), FakeQuantizeBase):
nodes_to_erase.append(node)
return
for prev_node in node.args:
recursive_find_erased_nodes(prev_node)
if isinstance(prev_node, torch.fx.Node):
recursive_find_erased_nodes(prev_node)
for prev_node in node.kwargs.values():
recursive_find_erased_nodes(prev_node)
if isinstance(prev_node, torch.fx.Node):
recursive_find_erased_nodes(prev_node)
return

if not inplace:
prepared_model = copy.deepcopy(prepared_model)
new_graph = copy.deepcopy(prepared_model.graph)
for node in new_graph.nodes:
if node.target in target_patterns:
nodes_to_erase = []
nodes_to_erase: List[torch.fx.Node] = []
recursive_find_erased_nodes(node)
for to_erase in nodes_to_erase:
to_erase.replace_all_uses_with(to_erase.args[0])
Expand All @@ -66,7 +94,22 @@ def recursive_find_erased_nodes(node):
return prepared_model


def del_fakequant_after_target(prepared_model, target_patterns, inplace=True):
def del_fakequant_after_target(prepared_model: torch.fx.GraphModule,
target_patterns: Tuple,
inplace: bool = True) -> torch.fx.GraphModule:
"""Delete useless fakequant after nodes whose target attribute
(node.target) is in `target_patterns`.
Args:
prepared_model (GraphModule): Prepared standalone module.
target_patterns (tuple): Fakequants after nodes whose target attribute
(node.target) is in `target_patterns` will be deleted.
inplace (bool): Can optionally do the operation in-place. Defaults to
True.
Returns:
GraphModule: Prepared standalone module after deletion.
"""
if not inplace:
prepared_model = copy.deepcopy(prepared_model)
new_graph = copy.deepcopy(prepared_model.graph)
Expand All @@ -91,17 +134,32 @@ def del_fakequant_after_target(prepared_model, target_patterns, inplace=True):
return prepared_model


def del_fakequant_before_module(prepared_model, module_patterns, inplace=True):
def del_fakequant_before_module(prepared_model: torch.fx.GraphModule,
module_patterns: Tuple,
inplace: bool = True) -> torch.fx.GraphModule:
"""Delete useless fakequant before modules whose type are in
`module_patterns`.
Args:
prepared_model (GraphModule): Prepared standalone module.
target_patterns (tuple): Fakequants before modules whose type is in
`module_patterns` will be deleted.
inplace (bool): Can optionally do the operation in-place.
Defaults to True.
Returns:
GraphModule: Prepared standalone module after deletion.
"""
if not inplace:
prepared_model = copy.deepcopy(prepared_model)
new_graph = copy.deepcopy(prepared_model.graph)
for node in new_graph.nodes:
if node.op == 'call_module' and isinstance(
_get_attrs(prepared_model, node.target), module_patterns):
to_erase = node.args[0]
if not isinstance(
if not (to_erase.op == 'call_module' and isinstance(
_get_attrs(prepared_model, to_erase.target),
FakeQuantizeBase):
FakeQuantizeBase)):
continue
if len(to_erase.users) > 1:
continue
Expand All @@ -113,7 +171,22 @@ def del_fakequant_before_module(prepared_model, module_patterns, inplace=True):
return prepared_model


def del_fakequant_after_module(prepared_model, module_patterns, inplace=True):
def del_fakequant_after_module(prepared_model: torch.fx.GraphModule,
module_patterns: Tuple,
inplace: bool = True) -> torch.fx.GraphModule:
"""Delete useless fakequant after modules whose type are in
`module_patterns`.
Args:
prepared_model (GraphModule): Prepared standalone module.
target_patterns (tuple): Fakequants after modules whose type is in
`module_patterns` will be deleted.
inplace (bool): Can optionally do the operation in-place.
Defaults to True.
Returns:
GraphModule: Prepared standalone module after deletion.
"""
if not inplace:
prepared_model = copy.deepcopy(prepared_model)
new_graph = copy.deepcopy(prepared_model.graph)
Expand Down
Loading

0 comments on commit 5707d07

Please sign in to comment.