Skip to content

Commit 6998402

Browse files
authored
Change GraphRAG search parameter query to query_text (#89)
* Fix doc for GraphRAG (arg is called query and not query_text= * Rename GraphRAG search parameter query to query_text for consistency with Retriever interface - and let the possibility to add query_vector param later on if requested * Update CHANGELOG * Check disk space before install * Check container size * Check docker image size * Test with one single neo/python version
1 parent 337aeba commit 6998402

File tree

8 files changed

+24
-19
lines changed

8 files changed

+24
-19
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
- Corrected initialization to allow specifying the embedding model name.
99
- Removed sentence_transformers from embeddings/__init__.py to avoid ImportError when the package is not installed.
1010

11+
### Changed
12+
- `GraphRAG.search` method first parameter has been renamed `query_text` (was `query`) for consistency with the retrievers interface.
13+
1114
## 0.3.0
1215

1316
### Added

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ rag = GraphRAG(retriever=retriever, llm=llm)
128128

129129
# Query the graph
130130
query_text = "How do I do similarity search in Neo4j?"
131-
response = rag.search(query_text=query_text, retriever_config={"top_k": 5})
131+
response = rag.search(query=query_text, retriever_config={"top_k": 5})
132132
print(response.answer)
133133
```
134134

examples/graphrag_custom_prompt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def formatter(record: neo4j.Record) -> RetrieverResultItem:
6060
{context}
6161
6262
Question:
63-
{query}
63+
{query_text}
6464
6565
Answer:
6666
"""

src/neo4j_genai/generation/graphrag.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(
5353

5454
def search(
5555
self,
56-
query: str,
56+
query_text: str,
5757
examples: str = "",
5858
retriever_config: Optional[dict[str, Any]] = None,
5959
return_context: bool = False,
@@ -64,7 +64,7 @@ def search(
6464
3. Generation: answer generation with LLM
6565
6666
Args:
67-
query (str): The user question
67+
query_text (str): The user question
6868
examples: Examples added to the LLM prompt.
6969
retriever_config (Optional[dict]): Parameters passed to the retriever
7070
search method; e.g.: top_k
@@ -76,20 +76,20 @@ def search(
7676
"""
7777
try:
7878
validated_data = RagSearchModel(
79-
query=query,
79+
query_text=query_text,
8080
examples=examples,
8181
retriever_config=retriever_config or {},
8282
return_context=return_context,
8383
)
8484
except ValidationError as e:
8585
raise SearchValidationError(e.errors())
86-
query = validated_data.query
86+
query_text = validated_data.query_text
8787
retriever_result: RetrieverResult = self.retriever.search(
88-
query_text=query, **validated_data.retriever_config
88+
query_text=query_text, **validated_data.retriever_config
8989
)
9090
context = "\n".join(item.content for item in retriever_result.items)
9191
prompt = self.prompt_template.format(
92-
query=query, context=context, examples=validated_data.examples
92+
query_text=query_text, context=context, examples=validated_data.examples
9393
)
9494
logger.debug(f"RAG: retriever_result={retriever_result}")
9595
logger.debug(f"RAG: prompt={prompt}")

src/neo4j_genai/generation/prompts.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,14 @@ class RagTemplate(PromptTemplate):
8787
{examples}
8888
8989
Question:
90-
{query}
90+
{query_text}
9191
9292
Answer:
9393
"""
94-
EXPECTED_INPUTS = ["context", "query", "examples"]
94+
EXPECTED_INPUTS = ["context", "query_text", "examples"]
9595

96-
def format(self, query: str, context: str, examples: str) -> str:
97-
return super().format(query=query, context=context, examples=examples)
96+
def format(self, query_text: str, context: str, examples: str) -> str:
97+
return super().format(query_text=query_text, context=context, examples=examples)
9898

9999

100100
class Text2CypherTemplate(PromptTemplate):

src/neo4j_genai/generation/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def check_llm(cls, value: Any) -> Any:
3939

4040

4141
class RagSearchModel(BaseModel):
42-
query: str
42+
query_text: str
4343
examples: str = ""
4444
retriever_config: dict[str, Any] = {}
4545
return_context: bool = False

tests/e2e/test_graphrag_e2e.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_graphrag_happy_path(
5252
llm.invoke.return_value = LLMResponse(content="some text")
5353

5454
result = rag.search(
55-
query="biology",
55+
query_text="biology",
5656
retriever_config={
5757
"top_k": 2,
5858
},
@@ -96,7 +96,7 @@ def test_graphrag_happy_path_return_context(
9696
llm.invoke.return_value = LLMResponse(content="some text")
9797

9898
result = rag.search(
99-
query="biology",
99+
query_text="biology",
100100
retriever_config={
101101
"top_k": 2,
102102
},
@@ -142,7 +142,7 @@ def test_graphrag_happy_path_examples(
142142
llm.invoke.return_value = LLMResponse(content="some text")
143143

144144
result = rag.search(
145-
query="biology",
145+
query_text="biology",
146146
retriever_config={
147147
"top_k": 2,
148148
},
@@ -186,7 +186,7 @@ def test_graphrag_llm_error(
186186

187187
with pytest.raises(LLMGenerationError):
188188
rag.search(
189-
query="biology",
189+
query_text="biology",
190190
)
191191

192192

@@ -203,5 +203,5 @@ def test_graphrag_retrieval_error(
203203

204204
with pytest.raises(TypeError):
205205
rag.search(
206-
query="biology",
206+
query_text="biology",
207207
)

tests/unit/test_graphrag.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525

2626
def test_graphrag_prompt_template() -> None:
2727
template = RagTemplate()
28-
prompt = template.format(context="my context", query="user's query", examples="")
28+
prompt = template.format(
29+
context="my context", query_text="user's query", examples=""
30+
)
2931
assert (
3032
prompt
3133
== """Answer the user question using the following context

0 commit comments

Comments
 (0)