Skip to content

Fix VertexAILLM #342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
### Fixed

- Fixed a bug where `spacy` and `rapidfuzz` needed to be installed even if not using the relevant entity resolvers.
- Fixed a bug where `VertexAILLM.(a)invoke_with_tools` called with multiple tools would raise an error.

### Changed

Expand Down
56 changes: 46 additions & 10 deletions examples/customize/llms/vertexai_tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import asyncio
from typing import Optional

from dotenv import load_dotenv
from vertexai.generative_models import GenerationConfig
Expand All @@ -17,7 +18,7 @@


# Create a custom Tool implementation for person info extraction
parameters = ObjectParameter(
person_tool_parameters = ObjectParameter(
description="Parameters for extracting person information",
properties={
"name": StringParameter(description="The person's full name"),
Expand All @@ -29,20 +30,50 @@
)


def run_tool(name: str, age: int, occupation: str) -> str:
def run_person_tool(
name: str, age: Optional[int] = None, occupation: Optional[str] = None
) -> str:
"""A simple function that summarizes person information from input parameters."""
return f"Found person {name} with age {age} and occupation {occupation}"


person_info_tool = Tool(
name="extract_person_info",
description="Extract information about a person from text",
parameters=parameters,
execute_func=run_tool,
parameters=person_tool_parameters,
execute_func=run_person_tool,
)

company_tool_parameters = ObjectParameter(
description="Parameters for extracting company information",
properties={
"name": StringParameter(description="The company's full name"),
"industry": StringParameter(description="The company's industry"),
"creation_year": IntegerParameter(description="The company's creation year"),
},
required_properties=["name"],
additional_properties=False,
)


def run_company_tool(
name: str, industry: Optional[str] = None, creation_year: Optional[int] = None
) -> str:
"""A simple function that summarizes company information from input parameters."""
return (
f"Found company {name} operating in industry {industry} since {creation_year}"
)


company_info_tool = Tool(
name="extract_company_info",
description="Extract information about a company from text",
parameters=company_tool_parameters,
execute_func=run_company_tool,
)

# Create the tool instance
TOOLS = [person_info_tool]
TOOLS = [person_info_tool, company_info_tool]


def process_tool_call(response: ToolCallResponse) -> str:
Expand All @@ -54,7 +85,12 @@ def process_tool_call(response: ToolCallResponse) -> str:
print(f"\nTool called: {tool_call.name}")
print(f"Arguments: {tool_call.arguments}")
print(f"Additional content: {response.content or 'None'}")
return person_info_tool.execute(**tool_call.arguments) # type: ignore[no-any-return]
if tool_call.name == "extract_person_info":
return person_info_tool.execute(**tool_call.arguments) # type: ignore[no-any-return]
elif tool_call.name == "extract_company_info":
return str(company_info_tool.execute(**tool_call.arguments))
else:
raise ValueError("Unknown tool call")


async def main() -> None:
Expand All @@ -65,21 +101,21 @@ async def main() -> None:
generation_config=generation_config,
)

# Example text containing information about a person
text = "Stella Hane is a 35-year-old software engineer who loves coding."
# Example text containing information about a company
text1 = "Neo4j is a software company created in 2007"

print("\n=== Synchronous Tool Call ===")
# Make a synchronous tool call
sync_response = llm.invoke_with_tools(
input=f"Extract information about the person from this text: {text}",
input=f"Extract information about the person from this text: {text1}",
tools=TOOLS,
)
sync_result = process_tool_call(sync_response)
print("\n=== Synchronous Tool Call Result ===")
print(sync_result)

print("\n=== Asynchronous Tool Call ===")
# Make an asynchronous tool call with a different text
# Make an asynchronous tool call with a different text about a person
text2 = "Molly Hane, 32, works as a data scientist and enjoys machine learning."
async_response = await llm.ainvoke_with_tools(
input=f"Extract information about the person from this text: {text2}",
Expand Down
61 changes: 31 additions & 30 deletions src/neo4j_graphrag/llm/vertexai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,20 +137,18 @@ def invoke(
Returns:
LLMResponse: The response from the LLM.
"""
system_message = [system_instruction] if system_instruction is not None else []
self.model = GenerativeModel(
model_name=self.model_name,
system_instruction=system_message,
**self.options,
model = self._get_model(
system_instruction=system_instruction,
tools=None,
)
try:
if isinstance(message_history, MessageHistory):
message_history = message_history.messages
messages = self.get_messages(input, message_history)
response = self.model.generate_content(messages, **self.model_params)
return LLMResponse(content=response.text)
response = model.generate_content(messages)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we removing model_params here?

return self._parse_content_response(response)
except ResponseValidationError as e:
raise LLMGenerationError(e)
raise LLMGenerationError("Error calling LLM") from e

async def ainvoke(
self,
Expand All @@ -172,39 +170,35 @@ async def ainvoke(
try:
if isinstance(message_history, MessageHistory):
message_history = message_history.messages
system_message = (
[system_instruction] if system_instruction is not None else []
)
self.model = GenerativeModel(
model_name=self.model_name,
system_instruction=system_message,
**self.options,
model = self._get_model(
system_instruction=system_instruction,
tools=None,
)
messages = self.get_messages(input, message_history)
response = await self.model.generate_content_async(
messages, **self.model_params
)
return LLMResponse(content=response.text)
response = await model.generate_content_async(messages)
return self._parse_content_response(response)
except ResponseValidationError as e:
raise LLMGenerationError(e)

def _to_vertexai_tool(self, tool: Tool) -> VertexAITool:
return VertexAITool(
function_declarations=[
FunctionDeclaration(
name=tool.get_name(),
description=tool.get_description(),
parameters=tool.get_parameters(exclude=["additional_properties"]),
)
]
def _to_vertexai_function_declaration(self, tool: Tool) -> FunctionDeclaration:
return FunctionDeclaration(
name=tool.get_name(),
description=tool.get_description(),
parameters=tool.get_parameters(exclude=["additional_properties"]),
)

def _get_llm_tools(
self, tools: Optional[Sequence[Tool]]
) -> Optional[list[VertexAITool]]:
if not tools:
return None
return [self._to_vertexai_tool(tool) for tool in tools]
return [
VertexAITool(
function_declarations=[
self._to_vertexai_function_declaration(tool) for tool in tools
]
)
]

def _get_model(
self,
Expand All @@ -213,11 +207,18 @@ def _get_model(
) -> GenerativeModel:
system_message = [system_instruction] if system_instruction is not None else []
vertex_ai_tools = self._get_llm_tools(tools)
options = dict(self.options)
if tools:
# we want a tool back, remove generation_config if defined
options.pop("generation_config", None)
else:
# no tools, remove tool_config if defined
options.pop("tool_config", None)
model = GenerativeModel(
model_name=self.model_name,
system_instruction=system_message,
tools=vertex_ai_tools,
**self.options,
**options,
)
return model

Expand Down
33 changes: 18 additions & 15 deletions tests/unit/llm/test_vertexai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from __future__ import annotations

from typing import cast
from unittest import mock
from unittest.mock import AsyncMock, MagicMock, Mock, patch

import pytest
Expand Down Expand Up @@ -50,19 +49,19 @@ def test_vertexai_invoke_happy_path(GenerativeModelMock: MagicMock) -> None:
response = llm.invoke(input_text)
assert response.content == "Return text"
GenerativeModelMock.assert_called_once_with(
model_name=model_name, system_instruction=[]
model_name=model_name, system_instruction=[], tools=None
)
user_message = mock.ANY
llm.model.generate_content.assert_called_once_with(user_message, **model_params)
last_call = llm.model.generate_content.call_args_list[0]
last_call = mock_model.generate_content.call_args_list[0]
content = last_call.args[0]
assert len(content) == 1
assert content[0].role == "user"
assert content[0].parts[0].text == input_text


@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel")
@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM.get_messages")
def test_vertexai_invoke_with_system_instruction(
mock_get_messages: MagicMock,
GenerativeModelMock: MagicMock,
) -> None:
system_instruction = "You are a helpful assistant."
Expand All @@ -72,16 +71,18 @@ def test_vertexai_invoke_with_system_instruction(
mock_response.text = "Return text"
mock_model = GenerativeModelMock.return_value
mock_model.generate_content.return_value = mock_response

mock_get_messages.return_value = [{"text": "some text"}]

model_params = {"temperature": 0.5}
llm = VertexAILLM(model_name, model_params)

response = llm.invoke(input_text, system_instruction=system_instruction)
assert response.content == "Return text"
GenerativeModelMock.assert_called_once_with(
model_name=model_name, system_instruction=[system_instruction]
model_name=model_name, system_instruction=[system_instruction], tools=None
)
user_message = mock.ANY
llm.model.generate_content.assert_called_once_with(user_message, **model_params)
mock_model.generate_content.assert_called_once_with([{"text": "some text"}])


@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel")
Expand Down Expand Up @@ -110,11 +111,9 @@ def test_vertexai_invoke_with_message_history_and_system_instruction(
)
assert response.content == "Return text"
GenerativeModelMock.assert_called_once_with(
model_name=model_name, system_instruction=[system_instruction]
model_name=model_name, system_instruction=[system_instruction], tools=None
)
user_message = mock.ANY
llm.model.generate_content.assert_called_once_with(user_message, **model_params)
last_call = llm.model.generate_content.call_args_list[0]
last_call = mock_model.generate_content.call_args_list[0]
content = last_call.args[0]
assert len(content) == 3 # question + 2 messages in history

Expand Down Expand Up @@ -167,18 +166,22 @@ def test_vertexai_get_messages_validation_error(GenerativeModelMock: MagicMock)

@pytest.mark.asyncio
@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel")
async def test_vertexai_ainvoke_happy_path(GenerativeModelMock: MagicMock) -> None:
@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM.get_messages")
async def test_vertexai_ainvoke_happy_path(
mock_get_messages: Mock, GenerativeModelMock: MagicMock
) -> None:
mock_response = AsyncMock()
mock_response.text = "Return text"
mock_model = GenerativeModelMock.return_value
mock_model.generate_content_async = AsyncMock(return_value=mock_response)
mock_get_messages.return_value = [{"text": "Return text"}]
model_params = {"temperature": 0.5}
llm = VertexAILLM("gemini-1.5-flash-001", model_params)
input_text = "may thy knife chip and shatter"
response = await llm.ainvoke(input_text)
assert response.content == "Return text"
llm.model.generate_content_async.assert_awaited_once_with(
[mock.ANY], **model_params
mock_model.generate_content_async.assert_awaited_once_with(
[{"text": "Return text"}]
)


Expand Down