From 58650835bb91d927623e6bff5cc4844fbcad6368 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 26 Oct 2022 14:43:42 -0700 Subject: [PATCH] [fx][subgraph_rewriter] Change match_filter to be a List in replace_pattern_with_filters (#87257) Summary: att, this is experimental api so not marking it as bc-breaking. The match will be accepted only if all the filters in the list passes. Changing the filter arg to be list also allows us to pass in empty list that means no filter, which makes user code cleaner. Test Plan: python test/test_fx.py -k test_replace_pattern_with_filters Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/87257 Approved by: https://github.com/SherlockNoMad --- test/fx/test_subgraph_rewriter.py | 6 +++--- torch/fx/subgraph_rewriter.py | 20 +++++++++++++------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/test/fx/test_subgraph_rewriter.py b/test/fx/test_subgraph_rewriter.py index ac3498458d6007..ed6d50e44b4ac9 100644 --- a/test/fx/test_subgraph_rewriter.py +++ b/test/fx/test_subgraph_rewriter.py @@ -773,7 +773,7 @@ def gemm_bias_mul_replacement_with_c(a, b, bias, c): self.assertEqual(repalcement_node_found, 2) - def test_replace_pattern_with_filter(self): + def test_replace_pattern_with_filters(self): class M(torch.nn.Module): def __init__(self): super().__init__() @@ -833,10 +833,10 @@ def num_repalcement_node_found(traced): # match with filter, should find 1 match traced = symbolic_trace(M()) - matches = subgraph_rewriter.replace_pattern_with_filter( + matches = subgraph_rewriter.replace_pattern_with_filters( traced, BinaryOpScalarReLUPattern, BinaryOpScalarReLUReplacement, - second_input_is_scalar) + [second_input_is_scalar]) self.assertEqual(len(matches), 1) self.assertEqual(num_repalcement_node_found(traced), 1) diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index 09e5550c5930dd..72bb7fd3735162 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -8,7 +8,7 @@ from typing import Callable, Dict, List, NamedTuple, Optional, Set import torch -__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filter'] +__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters'] @compatibility(is_backward_compatible=True) class Match(NamedTuple): @@ -185,11 +185,11 @@ def forward(self, x, w1, w2): # Experimental API, not backward compatible @compatibility(is_backward_compatible=False) -def replace_pattern_with_filter( +def replace_pattern_with_filters( gm: GraphModule, pattern: Callable, replacement: Callable, - match_filter: Callable[["InternalMatch", Graph, Graph], bool], # type: ignore[name-defined] + match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]], # type: ignore[name-defined] ) -> List[Match]: """ See replace_pattern for documentation. This function is an overload with an additional match_filter argument. @@ -200,18 +200,21 @@ def replace_pattern_with_filter( definition of InternalMatch. """ - return _replace_pattern(gm, pattern, replacement, match_filter) + return _replace_pattern(gm, pattern, replacement, match_filters) def _replace_pattern( gm: GraphModule, pattern: Callable, replacement: Callable, - match_filter: Optional[Callable[["InternalMatch", Graph, Graph], bool]] = None # type: ignore[name-defined] + match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]] = None # type: ignore[name-defined] ) -> List[Match]: from torch.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch + if match_filters is None: + match_filters = [] + # Get the graphs for `gm`, `pattern`, `replacement` original_graph: Graph = gm.graph pattern_graph: Graph = symbolic_trace(pattern).graph @@ -222,8 +225,11 @@ def _replace_pattern( _matches: List[InternalMatch] = matcher.match(original_graph) # Filter out matches that don't match the filter - if match_filter: - _matches = [m for m in _matches if match_filter(m, original_graph, pattern_graph)] + _matches = [ + m for m in _matches + if all(match_filter(m, original_graph, pattern_graph) + for match_filter in match_filters) + ] replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"]