@@ -102,6 +102,39 @@ def create_user(user: UserInput, flag: bool) -> dict:
102
102
assert "age" in tool .parameters ["$defs" ]["UserInput" ]["properties" ]
103
103
assert "flag" in tool .parameters ["properties" ]
104
104
105
+ def test_add_callable_object (self ):
106
+ """Test registering a callable object."""
107
+
108
+ class MyTool :
109
+ def __init__ (self ):
110
+ self .__name__ = "MyTool"
111
+
112
+ def __call__ (self , x : int ) -> int :
113
+ return x * 2
114
+
115
+ manager = ToolManager ()
116
+ tool = manager .add_tool (MyTool ())
117
+ assert tool .name == "MyTool"
118
+ assert tool .is_async is False
119
+ assert tool .parameters ["properties" ]["x" ]["type" ] == "integer"
120
+
121
+ @pytest .mark .anyio
122
+ async def test_add_async_callable_object (self ):
123
+ """Test registering an async callable object."""
124
+
125
+ class MyAsyncTool :
126
+ def __init__ (self ):
127
+ self .__name__ = "MyAsyncTool"
128
+
129
+ async def __call__ (self , x : int ) -> int :
130
+ return x * 2
131
+
132
+ manager = ToolManager ()
133
+ tool = manager .add_tool (MyAsyncTool ())
134
+ assert tool .name == "MyAsyncTool"
135
+ assert tool .is_async is True
136
+ assert tool .parameters ["properties" ]["x" ]["type" ] == "integer"
137
+
105
138
def test_add_invalid_tool (self ):
106
139
manager = ToolManager ()
107
140
with pytest .raises (AttributeError ):
@@ -168,6 +201,34 @@ async def double(n: int) -> int:
168
201
result = await manager .call_tool ("double" , {"n" : 5 })
169
202
assert result == 10
170
203
204
+ @pytest .mark .anyio
205
+ async def test_call_object_tool (self ):
206
+ class MyTool :
207
+ def __init__ (self ):
208
+ self .__name__ = "MyTool"
209
+
210
+ def __call__ (self , x : int ) -> int :
211
+ return x * 2
212
+
213
+ manager = ToolManager ()
214
+ tool = manager .add_tool (MyTool ())
215
+ result = await tool .run ({"x" : 5 })
216
+ assert result == 10
217
+
218
+ @pytest .mark .anyio
219
+ async def test_call_async_object_tool (self ):
220
+ class MyAsyncTool :
221
+ def __init__ (self ):
222
+ self .__name__ = "MyAsyncTool"
223
+
224
+ async def __call__ (self , x : int ) -> int :
225
+ return x * 2
226
+
227
+ manager = ToolManager ()
228
+ tool = manager .add_tool (MyAsyncTool ())
229
+ result = await tool .run ({"x" : 5 })
230
+ assert result == 10
231
+
171
232
@pytest .mark .anyio
172
233
async def test_call_tool_with_default_args (self ):
173
234
def add (a : int , b : int = 1 ) -> int :
0 commit comments