diff --git a/defog/llm/utils.py b/defog/llm/utils.py index e388d40..ff894db 100644 --- a/defog/llm/utils.py +++ b/defog/llm/utils.py @@ -191,7 +191,7 @@ async def _process_anthropic_response( result = execute_tool(tool_to_call, args) except Exception as e: raise Exception(f"Error executing tool `{func_name}`: {e}") - + if post_tool_function: if inspect.iscoroutinefunction(post_tool_function): await post_tool_function(func_name, args, result) @@ -295,7 +295,7 @@ def _process_anthropic_response_handler( tools=tools, tool_dict=tool_dict, is_async=is_async, - post_tool_function=post_tool_function + post_tool_function=post_tool_function, ) # Caller must await this else: return asyncio.run( @@ -306,7 +306,7 @@ def _process_anthropic_response_handler( tools=tools, tool_dict=tool_dict, is_async=is_async, - post_tool_function=post_tool_function + post_tool_function=post_tool_function, ) ) except Exception as e: @@ -628,7 +628,7 @@ async def _process_openai_response( result = execute_tool(tool_to_call, args) except Exception as e: raise Exception(f"Error executing tool `{func_name}`: {e}") - + if post_tool_function: if inspect.iscoroutinefunction(post_tool_function): await post_tool_function(func_name, args, result) @@ -702,7 +702,7 @@ def _process_openai_response_handler( tool_dict: Dict[str, Callable], response_format, model: str, - is_async: bool =False, + is_async: bool = False, post_tool_function: Callable = None, ): """ diff --git a/tests/test_llm_tool_calls.py b/tests/test_llm_tool_calls.py index d498e32..fccf272 100644 --- a/tests/test_llm_tool_calls.py +++ b/tests/test_llm_tool_calls.py @@ -28,6 +28,7 @@ def log_to_file(function_name, args, result): IO_STREAM.write(message + "\n") return IO_STREAM.getvalue() + class WeatherInput(BaseModel): latitude: float = Field(default=0.0, description="The latitude of the location") longitude: float = Field(default=0.0, description="The longitude of the location") @@ -303,23 +304,28 @@ async def test_post_tool_calls_openai(self): self.assertEqual(expected["args"], actual["args"]) self.assertEqual(expected["result"], actual["result"]) self.assertSetEqual(set(result.tools_used), {"numsum", "numprod"}) - expected_stream_value = json.dumps({ - "function_name": "numprod", - "args": { - "a": 31283, - "b": 2323 + expected_stream_value = ( + json.dumps( + { + "function_name": "numprod", + "args": {"a": 31283, "b": 2323}, + "result": 72670409, }, - "result": 72670409 - }, indent=4) + "\n" + json.dumps({ - "function_name": "numsum", - "args": { - "a": 72670409, - "b": 5 + indent=4, + ) + + "\n" + + json.dumps( + { + "function_name": "numsum", + "args": {"a": 72670409, "b": 5}, + "result": 72670414, }, - "result": 72670414 - }, indent=4) + "\n" + indent=4, + ) + + "\n" + ) self.assertEqual(IO_STREAM.getvalue(), expected_stream_value) - + # clear IO_STREAM IO_STREAM.seek(0) IO_STREAM.truncate() @@ -345,23 +351,28 @@ async def test_post_tool_calls_anthropic(self): self.assertEqual(expected["args"], actual["args"]) self.assertEqual(expected["result"], actual["result"]) self.assertSetEqual(set(result.tools_used), {"numsum", "numprod"}) - expected_stream_value = json.dumps({ - "function_name": "numprod", - "args": { - "a": 31283, - "b": 2323 + expected_stream_value = ( + json.dumps( + { + "function_name": "numprod", + "args": {"a": 31283, "b": 2323}, + "result": 72670409, }, - "result": 72670409 - }, indent=4) + "\n" + json.dumps({ - "function_name": "numsum", - "args": { - "a": 72670409, - "b": 5 + indent=4, + ) + + "\n" + + json.dumps( + { + "function_name": "numsum", + "args": {"a": 72670409, "b": 5}, + "result": 72670414, }, - "result": 72670414 - }, indent=4) + "\n" + indent=4, + ) + + "\n" + ) self.assertEqual(IO_STREAM.getvalue(), expected_stream_value) - + # clear IO_STREAM IO_STREAM.seek(0) IO_STREAM.truncate()