Skip to content

Commit

Permalink
Merge pull request #107 from chatchat-space/fix/_new_to_args_and_kwargs
Browse files Browse the repository at this point in the history
Fix/ new to args and kwargs
  • Loading branch information
yuehua-s authored Feb 17, 2025
2 parents 52e40c1 + b5a1605 commit 8acd35d
Show file tree
Hide file tree
Showing 3 changed files with 1,538 additions and 1,386 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
import re
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
from typing import Any, Callable, Optional, Type, Union

from langchain.agents import tool
from langchain_core.tools import BaseTool
Expand All @@ -14,51 +14,6 @@
_TOOLS_REGISTRY = {}


def _new_parse_input(
self,
tool_input: Union[str, Dict],
) -> Union[str, Dict[str, Any]]:
"""Convert tool input to pydantic model."""
input_args = self.args_schema
if isinstance(tool_input, str):
if input_args is not None:
key_ = next(iter(input_args.__fields__.keys()))
input_args.validate({key_: tool_input})
return tool_input
else:
if input_args is not None:
result = input_args.parse_obj(tool_input)
return result.dict()


def _new_to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
# For backwards compatibility, if run_input is a string,
# pass as a positional argument.
if isinstance(tool_input, str):
return (tool_input,), {}
else:
# for tool defined with `*args` parameters
# the args_schema has a field named `args`
# it should be expanded to actual *args
# e.g.: test_tools
# .test_named_tool_decorator_return_direct
# .search_api
if "args" in tool_input:
args = tool_input["args"]
if args is None:
tool_input.pop("args")
return (), tool_input
elif isinstance(args, tuple):
tool_input.pop("args")
return args, tool_input
return (), tool_input


BaseTool._parse_input = _new_parse_input
BaseTool._to_args_and_kwargs = _new_to_args_and_kwargs
###############################


def regist_tool(
*args: Any,
title: str = "",
Expand Down Expand Up @@ -140,4 +95,4 @@ def __str__(self) -> str:
elif callable(self.format):
return self.format(self)
else:
return str(self.data)
return str(self.data)
Loading

0 comments on commit 8acd35d

Please sign in to comment.