Skip to content

Commit 0ad557c

Browse files
committed
update visualize_graph for graph and hypergraph; bugfix: ensure g is always assigned
1 parent f73c5b2 commit 0ad557c

File tree

1 file changed

+68
-53
lines changed

1 file changed

+68
-53
lines changed

src/graphistry_mcp_server/server.py

Lines changed: 68 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
mcp = FastMCP("graphistry-mcp-server")
3333

3434
# Initialize state
35-
graph_cache = {}
35+
graph_cache: Dict[str, Any] = {}
3636

3737
# Debug: Print environment variables for Graphistry
3838
print(f"[DEBUG] GRAPHISTRY_USERNAME is set: {os.environ.get('GRAPHISTRY_USERNAME') is not None}")
@@ -55,78 +55,100 @@ async def visualize_graph(graph_data: Dict[str, Any], ctx: Optional[Context] = N
5555
Visualize a graph using Graphistry's GPU-accelerated renderer.
5656
5757
Args:
58+
graph_type (str, optional): Type of graph to visualize. Must be one of "graph" (two-way edges, default), "hypergraph" (many-to-many edges).
5859
graph_data (dict): Dictionary describing the graph to visualize. Fields:
59-
- data_format (str, required): Format of the input data. One of:
60-
* "edge_list": Use with 'edges' (list of {source, target}) and optional 'nodes' (list of {id, ...})
61-
* "pandas": Use with 'edges' (list of dicts), 'source' (str), 'destination' (str), and optional 'node_id' (str)
62-
* "networkx": Use with 'edges' as a networkx.Graph object
63-
- edges (list, required for edge_list/pandas): List of edges, each as a dict with at least 'source' and 'target' keys (e.g., [{"source": "A", "target": "B"}, ...])
64-
- nodes (list, optional): List of nodes, each as a dict with at least 'id' key (e.g., [{"id": "A"}, ...])
65-
- node_id (str, optional): Column name for node IDs (for pandas format)
66-
- source (str, optional): Column name for edge source (for pandas format)
67-
- destination (str, optional): Column name for edge destination (for pandas format)
60+
- edges (list, required): List of edges, each as a dict with at least 'source' and 'target' keys (e.g., [{"source": "A", "target": "B"}, ...]) and any other columns you want to include in the edge table
61+
- nodes (list, optional): List of nodes, each as a dict with at least 'id' key (e.g., [{"id": "A"}, ...]) and any other columns you want to include in the node table
62+
- node_id (str, optional): Column name for node IDs, if nodes are provided, must be provided.
63+
- source (str, optional): Column name for edge source (default: "source")
64+
- destination (str, optional): Column name for edge destination (default: "target")
65+
- columns (list, optional): List of column names for hypergraph edge table, use if graph_type is hypergraph.
6866
- title (str, optional): Title for the visualization
6967
- description (str, optional): Description for the visualization
7068
ctx: MCP context for progress reporting
7169
72-
Example:
70+
Example (graph):
7371
graph_data = {
74-
"data_format": "edge_list",
72+
"graph_type": "graph",
7573
"edges": [
76-
{"source": "A", "target": "B"},
77-
{"source": "A", "target": "C"},
74+
{"source": "A", "target": "B", "weight": 1},
75+
{"source": "A", "target": "C", "weight": 2},
7876
...
7977
],
8078
"nodes": [
81-
{"id": "A"}, {"id": "B"}, {"id": "C"}
79+
{"id": "A", "label": "Node A"},
80+
{"id": "B", "label": "Node B"},
81+
...
8282
],
83+
"node_id": "id",
84+
"source": "source",
85+
"destination": "target",
8386
"title": "My Graph",
8487
"description": "A simple example graph."
8588
}
89+
90+
Example (hypergraph):
91+
graph_data = {
92+
"graph_type": "hypergraph",
93+
"edges": [
94+
{"source": "A", "target": "B", "group": "G1", "weight": 1},
95+
{"source": "A", "target": "C", "group": "G1", "weight": 1},
96+
...
97+
],
98+
"columns": ["source", "target", "group"],
99+
"title": "My Hypergraph",
100+
"description": "A simple example hypergraph."
101+
}
86102
"""
87103
try:
88104
if ctx:
89105
await ctx.info("Initializing graph visualization...")
90106

91-
data_format = graph_data.get("data_format")
107+
graph_type = graph_data.get("graph_type") or "graph"
92108
edges = graph_data.get("edges")
93109
nodes = graph_data.get("nodes")
94110
node_id = graph_data.get("node_id")
95-
source = graph_data.get("source")
96-
destination = graph_data.get("destination")
111+
source = graph_data.get("source") or "source"
112+
destination = graph_data.get("destination") or "target"
97113
title = graph_data.get("title")
98114
description = graph_data.get("description")
115+
columns = graph_data.get("columns", None)
116+
117+
g = None
118+
edges_df = None
119+
nodes_df = None
99120

100-
# Handle different input formats
101-
if data_format == "edge_list":
121+
if graph_type == "graph":
102122
if not edges:
103123
raise ValueError("edges list required for edge_list format")
104-
df = pd.DataFrame(edges)
105-
# Ensure source and target columns exist
106-
if "source" not in df.columns or "target" not in df.columns:
107-
raise ValueError("edges must contain 'source' and 'target' columns")
108-
g = graphistry.bind(source="source", destination="target").edges(df)
109-
elif data_format == "pandas":
110-
if not (source and destination):
111-
raise ValueError("source and destination column names required for pandas format")
112-
df = pd.DataFrame(edges)
113-
g = graphistry.bind(source=source, destination=destination)
114-
if node_id:
115-
g = g.bind(node=node_id)
116-
g = g.edges(df)
117-
elif data_format == "networkx":
118-
g = graphistry.bind().from_networkx(edges)
124+
edges_df = pd.DataFrame(edges)
125+
if nodes:
126+
nodes_df = pd.DataFrame(nodes)
127+
g = graphistry.edges(edges_df, source=source, destination=destination).nodes(nodes_df, node=node_id)
128+
else:
129+
g = graphistry.edges(edges_df, source=source, destination=destination)
130+
nx_graph = nx.from_pandas_edgelist(edges_df, source=source, target=destination)
131+
elif graph_type == "hypergraph":
132+
if not edges:
133+
raise ValueError("edges list required for hypergraph format")
134+
edges_df = pd.DataFrame(edges)
135+
g = graphistry.hypergraph(edges_df, columns)['graph']
136+
nx_graph = None
119137
else:
120-
raise ValueError(f"Unsupported data format: {data_format}")
121-
138+
raise ValueError(f"Unsupported graph_type: {graph_type}")
139+
g = g.name(title)
122140
# Generate unique ID and cache
123141
graph_id = f"graph_{len(graph_cache)}"
124142
graph_cache[graph_id] = {
125143
"graph": g,
126144
"title": title,
127145
"description": description,
128-
"edges_df": df if data_format in ["edge_list", "pandas"] else None,
129-
"nx_graph": edges if data_format == "networkx" else None
146+
"edges_df": edges_df,
147+
"nodes_df": nodes_df,
148+
"node_id": node_id,
149+
"source": source,
150+
"destination": destination,
151+
"nx_graph": nx_graph
130152
}
131153

132154
if ctx:
@@ -143,27 +165,20 @@ async def visualize_graph(graph_data: Dict[str, Any], ctx: Optional[Context] = N
143165

144166
@mcp.tool()
145167
async def get_graph_info(graph_id: str) -> Dict[str, Any]:
146-
"""Get information about a stored graph visualization.
147-
148-
Args:
149-
graph_id: ID of the graph to retrieve information for
150-
"""
168+
"""Get information about a stored graph visualization."""
151169
try:
152170
if graph_id not in graph_cache:
153171
raise ValueError(f"Graph not found: {graph_id}")
154172

155173
graph_data = graph_cache[graph_id]
156-
g = graph_data["graph"]
157174
edges_df = graph_data["edges_df"]
158-
nx_graph = graph_data["nx_graph"]
175+
source = graph_data["source"]
176+
destination = graph_data["destination"]
159177

160178
# Get node and edge counts
161179
if edges_df is not None:
162-
node_count = len(set(edges_df["source"].unique()) | set(edges_df["target"].unique()))
180+
node_count = len(set(edges_df[source].unique()) | set(edges_df[destination].unique()))
163181
edge_count = len(edges_df)
164-
elif nx_graph is not None:
165-
node_count = len(nx_graph.nodes())
166-
edge_count = len(nx_graph.edges())
167182
else:
168183
node_count = 0
169184
edge_count = 0
@@ -236,8 +251,6 @@ async def detect_patterns(graph_id: str, ctx: Optional[Context] = None) -> Dict[
236251
- anomalies (if anomaly detection is available)
237252
- errors (dict of analysis_type -> error message)
238253
239-
Example:
240-
result = await mcp.call_tool("detect_patterns", {"graph_id": "graph_1"})
241254
"""
242255
try:
243256
if graph_id not in graph_cache:
@@ -249,10 +262,12 @@ async def detect_patterns(graph_id: str, ctx: Optional[Context] = None) -> Dict[
249262
graph_data = graph_cache[graph_id]
250263
nx_graph = graph_data["nx_graph"]
251264
edges_df = graph_data["edges_df"]
265+
source = graph_data["source"]
266+
destination = graph_data["destination"]
252267

253268
# Convert to NetworkX graph if needed
254269
if nx_graph is None and edges_df is not None:
255-
nx_graph = nx.from_pandas_edgelist(edges_df, source="source", target="target")
270+
nx_graph = nx.from_pandas_edgelist(edges_df, source=source, target=destination)
256271

257272
if nx_graph is None:
258273
raise ValueError("Graph data not available for analysis")

0 commit comments

Comments
 (0)