Skip to content

Commit f2f4dbd

Browse files
Fix async callable object tools (#568)
1 parent d187643 commit f2f4dbd

File tree

2 files changed

+72
-1
lines changed

2 files changed

+72
-1
lines changed

src/mcp/server/fastmcp/tools/base.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations as _annotations
22

3+
import functools
34
import inspect
45
from collections.abc import Callable
56
from typing import TYPE_CHECKING, Any, get_origin
@@ -53,7 +54,7 @@ def from_function(
5354
raise ValueError("You must provide a name for lambda functions")
5455

5556
func_doc = description or fn.__doc__ or ""
56-
is_async = inspect.iscoroutinefunction(fn)
57+
is_async = _is_async_callable(fn)
5758

5859
if context_kwarg is None:
5960
sig = inspect.signature(fn)
@@ -98,3 +99,12 @@ async def run(
9899
)
99100
except Exception as e:
100101
raise ToolError(f"Error executing tool {self.name}: {e}") from e
102+
103+
104+
def _is_async_callable(obj: Any) -> bool:
105+
while isinstance(obj, functools.partial):
106+
obj = obj.func
107+
108+
return inspect.iscoroutinefunction(obj) or (
109+
callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None))
110+
)

tests/server/fastmcp/test_tool_manager.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,39 @@ def create_user(user: UserInput, flag: bool) -> dict:
102102
assert "age" in tool.parameters["$defs"]["UserInput"]["properties"]
103103
assert "flag" in tool.parameters["properties"]
104104

105+
def test_add_callable_object(self):
106+
"""Test registering a callable object."""
107+
108+
class MyTool:
109+
def __init__(self):
110+
self.__name__ = "MyTool"
111+
112+
def __call__(self, x: int) -> int:
113+
return x * 2
114+
115+
manager = ToolManager()
116+
tool = manager.add_tool(MyTool())
117+
assert tool.name == "MyTool"
118+
assert tool.is_async is False
119+
assert tool.parameters["properties"]["x"]["type"] == "integer"
120+
121+
@pytest.mark.anyio
122+
async def test_add_async_callable_object(self):
123+
"""Test registering an async callable object."""
124+
125+
class MyAsyncTool:
126+
def __init__(self):
127+
self.__name__ = "MyAsyncTool"
128+
129+
async def __call__(self, x: int) -> int:
130+
return x * 2
131+
132+
manager = ToolManager()
133+
tool = manager.add_tool(MyAsyncTool())
134+
assert tool.name == "MyAsyncTool"
135+
assert tool.is_async is True
136+
assert tool.parameters["properties"]["x"]["type"] == "integer"
137+
105138
def test_add_invalid_tool(self):
106139
manager = ToolManager()
107140
with pytest.raises(AttributeError):
@@ -168,6 +201,34 @@ async def double(n: int) -> int:
168201
result = await manager.call_tool("double", {"n": 5})
169202
assert result == 10
170203

204+
@pytest.mark.anyio
205+
async def test_call_object_tool(self):
206+
class MyTool:
207+
def __init__(self):
208+
self.__name__ = "MyTool"
209+
210+
def __call__(self, x: int) -> int:
211+
return x * 2
212+
213+
manager = ToolManager()
214+
tool = manager.add_tool(MyTool())
215+
result = await tool.run({"x": 5})
216+
assert result == 10
217+
218+
@pytest.mark.anyio
219+
async def test_call_async_object_tool(self):
220+
class MyAsyncTool:
221+
def __init__(self):
222+
self.__name__ = "MyAsyncTool"
223+
224+
async def __call__(self, x: int) -> int:
225+
return x * 2
226+
227+
manager = ToolManager()
228+
tool = manager.add_tool(MyAsyncTool())
229+
result = await tool.run({"x": 5})
230+
assert result == 10
231+
171232
@pytest.mark.anyio
172233
async def test_call_tool_with_default_args(self):
173234
def add(a: int, b: int = 1) -> int:

0 commit comments

Comments
 (0)