Skip to content

Commit ac378bc

Browse files
authored
Merge pull request #216 from igabirondo16/feature/pytorch_geometric
Feature/pytorch geometric
2 parents d5df3cc + a066ffb commit ac378bc

File tree

5 files changed

+257
-1
lines changed

5 files changed

+257
-1
lines changed

programl/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
to_bytes,
5353
to_string,
5454
)
55-
from programl.transform_ops import to_dgl, to_dot, to_json, to_networkx
55+
from programl.transform_ops import to_dgl, to_dot, to_json, to_networkx, to_pyg
5656
from programl.util.py.runfiles_path import runfiles_path
5757
from programl.version import PROGRAML_VERSION
5858

@@ -84,6 +84,7 @@
8484
"to_dot",
8585
"to_json",
8686
"to_networkx",
87+
"to_pyg",
8788
"to_string",
8889
"UnsupportedCompiler",
8990
]

programl/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ networkx>=2.4
55
numpy>=1.19.3
66
protobuf>=3.13.0,<4.21.0
77
torch>=1.8.0
8+
torch_geometric==2.4.0
89
tqdm>=4.38.0

programl/transform_ops.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818
"""
1919
import json
2020
import subprocess
21+
import torch
2122
from typing import Any, Dict, Iterable, Optional, Union
2223

2324
import dgl
2425
import networkx as nx
2526
from dgl.heterograph import DGLHeteroGraph
27+
from torch_geometric.data import HeteroData
2628
from networkx.readwrite import json_graph as nx_json
2729

2830
from programl.exceptions import GraphTransformError
@@ -258,3 +260,122 @@ def _run_one(graph: ProgramGraph) -> str:
258260
if isinstance(graphs, ProgramGraph):
259261
return _run_one(graphs)
260262
return execute(_run_one, graphs, executor, chunksize)
263+
264+
265+
def to_pyg(
266+
graphs: Union[ProgramGraph, Iterable[ProgramGraph]],
267+
timeout: int = 300,
268+
vocabulary: Optional[Dict[str, int]] = None,
269+
executor: Optional[ExecutorLike] = None,
270+
chunksize: Optional[int] = None,
271+
) -> Union[HeteroData, Iterable[HeteroData]]:
272+
"""Convert one or more Program Graphs to Pytorch-Geometrics's HeteroData.
273+
This graphs can be used as input for any deep learning model built with
274+
Pytorch-Geometric:
275+
276+
https://pytorch-geometric.readthedocs.io/en/latest/tutorial/heterogeneous.html
277+
278+
:param graphs: A Program Graph, or a sequence of Program Graphs.
279+
280+
:param timeout: The maximum number of seconds to wait for an individual
281+
graph conversion before raising an error. If multiple inputs are
282+
provided, this timeout is per-input.
283+
284+
:param vocabulary: A dictionary containing ProGraML's vocabulary, where the
285+
keys are the text attribute of the nodes and the values their respective
286+
indexes.
287+
288+
:param executor: An executor object, with method :code:`submit(callable,
289+
*args, **kwargs)` and returning a Future-like object with methods
290+
:code:`done() -> bool` and :code:`result() -> float`. The executor role
291+
is to dispatch the execution of the jobs locally/on a cluster/with
292+
multithreading depending on the implementation. Eg:
293+
:code:`concurrent.futures.ThreadPoolExecutor`. Defaults to single
294+
threaded execution. This is only used when multiple inputs are given.
295+
296+
:param chunksize: The number of inputs to read and process at a time. A
297+
larger chunksize improves parallelism but increases memory consumption
298+
as more inputs must be stored in memory. This is only used when multiple
299+
inputs are given.
300+
301+
:return: A HeteroData graph when a single input is provided, else an
302+
iterable sequence of HeteroData graphs.
303+
"""
304+
305+
def _run_one(graph: ProgramGraph) -> HeteroData:
306+
# 4 lists, one per edge type
307+
# (control, data, call and type edges)
308+
adjacencies = [[], [], [], []]
309+
edge_positions = [[], [], [], []]
310+
311+
# Create the adjacency lists and the positions
312+
for edge in graph.edge:
313+
adjacencies[edge.flow].append([edge.source, edge.target])
314+
edge_positions[edge.flow].append(edge.position)
315+
316+
# Store the text attributes
317+
node_text_list = []
318+
node_full_text_list = []
319+
320+
# Store the text and full text attributes
321+
for node in graph.node:
322+
node_text = node_full_text = node.text
323+
324+
if (
325+
node.features
326+
and node.features.feature["full_text"].bytes_list.value
327+
):
328+
node_full_text = node.features.feature["full_text"].bytes_list.value[0]
329+
330+
node_text_list.append(node_text)
331+
node_full_text_list.append(node_full_text)
332+
333+
334+
vocab_ids = None
335+
if vocabulary is not None:
336+
vocab_ids = [
337+
vocabulary.get(node.text, len(vocabulary.keys()))
338+
for node in graph.node
339+
]
340+
341+
# Pass from list to tensor
342+
adjacencies = [torch.tensor(adj_flow_type) for adj_flow_type in adjacencies]
343+
edge_positions = [torch.tensor(edge_pos_flow_type) for edge_pos_flow_type in edge_positions]
344+
345+
if vocabulary is not None:
346+
vocab_ids = torch.tensor(vocab_ids)
347+
348+
# Create the graph structure
349+
hetero_graph = HeteroData()
350+
351+
# Vocabulary index of each node
352+
hetero_graph['nodes']['text'] = node_text_list
353+
hetero_graph['nodes']['full_text'] = node_full_text_list
354+
hetero_graph['nodes'].x = vocab_ids
355+
356+
# Add the adjacency lists
357+
hetero_graph['nodes', 'control', 'nodes'].edge_index = (
358+
adjacencies[0].t().contiguous() if adjacencies[0].nelement() > 0 else torch.tensor([[], []])
359+
)
360+
hetero_graph['nodes', 'data', 'nodes'].edge_index = (
361+
adjacencies[1].t().contiguous() if adjacencies[1].nelement() > 0 else torch.tensor([[], []])
362+
)
363+
hetero_graph['nodes', 'call', 'nodes'].edge_index = (
364+
adjacencies[2].t().contiguous() if adjacencies[2].nelement() > 0 else torch.tensor([[], []])
365+
)
366+
hetero_graph['nodes', 'type', 'nodes'].edge_index = (
367+
adjacencies[3].t().contiguous() if adjacencies[3].nelement() > 0 else torch.tensor([[], []])
368+
)
369+
370+
# Add the edge positions
371+
hetero_graph['nodes', 'control', 'nodes'].edge_attr = edge_positions[0]
372+
hetero_graph['nodes', 'data', 'nodes'].edge_attr = edge_positions[1]
373+
hetero_graph['nodes', 'call', 'nodes'].edge_attr = edge_positions[2]
374+
hetero_graph['nodes', 'type', 'nodes'].edge_attr = edge_positions[3]
375+
376+
return hetero_graph
377+
378+
if isinstance(graphs, ProgramGraph):
379+
return _run_one(graphs)
380+
381+
return execute(_run_one, graphs, executor, chunksize)

tests/BUILD

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,14 @@ py_test(
134134
"//tests/plugins",
135135
],
136136
)
137+
138+
py_test(
139+
name = "to_pyg_test",
140+
srcs = ["to_pyg_test.py"],
141+
shard_count = 8,
142+
deps = [
143+
"//programl",
144+
"//tests:test_main",
145+
"//tests/plugins",
146+
],
147+
)

tests/to_pyg_test.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright 2019-2020 the ProGraML authors.
2+
#
3+
# Contact Chris Cummins <chrisc.101@gmail.com>.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
from concurrent.futures.thread import ThreadPoolExecutor
17+
18+
import networkx as nx
19+
import pytest
20+
21+
import programl as pg
22+
from torch_geometric.data import HeteroData
23+
from tests.test_main import main
24+
25+
pytest_plugins = ["tests.plugins.llvm_program_graph"]
26+
27+
28+
@pytest.fixture(scope="session")
29+
def graph() -> pg.ProgramGraph:
30+
return pg.from_cpp("int A() { return 0; }")
31+
32+
@pytest.fixture(scope="session")
33+
def graph2() -> pg.ProgramGraph:
34+
return pg.from_cpp("int B() { return 1; }")
35+
36+
@pytest.fixture(scope="session")
37+
def graph3() -> pg.ProgramGraph:
38+
return pg.from_cpp("int B(int x) { return x + 1; }")
39+
40+
def graphs_are_equal(
41+
graph1: HeteroData,
42+
graph2: HeteroData,
43+
):
44+
return (
45+
(graph1['nodes']['full_text'] == graph2['nodes']['full_text'])
46+
and (graph1['nodes', 'control', 'nodes'].edge_index.equal(graph2['nodes', 'control', 'nodes'].edge_index))
47+
and (graph1['nodes', 'data', 'nodes'].edge_index.equal(graph2['nodes', 'data', 'nodes'].edge_index))
48+
and (graph1['nodes', 'call', 'nodes'].edge_index.equal(graph2['nodes', 'call', 'nodes'].edge_index))
49+
and (graph1['nodes', 'type', 'nodes'].edge_index.equal(graph2['nodes', 'type', 'nodes'].edge_index))
50+
)
51+
52+
def test_to_pyg_simple_graph(graph: pg.ProgramGraph):
53+
graphs = list(pg.to_pyg([graph]))
54+
assert len(graphs) == 1
55+
assert isinstance(graphs[0], HeteroData)
56+
57+
def test_to_pyg_simple_graph_single_input(graph: pg.ProgramGraph):
58+
pyg_graph = pg.to_pyg(graph)
59+
assert isinstance(pyg_graph, HeteroData)
60+
61+
def test_to_pyg_different_two_different_inputs(
62+
graph: pg.ProgramGraph,
63+
graph2: pg.ProgramGraph,
64+
):
65+
pyg_graph = pg.to_pyg(graph)
66+
pyg_graph2 = pg.to_pyg(graph2)
67+
68+
# Ensure that the graphs are different
69+
assert not graphs_are_equal(pyg_graph, pyg_graph2)
70+
71+
def test_to_pyg_different_inputs(
72+
graph: pg.ProgramGraph,
73+
graph2: pg.ProgramGraph,
74+
graph3: pg.ProgramGraph
75+
):
76+
pyg_graph = pg.to_pyg(graph)
77+
pyg_graph2 = pg.to_pyg(graph2)
78+
pyg_graph3 = pg.to_pyg(graph3)
79+
80+
# Ensure that the graphs are different
81+
assert not graphs_are_equal(pyg_graph, pyg_graph2)
82+
assert not graphs_are_equal(pyg_graph, pyg_graph3)
83+
assert not graphs_are_equal(pyg_graph2, pyg_graph3)
84+
85+
def test_to_pyg_two_inputs(graph: pg.ProgramGraph):
86+
graphs = list(pg.to_pyg([graph, graph]))
87+
assert len(graphs) == 2
88+
assert graphs_are_equal(graphs[0], graphs[1])
89+
90+
def test_to_pyg_generator(graph: pg.ProgramGraph):
91+
graphs = list(pg.to_pyg((graph for _ in range(10)), chunksize=3))
92+
assert len(graphs) == 10
93+
for x in graphs[1:]:
94+
assert graphs_are_equal(graphs[0], x)
95+
96+
def test_to_pyg_generator_parallel_executor(graph: pg.ProgramGraph):
97+
with ThreadPoolExecutor() as executor:
98+
graphs = list(
99+
pg.to_pyg((graph for _ in range(10)), chunksize=3, executor=executor)
100+
)
101+
assert len(graphs) == 10
102+
for x in graphs[1:]:
103+
assert graphs_are_equal(graphs[0], x)
104+
105+
def test_to_pyg_smoke_test(llvm_program_graph: pg.ProgramGraph):
106+
graphs = list(pg.to_pyg([llvm_program_graph]))
107+
108+
num_nodes = len(graphs[0]['nodes']['text'])
109+
num_control_edges = graphs[0]['nodes', 'control', 'nodes'].edge_index.size(1)
110+
num_data_edges = graphs[0]['nodes', 'data', 'nodes'].edge_index.size(1)
111+
num_call_edges = graphs[0]['nodes', 'call', 'nodes'].edge_index.size(1)
112+
num_type_edges = graphs[0]['nodes', 'type', 'nodes'].edge_index.size(1)
113+
num_edges = num_control_edges + num_data_edges + num_call_edges + num_type_edges
114+
115+
assert len(graphs) == 1
116+
assert isinstance(graphs[0], HeteroData)
117+
assert num_nodes == len(llvm_program_graph.node)
118+
assert num_edges <= len(llvm_program_graph.edge)
119+
120+
121+
if __name__ == "__main__":
122+
main()

0 commit comments

Comments
 (0)