Skip to content

Commit 7814c21

Browse files
authored
Replace pygraphviz with neo4j-viz for graph visualization (#306)
1 parent 85eaa5b commit 7814c21

File tree

13 files changed

+1397
-1006
lines changed

13 files changed

+1397
-1006
lines changed

.github/workflows/pr-e2e-tests.yaml

-2
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,6 @@ jobs:
4949
- 6333:6333
5050

5151
steps:
52-
- name: Install graphviz package
53-
run: sudo apt install graphviz graphviz-dev
5452
- name: Check out repository code
5553
uses: actions/checkout@v4
5654
- name: Docker Prune

.github/workflows/pr.yaml

-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ jobs:
88
matrix:
99
python-version: [ '3.9', '3.10', '3.11', '3.12' ]
1010
steps:
11-
- name: Install graphviz package
12-
run: sudo apt install graphviz graphviz-dev
1311
- name: Check out repository code
1412
uses: actions/checkout@v4
1513
- name: Set up Python ${{ matrix.python-version }}

.github/workflows/scheduled-e2e-tests.yaml

-2
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@ jobs:
5757
- 6333:6333
5858

5959
steps:
60-
- name: Install graphviz package
61-
run: sudo apt install graphviz graphviz-dev
6260
- name: Check out repository code
6361
uses: actions/checkout@v4
6462
- name: Docker Prune

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
### Changed
1414

1515
- Improved log output readability in Retrievers and GraphRAG and added embedded vector to retriever result metadata for debugging.
16+
- Switched from pygraphviz to neo4j-viz
17+
- Renders interactive graph now on HTML instead of PNG
18+
- Removed `get_pygraphviz_graph` method
1619

1720
### Fixed
1821

README.md

-6
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ the extra dependencies described below:
5959
- **pinecone**: store vectors in Pinecone
6060
- **qdrant**: store vectors in Qdrant
6161
- **experimental**: experimental features mainly related to the Knowledge Graph creation pipelines.
62-
- Warning: this dependency group requires `pygraphviz`. See below for installation instructions.
6362

6463

6564
Install package with optional dependencies with (for instance):
@@ -68,11 +67,6 @@ Install package with optional dependencies with (for instance):
6867
pip install "neo4j-graphrag[openai]"
6968
```
7069

71-
#### pygraphviz
72-
73-
`pygraphviz` is used for visualizing pipelines.
74-
Installation instructions can be found [here](https://pygraphviz.github.io/documentation/stable/install.html).
75-
7670
## 💻 Example Usage
7771

7872
The scripts below demonstrate how to get started with the package and make use of its key features.

docs/source/index.rst

-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ List of extra dependencies:
9898
- **pinecone**: store vectors in Pinecone
9999
- **qdrant**: store vectors in Qdrant
100100
- **experimental**: experimental features mainly from the Knowledge Graph creation pipelines.
101-
- Warning: this requires `pygraphviz`. Installation instructions can be found `here <https://pygraphviz.github.io/documentation/stable/install.html>`_.
102101
- nlp:
103102
- **spaCy**: load spaCy trained models for nlp pipelines, used by `SpaCySemanticMatchResolver` component from the Knowledge Graph creation pipelines.
104103
- fuzzy-matching:

docs/source/user_guide_pipeline.rst

+12-11
Original file line numberDiff line numberDiff line change
@@ -111,25 +111,26 @@ Pipelines can be visualized using the `draw` method:
111111
pipe = Pipeline()
112112
# ... define components and connections
113113
114-
pipe.draw("pipeline.png")
114+
pipe.draw("pipeline.html")
115115
116-
Here is an example pipeline rendering:
116+
Here is an example pipeline rendering as an interactive HTML visualization:
117117

118-
.. image:: images/pipeline_no_unused_outputs.png
119-
:alt: Pipeline visualisation with hidden outputs if unused
118+
.. code:: python
120119
120+
# To view the visualization in a browser
121+
import webbrowser
122+
webbrowser.open("pipeline.html")
121123
122124
By default, output fields which are not mapped to any component are hidden. They
123-
can be added to the canvas by setting `hide_unused_outputs` to `False`:
125+
can be added to the visualization by setting `hide_unused_outputs` to `False`:
124126

125127
.. code:: python
126128
127-
pipe.draw("pipeline.png", hide_unused_outputs=False)
128-
129-
Here is an example of final result:
130-
131-
.. image:: images/pipeline_full.png
132-
:alt: Pipeline visualisation
129+
pipe.draw("pipeline_full.html", hide_unused_outputs=False)
130+
131+
# To view the full visualization in a browser
132+
import webbrowser
133+
webbrowser.open("pipeline_full.html")
133134
134135
135136
************************

examples/customize/build_graph/pipeline/visualization.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -54,5 +54,5 @@ async def run(self, number: IntDataModel) -> IntDataModel:
5454
pipe.connect("times_two", "addition", {"a": "times_two.value"})
5555
pipe.connect("times_ten", "addition", {"b": "times_ten.value"})
5656
pipe.connect("addition", "save", {"number": "addition"})
57-
pipe.draw("graph.png")
58-
pipe.draw("graph_full.png", hide_unused_outputs=False)
57+
pipe.draw("graph.html")
58+
pipe.draw("graph_full.html", hide_unused_outputs=False)

poetry.lock

+1,205-932
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+4-6
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,7 @@ pyyaml = "^6.0.2"
3838
types-pyyaml = "^6.0.12.20240917"
3939
# optional deps
4040
langchain-text-splitters = {version = "^0.3.0", optional = true }
41-
pygraphviz = [
42-
{version = "^1.13.0", python = ">=3.10,<4.0.0", optional = true},
43-
{version = "^1.0.0", python = "<3.10", optional = true}
44-
]
41+
neo4j-viz = {version = "^0.2.2", optional = true }
4542
weaviate-client = {version = "^4.6.1", optional = true }
4643
pinecone-client = {version = "^4.1.0", optional = true }
4744
google-cloud-aiplatform = {version = "^1.66.0", optional = true }
@@ -68,6 +65,7 @@ sphinx = { version = "^7.2.6", python = "^3.9" }
6865
langchain-openai = {version = "^0.2.2", optional = true }
6966
langchain-huggingface = {version = "^0.1.0", optional = true }
7067
enum-tools = {extras = ["sphinx"], version = "^0.12.0"}
68+
neo4j-viz = "^0.2.2"
7169

7270
[tool.poetry.extras]
7371
weaviate = ["weaviate-client"]
@@ -79,9 +77,9 @@ ollama = ["ollama"]
7977
openai = ["openai"]
8078
mistralai = ["mistralai"]
8179
qdrant = ["qdrant-client"]
82-
kg_creation_tools = ["pygraphviz"]
80+
kg_creation_tools = ["neo4j-viz"]
8381
sentence-transformers = ["sentence-transformers"]
84-
experimental = ["langchain-text-splitters", "pygraphviz", "llama-index"]
82+
experimental = ["langchain-text-splitters", "neo4j-viz", "llama-index"]
8583
examples = ["langchain-openai", "langchain-huggingface"]
8684
nlp = ["spacy"]
8785
fuzzy-matching = ["rapidfuzz"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from typing import Any, Dict, List, Optional, Union
17+
18+
class Node:
19+
id: Union[str, int]
20+
caption: Optional[str] = None
21+
size: Optional[float] = None
22+
properties: Optional[Dict[str, Any]] = None
23+
24+
def __init__(
25+
self,
26+
id: Union[str, int],
27+
caption: Optional[str] = None,
28+
size: Optional[float] = None,
29+
properties: Optional[Dict[str, Any]] = None,
30+
**kwargs: Any,
31+
) -> None: ...
32+
33+
class Relationship:
34+
source: Union[str, int]
35+
target: Union[str, int]
36+
caption: Optional[str] = None
37+
properties: Optional[Dict[str, Any]] = None
38+
39+
def __init__(
40+
self,
41+
source: Union[str, int],
42+
target: Union[str, int],
43+
caption: Optional[str] = None,
44+
properties: Optional[Dict[str, Any]] = None,
45+
**kwargs: Any,
46+
) -> None: ...
47+
48+
class VisualizationGraph:
49+
nodes: List[Node]
50+
relationships: List[Relationship]
51+
52+
def __init__(
53+
self, nodes: List[Node], relationships: List[Relationship]
54+
) -> None: ...

src/neo4j_graphrag/experimental/pipeline/pipeline.py

+108-33
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@
2424
from neo4j_graphrag.utils.logging import prettify
2525

2626
try:
27-
import pygraphviz as pgv
27+
from neo4j_viz import Node, Relationship, VisualizationGraph
28+
29+
neo4j_viz_available = True
2830
except ImportError:
29-
pgv = None
31+
neo4j_viz_available = False
3032

3133
from pydantic import BaseModel
3234

@@ -198,53 +200,126 @@ def show_as_dict(self) -> dict[str, Any]:
198200
def draw(
199201
self, path: str, layout: str = "dot", hide_unused_outputs: bool = True
200202
) -> Any:
201-
G = self.get_pygraphviz_graph(hide_unused_outputs)
202-
G.layout(layout)
203-
G.draw(path)
203+
"""Render the pipeline graph to an HTML file at the specified path"""
204+
G = self._get_neo4j_viz_graph(hide_unused_outputs)
205+
206+
# Write the visualization to an HTML file
207+
with open(path, "w") as f:
208+
f.write(G.render().data)
209+
210+
return G
204211

205-
def get_pygraphviz_graph(self, hide_unused_outputs: bool = True) -> pgv.AGraph:
206-
if pgv is None:
212+
def _get_neo4j_viz_graph(
213+
self, hide_unused_outputs: bool = True
214+
) -> VisualizationGraph:
215+
"""Generate a neo4j-viz visualization of the pipeline graph"""
216+
if not neo4j_viz_available:
207217
raise ImportError(
208-
"Could not import pygraphviz. "
209-
"Follow installation instruction in pygraphviz documentation "
210-
"to get it up and running on your system."
218+
"Could not import neo4j-viz. Install it with 'pip install \"neo4j-graphrag[experimental]\"'"
211219
)
220+
212221
self.validate_parameter_mapping()
213-
G = pgv.AGraph(strict=False, directed=True)
214-
# create a node for each component
215-
for n, node in self._nodes.items():
216-
comp_inputs = ",".join(
222+
223+
nodes = []
224+
relationships = []
225+
node_ids = {} # Map node names to their numeric IDs
226+
next_id = 0
227+
228+
# Create nodes for each component
229+
for n, pipeline_node in self._nodes.items():
230+
comp_inputs = ", ".join(
217231
f"{i}: {d['annotation']}"
218-
for i, d in node.component.component_inputs.items()
232+
for i, d in pipeline_node.component.component_inputs.items()
219233
)
220-
G.add_node(
221-
n,
222-
node_type="component",
223-
shape="rectangle",
224-
label=f"{node.component.__class__.__name__}: {n}({comp_inputs})",
234+
235+
node_ids[n] = next_id
236+
label = f"{pipeline_node.component.__class__.__name__}: {n}({comp_inputs})"
237+
238+
# Create Node with properties parameter
239+
viz_node = Node( # type: ignore
240+
id=next_id,
241+
caption=label,
242+
size=20,
243+
properties={"node_type": "component"},
225244
)
226-
# create a node for each output field and connect them it to its component
227-
for o in node.component.component_outputs:
245+
nodes.append(viz_node)
246+
next_id += 1
247+
248+
# Create nodes for each output field
249+
for o in pipeline_node.component.component_outputs:
228250
param_node_name = f"{n}.{o}"
229-
G.add_node(param_node_name, label=o, node_type="output")
230-
G.add_edge(n, param_node_name)
231-
# then we create the edges between a component output
232-
# and the component it gets added to
251+
252+
# Skip if we're hiding unused outputs and it's not used
253+
if hide_unused_outputs:
254+
# Check if this output is used as a source in any parameter mapping
255+
is_used = False
256+
for params in self.param_mapping.values():
257+
for mapping in params.values():
258+
source_component = mapping["component"]
259+
source_param_name = mapping.get("param")
260+
if source_component == n and source_param_name == o:
261+
is_used = True
262+
break
263+
if is_used:
264+
break
265+
266+
if not is_used:
267+
continue
268+
269+
node_ids[param_node_name] = next_id
270+
# Create Node with properties parameter
271+
output_node = Node( # type: ignore
272+
id=next_id,
273+
caption=o,
274+
size=15,
275+
properties={"node_type": "output"},
276+
)
277+
nodes.append(output_node)
278+
279+
# Connect component to its output
280+
# Connect component to its output
281+
rel = Relationship( # type: ignore
282+
source=node_ids[n],
283+
target=node_ids[param_node_name],
284+
properties={"type": "HAS_OUTPUT"},
285+
)
286+
relationships.append(rel)
287+
next_id += 1
288+
289+
# Create edges between components based on parameter mapping
233290
for component_name, params in self.param_mapping.items():
234291
for param, mapping in params.items():
235292
source_component = mapping["component"]
236293
source_param_name = mapping.get("param")
294+
237295
if source_param_name:
238296
source_output_node = f"{source_component}.{source_param_name}"
239297
else:
240298
source_output_node = source_component
241-
G.add_edge(source_output_node, component_name, label=param)
242-
# remove outputs that are not mapped
243-
if hide_unused_outputs:
244-
for n in G.nodes():
245-
if n.attr["node_type"] == "output" and G.out_degree(n) == 0: # type: ignore
246-
G.remove_node(n)
247-
return G
299+
300+
if source_output_node in node_ids and component_name in node_ids:
301+
rel = Relationship( # type: ignore
302+
source=node_ids[source_output_node],
303+
target=node_ids[component_name],
304+
caption=param,
305+
properties={"type": "CONNECTS_TO"},
306+
)
307+
relationships.append(rel)
308+
309+
# Create the visualization graph
310+
viz_graph = VisualizationGraph(nodes=nodes, relationships=relationships)
311+
return viz_graph
312+
313+
def get_pygraphviz_graph(self, hide_unused_outputs: bool = True) -> Any:
314+
"""Legacy method for backward compatibility.
315+
Uses neo4j-viz instead of pygraphviz.
316+
"""
317+
warnings.warn(
318+
"get_pygraphviz_graph is deprecated, use draw instead",
319+
DeprecationWarning,
320+
stacklevel=2,
321+
)
322+
return self._get_neo4j_viz_graph(hide_unused_outputs)
248323

249324
def add_component(self, component: Component, name: str) -> None:
250325
"""Add a new component. Components are uniquely identified

0 commit comments

Comments
 (0)