|
18 | 18 | """
|
19 | 19 | import json
|
20 | 20 | import subprocess
|
| 21 | +import torch |
21 | 22 | from typing import Any, Dict, Iterable, Optional, Union
|
22 | 23 |
|
23 | 24 | import dgl
|
24 | 25 | import networkx as nx
|
25 | 26 | from dgl.heterograph import DGLHeteroGraph
|
| 27 | +from torch_geometric.data import HeteroData |
26 | 28 | from networkx.readwrite import json_graph as nx_json
|
27 | 29 |
|
28 | 30 | from programl.exceptions import GraphTransformError
|
@@ -258,3 +260,122 @@ def _run_one(graph: ProgramGraph) -> str:
|
258 | 260 | if isinstance(graphs, ProgramGraph):
|
259 | 261 | return _run_one(graphs)
|
260 | 262 | return execute(_run_one, graphs, executor, chunksize)
|
| 263 | + |
| 264 | + |
| 265 | +def to_pyg( |
| 266 | + graphs: Union[ProgramGraph, Iterable[ProgramGraph]], |
| 267 | + timeout: int = 300, |
| 268 | + vocabulary: Optional[Dict[str, int]] = None, |
| 269 | + executor: Optional[ExecutorLike] = None, |
| 270 | + chunksize: Optional[int] = None, |
| 271 | +) -> Union[HeteroData, Iterable[HeteroData]]: |
| 272 | + """Convert one or more Program Graphs to Pytorch-Geometrics's HeteroData. |
| 273 | + This graphs can be used as input for any deep learning model built with |
| 274 | + Pytorch-Geometric: |
| 275 | +
|
| 276 | + https://pytorch-geometric.readthedocs.io/en/latest/tutorial/heterogeneous.html |
| 277 | +
|
| 278 | + :param graphs: A Program Graph, or a sequence of Program Graphs. |
| 279 | +
|
| 280 | + :param timeout: The maximum number of seconds to wait for an individual |
| 281 | + graph conversion before raising an error. If multiple inputs are |
| 282 | + provided, this timeout is per-input. |
| 283 | +
|
| 284 | + :param vocabulary: A dictionary containing ProGraML's vocabulary, where the |
| 285 | + keys are the text attribute of the nodes and the values their respective |
| 286 | + indexes. |
| 287 | +
|
| 288 | + :param executor: An executor object, with method :code:`submit(callable, |
| 289 | + *args, **kwargs)` and returning a Future-like object with methods |
| 290 | + :code:`done() -> bool` and :code:`result() -> float`. The executor role |
| 291 | + is to dispatch the execution of the jobs locally/on a cluster/with |
| 292 | + multithreading depending on the implementation. Eg: |
| 293 | + :code:`concurrent.futures.ThreadPoolExecutor`. Defaults to single |
| 294 | + threaded execution. This is only used when multiple inputs are given. |
| 295 | +
|
| 296 | + :param chunksize: The number of inputs to read and process at a time. A |
| 297 | + larger chunksize improves parallelism but increases memory consumption |
| 298 | + as more inputs must be stored in memory. This is only used when multiple |
| 299 | + inputs are given. |
| 300 | +
|
| 301 | + :return: A HeteroData graph when a single input is provided, else an |
| 302 | + iterable sequence of HeteroData graphs. |
| 303 | + """ |
| 304 | + |
| 305 | + def _run_one(graph: ProgramGraph) -> HeteroData: |
| 306 | + # 4 lists, one per edge type |
| 307 | + # (control, data, call and type edges) |
| 308 | + adjacencies = [[], [], [], []] |
| 309 | + edge_positions = [[], [], [], []] |
| 310 | + |
| 311 | + # Create the adjacency lists and the positions |
| 312 | + for edge in graph.edge: |
| 313 | + adjacencies[edge.flow].append([edge.source, edge.target]) |
| 314 | + edge_positions[edge.flow].append(edge.position) |
| 315 | + |
| 316 | + # Store the text attributes |
| 317 | + node_text_list = [] |
| 318 | + node_full_text_list = [] |
| 319 | + |
| 320 | + # Store the text and full text attributes |
| 321 | + for node in graph.node: |
| 322 | + node_text = node_full_text = node.text |
| 323 | + |
| 324 | + if ( |
| 325 | + node.features |
| 326 | + and node.features.feature["full_text"].bytes_list.value |
| 327 | + ): |
| 328 | + node_full_text = node.features.feature["full_text"].bytes_list.value[0] |
| 329 | + |
| 330 | + node_text_list.append(node_text) |
| 331 | + node_full_text_list.append(node_full_text) |
| 332 | + |
| 333 | + |
| 334 | + vocab_ids = None |
| 335 | + if vocabulary is not None: |
| 336 | + vocab_ids = [ |
| 337 | + vocabulary.get(node.text, len(vocabulary.keys())) |
| 338 | + for node in graph.node |
| 339 | + ] |
| 340 | + |
| 341 | + # Pass from list to tensor |
| 342 | + adjacencies = [torch.tensor(adj_flow_type) for adj_flow_type in adjacencies] |
| 343 | + edge_positions = [torch.tensor(edge_pos_flow_type) for edge_pos_flow_type in edge_positions] |
| 344 | + |
| 345 | + if vocabulary is not None: |
| 346 | + vocab_ids = torch.tensor(vocab_ids) |
| 347 | + |
| 348 | + # Create the graph structure |
| 349 | + hetero_graph = HeteroData() |
| 350 | + |
| 351 | + # Vocabulary index of each node |
| 352 | + hetero_graph['nodes']['text'] = node_text_list |
| 353 | + hetero_graph['nodes']['full_text'] = node_full_text_list |
| 354 | + hetero_graph['nodes'].x = vocab_ids |
| 355 | + |
| 356 | + # Add the adjacency lists |
| 357 | + hetero_graph['nodes', 'control', 'nodes'].edge_index = ( |
| 358 | + adjacencies[0].t().contiguous() if adjacencies[0].nelement() > 0 else torch.tensor([[], []]) |
| 359 | + ) |
| 360 | + hetero_graph['nodes', 'data', 'nodes'].edge_index = ( |
| 361 | + adjacencies[1].t().contiguous() if adjacencies[1].nelement() > 0 else torch.tensor([[], []]) |
| 362 | + ) |
| 363 | + hetero_graph['nodes', 'call', 'nodes'].edge_index = ( |
| 364 | + adjacencies[2].t().contiguous() if adjacencies[2].nelement() > 0 else torch.tensor([[], []]) |
| 365 | + ) |
| 366 | + hetero_graph['nodes', 'type', 'nodes'].edge_index = ( |
| 367 | + adjacencies[3].t().contiguous() if adjacencies[3].nelement() > 0 else torch.tensor([[], []]) |
| 368 | + ) |
| 369 | + |
| 370 | + # Add the edge positions |
| 371 | + hetero_graph['nodes', 'control', 'nodes'].edge_attr = edge_positions[0] |
| 372 | + hetero_graph['nodes', 'data', 'nodes'].edge_attr = edge_positions[1] |
| 373 | + hetero_graph['nodes', 'call', 'nodes'].edge_attr = edge_positions[2] |
| 374 | + hetero_graph['nodes', 'type', 'nodes'].edge_attr = edge_positions[3] |
| 375 | + |
| 376 | + return hetero_graph |
| 377 | + |
| 378 | + if isinstance(graphs, ProgramGraph): |
| 379 | + return _run_one(graphs) |
| 380 | + |
| 381 | + return execute(_run_one, graphs, executor, chunksize) |
0 commit comments