Skip to content

Commit c58adfe

Browse files
authored
Merge pull request #167 from modelcontextprotocol/davidsp/88-v2
feat: add request cancellation and cleanup
2 parents 27bfde9 + 733db0c commit c58adfe

File tree

5 files changed

+364
-28
lines changed

5 files changed

+364
-28
lines changed

src/mcp/server/lowlevel/server.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -453,10 +453,15 @@ async def run(
453453
logger.debug(f"Received message: {message}")
454454

455455
match message:
456-
case RequestResponder(request=types.ClientRequest(root=req)):
457-
await self._handle_request(
458-
message, req, session, raise_exceptions
459-
)
456+
case (
457+
RequestResponder(
458+
request=types.ClientRequest(root=req)
459+
) as responder
460+
):
461+
with responder:
462+
await self._handle_request(
463+
message, req, session, raise_exceptions
464+
)
460465
case types.ClientNotification(root=notify):
461466
await self._handle_notification(notify)
462467

src/mcp/server/session.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -126,19 +126,20 @@ async def _received_request(
126126
case types.InitializeRequest(params=params):
127127
self._initialization_state = InitializationState.Initializing
128128
self._client_params = params
129-
await responder.respond(
130-
types.ServerResult(
131-
types.InitializeResult(
132-
protocolVersion=types.LATEST_PROTOCOL_VERSION,
133-
capabilities=self._init_options.capabilities,
134-
serverInfo=types.Implementation(
135-
name=self._init_options.server_name,
136-
version=self._init_options.server_version,
137-
),
138-
instructions=self._init_options.instructions,
129+
with responder:
130+
await responder.respond(
131+
types.ServerResult(
132+
types.InitializeResult(
133+
protocolVersion=types.LATEST_PROTOCOL_VERSION,
134+
capabilities=self._init_options.capabilities,
135+
serverInfo=types.Implementation(
136+
name=self._init_options.server_name,
137+
version=self._init_options.server_version,
138+
),
139+
instructions=self._init_options.instructions,
140+
)
139141
)
140142
)
141-
)
142143
case _:
143144
if self._initialization_state != InitializationState.Initialized:
144145
raise RuntimeError(

src/mcp/shared/session.py

Lines changed: 106 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import logging
12
from contextlib import AbstractAsyncContextManager
23
from datetime import timedelta
3-
from typing import Generic, TypeVar
4+
from typing import Any, Callable, Generic, TypeVar
45

56
import anyio
67
import anyio.lowlevel
@@ -10,6 +11,7 @@
1011

1112
from mcp.shared.exceptions import McpError
1213
from mcp.types import (
14+
CancelledNotification,
1315
ClientNotification,
1416
ClientRequest,
1517
ClientResult,
@@ -38,27 +40,98 @@
3840

3941

4042
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
43+
"""Handles responding to MCP requests and manages request lifecycle.
44+
45+
This class MUST be used as a context manager to ensure proper cleanup and
46+
cancellation handling:
47+
48+
Example:
49+
with request_responder as resp:
50+
await resp.respond(result)
51+
52+
The context manager ensures:
53+
1. Proper cancellation scope setup and cleanup
54+
2. Request completion tracking
55+
3. Cleanup of in-flight requests
56+
"""
57+
4158
def __init__(
4259
self,
4360
request_id: RequestId,
4461
request_meta: RequestParams.Meta | None,
4562
request: ReceiveRequestT,
4663
session: "BaseSession",
64+
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
4765
) -> None:
4866
self.request_id = request_id
4967
self.request_meta = request_meta
5068
self.request = request
5169
self._session = session
52-
self._responded = False
70+
self._completed = False
71+
self._cancel_scope = anyio.CancelScope()
72+
self._on_complete = on_complete
73+
self._entered = False # Track if we're in a context manager
74+
75+
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
76+
"""Enter the context manager, enabling request cancellation tracking."""
77+
self._entered = True
78+
self._cancel_scope = anyio.CancelScope()
79+
self._cancel_scope.__enter__()
80+
return self
81+
82+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
83+
"""Exit the context manager, performing cleanup and notifying completion."""
84+
try:
85+
if self._completed:
86+
self._on_complete(self)
87+
finally:
88+
self._entered = False
89+
if not self._cancel_scope:
90+
raise RuntimeError("No active cancel scope")
91+
self._cancel_scope.__exit__(exc_type, exc_val, exc_tb)
5392

5493
async def respond(self, response: SendResultT | ErrorData) -> None:
55-
assert not self._responded, "Request already responded to"
56-
self._responded = True
94+
"""Send a response for this request.
95+
96+
Must be called within a context manager block.
97+
Raises:
98+
RuntimeError: If not used within a context manager
99+
AssertionError: If request was already responded to
100+
"""
101+
if not self._entered:
102+
raise RuntimeError("RequestResponder must be used as a context manager")
103+
assert not self._completed, "Request already responded to"
104+
105+
if not self.cancelled:
106+
self._completed = True
107+
108+
await self._session._send_response(
109+
request_id=self.request_id, response=response
110+
)
111+
112+
async def cancel(self) -> None:
113+
"""Cancel this request and mark it as completed."""
114+
if not self._entered:
115+
raise RuntimeError("RequestResponder must be used as a context manager")
116+
if not self._cancel_scope:
117+
raise RuntimeError("No active cancel scope")
57118

119+
self._cancel_scope.cancel()
120+
self._completed = True # Mark as completed so it's removed from in_flight
121+
# Send an error response to indicate cancellation
58122
await self._session._send_response(
59-
request_id=self.request_id, response=response
123+
request_id=self.request_id,
124+
response=ErrorData(code=0, message="Request cancelled", data=None),
60125
)
61126

127+
@property
128+
def in_flight(self) -> bool:
129+
return not self._completed and not self.cancelled
130+
131+
@property
132+
def cancelled(self) -> bool:
133+
return self._cancel_scope is not None and self._cancel_scope.cancel_called
134+
62135

63136
class BaseSession(
64137
AbstractAsyncContextManager,
@@ -82,6 +155,7 @@ class BaseSession(
82155
RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]
83156
]
84157
_request_id: int
158+
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
85159

86160
def __init__(
87161
self,
@@ -99,6 +173,7 @@ def __init__(
99173
self._receive_request_type = receive_request_type
100174
self._receive_notification_type = receive_notification_type
101175
self._read_timeout_seconds = read_timeout_seconds
176+
self._in_flight = {}
102177

103178
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
104179
anyio.create_memory_object_stream[
@@ -219,27 +294,45 @@ async def _receive_loop(self) -> None:
219294
by_alias=True, mode="json", exclude_none=True
220295
)
221296
)
297+
222298
responder = RequestResponder(
223299
request_id=message.root.id,
224300
request_meta=validated_request.root.params.meta
225301
if validated_request.root.params
226302
else None,
227303
request=validated_request,
228304
session=self,
305+
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
229306
)
230307

308+
self._in_flight[responder.request_id] = responder
231309
await self._received_request(responder)
232-
if not responder._responded:
310+
if not responder._completed:
233311
await self._incoming_message_stream_writer.send(responder)
312+
234313
elif isinstance(message.root, JSONRPCNotification):
235-
notification = self._receive_notification_type.model_validate(
236-
message.root.model_dump(
237-
by_alias=True, mode="json", exclude_none=True
314+
try:
315+
notification = self._receive_notification_type.model_validate(
316+
message.root.model_dump(
317+
by_alias=True, mode="json", exclude_none=True
318+
)
319+
)
320+
# Handle cancellation notifications
321+
if isinstance(notification.root, CancelledNotification):
322+
cancelled_id = notification.root.params.requestId
323+
if cancelled_id in self._in_flight:
324+
await self._in_flight[cancelled_id].cancel()
325+
else:
326+
await self._received_notification(notification)
327+
await self._incoming_message_stream_writer.send(
328+
notification
329+
)
330+
except Exception as e:
331+
# For other validation errors, log and continue
332+
logging.warning(
333+
f"Failed to validate notification: {e}. "
334+
f"Message was: {message.root}"
238335
)
239-
)
240-
241-
await self._received_notification(notification)
242-
await self._incoming_message_stream_writer.send(notification)
243336
else: # Response or error
244337
stream = self._response_streams.pop(message.root.id, None)
245338
if stream:

tests/issues/test_88_random_error.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""Test to reproduce issue #88: Random error thrown on response."""
2+
3+
from datetime import timedelta
4+
from pathlib import Path
5+
from typing import Sequence
6+
7+
import anyio
8+
import pytest
9+
10+
from mcp.client.session import ClientSession
11+
from mcp.server.lowlevel import Server
12+
from mcp.shared.exceptions import McpError
13+
from mcp.types import (
14+
EmbeddedResource,
15+
ImageContent,
16+
TextContent,
17+
)
18+
19+
20+
@pytest.mark.anyio
21+
async def test_notification_validation_error(tmp_path: Path):
22+
"""Test that timeouts are handled gracefully and don't break the server.
23+
24+
This test verifies that when a client request times out:
25+
1. The server task stays alive
26+
2. The server can still handle new requests
27+
3. The client can make new requests
28+
4. No resources are leaked
29+
"""
30+
31+
server = Server(name="test")
32+
request_count = 0
33+
slow_request_started = anyio.Event()
34+
slow_request_complete = anyio.Event()
35+
36+
@server.call_tool()
37+
async def slow_tool(
38+
name: str, arg
39+
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
40+
nonlocal request_count
41+
request_count += 1
42+
43+
if name == "slow":
44+
# Signal that slow request has started
45+
slow_request_started.set()
46+
# Long enough to ensure timeout
47+
await anyio.sleep(0.2)
48+
# Signal completion
49+
slow_request_complete.set()
50+
return [TextContent(type="text", text=f"slow {request_count}")]
51+
elif name == "fast":
52+
# Fast enough to complete before timeout
53+
await anyio.sleep(0.01)
54+
return [TextContent(type="text", text=f"fast {request_count}")]
55+
return [TextContent(type="text", text=f"unknown {request_count}")]
56+
57+
async def server_handler(read_stream, write_stream):
58+
await server.run(
59+
read_stream,
60+
write_stream,
61+
server.create_initialization_options(),
62+
raise_exceptions=True,
63+
)
64+
65+
async def client(read_stream, write_stream):
66+
# Use a timeout that's:
67+
# - Long enough for fast operations (>10ms)
68+
# - Short enough for slow operations (<200ms)
69+
# - Not too short to avoid flakiness
70+
async with ClientSession(
71+
read_stream, write_stream, read_timeout_seconds=timedelta(milliseconds=50)
72+
) as session:
73+
await session.initialize()
74+
75+
# First call should work (fast operation)
76+
result = await session.call_tool("fast")
77+
assert result.content == [TextContent(type="text", text="fast 1")]
78+
assert not slow_request_complete.is_set()
79+
80+
# Second call should timeout (slow operation)
81+
with pytest.raises(McpError) as exc_info:
82+
await session.call_tool("slow")
83+
assert "Timed out while waiting" in str(exc_info.value)
84+
85+
# Wait for slow request to complete in the background
86+
with anyio.fail_after(1): # Timeout after 1 second
87+
await slow_request_complete.wait()
88+
89+
# Third call should work (fast operation),
90+
# proving server is still responsive
91+
result = await session.call_tool("fast")
92+
assert result.content == [TextContent(type="text", text="fast 3")]
93+
94+
# Run server and client in separate task groups to avoid cancellation
95+
server_writer, server_reader = anyio.create_memory_object_stream(1)
96+
client_writer, client_reader = anyio.create_memory_object_stream(1)
97+
98+
server_ready = anyio.Event()
99+
100+
async def wrapped_server_handler(read_stream, write_stream):
101+
server_ready.set()
102+
await server_handler(read_stream, write_stream)
103+
104+
async with anyio.create_task_group() as tg:
105+
tg.start_soon(wrapped_server_handler, server_reader, client_writer)
106+
# Wait for server to start and initialize
107+
with anyio.fail_after(1): # Timeout after 1 second
108+
await server_ready.wait()
109+
# Run client in a separate task to avoid cancellation
110+
async with anyio.create_task_group() as client_tg:
111+
client_tg.start_soon(client, client_reader, server_writer)

0 commit comments

Comments
 (0)