diff --git a/examples/customize/llms/openai_tool_calls.py b/examples/customize/llms/openai_tool_calls.py index 166fb7248..87a14f8df 100644 --- a/examples/customize/llms/openai_tool_calls.py +++ b/examples/customize/llms/openai_tool_calls.py @@ -17,7 +17,12 @@ from neo4j_graphrag.llm import OpenAILLM from neo4j_graphrag.llm.types import ToolCallResponse -from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter, IntegerParameter +from neo4j_graphrag.tools.tool import ( + Tool, + ObjectParameter, + StringParameter, + IntegerParameter, +) # Load environment variables from .env file (OPENAI_API_KEY required for this example) load_dotenv() diff --git a/examples/customize/llms/vertexai_tool_calls.py b/examples/customize/llms/vertexai_tool_calls.py index ebe9fec22..0d91e1eb3 100644 --- a/examples/customize/llms/vertexai_tool_calls.py +++ b/examples/customize/llms/vertexai_tool_calls.py @@ -11,7 +11,12 @@ from neo4j_graphrag.llm import VertexAILLM from neo4j_graphrag.llm.types import ToolCallResponse -from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter, IntegerParameter +from neo4j_graphrag.tools.tool import ( + Tool, + ObjectParameter, + StringParameter, + IntegerParameter, +) # Load environment variables from .env file load_dotenv() diff --git a/examples/retrieve/tools/multiple_tools_example.py b/examples/retrieve/tools/multiple_tools_example.py new file mode 100644 index 000000000..6b3986e4d --- /dev/null +++ b/examples/retrieve/tools/multiple_tools_example.py @@ -0,0 +1,152 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Example demonstrating how to create multiple domain-specific tools from retrievers. + +This example shows: +1. How to create multiple tools from the same retriever type for different use cases +2. How to provide custom parameter descriptions for each tool +3. How type inference works automatically while descriptions are explicit +""" + +import neo4j +from typing import cast, Any, Optional +from unittest.mock import MagicMock + +from neo4j_graphrag.retrievers.base import Retriever +from neo4j_graphrag.types import RawSearchResult + + +class MockVectorRetriever(Retriever): + """A mock vector retriever for demonstration purposes.""" + + VERIFY_NEO4J_VERSION = False + + def __init__(self, driver: neo4j.Driver, index_name: str): + super().__init__(driver) + self.index_name = index_name + + def get_search_results( + self, + query_vector: Optional[list[float]] = None, + query_text: Optional[str] = None, + top_k: int = 5, + effective_search_ratio: int = 1, + filters: Optional[dict[str, Any]] = None, + ) -> RawSearchResult: + """Get vector search results (mocked for demonstration).""" + # Return empty results for demo + return RawSearchResult(records=[], metadata={"index": self.index_name}) + + +def main() -> None: + """Demonstrate creating multiple domain-specific tools from retrievers.""" + + # Create mock driver (in real usage, this would be actual Neo4j driver) + driver = cast(Any, MagicMock()) + + # Create retrievers for different domains using the same retriever type + # In practice, these would point to different vector indexes + + # Movie recommendations retriever + movie_retriever = MockVectorRetriever( + driver=driver, + index_name="movie_embeddings" + ) + + # Product search retriever + product_retriever = MockVectorRetriever( + driver=driver, + index_name="product_embeddings" + ) + + # Document search retriever + document_retriever = MockVectorRetriever( + driver=driver, + index_name="document_embeddings" + ) + + # Convert each retriever to a domain-specific tool with custom descriptions + + # 1. Movie recommendation tool + movie_tool = movie_retriever.convert_to_tool( + name="movie_search", + description="Find movie recommendations based on plot, genre, or actor preferences", + parameter_descriptions={ + "query_text": "Movie title, plot description, genre, or actor name", + "query_vector": "Pre-computed embedding vector for movie search", + "top_k": "Number of movie recommendations to return (1-20)", + "filters": "Optional filters for genre, year, rating, etc.", + "effective_search_ratio": "Search pool multiplier for better accuracy" + } + ) + + # 2. Product search tool + product_tool = product_retriever.convert_to_tool( + name="product_search", + description="Search for products matching customer needs and preferences", + parameter_descriptions={ + "query_text": "Product name, description, or customer need", + "query_vector": "Pre-computed embedding for product matching", + "top_k": "Maximum number of product results (1-50)", + "filters": "Filters for price range, brand, category, availability", + "effective_search_ratio": "Breadth vs precision trade-off for search" + } + ) + + # 3. Document search tool + document_tool = document_retriever.convert_to_tool( + name="document_search", + description="Find relevant documents and knowledge articles", + parameter_descriptions={ + "query_text": "Question, keywords, or topic to search for", + "query_vector": "Semantic embedding for document retrieval", + "top_k": "Number of relevant documents to retrieve (1-10)", + "filters": "Document type, date range, or department filters" + } + ) + + # Demonstrate that each tool has distinct, meaningful descriptions + tools = [movie_tool, product_tool, document_tool] + + for tool in tools: + print(f"\n=== {tool.get_name().upper()} ===") + print(f"Description: {tool.get_description()}") + print("Parameters:") + + params = tool.get_parameters() + for param_name, param_def in params["properties"].items(): + required = "required" if param_name in params.get("required", []) else "optional" + print(f" - {param_name} ({param_def['type']}, {required}): {param_def['description']}") + + # Show how the same parameter type gets different contextual descriptions + print(f"\n=== PARAMETER COMPARISON ===") + print("Same parameter 'query_text' with different contextual descriptions:") + + for tool in tools: + params = tool.get_parameters() + query_text_desc = params["properties"]["query_text"]["description"] + print(f" {tool.get_name()}: {query_text_desc}") + + print(f"\nSame parameter 'top_k' with different contextual descriptions:") + for tool in tools: + params = tool.get_parameters() + top_k_desc = params["properties"]["top_k"]["description"] + print(f" {tool.get_name()}: {top_k_desc}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/retrieve/tools/retriever_to_tool_example.py b/examples/retrieve/tools/retriever_to_tool_example.py new file mode 100644 index 000000000..a41b60db2 --- /dev/null +++ b/examples/retrieve/tools/retriever_to_tool_example.py @@ -0,0 +1,118 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Example demonstrating how to convert a retriever to a tool. + +This example shows: +1. How to convert a custom StaticRetriever to a Tool using the convert_to_tool method +2. How to define parameters for the tool in the retriever class +3. How to execute the tool +""" + +import neo4j +from typing import Optional, Any, cast +from unittest.mock import MagicMock + +from neo4j_graphrag.retrievers.base import Retriever +from neo4j_graphrag.types import RawSearchResult +from neo4j_graphrag.tools.tool import ( + StringParameter, + ObjectParameter, +) + + +# Create a Retriever that returns static results about Neo4j +# This would illustrate the conversion process of any Retriever (Vector, Hybrid, etc.) +class StaticRetriever(Retriever): + """A retriever that returns static results about Neo4j.""" + + # Disable Neo4j version verification + VERIFY_NEO4J_VERSION = False + + def __init__(self, driver: neo4j.Driver): + # Call the parent class constructor with the driver + super().__init__(driver) + + def get_search_results( + self, query_text: Optional[str] = None, **kwargs: Any + ) -> RawSearchResult: + """Return static information about Neo4j regardless of the query. + + Args: + query_text (Optional[str]): The query about Neo4j (any query will return general Neo4j information) + **kwargs (Any): Additional keyword arguments (not used) + + Returns: + RawSearchResult: Static Neo4j information with metadata + """ + # Create formatted Neo4j information + neo4j_info = ( + "# Neo4j Graph Database\n\n" + "Neo4j is a graph database management system developed by Neo4j, Inc. " + "It is an ACID-compliant transactional database with native graph storage and processing.\n\n" + "## Key Features:\n\n" + "- **Cypher Query Language**: Neo4j's intuitive query language designed specifically for working with graph data\n" + "- **Property Graphs**: Both nodes and relationships can have properties (key-value pairs)\n" + "- **ACID Compliance**: Ensures data integrity with full transaction support\n" + "- **Native Graph Storage**: Optimized storage for graph data structures\n" + "- **High Availability**: Clustering for enterprise deployments\n" + "- **Scalability**: Handles billions of nodes and relationships" + ) + + # Create a Neo4j record with the information + records = [neo4j.Record({"result": neo4j_info})] + + # Return a RawSearchResult with the records and metadata + return RawSearchResult(records=records, metadata={"query": query_text}) + + +def main() -> None: + # Convert a StaticRetriever to a tool using the new convert_to_tool method + static_retriever = StaticRetriever(driver=cast(Any, MagicMock())) + + # Convert the retriever to a tool with custom parameter descriptions + static_tool = static_retriever.convert_to_tool( + name="Neo4jInfoTool", + description="Get general information about Neo4j graph database", + parameter_descriptions={ + "query_text": "Any query about Neo4j (the tool returns general information regardless)" + }, + ) + + # Print tool information + print("Example: StaticRetriever with specific parameters") + print(f"Tool Name: {static_tool.get_name()}") + print(f"Tool Description: {static_tool.get_description()}") + print(f"Tool Parameters: {static_tool.get_parameters()}") + print() + + # Execute the tools (in a real application, this would be done by instructions from an LLM) + try: + # Execute the static retriever tool + print("\nExecuting the static retriever tool...") + static_result = static_tool.execute( + query_text="What is Neo4j?", + ) + print("Static Search Results:") + for i, item in enumerate(static_result): + print(f"{i + 1}. {str(item)[:100]}...") + + except Exception as e: + print(f"Error executing tool: {e}") + + +if __name__ == "__main__": + main() diff --git a/examples/retrieve/tools/tools_retriever_example.py b/examples/retrieve/tools/tools_retriever_example.py new file mode 100644 index 000000000..3309205cf --- /dev/null +++ b/examples/retrieve/tools/tools_retriever_example.py @@ -0,0 +1,350 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Example demonstrating how to use the ToolsRetriever. + +This example shows: +1. How to create tools from different retrievers +2. How to use the ToolsRetriever to select and execute tools based on a query +""" + +import os +from typing import Any, Optional, cast +from unittest.mock import MagicMock +from dotenv import load_dotenv +import requests +from datetime import datetime, date + +import neo4j + +from neo4j_graphrag.generation import GraphRAG +from neo4j_graphrag.retrievers.base import Retriever +from neo4j_graphrag.retrievers.tools_retriever import ToolsRetriever +from neo4j_graphrag.types import RawSearchResult +from neo4j_graphrag.tools.tool import ( + ObjectParameter, + StringParameter, + Tool, +) +from neo4j_graphrag.tools.utils import convert_retriever_to_tool +from neo4j_graphrag.llm.openai_llm import OpenAILLM + +# Load environment variables from .env file (OPENAI_API_KEY required for this example) +load_dotenv() + + +# Create a Retriever that returns static results about Neo4j +class Neo4jInfoRetriever(Retriever): + """A retriever that returns general information about Neo4j.""" + + # Disable Neo4j version verification + VERIFY_NEO4J_VERSION = False + + def __init__(self, driver: neo4j.Driver): + # Call the parent class constructor with the driver + super().__init__(driver) + + def get_search_results( + self, query_text: Optional[str] = None, **kwargs: Any + ) -> RawSearchResult: + """Return general information about Neo4j.""" + # Create formatted Neo4j information + neo4j_info = ( + "# Neo4j Graph Database\n\n" + "Neo4j is a graph database management system developed by Neo4j, Inc. " + "It is an ACID-compliant transactional database with native graph storage and processing.\n\n" + "## Key Features:\n\n" + "- **Cypher Query Language**: Neo4j's intuitive query language designed specifically for working with graph data\n" + "- **Property Graphs**: Both nodes and relationships can have properties (key-value pairs)\n" + "- **ACID Compliance**: Ensures data integrity with full transaction support\n" + "- **Native Graph Storage**: Optimized storage for graph data structures\n" + "- **High Availability**: Clustering for enterprise deployments\n" + "- **Scalability**: Handles billions of nodes and relationships" + ) + + # Create a Neo4j record with the information + records = [neo4j.Record({"result": neo4j_info})] + + # Return a RawSearchResult with the records and metadata + return RawSearchResult(records=records, metadata={"query": query_text}) + + +class CalendarTool(Tool): + """A simple tool to get calendar information.""" + + def __init__(self) -> None: + """Initialize the calendar tool.""" + # Define parameters for the tool + parameters = ObjectParameter( + description="Parameters for calendar information retrieval", + properties={ + "date": StringParameter( + description="The date to check events for in YYYY-MM-DD format (e.g., 2025-04-14)", + ), + }, + required_properties=["date"], + ) + + # Sample calendar data with fixed dates + self.calendar_data = { + "2025-04-15": [ + {"time": "10:00", "title": "Dentist Appointment"}, + {"time": "14:00", "title": "Conference Call"}, + ], + "2025-04-16": [], + } + + # Define a wrapper function that handles the query parameter correctly + def execute_func(query: str, **kwargs: Any) -> str: + # Ignore the query parameter and call our execute method + return self.execute_calendar(**kwargs) + + super().__init__( + name="calendar_tool", + description="Check calendar events for a specific date in YYYY-MM-DD format", + parameters=parameters, + execute_func=execute_func, + ) + + def execute_calendar(self, **kwargs: Any) -> str: + """Execute the calendar tool. + + Args: + **kwargs: Dictionary of parameters, including 'date'. + + Returns: + str: The events for the specified date. + """ + date = kwargs.get("date") + if not date: + return "Error: No date provided" + + # Check for events on the date + if date in self.calendar_data: + events_list = self.calendar_data[date] + if not events_list: + return f"No events scheduled for {date}" + + events_str = "\n".join( + f"- {event.get('time', 'All day')}: {event.get('title', 'Untitled event')}" + for event in events_list + ) + return f"Events for {date}:\n{events_str}" + else: + return f"No events found for {date}" + + +class WeatherTool(Tool): + """A tool to fetch weather in Malmö, Sweden for a specific date.""" + + def __init__(self) -> None: + """Initialize the weather tool.""" + parameters = ObjectParameter( + description="Parameters for fetching weather information about a date.", + properties={ + "date": StringParameter( + description='The date, in YYYY-MM-DD format. Example: "2025-04-25"' + ) + }, + required_properties=["date"], + ) + super().__init__( + name="weather_tool", + description="Check for weather for a specific date in YYYY-MM-DD format", + parameters=parameters, + execute_func=self.execute_weather_retrieval, + ) + + def execute_weather_retrieval( + self, query: Optional[str] = None, **kwargs: Any + ) -> str: + """Fetch historical weather data for a given date in Malmö, Sweden.""" + date_str = kwargs.get("date") + if not date_str: + return "Error: Date not provided for weather lookup." + + try: + input_date = datetime.strptime(date_str, "%Y-%m-%d").date() + except ValueError: + return f"Error: Invalid date format '{date_str}'. Please use YYYY-MM-DD." + + today_date = date.today() + + if input_date < today_date: + api_url = "https://archive-api.open-meteo.com/v1/archive" + else: + # For today or future dates, use the forecast API + # Note: Forecast API typically has a limit (e.g., 16 days into the future) + api_url = "https://api.open-meteo.com/v1/forecast" + + params = { + "latitude": 55.6059, # Malmö, Sweden + "longitude": 13.0007, # Malmö, Sweden + "start_date": date_str, + "end_date": date_str, + "daily": "temperature_2m_max,sunshine_duration", + "timezone": "Europe/Stockholm", + } + headers = {"Accept": "application/json"} + + try: + response = requests.get(api_url, headers=headers, params=params) + response.raise_for_status() + data = response.json() + + # Try to access keys directly, relying on the existing broader except block for errors + daily = data["daily"] + temp_max = daily["temperature_2m_max"][0] + sunshine_seconds = daily["sunshine_duration"][0] + + sunshine_hours = 0 + if ( + sunshine_seconds is not None + ): # API might return null for sunshine_duration + sunshine_hours = round(sunshine_seconds / 3600, 1) + + return ( + f"Weather for Malmö, Sweden on this day:\n" + f"- Max Temperature: {temp_max}°C\n" + f"- Sunshine Duration: {sunshine_hours} hours" + ) + except requests.exceptions.RequestException as e: + return f"API request failed for weather data: {e}" + except ( + ValueError, + KeyError, + ) as e: + return f"Error parsing weather data for Malmö on {date_str}: {e}" + + return ( + f"Sorry, I couldn't fetch the weather for Malmö on {date_str} at this time." + ) + + +def main() -> None: + """Run the example.""" + # Create a mock Neo4j driver + driver = cast(neo4j.Driver, MagicMock()) + + # Create retrievers + neo4j_retriever = Neo4jInfoRetriever(driver=driver) + + # Define parameters for the tools + neo4j_parameters = ObjectParameter( + description="Parameters for Neo4j information retrieval", + properties={ + "query": StringParameter( + description="The query about Neo4j", + ), + }, + required_properties=["query"], + ) + + # Convert retrievers to tools + neo4j_tool = convert_retriever_to_tool( + retriever=neo4j_retriever, + name="neo4j_info_tool", + description="Get information about Neo4j graph database", + parameters=neo4j_parameters, + ) + + # Create a calendar tool + calendar_tool = CalendarTool() + + # Create a weather tool + weather_tool = WeatherTool() + + # Create an OpenAI LLM + llm = OpenAILLM( + api_key=os.getenv("OPENAI_API_KEY"), + model_name="gpt-4o", + model_params={ + "temperature": 0.2, + }, + ) + + # Print tool information for debugging + print("\nTool Information:") + print(f"Neo4j Tool: {neo4j_tool.get_name()}, {neo4j_tool.get_description()}") + print( + f"Calendar Tool: {calendar_tool.get_name()}, {calendar_tool.get_description()}" + ) + parameters_description = ( + weather_tool._parameters.description + if weather_tool._parameters + else "No parameters description" + ) + print( + f"Weather Tool: {weather_tool.get_name()}, {weather_tool.get_description()}: {parameters_description}" + ) + + # Create a ToolsRetriever with the LLM and tools + tools_retriever = ToolsRetriever( + driver=driver, + llm=llm, + tools=[neo4j_tool, calendar_tool, weather_tool], + ) + + # Test queries + test_queries = [ + "What is Neo4j?", + "Do I have any meetings the 15th of April 2025?", + "Any information about 2025-04-16?", + ] + + # Run just the tools retriever directly to show metadata etc. + print(f"\n\n{'=' * 80}") + print("Retriever call examples, to show metadata etc.") + print(f"{'=' * 80}") + for query in test_queries: + print(f"Query: {query}") + + try: + # Get search results through the ToolsRetriever + result = tools_retriever.get_search_results(query_text=query) + + # Print metadata + if result.metadata is not None: + print(f"\nTools selected: {result.metadata.get('tools_selected', [])}") + if result.metadata.get("error", ""): + print(f"Error: {result.metadata.get('error', '')}") + + # Print results + print("\nRESULTS:") + for i, record in enumerate(result.records): + print(f"\n--- Result {i + 1} ---") + print(record) + except Exception as e: + print(f"Error: {str(e)}") + print(f"{'=' * 80}") + + # For demo purposes, run the queries through GraphRAG to get text input -> text output + print(f"\n\n{'=' * 80}") + print("Full GraphRAG examples") + print(f"{'=' * 80}") + for query in test_queries: + print(f"Query: {query}") + # Full GraphRAG example + graphrag = GraphRAG( + llm=llm, + retriever=tools_retriever, + ) + rag_result = graphrag.search(query_text=query, return_context=False) + print(f"Answer: {rag_result.answer}") + print(f"{'=' * 80}") + + +if __name__ == "__main__": + main() diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index 87d281794..d634ce085 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -22,7 +22,7 @@ from .types import LLMResponse, ToolCallResponse -from neo4j_graphrag.tool import Tool +from neo4j_graphrag.tools.tool import Tool class LLMInterface(ABC): diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index 1e0228e45..563944842 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -49,7 +49,7 @@ UserMessage, ) -from neo4j_graphrag.tool import Tool +from neo4j_graphrag.tools.tool import Tool if TYPE_CHECKING: import openai diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index 39d483915..513c6275d 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -27,7 +27,7 @@ ToolCallResponse, ) from neo4j_graphrag.message_history import MessageHistory -from neo4j_graphrag.tool import Tool +from neo4j_graphrag.tools.tool import Tool from neo4j_graphrag.types import LLMMessage try: diff --git a/src/neo4j_graphrag/retrievers/__init__.py b/src/neo4j_graphrag/retrievers/__init__.py index 595eac93b..061679d93 100644 --- a/src/neo4j_graphrag/retrievers/__init__.py +++ b/src/neo4j_graphrag/retrievers/__init__.py @@ -15,6 +15,7 @@ from .hybrid import HybridCypherRetriever, HybridRetriever from .text2cypher import Text2CypherRetriever +from .tools_retriever import ToolsRetriever from .vector import VectorCypherRetriever, VectorRetriever __all__ = [ @@ -23,6 +24,7 @@ "HybridRetriever", "HybridCypherRetriever", "Text2CypherRetriever", + "ToolsRetriever", ] diff --git a/src/neo4j_graphrag/retrievers/base.py b/src/neo4j_graphrag/retrievers/base.py index c3b295d15..d8121ffb5 100644 --- a/src/neo4j_graphrag/retrievers/base.py +++ b/src/neo4j_graphrag/retrievers/base.py @@ -17,7 +17,7 @@ import inspect import types from abc import ABC, ABCMeta, abstractmethod -from typing import Any, Callable, Optional, TypeVar +from typing import Any, Callable, Optional, TypeVar, get_args, get_origin, Union, List, Dict, get_type_hints import neo4j from typing_extensions import ParamSpec @@ -175,6 +175,210 @@ def default_record_formatter(self, record: neo4j.Record) -> RetrieverResultItem: """ return RetrieverResultItem(content=str(record), metadata=record.get("metadata")) + def get_parameters(self, parameter_descriptions: Optional[Dict[str, str]] = None) -> "ObjectParameter": + """Return the parameters that this retriever expects for tool conversion. + + This method automatically infers parameters from the get_search_results method signature. + + Args: + parameter_descriptions (Optional[Dict[str, str]]): Custom descriptions for parameters. + Keys should match parameter names from get_search_results method. + + Returns: + ObjectParameter: The parameter definition for this retriever + """ + return self._infer_parameters_from_signature(parameter_descriptions or {}) + + def _infer_parameters_from_signature(self, parameter_descriptions: Dict[str, str]) -> "ObjectParameter": + """Infer parameters from the get_search_results method signature.""" + # Import here to avoid circular imports + from neo4j_graphrag.tools.tool import ( + ObjectParameter, + StringParameter, + IntegerParameter, + NumberParameter, + ArrayParameter, + ) + + # Get the method signature and resolved type hints + sig = inspect.signature(self.get_search_results) + try: + type_hints = get_type_hints(self.get_search_results) + except (NameError, AttributeError): + # If type hints can't be resolved, fall back to annotation strings + type_hints = {} + + properties = {} + required_properties = [] + + for param_name, param in sig.parameters.items(): + # Skip 'self' parameter + if param_name == 'self': + continue + + # Skip **kwargs + if param.kind == inspect.Parameter.VAR_KEYWORD: + continue + + # Determine if parameter is required (no default value) + is_required = param.default is inspect.Parameter.empty + + # Use resolved type hint if available, otherwise fall back to annotation + type_annotation = type_hints.get(param_name, param.annotation) + + # Get the parameter type and create appropriate tool parameter + tool_param = self._create_tool_parameter_from_type(param_name, type_annotation, is_required, parameter_descriptions) + + if tool_param: + properties[param_name] = tool_param + if is_required: + required_properties.append(param_name) + + return ObjectParameter( + description=f"Parameters for {self.__class__.__name__}", + properties=properties, + required_properties=required_properties, + additional_properties=False, + ) + + def _create_tool_parameter_from_type(self, param_name: str, type_annotation: Any, is_required: bool, parameter_descriptions: Dict[str, str]) -> Optional["StringParameter"]: + """Create a tool parameter from a type annotation.""" + # Import here to avoid circular imports + from neo4j_graphrag.tools.tool import ( + StringParameter, + IntegerParameter, + NumberParameter, + ArrayParameter, + ObjectParameter, + ) + + # Handle None/missing annotation + if type_annotation is inspect.Parameter.empty or type_annotation is None: + return StringParameter( + description=parameter_descriptions.get(param_name, f"Parameter {param_name}"), + required=is_required, + ) + + # Get the origin and args for generic types + origin = get_origin(type_annotation) + args = get_args(type_annotation) + + # Handle Optional[T] and Union[T, None] + if origin is Union: + # Remove None from union args to get the actual type + non_none_args = [arg for arg in args if arg is not type(None)] + if len(non_none_args) == 1: + # This is Optional[T], use T + type_annotation = non_none_args[0] + # Re-calculate origin and args for the unwrapped type + origin = get_origin(type_annotation) + args = get_args(type_annotation) + elif len(non_none_args) > 1: + # This is Union[T, U, ...], treat as string for now + return StringParameter( + description=parameter_descriptions.get(param_name, f"Parameter {param_name}"), + required=is_required, + ) + + # Handle specific types + if type_annotation is str: + return StringParameter( + description=parameter_descriptions.get(param_name, f"Parameter {param_name}"), + required=is_required, + ) + elif type_annotation is int: + return IntegerParameter( + description=parameter_descriptions.get(param_name, f"Parameter {param_name}"), + minimum=1 if param_name in ['top_k', 'effective_search_ratio'] else None, + required=is_required, + ) + elif type_annotation is float: + constraints = {} + if param_name == 'alpha': + constraints.update(minimum=0.0, maximum=1.0) + return NumberParameter( + description=parameter_descriptions.get(param_name, f"Parameter {param_name}"), + required=is_required, + **constraints + ) + elif origin is list or type_annotation is list or (hasattr(type_annotation, '__origin__') and type_annotation.__origin__ is list) or str(type_annotation).startswith('list['): + # Handle list[float] for vectors + if args and args[0] is float: + return ArrayParameter( + items=NumberParameter( + description="A single vector component", + required=False, + ), + description=parameter_descriptions.get(param_name, f"Parameter {param_name}"), + required=is_required, + ) + else: + # For complex list types like List[LLMMessage], treat as object + return ObjectParameter( + description=parameter_descriptions.get(param_name, f"Parameter {param_name}"), + properties={}, + additional_properties=True, + required=is_required, + ) + elif origin is dict or (hasattr(type_annotation, '__origin__') and type_annotation.__origin__ is dict): + return ObjectParameter( + description=parameter_descriptions.get(param_name, f"Parameter {param_name}"), + properties={}, + additional_properties=True, + required=is_required, + ) + else: + # Check if it's a complex type that should be an object + type_name = str(type_annotation) + if any(keyword in type_name.lower() for keyword in ['dict', 'list', 'optional[dict', 'optional[list']): + return ObjectParameter( + description=parameter_descriptions.get(param_name, f"Parameter {param_name}"), + properties={}, + additional_properties=True, + required=is_required, + ) + # For other complex types or enums, default to string + return StringParameter( + description=parameter_descriptions.get(param_name, f"Parameter {param_name}"), + required=is_required, + ) + + + def convert_to_tool( + self, + name: str, + description: str, + parameter_descriptions: Optional[Dict[str, str]] = None + ) -> "Tool": + """Convert this retriever to a Tool object. + + Args: + name (str): Name for the tool. + description (str): Description of what the tool does. + parameter_descriptions (Optional[Dict[str, str]]): Optional descriptions for each parameter. + Keys should match parameter names from get_search_results method. + + Returns: + Tool: A Tool object configured to use this retriever's search functionality. + """ + # Import here to avoid circular imports + from neo4j_graphrag.tools.tool import Tool + + # Get parameters from the retriever with custom descriptions + parameters = self.get_parameters(parameter_descriptions or {}) + + # Define a function that matches the Callable[[str, ...], Any] signature + def execute_func(**kwargs: Any) -> Any: + return self.get_search_results(**kwargs) + + # Create a Tool object from the retriever + return Tool( + name=name, + description=description, + execute_func=execute_func, + parameters=parameters, + ) + class ExternalRetriever(Retriever, ABC): """ diff --git a/src/neo4j_graphrag/retrievers/tools_retriever.py b/src/neo4j_graphrag/retrievers/tools_retriever.py new file mode 100644 index 000000000..633334b42 --- /dev/null +++ b/src/neo4j_graphrag/retrievers/tools_retriever.py @@ -0,0 +1,158 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Any, List, Optional, Sequence + +import neo4j + +from neo4j_graphrag.llm.base import LLMInterface +from neo4j_graphrag.retrievers.base import Retriever +from neo4j_graphrag.types import RawSearchResult +from neo4j_graphrag.tools.tool import Tool +from neo4j_graphrag.types import LLMMessage + + +class ToolsRetriever(Retriever): + """A retriever that uses an LLM to select appropriate tools for retrieval based on user input. + + This retriever takes an LLM instance and a list of Tool objects as input. When a search is performed, + it uses the LLM to analyze the query and determine which tools (if any) should be used to retrieve + the necessary data. It then executes the selected tools and returns the combined results. + + Args: + driver (neo4j.Driver): Neo4j driver instance. + llm (LLMInterface): LLM instance used to select tools. + tools (Sequence[Tool]): List of tools available for selection. + neo4j_database (Optional[str], optional): Neo4j database name. Defaults to None. + system_instruction (Optional[str], optional): Custom system instruction for the LLM. Defaults to None. + """ + + # Disable Neo4j version verification since this retriever doesn't directly interact with Neo4j + VERIFY_NEO4J_VERSION = False + + def __init__( + self, + driver: neo4j.Driver, + llm: LLMInterface, + tools: Sequence[Tool], + neo4j_database: Optional[str] = None, + system_instruction: Optional[str] = None, + ): + """Initialize the ToolsRetriever with an LLM and a list of tools.""" + super().__init__(driver, neo4j_database) + self.llm = llm + self._tools = list(tools) # Make a copy to allow modification + self.system_instruction = ( + system_instruction or self._get_default_system_instruction() + ) + + def _get_default_system_instruction(self) -> str: + """Get the default system instruction for the LLM.""" + return ( + "You are an assistant that helps select the most appropriate tools to retrieve information " + "based on the user's query. Analyze the query carefully and determine which tools, if any, " + "would be most helpful in retrieving the relevant information. You can select multiple tools " + "if necessary, or none if no tools are appropriate for the query." + ) + + def get_search_results( + self, + query_text: str, + message_history: Optional[List[LLMMessage]] = None, + **kwargs: Any, + ) -> RawSearchResult: + """Use the LLM to select and execute appropriate tools based on the query. + + Args: + query_text (str): The user's query text. + message_history (Optional[Union[List[LLMMessage], MessageHistory]], optional): + Previous conversation history. Defaults to None. + **kwargs (Any): Additional arguments passed to the tool execution. + + Returns: + RawSearchResult: The combined results from the executed tools. + """ + if not self._tools: + # No tools available, return empty result + return RawSearchResult( + records=[], + metadata={"query": query_text, "error": "No tools available"}, + ) + + try: + # Use the LLM to select appropriate tools + tool_call_response = self.llm.invoke_with_tools( + input=query_text, + tools=self._tools, + message_history=message_history, + system_instruction=self.system_instruction, + ) + # If no tool calls were made, return empty result + if not tool_call_response.tool_calls: + return RawSearchResult( + records=[], + metadata={ + "query": query_text, + "llm_response": tool_call_response.content, + "tools_selected": [], + }, + ) + + # Execute each selected tool and collect results + all_records = [] + tools_selected = [] + + for tool_call in tool_call_response.tool_calls: + tool_name = tool_call.name + tools_selected.append(tool_name) + + # Find the tool by name + selected_tool = next( + (tool for tool in self._tools if tool.get_name() == tool_name), None + ) + if selected_tool is not None: + # Extract arguments from the tool call + tool_args = tool_call.arguments or {} + + # Execute the tool with the provided arguments + tool_result = selected_tool.execute(**tool_args) + # If the tool result is a RawSearchResult, extract its records + if hasattr(tool_result, "records"): + all_records.extend(tool_result.records) + else: + # Create a record from the tool result + record = neo4j.Record({"result": tool_result}) + all_records.append(record) + + # Combine metadata from all tool calls + combined_metadata = { + "query": query_text, + "llm_response": tool_call_response.content, + "tools_selected": tools_selected, + } + + return RawSearchResult(records=all_records, metadata=combined_metadata) + + except Exception as e: + # Handle any errors during tool selection or execution + return RawSearchResult( + records=[], + metadata={ + "query": query_text, + "error": str(e), + "error_type": type(e).__name__, + }, + ) diff --git a/src/neo4j_graphrag/tool.py b/src/neo4j_graphrag/tools/tool.py similarity index 95% rename from src/neo4j_graphrag/tool.py rename to src/neo4j_graphrag/tools/tool.py index 905fb663a..99b7721fa 100644 --- a/src/neo4j_graphrag/tool.py +++ b/src/neo4j_graphrag/tools/tool.py @@ -211,23 +211,28 @@ def validate_properties(self) -> "ObjectParameter": class Tool(ABC): """Abstract base class defining the interface for all tools in the neo4j-graphrag library.""" + _name: str + _description: str + _parameters: Optional[ObjectParameter] + _execute_func: Callable[..., Any] + def __init__( self, name: str, description: str, - parameters: Union[ObjectParameter, Dict[str, Any]], execute_func: Callable[..., Any], + parameters: Optional[Union[ObjectParameter, Dict[str, Any]]] = None, ): self._name = name self._description = description + self._execute_func = execute_func - # Allow parameters to be provided as a dictionary if isinstance(parameters, dict): self._parameters = ObjectParameter.model_validate(parameters) - else: + elif isinstance(parameters, ObjectParameter): self._parameters = parameters - - self._execute_func = execute_func + else: + self._parameters = None def get_name(self) -> str: """Get the name of the tool. @@ -251,7 +256,9 @@ def get_parameters(self, exclude: Optional[list[str]] = None) -> Dict[str, Any]: Returns: Dict[str, Any]: Dictionary containing parameter schema information. """ - return self._parameters.model_dump_tool(exclude) + if self._parameters: + return self._parameters.model_dump_tool(exclude) + return {} def execute(self, **kwargs: Any) -> Any: """Execute the tool with the given query and additional parameters. diff --git a/src/neo4j_graphrag/tools/utils.py b/src/neo4j_graphrag/tools/utils.py new file mode 100644 index 000000000..0df86b0cd --- /dev/null +++ b/src/neo4j_graphrag/tools/utils.py @@ -0,0 +1,76 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Union + +from neo4j_graphrag.tools.tool import Tool, ObjectParameter + + +def convert_retriever_to_tool( + retriever: Any, + description: Optional[str] = None, + parameters: Optional[Union[ObjectParameter, Dict[str, Any]]] = None, + name: Optional[str] = None, +) -> Tool: + """Convert a retriever instance to a Tool object. + + Args: + retriever (Any): The retriever instance to convert. + description (Optional[str]): Custom description for the tool. If not provided, + an attempt will be made to infer it from the retriever or a generic description will be used. + parameters (Optional[Union[ObjectParameter, Dict[str, ToolParameter]]]): Custom parameters for the tool. + If not provided, no parameters will be included in the tool. + name (Optional[str]): Custom name for the tool. If not provided, + an attempt will be made to infer it from the retriever or a default name will be used. + + Returns: + RetrieverTool: A Tool object configured to use the retriever's search functionality. + """ + # Use provided name or infer it from the retriever + if name is None: + name = getattr(retriever, "name", None) or getattr( + retriever.__class__, "__name__", "UnnamedRetrieverTool" + ) + + # Infer description if not provided + if description is None: + description = ( + getattr(retriever, "description", None) + or f"A tool for retrieving data using {name}." + ) + + # Parameters can be None + + # Define a function that matches the Callable[[str, ...], Any] signature + def execute_func(**kwargs: Any) -> Any: + # The retriever's get_search_results method is expected to handle + # arguments like query_text, top_k, etc., passed as keyword arguments. + # The Tool's 'parameters' definition (e.g., ObjectParameter) ensures + # that these arguments are provided in kwargs when Tool.execute is called. + return retriever.get_search_results(**kwargs) + + # Ensure name is a string + tool_name = str(name) if name is not None else "UnnamedRetrieverTool" + + # Create a Tool object from the retriever + + # Pass parameters directly to the Tool constructor + # If parameters is None, the Tool class will handle it appropriately + return Tool( + name=tool_name, + description=description, + execute_func=execute_func, + parameters=parameters, + ) diff --git a/tests/unit/llm/conftest.py b/tests/unit/llm/conftest.py index 269efadec..9fc776120 100644 --- a/tests/unit/llm/conftest.py +++ b/tests/unit/llm/conftest.py @@ -1,6 +1,6 @@ import pytest -from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter +from neo4j_graphrag.tools.tool import Tool, ObjectParameter, StringParameter class TestTool(Tool): diff --git a/tests/unit/llm/test_openai_llm.py b/tests/unit/llm/test_openai_llm.py index 3c5ee1b9e..55a4f7824 100644 --- a/tests/unit/llm/test_openai_llm.py +++ b/tests/unit/llm/test_openai_llm.py @@ -20,7 +20,7 @@ from neo4j_graphrag.llm import LLMResponse from neo4j_graphrag.llm.openai_llm import AzureOpenAILLM, OpenAILLM from neo4j_graphrag.llm.types import ToolCallResponse -from neo4j_graphrag.tool import Tool +from neo4j_graphrag.tools.tool import Tool def get_mock_openai() -> MagicMock: diff --git a/tests/unit/llm/test_vertexai_llm.py b/tests/unit/llm/test_vertexai_llm.py index c937d2cb8..d914e4cb4 100644 --- a/tests/unit/llm/test_vertexai_llm.py +++ b/tests/unit/llm/test_vertexai_llm.py @@ -20,7 +20,7 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.types import ToolCallResponse from neo4j_graphrag.llm.vertexai_llm import VertexAILLM -from neo4j_graphrag.tool import Tool +from neo4j_graphrag.tools.tool import Tool from neo4j_graphrag.types import LLMMessage from vertexai.generative_models import ( Content, diff --git a/tests/unit/retrievers/test_retriever_parameter_inference.py b/tests/unit/retrievers/test_retriever_parameter_inference.py new file mode 100644 index 000000000..b57507abf --- /dev/null +++ b/tests/unit/retrievers/test_retriever_parameter_inference.py @@ -0,0 +1,456 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for retriever parameter inference and convert_to_tool functionality. +""" + +from unittest.mock import MagicMock, patch +from typing import Optional, Any, Dict, List + +import neo4j +import pytest + +from neo4j_graphrag.retrievers.base import Retriever +from neo4j_graphrag.retrievers import ( + VectorRetriever, + VectorCypherRetriever, + HybridRetriever, + HybridCypherRetriever, + Text2CypherRetriever, +) +from neo4j_graphrag.retrievers.tools_retriever import ToolsRetriever +from neo4j_graphrag.tools.tool import Tool, ParameterType +from neo4j_graphrag.types import RawSearchResult, LLMMessage +from neo4j_graphrag.embeddings.base import Embedder +from neo4j_graphrag.llm.base import LLMInterface + + +# Helper functions for creating mock objects +def create_mock_driver() -> neo4j.Driver: + driver = MagicMock(spec=neo4j.Driver) + mock_result = MagicMock() + mock_result.records = [] + driver.execute_query.return_value = mock_result + return driver + + +def create_mock_embedder() -> Embedder: + embedder = MagicMock(spec=Embedder) + embedder.embed_query.return_value = [0.1, 0.2, 0.3] + return embedder + + +def create_mock_llm() -> LLMInterface: + llm = MagicMock(spec=LLMInterface) + llm.invoke.return_value = MagicMock(content="MATCH (n) RETURN n") + return llm + + +class MockRetriever(Retriever): + """Test retriever with well-documented parameters.""" + + VERIFY_NEO4J_VERSION = False + + def get_search_results( + self, + query_text: str, + top_k: int = 5, + filters: Optional[Dict[str, Any]] = None, + score_threshold: Optional[float] = None, + ) -> RawSearchResult: + """Test search method with documented parameters. + + Args: + query_text (str): The text query to search for in the database + top_k (int): The maximum number of results to return + filters (Optional[Dict[str, Any]]): Optional metadata filters to apply + score_threshold (Optional[float]): Minimum similarity score threshold + + Returns: + RawSearchResult: The search results + """ + return RawSearchResult(records=[], metadata={}) + + +class MockRetrieverNoDocstring(Retriever): + """Test retriever without parameter documentation.""" + + VERIFY_NEO4J_VERSION = False + + def get_search_results(self, param_one: str, param_two: Optional[int] = None) -> RawSearchResult: + """No parameter documentation here.""" + return RawSearchResult(records=[], metadata={}) + + +class TestParameterInference: + """Test parameter inference from method signatures and docstrings.""" + + def test_parameter_inference_with_docstring(self): + """Test that parameters are correctly inferred from method signature and docstring.""" + driver = create_mock_driver() + retriever = MockRetriever(driver) + + # Get inferred parameters + params = retriever.get_parameters() + + # Check basic structure + assert params.type == ParameterType.OBJECT + assert params.description == "Parameters for MockRetriever" + assert not params.additional_properties + + # Check properties + properties = params.properties + assert len(properties) == 4 + + # Check query_text parameter + query_text_param = properties["query_text"] + assert query_text_param.type == ParameterType.STRING + assert query_text_param.description == "Parameter query_text" + assert query_text_param.required is True + + # Check top_k parameter + top_k_param = properties["top_k"] + assert top_k_param.type == ParameterType.INTEGER + assert top_k_param.description == "Parameter top_k" + assert top_k_param.required is False + assert top_k_param.minimum == 1 # Should be set for top_k parameters + + # Check filters parameter + filters_param = properties["filters"] + assert filters_param.type == ParameterType.OBJECT + assert filters_param.description == "Parameter filters" + assert filters_param.required is False + assert filters_param.additional_properties is True + + # Check score_threshold parameter + score_param = properties["score_threshold"] + assert score_param.type == ParameterType.NUMBER + assert score_param.description == "Parameter score_threshold" + assert score_param.required is False + + def test_parameter_inference_without_docstring(self): + """Test that parameters work with fallback descriptions when no docstring documentation.""" + driver = create_mock_driver() + retriever = MockRetrieverNoDocstring(driver) + + # Get inferred parameters + params = retriever.get_parameters() + + # Check properties + properties = params.properties + assert len(properties) == 2 + + # Check param_one with fallback description + param_one = properties["param_one"] + assert param_one.type == ParameterType.STRING + assert param_one.description == "Parameter param_one" # Simple fallback format + assert param_one.required is True + + # Check param_two with fallback description + param_two = properties["param_two"] + assert param_two.type == ParameterType.INTEGER + assert param_two.description == "Parameter param_two" # Simple fallback format + assert param_two.required is False + + def test_convert_to_tool_basic(self): + """Test basic convert_to_tool functionality.""" + driver = create_mock_driver() + retriever = MockRetriever(driver) + + # Convert to tool + tool = retriever.convert_to_tool( + name="TestTool", + description="A test tool for searching" + ) + + # Check tool properties + assert isinstance(tool, Tool) + assert tool.get_name() == "TestTool" + assert tool.get_description() == "A test tool for searching" + + # Check that parameters were inferred + params = tool.get_parameters() + assert len(params["properties"]) == 4 + assert "query_text" in params["properties"] + assert "top_k" in params["properties"] + + def test_convert_to_tool_with_custom_descriptions(self): + """Test convert_to_tool with custom parameter descriptions.""" + driver = create_mock_driver() + retriever = MockRetriever(driver) + + # Convert to tool with custom parameter descriptions + tool = retriever.convert_to_tool( + name="CustomTool", + description="A custom search tool", + parameter_descriptions={ + "query_text": "The search query to execute", + "top_k": "Maximum number of results to return", + "filters": "Optional filters to apply to the search" + } + ) + + # Check tool properties + assert tool.get_name() == "CustomTool" + assert tool.get_description() == "A custom search tool" + + # Check custom parameter descriptions + params = tool.get_parameters() + properties = params["properties"] + + assert properties["query_text"]["description"] == "The search query to execute" + assert properties["top_k"]["description"] == "Maximum number of results to return" + assert properties["filters"]["description"] == "Optional filters to apply to the search" + # Parameter without custom description should use fallback + assert properties["score_threshold"]["description"] == "Parameter score_threshold" + + +class TestRealRetrieverParameterInference: + """Test parameter inference on real retriever classes.""" + + @patch("neo4j_graphrag.retrievers.base.get_version") + def test_vector_retriever_parameters(self, mock_get_version): + """Test VectorRetriever parameter inference.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + + driver = create_mock_driver() + embedder = create_mock_embedder() + + # Patch _fetch_index_infos to avoid database calls + with patch.object(VectorRetriever, '_fetch_index_infos'): + retriever = VectorRetriever( + driver=driver, + index_name="test_index", + embedder=embedder + ) + + params = retriever.get_parameters() + properties = params.properties + + # Check expected parameters from VectorRetriever.get_search_results + expected_params = {"query_vector", "query_text", "top_k", "effective_search_ratio", "filters"} + assert set(properties.keys()) == expected_params + + # Check specific parameter types + assert properties["query_vector"].type == ParameterType.ARRAY + assert properties["query_text"].type == ParameterType.STRING + assert properties["top_k"].type == ParameterType.INTEGER + assert properties["effective_search_ratio"].type == ParameterType.INTEGER + assert properties["filters"].type == ParameterType.OBJECT + + # Check that default descriptions are used when no custom descriptions provided + assert properties["query_vector"].description == "Parameter query_vector" + assert properties["query_text"].description == "Parameter query_text" + + @patch("neo4j_graphrag.retrievers.base.get_version") + def test_vector_cypher_retriever_parameters(self, mock_get_version): + """Test VectorCypherRetriever parameter inference.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + + driver = create_mock_driver() + embedder = create_mock_embedder() + + # Patch _fetch_index_infos to avoid database calls + with patch.object(VectorCypherRetriever, '_fetch_index_infos'): + retriever = VectorCypherRetriever( + driver=driver, + index_name="test_index", + retrieval_query="RETURN node.name", + embedder=embedder + ) + + params = retriever.get_parameters() + properties = params.properties + + # Should have all VectorRetriever params plus query_params + expected_params = {"query_vector", "query_text", "top_k", "effective_search_ratio", "query_params", "filters"} + assert set(properties.keys()) == expected_params + + # Check query_params is properly typed + assert properties["query_params"].type == ParameterType.OBJECT + assert properties["query_params"].additional_properties is True + + @patch("neo4j_graphrag.retrievers.base.get_version") + def test_hybrid_retriever_parameters(self, mock_get_version): + """Test HybridRetriever parameter inference.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + + driver = create_mock_driver() + embedder = create_mock_embedder() + + # Patch _fetch_index_infos to avoid database calls + with patch.object(HybridRetriever, '_fetch_index_infos'): + retriever = HybridRetriever( + driver=driver, + vector_index_name="vector_index", + fulltext_index_name="fulltext_index", + embedder=embedder + ) + + params = retriever.get_parameters() + properties = params.properties + + # Check expected parameters from HybridRetriever.get_search_results + expected_params = {"query_text", "query_vector", "top_k", "effective_search_ratio", "ranker", "alpha"} + assert set(properties.keys()) == expected_params + + # Check that query_text is required for hybrid retriever + assert properties["query_text"].required is True + assert properties["alpha"].type == ParameterType.NUMBER + assert properties["alpha"].minimum == 0.0 + assert properties["alpha"].maximum == 1.0 + + @patch("neo4j_graphrag.retrievers.base.get_version") + def test_text2cypher_retriever_parameters(self, mock_get_version): + """Test Text2CypherRetriever parameter inference.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + + driver = create_mock_driver() + llm = create_mock_llm() + retriever = Text2CypherRetriever( + driver=driver, + llm=llm, + neo4j_schema="(Person)-[:KNOWS]->(Person)" + ) + + params = retriever.get_parameters() + properties = params.properties + + # Check expected parameters + expected_params = {"query_text", "prompt_params"} + assert set(properties.keys()) == expected_params + + # Check parameter types + assert properties["query_text"].type == ParameterType.STRING + assert properties["query_text"].required is True + assert properties["prompt_params"].type == ParameterType.OBJECT # Dict maps to object + assert properties["prompt_params"].required is False + + def test_tools_retriever_parameters(self): + """Test ToolsRetriever parameter inference.""" + driver = create_mock_driver() + llm = create_mock_llm() + retriever = ToolsRetriever( + driver=driver, + llm=llm, + tools=[] + ) + + params = retriever.get_parameters() + properties = params.properties + + # Check expected parameters from ToolsRetriever.get_search_results + expected_params = {"query_text", "message_history"} + assert set(properties.keys()) == expected_params + + # Check parameter types + assert properties["query_text"].type == ParameterType.STRING + assert properties["query_text"].required is True + assert properties["message_history"].type == ParameterType.OBJECT # List[LLMMessage] maps to Object + assert properties["message_history"].required is False + + +class TestToolExecution: + """Test that tools created from retrievers actually work.""" + + def test_tool_execution(self): + """Test that a tool created from a retriever can be executed.""" + driver = create_mock_driver() + retriever = MockRetriever(driver) + + # Convert to tool + tool = retriever.convert_to_tool( + name="TestTool", + description="A test tool" + ) + + # Execute the tool + result = tool.execute(query_text="test query", top_k=3) + + # Check that we get a result (even if empty due to mocking) + assert result is not None + assert hasattr(result, 'records') + assert hasattr(result, 'metadata') + + def test_tool_execution_with_validation(self): + """Test that tool parameter validation works.""" + driver = create_mock_driver() + retriever = MockRetriever(driver) + + # Convert to tool + tool = retriever.convert_to_tool( + name="TestTool", + description="A test tool" + ) + + # Test with missing required parameter should work due to our setup + # (the actual validation happens in the Tool class) + result = tool.execute(query_text="test query") + assert result is not None + + +class TestParameterDescriptions: + """Test parameter description functionality.""" + + def test_custom_parameter_descriptions(self): + """Test that custom parameter descriptions are used correctly.""" + + class TestRetriever(Retriever): + VERIFY_NEO4J_VERSION = False + + def get_search_results( + self, + param_a: str, + param_b: int = 5, + param_c: Optional[float] = None + ) -> RawSearchResult: + return RawSearchResult(records=[], metadata={}) + + driver = create_mock_driver() + retriever = TestRetriever(driver) + + # Test with custom descriptions + custom_descriptions = { + "param_a": "Custom description for param A", + "param_b": "Custom description for param B" + # param_c intentionally omitted to test fallback + } + + params = retriever.get_parameters(custom_descriptions) + properties = params.properties + + # Check that custom descriptions are used + assert properties["param_a"].description == "Custom description for param A" + assert properties["param_b"].description == "Custom description for param B" + # Check fallback for param without custom description + assert properties["param_c"].description == "Parameter param_c" + + def test_no_custom_descriptions(self): + """Test behavior when no custom descriptions are provided.""" + + class SimpleRetriever(Retriever): + VERIFY_NEO4J_VERSION = False + + def get_search_results(self, test_param: str) -> RawSearchResult: + return RawSearchResult(records=[], metadata={}) + + driver = create_mock_driver() + retriever = SimpleRetriever(driver) + params = retriever.get_parameters() + properties = params.properties + + # Should use fallback description + assert properties["test_param"].description == "Parameter test_param" \ No newline at end of file diff --git a/tests/unit/retrievers/test_tools_retriever.py b/tests/unit/retrievers/test_tools_retriever.py new file mode 100644 index 000000000..a5333aedc --- /dev/null +++ b/tests/unit/retrievers/test_tools_retriever.py @@ -0,0 +1,262 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard library imports +from typing import Any, List, cast +from unittest.mock import MagicMock + +import neo4j + +# Local imports +from neo4j_graphrag.llm.base import LLMInterface +from neo4j_graphrag.llm.types import ToolCall, ToolCallResponse +from neo4j_graphrag.retrievers.tools_retriever import ToolsRetriever +from neo4j_graphrag.tools.tool import Tool + + +# Mock dependencies +def create_mock_driver() -> neo4j.Driver: + driver = MagicMock(spec=neo4j.Driver) + # Create a mock result object with a records attribute + mock_result = MagicMock() + mock_result.records = [MagicMock()] + driver.execute_query.return_value = mock_result + return cast(neo4j.Driver, driver) + + +def create_mock_llm() -> Any: + llm = MagicMock(spec=LLMInterface) + return llm + + +def create_mock_tool(name: str = "MockTool", description: str = "A mock tool") -> Any: + tool = MagicMock(spec=Tool) + cast(Any, tool.get_name).return_value = name + cast(Any, tool.get_description).return_value = description + cast(Any, tool.get_parameters).return_value = { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The query to search for", + } + }, + } + # Mock the execute method to return a dictionary with records and metadata + cast(Any, tool.execute).return_value = { + "records": [neo4j.Record({"result": f"Result from {name}"})], + "metadata": {"source": name}, + } + return tool + + +class TestToolsRetriever: + """Test the ToolsRetriever class.""" + + def test_initialization(self) -> None: + """Test that the ToolsRetriever initializes correctly.""" + driver = create_mock_driver() + llm = create_mock_llm() + tools = [create_mock_tool("Tool1"), create_mock_tool("Tool2")] + + retriever = ToolsRetriever(driver=driver, llm=llm, tools=tools) + + assert retriever.llm == llm + assert len(retriever._tools) == 2 + assert retriever._tools[0].get_name() == "Tool1" + assert retriever._tools[1].get_name() == "Tool2" + + def test_get_search_results_no_tools(self) -> None: + """Test that get_search_results returns an empty result when no tools are available.""" + driver = create_mock_driver() + llm = create_mock_llm() + tools: List[Tool] = [] + + retriever = ToolsRetriever(driver=driver, llm=llm, tools=tools) + result = retriever.get_search_results(query_text="Test query") + + assert result.records == [] + assert result.metadata is not None + assert result.metadata.get("query") == "Test query" + assert "error" in result.metadata + assert result.metadata.get("error") == "No tools available" + + def test_get_search_results_no_tool_calls(self) -> None: + """Test that get_search_results returns an empty result when the LLM doesn't select any tools.""" + driver = create_mock_driver() + llm = create_mock_llm() + tools = [create_mock_tool("Tool1"), create_mock_tool("Tool2")] + + # Mock the LLM to return a response with no tool calls + cast(Any, llm.invoke_with_tools).return_value = ToolCallResponse( + content="I don't need any tools for this query.", + tool_calls=[], + ) + + retriever = ToolsRetriever(driver=driver, llm=llm, tools=tools) + result = retriever.get_search_results(query_text="Test query") + + assert result.records == [] + assert result.metadata is not None + assert result.metadata.get("query") == "Test query" + assert ( + result.metadata.get("llm_response") + == "I don't need any tools for this query." + ) + assert result.metadata.get("tools_selected") == [] + + def test_get_search_results_with_tool_calls(self) -> None: + """Test that get_search_results correctly executes selected tools and returns their results.""" + driver = create_mock_driver() + llm = create_mock_llm() + tool1 = create_mock_tool("Tool1") + tool2 = create_mock_tool("Tool2") + tools = [tool1, tool2] + + # Mock the LLM to return a response with tool calls + cast(Any, llm.invoke_with_tools).return_value = ToolCallResponse( + content="I'll use Tool1 for this query.", + tool_calls=[ + ToolCall( + name="Tool1", + arguments={"query": "Test query"}, + ) + ], + ) + + # Mock the tool execution to return a simple string value + # This is processed by the ToolsRetriever and converted to a neo4j.Record + cast(Any, tool1).execute.return_value = "Result from Tool1" + + retriever = ToolsRetriever(driver=driver, llm=llm, tools=tools) + result = retriever.get_search_results(query_text="Test query") + + # Check that the LLM was called with the right arguments + cast(Any, llm.invoke_with_tools).assert_called_once_with( + input="Test query", + tools=tools, + message_history=None, + system_instruction=retriever.system_instruction, + ) + + # Check that the tool was executed with the right arguments + tool1.execute.assert_called_once_with(query="Test query") + + # Check that the result contains the expected records and metadata + assert len(result.records) == 1 + # The record is a neo4j.Record object + assert isinstance(result.records[0], neo4j.Record) + # Access the result directly using index 0 + assert result.records[0][0] == "Result from Tool1" + assert result.metadata is not None + assert result.metadata.get("query") == "Test query" + assert result.metadata.get("llm_response") == "I'll use Tool1 for this query." + assert result.metadata.get("tools_selected") == ["Tool1"] + + def test_get_search_results_with_multiple_tool_calls(self) -> None: + """Test that get_search_results correctly executes multiple selected tools and combines their results.""" + driver = create_mock_driver() + llm = create_mock_llm() + tool1 = create_mock_tool("Tool1") + tool2 = create_mock_tool("Tool2") + tools = [tool1, tool2] + + # Mock the LLM to return a response with multiple tool calls + cast(Any, llm.invoke_with_tools).return_value = ToolCallResponse( + content="I'll use both Tool1 and Tool2 for this query.", + tool_calls=[ + ToolCall( + name="Tool1", + arguments={"query": "Test query part 1"}, + ), + ToolCall( + name="Tool2", + arguments={"query": "Test query part 2"}, + ), + ], + ) + + # Mock the tool executions to return specific records + tool1_record = neo4j.Record({"result": "Result from Tool1"}) + cast(Any, tool1.execute).return_value = { + "records": [tool1_record], + "metadata": {"source": "Tool1"}, + } + + tool2_record = neo4j.Record({"result": "Result from Tool2"}) + cast(Any, tool2.execute).return_value = { + "records": [tool2_record], + "metadata": {"source": "Tool2"}, + } + + retriever = ToolsRetriever(driver=driver, llm=llm, tools=tools) + result = retriever.get_search_results(query_text="Test query") + + # Check that both tools were executed with the right arguments + cast(Any, tool1.execute).assert_called_once_with(query="Test query part 1") + cast(Any, tool2.execute).assert_called_once_with(query="Test query part 2") + + # Check that the result contains the expected records and metadata + assert len(result.records) == 2 + assert result.metadata is not None + assert result.metadata.get("query") == "Test query" + assert ( + result.metadata.get("llm_response") + == "I'll use both Tool1 and Tool2 for this query." + ) + assert result.metadata.get("tools_selected") == ["Tool1", "Tool2"] + + def test_get_search_results_with_error(self) -> None: + """Test that get_search_results handles errors during tool execution.""" + driver = create_mock_driver() + llm = create_mock_llm() + tool = create_mock_tool("Tool1") + tools = [tool] + + # Mock the LLM to raise an exception + cast(Any, llm.invoke_with_tools).side_effect = Exception("LLM error") + + retriever = ToolsRetriever(driver=driver, llm=llm, tools=tools) + result = retriever.get_search_results(query_text="Test query") + + # Check that the result contains the error information + assert result.records == [] + assert result.metadata is not None + assert result.metadata.get("query") == "Test query" + assert result.metadata.get("error") == "LLM error" + assert result.metadata.get("error_type") == "Exception" + + def test_custom_system_instruction(self) -> None: + """Test that a custom system instruction is used when provided.""" + driver = create_mock_driver() + llm = create_mock_llm() + tools = [create_mock_tool("Tool1")] + custom_instruction = "This is a custom system instruction." + + retriever = ToolsRetriever( + driver=driver, llm=llm, tools=tools, system_instruction=custom_instruction + ) + + assert retriever.system_instruction == custom_instruction + + # Test that the custom instruction is passed to the LLM + retriever.get_search_results(query_text="Test query") + + llm.invoke_with_tools.assert_called_once_with( + input="Test query", + tools=tools, + message_history=None, + system_instruction=custom_instruction, + ) diff --git a/tests/unit/tool/test_tool.py b/tests/unit/tool/test_tool.py index 6c04a1782..e1af50210 100644 --- a/tests/unit/tool/test_tool.py +++ b/tests/unit/tool/test_tool.py @@ -1,6 +1,6 @@ import pytest from typing import Any -from neo4j_graphrag.tool import ( +from neo4j_graphrag.tools.tool import ( StringParameter, IntegerParameter, NumberParameter, diff --git a/tests/unit/tool/test_tools_utils.py b/tests/unit/tool/test_tools_utils.py new file mode 100644 index 000000000..6926c5105 --- /dev/null +++ b/tests/unit/tool/test_tools_utils.py @@ -0,0 +1,529 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from unittest.mock import MagicMock, patch +import neo4j +from neo4j_graphrag.embeddings.base import Embedder +from neo4j_graphrag.llm.base import LLMInterface +from neo4j_graphrag.retrievers import ( + HybridCypherRetriever, + HybridRetriever, + Text2CypherRetriever, + VectorCypherRetriever, + VectorRetriever, +) +from neo4j_graphrag.tools.tool import ( + Tool, + ObjectParameter, + StringParameter, + IntegerParameter, +) +from neo4j_graphrag.tools.utils import convert_retriever_to_tool + + +# Mock dependencies for retriever instances +def create_mock_driver() -> neo4j.Driver: + driver = MagicMock(spec=neo4j.Driver) + # Create a mock result object with a records attribute + mock_result = MagicMock() + mock_result.records = [MagicMock()] + driver.execute_query.return_value = mock_result + return driver + + +def create_mock_embedder() -> Embedder: + embedder = MagicMock(spec=Embedder) + embedder.embed_query.return_value = [0.1, 0.2, 0.3] + return embedder + + +def create_mock_llm() -> LLMInterface: + llm = MagicMock() + llm.invoke.return_value = "MATCH (n) RETURN n" + return llm + + +# Test conversion with VectorRetriever +@patch("neo4j_graphrag.retrievers.base.get_version") +def test_convert_vector_retriever_to_tool(mock_get_version: MagicMock) -> None: + """Test conversion of VectorRetriever to a Tool instance with correct attributes.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + driver = create_mock_driver() + embedder = create_mock_embedder() + retriever = VectorRetriever( + driver=driver, + index_name="test_index", + embedder=embedder, + return_properties=["name", "description"], + ) + parameters = ObjectParameter( + description="Parameters for vector search", + properties={ + "query_text": StringParameter( + description="The query text for vector search.", + required=True, + ), + "top_k": IntegerParameter( + description="Number of results to return.", + required=False, + ), + }, + ) + tool = convert_retriever_to_tool( + retriever, + description="A tool for vector-based retrieval from Neo4j.", + parameters=parameters, + ) + assert isinstance(tool, Tool) + assert tool.get_name() in ["VectorRetriever", "UnnamedRetrieverTool"] + assert tool.get_description() == "A tool for vector-based retrieval from Neo4j." + # Check that the parameters object has the expected properties + params = tool.get_parameters() + assert "properties" in params + assert len(params["properties"]) == 2 + + +# Test conversion with VectorCypherRetriever +@patch("neo4j_graphrag.retrievers.base.get_version") +def test_convert_vector_cypher_retriever_to_tool(mock_get_version: MagicMock) -> None: + """Test conversion of VectorCypherRetriever to a Tool instance with correct attributes.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + driver = create_mock_driver() + embedder = create_mock_embedder() + retriever = VectorCypherRetriever( + driver=driver, + index_name="test_index", + embedder=embedder, + retrieval_query="RETURN n", + ) + parameters = ObjectParameter( + description="Parameters for vector-cypher search", + properties={ + "query_text": StringParameter( + description="The query text for vector-cypher search.", + required=True, + ), + "top_k": IntegerParameter( + description="Number of results to return.", + required=False, + ), + }, + ) + tool = convert_retriever_to_tool( + retriever, + description="A tool for vector-cypher retrieval from Neo4j.", + parameters=parameters, + ) + assert isinstance(tool, Tool) + assert tool.get_name() in ["VectorCypherRetriever", "UnnamedRetrieverTool"] + assert tool.get_description() == "A tool for vector-cypher retrieval from Neo4j." + # Check that the parameters object has the expected properties + params = tool.get_parameters() + assert "properties" in params + assert len(params["properties"]) == 2 + + +# Test conversion with HybridRetriever +@patch("neo4j_graphrag.retrievers.base.get_version") +def test_convert_hybrid_retriever_to_tool(mock_get_version: MagicMock) -> None: + """Test conversion of HybridRetriever to a Tool instance with correct attributes.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + driver = create_mock_driver() + embedder = create_mock_embedder() + retriever = HybridRetriever( + driver=driver, + vector_index_name="test_vector_index", + fulltext_index_name="test_fulltext_index", + embedder=embedder, + return_properties=["name", "description"], + ) + parameters = ObjectParameter( + description="Parameters for hybrid search", + properties={ + "query_text": StringParameter( + description="The query text for hybrid search.", + required=True, + ), + "top_k": IntegerParameter( + description="Number of results to return.", + required=False, + ), + }, + ) + tool = convert_retriever_to_tool( + retriever, + description="A tool for hybrid retrieval from Neo4j.", + parameters=parameters, + ) + assert isinstance(tool, Tool) + assert tool.get_name() in ["HybridRetriever", "UnnamedRetrieverTool"] + assert tool.get_description() == "A tool for hybrid retrieval from Neo4j." + # Check that the parameters object has the expected properties + params = tool.get_parameters() + assert "properties" in params + assert len(params["properties"]) == 2 + + +# Test conversion with HybridCypherRetriever +@patch("neo4j_graphrag.retrievers.base.get_version") +def test_convert_hybrid_cypher_retriever_to_tool(mock_get_version: MagicMock) -> None: + """Test conversion of HybridCypherRetriever to a Tool instance with correct attributes.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + driver = create_mock_driver() + embedder = create_mock_embedder() + retriever = HybridCypherRetriever( + driver=driver, + vector_index_name="test_vector_index", + fulltext_index_name="test_fulltext_index", + embedder=embedder, + retrieval_query="RETURN n", + ) + parameters = ObjectParameter( + description="Parameters for hybrid-cypher search", + properties={ + "query_text": StringParameter( + description="The query text for hybrid-cypher search.", + required=True, + ), + "top_k": IntegerParameter( + description="Number of results to return.", + required=False, + ), + }, + ) + tool = convert_retriever_to_tool( + retriever, + description="A tool for hybrid-cypher retrieval from Neo4j.", + parameters=parameters, + ) + assert isinstance(tool, Tool) + assert tool.get_name() in ["HybridCypherRetriever", "UnnamedRetrieverTool"] + assert tool.get_description() == "A tool for hybrid-cypher retrieval from Neo4j." + # Check that the parameters object has the expected properties + params = tool.get_parameters() + assert "properties" in params + assert len(params["properties"]) == 2 + + +# Test conversion with Text2CypherRetriever +@patch("neo4j_graphrag.retrievers.base.get_version") +def test_convert_text2cypher_retriever_to_tool(mock_get_version: MagicMock) -> None: + """Test conversion of Text2CypherRetriever to a Tool instance with correct attributes.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + driver = create_mock_driver() + llm = create_mock_llm() + retriever = Text2CypherRetriever(driver=driver, llm=llm) + parameters = ObjectParameter( + description="Parameters for text to Cypher conversion", + properties={ + "query_text": StringParameter( + description="The query text for text to Cypher conversion.", + required=True, + ), + }, + ) + tool = convert_retriever_to_tool( + retriever, + description="A tool for text to Cypher retrieval from Neo4j.", + parameters=parameters, + ) + assert isinstance(tool, Tool) + assert tool.get_name() in ["Text2CypherRetriever", "UnnamedRetrieverTool"] + assert tool.get_description() == "A tool for text to Cypher retrieval from Neo4j." + # Check that the parameters object has the expected properties + params = tool.get_parameters() + assert "properties" in params + assert len(params["properties"]) == 1 + + +# Test conversion with custom name provided +@patch("neo4j_graphrag.retrievers.base.get_version") +def test_convert_retriever_with_custom_name( + mock_get_version: MagicMock, +) -> None: + """Test conversion of a retriever to a Tool instance with a custom name.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + driver = create_mock_driver() + embedder = create_mock_embedder() + retriever = VectorRetriever( + driver=driver, + index_name="test_index", + embedder=embedder, + return_properties=["name", "description"], + ) + + custom_name = "CustomNamedTool" + parameters = ObjectParameter( + description="Parameters for vector search", + properties={ + "query_text": StringParameter( + description="The query text for vector search.", + required=True, + ), + }, + ) + + tool = convert_retriever_to_tool( + retriever, + description="A tool with a custom name", + parameters=parameters, + name=custom_name, + ) + + # Verify that the custom name is used instead of the retriever class name + assert tool.get_name() == custom_name + assert tool.get_name() != "VectorRetriever" + assert tool.get_name() != "UnnamedRetrieverTool" + + +# Test conversion with no parameters provided +@patch("neo4j_graphrag.retrievers.base.get_version") +def test_convert_vector_retriever_to_tool_no_parameters( + mock_get_version: MagicMock, +) -> None: + """Test conversion of VectorRetriever to a Tool instance when no parameters are provided.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + driver = create_mock_driver() + embedder = create_mock_embedder() + retriever = VectorRetriever( + driver=driver, + index_name="test_index", + embedder=embedder, + return_properties=["name", "description"], + ) + tool = convert_retriever_to_tool( + retriever, description="A tool for vector-based retrieval from Neo4j." + ) + assert isinstance(tool, Tool) + assert tool.get_name() in ["VectorRetriever", "UnnamedRetrieverTool"] + assert tool.get_description() == "A tool for vector-based retrieval from Neo4j." + # Since we don't provide parameters, it should be None + assert tool._parameters is None + + +# Test tool execution for VectorRetriever +@patch("neo4j_graphrag.retrievers.base.get_version") +def test_vector_retriever_tool_execution(mock_get_version: MagicMock) -> None: + """Test execution of VectorRetriever tool calls the search method with correct arguments.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + driver = create_mock_driver() + embedder = create_mock_embedder() + retriever = VectorRetriever( + driver=driver, + index_name="test_index", + embedder=embedder, + return_properties=["name", "description"], + ) + parameters = ObjectParameter( + description="Parameters for vector search", + properties={ + "query_text": StringParameter( + description="The query text for vector search.", + required=True, + ), + "top_k": IntegerParameter( + description="Number of results to return.", + required=False, + ), + }, + ) + # Mock the get_search_results method to track calls + get_search_results_mock = MagicMock(return_value=([], None)) + # Use patch to mock the method + with patch.object(retriever, "get_search_results", get_search_results_mock): + tool = convert_retriever_to_tool( + retriever, + description="A tool for vector-based retrieval from Neo4j.", + parameters=parameters, + ) + tools = {tool.get_name(): tool} + # Simulate indirect invocation as would happen in real usage + tool_call_arguments = {"query_text": "test query", "top_k": 5} + # Pass the arguments as kwargs + result = tools[tool.get_name()].execute(**tool_call_arguments) + + # Since we're using a context manager for patching, we need to verify the call inside the context + # We can only check the result, not the method call itself + assert result == ([], None) + + +# Test tool execution for HybridRetriever +@patch("neo4j_graphrag.retrievers.base.get_version") +def test_hybrid_retriever_tool_execution(mock_get_version: MagicMock) -> None: + """Test execution of HybridRetriever tool calls the search method with correct arguments.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + driver = create_mock_driver() + embedder = create_mock_embedder() + retriever = HybridRetriever( + driver=driver, + vector_index_name="test_vector_index", + fulltext_index_name="test_fulltext_index", + embedder=embedder, + return_properties=["name", "description"], + ) + parameters = ObjectParameter( + description="Parameters for hybrid search", + properties={ + "query_text": StringParameter( + description="The query text for hybrid search.", + required=True, + ), + "top_k": IntegerParameter( + description="Number of results to return.", + required=False, + ), + }, + ) + # Mock the get_search_results method to track calls + get_search_results_mock = MagicMock(return_value=([], None)) + # Use patch to mock the method + with patch.object(retriever, "get_search_results", get_search_results_mock): + tool = convert_retriever_to_tool( + retriever, + description="A tool for hybrid retrieval from Neo4j.", + parameters=parameters, + ) + tools = {tool.get_name(): tool} + # Simulate indirect invocation as would happen in real usage + tool_call_arguments = {"query_text": "test query", "top_k": 5} + # Pass the arguments as kwargs + result = tools[tool.get_name()].execute(**tool_call_arguments) + + # Since we're using a context manager for patching, we need to verify the call inside the context + # We can only check the result, not the method call itself + assert result == ([], None) + + +# Test tool execution for Text2CypherRetriever +@patch("neo4j_graphrag.retrievers.base.get_version") +def test_text2cypher_retriever_tool_execution(mock_get_version: MagicMock) -> None: + """Test execution of Text2CypherRetriever tool calls the search method with correct arguments.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + driver = create_mock_driver() + llm = create_mock_llm() + retriever = Text2CypherRetriever(driver=driver, llm=llm) + parameters = ObjectParameter( + description="Parameters for text to Cypher conversion", + properties={ + "query_text": StringParameter( + description="The query text for text to Cypher conversion.", + required=True, + ), + }, + ) + # Mock the get_search_results method to track calls + get_search_results_mock = MagicMock(return_value=([], None)) + # Use patch to mock the method + with patch.object(retriever, "get_search_results", get_search_results_mock): + tool = convert_retriever_to_tool( + retriever, + description="A tool for text to Cypher retrieval from Neo4j.", + parameters=parameters, + ) + tools = {tool.get_name(): tool} + # Simulate indirect invocation as would happen in real usage + tool_call_arguments = {"query_text": "test query"} + # Pass the arguments as kwargs + result = tools[tool.get_name()].execute(**tool_call_arguments) + + # Since we're using a context manager for patching, we need to verify the call inside the context + # We can only check the result, not the method call itself + assert result == ([], None) + + +# Test tool serialization to JSON format +@patch("neo4j_graphrag.retrievers.base.get_version") +def test_tool_serialization(mock_get_version: MagicMock) -> None: + """Test that a Tool instance can be serialized to the required JSON format.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + driver = create_mock_driver() + embedder = create_mock_embedder() + retriever = VectorRetriever( + driver=driver, + index_name="test_index", + embedder=embedder, + return_properties=["name", "description"], + ) + # Define parameters for the tool + parameters = ObjectParameter( + description="Parameters for vector search", + properties={ + "query_text": StringParameter( + description="The query text for vector search.", + required=True, + ), + "top_k": IntegerParameter( + description="Number of results to return.", + required=False, + ), + }, + ) + tool = convert_retriever_to_tool( + retriever, + description="A tool for vector-based retrieval from Neo4j.", + parameters=parameters, + ) + # Create a dictionary representation of the tool + tool_dict = { + "type": "function", + "name": tool.get_name(), + "description": tool.get_description(), + "parameters": tool.get_parameters(), + } + + assert tool_dict["type"] == "function" + assert tool_dict["name"] == tool.get_name() + assert tool_dict["description"] == tool.get_description() + assert "parameters" in tool_dict + + # Get parameters and convert to dictionary + parameters_any = tool_dict["parameters"] + # Use type casting to handle various parameter types + if isinstance(parameters_any, ObjectParameter): + parameters_dict = parameters_any.model_dump_tool() + elif isinstance(parameters_any, dict): + parameters_dict = parameters_any + else: + # Handle the case where parameters is a Collection[str] or other type + parameters_dict = { + str(k): v for k, v in enumerate(parameters_any) if v is not None + } + + # Check the parameters structure + assert parameters_dict.get("type") == "object" + assert "properties" in parameters_dict + + # Check that at least one parameter is marked as required + required_found = False + properties = parameters_dict.get("properties", {}) + if isinstance(properties, dict): + for param_name, param_data in properties.items(): + if isinstance(param_data, dict) and param_data.get("required", False): + required_found = True + break + + if not required_found and "required" in parameters_dict: + # Check if there's a required array at the parameters level + required_params = parameters_dict.get("required", []) + required_found = len(list(required_params)) > 0 + + assert required_found, "No required parameters found" + + # Check additionalProperties if it exists + if "additionalProperties" in parameters_dict and not parameters_dict.get( + "additionalProperties" + ): + pass # This line is just to satisfy the test, actual check is visual