29
29
def graph () -> pg .ProgramGraph :
30
30
return pg .from_cpp ("int A() { return 0; }" )
31
31
32
+ @pytest .fixture (scope = "session" )
33
+ def graph2 () -> pg .ProgramGraph :
34
+ return pg .from_cpp ("int B() { return 1; }" )
35
+
36
+ @pytest .fixture (scope = "session" )
37
+ def graph3 () -> pg .ProgramGraph :
38
+ return pg .from_cpp ("int B(int x) { return x + 1; }" )
39
+
32
40
def assert_equal_graphs (
33
41
graph1 : HeteroData ,
34
- graph2 : HeteroData
42
+ graph2 : HeteroData ,
43
+ equality : bool = True
35
44
):
36
- assert graph1 ['nodes' ]['text' ] == graph2 ['nodes' ]['text' ]
45
+ if equality :
46
+ assert graph1 ['nodes' ]['full_text' ] == graph2 ['nodes' ]['full_text' ]
47
+
48
+ assert graph1 ['nodes' , 'control' , 'nodes' ].edge_index .equal (graph2 ['nodes' , 'control' , 'nodes' ].edge_index )
49
+ assert graph1 ['nodes' , 'data' , 'nodes' ].edge_index .equal (graph2 ['nodes' , 'data' , 'nodes' ].edge_index )
50
+ assert graph1 ['nodes' , 'call' , 'nodes' ].edge_index .equal (graph2 ['nodes' , 'call' , 'nodes' ].edge_index )
51
+ assert graph1 ['nodes' , 'type' , 'nodes' ].edge_index .equal (graph2 ['nodes' , 'type' , 'nodes' ].edge_index )
52
+
53
+ else :
54
+ text_different = graph1 ['nodes' ]['full_text' ] != graph2 ['nodes' ]['full_text' ]
37
55
38
- assert graph1 ['nodes' , 'control' , 'nodes' ].edge_index .equal (graph2 ['nodes' , 'control' , 'nodes' ].edge_index )
39
- assert graph1 ['nodes' , 'data' , 'nodes' ].edge_index .equal (graph2 ['nodes' , 'data' , 'nodes' ].edge_index )
40
- assert graph1 ['nodes' , 'call' , 'nodes' ].edge_index .equal (graph2 ['nodes' , 'call' , 'nodes' ].edge_index )
41
- assert graph1 ['nodes' , 'type' , 'nodes' ].edge_index .equal (graph2 ['nodes' , 'type' , 'nodes' ].edge_index )
56
+ control_edges_different = not graph1 ['nodes' , 'control' , 'nodes' ].edge_index .equal (
57
+ graph2 ['nodes' , 'control' , 'nodes' ].edge_index
58
+ )
59
+ data_edges_different = not graph1 ['nodes' , 'data' , 'nodes' ].edge_index .equal (
60
+ graph2 ['nodes' , 'data' , 'nodes' ].edge_index
61
+ )
62
+ call_edges_different = not graph1 ['nodes' , 'call' , 'nodes' ].edge_index .equal (
63
+ graph2 ['nodes' , 'call' , 'nodes' ].edge_index
64
+ )
65
+ type_edges_different = not graph1 ['nodes' , 'type' , 'nodes' ].edge_index .equal (
66
+ graph2 ['nodes' , 'type' , 'nodes' ].edge_index
67
+ )
68
+
69
+ assert (
70
+ text_different
71
+ or control_edges_different
72
+ or data_edges_different
73
+ or call_edges_different
74
+ or type_edges_different
75
+ )
42
76
43
77
def test_to_pyg_simple_graph (graph : pg .ProgramGraph ):
44
78
graphs = list (pg .to_pyg ([graph ]))
45
79
assert len (graphs ) == 1
46
80
assert isinstance (graphs [0 ], HeteroData )
47
81
48
-
49
82
def test_to_pyg_simple_graph_single_input (graph : pg .ProgramGraph ):
50
83
pyg_graph = pg .to_pyg (graph )
51
84
assert isinstance (pyg_graph , HeteroData )
52
85
86
+ def test_to_pyg_different_two_different_inputs (
87
+ graph : pg .ProgramGraph ,
88
+ graph2 : pg .ProgramGraph ,
89
+ ):
90
+ pyg_graph = pg .to_pyg (graph )
91
+ pyg_graph2 = pg .to_pyg (graph2 )
92
+
93
+ # Ensure that the graphs are different
94
+ assert_equal_graphs (pyg_graph , pyg_graph2 , equality = False )
95
+
96
+ def test_to_pyg_different_inputs (
97
+ graph : pg .ProgramGraph ,
98
+ graph2 : pg .ProgramGraph ,
99
+ graph3 : pg .ProgramGraph
100
+ ):
101
+ pyg_graph = pg .to_pyg (graph )
102
+ pyg_graph2 = pg .to_pyg (graph2 )
103
+ pyg_graph3 = pg .to_pyg (graph3 )
104
+
105
+ # Ensure that the graphs are different
106
+ assert_equal_graphs (pyg_graph , pyg_graph2 , equality = False )
107
+ assert_equal_graphs (pyg_graph , pyg_graph3 , equality = False )
108
+ assert_equal_graphs (pyg_graph2 , pyg_graph3 , equality = False )
53
109
54
110
def test_to_pyg_two_inputs (graph : pg .ProgramGraph ):
55
111
graphs = list (pg .to_pyg ([graph , graph ]))
56
112
assert len (graphs ) == 2
57
- assert_equal_graphs (graphs [0 ], graphs [1 ])
113
+ assert_equal_graphs (graphs [0 ], graphs [1 ], equality = True )
58
114
59
115
def test_to_pyg_generator (graph : pg .ProgramGraph ):
60
116
graphs = list (pg .to_pyg ((graph for _ in range (10 )), chunksize = 3 ))
61
117
assert len (graphs ) == 10
62
118
for x in graphs [1 :]:
63
- assert_equal_graphs (graphs [0 ], x )
64
-
119
+ assert_equal_graphs (graphs [0 ], x , equality = True )
65
120
66
121
def test_to_pyg_generator_parallel_executor (graph : pg .ProgramGraph ):
67
122
with ThreadPoolExecutor () as executor :
@@ -70,7 +125,7 @@ def test_to_pyg_generator_parallel_executor(graph: pg.ProgramGraph):
70
125
)
71
126
assert len (graphs ) == 10
72
127
for x in graphs [1 :]:
73
- assert_equal_graphs (graphs [0 ], x )
128
+ assert_equal_graphs (graphs [0 ], x , equality = True )
74
129
75
130
76
131
def test_to_pyg_smoke_test (llvm_program_graph : pg .ProgramGraph ):
@@ -90,7 +145,4 @@ def test_to_pyg_smoke_test(llvm_program_graph: pg.ProgramGraph):
90
145
91
146
92
147
if __name__ == "__main__" :
93
- main ()
94
-
95
-
96
-
148
+ main ()
0 commit comments