Skip to content

Commit 005483a

Browse files
Refactored default behaviour, updated tests
1 parent 9277c69 commit 005483a

File tree

3 files changed

+80
-29
lines changed

3 files changed

+80
-29
lines changed

src/mcp/client/session.py

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,49 @@
11
from datetime import timedelta
2-
from typing import Protocol, Any
2+
from typing import Any, Protocol
33

44
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
5-
from pydantic import AnyUrl
5+
from pydantic import AnyUrl, TypeAdapter
66

7-
from mcp.shared.context import RequestContext
87
import mcp.types as types
8+
from mcp.shared.context import RequestContext
99
from mcp.shared.session import BaseSession, RequestResponder
1010
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
1111

1212

1313
class SamplingFnT(Protocol):
1414
async def __call__(
15-
self, context: RequestContext["ClientSession", Any], params: types.CreateMessageRequestParams
16-
) -> types.CreateMessageResult:
17-
...
15+
self,
16+
context: RequestContext["ClientSession", Any],
17+
params: types.CreateMessageRequestParams,
18+
) -> types.CreateMessageResult | types.ErrorData: ...
1819

1920

2021
class ListRootsFnT(Protocol):
2122
async def __call__(
2223
self, context: RequestContext["ClientSession", Any]
23-
) -> types.ListRootsResult:
24-
...
24+
) -> types.ListRootsResult | types.ErrorData: ...
25+
26+
27+
async def _default_sampling_callback(
28+
context: RequestContext["ClientSession", Any],
29+
params: types.CreateMessageRequestParams,
30+
) -> types.CreateMessageResult | types.ErrorData:
31+
return types.ErrorData(
32+
code=types.INVALID_REQUEST,
33+
message="Sampling not supported",
34+
)
35+
36+
37+
async def _default_list_roots_callback(
38+
context: RequestContext["ClientSession", Any],
39+
) -> types.ListRootsResult | types.ErrorData:
40+
return types.ErrorData(
41+
code=types.INVALID_REQUEST,
42+
message="List roots not supported",
43+
)
44+
45+
46+
ClientResponse = TypeAdapter(types.ClientResult | types.ErrorData)
2547

2648

2749
class ClientSession(
@@ -33,8 +55,6 @@ class ClientSession(
3355
types.ServerNotification,
3456
]
3557
):
36-
_sampling_callback: SamplingFnT | None = None
37-
3858
def __init__(
3959
self,
4060
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
@@ -50,8 +70,8 @@ def __init__(
5070
types.ServerNotification,
5171
read_timeout_seconds=read_timeout_seconds,
5272
)
53-
self._sampling_callback = sampling_callback
54-
self._list_roots_callback = list_roots_callback
73+
self._sampling_callback = sampling_callback or _default_sampling_callback
74+
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
5575

5676
async def initialize(self) -> types.InitializeResult:
5777
sampling = (
@@ -278,27 +298,28 @@ async def send_roots_list_changed(self) -> None:
278298
async def _received_request(
279299
self, responder: RequestResponder[types.ServerRequest, types.ClientResult]
280300
) -> None:
281-
282301
ctx = RequestContext[ClientSession, Any](
283302
request_id=responder.request_id,
284303
meta=responder.request_meta,
285304
session=self,
286305
lifespan_context=None,
287306
)
288-
307+
289308
match responder.request.root:
290309
case types.CreateMessageRequest(params=params):
291-
if self._sampling_callback is not None:
310+
with responder:
292311
response = await self._sampling_callback(ctx, params)
293-
client_response = types.ClientResult(root=response)
294-
with responder:
295-
await responder.respond(client_response)
312+
client_response = ClientResponse.validate_python(response)
313+
await responder.respond(client_response)
314+
296315
case types.ListRootsRequest():
297-
if self._list_roots_callback is not None:
316+
with responder:
298317
response = await self._list_roots_callback(ctx)
299-
client_response = types.ClientResult(root=response)
300-
with responder:
301-
await responder.respond(client_response)
318+
client_response = ClientResponse.validate_python(response)
319+
await responder.respond(client_response)
320+
302321
case types.PingRequest():
303322
with responder:
304-
await responder.respond(types.ClientResult(root=types.EmptyResult()))
323+
return await responder.respond(
324+
types.ClientResult(root=types.EmptyResult())
325+
)

tests/client/test_list_roots_callback.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from pydantic import FileUrl
21
import pytest
2+
from pydantic import FileUrl
33

44
from mcp.client.session import ClientSession
55
from mcp.server.fastmcp.server import Context
@@ -10,6 +10,7 @@
1010
from mcp.types import (
1111
ListRootsResult,
1212
Root,
13+
TextContent,
1314
)
1415

1516

@@ -21,11 +22,11 @@ async def test_list_roots_callback():
2122

2223
callback_return = ListRootsResult(roots=[
2324
Root(
24-
uri=FileUrl("test://users/fake/test"),
25+
uri=FileUrl("file://users/fake/test"),
2526
name="Test Root 1",
2627
),
2728
Root(
28-
uri=FileUrl("test://users/fake/test/2"),
29+
uri=FileUrl("file://users/fake/test/2"),
2930
name="Test Root 2",
3031
)
3132
])
@@ -37,14 +38,29 @@ async def list_roots_callback(
3738

3839
@server.tool("test_list_roots")
3940
async def test_list_roots(context: Context, message: str):
40-
roots = context.session.list_roots()
41+
roots = await context.session.list_roots()
4142
assert roots == callback_return
4243
return True
4344

45+
# Test with list_roots callback
4446
async with create_session(
4547
server._mcp_server, list_roots_callback=list_roots_callback
4648
) as client_session:
4749
# Make a request to trigger sampling callback
48-
assert await client_session.call_tool(
50+
result = await client_session.call_tool(
4951
"test_list_roots", {"message": "test message"}
5052
)
53+
assert result.isError is False
54+
assert isinstance(result.content[0], TextContent)
55+
assert result.content[0].text == 'true'
56+
57+
# Test without list_roots callback
58+
async with create_session(server._mcp_server) as client_session:
59+
# Make a request to trigger sampling callback
60+
result = await client_session.call_tool(
61+
"test_list_roots", {"message": "test message"}
62+
)
63+
assert result.isError is True
64+
assert isinstance(result.content[0], TextContent)
65+
assert result.content[0].text == 'Error executing tool test_list_roots: List roots not supported'
66+

tests/client/test_sampling_callback.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,24 @@ async def test_sampling_tool(message: str):
4747
assert value == callback_return
4848
return True
4949

50+
# Test with sampling callback
5051
async with create_session(
5152
server._mcp_server, sampling_callback=sampling_callback
5253
) as client_session:
5354
# Make a request to trigger sampling callback
54-
assert await client_session.call_tool(
55+
result = await client_session.call_tool(
5556
"test_sampling", {"message": "Test message for sampling"}
5657
)
58+
assert result.isError is False
59+
assert isinstance(result.content[0], TextContent)
60+
assert result.content[0].text == 'true'
61+
62+
# Test without sampling callback
63+
async with create_session(server._mcp_server) as client_session:
64+
# Make a request to trigger sampling callback
65+
result = await client_session.call_tool(
66+
"test_sampling", {"message": "Test message for sampling"}
67+
)
68+
assert result.isError is True
69+
assert isinstance(result.content[0], TextContent)
70+
assert result.content[0].text == 'Error executing tool test_sampling: Sampling not supported'

0 commit comments

Comments
 (0)