1
1
from datetime import timedelta
2
- from typing import Protocol , Any
2
+ from typing import Any , Protocol
3
3
4
4
from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
5
- from pydantic import AnyUrl
5
+ from pydantic import AnyUrl , TypeAdapter
6
6
7
- from mcp .shared .context import RequestContext
8
7
import mcp .types as types
8
+ from mcp .shared .context import RequestContext
9
9
from mcp .shared .session import BaseSession , RequestResponder
10
10
from mcp .shared .version import SUPPORTED_PROTOCOL_VERSIONS
11
11
12
12
13
13
class SamplingFnT (Protocol ):
14
14
async def __call__ (
15
- self , context : RequestContext ["ClientSession" , Any ], params : types .CreateMessageRequestParams
16
- ) -> types .CreateMessageResult :
17
- ...
15
+ self ,
16
+ context : RequestContext ["ClientSession" , Any ],
17
+ params : types .CreateMessageRequestParams ,
18
+ ) -> types .CreateMessageResult | types .ErrorData : ...
18
19
19
20
20
21
class ListRootsFnT (Protocol ):
21
22
async def __call__ (
22
23
self , context : RequestContext ["ClientSession" , Any ]
23
- ) -> types .ListRootsResult :
24
- ...
24
+ ) -> types .ListRootsResult | types .ErrorData : ...
25
+
26
+
27
+ async def _default_sampling_callback (
28
+ context : RequestContext ["ClientSession" , Any ],
29
+ params : types .CreateMessageRequestParams ,
30
+ ) -> types .CreateMessageResult | types .ErrorData :
31
+ return types .ErrorData (
32
+ code = types .INVALID_REQUEST ,
33
+ message = "Sampling not supported" ,
34
+ )
35
+
36
+
37
+ async def _default_list_roots_callback (
38
+ context : RequestContext ["ClientSession" , Any ],
39
+ ) -> types .ListRootsResult | types .ErrorData :
40
+ return types .ErrorData (
41
+ code = types .INVALID_REQUEST ,
42
+ message = "List roots not supported" ,
43
+ )
44
+
45
+
46
+ ClientResponse = TypeAdapter (types .ClientResult | types .ErrorData )
25
47
26
48
27
49
class ClientSession (
@@ -33,8 +55,6 @@ class ClientSession(
33
55
types .ServerNotification ,
34
56
]
35
57
):
36
- _sampling_callback : SamplingFnT | None = None
37
-
38
58
def __init__ (
39
59
self ,
40
60
read_stream : MemoryObjectReceiveStream [types .JSONRPCMessage | Exception ],
@@ -50,8 +70,8 @@ def __init__(
50
70
types .ServerNotification ,
51
71
read_timeout_seconds = read_timeout_seconds ,
52
72
)
53
- self ._sampling_callback = sampling_callback
54
- self ._list_roots_callback = list_roots_callback
73
+ self ._sampling_callback = sampling_callback or _default_sampling_callback
74
+ self ._list_roots_callback = list_roots_callback or _default_list_roots_callback
55
75
56
76
async def initialize (self ) -> types .InitializeResult :
57
77
sampling = (
@@ -278,27 +298,28 @@ async def send_roots_list_changed(self) -> None:
278
298
async def _received_request (
279
299
self , responder : RequestResponder [types .ServerRequest , types .ClientResult ]
280
300
) -> None :
281
-
282
301
ctx = RequestContext [ClientSession , Any ](
283
302
request_id = responder .request_id ,
284
303
meta = responder .request_meta ,
285
304
session = self ,
286
305
lifespan_context = None ,
287
306
)
288
-
307
+
289
308
match responder .request .root :
290
309
case types .CreateMessageRequest (params = params ):
291
- if self . _sampling_callback is not None :
310
+ with responder :
292
311
response = await self ._sampling_callback (ctx , params )
293
- client_response = types . ClientResult ( root = response )
294
- with responder :
295
- await responder . respond ( client_response )
312
+ client_response = ClientResponse . validate_python ( response )
313
+ await responder . respond ( client_response )
314
+
296
315
case types .ListRootsRequest ():
297
- if self . _list_roots_callback is not None :
316
+ with responder :
298
317
response = await self ._list_roots_callback (ctx )
299
- client_response = types . ClientResult ( root = response )
300
- with responder :
301
- await responder . respond ( client_response )
318
+ client_response = ClientResponse . validate_python ( response )
319
+ await responder . respond ( client_response )
320
+
302
321
case types .PingRequest ():
303
322
with responder :
304
- await responder .respond (types .ClientResult (root = types .EmptyResult ()))
323
+ return await responder .respond (
324
+ types .ClientResult (root = types .EmptyResult ())
325
+ )
0 commit comments