Skip to content

Commit 2642e47

Browse files
authored
Make the intermediate output capturer more extensible
Differential Revision: D75699297 Pull Request resolved: #11289
1 parent 1a11267 commit 2642e47

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

devtools/inspector/_intermediate_output_capturer.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,57 @@
77
# pyre-unsafe
88

99

10-
from typing import Any, Dict, Tuple
10+
from typing import Any, Dict, List, Tuple
1111

1212
import torch
1313
from torch.fx import GraphModule
1414
from torch.fx.interpreter import Interpreter
1515

1616

17+
class NodeFilter:
18+
"""
19+
A class used to filter nodes based on extensible criteria.
20+
Attributes:
21+
metadata_key (str): The key to look for in the node's metadata.
22+
op_type (str): The operation code to match.
23+
exclude_ops (List[str]): A list of operations to exclude from the filter.
24+
"""
25+
26+
def __init__(self, metadata_key: str, op_type: str, exclude_ops: List[str] = None):
27+
self.metadata_key = metadata_key
28+
self.op_type = op_type
29+
self.exclude_ops = exclude_ops
30+
31+
def matches(self, node: torch.fx.Node) -> bool:
32+
return (
33+
node.meta.get(self.metadata_key) is not None
34+
and node.op == self.op_type
35+
and all(exclude_name not in node.name for exclude_name in self.exclude_ops)
36+
)
37+
38+
1739
class IntermediateOutputCapturer(Interpreter):
40+
"""
41+
A class that captures intermediate outputs from a PyTorch graph module.
42+
Attributes:
43+
module (GraphModule): The graph module to capture outputs from.
44+
node_filters (List[NodeFilter]): A list of filters to apply to the nodes.
45+
"""
46+
1847
def __init__(self, module: GraphModule):
1948
super().__init__(module)
49+
self.node_filters = [
50+
NodeFilter("debug_handle", "call_function", exclude_ops=["getitem"])
51+
]
2052

53+
# Runs the graph module and captures the intermediate outputs.
2154
def run_and_capture(self, *args, **kwargs) -> Dict[Tuple[int, ...], Any]:
2255
captured_outputs = {}
2356

2457
def capture_run_node(n: torch.fx.Node) -> Any:
2558
result = super(IntermediateOutputCapturer, self).run_node(n)
26-
debug_handle = n.meta.get("debug_handle", None)
27-
if debug_handle is not None and n.op == "call_function":
59+
if all(filter.matches(n) for filter in self.node_filters):
60+
debug_handle = n.meta["debug_handle"]
2861
# Convert the debug handle to a tuple to use as a dictionary key
2962
key = (
3063
(debug_handle,)

devtools/inspector/tests/intermediate_output_capturer_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,6 @@ def test_capture_correct_outputs(self):
111111
(19,): torch.tensor([[3.6000, 4.5067]]),
112112
(20,): torch.tensor([[0.9734, 0.9891]]),
113113
(21,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
114-
(22,): torch.tensor([[0.9734]]),
115-
(23,): torch.tensor([[0.9891]]),
116114
}
117115
self.assertEqual(
118116
len(self.intermediate_outputs), len(expected_outputs_with_handles)

0 commit comments

Comments
 (0)