Skip to content

Commit 5865083

Browse files
jerryzh168pytorchmergebot
authored andcommitted
[fx][subgraph_rewriter] Change match_filter to be a List in replace_pattern_with_filters (pytorch#87257)
Summary: att, this is experimental api so not marking it as bc-breaking. The match will be accepted only if all the filters in the list passes. Changing the filter arg to be list also allows us to pass in empty list that means no filter, which makes user code cleaner. Test Plan: python test/test_fx.py -k test_replace_pattern_with_filters Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: pytorch#87257 Approved by: https://github.com/SherlockNoMad
1 parent 195a13f commit 5865083

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

test/fx/test_subgraph_rewriter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,7 @@ def gemm_bias_mul_replacement_with_c(a, b, bias, c):
773773

774774
self.assertEqual(repalcement_node_found, 2)
775775

776-
def test_replace_pattern_with_filter(self):
776+
def test_replace_pattern_with_filters(self):
777777
class M(torch.nn.Module):
778778
def __init__(self):
779779
super().__init__()
@@ -833,10 +833,10 @@ def num_repalcement_node_found(traced):
833833

834834
# match with filter, should find 1 match
835835
traced = symbolic_trace(M())
836-
matches = subgraph_rewriter.replace_pattern_with_filter(
836+
matches = subgraph_rewriter.replace_pattern_with_filters(
837837
traced,
838838
BinaryOpScalarReLUPattern,
839839
BinaryOpScalarReLUReplacement,
840-
second_input_is_scalar)
840+
[second_input_is_scalar])
841841
self.assertEqual(len(matches), 1)
842842
self.assertEqual(num_repalcement_node_found(traced), 1)

torch/fx/subgraph_rewriter.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Callable, Dict, List, NamedTuple, Optional, Set
99
import torch
1010

11-
__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filter']
11+
__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters']
1212

1313
@compatibility(is_backward_compatible=True)
1414
class Match(NamedTuple):
@@ -185,11 +185,11 @@ def forward(self, x, w1, w2):
185185

186186
# Experimental API, not backward compatible
187187
@compatibility(is_backward_compatible=False)
188-
def replace_pattern_with_filter(
188+
def replace_pattern_with_filters(
189189
gm: GraphModule,
190190
pattern: Callable,
191191
replacement: Callable,
192-
match_filter: Callable[["InternalMatch", Graph, Graph], bool], # type: ignore[name-defined]
192+
match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]], # type: ignore[name-defined]
193193
) -> List[Match]:
194194
"""
195195
See replace_pattern for documentation. This function is an overload with an additional match_filter argument.
@@ -200,18 +200,21 @@ def replace_pattern_with_filter(
200200
definition of InternalMatch.
201201
"""
202202

203-
return _replace_pattern(gm, pattern, replacement, match_filter)
203+
return _replace_pattern(gm, pattern, replacement, match_filters)
204204

205205

206206
def _replace_pattern(
207207
gm: GraphModule,
208208
pattern: Callable,
209209
replacement: Callable,
210-
match_filter: Optional[Callable[["InternalMatch", Graph, Graph], bool]] = None # type: ignore[name-defined]
210+
match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]] = None # type: ignore[name-defined]
211211
) -> List[Match]:
212212

213213
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch
214214

215+
if match_filters is None:
216+
match_filters = []
217+
215218
# Get the graphs for `gm`, `pattern`, `replacement`
216219
original_graph: Graph = gm.graph
217220
pattern_graph: Graph = symbolic_trace(pattern).graph
@@ -222,8 +225,11 @@ def _replace_pattern(
222225
_matches: List[InternalMatch] = matcher.match(original_graph)
223226

224227
# Filter out matches that don't match the filter
225-
if match_filter:
226-
_matches = [m for m in _matches if match_filter(m, original_graph, pattern_graph)]
228+
_matches = [
229+
m for m in _matches
230+
if all(match_filter(m, original_graph, pattern_graph)
231+
for match_filter in match_filters)
232+
]
227233

228234
replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"]
229235

0 commit comments

Comments
 (0)