Skip to content

Commit

Permalink
[fx][subgraph_rewriter] Change match_filter to be a List in replace_p…
Browse files Browse the repository at this point in the history
…attern_with_filters (pytorch#87257)

Summary:
att, this is experimental api so not marking it as bc-breaking.
The match will be accepted only if all the filters in the list passes.
Changing the filter arg to be list also allows us to pass in empty list that means no filter, which makes user code cleaner.

Test Plan:
python test/test_fx.py -k test_replace_pattern_with_filters

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: pytorch#87257
Approved by: https://github.com/SherlockNoMad
  • Loading branch information
jerryzh168 authored and pytorchmergebot committed Oct 27, 2022
1 parent 195a13f commit 5865083
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
6 changes: 3 additions & 3 deletions test/fx/test_subgraph_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ def gemm_bias_mul_replacement_with_c(a, b, bias, c):

self.assertEqual(repalcement_node_found, 2)

def test_replace_pattern_with_filter(self):
def test_replace_pattern_with_filters(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -833,10 +833,10 @@ def num_repalcement_node_found(traced):

# match with filter, should find 1 match
traced = symbolic_trace(M())
matches = subgraph_rewriter.replace_pattern_with_filter(
matches = subgraph_rewriter.replace_pattern_with_filters(
traced,
BinaryOpScalarReLUPattern,
BinaryOpScalarReLUReplacement,
second_input_is_scalar)
[second_input_is_scalar])
self.assertEqual(len(matches), 1)
self.assertEqual(num_repalcement_node_found(traced), 1)
20 changes: 13 additions & 7 deletions torch/fx/subgraph_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Callable, Dict, List, NamedTuple, Optional, Set
import torch

__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filter']
__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters']

@compatibility(is_backward_compatible=True)
class Match(NamedTuple):
Expand Down Expand Up @@ -185,11 +185,11 @@ def forward(self, x, w1, w2):

# Experimental API, not backward compatible
@compatibility(is_backward_compatible=False)
def replace_pattern_with_filter(
def replace_pattern_with_filters(
gm: GraphModule,
pattern: Callable,
replacement: Callable,
match_filter: Callable[["InternalMatch", Graph, Graph], bool], # type: ignore[name-defined]
match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]], # type: ignore[name-defined]
) -> List[Match]:
"""
See replace_pattern for documentation. This function is an overload with an additional match_filter argument.
Expand All @@ -200,18 +200,21 @@ def replace_pattern_with_filter(
definition of InternalMatch.
"""

return _replace_pattern(gm, pattern, replacement, match_filter)
return _replace_pattern(gm, pattern, replacement, match_filters)


def _replace_pattern(
gm: GraphModule,
pattern: Callable,
replacement: Callable,
match_filter: Optional[Callable[["InternalMatch", Graph, Graph], bool]] = None # type: ignore[name-defined]
match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]] = None # type: ignore[name-defined]
) -> List[Match]:

from torch.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch

if match_filters is None:
match_filters = []

# Get the graphs for `gm`, `pattern`, `replacement`
original_graph: Graph = gm.graph
pattern_graph: Graph = symbolic_trace(pattern).graph
Expand All @@ -222,8 +225,11 @@ def _replace_pattern(
_matches: List[InternalMatch] = matcher.match(original_graph)

# Filter out matches that don't match the filter
if match_filter:
_matches = [m for m in _matches if match_filter(m, original_graph, pattern_graph)]
_matches = [
m for m in _matches
if all(match_filter(m, original_graph, pattern_graph)
for match_filter in match_filters)
]

replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"]

Expand Down

0 comments on commit 5865083

Please sign in to comment.