diff --git a/programl/__init__.py b/programl/__init__.py index c049afb2..065068ab 100644 --- a/programl/__init__.py +++ b/programl/__init__.py @@ -52,7 +52,7 @@ to_bytes, to_string, ) -from programl.transform_ops import to_dgl, to_dot, to_json, to_networkx +from programl.transform_ops import to_dgl, to_dot, to_json, to_networkx, to_pyg from programl.util.py.runfiles_path import runfiles_path from programl.version import PROGRAML_VERSION @@ -84,6 +84,7 @@ "to_dot", "to_json", "to_networkx", + "to_pyg", "to_string", "UnsupportedCompiler", ] diff --git a/programl/requirements.txt b/programl/requirements.txt index db4d6937..3900bc9a 100644 --- a/programl/requirements.txt +++ b/programl/requirements.txt @@ -5,4 +5,5 @@ networkx>=2.4 numpy>=1.19.3 protobuf>=3.13.0,<4.21.0 torch>=1.8.0 +torch_geometric==2.4.0 tqdm>=4.38.0 diff --git a/programl/transform_ops.py b/programl/transform_ops.py index 6c76734b..bc9d86f7 100644 --- a/programl/transform_ops.py +++ b/programl/transform_ops.py @@ -18,11 +18,13 @@ """ import json import subprocess +import torch from typing import Any, Dict, Iterable, Optional, Union import dgl import networkx as nx from dgl.heterograph import DGLHeteroGraph +from torch_geometric.data import HeteroData from networkx.readwrite import json_graph as nx_json from programl.exceptions import GraphTransformError @@ -258,3 +260,122 @@ def _run_one(graph: ProgramGraph) -> str: if isinstance(graphs, ProgramGraph): return _run_one(graphs) return execute(_run_one, graphs, executor, chunksize) + + +def to_pyg( + graphs: Union[ProgramGraph, Iterable[ProgramGraph]], + timeout: int = 300, + vocabulary: Optional[Dict[str, int]] = None, + executor: Optional[ExecutorLike] = None, + chunksize: Optional[int] = None, +) -> Union[HeteroData, Iterable[HeteroData]]: + """Convert one or more Program Graphs to Pytorch-Geometrics's HeteroData. + This graphs can be used as input for any deep learning model built with + Pytorch-Geometric: + + https://pytorch-geometric.readthedocs.io/en/latest/tutorial/heterogeneous.html + + :param graphs: A Program Graph, or a sequence of Program Graphs. + + :param timeout: The maximum number of seconds to wait for an individual + graph conversion before raising an error. If multiple inputs are + provided, this timeout is per-input. + + :param vocabulary: A dictionary containing ProGraML's vocabulary, where the + keys are the text attribute of the nodes and the values their respective + indexes. + + :param executor: An executor object, with method :code:`submit(callable, + *args, **kwargs)` and returning a Future-like object with methods + :code:`done() -> bool` and :code:`result() -> float`. The executor role + is to dispatch the execution of the jobs locally/on a cluster/with + multithreading depending on the implementation. Eg: + :code:`concurrent.futures.ThreadPoolExecutor`. Defaults to single + threaded execution. This is only used when multiple inputs are given. + + :param chunksize: The number of inputs to read and process at a time. A + larger chunksize improves parallelism but increases memory consumption + as more inputs must be stored in memory. This is only used when multiple + inputs are given. + + :return: A HeteroData graph when a single input is provided, else an + iterable sequence of HeteroData graphs. + """ + + def _run_one(graph: ProgramGraph) -> HeteroData: + # 4 lists, one per edge type + # (control, data, call and type edges) + adjacencies = [[], [], [], []] + edge_positions = [[], [], [], []] + + # Create the adjacency lists and the positions + for edge in graph.edge: + adjacencies[edge.flow].append([edge.source, edge.target]) + edge_positions[edge.flow].append(edge.position) + + # Store the text attributes + node_text_list = [] + node_full_text_list = [] + + # Store the text and full text attributes + for node in graph.node: + node_text = node_full_text = node.text + + if ( + node.features + and node.features.feature["full_text"].bytes_list.value + ): + node_full_text = node.features.feature["full_text"].bytes_list.value[0] + + node_text_list.append(node_text) + node_full_text_list.append(node_full_text) + + + vocab_ids = None + if vocabulary is not None: + vocab_ids = [ + vocabulary.get(node.text, len(vocabulary.keys())) + for node in graph.node + ] + + # Pass from list to tensor + adjacencies = [torch.tensor(adj_flow_type) for adj_flow_type in adjacencies] + edge_positions = [torch.tensor(edge_pos_flow_type) for edge_pos_flow_type in edge_positions] + + if vocabulary is not None: + vocab_ids = torch.tensor(vocab_ids) + + # Create the graph structure + hetero_graph = HeteroData() + + # Vocabulary index of each node + hetero_graph['nodes']['text'] = node_text_list + hetero_graph['nodes']['full_text'] = node_full_text_list + hetero_graph['nodes'].x = vocab_ids + + # Add the adjacency lists + hetero_graph['nodes', 'control', 'nodes'].edge_index = ( + adjacencies[0].t().contiguous() if adjacencies[0].nelement() > 0 else torch.tensor([[], []]) + ) + hetero_graph['nodes', 'data', 'nodes'].edge_index = ( + adjacencies[1].t().contiguous() if adjacencies[1].nelement() > 0 else torch.tensor([[], []]) + ) + hetero_graph['nodes', 'call', 'nodes'].edge_index = ( + adjacencies[2].t().contiguous() if adjacencies[2].nelement() > 0 else torch.tensor([[], []]) + ) + hetero_graph['nodes', 'type', 'nodes'].edge_index = ( + adjacencies[3].t().contiguous() if adjacencies[3].nelement() > 0 else torch.tensor([[], []]) + ) + + # Add the edge positions + hetero_graph['nodes', 'control', 'nodes'].edge_attr = edge_positions[0] + hetero_graph['nodes', 'data', 'nodes'].edge_attr = edge_positions[1] + hetero_graph['nodes', 'call', 'nodes'].edge_attr = edge_positions[2] + hetero_graph['nodes', 'type', 'nodes'].edge_attr = edge_positions[3] + + return hetero_graph + + if isinstance(graphs, ProgramGraph): + return _run_one(graphs) + + return execute(_run_one, graphs, executor, chunksize) diff --git a/tests/BUILD b/tests/BUILD index 211ec444..e522a56d 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -134,3 +134,14 @@ py_test( "//tests/plugins", ], ) + +py_test( + name = "to_pyg_test", + srcs = ["to_pyg_test.py"], + shard_count = 8, + deps = [ + "//programl", + "//tests:test_main", + "//tests/plugins", + ], +) diff --git a/tests/to_pyg_test.py b/tests/to_pyg_test.py new file mode 100644 index 00000000..56cdfe57 --- /dev/null +++ b/tests/to_pyg_test.py @@ -0,0 +1,122 @@ +# Copyright 2019-2020 the ProGraML authors. +# +# Contact Chris Cummins . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from concurrent.futures.thread import ThreadPoolExecutor + +import networkx as nx +import pytest + +import programl as pg +from torch_geometric.data import HeteroData +from tests.test_main import main + +pytest_plugins = ["tests.plugins.llvm_program_graph"] + + +@pytest.fixture(scope="session") +def graph() -> pg.ProgramGraph: + return pg.from_cpp("int A() { return 0; }") + +@pytest.fixture(scope="session") +def graph2() -> pg.ProgramGraph: + return pg.from_cpp("int B() { return 1; }") + +@pytest.fixture(scope="session") +def graph3() -> pg.ProgramGraph: + return pg.from_cpp("int B(int x) { return x + 1; }") + +def graphs_are_equal( + graph1: HeteroData, + graph2: HeteroData, +): + return ( + (graph1['nodes']['full_text'] == graph2['nodes']['full_text']) + and (graph1['nodes', 'control', 'nodes'].edge_index.equal(graph2['nodes', 'control', 'nodes'].edge_index)) + and (graph1['nodes', 'data', 'nodes'].edge_index.equal(graph2['nodes', 'data', 'nodes'].edge_index)) + and (graph1['nodes', 'call', 'nodes'].edge_index.equal(graph2['nodes', 'call', 'nodes'].edge_index)) + and (graph1['nodes', 'type', 'nodes'].edge_index.equal(graph2['nodes', 'type', 'nodes'].edge_index)) + ) + +def test_to_pyg_simple_graph(graph: pg.ProgramGraph): + graphs = list(pg.to_pyg([graph])) + assert len(graphs) == 1 + assert isinstance(graphs[0], HeteroData) + +def test_to_pyg_simple_graph_single_input(graph: pg.ProgramGraph): + pyg_graph = pg.to_pyg(graph) + assert isinstance(pyg_graph, HeteroData) + +def test_to_pyg_different_two_different_inputs( + graph: pg.ProgramGraph, + graph2: pg.ProgramGraph, +): + pyg_graph = pg.to_pyg(graph) + pyg_graph2 = pg.to_pyg(graph2) + + # Ensure that the graphs are different + assert not graphs_are_equal(pyg_graph, pyg_graph2) + +def test_to_pyg_different_inputs( + graph: pg.ProgramGraph, + graph2: pg.ProgramGraph, + graph3: pg.ProgramGraph +): + pyg_graph = pg.to_pyg(graph) + pyg_graph2 = pg.to_pyg(graph2) + pyg_graph3 = pg.to_pyg(graph3) + + # Ensure that the graphs are different + assert not graphs_are_equal(pyg_graph, pyg_graph2) + assert not graphs_are_equal(pyg_graph, pyg_graph3) + assert not graphs_are_equal(pyg_graph2, pyg_graph3) + +def test_to_pyg_two_inputs(graph: pg.ProgramGraph): + graphs = list(pg.to_pyg([graph, graph])) + assert len(graphs) == 2 + assert graphs_are_equal(graphs[0], graphs[1]) + +def test_to_pyg_generator(graph: pg.ProgramGraph): + graphs = list(pg.to_pyg((graph for _ in range(10)), chunksize=3)) + assert len(graphs) == 10 + for x in graphs[1:]: + assert graphs_are_equal(graphs[0], x) + +def test_to_pyg_generator_parallel_executor(graph: pg.ProgramGraph): + with ThreadPoolExecutor() as executor: + graphs = list( + pg.to_pyg((graph for _ in range(10)), chunksize=3, executor=executor) + ) + assert len(graphs) == 10 + for x in graphs[1:]: + assert graphs_are_equal(graphs[0], x) + +def test_to_pyg_smoke_test(llvm_program_graph: pg.ProgramGraph): + graphs = list(pg.to_pyg([llvm_program_graph])) + + num_nodes = len(graphs[0]['nodes']['text']) + num_control_edges = graphs[0]['nodes', 'control', 'nodes'].edge_index.size(1) + num_data_edges = graphs[0]['nodes', 'data', 'nodes'].edge_index.size(1) + num_call_edges = graphs[0]['nodes', 'call', 'nodes'].edge_index.size(1) + num_type_edges = graphs[0]['nodes', 'type', 'nodes'].edge_index.size(1) + num_edges = num_control_edges + num_data_edges + num_call_edges + num_type_edges + + assert len(graphs) == 1 + assert isinstance(graphs[0], HeteroData) + assert num_nodes == len(llvm_program_graph.node) + assert num_edges <= len(llvm_program_graph.edge) + + +if __name__ == "__main__": + main() \ No newline at end of file