|
7 | 7 | # pyre-unsafe
|
8 | 8 |
|
9 | 9 |
|
10 |
| -from typing import Any, Dict, Tuple |
| 10 | +from typing import Any, Dict, List, Tuple |
11 | 11 |
|
12 | 12 | import torch
|
13 | 13 | from torch.fx import GraphModule
|
14 | 14 | from torch.fx.interpreter import Interpreter
|
15 | 15 |
|
16 | 16 |
|
| 17 | +class NodeFilter: |
| 18 | + """ |
| 19 | + A class used to filter nodes based on extensible criteria. |
| 20 | + Attributes: |
| 21 | + metadata_key (str): The key to look for in the node's metadata. |
| 22 | + op_type (str): The operation code to match. |
| 23 | + exclude_ops (List[str]): A list of operations to exclude from the filter. |
| 24 | + """ |
| 25 | + |
| 26 | + def __init__(self, metadata_key: str, op_type: str, exclude_ops: List[str] = None): |
| 27 | + self.metadata_key = metadata_key |
| 28 | + self.op_type = op_type |
| 29 | + self.exclude_ops = exclude_ops |
| 30 | + |
| 31 | + def matches(self, node: torch.fx.Node) -> bool: |
| 32 | + return ( |
| 33 | + node.meta.get(self.metadata_key) is not None |
| 34 | + and node.op == self.op_type |
| 35 | + and all(exclude_name not in node.name for exclude_name in self.exclude_ops) |
| 36 | + ) |
| 37 | + |
| 38 | + |
17 | 39 | class IntermediateOutputCapturer(Interpreter):
|
| 40 | + """ |
| 41 | + A class that captures intermediate outputs from a PyTorch graph module. |
| 42 | + Attributes: |
| 43 | + module (GraphModule): The graph module to capture outputs from. |
| 44 | + node_filters (List[NodeFilter]): A list of filters to apply to the nodes. |
| 45 | + """ |
| 46 | + |
18 | 47 | def __init__(self, module: GraphModule):
|
19 | 48 | super().__init__(module)
|
| 49 | + self.node_filters = [ |
| 50 | + NodeFilter("debug_handle", "call_function", exclude_ops=["getitem"]) |
| 51 | + ] |
20 | 52 |
|
| 53 | + # Runs the graph module and captures the intermediate outputs. |
21 | 54 | def run_and_capture(self, *args, **kwargs) -> Dict[Tuple[int, ...], Any]:
|
22 | 55 | captured_outputs = {}
|
23 | 56 |
|
24 | 57 | def capture_run_node(n: torch.fx.Node) -> Any:
|
25 | 58 | result = super(IntermediateOutputCapturer, self).run_node(n)
|
26 |
| - debug_handle = n.meta.get("debug_handle", None) |
27 |
| - if debug_handle is not None and n.op == "call_function": |
| 59 | + if all(filter.matches(n) for filter in self.node_filters): |
| 60 | + debug_handle = n.meta["debug_handle"] |
28 | 61 | # Convert the debug handle to a tuple to use as a dictionary key
|
29 | 62 | key = (
|
30 | 63 | (debug_handle,)
|
|
0 commit comments