diff --git a/test/fx/test_subgraph_rewriter.py b/test/fx/test_subgraph_rewriter.py index ac3498458d600..ed6d50e44b4ac 100644 --- a/test/fx/test_subgraph_rewriter.py +++ b/test/fx/test_subgraph_rewriter.py @@ -773,7 +773,7 @@ def gemm_bias_mul_replacement_with_c(a, b, bias, c): self.assertEqual(repalcement_node_found, 2) - def test_replace_pattern_with_filter(self): + def test_replace_pattern_with_filters(self): class M(torch.nn.Module): def __init__(self): super().__init__() @@ -833,10 +833,10 @@ def num_repalcement_node_found(traced): # match with filter, should find 1 match traced = symbolic_trace(M()) - matches = subgraph_rewriter.replace_pattern_with_filter( + matches = subgraph_rewriter.replace_pattern_with_filters( traced, BinaryOpScalarReLUPattern, BinaryOpScalarReLUReplacement, - second_input_is_scalar) + [second_input_is_scalar]) self.assertEqual(len(matches), 1) self.assertEqual(num_repalcement_node_found(traced), 1) diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index 09e5550c5930d..72bb7fd373516 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -8,7 +8,7 @@ from typing import Callable, Dict, List, NamedTuple, Optional, Set import torch -__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filter'] +__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters'] @compatibility(is_backward_compatible=True) class Match(NamedTuple): @@ -185,11 +185,11 @@ def forward(self, x, w1, w2): # Experimental API, not backward compatible @compatibility(is_backward_compatible=False) -def replace_pattern_with_filter( +def replace_pattern_with_filters( gm: GraphModule, pattern: Callable, replacement: Callable, - match_filter: Callable[["InternalMatch", Graph, Graph], bool], # type: ignore[name-defined] + match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]], # type: ignore[name-defined] ) -> List[Match]: """ See replace_pattern for documentation. This function is an overload with an additional match_filter argument. @@ -200,18 +200,21 @@ def replace_pattern_with_filter( definition of InternalMatch. """ - return _replace_pattern(gm, pattern, replacement, match_filter) + return _replace_pattern(gm, pattern, replacement, match_filters) def _replace_pattern( gm: GraphModule, pattern: Callable, replacement: Callable, - match_filter: Optional[Callable[["InternalMatch", Graph, Graph], bool]] = None # type: ignore[name-defined] + match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]] = None # type: ignore[name-defined] ) -> List[Match]: from torch.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch + if match_filters is None: + match_filters = [] + # Get the graphs for `gm`, `pattern`, `replacement` original_graph: Graph = gm.graph pattern_graph: Graph = symbolic_trace(pattern).graph @@ -222,8 +225,11 @@ def _replace_pattern( _matches: List[InternalMatch] = matcher.match(original_graph) # Filter out matches that don't match the filter - if match_filter: - _matches = [m for m in _matches if match_filter(m, original_graph, pattern_graph)] + _matches = [ + m for m in _matches + if all(match_filter(m, original_graph, pattern_graph) + for match_filter in match_filters) + ] replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"]