Skip to content

Fixes for graph pruning #359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 124 additions & 128 deletions poetry.lock

Large diffs are not rendered by default.

19 changes: 18 additions & 1 deletion src/neo4j_graphrag/experimental/components/graph_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Neo4jGraph,
Neo4jNode,
Neo4jRelationship,
LexicalGraphConfig,
)
from neo4j_graphrag.experimental.pipeline import Component, DataModel

Expand Down Expand Up @@ -135,9 +136,14 @@ async def run(
self,
graph: Neo4jGraph,
schema: Optional[GraphSchema] = None,
lexical_graph_config: Optional[LexicalGraphConfig] = None,
) -> GraphPruningResult:
if lexical_graph_config is None:
lexical_graph_config = LexicalGraphConfig()
if schema is not None:
new_graph, pruning_stats = self._clean_graph(graph, schema)
new_graph, pruning_stats = self._clean_graph(
graph, schema, lexical_graph_config
)
else:
new_graph = graph
pruning_stats = PruningStats()
Expand All @@ -150,6 +156,7 @@ def _clean_graph(
self,
graph: Neo4jGraph,
schema: GraphSchema,
lexical_graph_config: LexicalGraphConfig,
) -> tuple[Neo4jGraph, PruningStats]:
"""
Verify that the graph conforms to the provided schema.
Expand All @@ -162,6 +169,7 @@ def _clean_graph(
filtered_nodes = self._enforce_nodes(
graph.nodes,
schema,
lexical_graph_config,
pruning_stats,
)
if not filtered_nodes:
Expand All @@ -174,6 +182,7 @@ def _clean_graph(
graph.relationships,
filtered_nodes,
schema,
lexical_graph_config,
pruning_stats,
)

Expand Down Expand Up @@ -216,6 +225,7 @@ def _enforce_nodes(
self,
extracted_nodes: list[Neo4jNode],
schema: GraphSchema,
lexical_graph_config: LexicalGraphConfig,
pruning_stats: PruningStats,
) -> list[Neo4jNode]:
"""
Expand All @@ -228,6 +238,9 @@ def _enforce_nodes(
"""
valid_nodes = []
for node in extracted_nodes:
if node.label in lexical_graph_config.lexical_graph_node_labels:
valid_nodes.append(node)
continue
schema_entity = schema.node_type_from_label(node.label)
new_node = self._validate_node(
node,
Expand Down Expand Up @@ -319,6 +332,7 @@ def _enforce_relationships(
extracted_relationships: list[Neo4jRelationship],
filtered_nodes: list[Neo4jNode],
schema: GraphSchema,
lexical_graph_config: LexicalGraphConfig,
pruning_stats: PruningStats,
) -> list[Neo4jRelationship]:
"""
Expand All @@ -334,6 +348,9 @@ def _enforce_relationships(
valid_rels = []
valid_nodes = {node.id: node.label for node in filtered_nodes}
for rel in extracted_relationships:
if rel.type in lexical_graph_config.lexical_graph_relationship_types:
valid_rels.append(rel)
continue
schema_relation = schema.relationship_type_from_label(rel.type)
new_rel = self._validate_relationship(
rel,
Expand Down
15 changes: 12 additions & 3 deletions src/neo4j_graphrag/experimental/components/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def create_schema_model(
node_types: Sequence[NodeType],
relationship_types: Optional[Sequence[RelationshipType]] = None,
patterns: Optional[Sequence[Tuple[str, str, str]]] = None,
**kwargs: Any,
) -> GraphSchema:
"""
Creates a GraphSchema object from Lists of Entity and Relation objects
Expand All @@ -343,6 +344,7 @@ def create_schema_model(
node_types (Sequence[NodeType]): List or tuple of NodeType objects.
relationship_types (Optional[Sequence[RelationshipType]]): List or tuple of RelationshipType objects.
patterns (Optional[Sequence[Tuple[str, str, str]]]): List or tuples of triplets: (source_entity_label, relation_label, target_entity_label).
kwargs: other arguments passed to GraphSchema validator.

Returns:
GraphSchema: A configured schema object.
Expand All @@ -353,17 +355,19 @@ def create_schema_model(
node_types=node_types,
relationship_types=relationship_types or (),
patterns=patterns or (),
**kwargs,
)
)
except (ValidationError, SchemaValidationError) as e:
raise SchemaValidationError(e) from e
except ValidationError as e:
raise SchemaValidationError() from e

@validate_call
async def run(
self,
node_types: Sequence[NodeType],
relationship_types: Optional[Sequence[RelationshipType]] = None,
patterns: Optional[Sequence[Tuple[str, str, str]]] = None,
**kwargs: Any,
) -> GraphSchema:
"""
Asynchronously constructs and returns a GraphSchema object.
Expand All @@ -376,7 +380,12 @@ async def run(
Returns:
GraphSchema: A configured schema object, constructed asynchronously.
"""
return self.create_schema_model(node_types, relationship_types, patterns)
return self.create_schema_model(
node_types,
relationship_types,
patterns,
**kwargs,
)


class SchemaFromTextExtractor(Component):
Expand Down
8 changes: 8 additions & 0 deletions src/neo4j_graphrag/experimental/components/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,14 @@ class LexicalGraphConfig(BaseModel):
def lexical_graph_node_labels(self) -> tuple[str, ...]:
return self.document_node_label, self.chunk_node_label

@property
def lexical_graph_relationship_types(self) -> tuple[str, ...]:
return (
self.chunk_to_document_relationship_type,
self.next_chunk_relationship_type,
self.node_to_chunk_relationship_type,
)


class GraphResult(DataModel):
graph: Neo4jGraph
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
Optional,
Sequence,
Union,
Tuple,
)
import logging
import warnings
Expand All @@ -45,8 +44,6 @@
from neo4j_graphrag.experimental.components.schema import (
SchemaBuilder,
GraphSchema,
NodeType,
RelationshipType,
SchemaFromTextExtractor,
)
from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter
Expand Down Expand Up @@ -184,66 +181,33 @@ def _get_schema(self) -> Union[SchemaBuilder, SchemaFromTextExtractor]:
return SchemaFromTextExtractor(llm=self.get_default_llm())
return SchemaBuilder()

def _process_schema_with_precedence(
self,
) -> Tuple[
Tuple[NodeType, ...],
Tuple[RelationshipType, ...] | None,
Optional[Tuple[Tuple[str, str, str], ...]] | None,
]:
def _process_schema_with_precedence(self) -> dict[str, Any]:
"""
Process schema inputs according to precedence rules:
1. If schema is provided as GraphSchema object, use it
2. If schema is provided as dictionary, extract from it
3. Otherwise, use individual schema components

Returns:
Tuple of (node_types, relationship_types, patterns)
A dict representing the schema
"""
if self.schema_ is not None:
# schema takes precedence over individual components
node_types = self.schema_.node_types
return self.schema_.model_dump()

# handle case where relations could be None
if self.schema_.relationship_types is not None:
relationship_types = self.schema_.relationship_types
else:
relationship_types = None

patterns = self.schema_.patterns
else:
# use individual components
node_types = tuple(
[NodeType.model_validate(e) for e in self.entities]
if self.entities
else []
)
relationship_types = (
tuple([RelationshipType.model_validate(r) for r in self.relations])
if self.relations is not None
else None
)
patterns = (
tuple(self.potential_schema) if self.potential_schema else tuple()
)

return node_types, relationship_types, patterns
return dict(
node_types=self.entities,
relationship_types=self.relations,
patterns=self.potential_schema,
)

def _get_run_params_for_schema(self) -> dict[str, Any]:
if not self.has_user_provided_schema():
# for automatic extraction, the text parameter is needed (will flow through the pipeline connections)
return {}
else:
# process schema components according to precedence rules
node_types, relationship_types, patterns = (
self._process_schema_with_precedence()
)

return {
"node_types": node_types,
"relationship_types": relationship_types,
"patterns": patterns,
}
schema_dict = self._process_schema_with_precedence()
return schema_dict

def _get_extractor(self) -> EntityRelationExtractor:
return LLMEntityRelationExtractor(
Expand Down Expand Up @@ -368,7 +332,13 @@ def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]:
run_params = {}
if self.lexical_graph_config:
run_params["extractor"] = {
"lexical_graph_config": self.lexical_graph_config
"lexical_graph_config": self.lexical_graph_config,
}
run_params["writer"] = {
"lexical_graph_config": self.lexical_graph_config,
}
run_params["pruner"] = {
"lexical_graph_config": self.lexical_graph_config,
}
text = user_input.get("text")
file_path = user_input.get("file_path")
Expand Down
15 changes: 12 additions & 3 deletions src/neo4j_graphrag/experimental/pipeline/kg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,29 @@ class SimpleKGPipeline:
llm (LLMInterface): An instance of an LLM to use for entity and relation extraction.
driver (neo4j.Driver): A Neo4j driver instance for database connection.
embedder (Embedder): An instance of an embedder used to generate chunk embeddings from text chunks.
schema (Optional[Union[GraphSchema, dict[str, list]]]): A schema configuration defining entities,
relations, and potential schema relationships.
This is the recommended way to provide schema information.
schema (Optional[Union[GraphSchema, dict[str, list]]]): A schema configuration defining node types,
relationship types, and graph patterns.
entities (Optional[List[Union[str, dict[str, str], NodeType]]]): DEPRECATED. A list of either:

- str: entity labels
- dict: following the NodeType schema, ie with label, description and properties keys

.. deprecated:: 1.7.1
Use schema instead

relations (Optional[List[Union[str, dict[str, str], RelationshipType]]]): DEPRECATED. A list of either:

- str: relation label
- dict: following the RelationshipType schema, ie with label, description and properties keys

.. deprecated:: 1.7.1
Use schema instead

potential_schema (Optional[List[tuple]]): DEPRECATED. A list of potential schema relationships.

.. deprecated:: 1.7.1
Use schema instead

from_pdf (bool): Determines whether to include the PdfLoader in the pipeline.
If True, expects `file_path` input in `run` methods.
If False, expects `text` input in `run` methods.
Expand Down
Loading