Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an optional post-tool function to handle tool responses #72

Merged
merged 4 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 47 additions & 6 deletions defog/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ async def _process_anthropic_response(
tools,
tool_dict,
is_async,
post_tool_function: Callable = None,
):
"""
Extract content (including any tool calls) and usage info from Anthropic response.
Expand Down Expand Up @@ -191,6 +192,12 @@ async def _process_anthropic_response(
except Exception as e:
raise Exception(f"Error executing tool `{func_name}`: {e}")

if post_tool_function:
if inspect.iscoroutinefunction(post_tool_function):
await post_tool_function(func_name, args, result)
else:
post_tool_function(func_name, args, result)

# Store the tool call, result, and text
tools_used.append(func_name)
tool_outputs.append(
Expand Down Expand Up @@ -261,6 +268,7 @@ def _process_anthropic_response_handler(
tools: List[Callable],
tool_dict: Dict[str, Callable],
is_async=False,
post_tool_function: Callable = None,
):
"""
Processes Anthropic's response by determining whether to execute the response handling
Expand All @@ -287,6 +295,7 @@ def _process_anthropic_response_handler(
tools=tools,
tool_dict=tool_dict,
is_async=is_async,
post_tool_function=post_tool_function,
) # Caller must await this
else:
return asyncio.run(
Expand All @@ -297,6 +306,7 @@ def _process_anthropic_response_handler(
tools=tools,
tool_dict=tool_dict,
is_async=is_async,
post_tool_function=post_tool_function,
)
)
except Exception as e:
Expand Down Expand Up @@ -391,6 +401,7 @@ async def chat_anthropic_async(
timeout=100,
prediction=None,
reasoning_effort=None,
post_tool_function: Callable = None,
):
"""
Asynchronous Anthropic chat.
Expand Down Expand Up @@ -445,6 +456,7 @@ async def chat_anthropic_async(
tools=tools,
tool_dict=tool_dict,
is_async=True,
post_tool_function=post_tool_function,
)
)

Expand Down Expand Up @@ -572,12 +584,13 @@ def _build_openai_params(
async def _process_openai_response(
client,
response,
request_params,
tools,
tool_dict,
request_params: Dict[str, Any],
tools: List[Callable],
tool_dict: Dict[str, Callable],
response_format,
model,
is_async,
model: str,
is_async: bool,
post_tool_function: Callable = None,
):
"""
Extract content (including any tool calls) and usage info from OpenAI response.
Expand Down Expand Up @@ -616,6 +629,12 @@ async def _process_openai_response(
except Exception as e:
raise Exception(f"Error executing tool `{func_name}`: {e}")

if post_tool_function:
if inspect.iscoroutinefunction(post_tool_function):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL about inspect.iscoroutinefunction!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was a find by @wendy-aw!

await post_tool_function(func_name, args, result)
else:
post_tool_function(func_name, args, result)

# Store the tool call, result, and text
tools_used.append(func_name)
tool_outputs.append(
Expand Down Expand Up @@ -683,7 +702,8 @@ def _process_openai_response_handler(
tool_dict: Dict[str, Callable],
response_format,
model: str,
is_async=False,
is_async: bool = False,
post_tool_function: Callable = None,
):
"""
Processes OpenAI's response by determining whether to execute the response handling
Expand Down Expand Up @@ -712,6 +732,7 @@ def _process_openai_response_handler(
response_format=response_format,
model=model,
is_async=is_async,
post_tool_function=post_tool_function,
) # Caller must await this
else:
return asyncio.run(
Expand All @@ -724,6 +745,7 @@ def _process_openai_response_handler(
response_format=response_format,
model=model,
is_async=is_async,
post_tool_function=post_tool_function,
)
)

Expand All @@ -748,6 +770,7 @@ def chat_openai(
store: bool = True,
metadata: Dict[str, str] = None,
timeout: int = 100,
post_tool_function: Callable = None,
):
"""
Synchronous OpenAI chat.
Expand Down Expand Up @@ -821,6 +844,7 @@ def chat_openai(
response_format=response_format,
model=model,
is_async=False,
post_tool_function=post_tool_function,
)

return LLMResponse(
Expand Down Expand Up @@ -852,6 +876,7 @@ async def chat_openai_async(
api_key: str = os.environ.get("OPENAI_API_KEY", ""),
prediction: Dict[str, str] = None,
reasoning_effort: str = None,
post_tool_function: Callable = None,
):
"""
Asynchronous OpenAI chat.
Expand Down Expand Up @@ -925,6 +950,7 @@ async def chat_openai_async(
response_format=response_format,
model=model,
is_async=True,
post_tool_function=post_tool_function,
)

return LLMResponse(
Expand Down Expand Up @@ -986,6 +1012,7 @@ def chat_together(
seed: int = 0,
tools=None,
tool_choice=None,
post_tool_function: Callable = None,
):
"""Synchronous Together chat."""
from together import Together
Expand Down Expand Up @@ -1027,6 +1054,7 @@ async def chat_together_async(
timeout=100,
prediction=None,
reasoning_effort=None,
post_tool_function: Callable = None,
):
"""Asynchronous Together chat."""
from together import AsyncTogether
Expand Down Expand Up @@ -1118,6 +1146,7 @@ def chat_gemini(
tool_choice=None,
store=True,
metadata=None,
post_tool_function: Callable = None,
):
"""Synchronous Gemini chat."""
from google import genai
Expand Down Expand Up @@ -1173,6 +1202,7 @@ async def chat_gemini_async(
timeout=100,
prediction=None,
reasoning_effort=None,
post_tool_function: Callable = None,
):
"""Asynchronous Gemini chat."""
from google import genai
Expand Down Expand Up @@ -1256,6 +1286,7 @@ async def chat_async(
tools=None,
tool_choice=None,
max_retries=3,
post_tool_function: Callable = None,
) -> LLMResponse:
"""
Returns the response from the LLM API for a single model that is passed in.
Expand All @@ -1264,6 +1295,14 @@ async def chat_async(
llm_function = map_model_to_chat_fn_async(model)
base_delay = 1 # Initial delay in seconds

if post_tool_function:
# get number of input params from post_tool_function
num_params = len(inspect.signature(post_tool_function).parameters)
if num_params != 3:
raise ValueError(
"post_tool_function must have exactly three parameters: function_name, input_args, and tool_results"
)

for attempt in range(max_retries):
try:
if attempt > 0 and backup_model is not None:
Expand All @@ -1287,6 +1326,7 @@ async def chat_async(
timeout=timeout,
prediction=prediction,
reasoning_effort=reasoning_effort,
post_tool_function=post_tool_function,
)
else:
if not os.getenv("DEEPSEEK_API_KEY"):
Expand All @@ -1306,6 +1346,7 @@ async def chat_async(
reasoning_effort=reasoning_effort,
base_url="https://api.deepseek.com",
api_key=os.getenv("DEEPSEEK_API_KEY"),
post_tool_function=post_tool_function,
)
except Exception as e:
delay = base_delay * (2**attempt) # Exponential backoff
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def package_files(directory):
name="defog",
packages=find_packages(),
package_data={"defog": ["gcp/*", "aws/*"] + next_static_files},
version="0.67.2",
version="0.67.3",
description="Defog is a Python library that helps you generate data queries from natural language questions.",
author="Full Stack Data Pte. Ltd.",
license="MIT",
Expand Down
114 changes: 114 additions & 0 deletions tests/test_llm_tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,30 @@
from defog.llm.utils_function_calling import get_function_specs
from pydantic import BaseModel, Field
import aiohttp
from io import StringIO
import json

# ==================================================================================================
# Functions for function calling
# ==================================================================================================

IO_STREAM = StringIO()


def log_to_file(function_name, args, result):
"""
Simple function to test logging to a StringIO object.
Used in test_post_tool_calls_openai and test_post_tool_calls_anthropic
"""
message = {
"function_name": function_name,
"args": args,
"result": result,
}
message = json.dumps(message, indent=4)
IO_STREAM.write(message + "\n")
return IO_STREAM.getvalue()


class WeatherInput(BaseModel):
latitude: float = Field(default=0.0, description="The latitude of the location")
Expand Down Expand Up @@ -263,6 +282,101 @@ async def test_tool_use_weather_async_anthropic(self):
self.assertGreaterEqual(float(result.content), 21)
self.assertLessEqual(float(result.content), 38)

@pytest.mark.asyncio
async def test_post_tool_calls_openai(self):
result = await chat_async(
model="gpt-4o",
messages=[
{
"role": "user",
"content": self.arithmetic_qn,
},
],
tools=self.tools,
post_tool_function=log_to_file,
)
print(result)
self.assertEqual(result.content, self.arithmetic_answer)
for expected, actual in zip(
self.arithmetic_expected_tool_outputs, result.tool_outputs
):
self.assertEqual(expected["name"], actual["name"])
self.assertEqual(expected["args"], actual["args"])
self.assertEqual(expected["result"], actual["result"])
self.assertSetEqual(set(result.tools_used), {"numsum", "numprod"})
expected_stream_value = (
json.dumps(
{
"function_name": "numprod",
"args": {"a": 31283, "b": 2323},
"result": 72670409,
},
indent=4,
)
+ "\n"
+ json.dumps(
{
"function_name": "numsum",
"args": {"a": 72670409, "b": 5},
"result": 72670414,
},
indent=4,
)
+ "\n"
)
self.assertEqual(IO_STREAM.getvalue(), expected_stream_value)

# clear IO_STREAM
IO_STREAM.seek(0)
IO_STREAM.truncate()

async def test_post_tool_calls_anthropic(self):
result = await chat_async(
model="claude-3-5-sonnet-latest",
messages=[
{
"role": "user",
"content": self.arithmetic_qn,
},
],
tools=self.tools,
post_tool_function=log_to_file,
)
print(result)
self.assertEqual(result.content, self.arithmetic_answer)
for expected, actual in zip(
self.arithmetic_expected_tool_outputs, result.tool_outputs
):
self.assertEqual(expected["name"], actual["name"])
self.assertEqual(expected["args"], actual["args"])
self.assertEqual(expected["result"], actual["result"])
self.assertSetEqual(set(result.tools_used), {"numsum", "numprod"})
expected_stream_value = (
json.dumps(
{
"function_name": "numprod",
"args": {"a": 31283, "b": 2323},
"result": 72670409,
},
indent=4,
)
+ "\n"
+ json.dumps(
{
"function_name": "numsum",
"args": {"a": 72670409, "b": 5},
"result": 72670414,
},
indent=4,
)
+ "\n"
)
self.assertEqual(IO_STREAM.getvalue(), expected_stream_value)

# clear IO_STREAM
IO_STREAM.seek(0)
IO_STREAM.truncate()

def test_async_tool_in_sync_function_openai(self):
result = chat_openai(
model="gpt-4o",
Expand Down