Skip to content

Commit c6fb822

Browse files
authored
Fix streamable http sampling (#693)
1 parent ed25167 commit c6fb822

File tree

7 files changed

+152
-23
lines changed

7 files changed

+152
-23
lines changed

src/mcp/cli/claude.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def get_claude_config_path() -> Path | None:
3131
return path
3232
return None
3333

34+
3435
def get_uv_path() -> str:
3536
"""Get the full path to the uv executable."""
3637
uv_path = shutil.which("uv")
@@ -42,6 +43,7 @@ def get_uv_path() -> str:
4243
return "uv" # Fall back to just "uv" if not found
4344
return uv_path
4445

46+
4547
def update_claude_config(
4648
file_spec: str,
4749
server_name: str,

src/mcp/client/streamable_http.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import anyio
1717
import httpx
18+
from anyio.abc import TaskGroup
1819
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1920
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
2021

@@ -239,7 +240,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
239240
break
240241

241242
async def _handle_post_request(self, ctx: RequestContext) -> None:
242-
"""Handle a POST request with response processing."""
243+
"""Handle a POST request with response processing."""
243244
headers = self._update_headers_with_session(ctx.headers)
244245
message = ctx.session_message.message
245246
is_initialization = self._is_initialization_request(message)
@@ -300,7 +301,7 @@ async def _handle_sse_response(
300301
try:
301302
event_source = EventSource(response)
302303
async for sse in event_source.aiter_sse():
303-
await self._handle_sse_event(
304+
is_complete = await self._handle_sse_event(
304305
sse,
305306
ctx.read_stream_writer,
306307
resumption_callback=(
@@ -309,6 +310,10 @@ async def _handle_sse_response(
309310
else None
310311
),
311312
)
313+
# If the SSE event indicates completion, like returning respose/error
314+
# break the loop
315+
if is_complete:
316+
break
312317
except Exception as e:
313318
logger.exception("Error reading SSE stream:")
314319
await ctx.read_stream_writer.send(e)
@@ -344,6 +349,7 @@ async def post_writer(
344349
read_stream_writer: StreamWriter,
345350
write_stream: MemoryObjectSendStream[SessionMessage],
346351
start_get_stream: Callable[[], None],
352+
tg: TaskGroup,
347353
) -> None:
348354
"""Handle writing requests to the server."""
349355
try:
@@ -375,10 +381,17 @@ async def post_writer(
375381
sse_read_timeout=self.sse_read_timeout,
376382
)
377383

378-
if is_resumption:
379-
await self._handle_resumption_request(ctx)
384+
async def handle_request_async():
385+
if is_resumption:
386+
await self._handle_resumption_request(ctx)
387+
else:
388+
await self._handle_post_request(ctx)
389+
390+
# If this is a request, start a new task to handle it
391+
if isinstance(message.root, JSONRPCRequest):
392+
tg.start_soon(handle_request_async)
380393
else:
381-
await self._handle_post_request(ctx)
394+
await handle_request_async()
382395

383396
except Exception as exc:
384397
logger.error(f"Error in post_writer: {exc}")
@@ -466,6 +479,7 @@ def start_get_stream() -> None:
466479
read_stream_writer,
467480
write_stream,
468481
start_get_stream,
482+
tg,
469483
)
470484

471485
try:

src/mcp/server/session.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
4747

4848
import mcp.types as types
4949
from mcp.server.models import InitializationOptions
50-
from mcp.shared.message import SessionMessage
50+
from mcp.shared.message import ServerMessageMetadata, SessionMessage
5151
from mcp.shared.session import (
5252
BaseSession,
5353
RequestResponder,
@@ -230,10 +230,11 @@ async def create_message(
230230
stop_sequences: list[str] | None = None,
231231
metadata: dict[str, Any] | None = None,
232232
model_preferences: types.ModelPreferences | None = None,
233+
related_request_id: types.RequestId | None = None,
233234
) -> types.CreateMessageResult:
234235
"""Send a sampling/create_message request."""
235236
return await self.send_request(
236-
types.ServerRequest(
237+
request=types.ServerRequest(
237238
types.CreateMessageRequest(
238239
method="sampling/createMessage",
239240
params=types.CreateMessageRequestParams(
@@ -248,7 +249,10 @@ async def create_message(
248249
),
249250
)
250251
),
251-
types.CreateMessageResult,
252+
result_type=types.CreateMessageResult,
253+
metadata=ServerMessageMetadata(
254+
related_request_id=related_request_id,
255+
),
252256
)
253257

254258
async def list_roots(self) -> types.ListRootsResult:

src/mcp/server/streamable_http.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
ErrorData,
3434
JSONRPCError,
3535
JSONRPCMessage,
36-
JSONRPCNotification,
3736
JSONRPCRequest,
3837
JSONRPCResponse,
3938
RequestId,
@@ -849,9 +848,15 @@ async def message_router():
849848
# Determine which request stream(s) should receive this message
850849
message = session_message.message
851850
target_request_id = None
852-
if isinstance(
853-
message.root, JSONRPCNotification | JSONRPCRequest
854-
):
851+
# Check if this is a response
852+
if isinstance(message.root, JSONRPCResponse | JSONRPCError):
853+
response_id = str(message.root.id)
854+
# If this response is for an existing request stream,
855+
# send it there
856+
if response_id in self._request_streams:
857+
target_request_id = response_id
858+
859+
else:
855860
# Extract related_request_id from meta if it exists
856861
if (
857862
session_message.metadata is not None
@@ -865,10 +870,12 @@ async def message_router():
865870
target_request_id = str(
866871
session_message.metadata.related_request_id
867872
)
868-
else:
869-
target_request_id = str(message.root.id)
870873

871-
request_stream_id = target_request_id or GET_STREAM_KEY
874+
request_stream_id = (
875+
target_request_id
876+
if target_request_id is not None
877+
else GET_STREAM_KEY
878+
)
872879

873880
# Store the event if we have an event store,
874881
# regardless of whether a client is connected

src/mcp/shared/session.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,6 @@ async def send_request(
223223
Do not use this method to emit notifications! Use send_notification()
224224
instead.
225225
"""
226-
227226
request_id = self._request_id
228227
self._request_id = request_id + 1
229228

tests/client/test_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_absolute_uv_path(mock_config_path: Path):
5454
"""Test that the absolute path to uv is used when available."""
5555
# Mock the shutil.which function to return a fake path
5656
mock_uv_path = "/usr/local/bin/uv"
57-
57+
5858
with patch("mcp.cli.claude.get_uv_path", return_value=mock_uv_path):
5959
# Setup
6060
server_name = "test_server"
@@ -71,5 +71,5 @@ def test_absolute_uv_path(mock_config_path: Path):
7171
# Verify the command is the absolute path
7272
server_config = config["mcpServers"][server_name]
7373
command = server_config["command"]
74-
75-
assert command == mock_uv_path
74+
75+
assert command == mock_uv_path

tests/shared/test_streamable_http.py

Lines changed: 107 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import socket
99
import time
1010
from collections.abc import Generator
11+
from typing import Any
1112

1213
import anyio
1314
import httpx
@@ -33,6 +34,7 @@
3334
StreamId,
3435
)
3536
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
37+
from mcp.shared.context import RequestContext
3638
from mcp.shared.exceptions import McpError
3739
from mcp.shared.message import (
3840
ClientMessageMetadata,
@@ -139,6 +141,11 @@ async def handle_list_tools() -> list[Tool]:
139141
description="A long-running tool that sends periodic notifications",
140142
inputSchema={"type": "object", "properties": {}},
141143
),
144+
Tool(
145+
name="test_sampling_tool",
146+
description="A tool that triggers server-side sampling",
147+
inputSchema={"type": "object", "properties": {}},
148+
),
142149
]
143150

144151
@self.call_tool()
@@ -174,6 +181,34 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
174181

175182
return [TextContent(type="text", text="Completed!")]
176183

184+
elif name == "test_sampling_tool":
185+
# Test sampling by requesting the client to sample a message
186+
sampling_result = await ctx.session.create_message(
187+
messages=[
188+
types.SamplingMessage(
189+
role="user",
190+
content=types.TextContent(
191+
type="text", text="Server needs client sampling"
192+
),
193+
)
194+
],
195+
max_tokens=100,
196+
related_request_id=ctx.request_id,
197+
)
198+
199+
# Return the sampling result in the tool response
200+
response = (
201+
sampling_result.content.text
202+
if sampling_result.content.type == "text"
203+
else None
204+
)
205+
return [
206+
TextContent(
207+
type="text",
208+
text=f"Response from sampling: {response}",
209+
)
210+
]
211+
177212
return [TextContent(type="text", text=f"Called {name}")]
178213

179214

@@ -754,7 +789,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session)
754789
"""Test client tool invocation."""
755790
# First list tools
756791
tools = await initialized_client_session.list_tools()
757-
assert len(tools.tools) == 3
792+
assert len(tools.tools) == 4
758793
assert tools.tools[0].name == "test_tool"
759794

760795
# Call the tool
@@ -795,7 +830,7 @@ async def test_streamablehttp_client_session_persistence(
795830

796831
# Make multiple requests to verify session persistence
797832
tools = await session.list_tools()
798-
assert len(tools.tools) == 3
833+
assert len(tools.tools) == 4
799834

800835
# Read a resource
801836
resource = await session.read_resource(uri=AnyUrl("foobar://test-persist"))
@@ -826,7 +861,7 @@ async def test_streamablehttp_client_json_response(
826861

827862
# Check tool listing
828863
tools = await session.list_tools()
829-
assert len(tools.tools) == 3
864+
assert len(tools.tools) == 4
830865

831866
# Call a tool and verify JSON response handling
832867
result = await session.call_tool("test_tool", {})
@@ -905,7 +940,7 @@ async def test_streamablehttp_client_session_termination(
905940

906941
# Make a request to confirm session is working
907942
tools = await session.list_tools()
908-
assert len(tools.tools) == 3
943+
assert len(tools.tools) == 4
909944

910945
headers = {}
911946
if captured_session_id:
@@ -1054,3 +1089,71 @@ async def run_tool():
10541089
assert not any(
10551090
n in captured_notifications_pre for n in captured_notifications
10561091
)
1092+
1093+
1094+
@pytest.mark.anyio
1095+
async def test_streamablehttp_server_sampling(basic_server, basic_server_url):
1096+
"""Test server-initiated sampling request through streamable HTTP transport."""
1097+
print("Testing server sampling...")
1098+
# Variable to track if sampling callback was invoked
1099+
sampling_callback_invoked = False
1100+
captured_message_params = None
1101+
1102+
# Define sampling callback that returns a mock response
1103+
async def sampling_callback(
1104+
context: RequestContext[ClientSession, Any],
1105+
params: types.CreateMessageRequestParams,
1106+
) -> types.CreateMessageResult:
1107+
nonlocal sampling_callback_invoked, captured_message_params
1108+
sampling_callback_invoked = True
1109+
captured_message_params = params
1110+
message_received = (
1111+
params.messages[0].content.text
1112+
if params.messages[0].content.type == "text"
1113+
else None
1114+
)
1115+
1116+
return types.CreateMessageResult(
1117+
role="assistant",
1118+
content=types.TextContent(
1119+
type="text",
1120+
text=f"Received message from server: {message_received}",
1121+
),
1122+
model="test-model",
1123+
stopReason="endTurn",
1124+
)
1125+
1126+
# Create client with sampling callback
1127+
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
1128+
read_stream,
1129+
write_stream,
1130+
_,
1131+
):
1132+
async with ClientSession(
1133+
read_stream,
1134+
write_stream,
1135+
sampling_callback=sampling_callback,
1136+
) as session:
1137+
# Initialize the session
1138+
result = await session.initialize()
1139+
assert isinstance(result, InitializeResult)
1140+
1141+
# Call the tool that triggers server-side sampling
1142+
tool_result = await session.call_tool("test_sampling_tool", {})
1143+
1144+
# Verify the tool result contains the expected content
1145+
assert len(tool_result.content) == 1
1146+
assert tool_result.content[0].type == "text"
1147+
assert (
1148+
"Response from sampling: Received message from server"
1149+
in tool_result.content[0].text
1150+
)
1151+
1152+
# Verify sampling callback was invoked
1153+
assert sampling_callback_invoked
1154+
assert captured_message_params is not None
1155+
assert len(captured_message_params.messages) == 1
1156+
assert (
1157+
captured_message_params.messages[0].content.text
1158+
== "Server needs client sampling"
1159+
)

0 commit comments

Comments
 (0)