From f9598ad80719ec72320d6d34bca1b899613b8510 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 4 May 2025 17:53:46 +0000 Subject: [PATCH 01/31] attempt to support cancelation --- src/mcp/shared/session.py | 25 +++++++++++++++++++++++++ tests/shared/test_streamable_http.py | 2 +- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index cce8b1184..303e4f7f1 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -15,6 +15,7 @@ from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.types import ( CancelledNotification, + CancelledNotificationParams, ClientNotification, ClientRequest, ClientResult, @@ -33,6 +34,7 @@ SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest) SendResultT = TypeVar("SendResultT", ClientResult, ServerResult) SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification) +SendNotificationInternalT = TypeVar("SendNotificationInternalT", CancelledNotification, ClientNotification, ServerNotification) ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest) ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) ReceiveNotificationT = TypeVar( @@ -254,6 +256,8 @@ async def send_request( elif self._session_read_timeout_seconds is not None: timeout = self._session_read_timeout_seconds.total_seconds() + response_or_error = None + try: with anyio.fail_after(timeout): response_or_error = await response_stream_reader.receive() @@ -268,7 +272,21 @@ async def send_request( ), ) ) + except anyio.get_cancelled_exc_class(): + with anyio.CancelScope(shield=True): + notification = CancelledNotification( + method="notifications/cancelled", + params=CancelledNotificationParams( + requestId=request_id, + reason="cancelled" + ) + ) + await self._send_notification_internal(notification, request_id, ) + if response_or_error is None: + raise McpError( + ErrorData(code=32601, message="request cancelled") + ) if isinstance(response_or_error, JSONRPCError): raise McpError(response_or_error.error) else: @@ -283,6 +301,13 @@ async def send_notification( self, notification: SendNotificationT, related_request_id: RequestId | None = None, + ) -> None: + await self._send_notification_internal(notification, related_request_id) + + async def _send_notification_internal( + self, + notification: SendNotificationInternalT, + related_request_id: RequestId | None = None, ) -> None: """ Emits a notification, which is a one-way message that does not expect diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index b1dc7ea33..2fac4d5c8 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -989,7 +989,7 @@ async def test_streamablehttp_client_session_termination( # Attempt to make a request after termination with pytest.raises( McpError, - match="Session terminated", + match="request cancelled", ): await session.list_tools() From dfb3686a5d1f7385f31b714d5a420af630ae773c Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 4 May 2025 19:35:03 +0000 Subject: [PATCH 02/31] fix sending cancel --- src/mcp/shared/session.py | 36 ++++++++++++++-------------- tests/shared/test_streamable_http.py | 2 +- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 303e4f7f1..73870ed67 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -256,11 +256,25 @@ async def send_request( elif self._session_read_timeout_seconds is not None: timeout = self._session_read_timeout_seconds.total_seconds() - response_or_error = None - try: - with anyio.fail_after(timeout): + with anyio.fail_after(timeout) as scope: response_or_error = await response_stream_reader.receive() + + if scope.cancel_called: + with anyio.CancelScope(shield=True): + notification = CancelledNotification( + method="notifications/cancelled", + params=CancelledNotificationParams( + requestId=request_id, + reason="cancelled" + ) + ) + await self._send_notification_internal(notification, request_id) + + raise McpError( + ErrorData(code=32601, message="request cancelled") + ) + except TimeoutError: raise McpError( ErrorData( @@ -272,21 +286,7 @@ async def send_request( ), ) ) - except anyio.get_cancelled_exc_class(): - with anyio.CancelScope(shield=True): - notification = CancelledNotification( - method="notifications/cancelled", - params=CancelledNotificationParams( - requestId=request_id, - reason="cancelled" - ) - ) - await self._send_notification_internal(notification, request_id, ) - - if response_or_error is None: - raise McpError( - ErrorData(code=32601, message="request cancelled") - ) + if isinstance(response_or_error, JSONRPCError): raise McpError(response_or_error.error) else: diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 2fac4d5c8..b1dc7ea33 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -989,7 +989,7 @@ async def test_streamablehttp_client_session_termination( # Attempt to make a request after termination with pytest.raises( McpError, - match="request cancelled", + match="Session terminated", ): await session.list_tools() From abda067134a421878a9ce0fd6e99ed3b2c09122b Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 4 May 2025 19:54:45 +0000 Subject: [PATCH 03/31] add decorator for handling cancelation --- src/mcp/server/lowlevel/server.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 4b97b33da..dd93df49b 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -441,6 +441,22 @@ async def handler(req: types.ProgressNotification): return decorator + def cancel_notification(self): + def decorator( + func: Callable[[str | int, str | None], Awaitable[None]], + ): + logger.debug("Registering handler for ProgressNotification") + + async def handler(req: types.CancelledNotification): + await func( + req.params.requestId, req.params.reason + ) + + self.notification_handlers[types.CancelledNotification] = handler + return func + + return decorator + def completion(self): """Provides completions for prompts and resource templates""" @@ -587,12 +603,14 @@ async def _handle_notification(self, notify: Any): assert type(notify) in self.notification_handlers handler = self.notification_handlers[type(notify)] - logger.debug(f"Dispatching notification of type {type(notify).__name__}") + print(f"Dispatching notification of type {type(notify).__name__}") try: await handler(notify) except Exception as err: logger.error(f"Uncaught exception in notification handler: {err}") + else: + print(f"Not handling {notify}") async def _ping_handler(request: types.PingRequest) -> types.ServerResult: From 24553c65abd4944eeb4209955a85500ef6362ded Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 4 May 2025 19:55:18 +0000 Subject: [PATCH 04/31] fix capitalisation on error message --- src/mcp/shared/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 73870ed67..24acdc413 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -272,7 +272,7 @@ async def send_request( await self._send_notification_internal(notification, request_id) raise McpError( - ErrorData(code=32601, message="request cancelled") + ErrorData(code=32601, message="Request cancelled") ) except TimeoutError: From fe49931a390f0a2e2cda651dc230e404c470f671 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 4 May 2025 19:56:00 +0000 Subject: [PATCH 05/31] try to check notification event received --- tests/shared/test_session.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 59cb30c86..bb040a162 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -50,6 +50,7 @@ async def test_request_cancellation(): ev_tool_called = anyio.Event() ev_cancelled = anyio.Event() + ev_cancel_notified = anyio.Event() request_id = None # Start the request in a separate task so we can cancel it @@ -66,6 +67,11 @@ async def handle_call_tool(name: str, arguments: dict | None) -> list: await anyio.sleep(10) # Long enough to ensure we can cancel return [] raise ValueError(f"Unknown tool: {name}") + + @server.cancel_notification() + async def handle_cancel(requestId: str | int, reason: str | None): + nonlocal ev_cancel_notified + ev_cancel_notified.set() # Register the tool so it shows up in list_tools @server.list_tools() @@ -121,6 +127,9 @@ async def make_request(client_session): ) ) + with anyio.fail_after(1): + await ev_cancel_notified.wait() + # Give cancellation time to process with anyio.fail_after(1): await ev_cancelled.wait() From efd0ffdd8ed5e2d0bfe2a19811473312032364a4 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 4 May 2025 20:54:46 +0000 Subject: [PATCH 06/31] fix test --- tests/shared/test_session.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index bb040a162..4d0ce4553 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -51,7 +51,6 @@ async def test_request_cancellation(): ev_tool_called = anyio.Event() ev_cancelled = anyio.Event() ev_cancel_notified = anyio.Event() - request_id = None # Start the request in a separate task so we can cancel it def make_server() -> Server: @@ -60,9 +59,8 @@ def make_server() -> Server: # Register the tool handler @server.call_tool() async def handle_call_tool(name: str, arguments: dict | None) -> list: - nonlocal request_id, ev_tool_called + nonlocal ev_tool_called if name == "slow_tool": - request_id = server.request_context.request_id ev_tool_called.set() await anyio.sleep(10) # Long enough to ensure we can cancel return [] @@ -116,16 +114,7 @@ async def make_request(client_session): with anyio.fail_after(1): # Timeout after 1 second await ev_tool_called.wait() - # Send cancellation notification - assert request_id is not None - await client_session.send_notification( - ClientNotification( - CancelledNotification( - method="notifications/cancelled", - params=CancelledNotificationParams(requestId=request_id), - ) - ) - ) + tg.cancel_scope.cancel() with anyio.fail_after(1): await ev_cancel_notified.wait() From 1364b7a358421d202563032cc73cee741d389fc4 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 4 May 2025 20:58:37 +0000 Subject: [PATCH 07/31] fix debug message --- src/mcp/server/lowlevel/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index dd93df49b..a09c5fabf 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -445,7 +445,7 @@ def cancel_notification(self): def decorator( func: Callable[[str | int, str | None], Awaitable[None]], ): - logger.debug("Registering handler for ProgressNotification") + logger.debug("Registering handler for CancelledNotification") async def handler(req: types.CancelledNotification): await func( From 17ae44cec83543c7deba209702e4373b05d2da04 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 4 May 2025 21:06:27 +0000 Subject: [PATCH 08/31] prevent cancel on intialisation as per https://modelcontextprotocol.io/specification/2025-03-26/basic/utilities/cancellation --- src/mcp/client/session.py | 1 + src/mcp/shared/session.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 7bb8821f7..130fe33f2 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -140,6 +140,7 @@ async def initialize(self) -> types.InitializeResult: ) ), types.InitializeResult, + cancellable=False ) if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS: diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 24acdc413..62052ac25 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -216,6 +216,7 @@ async def send_request( result_type: type[ReceiveResultT], request_read_timeout_seconds: timedelta | None = None, metadata: MessageMetadata = None, + cancellable: bool = True, ) -> ReceiveResultT: """ Sends a request and wait for a response. Raises an McpError if the @@ -260,7 +261,7 @@ async def send_request( with anyio.fail_after(timeout) as scope: response_or_error = await response_stream_reader.receive() - if scope.cancel_called: + if cancellable and scope.cancel_called: with anyio.CancelScope(shield=True): notification = CancelledNotification( method="notifications/cancelled", From 06f4b3c5a3108ace2bd7028e562bd13d82b56dc7 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 4 May 2025 21:07:40 +0000 Subject: [PATCH 09/31] remove unused imports --- tests/shared/test_session.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 4d0ce4553..037553b0e 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -9,9 +9,6 @@ from mcp.shared.exceptions import McpError from mcp.shared.memory import create_connected_server_and_client_session from mcp.types import ( - CancelledNotification, - CancelledNotificationParams, - ClientNotification, ClientRequest, EmptyResult, ) From 07e1a525b02faff8f2fee9894daaaaeffc57afa5 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 4 May 2025 21:08:34 +0000 Subject: [PATCH 10/31] fixed long line --- src/mcp/shared/session.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 62052ac25..c55434229 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -270,7 +270,9 @@ async def send_request( reason="cancelled" ) ) - await self._send_notification_internal(notification, request_id) + await self._send_notification_internal( + notification, request_id + ) raise McpError( ErrorData(code=32601, message="Request cancelled") From 92f806b95912575673116c46511253c0756d8eb7 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 4 May 2025 21:09:51 +0000 Subject: [PATCH 11/31] fixed long line --- src/mcp/shared/session.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index c55434229..55133d5c3 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -34,7 +34,10 @@ SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest) SendResultT = TypeVar("SendResultT", ClientResult, ServerResult) SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification) -SendNotificationInternalT = TypeVar("SendNotificationInternalT", CancelledNotification, ClientNotification, ServerNotification) +SendNotificationInternalT = TypeVar( + "SendNotificationInternalT", + CancelledNotification, ClientNotification, ServerNotification +) ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest) ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) ReceiveNotificationT = TypeVar( From e46c693290afd29c178467fc7d25ebe87cecbc34 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 4 May 2025 21:11:08 +0000 Subject: [PATCH 12/31] updates from ruff format --- src/mcp/client/session.py | 2 +- src/mcp/server/lowlevel/server.py | 6 ++---- src/mcp/shared/session.py | 15 ++++++++------- tests/shared/test_session.py | 2 +- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 130fe33f2..fc12ceb2e 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -140,7 +140,7 @@ async def initialize(self) -> types.InitializeResult: ) ), types.InitializeResult, - cancellable=False + cancellable=False, ) if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS: diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index a09c5fabf..dfbacd77b 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -448,15 +448,13 @@ def decorator( logger.debug("Registering handler for CancelledNotification") async def handler(req: types.CancelledNotification): - await func( - req.params.requestId, req.params.reason - ) + await func(req.params.requestId, req.params.reason) self.notification_handlers[types.CancelledNotification] = handler return func return decorator - + def completion(self): """Provides completions for prompts and resource templates""" diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 55133d5c3..20dfe7d70 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -35,8 +35,10 @@ SendResultT = TypeVar("SendResultT", ClientResult, ServerResult) SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification) SendNotificationInternalT = TypeVar( - "SendNotificationInternalT", - CancelledNotification, ClientNotification, ServerNotification + "SendNotificationInternalT", + CancelledNotification, + ClientNotification, + ServerNotification, ) ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest) ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) @@ -269,14 +271,13 @@ async def send_request( notification = CancelledNotification( method="notifications/cancelled", params=CancelledNotificationParams( - requestId=request_id, - reason="cancelled" - ) + requestId=request_id, reason="cancelled" + ), ) await self._send_notification_internal( notification, request_id ) - + raise McpError( ErrorData(code=32601, message="Request cancelled") ) @@ -292,7 +293,7 @@ async def send_request( ), ) ) - + if isinstance(response_or_error, JSONRPCError): raise McpError(response_or_error.error) else: diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 037553b0e..866444250 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -62,7 +62,7 @@ async def handle_call_tool(name: str, arguments: dict | None) -> list: await anyio.sleep(10) # Long enough to ensure we can cancel return [] raise ValueError(f"Unknown tool: {name}") - + @server.cancel_notification() async def handle_cancel(requestId: str | int, reason: str | None): nonlocal ev_cancel_notified From 2a24e0cba16d21e68d93e79e7d61d834a9ae9bf5 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 4 May 2025 21:32:01 +0000 Subject: [PATCH 13/31] removed dev print statements committed by mistake --- src/mcp/server/lowlevel/server.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index dfbacd77b..2e262cd3d 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -601,14 +601,12 @@ async def _handle_notification(self, notify: Any): assert type(notify) in self.notification_handlers handler = self.notification_handlers[type(notify)] - print(f"Dispatching notification of type {type(notify).__name__}") + logger.debug(f"Dispatching notification of type {type(notify).__name__}") try: await handler(notify) except Exception as err: logger.error(f"Uncaught exception in notification handler: {err}") - else: - print(f"Not handling {notify}") async def _ping_handler(request: types.PingRequest) -> types.ServerResult: From 87722f8bfdf8876400207cdcf940fd4dc1e3a5e9 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 4 May 2025 21:55:01 +0000 Subject: [PATCH 14/31] add constant for request cancelled and use it in case of cancelled task --- src/mcp/shared/session.py | 3 ++- src/mcp/types.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 20dfe7d70..e9603accb 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -29,6 +29,7 @@ ServerNotification, ServerRequest, ServerResult, + REQUEST_CANCELLED, ) SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest) @@ -279,7 +280,7 @@ async def send_request( ) raise McpError( - ErrorData(code=32601, message="Request cancelled") + ErrorData(code=REQUEST_CANCELLED, message="Request cancelled") ) except TimeoutError: diff --git a/src/mcp/types.py b/src/mcp/types.py index 6ab7fba5c..5c077ca9d 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -146,6 +146,7 @@ class JSONRPCResponse(BaseModel): METHOD_NOT_FOUND = -32601 INVALID_PARAMS = -32602 INTERNAL_ERROR = -32603 +REQUEST_CANCELLED = -32604 class ErrorData(BaseModel): From f0782d251c64275df6a248bb1a779f5d2dc24699 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 4 May 2025 21:55:41 +0000 Subject: [PATCH 15/31] fix long line --- src/mcp/shared/session.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index e9603accb..bc6388648 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -280,7 +280,9 @@ async def send_request( ) raise McpError( - ErrorData(code=REQUEST_CANCELLED, message="Request cancelled") + ErrorData( + code=REQUEST_CANCELLED, message="Request cancelled" + ) ) except TimeoutError: From a0164f989904e6f649df64ab057ac3e88cfc1bf9 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 4 May 2025 21:56:59 +0000 Subject: [PATCH 16/31] fix import order --- src/mcp/shared/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index bc6388648..f21f72132 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -14,6 +14,7 @@ from mcp.shared.exceptions import McpError from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.types import ( + REQUEST_CANCELLED, CancelledNotification, CancelledNotificationParams, ClientNotification, @@ -29,7 +30,6 @@ ServerNotification, ServerRequest, ServerResult, - REQUEST_CANCELLED, ) SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest) From bf220d58b534d9de68386bb565b19e6c8c0ee392 Mon Sep 17 00:00:00 2001 From: David Savage Date: Mon, 5 May 2025 04:44:02 +0000 Subject: [PATCH 17/31] add assert that cancel scope triggered on server, add comments for readability --- tests/shared/test_session.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 866444250..8e67940ad 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -46,6 +46,7 @@ async def test_request_cancellation(): # The tool is already registered in the fixture ev_tool_called = anyio.Event() + ev_tool_cancelled = anyio.Event() ev_cancelled = anyio.Event() ev_cancel_notified = anyio.Event() @@ -56,11 +57,17 @@ def make_server() -> Server: # Register the tool handler @server.call_tool() async def handle_call_tool(name: str, arguments: dict | None) -> list: - nonlocal ev_tool_called + nonlocal ev_tool_called, ev_tool_cancelled if name == "slow_tool": ev_tool_called.set() - await anyio.sleep(10) # Long enough to ensure we can cancel - return [] + with anyio.CancelScope(): + try: + await anyio.sleep(10) # Long enough to ensure we can cancel + return [] + except anyio.get_cancelled_exc_class() as err: + ev_tool_cancelled.set() + raise err + raise ValueError(f"Unknown tool: {name}") @server.cancel_notification() @@ -111,11 +118,17 @@ async def make_request(client_session): with anyio.fail_after(1): # Timeout after 1 second await ev_tool_called.wait() + # cancel the task via task group tg.cancel_scope.cancel() + # Give cancellation time to process + with anyio.fail_after(1): + await ev_cancelled.wait() + + # check server cancel notification received with anyio.fail_after(1): await ev_cancel_notified.wait() - # Give cancellation time to process + # Give cancellation time to process on server with anyio.fail_after(1): - await ev_cancelled.wait() + await ev_tool_cancelled.wait() From 45ac52ae33105c489aaba1f014cb27759085fee7 Mon Sep 17 00:00:00 2001 From: David Savage Date: Mon, 5 May 2025 04:45:49 +0000 Subject: [PATCH 18/31] trivial update to comment capitalisation --- tests/shared/test_session.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 8e67940ad..15607d6e0 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -118,14 +118,14 @@ async def make_request(client_session): with anyio.fail_after(1): # Timeout after 1 second await ev_tool_called.wait() - # cancel the task via task group + # Cancel the task via task group tg.cancel_scope.cancel() # Give cancellation time to process with anyio.fail_after(1): await ev_cancelled.wait() - # check server cancel notification received + # Check server cancel notification received with anyio.fail_after(1): await ev_cancel_notified.wait() From 8b7f1cda328c2119883ab5e1d8822bc1faf64154 Mon Sep 17 00:00:00 2001 From: David Savage Date: Mon, 5 May 2025 04:52:03 +0000 Subject: [PATCH 19/31] simplify test code --- tests/shared/test_session.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 15607d6e0..57a6a2848 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -9,7 +9,6 @@ from mcp.shared.exceptions import McpError from mcp.shared.memory import create_connected_server_and_client_session from mcp.types import ( - ClientRequest, EmptyResult, ) @@ -88,20 +87,10 @@ async def handle_list_tools() -> list[types.Tool]: return server - async def make_request(client_session): + async def make_request(client_session: ClientSession): nonlocal ev_cancelled try: - await client_session.send_request( - ClientRequest( - types.CallToolRequest( - method="tools/call", - params=types.CallToolRequestParams( - name="slow_tool", arguments={} - ), - ) - ), - types.CallToolResult, - ) + await client_session.call_tool("slow_tool") pytest.fail("Request should have been cancelled") except McpError as e: # Expected - request was cancelled From ac4b8226069f53dff177cfe2d80bdaa6120e1546 Mon Sep 17 00:00:00 2001 From: David Savage Date: Mon, 5 May 2025 05:29:27 +0000 Subject: [PATCH 20/31] add tests to assert cancellable=False behaves as expected --- src/mcp/client/session.py | 2 ++ src/mcp/shared/session.py | 59 ++++++++++++++++---------------- tests/shared/test_session.py | 66 +++++++++++++++++++++++++++++++++++- 3 files changed, 97 insertions(+), 30 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index fc12ceb2e..356330390 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -260,6 +260,7 @@ async def call_tool( name: str, arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, + cancellable: bool = True, ) -> types.CallToolResult: """Send a tools/call request.""" @@ -272,6 +273,7 @@ async def call_tool( ), types.CallToolResult, request_read_timeout_seconds=read_timeout_seconds, + cancellable=cancellable, ) async def list_prompts(self) -> types.ListPromptsResult: diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index f21f72132..0022c0a60 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -263,39 +263,40 @@ async def send_request( elif self._session_read_timeout_seconds is not None: timeout = self._session_read_timeout_seconds.total_seconds() - try: - with anyio.fail_after(timeout) as scope: - response_or_error = await response_stream_reader.receive() - - if cancellable and scope.cancel_called: - with anyio.CancelScope(shield=True): - notification = CancelledNotification( - method="notifications/cancelled", - params=CancelledNotificationParams( - requestId=request_id, reason="cancelled" - ), - ) - await self._send_notification_internal( - notification, request_id + with anyio.CancelScope(shield=not cancellable): + try: + with anyio.fail_after(timeout) as scope: + response_or_error = await response_stream_reader.receive() + + if scope.cancel_called: + with anyio.CancelScope(shield=True): + notification = CancelledNotification( + method="notifications/cancelled", + params=CancelledNotificationParams( + requestId=request_id, reason="cancelled" + ), + ) + await self._send_notification_internal( + notification, request_id + ) + + raise McpError( + ErrorData( + code=REQUEST_CANCELLED, message="Request cancelled" + ) ) - raise McpError( - ErrorData( - code=REQUEST_CANCELLED, message="Request cancelled" - ) + except TimeoutError: + raise McpError( + ErrorData( + code=httpx.codes.REQUEST_TIMEOUT, + message=( + f"Timed out while waiting for response to " + f"{request.__class__.__name__}. Waited " + f"{timeout} seconds." + ), ) - - except TimeoutError: - raise McpError( - ErrorData( - code=httpx.codes.REQUEST_TIMEOUT, - message=( - f"Timed out while waiting for response to " - f"{request.__class__.__name__}. Waited " - f"{timeout} seconds." - ), ) - ) if isinstance(response_or_error, JSONRPCError): raise McpError(response_or_error.error) diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 57a6a2848..80e0694c2 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -42,7 +42,6 @@ async def test_in_flight_requests_cleared_after_completion( @pytest.mark.anyio async def test_request_cancellation(): """Test that requests can be cancelled while in-flight.""" - # The tool is already registered in the fixture ev_tool_called = anyio.Event() ev_tool_cancelled = anyio.Event() @@ -121,3 +120,68 @@ async def make_request(client_session: ClientSession): # Give cancellation time to process on server with anyio.fail_after(1): await ev_tool_cancelled.wait() + +@pytest.mark.anyio +async def test_request_cancellation_uncancellable(): + """Test that asserts.""" + # The tool is already registered in the fixture + + ev_tool_called = anyio.Event() + ev_tool_commplete = anyio.Event() + ev_cancelled = anyio.Event() + + # Start the request in a separate task so we can cancel it + def make_server() -> Server: + server = Server(name="TestSessionServer") + + # Register the tool handler + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict | None) -> list: + nonlocal ev_tool_called, ev_tool_commplete + if name == "slow_tool": + ev_tool_called.set() + with anyio.CancelScope(): + with anyio.fail_after(10): # Long enough to ensure we can cancel + await ev_cancelled.wait() + ev_tool_commplete.set() + return [] + + raise ValueError(f"Unknown tool: {name}") + + # Register the tool so it shows up in list_tools + @server.list_tools() + async def handle_list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="slow_tool", + description="A slow tool that takes 10 seconds to complete", + inputSchema={}, + ) + ] + + return server + + async def make_request(client_session: ClientSession): + nonlocal ev_cancelled + try: + await client_session.call_tool("slow_tool", cancellable=False) + except McpError as e: + pytest.fail("Request should not have been cancelled") + + async with create_connected_server_and_client_session( + make_server() + ) as client_session: + async with anyio.create_task_group() as tg: + tg.start_soon(make_request, client_session) + + # Wait for the request to be in-flight + with anyio.fail_after(1): # Timeout after 1 second + await ev_tool_called.wait() + + # Cancel the task via task group + tg.cancel_scope.cancel() + ev_cancelled.set() + + # Check server completed regardless + with anyio.fail_after(1): + await ev_tool_commplete.wait() From d86b4a51934d2a5954a6116f06386fd5007959b4 Mon Sep 17 00:00:00 2001 From: David Savage Date: Mon, 5 May 2025 05:30:18 +0000 Subject: [PATCH 21/31] tidy up ruff check --- tests/shared/test_session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 80e0694c2..65e317703 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -165,7 +165,7 @@ async def make_request(client_session: ClientSession): nonlocal ev_cancelled try: await client_session.call_tool("slow_tool", cancellable=False) - except McpError as e: + except McpError: pytest.fail("Request should not have been cancelled") async with create_connected_server_and_client_session( From 6f4ae44cad46342e6c0154848b7a8b96aad2eeda Mon Sep 17 00:00:00 2001 From: David Savage Date: Mon, 5 May 2025 05:30:52 +0000 Subject: [PATCH 22/31] ruff format update --- tests/shared/test_session.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 65e317703..c4c3287d5 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -121,6 +121,7 @@ async def make_request(client_session: ClientSession): with anyio.fail_after(1): await ev_tool_cancelled.wait() + @pytest.mark.anyio async def test_request_cancellation_uncancellable(): """Test that asserts.""" @@ -141,7 +142,7 @@ async def handle_call_tool(name: str, arguments: dict | None) -> list: if name == "slow_tool": ev_tool_called.set() with anyio.CancelScope(): - with anyio.fail_after(10): # Long enough to ensure we can cancel + with anyio.fail_after(10): # Long enough to ensure we can cancel await ev_cancelled.wait() ev_tool_commplete.set() return [] From 11d2e5296c9b47af312a0abc595248a4d0aa1b1d Mon Sep 17 00:00:00 2001 From: David Savage Date: Mon, 5 May 2025 05:44:55 +0000 Subject: [PATCH 23/31] fixed test description --- tests/shared/test_session.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index c4c3287d5..1bb7b86f6 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -124,8 +124,8 @@ async def make_request(client_session: ClientSession): @pytest.mark.anyio async def test_request_cancellation_uncancellable(): - """Test that asserts.""" - # The tool is already registered in the fixture + """Test that asserts a call with cancellable=False is not cancelled on + server when cancel scope on client is set.""" ev_tool_called = anyio.Event() ev_tool_commplete = anyio.Event() From 235df35de77e039653af0e20cf48f029225ec38c Mon Sep 17 00:00:00 2001 From: David Savage Date: Mon, 5 May 2025 06:05:19 +0000 Subject: [PATCH 24/31] use defined types in decorator --- src/mcp/server/lowlevel/server.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 2e262cd3d..ffe1bf0c2 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -85,7 +85,7 @@ async def main(): from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder +from mcp.shared.session import RequestId, RequestResponder logger = logging.getLogger(__name__) @@ -427,7 +427,7 @@ async def handler(req: types.CallToolRequest): def progress_notification(self): def decorator( - func: Callable[[str | int, float, float | None], Awaitable[None]], + func: Callable[[types.ProgressToken, float, float | None], Awaitable[None]], ): logger.debug("Registering handler for ProgressNotification") @@ -443,7 +443,7 @@ async def handler(req: types.ProgressNotification): def cancel_notification(self): def decorator( - func: Callable[[str | int, str | None], Awaitable[None]], + func: Callable[[RequestId, str | None], Awaitable[None]], ): logger.debug("Registering handler for CancelledNotification") From f96aaa58dc4c7f8f894c7c300ee65e34cbadded5 Mon Sep 17 00:00:00 2001 From: David Savage Date: Mon, 5 May 2025 10:03:54 +0000 Subject: [PATCH 25/31] Add docstring info for cancellable and add TODO note for initialised which states that initialised is not cancellable as per https://modelcontextprotocol.io/specification/2025-03-26/basic/utilities/cancellation --- src/mcp/client/session.py | 5 +++++ src/mcp/shared/session.py | 12 ++++++++++++ 2 files changed, 17 insertions(+) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 356330390..8d38cd164 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -140,6 +140,11 @@ async def initialize(self) -> types.InitializeResult: ) ), types.InitializeResult, + # TODO should set a request_read_timeout_seconds as per + # guidance from BaseSession.send_request not obvious + # what subsequent process should be, refer the following + # specification for more details + # https://modelcontextprotocol.io/specification/2025-03-26/basic/utilities/cancellation cancellable=False, ) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 0022c0a60..ba510be9a 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -229,6 +229,18 @@ async def send_request( response contains an error. If a request read timeout is provided, it will take precedence over the session read timeout. + If cancellable is set to False then the request will wait + request_read_timeout_seconds to complete and ignore any attempt to + cancel via the anyio.CancelScope within which this method was called. + + If cancellable is set to True (default) if the anyio.CancelScope within + which this method was called is cancelled it will generate a + CancelationNotfication and send this to the server which should then abort + the task however this is not guaranteed. + + For further information on the CancelNotification flow refer to + https://modelcontextprotocol.io/specification/2025-03-26/basic/utilities/cancellation + Do not use this method to emit notifications! Use send_notification() instead. """ From bd73448b72897d1f18ca063e6cfd35b049c1216c Mon Sep 17 00:00:00 2001 From: David Savage Date: Mon, 5 May 2025 10:13:12 +0000 Subject: [PATCH 26/31] set read_timeout_seconds to avoid test blocking for ever in case something fails --- tests/shared/test_session.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 1bb7b86f6..c0f60da8d 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -1,4 +1,5 @@ from collections.abc import AsyncGenerator +from datetime import timedelta import anyio import pytest @@ -165,7 +166,11 @@ async def handle_list_tools() -> list[types.Tool]: async def make_request(client_session: ClientSession): nonlocal ev_cancelled try: - await client_session.call_tool("slow_tool", cancellable=False) + await client_session.call_tool( + "slow_tool", + cancellable=False, + read_timeout_seconds=timedelta(seconds=10), + ) except McpError: pytest.fail("Request should not have been cancelled") From 3817fe2f9d178ea1bea7f9354ff18b3315d805ba Mon Sep 17 00:00:00 2001 From: David Savage Date: Mon, 5 May 2025 10:34:18 +0000 Subject: [PATCH 27/31] remove unnecessary shielded cancel scope --- src/mcp/shared/session.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index ba510be9a..1221a9876 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -281,16 +281,15 @@ async def send_request( response_or_error = await response_stream_reader.receive() if scope.cancel_called: - with anyio.CancelScope(shield=True): - notification = CancelledNotification( - method="notifications/cancelled", - params=CancelledNotificationParams( - requestId=request_id, reason="cancelled" - ), - ) - await self._send_notification_internal( - notification, request_id - ) + notification = CancelledNotification( + method="notifications/cancelled", + params=CancelledNotificationParams( + requestId=request_id, reason="cancelled" + ), + ) + await self._send_notification_internal( + notification, request_id + ) raise McpError( ErrorData( From 2fd27a28e883f27ee5c6a17a5c547cc19790c3e2 Mon Sep 17 00:00:00 2001 From: David Savage Date: Mon, 5 May 2025 10:35:00 +0000 Subject: [PATCH 28/31] whitespace --- src/mcp/shared/session.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 1221a9876..992ff086e 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -290,7 +290,6 @@ async def send_request( await self._send_notification_internal( notification, request_id ) - raise McpError( ErrorData( code=REQUEST_CANCELLED, message="Request cancelled" From 2e86d32de05334e0f2cff4591d929631ca3f30db Mon Sep 17 00:00:00 2001 From: David Savage Date: Tue, 6 May 2025 06:15:37 +0000 Subject: [PATCH 29/31] clarified doc string comment --- src/mcp/shared/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 992ff086e..f6a70fc9f 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -236,7 +236,7 @@ async def send_request( If cancellable is set to True (default) if the anyio.CancelScope within which this method was called is cancelled it will generate a CancelationNotfication and send this to the server which should then abort - the task however this is not guaranteed. + the task however the server is is not guaranteed to honour this request. For further information on the CancelNotification flow refer to https://modelcontextprotocol.io/specification/2025-03-26/basic/utilities/cancellation From 1039b993ee8d11474f0d1a98668694f3cca91fc2 Mon Sep 17 00:00:00 2001 From: David Savage Date: Tue, 6 May 2025 06:21:17 +0000 Subject: [PATCH 30/31] fixed doc string and added comment on internal notification method --- src/mcp/shared/session.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index f6a70fc9f..3e203ca31 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -323,17 +323,20 @@ async def send_notification( notification: SendNotificationT, related_request_id: RequestId | None = None, ) -> None: + """ + Emits a notification, which is a one-way message that does not expect + a response. + """ await self._send_notification_internal(notification, related_request_id) + # this method is required as SendNotificationT type checking prevents + # internal use for sending cancelation - typechecking sorcery may be + # required async def _send_notification_internal( self, notification: SendNotificationInternalT, related_request_id: RequestId | None = None, ) -> None: - """ - Emits a notification, which is a one-way message that does not expect - a response. - """ # Some transport implementations may need to set the related_request_id # to attribute to the notifications to the request that triggered them. jsonrpc_notification = JSONRPCNotification( From 39f7739a8fbb2f05d2d0a0d0c8d2f8cf71f218b7 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sat, 17 May 2025 17:41:42 +0000 Subject: [PATCH 31/31] ignore type warning for internal use of send_notification --- src/mcp/shared/session.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index a022d9f90..6a05f74a0 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -36,12 +36,6 @@ SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest) SendResultT = TypeVar("SendResultT", ClientResult, ServerResult) SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification) -SendNotificationInternalT = TypeVar( - "SendNotificationInternalT", - CancelledNotification, - ClientNotification, - ServerNotification, -) ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest) ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) ReceiveNotificationT = TypeVar( @@ -308,7 +302,7 @@ async def send_request( requestId=request_id, reason="cancelled" ), ) - await self._send_notification_internal( + await self._send_notification( # type: ignore notification, request_id ) raise McpError( @@ -349,16 +343,6 @@ async def send_notification( Emits a notification, which is a one-way message that does not expect a response. """ - await self._send_notification_internal(notification, related_request_id) - - # this method is required as SendNotificationT type checking prevents - # internal use for sending cancelation - typechecking sorcery may be - # required - async def _send_notification_internal( - self, - notification: SendNotificationInternalT, - related_request_id: RequestId | None = None, - ) -> None: # Some transport implementations may need to set the related_request_id # to attribute to the notifications to the request that triggered them. jsonrpc_notification = JSONRPCNotification(