From 9632636edbaabbde220f8d128f0d274cf0ea55c4 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 20 May 2025 15:38:04 +0200 Subject: [PATCH 01/12] Fix? VertexAILLM --- src/neo4j_graphrag/llm/vertexai_llm.py | 71 +++++++++++++++----------- tests/unit/llm/test_vertexai_llm.py | 30 +++++------ 2 files changed, 56 insertions(+), 45 deletions(-) diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index 100ff99ab..9dacb2cc2 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -35,6 +35,7 @@ Content, FunctionCall, FunctionDeclaration, + GenerationConfig, GenerationResponse, GenerativeModel, Part, @@ -137,20 +138,18 @@ def invoke( Returns: LLMResponse: The response from the LLM. """ - system_message = [system_instruction] if system_instruction is not None else [] - self.model = GenerativeModel( - model_name=self.model_name, - system_instruction=system_message, - **self.options, + model = self._get_model( + system_instruction=system_instruction, + tools=None, ) try: if isinstance(message_history, MessageHistory): message_history = message_history.messages messages = self.get_messages(input, message_history) - response = self.model.generate_content(messages, **self.model_params) - return LLMResponse(content=response.text) + response = model.generate_content(messages) + return self._parse_content_response(response) except ResponseValidationError as e: - raise LLMGenerationError(e) + raise LLMGenerationError("Error calling LLM") from e async def ainvoke( self, @@ -172,31 +171,21 @@ async def ainvoke( try: if isinstance(message_history, MessageHistory): message_history = message_history.messages - system_message = ( - [system_instruction] if system_instruction is not None else [] - ) - self.model = GenerativeModel( - model_name=self.model_name, - system_instruction=system_message, - **self.options, + model = self._get_model( + system_instruction=system_instruction, + tools=None, ) messages = self.get_messages(input, message_history) - response = await self.model.generate_content_async( - messages, **self.model_params - ) - return LLMResponse(content=response.text) + response = await model.generate_content_async(messages) + return self._parse_content_response(response) except ResponseValidationError as e: raise LLMGenerationError(e) - def _to_vertexai_tool(self, tool: Tool) -> VertexAITool: - return VertexAITool( - function_declarations=[ - FunctionDeclaration( - name=tool.get_name(), - description=tool.get_description(), - parameters=tool.get_parameters(exclude=["additional_properties"]), - ) - ] + def _to_vertexai_function_declaration(self, tool: Tool) -> FunctionDeclaration: + return FunctionDeclaration( + name=tool.get_name(), + description=tool.get_description(), + parameters=tool.get_parameters(exclude=["additional_properties"]), ) def _get_llm_tools( @@ -204,7 +193,28 @@ def _get_llm_tools( ) -> Optional[list[VertexAITool]]: if not tools: return None - return [self._to_vertexai_tool(tool) for tool in tools] + return [ + VertexAITool( + function_declarations=[ + self._to_vertexai_function_declaration(tool) for tool in tools + ] + ) + ] + + def _get_options(self, tool_mode: bool = False) -> dict[str, Any]: + options = dict(self.options) + if tool_mode: + # remove response_mime_type from GenerationConfig + config = options.get("generation_config") + if config: + config_dict = config.to_dict() + if config_dict.get("response_mime_type"): + config_dict["response_mime_type"] = None + options["generation_config"] = GenerationConfig.from_dict(config_dict) + else: + # no tools, drop tool_config if defined + options.pop("tool_config", None) + return options def _get_model( self, @@ -213,11 +223,12 @@ def _get_model( ) -> GenerativeModel: system_message = [system_instruction] if system_instruction is not None else [] vertex_ai_tools = self._get_llm_tools(tools) + options = self._get_options(tool_mode=tools is not None) model = GenerativeModel( model_name=self.model_name, system_instruction=system_message, tools=vertex_ai_tools, - **self.options, + **options, ) return model diff --git a/tests/unit/llm/test_vertexai_llm.py b/tests/unit/llm/test_vertexai_llm.py index b475efcc5..663b40f53 100644 --- a/tests/unit/llm/test_vertexai_llm.py +++ b/tests/unit/llm/test_vertexai_llm.py @@ -50,11 +50,9 @@ def test_vertexai_invoke_happy_path(GenerativeModelMock: MagicMock) -> None: response = llm.invoke(input_text) assert response.content == "Return text" GenerativeModelMock.assert_called_once_with( - model_name=model_name, system_instruction=[] + model_name=model_name, system_instruction=[], tools=None ) - user_message = mock.ANY - llm.model.generate_content.assert_called_once_with(user_message, **model_params) - last_call = llm.model.generate_content.call_args_list[0] + last_call = mock_model.generate_content.call_args_list[0] content = last_call.args[0] assert len(content) == 1 assert content[0].role == "user" @@ -62,7 +60,9 @@ def test_vertexai_invoke_happy_path(GenerativeModelMock: MagicMock) -> None: @patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") +@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM.get_messages") def test_vertexai_invoke_with_system_instruction( + mock_get_messages: MagicMock, GenerativeModelMock: MagicMock, ) -> None: system_instruction = "You are a helpful assistant." @@ -72,16 +72,18 @@ def test_vertexai_invoke_with_system_instruction( mock_response.text = "Return text" mock_model = GenerativeModelMock.return_value mock_model.generate_content.return_value = mock_response + + mock_get_messages.return_value = [{"text": "some text"}] + model_params = {"temperature": 0.5} llm = VertexAILLM(model_name, model_params) response = llm.invoke(input_text, system_instruction=system_instruction) assert response.content == "Return text" GenerativeModelMock.assert_called_once_with( - model_name=model_name, system_instruction=[system_instruction] + model_name=model_name, system_instruction=[system_instruction], tools=None ) - user_message = mock.ANY - llm.model.generate_content.assert_called_once_with(user_message, **model_params) + mock_model.generate_content.assert_called_once_with([{"text": "some text"}]) @patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") @@ -110,11 +112,9 @@ def test_vertexai_invoke_with_message_history_and_system_instruction( ) assert response.content == "Return text" GenerativeModelMock.assert_called_once_with( - model_name=model_name, system_instruction=[system_instruction] + model_name=model_name, system_instruction=[system_instruction], tools=None ) - user_message = mock.ANY - llm.model.generate_content.assert_called_once_with(user_message, **model_params) - last_call = llm.model.generate_content.call_args_list[0] + last_call = mock_model.generate_content.call_args_list[0] content = last_call.args[0] assert len(content) == 3 # question + 2 messages in history @@ -167,19 +167,19 @@ def test_vertexai_get_messages_validation_error(GenerativeModelMock: MagicMock) @pytest.mark.asyncio @patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") -async def test_vertexai_ainvoke_happy_path(GenerativeModelMock: MagicMock) -> None: +@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM.get_messages") +async def test_vertexai_ainvoke_happy_path(mock_get_messages: Mock, GenerativeModelMock: MagicMock) -> None: mock_response = AsyncMock() mock_response.text = "Return text" mock_model = GenerativeModelMock.return_value mock_model.generate_content_async = AsyncMock(return_value=mock_response) + mock_get_messages.return_value = [{"text": "Return text"}] model_params = {"temperature": 0.5} llm = VertexAILLM("gemini-1.5-flash-001", model_params) input_text = "may thy knife chip and shatter" response = await llm.ainvoke(input_text) assert response.content == "Return text" - llm.model.generate_content_async.assert_awaited_once_with( - [mock.ANY], **model_params - ) + mock_model.generate_content_async.assert_awaited_once_with([{"text": "Return text"}]) def test_vertexai_get_llm_tools(test_tool: Tool) -> None: From 5fa8d21e99f10e2bb520260204f88da1869db6ac Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 21 May 2025 15:19:07 +0200 Subject: [PATCH 02/12] Ruff --- src/neo4j_graphrag/llm/vertexai_llm.py | 4 +++- tests/unit/llm/test_vertexai_llm.py | 9 ++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index 9dacb2cc2..f8b4fdb76 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -210,7 +210,9 @@ def _get_options(self, tool_mode: bool = False) -> dict[str, Any]: config_dict = config.to_dict() if config_dict.get("response_mime_type"): config_dict["response_mime_type"] = None - options["generation_config"] = GenerationConfig.from_dict(config_dict) + options["generation_config"] = GenerationConfig.from_dict( + config_dict + ) else: # no tools, drop tool_config if defined options.pop("tool_config", None) diff --git a/tests/unit/llm/test_vertexai_llm.py b/tests/unit/llm/test_vertexai_llm.py index 663b40f53..6a5ab74ca 100644 --- a/tests/unit/llm/test_vertexai_llm.py +++ b/tests/unit/llm/test_vertexai_llm.py @@ -14,7 +14,6 @@ from __future__ import annotations from typing import cast -from unittest import mock from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest @@ -168,7 +167,9 @@ def test_vertexai_get_messages_validation_error(GenerativeModelMock: MagicMock) @pytest.mark.asyncio @patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") @patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM.get_messages") -async def test_vertexai_ainvoke_happy_path(mock_get_messages: Mock, GenerativeModelMock: MagicMock) -> None: +async def test_vertexai_ainvoke_happy_path( + mock_get_messages: Mock, GenerativeModelMock: MagicMock +) -> None: mock_response = AsyncMock() mock_response.text = "Return text" mock_model = GenerativeModelMock.return_value @@ -179,7 +180,9 @@ async def test_vertexai_ainvoke_happy_path(mock_get_messages: Mock, GenerativeMo input_text = "may thy knife chip and shatter" response = await llm.ainvoke(input_text) assert response.content == "Return text" - mock_model.generate_content_async.assert_awaited_once_with([{"text": "Return text"}]) + mock_model.generate_content_async.assert_awaited_once_with( + [{"text": "Return text"}] + ) def test_vertexai_get_llm_tools(test_tool: Tool) -> None: From 88533d7e66605ebaf90e25ab79cf2196da933b86 Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 21 May 2025 15:46:43 +0200 Subject: [PATCH 03/12] Remove full generation config in case of tool calling --- src/neo4j_graphrag/llm/vertexai_llm.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index f8b4fdb76..c02f23921 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -41,6 +41,7 @@ Part, ResponseValidationError, Tool as VertexAITool, + ToolConfig, ) except ImportError: GenerativeModel = None @@ -204,17 +205,10 @@ def _get_llm_tools( def _get_options(self, tool_mode: bool = False) -> dict[str, Any]: options = dict(self.options) if tool_mode: - # remove response_mime_type from GenerationConfig - config = options.get("generation_config") - if config: - config_dict = config.to_dict() - if config_dict.get("response_mime_type"): - config_dict["response_mime_type"] = None - options["generation_config"] = GenerationConfig.from_dict( - config_dict - ) + # we want a tool back, remove generation_config if defined + options.pop("generation_config", None) else: - # no tools, drop tool_config if defined + # no tools, remove tool_config if defined options.pop("tool_config", None) return options From b83fc745b0458d617247de9e0b3332cfd5718c83 Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 21 May 2025 15:49:29 +0200 Subject: [PATCH 04/12] Ruff --- src/neo4j_graphrag/llm/vertexai_llm.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index c02f23921..bddc75e1b 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -35,13 +35,11 @@ Content, FunctionCall, FunctionDeclaration, - GenerationConfig, GenerationResponse, GenerativeModel, Part, ResponseValidationError, Tool as VertexAITool, - ToolConfig, ) except ImportError: GenerativeModel = None From 7b694ae4a9d212e1243796840f4b6f79b0e5e15e Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 21 May 2025 15:57:37 +0200 Subject: [PATCH 05/12] Update example --- CHANGELOG.md | 1 + .../customize/llms/vertexai_tool_calls.py | 53 ++++++++++++++++--- 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 140b36bd2..cbd83dedd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ ### Fixed - Fixed a bug where `spacy` and `rapidfuzz` needed to be installed even if not using the relevant entity resolvers. +- Fixed a bug where `VertexAILLM.(a)invoke_with_tools` called with multiple tools would raise an error. ### Changed diff --git a/examples/customize/llms/vertexai_tool_calls.py b/examples/customize/llms/vertexai_tool_calls.py index b8b00da5b..8c5e9ca12 100644 --- a/examples/customize/llms/vertexai_tool_calls.py +++ b/examples/customize/llms/vertexai_tool_calls.py @@ -4,6 +4,7 @@ """ import asyncio +from typing import Optional from dotenv import load_dotenv from vertexai.generative_models import GenerationConfig @@ -17,7 +18,7 @@ # Create a custom Tool implementation for person info extraction -parameters = ObjectParameter( +person_tool_parameters = ObjectParameter( description="Parameters for extracting person information", properties={ "name": StringParameter(description="The person's full name"), @@ -29,7 +30,9 @@ ) -def run_tool(name: str, age: int, occupation: str) -> str: +def run_person_tool( + name: str, age: Optional[int] = None, occupation: Optional[str] = None +) -> str: """A simple function that summarizes person information from input parameters.""" return f"Found person {name} with age {age} and occupation {occupation}" @@ -37,12 +40,40 @@ def run_tool(name: str, age: int, occupation: str) -> str: person_info_tool = Tool( name="extract_person_info", description="Extract information about a person from text", - parameters=parameters, - execute_func=run_tool, + parameters=person_tool_parameters, + execute_func=run_person_tool, +) + +company_tool_parameters = ObjectParameter( + description="Parameters for extracting company information", + properties={ + "name": StringParameter(description="The company's full name"), + "industry": StringParameter(description="The company's industry"), + "creation_year": IntegerParameter(description="The company's creation year"), + }, + required_properties=["name"], + additional_properties=False, +) + + +def run_company_tool( + name: str, industry: Optional[str] = None, creation_year: Optional[int] = None +) -> str: + """A simple function that summarizes company information from input parameters.""" + return ( + f"Found company {name} operating in industry {industry} since {creation_year}" + ) + + +company_info_tool = Tool( + name="extract_company_info", + description="Extract information about a company from text", + parameters=company_tool_parameters, + execute_func=run_company_tool, ) # Create the tool instance -TOOLS = [person_info_tool] +TOOLS = [person_info_tool, company_info_tool] def process_tool_call(response: ToolCallResponse) -> str: @@ -54,7 +85,12 @@ def process_tool_call(response: ToolCallResponse) -> str: print(f"\nTool called: {tool_call.name}") print(f"Arguments: {tool_call.arguments}") print(f"Additional content: {response.content or 'None'}") - return person_info_tool.execute(**tool_call.arguments) # type: ignore[no-any-return] + if tool_call.name == "extract_person_info": + return person_info_tool.execute(**tool_call.arguments) # type: ignore[no-any-return] + elif tool_call.name == "extract_company_info": + return company_info_tool.execute(**tool_call.arguments) + else: + raise ValueError("Unknown tool call") async def main() -> None: @@ -66,12 +102,13 @@ async def main() -> None: ) # Example text containing information about a person - text = "Stella Hane is a 35-year-old software engineer who loves coding." + # text = "Stella Hane is a 35-year-old software engineer who loves coding." + text1 = "Neo4j is a software company created in 2017" print("\n=== Synchronous Tool Call ===") # Make a synchronous tool call sync_response = llm.invoke_with_tools( - input=f"Extract information about the person from this text: {text}", + input=f"Extract information about the person from this text: {text1}", tools=TOOLS, ) sync_result = process_tool_call(sync_response) From bd0217a3d1597038fda13de641fc43c457aa504e Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 21 May 2025 15:59:04 +0200 Subject: [PATCH 06/12] Update example --- examples/customize/llms/vertexai_tool_calls.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/customize/llms/vertexai_tool_calls.py b/examples/customize/llms/vertexai_tool_calls.py index 8c5e9ca12..7b3357eb2 100644 --- a/examples/customize/llms/vertexai_tool_calls.py +++ b/examples/customize/llms/vertexai_tool_calls.py @@ -101,8 +101,7 @@ async def main() -> None: generation_config=generation_config, ) - # Example text containing information about a person - # text = "Stella Hane is a 35-year-old software engineer who loves coding." + # Example text containing information about a company text1 = "Neo4j is a software company created in 2017" print("\n=== Synchronous Tool Call ===") @@ -116,7 +115,7 @@ async def main() -> None: print(sync_result) print("\n=== Asynchronous Tool Call ===") - # Make an asynchronous tool call with a different text + # Make an asynchronous tool call with a different text about a person text2 = "Molly Hane, 32, works as a data scientist and enjoys machine learning." async_response = await llm.ainvoke_with_tools( input=f"Extract information about the person from this text: {text2}", From 5bc2eb81a0a50fb409011957f44cbd4b63a1850c Mon Sep 17 00:00:00 2001 From: estelle Date: Fri, 23 May 2025 17:10:08 +0200 Subject: [PATCH 07/12] Review comment part 1 --- examples/customize/llms/vertexai_tool_calls.py | 2 +- src/neo4j_graphrag/llm/vertexai_llm.py | 18 +++++++----------- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/examples/customize/llms/vertexai_tool_calls.py b/examples/customize/llms/vertexai_tool_calls.py index 7b3357eb2..1a67c36ed 100644 --- a/examples/customize/llms/vertexai_tool_calls.py +++ b/examples/customize/llms/vertexai_tool_calls.py @@ -102,7 +102,7 @@ async def main() -> None: ) # Example text containing information about a company - text1 = "Neo4j is a software company created in 2017" + text1 = "Neo4j is a software company created in 2007" print("\n=== Synchronous Tool Call ===") # Make a synchronous tool call diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index bddc75e1b..387c51bc7 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -200,16 +200,6 @@ def _get_llm_tools( ) ] - def _get_options(self, tool_mode: bool = False) -> dict[str, Any]: - options = dict(self.options) - if tool_mode: - # we want a tool back, remove generation_config if defined - options.pop("generation_config", None) - else: - # no tools, remove tool_config if defined - options.pop("tool_config", None) - return options - def _get_model( self, system_instruction: Optional[str] = None, @@ -217,7 +207,13 @@ def _get_model( ) -> GenerativeModel: system_message = [system_instruction] if system_instruction is not None else [] vertex_ai_tools = self._get_llm_tools(tools) - options = self._get_options(tool_mode=tools is not None) + options = dict(self.options) + if tools: + # we want a tool back, remove generation_config if defined + options.pop("generation_config", None) + else: + # no tools, remove tool_config if defined + options.pop("tool_config", None) model = GenerativeModel( model_name=self.model_name, system_instruction=system_message, From 907e8b511052189a07e5e1be4e9c118fd93fdffd Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 26 May 2025 09:33:20 +0200 Subject: [PATCH 08/12] mypy --- examples/customize/llms/vertexai_tool_calls.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/customize/llms/vertexai_tool_calls.py b/examples/customize/llms/vertexai_tool_calls.py index 1a67c36ed..d853d7377 100644 --- a/examples/customize/llms/vertexai_tool_calls.py +++ b/examples/customize/llms/vertexai_tool_calls.py @@ -88,7 +88,7 @@ def process_tool_call(response: ToolCallResponse) -> str: if tool_call.name == "extract_person_info": return person_info_tool.execute(**tool_call.arguments) # type: ignore[no-any-return] elif tool_call.name == "extract_company_info": - return company_info_tool.execute(**tool_call.arguments) + return str(company_info_tool.execute(**tool_call.arguments)) else: raise ValueError("Unknown tool call") From c1db7569dd1a392fc4a0f0eba359b7611a640905 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 16 Jun 2025 15:13:49 +0200 Subject: [PATCH 09/12] Better deal with call options --- .../customize/llms/vertexai_tool_calls.py | 7 ++- src/neo4j_graphrag/llm/vertexai_llm.py | 60 ++++++++++++------- tests/unit/llm/test_vertexai_llm.py | 39 +++++++----- 3 files changed, 67 insertions(+), 39 deletions(-) diff --git a/examples/customize/llms/vertexai_tool_calls.py b/examples/customize/llms/vertexai_tool_calls.py index d853d7377..ebe9fec22 100644 --- a/examples/customize/llms/vertexai_tool_calls.py +++ b/examples/customize/llms/vertexai_tool_calls.py @@ -97,8 +97,13 @@ async def main() -> None: # Initialize the VertexAI LLM generation_config = GenerationConfig(temperature=0.0) llm = VertexAILLM( - model_name="gemini-1.5-flash-001", + model_name="gemini-2.0-flash-001", generation_config=generation_config, + # tool_config=ToolConfig( + # function_calling_config=ToolConfig.FunctionCallingConfig( + # mode=ToolConfig.FunctionCallingConfig.Mode.ANY, + # # allowed_function_names=["extract_person_info"], + # )) ) # Example text containing information about a company diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index 387c51bc7..bc75e8454 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -40,6 +40,7 @@ Part, ResponseValidationError, Tool as VertexAITool, + ToolConfig, ) except ImportError: GenerativeModel = None @@ -139,16 +140,15 @@ def invoke( """ model = self._get_model( system_instruction=system_instruction, - tools=None, ) try: if isinstance(message_history, MessageHistory): message_history = message_history.messages - messages = self.get_messages(input, message_history) - response = model.generate_content(messages) + options = self._get_call_params(input, message_history, tools=None) + response = model.generate_content(**options) return self._parse_content_response(response) except ResponseValidationError as e: - raise LLMGenerationError("Error calling LLM") from e + raise LLMGenerationError("Error calling VertexAILLM") from e async def ainvoke( self, @@ -172,13 +172,12 @@ async def ainvoke( message_history = message_history.messages model = self._get_model( system_instruction=system_instruction, - tools=None, ) - messages = self.get_messages(input, message_history) - response = await model.generate_content_async(messages) + options = self._get_call_params(input, message_history, tools=None) + response = await model.generate_content_async(**options) return self._parse_content_response(response) except ResponseValidationError as e: - raise LLMGenerationError(e) + raise LLMGenerationError("Error calling VertexAILLM") from e def _to_vertexai_function_declaration(self, tool: Tool) -> FunctionDeclaration: return FunctionDeclaration( @@ -203,24 +202,38 @@ def _get_llm_tools( def _get_model( self, system_instruction: Optional[str] = None, - tools: Optional[Sequence[Tool]] = None, ) -> GenerativeModel: system_message = [system_instruction] if system_instruction is not None else [] - vertex_ai_tools = self._get_llm_tools(tools) + model = GenerativeModel( + model_name=self.model_name, + system_instruction=system_message, + ) + return model + + def _get_call_params( + self, + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]], + tools: Optional[Sequence[Tool]], + ): options = dict(self.options) if tools: # we want a tool back, remove generation_config if defined options.pop("generation_config", None) + options["tools"] = self._get_llm_tools(tools) + if "tool_config" not in options: + options["tool_config"] = ToolConfig( + function_calling_config=ToolConfig.FunctionCallingConfig( + mode=ToolConfig.FunctionCallingConfig.Mode.ANY, + ) + ) else: # no tools, remove tool_config if defined options.pop("tool_config", None) - model = GenerativeModel( - model_name=self.model_name, - system_instruction=system_message, - tools=vertex_ai_tools, - **options, - ) - return model + + messages = self.get_messages(input, message_history) + options["contents"] = messages + return options async def _acall_llm( self, @@ -229,9 +242,9 @@ async def _acall_llm( system_instruction: Optional[str] = None, tools: Optional[Sequence[Tool]] = None, ) -> GenerationResponse: - model = self._get_model(system_instruction=system_instruction, tools=tools) - messages = self.get_messages(input, message_history) - response = await model.generate_content_async(messages, **self.model_params) + model = self._get_model(system_instruction=system_instruction) + options = self._get_call_params(input, message_history, tools) + response = await model.generate_content_async(**options) return response def _call_llm( @@ -241,9 +254,10 @@ def _call_llm( system_instruction: Optional[str] = None, tools: Optional[Sequence[Tool]] = None, ) -> GenerationResponse: - model = self._get_model(system_instruction=system_instruction, tools=tools) - messages = self.get_messages(input, message_history) - response = model.generate_content(messages, **self.model_params) + model = self._get_model(system_instruction=system_instruction) + options = self._get_call_params(input, message_history, tools) + print(options) + response = model.generate_content(**options) return response def _to_tool_call(self, function_call: FunctionCall) -> ToolCall: diff --git a/tests/unit/llm/test_vertexai_llm.py b/tests/unit/llm/test_vertexai_llm.py index 6a5ab74ca..4276a936e 100644 --- a/tests/unit/llm/test_vertexai_llm.py +++ b/tests/unit/llm/test_vertexai_llm.py @@ -14,7 +14,7 @@ from __future__ import annotations from typing import cast -from unittest.mock import AsyncMock, MagicMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, patch, ANY import pytest from neo4j_graphrag.exceptions import LLMGenerationError @@ -26,6 +26,7 @@ Content, GenerationResponse, Part, + ToolConfig, ) @@ -49,10 +50,11 @@ def test_vertexai_invoke_happy_path(GenerativeModelMock: MagicMock) -> None: response = llm.invoke(input_text) assert response.content == "Return text" GenerativeModelMock.assert_called_once_with( - model_name=model_name, system_instruction=[], tools=None + model_name=model_name, + system_instruction=[], ) last_call = mock_model.generate_content.call_args_list[0] - content = last_call.args[0] + content = last_call.kwargs["contents"] assert len(content) == 1 assert content[0].role == "user" assert content[0].parts[0].text == input_text @@ -80,9 +82,12 @@ def test_vertexai_invoke_with_system_instruction( response = llm.invoke(input_text, system_instruction=system_instruction) assert response.content == "Return text" GenerativeModelMock.assert_called_once_with( - model_name=model_name, system_instruction=[system_instruction], tools=None + model_name=model_name, + system_instruction=[system_instruction], + ) + mock_model.generate_content.assert_called_once_with( + contents=[{"text": "some text"}] ) - mock_model.generate_content.assert_called_once_with([{"text": "some text"}]) @patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") @@ -111,10 +116,11 @@ def test_vertexai_invoke_with_message_history_and_system_instruction( ) assert response.content == "Return text" GenerativeModelMock.assert_called_once_with( - model_name=model_name, system_instruction=[system_instruction], tools=None + model_name=model_name, + system_instruction=[system_instruction], ) last_call = mock_model.generate_content.call_args_list[0] - content = last_call.args[0] + content = last_call.kwargs["contents"] assert len(content) == 3 # question + 2 messages in history @@ -181,7 +187,7 @@ async def test_vertexai_ainvoke_happy_path( response = await llm.ainvoke(input_text) assert response.content == "Return text" mock_model.generate_content_async.assert_awaited_once_with( - [{"text": "Return text"}] + contents=[{"text": "Return text"}] ) @@ -238,13 +244,17 @@ def test_vertexai_call_llm_with_tools(mock_model: Mock, test_tool: Tool) -> None llm = VertexAILLM(model_name="gemini") tools = [test_tool] - res = llm._call_llm("my text", tools=tools) - assert isinstance(res, GenerationResponse) + with patch.object(llm, "_get_llm_tools", return_value=["my tools"]): + res = llm._call_llm("my text", tools=tools) + assert isinstance(res, GenerationResponse) - mock_model.assert_called_once_with( - system_instruction=None, - tools=tools, - ) + mock_model.assert_called_once_with( + system_instruction=None, + ) + calls = mock_generate_content.call_args_list + assert len(calls) == 1 + assert calls[0][1]["tools"] == ["my tools"] + assert calls[0][1]["tool_config"] is not None @patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM._parse_tool_response") @@ -295,6 +305,5 @@ async def test_vertexai_acall_llm_with_tools(mock_model: Mock, test_tool: Tool) res = await llm._acall_llm("my text", tools=tools) mock_model.assert_called_once_with( system_instruction=None, - tools=tools, ) assert isinstance(res, GenerationResponse) From 609e9cef00266b02c00355dbd4f09984f3dfd698 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 16 Jun 2025 15:15:51 +0200 Subject: [PATCH 10/12] Ruff --- tests/unit/llm/test_vertexai_llm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/unit/llm/test_vertexai_llm.py b/tests/unit/llm/test_vertexai_llm.py index 4276a936e..c937d2cb8 100644 --- a/tests/unit/llm/test_vertexai_llm.py +++ b/tests/unit/llm/test_vertexai_llm.py @@ -14,7 +14,7 @@ from __future__ import annotations from typing import cast -from unittest.mock import AsyncMock, MagicMock, Mock, patch, ANY +from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest from neo4j_graphrag.exceptions import LLMGenerationError @@ -26,7 +26,6 @@ Content, GenerationResponse, Part, - ToolConfig, ) From 80971bffc7e302f23e126fdadc0be3adf8d2e1c3 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 16 Jun 2025 15:22:18 +0200 Subject: [PATCH 11/12] mypy --- src/neo4j_graphrag/llm/vertexai_llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index bc75e8454..1fa4b566a 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -215,7 +215,7 @@ def _get_call_params( input: str, message_history: Optional[Union[List[LLMMessage], MessageHistory]], tools: Optional[Sequence[Tool]], - ): + ) -> dict[str, Any]: options = dict(self.options) if tools: # we want a tool back, remove generation_config if defined From a3c7199915dfb3db564ea8c7f0d1a9962eda1c88 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 16 Jun 2025 15:38:51 +0200 Subject: [PATCH 12/12] Rm print --- src/neo4j_graphrag/llm/vertexai_llm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index 1fa4b566a..39d483915 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -256,7 +256,6 @@ def _call_llm( ) -> GenerationResponse: model = self._get_model(system_instruction=system_instruction) options = self._get_call_params(input, message_history, tools) - print(options) response = model.generate_content(**options) return response