Skip to content

Commit 2a93fe6

Browse files
author
Iñigo Gabirondo
committed
Add tests for different graphs
1 parent 37b0fe2 commit 2a93fe6

File tree

1 file changed

+67
-15
lines changed

1 file changed

+67
-15
lines changed

tests/to_pyg_test.py

Lines changed: 67 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,39 +29,94 @@
2929
def graph() -> pg.ProgramGraph:
3030
return pg.from_cpp("int A() { return 0; }")
3131

32+
@pytest.fixture(scope="session")
33+
def graph2() -> pg.ProgramGraph:
34+
return pg.from_cpp("int B() { return 1; }")
35+
36+
@pytest.fixture(scope="session")
37+
def graph3() -> pg.ProgramGraph:
38+
return pg.from_cpp("int B(int x) { return x + 1; }")
39+
3240
def assert_equal_graphs(
3341
graph1: HeteroData,
34-
graph2: HeteroData
42+
graph2: HeteroData,
43+
equality: bool = True
3544
):
36-
assert graph1['nodes']['text'] == graph2['nodes']['text']
45+
if equality:
46+
assert graph1['nodes']['full_text'] == graph2['nodes']['full_text']
47+
48+
assert graph1['nodes', 'control', 'nodes'].edge_index.equal(graph2['nodes', 'control', 'nodes'].edge_index)
49+
assert graph1['nodes', 'data', 'nodes'].edge_index.equal(graph2['nodes', 'data', 'nodes'].edge_index)
50+
assert graph1['nodes', 'call', 'nodes'].edge_index.equal(graph2['nodes', 'call', 'nodes'].edge_index)
51+
assert graph1['nodes', 'type', 'nodes'].edge_index.equal(graph2['nodes', 'type', 'nodes'].edge_index)
52+
53+
else:
54+
text_different = graph1['nodes']['full_text'] != graph2['nodes']['full_text']
3755

38-
assert graph1['nodes', 'control', 'nodes'].edge_index.equal(graph2['nodes', 'control', 'nodes'].edge_index)
39-
assert graph1['nodes', 'data', 'nodes'].edge_index.equal(graph2['nodes', 'data', 'nodes'].edge_index)
40-
assert graph1['nodes', 'call', 'nodes'].edge_index.equal(graph2['nodes', 'call', 'nodes'].edge_index)
41-
assert graph1['nodes', 'type', 'nodes'].edge_index.equal(graph2['nodes', 'type', 'nodes'].edge_index)
56+
control_edges_different = not graph1['nodes', 'control', 'nodes'].edge_index.equal(
57+
graph2['nodes', 'control', 'nodes'].edge_index
58+
)
59+
data_edges_different = not graph1['nodes', 'data', 'nodes'].edge_index.equal(
60+
graph2['nodes', 'data', 'nodes'].edge_index
61+
)
62+
call_edges_different = not graph1['nodes', 'call', 'nodes'].edge_index.equal(
63+
graph2['nodes', 'call', 'nodes'].edge_index
64+
)
65+
type_edges_different = not graph1['nodes', 'type', 'nodes'].edge_index.equal(
66+
graph2['nodes', 'type', 'nodes'].edge_index
67+
)
68+
69+
assert (
70+
text_different
71+
or control_edges_different
72+
or data_edges_different
73+
or call_edges_different
74+
or type_edges_different
75+
)
4276

4377
def test_to_pyg_simple_graph(graph: pg.ProgramGraph):
4478
graphs = list(pg.to_pyg([graph]))
4579
assert len(graphs) == 1
4680
assert isinstance(graphs[0], HeteroData)
4781

48-
4982
def test_to_pyg_simple_graph_single_input(graph: pg.ProgramGraph):
5083
pyg_graph = pg.to_pyg(graph)
5184
assert isinstance(pyg_graph, HeteroData)
5285

86+
def test_to_pyg_different_two_different_inputs(
87+
graph: pg.ProgramGraph,
88+
graph2: pg.ProgramGraph,
89+
):
90+
pyg_graph = pg.to_pyg(graph)
91+
pyg_graph2 = pg.to_pyg(graph2)
92+
93+
# Ensure that the graphs are different
94+
assert_equal_graphs(pyg_graph, pyg_graph2, equality=False)
95+
96+
def test_to_pyg_different_inputs(
97+
graph: pg.ProgramGraph,
98+
graph2: pg.ProgramGraph,
99+
graph3: pg.ProgramGraph
100+
):
101+
pyg_graph = pg.to_pyg(graph)
102+
pyg_graph2 = pg.to_pyg(graph2)
103+
pyg_graph3 = pg.to_pyg(graph3)
104+
105+
# Ensure that the graphs are different
106+
assert_equal_graphs(pyg_graph, pyg_graph2, equality=False)
107+
assert_equal_graphs(pyg_graph, pyg_graph3, equality=False)
108+
assert_equal_graphs(pyg_graph2, pyg_graph3, equality=False)
53109

54110
def test_to_pyg_two_inputs(graph: pg.ProgramGraph):
55111
graphs = list(pg.to_pyg([graph, graph]))
56112
assert len(graphs) == 2
57-
assert_equal_graphs(graphs[0], graphs[1])
113+
assert_equal_graphs(graphs[0], graphs[1], equality=True)
58114

59115
def test_to_pyg_generator(graph: pg.ProgramGraph):
60116
graphs = list(pg.to_pyg((graph for _ in range(10)), chunksize=3))
61117
assert len(graphs) == 10
62118
for x in graphs[1:]:
63-
assert_equal_graphs(graphs[0], x)
64-
119+
assert_equal_graphs(graphs[0], x, equality=True)
65120

66121
def test_to_pyg_generator_parallel_executor(graph: pg.ProgramGraph):
67122
with ThreadPoolExecutor() as executor:
@@ -70,7 +125,7 @@ def test_to_pyg_generator_parallel_executor(graph: pg.ProgramGraph):
70125
)
71126
assert len(graphs) == 10
72127
for x in graphs[1:]:
73-
assert_equal_graphs(graphs[0], x)
128+
assert_equal_graphs(graphs[0], x, equality=True)
74129

75130

76131
def test_to_pyg_smoke_test(llvm_program_graph: pg.ProgramGraph):
@@ -90,7 +145,4 @@ def test_to_pyg_smoke_test(llvm_program_graph: pg.ProgramGraph):
90145

91146

92147
if __name__ == "__main__":
93-
main()
94-
95-
96-
148+
main()

0 commit comments

Comments
 (0)