diff --git a/CHANGELOG.md b/CHANGELOG.md index a8bad948..f3571845 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ ### Fixed - Fixed a bug where `spacy` and `rapidfuzz` needed to be installed even if not using the relevant entity resolvers. +- Fixed a bug where `VertexAILLM.(a)invoke_with_tools` called with multiple tools would raise an error. ### Changed diff --git a/examples/customize/llms/vertexai_tool_calls.py b/examples/customize/llms/vertexai_tool_calls.py index b8b00da5..d853d737 100644 --- a/examples/customize/llms/vertexai_tool_calls.py +++ b/examples/customize/llms/vertexai_tool_calls.py @@ -4,6 +4,7 @@ """ import asyncio +from typing import Optional from dotenv import load_dotenv from vertexai.generative_models import GenerationConfig @@ -17,7 +18,7 @@ # Create a custom Tool implementation for person info extraction -parameters = ObjectParameter( +person_tool_parameters = ObjectParameter( description="Parameters for extracting person information", properties={ "name": StringParameter(description="The person's full name"), @@ -29,7 +30,9 @@ ) -def run_tool(name: str, age: int, occupation: str) -> str: +def run_person_tool( + name: str, age: Optional[int] = None, occupation: Optional[str] = None +) -> str: """A simple function that summarizes person information from input parameters.""" return f"Found person {name} with age {age} and occupation {occupation}" @@ -37,12 +40,40 @@ def run_tool(name: str, age: int, occupation: str) -> str: person_info_tool = Tool( name="extract_person_info", description="Extract information about a person from text", - parameters=parameters, - execute_func=run_tool, + parameters=person_tool_parameters, + execute_func=run_person_tool, +) + +company_tool_parameters = ObjectParameter( + description="Parameters for extracting company information", + properties={ + "name": StringParameter(description="The company's full name"), + "industry": StringParameter(description="The company's industry"), + "creation_year": IntegerParameter(description="The company's creation year"), + }, + required_properties=["name"], + additional_properties=False, +) + + +def run_company_tool( + name: str, industry: Optional[str] = None, creation_year: Optional[int] = None +) -> str: + """A simple function that summarizes company information from input parameters.""" + return ( + f"Found company {name} operating in industry {industry} since {creation_year}" + ) + + +company_info_tool = Tool( + name="extract_company_info", + description="Extract information about a company from text", + parameters=company_tool_parameters, + execute_func=run_company_tool, ) # Create the tool instance -TOOLS = [person_info_tool] +TOOLS = [person_info_tool, company_info_tool] def process_tool_call(response: ToolCallResponse) -> str: @@ -54,7 +85,12 @@ def process_tool_call(response: ToolCallResponse) -> str: print(f"\nTool called: {tool_call.name}") print(f"Arguments: {tool_call.arguments}") print(f"Additional content: {response.content or 'None'}") - return person_info_tool.execute(**tool_call.arguments) # type: ignore[no-any-return] + if tool_call.name == "extract_person_info": + return person_info_tool.execute(**tool_call.arguments) # type: ignore[no-any-return] + elif tool_call.name == "extract_company_info": + return str(company_info_tool.execute(**tool_call.arguments)) + else: + raise ValueError("Unknown tool call") async def main() -> None: @@ -65,13 +101,13 @@ async def main() -> None: generation_config=generation_config, ) - # Example text containing information about a person - text = "Stella Hane is a 35-year-old software engineer who loves coding." + # Example text containing information about a company + text1 = "Neo4j is a software company created in 2007" print("\n=== Synchronous Tool Call ===") # Make a synchronous tool call sync_response = llm.invoke_with_tools( - input=f"Extract information about the person from this text: {text}", + input=f"Extract information about the person from this text: {text1}", tools=TOOLS, ) sync_result = process_tool_call(sync_response) @@ -79,7 +115,7 @@ async def main() -> None: print(sync_result) print("\n=== Asynchronous Tool Call ===") - # Make an asynchronous tool call with a different text + # Make an asynchronous tool call with a different text about a person text2 = "Molly Hane, 32, works as a data scientist and enjoys machine learning." async_response = await llm.ainvoke_with_tools( input=f"Extract information about the person from this text: {text2}", diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index 100ff99a..387c51bc 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -137,20 +137,18 @@ def invoke( Returns: LLMResponse: The response from the LLM. """ - system_message = [system_instruction] if system_instruction is not None else [] - self.model = GenerativeModel( - model_name=self.model_name, - system_instruction=system_message, - **self.options, + model = self._get_model( + system_instruction=system_instruction, + tools=None, ) try: if isinstance(message_history, MessageHistory): message_history = message_history.messages messages = self.get_messages(input, message_history) - response = self.model.generate_content(messages, **self.model_params) - return LLMResponse(content=response.text) + response = model.generate_content(messages) + return self._parse_content_response(response) except ResponseValidationError as e: - raise LLMGenerationError(e) + raise LLMGenerationError("Error calling LLM") from e async def ainvoke( self, @@ -172,31 +170,21 @@ async def ainvoke( try: if isinstance(message_history, MessageHistory): message_history = message_history.messages - system_message = ( - [system_instruction] if system_instruction is not None else [] - ) - self.model = GenerativeModel( - model_name=self.model_name, - system_instruction=system_message, - **self.options, + model = self._get_model( + system_instruction=system_instruction, + tools=None, ) messages = self.get_messages(input, message_history) - response = await self.model.generate_content_async( - messages, **self.model_params - ) - return LLMResponse(content=response.text) + response = await model.generate_content_async(messages) + return self._parse_content_response(response) except ResponseValidationError as e: raise LLMGenerationError(e) - def _to_vertexai_tool(self, tool: Tool) -> VertexAITool: - return VertexAITool( - function_declarations=[ - FunctionDeclaration( - name=tool.get_name(), - description=tool.get_description(), - parameters=tool.get_parameters(exclude=["additional_properties"]), - ) - ] + def _to_vertexai_function_declaration(self, tool: Tool) -> FunctionDeclaration: + return FunctionDeclaration( + name=tool.get_name(), + description=tool.get_description(), + parameters=tool.get_parameters(exclude=["additional_properties"]), ) def _get_llm_tools( @@ -204,7 +192,13 @@ def _get_llm_tools( ) -> Optional[list[VertexAITool]]: if not tools: return None - return [self._to_vertexai_tool(tool) for tool in tools] + return [ + VertexAITool( + function_declarations=[ + self._to_vertexai_function_declaration(tool) for tool in tools + ] + ) + ] def _get_model( self, @@ -213,11 +207,18 @@ def _get_model( ) -> GenerativeModel: system_message = [system_instruction] if system_instruction is not None else [] vertex_ai_tools = self._get_llm_tools(tools) + options = dict(self.options) + if tools: + # we want a tool back, remove generation_config if defined + options.pop("generation_config", None) + else: + # no tools, remove tool_config if defined + options.pop("tool_config", None) model = GenerativeModel( model_name=self.model_name, system_instruction=system_message, tools=vertex_ai_tools, - **self.options, + **options, ) return model diff --git a/tests/unit/llm/test_vertexai_llm.py b/tests/unit/llm/test_vertexai_llm.py index b475efcc..6a5ab74c 100644 --- a/tests/unit/llm/test_vertexai_llm.py +++ b/tests/unit/llm/test_vertexai_llm.py @@ -14,7 +14,6 @@ from __future__ import annotations from typing import cast -from unittest import mock from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest @@ -50,11 +49,9 @@ def test_vertexai_invoke_happy_path(GenerativeModelMock: MagicMock) -> None: response = llm.invoke(input_text) assert response.content == "Return text" GenerativeModelMock.assert_called_once_with( - model_name=model_name, system_instruction=[] + model_name=model_name, system_instruction=[], tools=None ) - user_message = mock.ANY - llm.model.generate_content.assert_called_once_with(user_message, **model_params) - last_call = llm.model.generate_content.call_args_list[0] + last_call = mock_model.generate_content.call_args_list[0] content = last_call.args[0] assert len(content) == 1 assert content[0].role == "user" @@ -62,7 +59,9 @@ def test_vertexai_invoke_happy_path(GenerativeModelMock: MagicMock) -> None: @patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") +@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM.get_messages") def test_vertexai_invoke_with_system_instruction( + mock_get_messages: MagicMock, GenerativeModelMock: MagicMock, ) -> None: system_instruction = "You are a helpful assistant." @@ -72,16 +71,18 @@ def test_vertexai_invoke_with_system_instruction( mock_response.text = "Return text" mock_model = GenerativeModelMock.return_value mock_model.generate_content.return_value = mock_response + + mock_get_messages.return_value = [{"text": "some text"}] + model_params = {"temperature": 0.5} llm = VertexAILLM(model_name, model_params) response = llm.invoke(input_text, system_instruction=system_instruction) assert response.content == "Return text" GenerativeModelMock.assert_called_once_with( - model_name=model_name, system_instruction=[system_instruction] + model_name=model_name, system_instruction=[system_instruction], tools=None ) - user_message = mock.ANY - llm.model.generate_content.assert_called_once_with(user_message, **model_params) + mock_model.generate_content.assert_called_once_with([{"text": "some text"}]) @patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") @@ -110,11 +111,9 @@ def test_vertexai_invoke_with_message_history_and_system_instruction( ) assert response.content == "Return text" GenerativeModelMock.assert_called_once_with( - model_name=model_name, system_instruction=[system_instruction] + model_name=model_name, system_instruction=[system_instruction], tools=None ) - user_message = mock.ANY - llm.model.generate_content.assert_called_once_with(user_message, **model_params) - last_call = llm.model.generate_content.call_args_list[0] + last_call = mock_model.generate_content.call_args_list[0] content = last_call.args[0] assert len(content) == 3 # question + 2 messages in history @@ -167,18 +166,22 @@ def test_vertexai_get_messages_validation_error(GenerativeModelMock: MagicMock) @pytest.mark.asyncio @patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") -async def test_vertexai_ainvoke_happy_path(GenerativeModelMock: MagicMock) -> None: +@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM.get_messages") +async def test_vertexai_ainvoke_happy_path( + mock_get_messages: Mock, GenerativeModelMock: MagicMock +) -> None: mock_response = AsyncMock() mock_response.text = "Return text" mock_model = GenerativeModelMock.return_value mock_model.generate_content_async = AsyncMock(return_value=mock_response) + mock_get_messages.return_value = [{"text": "Return text"}] model_params = {"temperature": 0.5} llm = VertexAILLM("gemini-1.5-flash-001", model_params) input_text = "may thy knife chip and shatter" response = await llm.ainvoke(input_text) assert response.content == "Return text" - llm.model.generate_content_async.assert_awaited_once_with( - [mock.ANY], **model_params + mock_model.generate_content_async.assert_awaited_once_with( + [{"text": "Return text"}] )