Skip to content

Commit 9c89f0d

Browse files
author
Iñigo Gabirondo
committed
Add tests for to_pyg() method
1 parent 80b22c4 commit 9c89f0d

File tree

1 file changed

+82
-0
lines changed

1 file changed

+82
-0
lines changed

tests/to_pyg_test.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
2+
from concurrent.futures.thread import ThreadPoolExecutor
3+
4+
import networkx as nx
5+
import pytest
6+
7+
import programl as pg
8+
from torch_geometric.data import HeteroData
9+
from tests.test_main import main
10+
11+
pytest_plugins = ["tests.plugins.llvm_program_graph"]
12+
13+
14+
@pytest.fixture(scope="session")
15+
def graph() -> pg.ProgramGraph:
16+
return pg.from_cpp("int A() { return 0; }")
17+
18+
def assert_equal_graphs(
19+
graph1: HeteroData,
20+
graph2: HeteroData
21+
):
22+
assert graph1['nodes']['text'] == graph2['nodes']['text']
23+
24+
assert graph1['nodes', 'control', 'nodes'].edge_index.equal(graph2['nodes', 'control', 'nodes'].edge_index)
25+
assert graph1['nodes', 'data', 'nodes'].edge_index.equal(graph2['nodes', 'data', 'nodes'].edge_index)
26+
assert graph1['nodes', 'call', 'nodes'].edge_index.equal(graph2['nodes', 'call', 'nodes'].edge_index)
27+
assert graph1['nodes', 'type', 'nodes'].edge_index.equal(graph2['nodes', 'type', 'nodes'].edge_index)
28+
29+
def test_to_pyg_simple_graph(graph: pg.ProgramGraph):
30+
graphs = list(pg.to_pyg([graph]))
31+
assert len(graphs) == 1
32+
assert isinstance(graphs[0], HeteroData)
33+
34+
35+
def test_to_pyg_simple_graph_single_input(graph: pg.ProgramGraph):
36+
pyg_graph = pg.to_pyg(graph)
37+
assert isinstance(pyg_graph, HeteroData)
38+
39+
40+
def test_to_pyg_two_inputs(graph: pg.ProgramGraph):
41+
graphs = list(pg.to_pyg([graph, graph]))
42+
assert len(graphs) == 2
43+
assert_equal_graphs(graphs[0], graphs[1])
44+
45+
def test_to_pyg_generator(graph: pg.ProgramGraph):
46+
graphs = list(pg.to_pyg((graph for _ in range(10)), chunksize=3))
47+
assert len(graphs) == 10
48+
for x in graphs[1:]:
49+
assert_equal_graphs(graphs[0], x)
50+
51+
52+
def test_to_pyg_generator_parallel_executor(graph: pg.ProgramGraph):
53+
with ThreadPoolExecutor() as executor:
54+
graphs = list(
55+
pg.to_pyg((graph for _ in range(10)), chunksize=3, executor=executor)
56+
)
57+
assert len(graphs) == 10
58+
for x in graphs[1:]:
59+
assert_equal_graphs(graphs[0], x)
60+
61+
62+
def test_to_pyg_smoke_test(llvm_program_graph: pg.ProgramGraph):
63+
graphs = list(pg.to_pyg([llvm_program_graph]))
64+
65+
num_nodes = len(graphs[0]['nodes']['text'])
66+
num_control_edges = graphs[0]['nodes', 'control', 'nodes'].edge_index.size(1)
67+
num_data_edges = graphs[0]['nodes', 'data', 'nodes'].edge_index.size(1)
68+
num_call_edges = graphs[0]['nodes', 'call', 'nodes'].edge_index.size(1)
69+
num_type_edges = graphs[0]['nodes', 'type', 'nodes'].edge_index.size(1)
70+
num_edges = num_control_edges + num_data_edges + num_call_edges + num_type_edges
71+
72+
assert len(graphs) == 1
73+
assert isinstance(graphs[0], HeteroData)
74+
assert num_nodes == len(llvm_program_graph.node)
75+
assert num_edges <= len(llvm_program_graph.edge)
76+
77+
78+
if __name__ == "__main__":
79+
main()
80+
81+
82+

0 commit comments

Comments
 (0)