Skip to content

Commit c0965c3

Browse files
committed
Add tests
1 parent 1f046fd commit c0965c3

File tree

6 files changed

+176
-40
lines changed

6 files changed

+176
-40
lines changed

src/neo4j_graphrag/llm/vertexai_llm.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -194,12 +194,12 @@ def _to_vertexai_tool(self, tool: Tool) -> VertexAITool:
194194
FunctionDeclaration(
195195
name=tool.get_name(),
196196
description=tool.get_description(),
197-
parameters=tool.get_parameters(),
197+
parameters=tool.get_parameters(exclude=["additional_properties"]),
198198
)
199199
]
200200
)
201201

202-
def get_tools(
202+
def _get_llm_tools(
203203
self, tools: Optional[Sequence[Tool]]
204204
) -> Optional[list[VertexAITool]]:
205205
if not tools:
@@ -212,7 +212,7 @@ def _get_model(
212212
tools: Optional[Sequence[Tool]] = None,
213213
) -> GenerativeModel:
214214
system_message = [system_instruction] if system_instruction is not None else []
215-
vertex_ai_tools = self.get_tools(tools)
215+
vertex_ai_tools = self._get_llm_tools(tools)
216216
model = GenerativeModel(
217217
model_name=self.model_name,
218218
system_instruction=system_message,
@@ -228,7 +228,7 @@ async def _acall_llm(
228228
system_instruction: Optional[str] = None,
229229
tools: Optional[Sequence[Tool]] = None,
230230
) -> GenerationResponse:
231-
model = self._get_model(system_instruction, tools)
231+
model = self._get_model(system_instruction=system_instruction, tools=tools)
232232
messages = self.get_messages(input, message_history)
233233
response = await model.generate_content_async(messages, **self.model_params)
234234
return response
@@ -240,7 +240,7 @@ def _call_llm(
240240
system_instruction: Optional[str] = None,
241241
tools: Optional[Sequence[Tool]] = None,
242242
) -> GenerationResponse:
243-
model = self._get_model(system_instruction, tools)
243+
model = self._get_model(system_instruction=system_instruction, tools=tools)
244244
messages = self.get_messages(input, message_history)
245245
response = model.generate_content(messages, **self.model_params)
246246
return response

src/neo4j_graphrag/tool.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -169,18 +169,21 @@ def _preprocess_properties(cls, values: dict[str, Any]) -> dict[str, Any]:
169169
values["properties"] = new_props
170170
return values
171171

172-
def model_dump_tool(self) -> Dict[str, Any]:
172+
def model_dump_tool(self, exclude: Optional[list[str]] = None) -> Dict[str, Any]:
173+
exclude = exclude or []
173174
properties_dict: Dict[str, Any] = {}
174175
for name, param in self.properties.items():
176+
if name in exclude:
177+
continue
175178
properties_dict[name] = param.model_dump_tool()
176179

177180
result = super().model_dump_tool()
178181
result["properties"] = properties_dict
179182

180-
if self.required_properties:
183+
if self.required_properties and "required" not in exclude:
181184
result["required"] = self.required_properties
182185

183-
if not self.additional_properties:
186+
if not self.additional_properties and "additional_properties" not in exclude:
184187
result["additionalProperties"] = False
185188

186189
return result
@@ -242,13 +245,13 @@ def get_description(self) -> str:
242245
"""
243246
return self._description
244247

245-
def get_parameters(self) -> Dict[str, Any]:
248+
def get_parameters(self, exclude: Optional[list[str]] = None) -> Dict[str, Any]:
246249
"""Get the parameters the tool accepts in a dictionary format suitable for LLM providers.
247250
248251
Returns:
249252
Dict[str, Any]: Dictionary containing parameter schema information.
250253
"""
251-
return self._parameters.model_dump_tool()
254+
return self._parameters.model_dump_tool(exclude)
252255

253256
def execute(self, query: str, **kwargs: Any) -> Any:
254257
"""Execute the tool with the given query and additional parameters.

tests/unit/llm/conftest.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import pytest
2+
3+
from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter
4+
5+
6+
class TestTool(Tool):
7+
"""Test tool for unit tests."""
8+
9+
def __init__(self, name: str = "test_tool", description: str = "A test tool"):
10+
parameters = ObjectParameter(
11+
description="Test parameters",
12+
properties={"param1": StringParameter(description="Test parameter")},
13+
required_properties=["param1"],
14+
additional_properties=False,
15+
)
16+
17+
super().__init__(
18+
name=name,
19+
description=description,
20+
parameters=parameters,
21+
execute_func=lambda **kwargs: kwargs,
22+
)
23+
24+
25+
@pytest.fixture
26+
def test_tool() -> Tool:
27+
return TestTool()

tests/unit/llm/test_openai_llm.py

+15-28
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from neo4j_graphrag.llm import LLMResponse
2121
from neo4j_graphrag.llm.openai_llm import AzureOpenAILLM, OpenAILLM
2222
from neo4j_graphrag.llm.types import ToolCallResponse
23-
from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter
23+
from neo4j_graphrag.tool import Tool
2424

2525

2626
def get_mock_openai() -> MagicMock:
@@ -29,25 +29,6 @@ def get_mock_openai() -> MagicMock:
2929
return mock
3030

3131

32-
class TestTool(Tool):
33-
"""Test tool for unit tests."""
34-
35-
def __init__(self, name: str = "test_tool", description: str = "A test tool"):
36-
parameters = ObjectParameter(
37-
description="Test parameters",
38-
properties={"param1": StringParameter(description="Test parameter")},
39-
required_properties=["param1"],
40-
additional_properties=False,
41-
)
42-
43-
super().__init__(
44-
name=name,
45-
description=description,
46-
parameters=parameters,
47-
execute_func=lambda **kwargs: kwargs,
48-
)
49-
50-
5132
@patch("builtins.__import__", side_effect=ImportError)
5233
def test_openai_llm_missing_dependency(mock_import: Mock) -> None:
5334
with pytest.raises(ImportError):
@@ -156,7 +137,9 @@ def test_openai_llm_with_message_history_validation_error(mock_import: Mock) ->
156137
@patch("builtins.__import__")
157138
@patch("json.loads")
158139
def test_openai_llm_invoke_with_tools_happy_path(
159-
mock_json_loads: Mock, mock_import: Mock
140+
mock_json_loads: Mock,
141+
mock_import: Mock,
142+
test_tool: Tool,
160143
) -> None:
161144
# Set up json.loads to return a dictionary
162145
mock_json_loads.return_value = {"param1": "value1"}
@@ -183,7 +166,7 @@ def test_openai_llm_invoke_with_tools_happy_path(
183166
)
184167

185168
llm = OpenAILLM(api_key="my key", model_name="gpt")
186-
tools = [TestTool()]
169+
tools = [test_tool]
187170

188171
res = llm.invoke_with_tools("my text", tools)
189172
assert isinstance(res, ToolCallResponse)
@@ -196,7 +179,9 @@ def test_openai_llm_invoke_with_tools_happy_path(
196179
@patch("builtins.__import__")
197180
@patch("json.loads")
198181
def test_openai_llm_invoke_with_tools_with_message_history(
199-
mock_json_loads: Mock, mock_import: Mock
182+
mock_json_loads: Mock,
183+
mock_import: Mock,
184+
test_tool: Tool,
200185
) -> None:
201186
# Set up json.loads to return a dictionary
202187
mock_json_loads.return_value = {"param1": "value1"}
@@ -223,7 +208,7 @@ def test_openai_llm_invoke_with_tools_with_message_history(
223208
)
224209

225210
llm = OpenAILLM(api_key="my key", model_name="gpt")
226-
tools = [TestTool()]
211+
tools = [test_tool]
227212

228213
message_history = [
229214
{"role": "user", "content": "When does the sun come up in the summer?"},
@@ -259,7 +244,9 @@ def test_openai_llm_invoke_with_tools_with_message_history(
259244
@patch("builtins.__import__")
260245
@patch("json.loads")
261246
def test_openai_llm_invoke_with_tools_with_system_instruction(
262-
mock_json_loads: Mock, mock_import: Mock
247+
mock_json_loads: Mock,
248+
mock_import: Mock,
249+
test_tool: Mock,
263250
) -> None:
264251
# Set up json.loads to return a dictionary
265252
mock_json_loads.return_value = {"param1": "value1"}
@@ -286,7 +273,7 @@ def test_openai_llm_invoke_with_tools_with_system_instruction(
286273
)
287274

288275
llm = OpenAILLM(api_key="my key", model_name="gpt")
289-
tools = [TestTool()]
276+
tools = [test_tool]
290277

291278
system_instruction = "You are a helpful assistant."
292279

@@ -314,7 +301,7 @@ def test_openai_llm_invoke_with_tools_with_system_instruction(
314301

315302

316303
@patch("builtins.__import__")
317-
def test_openai_llm_invoke_with_tools_error(mock_import: Mock) -> None:
304+
def test_openai_llm_invoke_with_tools_error(mock_import: Mock, test_tool: Tool) -> None:
318305
mock_openai = get_mock_openai()
319306
mock_import.return_value = mock_openai
320307

@@ -324,7 +311,7 @@ def test_openai_llm_invoke_with_tools_error(mock_import: Mock) -> None:
324311
)
325312

326313
llm = OpenAILLM(api_key="my key", model_name="gpt")
327-
tools = [TestTool()]
314+
tools = [test_tool]
328315

329316
with pytest.raises(LLMGenerationError):
330317
llm.invoke_with_tools("my text", tools)

tests/unit/llm/test_vertexai_llm.py

+121-2
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,15 @@
1919

2020
import pytest
2121
from neo4j_graphrag.exceptions import LLMGenerationError
22+
from neo4j_graphrag.llm.types import ToolCallResponse
2223
from neo4j_graphrag.llm.vertexai_llm import VertexAILLM
24+
from neo4j_graphrag.tool import Tool
2325
from neo4j_graphrag.types import LLMMessage
24-
from vertexai.generative_models import Content, Part
26+
from vertexai.generative_models import (
27+
Content,
28+
GenerationResponse,
29+
Part,
30+
)
2531

2632

2733
@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel", None)
@@ -171,4 +177,117 @@ async def test_vertexai_ainvoke_happy_path(GenerativeModelMock: MagicMock) -> No
171177
input_text = "may thy knife chip and shatter"
172178
response = await llm.ainvoke(input_text)
173179
assert response.content == "Return text"
174-
llm.model.generate_content_async.assert_called_once_with([mock.ANY], **model_params)
180+
llm.model.generate_content_async.assert_awaited_once_with(
181+
[mock.ANY], **model_params
182+
)
183+
184+
185+
def test_vertexai_get_llm_tools(test_tool: Tool) -> None:
186+
llm = VertexAILLM(model_name="gemini")
187+
tools = llm._get_llm_tools(tools=[test_tool])
188+
assert tools is not None
189+
assert len(tools) == 1
190+
tool = tools[0]
191+
tool_dict = tool.to_dict()
192+
assert len(tool_dict["function_declarations"]) == 1
193+
assert tool_dict["function_declarations"][0]["name"] == "test_tool"
194+
assert tool_dict["function_declarations"][0]["description"] == "A test tool"
195+
196+
197+
@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM._parse_tool_response")
198+
@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM._call_llm")
199+
def test_vertexai_invoke_with_tools(
200+
mock_call_llm: Mock, mock_parse_tool: Mock, test_tool: Tool,
201+
) -> None:
202+
# Mock the model call response
203+
tool_call_mock = MagicMock()
204+
tool_call_mock.name = "function"
205+
tool_call_mock.args = {}
206+
mock_call_llm.return_value = MagicMock(
207+
candidates=[MagicMock(function_calls=[tool_call_mock])]
208+
)
209+
mock_parse_tool.return_value = ToolCallResponse(tool_calls=[])
210+
211+
llm = VertexAILLM(model_name="gemini")
212+
tools = [test_tool]
213+
214+
res = llm.invoke_with_tools("my text", tools)
215+
mock_call_llm.assert_called_once_with(
216+
"my text",
217+
message_history=None,
218+
system_instruction=None,
219+
tools=tools,
220+
)
221+
mock_parse_tool.assert_called_once()
222+
assert isinstance(res, ToolCallResponse)
223+
224+
225+
@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM._get_model")
226+
def test_vertexai_call_llm_with_tools(mock_model: Mock, test_tool: Tool) -> None:
227+
# Mock the generation response
228+
mock_generate_content = mock_model.return_value.generate_content
229+
mock_generate_content.return_value = MagicMock(
230+
spec=GenerationResponse,
231+
)
232+
233+
llm = VertexAILLM(model_name="gemini")
234+
tools = [test_tool]
235+
236+
res = llm._call_llm("my text", tools=tools)
237+
assert isinstance(res, GenerationResponse)
238+
239+
mock_model.assert_called_once_with(
240+
system_instruction=None,
241+
tools=tools,
242+
)
243+
244+
245+
@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM._parse_tool_response")
246+
@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM._call_llm")
247+
def test_vertexai_ainvoke_with_tools(
248+
mock_call_llm: Mock, mock_parse_tool: Mock, test_tool: Tool,
249+
) -> None:
250+
# Mock the model call response
251+
tool_call_mock = MagicMock()
252+
tool_call_mock.name = "function"
253+
tool_call_mock.args = {}
254+
mock_call_llm.return_value = AsyncMock(
255+
return_value=MagicMock(candidates=[MagicMock(function_calls=[tool_call_mock])])
256+
)
257+
mock_parse_tool.return_value = ToolCallResponse(tool_calls=[])
258+
259+
llm = VertexAILLM(model_name="gemini")
260+
tools = [test_tool]
261+
262+
res = llm.invoke_with_tools("my text", tools)
263+
mock_call_llm.assert_called_once_with(
264+
"my text",
265+
message_history=None,
266+
system_instruction=None,
267+
tools=tools,
268+
)
269+
mock_parse_tool.assert_called_once()
270+
assert isinstance(res, ToolCallResponse)
271+
272+
273+
@pytest.mark.asyncio
274+
@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM._get_model")
275+
async def test_vertexai_acall_llm_with_tools(mock_model, test_tool: Tool) -> None:
276+
# Mock the generation response
277+
mock_model.return_value = AsyncMock(
278+
generate_content_async=AsyncMock(
279+
return_value=MagicMock(
280+
spec=GenerationResponse,
281+
)
282+
)
283+
)
284+
285+
llm = VertexAILLM(model_name="gemini")
286+
tools = [test_tool]
287+
288+
res = await llm._acall_llm("my text", tools=tools)
289+
mock_model.assert_called_once_with(
290+
system_instruction=None,
291+
tools=tools,
292+
)
293+
assert isinstance(res, GenerationResponse)

tests/unit/tool/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)