From c3ef7deb613e9df8ab7316bd0b5ed39120e61bf5 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sat, 24 May 2025 15:35:10 +0000 Subject: [PATCH 01/31] tinkering with resource progress --- src/mcp/client/session.py | 17 +++++++----- src/mcp/server/fastmcp/server.py | 11 ++++++-- src/mcp/server/session.py | 6 ++-- src/mcp/shared/session.py | 37 +++++++++++++++++++++---- src/mcp/types.py | 10 +++++-- tests/issues/test_176_progress_token.py | 6 ++-- 6 files changed, 65 insertions(+), 22 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index fe90716e2..719dc2808 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,14 +1,15 @@ from datetime import timedelta -from typing import Any, Protocol +from typing import Annotated, Any, Protocol import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import AnyUrl, TypeAdapter +from pydantic import TypeAdapter +from pydantic.networks import AnyUrl, UrlConstraints import mcp.types as types from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage -from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder +from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder, ResourceProgressFnT from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") @@ -173,6 +174,7 @@ async def send_progress_notification( progress: float, total: float | None = None, message: str | None = None, + # TODO decide whether clients can send resource progress too? ) -> None: """Send a progress notification.""" await self.send_notification( @@ -203,6 +205,7 @@ async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResul async def list_resources( self, cursor: str | None = None + # TODO suggest in progress resources should be excluded by default? possibly add an optional flag to include? ) -> types.ListResourcesResult: """Send a resources/list request.""" return await self.send_request( @@ -233,7 +236,7 @@ async def list_resource_templates( types.ListResourceTemplatesResult, ) - async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: + async def read_resource(self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]) -> types.ReadResourceResult: """Send a resources/read request.""" return await self.send_request( types.ClientRequest( @@ -245,7 +248,7 @@ async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: types.ReadResourceResult, ) - async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + async def subscribe_resource(self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]) -> types.EmptyResult: """Send a resources/subscribe request.""" return await self.send_request( types.ClientRequest( @@ -257,7 +260,7 @@ async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: types.EmptyResult, ) - async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + async def unsubscribe_resource(self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]) -> types.EmptyResult: """Send a resources/unsubscribe request.""" return await self.send_request( types.ClientRequest( @@ -274,7 +277,7 @@ async def call_tool( name: str, arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, - progress_callback: ProgressFnT | None = None, + progress_callback: ProgressFnT | ResourceProgressFnT | None = None, ) -> types.CallToolResult: """Send a tools/call request with optional progress callback support.""" diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 3282baae6..1b4ab7428 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -10,12 +10,12 @@ asynccontextmanager, ) from itertools import chain -from typing import Any, Generic, Literal +from typing import Annotated, Any, Generic, Literal import anyio import pydantic_core from pydantic import BaseModel, Field -from pydantic.networks import AnyUrl +from pydantic.networks import AnyUrl, UrlConstraints from pydantic_settings import BaseSettings, SettingsConfigDict from starlette.applications import Starlette from starlette.middleware import Middleware @@ -956,7 +956,11 @@ def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]: return self._request_context async def report_progress( - self, progress: float, total: float | None = None, message: str | None = None + self, + progress: float, + total: float | None = None, + message: str | None = None, + resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] | None = None, ) -> None: """Report progress for the current operation. @@ -979,6 +983,7 @@ async def report_progress( progress=progress, total=total, message=message, + resource_uri=resource_uri, ) async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]: diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index ef5c5a3c3..37115a0c6 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -38,12 +38,12 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: """ from enum import Enum -from typing import Any, TypeVar +from typing import Annotated, Any, TypeVar import anyio import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import AnyUrl +from pydantic.networks import AnyUrl, UrlConstraints import mcp.types as types from mcp.server.models import InitializationOptions @@ -288,6 +288,7 @@ async def send_progress_notification( total: float | None = None, message: str | None = None, related_request_id: str | None = None, + resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] | None = None, ) -> None: """Send a progress notification.""" await self.send_notification( @@ -299,6 +300,7 @@ async def send_progress_notification( progress=progress, total=total, message=message, + resource_uri=resource_uri, ), ) ), diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 90b4eb27c..8f0d2fe2d 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -3,12 +3,14 @@ from contextlib import AsyncExitStack from datetime import timedelta from types import TracebackType -from typing import Any, Generic, Protocol, TypeVar +from typing import Annotated, Any, Generic, Protocol, TypeVar, runtime_checkable import anyio import httpx +import inspect from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import BaseModel +from pydantic.networks import AnyUrl, UrlConstraints from typing_extensions import Self from mcp.shared.exceptions import McpError @@ -43,6 +45,7 @@ RequestId = str | int +@runtime_checkable class ProgressFnT(Protocol): """Protocol for progress notification callbacks.""" @@ -50,6 +53,13 @@ async def __call__( self, progress: float, total: float | None, message: str | None ) -> None: ... +@runtime_checkable +class ResourceProgressFnT(Protocol): + """Protocol for progress notification callbacks with resources.""" + + async def __call__( + self, progress: float, total: float | None, message: str | None, resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] | None = None + ) -> None: ... class RequestResponder(Generic[ReceiveRequestT, SendResultT]): """Handles responding to MCP requests and manages request lifecycle. @@ -178,7 +188,8 @@ class BaseSession( ] _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] - _progress_callbacks: dict[RequestId, ProgressFnT] + _progress_callbacks: dict[RequestId, ProgressFnT ] + _resource_progress_callbacks: dict[RequestId, ResourceProgressFnT] def __init__( self, @@ -198,6 +209,7 @@ def __init__( self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} self._progress_callbacks = {} + self._resource_progress_callbacks = {} self._exit_stack = AsyncExitStack() async def __aenter__(self) -> Self: @@ -225,7 +237,7 @@ async def send_request( result_type: type[ReceiveResultT], request_read_timeout_seconds: timedelta | None = None, metadata: MessageMetadata = None, - progress_callback: ProgressFnT | None = None, + progress_callback: ProgressFnT | ResourceProgressFnT | None = None, ) -> ReceiveResultT: """ Sends a request and wait for a response. Raises an McpError if the @@ -252,8 +264,14 @@ async def send_request( if "_meta" not in request_data["params"]: request_data["params"]["_meta"] = {} request_data["params"]["_meta"]["progressToken"] = request_id - # Store the callback for this request - self._progress_callbacks[request_id] = progress_callback + # note this is required to ensure backwards compatibility for previous clients + signature = inspect.signature(progress_callback.__call__) + if 'resource_uri' in signature.parameters: + # Store the callback for this request + self._resource_progress_callbacks[request_id] = progress_callback # type: ignore + else: + # Store the callback for this request + self._progress_callbacks[request_id] = progress_callback try: jsonrpc_request = JSONRPCRequest( @@ -397,6 +415,15 @@ async def _receive_loop(self) -> None: notification.root.params.total, notification.root.params.message, ) + elif progress_token in self._resource_progress_callbacks: + callback = self._resource_progress_callbacks[progress_token] + await callback( + notification.root.params.progress, + notification.root.params.total, + notification.root.params.message, + notification.root.params.resource_uri, + ) + await self._received_notification(notification) await self._handle_incoming(notification) except Exception as e: diff --git a/src/mcp/types.py b/src/mcp/types.py index 465fc6ee6..ac83deda3 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -346,12 +346,18 @@ class ProgressNotificationParams(NotificationParams): total is unknown. """ total: float | None = None + """Total number of items to process (or total progress required), if known.""" + message: str | None = None """ Message related to progress. This should provide relevant human readable progress information. """ - message: str | None = None - """Total number of items to process (or total progress required), if known.""" + resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] | None = None + """ + An optional reference to an ephemeral resource associated with this progress, servers + may delete these at their descretion, but are encouraged to make them available for + a reasonable time period to allow clients to retrieve and cache the resources locally + """ model_config = ConfigDict(extra="allow") diff --git a/tests/issues/test_176_progress_token.py b/tests/issues/test_176_progress_token.py index 4ad22f294..6ba4770aa 100644 --- a/tests/issues/test_176_progress_token.py +++ b/tests/issues/test_176_progress_token.py @@ -39,11 +39,11 @@ async def test_progress_token_zero_first_call(): mock_session.send_progress_notification.call_count == 3 ), "All progress notifications should be sent" mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=0.0, total=10.0, message=None + progress_token=0, progress=0.0, total=10.0, message=None, resource_uri=None ) mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=5.0, total=10.0, message=None + progress_token=0, progress=5.0, total=10.0, message=None, resource_uri=None ) mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=10.0, total=10.0, message=None + progress_token=0, progress=10.0, total=10.0, message=None, resource_uri=None ) From b72bbc42dee94406a518b0c858bc4cd2d8927e9f Mon Sep 17 00:00:00 2001 From: David Savage Date: Sat, 24 May 2025 15:44:36 +0000 Subject: [PATCH 02/31] ruff check and format fixes --- src/mcp/client/session.py | 25 +++++++++++++++++++------ src/mcp/client/session_group.py | 1 - src/mcp/server/fastmcp/server.py | 3 ++- src/mcp/server/session.py | 3 ++- src/mcp/shared/session.py | 28 ++++++++++++++++++---------- src/mcp/types.py | 7 ++++--- 6 files changed, 45 insertions(+), 22 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 719dc2808..e1db957e6 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -9,7 +9,12 @@ import mcp.types as types from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage -from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder, ResourceProgressFnT +from mcp.shared.session import ( + BaseSession, + ProgressFnT, + RequestResponder, + ResourceProgressFnT, +) from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") @@ -204,8 +209,10 @@ async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResul ) async def list_resources( - self, cursor: str | None = None - # TODO suggest in progress resources should be excluded by default? possibly add an optional flag to include? + self, + cursor: str | None = None, + # TODO suggest in progress resources should be excluded by default? + # possibly add an optional flag to include? ) -> types.ListResourcesResult: """Send a resources/list request.""" return await self.send_request( @@ -236,7 +243,9 @@ async def list_resource_templates( types.ListResourceTemplatesResult, ) - async def read_resource(self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]) -> types.ReadResourceResult: + async def read_resource( + self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + ) -> types.ReadResourceResult: """Send a resources/read request.""" return await self.send_request( types.ClientRequest( @@ -248,7 +257,9 @@ async def read_resource(self, uri: Annotated[AnyUrl, UrlConstraints(host_require types.ReadResourceResult, ) - async def subscribe_resource(self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]) -> types.EmptyResult: + async def subscribe_resource( + self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + ) -> types.EmptyResult: """Send a resources/subscribe request.""" return await self.send_request( types.ClientRequest( @@ -260,7 +271,9 @@ async def subscribe_resource(self, uri: Annotated[AnyUrl, UrlConstraints(host_re types.EmptyResult, ) - async def unsubscribe_resource(self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]) -> types.EmptyResult: + async def unsubscribe_resource( + self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + ) -> types.EmptyResult: """Send a resources/unsubscribe request.""" return await self.send_request( types.ClientRequest( diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index a430533b3..a77dc7a1e 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -154,7 +154,6 @@ async def __aexit__( for exit_stack in self._session_exit_stacks.values(): tg.start_soon(exit_stack.aclose) - @property def sessions(self) -> list[mcp.ClientSession]: """Returns the list of sessions being managed.""" diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 1b4ab7428..5489f6952 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -960,7 +960,8 @@ async def report_progress( progress: float, total: float | None = None, message: str | None = None, - resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] | None = None, + resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + | None = None, ) -> None: """Report progress for the current operation. diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 37115a0c6..147bf0326 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -288,7 +288,8 @@ async def send_progress_notification( total: float | None = None, message: str | None = None, related_request_id: str | None = None, - resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] | None = None, + resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + | None = None, ) -> None: """Send a progress notification.""" await self.send_notification( diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 8f0d2fe2d..94ae3651d 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,3 +1,4 @@ +import inspect import logging from collections.abc import Callable from contextlib import AsyncExitStack @@ -7,7 +8,6 @@ import anyio import httpx -import inspect from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import BaseModel from pydantic.networks import AnyUrl, UrlConstraints @@ -53,14 +53,21 @@ async def __call__( self, progress: float, total: float | None, message: str | None ) -> None: ... + @runtime_checkable class ResourceProgressFnT(Protocol): """Protocol for progress notification callbacks with resources.""" async def __call__( - self, progress: float, total: float | None, message: str | None, resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] | None = None + self, + progress: float, + total: float | None, + message: str | None, + resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + | None = None, ) -> None: ... + class RequestResponder(Generic[ReceiveRequestT, SendResultT]): """Handles responding to MCP requests and manages request lifecycle. @@ -188,8 +195,8 @@ class BaseSession( ] _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] - _progress_callbacks: dict[RequestId, ProgressFnT ] - _resource_progress_callbacks: dict[RequestId, ResourceProgressFnT] + _progress_callbacks: dict[RequestId, ProgressFnT] + _resource_callbacks: dict[RequestId, ResourceProgressFnT] def __init__( self, @@ -209,7 +216,7 @@ def __init__( self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} self._progress_callbacks = {} - self._resource_progress_callbacks = {} + self._resource_callbacks = {} self._exit_stack = AsyncExitStack() async def __aenter__(self) -> Self: @@ -264,11 +271,12 @@ async def send_request( if "_meta" not in request_data["params"]: request_data["params"]["_meta"] = {} request_data["params"]["_meta"]["progressToken"] = request_id - # note this is required to ensure backwards compatibility for previous clients + # note this is required to ensure backwards compatibility + # for previous clients signature = inspect.signature(progress_callback.__call__) - if 'resource_uri' in signature.parameters: + if "resource_uri" in signature.parameters: # Store the callback for this request - self._resource_progress_callbacks[request_id] = progress_callback # type: ignore + self._resource_callbacks[request_id] = progress_callback # type: ignore else: # Store the callback for this request self._progress_callbacks[request_id] = progress_callback @@ -415,8 +423,8 @@ async def _receive_loop(self) -> None: notification.root.params.total, notification.root.params.message, ) - elif progress_token in self._resource_progress_callbacks: - callback = self._resource_progress_callbacks[progress_token] + elif progress_token in self._resource_callbacks: + callback = self._resource_callbacks[progress_token] await callback( notification.root.params.progress, notification.root.params.total, diff --git a/src/mcp/types.py b/src/mcp/types.py index ac83deda3..c7a6dfff4 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -354,9 +354,10 @@ class ProgressNotificationParams(NotificationParams): """ resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] | None = None """ - An optional reference to an ephemeral resource associated with this progress, servers - may delete these at their descretion, but are encouraged to make them available for - a reasonable time period to allow clients to retrieve and cache the resources locally + An optional reference to an ephemeral resource associated with this + progress, servers may delete these at their descretion, but are encouraged + to make them available for a reasonable time period to allow clients to + retrieve and cache the resources locally """ model_config = ConfigDict(extra="allow") From 202e92273bd70ed12e266c425165be6058b06b4a Mon Sep 17 00:00:00 2001 From: David Savage Date: Sat, 24 May 2025 17:52:16 +0000 Subject: [PATCH 03/31] use len of parameters to check progress variant, can't rely on clients naming params consistently --- src/mcp/shared/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 94ae3651d..f31657c14 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -274,7 +274,7 @@ async def send_request( # note this is required to ensure backwards compatibility # for previous clients signature = inspect.signature(progress_callback.__call__) - if "resource_uri" in signature.parameters: + if len(signature.parameters) == 3: # Store the callback for this request self._resource_callbacks[request_id] = progress_callback # type: ignore else: From d634e6a4228a175a984d08301068dba894b6feb2 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 25 May 2025 09:05:32 +0000 Subject: [PATCH 04/31] clarify TODO message --- src/mcp/client/session.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index e1db957e6..ce124bd1b 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -179,7 +179,9 @@ async def send_progress_notification( progress: float, total: float | None = None, message: str | None = None, - # TODO decide whether clients can send resource progress too? + # TODO check whether MCP spec allows clients to create resources + # for server and therefore whether resource notifications + # would be required here too ) -> None: """Send a progress notification.""" await self.send_notification( From 22bd2eb9295897a03b936b578463e800e889d667 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sat, 31 May 2025 18:45:23 +0000 Subject: [PATCH 05/31] implement types as per 617 proposal --- src/mcp/server/session.py | 2 +- src/mcp/types.py | 66 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 147bf0326..9c839326c 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -301,7 +301,7 @@ async def send_progress_notification( progress=progress, total=total, message=message, - resource_uri=resource_uri, + resourceUri=resource_uri, ), ) ), diff --git a/src/mcp/types.py b/src/mcp/types.py index 485cea2de..19c46cd9f 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -31,6 +31,7 @@ LATEST_PROTOCOL_VERSION = "2025-03-26" +AsyncToken = str | int ProgressToken = str | int Cursor = str Role = Literal["user", "assistant"] @@ -260,6 +261,12 @@ class ToolsCapability(BaseModel): """Whether this server supports notifications for changes to the tool list.""" model_config = ConfigDict(extra="allow") +class AsyncCapability(BaseModel): + """Capability for async operations.""" + + maxKeepAliveTime: int | None = None + """The maximum keep alive time in seconds for async requests.""" + model_config = ConfigDict(extra="allow") class LoggingCapability(BaseModel): """Capability for logging operations.""" @@ -279,6 +286,9 @@ class ServerCapabilities(BaseModel): resources: ResourcesCapability | None = None """Present if the server offers any resources to read.""" tools: ToolsCapability | None = None + """Present if the server offers async tool calling support.""" + async_: AsyncCapability | None = Field(alias='async', default=None) + """Present if the server offers any tools to call.""" model_config = ConfigDict(extra="allow") @@ -356,7 +366,7 @@ class ProgressNotificationParams(NotificationParams): Message related to progress. This should provide relevant human readable progress information. """ - resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] | None = None + resourceUri: Annotated[AnyUrl, UrlConstraints(host_required=False)] | None = None """ An optional reference to an ephemeral resource associated with this progress, servers may delete these at their descretion, but are encouraged @@ -767,6 +777,12 @@ class ToolAnnotations(BaseModel): of a memory tool is not. Default: true """ + preferAsync: bool | None = None + """ + If true, should ideally be called using the async protocol + as requests are expected to be long running. + Default: false + """ model_config = ConfigDict(extra="allow") @@ -797,19 +813,61 @@ class CallToolRequestParams(RequestParams): arguments: dict[str, Any] | None = None model_config = ConfigDict(extra="allow") - class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]): """Used by the client to invoke a tool provided by the server.""" method: Literal["tools/call"] params: CallToolRequestParams +class CallToolAsyncRequestParams(CallToolRequestParams): + """Parameters for calling a tool asynchronously.""" + + keepAlive: int | None = None + model_config = ConfigDict(extra="allow") + +class CallToolAsyncRequest(Request[CallToolAsyncRequestParams, Literal["tools/async/call"]]): + """Used by the client to invoke a tool provided by the server asynchronously.""" + method: Literal["tools/async/call"] + params: CallToolAsyncRequestParams + +class JoinCallToolRequestParams(RequestParams): + """Parameters for joining an asynchronous tool call.""" + token: AsyncToken + keepAlive: int | None = None + model_config = ConfigDict(extra="allow") + +class JoinCallToolAsyncRequest(Request[JoinCallToolRequestParams, Literal["tools/async/join"]]): + """Used by the client to join an tool call executing on the server asynchronously.""" + method: Literal["tools/async/join"] + params: JoinCallToolRequestParams + +class CancelToolAsyncNotificationParams(NotificationParams): + token: AsyncToken + +class CancelToolAsyncNotification(Notification[CancelToolAsyncNotificationParams, Literal["tools/async/cancel"]]): + method: Literal["tools/async/cancel"] + params: CancelToolAsyncNotificationParams + +class GetToolAsyncResultRequestParams(RequestParams): + token: AsyncToken + +class GetToolAsyncResultRequest(Request[GetToolAsyncResultRequestParams, Literal["tools/async/get"]]): + method: Literal["tools/async/get"] + params: GetToolAsyncResultRequestParams class CallToolResult(Result): """The server's response to a tool call.""" content: list[TextContent | ImageContent | EmbeddedResource] isError: bool = False + isPending: bool = False + +class CallToolAsyncResult(Result): + """The servers response to an async tool call""" + token: AsyncToken | None = None + recieved: int | None = None + keepAlive: int | None = None + accepted: bool class ToolListChangedNotification( @@ -1141,6 +1199,8 @@ class ClientRequest( | SubscribeRequest | UnsubscribeRequest | CallToolRequest + | CallToolAsyncRequest + | JoinCallToolAsyncRequest | ListToolsRequest ] ): @@ -1152,6 +1212,7 @@ class ClientNotification( CancelledNotification | ProgressNotification | InitializedNotification + | CancelToolAsyncNotification | RootsListChangedNotification ] ): @@ -1191,6 +1252,7 @@ class ServerResult( | ListResourceTemplatesResult | ReadResourceResult | CallToolResult + | CallToolAsyncResult | ListToolsResult ] ): From 9d3dfb292869319bdb78257a33ba0ad0a574fe19 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sat, 31 May 2025 18:48:33 +0000 Subject: [PATCH 06/31] ruff format/check fixes --- src/mcp/types.py | 38 ++++++++++++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/src/mcp/types.py b/src/mcp/types.py index 19c46cd9f..ea1f1fab2 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -261,6 +261,7 @@ class ToolsCapability(BaseModel): """Whether this server supports notifications for changes to the tool list.""" model_config = ConfigDict(extra="allow") + class AsyncCapability(BaseModel): """Capability for async operations.""" @@ -268,6 +269,7 @@ class AsyncCapability(BaseModel): """The maximum keep alive time in seconds for async requests.""" model_config = ConfigDict(extra="allow") + class LoggingCapability(BaseModel): """Capability for logging operations.""" @@ -287,7 +289,7 @@ class ServerCapabilities(BaseModel): """Present if the server offers any resources to read.""" tools: ToolsCapability | None = None """Present if the server offers async tool calling support.""" - async_: AsyncCapability | None = Field(alias='async', default=None) + async_: AsyncCapability | None = Field(alias="async", default=None) """Present if the server offers any tools to call.""" model_config = ConfigDict(extra="allow") @@ -813,48 +815,70 @@ class CallToolRequestParams(RequestParams): arguments: dict[str, Any] | None = None model_config = ConfigDict(extra="allow") + class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]): """Used by the client to invoke a tool provided by the server.""" method: Literal["tools/call"] params: CallToolRequestParams + class CallToolAsyncRequestParams(CallToolRequestParams): """Parameters for calling a tool asynchronously.""" keepAlive: int | None = None model_config = ConfigDict(extra="allow") -class CallToolAsyncRequest(Request[CallToolAsyncRequestParams, Literal["tools/async/call"]]): + +class CallToolAsyncRequest( + Request[CallToolAsyncRequestParams, Literal["tools/async/call"]] +): """Used by the client to invoke a tool provided by the server asynchronously.""" + method: Literal["tools/async/call"] params: CallToolAsyncRequestParams + class JoinCallToolRequestParams(RequestParams): """Parameters for joining an asynchronous tool call.""" + token: AsyncToken keepAlive: int | None = None model_config = ConfigDict(extra="allow") -class JoinCallToolAsyncRequest(Request[JoinCallToolRequestParams, Literal["tools/async/join"]]): - """Used by the client to join an tool call executing on the server asynchronously.""" + +class JoinCallToolAsyncRequest( + Request[JoinCallToolRequestParams, Literal["tools/async/join"]] +): + """Used by the client to join an tool call executing on the server + asynchronously.""" + method: Literal["tools/async/join"] params: JoinCallToolRequestParams + class CancelToolAsyncNotificationParams(NotificationParams): token: AsyncToken -class CancelToolAsyncNotification(Notification[CancelToolAsyncNotificationParams, Literal["tools/async/cancel"]]): + +class CancelToolAsyncNotification( + Notification[CancelToolAsyncNotificationParams, Literal["tools/async/cancel"]] +): method: Literal["tools/async/cancel"] params: CancelToolAsyncNotificationParams + class GetToolAsyncResultRequestParams(RequestParams): token: AsyncToken -class GetToolAsyncResultRequest(Request[GetToolAsyncResultRequestParams, Literal["tools/async/get"]]): + +class GetToolAsyncResultRequest( + Request[GetToolAsyncResultRequestParams, Literal["tools/async/get"]] +): method: Literal["tools/async/get"] params: GetToolAsyncResultRequestParams + class CallToolResult(Result): """The server's response to a tool call.""" @@ -862,8 +886,10 @@ class CallToolResult(Result): isError: bool = False isPending: bool = False + class CallToolAsyncResult(Result): """The servers response to an async tool call""" + token: AsyncToken | None = None recieved: int | None = None keepAlive: int | None = None From ffdb9501d9485af70bb63f6a439d4b89109c1377 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sat, 31 May 2025 18:49:23 +0000 Subject: [PATCH 07/31] fixed updated field name --- src/mcp/shared/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index c57c64237..5a321b017 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -433,7 +433,7 @@ async def _receive_loop(self) -> None: notification.root.params.progress, notification.root.params.total, notification.root.params.message, - notification.root.params.resource_uri, + notification.root.params.resourceUri, ) await self._received_notification(notification) From a2c0ade414d75e7248abb5e9021da1986dda8342 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sat, 31 May 2025 18:52:27 +0000 Subject: [PATCH 08/31] ignore type error --- src/mcp/server/fastmcp/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index b182decae..79c9b3a1d 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -147,7 +147,7 @@ def __init__( tools: list[Tool] | None = None, **settings: Any, ): - self.settings = Settings(**settings) + self.settings = Settings(**settings) # type: ignore self._mcp_server = MCPServer( name=name or "FastMCP", From 9b1fd8193403ffb459ac5bb54b79bab48ae1c376 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 1 Jun 2025 06:15:47 +0000 Subject: [PATCH 09/31] work in progress towards async protocol - simple call and get work --- pyproject.toml | 1 + src/mcp/server/fastmcp/server.py | 2 +- src/mcp/server/lowlevel/result_cache.py | 122 ++++++++++++++++++++++++ src/mcp/server/lowlevel/server.py | 29 ++++++ src/mcp/types.py | 3 + tests/shared/test_session.py | 71 ++++++++++++++ uv.lock | 11 +++ 7 files changed, 238 insertions(+), 1 deletion(-) create mode 100644 src/mcp/server/lowlevel/result_cache.py diff --git a/pyproject.toml b/pyproject.toml index 0a11a3b15..b08cd0ddc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "sse-starlette>=1.6.1", "pydantic-settings>=2.5.2", "uvicorn>=0.23.1; sys_platform != 'emscripten'", + "cachetools==6", ] [project.optional-dependencies] diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 79c9b3a1d..746baf290 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -147,7 +147,7 @@ def __init__( tools: list[Tool] | None = None, **settings: Any, ): - self.settings = Settings(**settings) # type: ignore + self.settings = Settings(**settings) # type: ignore self._mcp_server = MCPServer( name=name or "FastMCP", diff --git a/src/mcp/server/lowlevel/result_cache.py b/src/mcp/server/lowlevel/result_cache.py new file mode 100644 index 000000000..e6d2e8d28 --- /dev/null +++ b/src/mcp/server/lowlevel/result_cache.py @@ -0,0 +1,122 @@ +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field +from time import time +from typing import Any +from uuid import uuid4 + +from anyio import Lock, create_task_group, move_on_after +from anyio.abc import TaskGroup +from cachetools import TTLCache + +from mcp import types +from mcp.shared.context import BaseSession, RequestContext, SessionT + + +@dataclass +class InProgress: + token: str + task_group: TaskGroup | None = None + sessions: list[BaseSession[Any, Any, Any, Any, Any]] = field( + default_factory=lambda: [] + ) + + +class ResultCache: + _in_progress: dict[types.AsyncToken, InProgress] + + def __init__(self, max_size: int, max_keep_alive: int): + self._max_size = max_size + self._max_keep_alive = max_keep_alive + self._result_cache = TTLCache[types.AsyncToken, types.CallToolResult]( + self._max_size, self._max_keep_alive + ) + self._in_progress = {} + self._lock = Lock() + + async def add_call( + self, + call: Callable[[types.CallToolRequest], Awaitable[types.ServerResult]], + req: types.CallToolAsyncRequest, + ctx: RequestContext[SessionT, Any, Any], + ) -> types.CallToolAsyncResult: + in_progress = await self._new_in_progress() + timeout = min( + req.params.keepAlive or self._max_keep_alive, self._max_keep_alive + ) + + async def call_tool(): + with move_on_after(timeout) as scope: + result = await call( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams( + name=req.params.name, arguments=req.params.arguments + ), + ) + ) + if not scope.cancel_called: + async with self._lock: + assert type(result.root) is types.CallToolResult + self._result_cache[in_progress.token] = result.root + + async with create_task_group() as tg: + tg.start_soon(call_tool) + in_progress.task_group = tg + in_progress.sessions.append(ctx.session) + result = types.CallToolAsyncResult( + token=in_progress.token, + recieved=round(time()), + keepAlive=timeout, + accepted=True, + ) + return result + + async def join_call( + self, + req: types.JoinCallToolAsyncRequest, + ctx: RequestContext[SessionT, Any, Any], + ) -> types.CallToolAsyncResult: + async with self._lock: + in_progress = self._in_progress.get(req.params.token) + if in_progress is None: + # TODO consider creating new token to allow client + # to get message describing why it wasn't accepted + return types.CallToolAsyncResult(accepted=False) + else: + in_progress.sessions.append(ctx.session) + return types.CallToolAsyncResult(accepted=True) + + return + + async def cancel(self, notification: types.CancelToolAsyncNotification) -> None: + async with self._lock: + in_progress = self._in_progress.get(notification.params.token) + if in_progress is not None and in_progress.task_group is not None: + in_progress.task_group.cancel_scope.cancel() + del self._in_progress[notification.params.token] + + async def get_result(self, req: types.GetToolAsyncResultRequest): + async with self._lock: + in_progress = self._in_progress.get(req.params.token) + if in_progress is None: + return types.CallToolResult( + content=[ + types.TextContent(type="text", text="Unknown progress token") + ], + isError=True, + ) + else: + result = self._result_cache.get(in_progress.token) + if result is None: + return types.CallToolResult(content=[], isPending=True) + else: + return result + + async def _new_in_progress(self) -> InProgress: + async with self._lock: + while True: + token = str(uuid4()) + if token not in self._in_progress: + new_in_progress = InProgress(token) + self._in_progress[token] = new_in_progress + return new_in_progress diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index b98e3dd1a..03610ab23 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -81,6 +81,7 @@ async def main(): import mcp.types as types from mcp.server.lowlevel.helper_types import ReadResourceContents +from mcp.server.lowlevel.result_cache import ResultCache from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.server.stdio import stdio_server as stdio_server @@ -135,6 +136,8 @@ def __init__( [Server[LifespanResultT, RequestT]], AbstractAsyncContextManager[LifespanResultT], ] = lifespan, + max_cache_size: int = 1000, + max_cache_ttl: int = 60, ): self.name = name self.version = version @@ -145,6 +148,7 @@ def __init__( ] = { types.PingRequest: _ping_handler, } + self.result_cache = ResultCache(max_cache_size, max_cache_ttl) self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} self.notification_options = NotificationOptions() logger.debug(f"Initializing server '{name}'") @@ -426,7 +430,32 @@ async def handler(req: types.CallToolRequest): ) ) + async def async_call_handler(req: types.CallToolAsyncRequest): + ctx = request_ctx.get() + result = await self.result_cache.add_call(handler, req, ctx) + return types.ServerResult(result) + + async def async_join_handler(req: types.JoinCallToolAsyncRequest): + ctx = request_ctx.get() + result = await self.result_cache.join_call(req, ctx) + return types.ServerResult(result) + + async def async_cancel_handler(req: types.CancelToolAsyncNotification): + await self.result_cache.cancel(req) + + async def async_result_handler(req: types.GetToolAsyncResultRequest): + result = await self.result_cache.get_result(req) + return types.ServerResult(result) + self.request_handlers[types.CallToolRequest] = handler + self.request_handlers[types.CallToolAsyncRequest] = async_call_handler + self.request_handlers[types.JoinCallToolAsyncRequest] = async_join_handler + self.request_handlers[types.GetToolAsyncResultRequest] = ( + async_result_handler + ) + self.notification_handlers[types.CancelToolAsyncNotification] = ( + async_cancel_handler + ) return func return decorator diff --git a/src/mcp/types.py b/src/mcp/types.py index ea1f1fab2..d5370f035 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -799,6 +799,8 @@ class Tool(BaseModel): """A JSON Schema object defining the expected parameters for the tool.""" annotations: ToolAnnotations | None = None """Optional additional tool information.""" + preferAsync: bool | None = None + """Optional flag to suggest to client async calls should be preferred""" model_config = ConfigDict(extra="allow") @@ -1227,6 +1229,7 @@ class ClientRequest( | CallToolRequest | CallToolAsyncRequest | JoinCallToolAsyncRequest + | GetToolAsyncResultRequest | ListToolsRequest ] ): diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index eb4e004ae..5d27c9b2b 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -129,6 +129,77 @@ async def make_request(client_session): await ev_cancelled.wait() +@pytest.mark.anyio +async def test_request_async(): + """Test that requests can be run asynchronously.""" + # The tool is already registered in the fixture + + ev_tool_called = anyio.Event() + + # Start the request in a separate task so we can cancel it + def make_server() -> Server: + server = Server(name="TestSessionServer") + + # Register the tool handler + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict | None) -> list: + nonlocal ev_tool_called + if name == "async_tool": + ev_tool_called.set() + return [types.TextContent(type="text", text="test")] + raise ValueError(f"Unknown tool: {name}") + + # Register the tool so it shows up in list_tools + @server.list_tools() + async def handle_list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="async_tool", + description="A tool that does things asynchronously", + inputSchema={}, + preferAsync=True, + ) + ] + + return server + + async def make_request(client_session: ClientSession): + return await client_session.send_request( + ClientRequest( + types.CallToolAsyncRequest( + method="tools/async/call", + params=types.CallToolAsyncRequestParams( + name="async_tool", arguments={} + ), + ) + ), + types.CallToolAsyncResult, + ) + + async def get_result(client_session: ClientSession, async_token: types.AsyncToken): + return await client_session.send_request( + ClientRequest( + types.GetToolAsyncResultRequest( + method="tools/async/get", + params=types.GetToolAsyncResultRequestParams(token=async_token), + ) + ), + types.CallToolResult, + ) + + async with create_connected_server_and_client_session( + make_server() + ) as client_session: + async_result = await make_request(client_session) + assert async_result is not None + assert async_result.token is not None + with anyio.fail_after(1): # Timeout after 1 second + await ev_tool_called.wait() + result = await get_result(client_session, async_result.token) + assert type(result.content[0]) is types.TextContent + assert result.content[0].text == "test" + + @pytest.mark.anyio async def test_connection_closed(): """ diff --git a/uv.lock b/uv.lock index 180d5a9c1..3d013ed53 100644 --- a/uv.lock +++ b/uv.lock @@ -104,6 +104,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/09/71/54e999902aed72baf26bca0d50781b01838251a462612966e9fc4891eadd/black-25.1.0-py3-none-any.whl", hash = "sha256:95e8176dae143ba9097f351d174fdaf0ccd29efb414b362ae3fd72bf0f710717", size = 207646, upload-time = "2025-01-29T04:15:38.082Z" }, ] +[[package]] +name = "cachetools" +version = "6.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c0/b0/f539a1ddff36644c28a61490056e5bae43bd7386d9f9c69beae2d7e7d6d1/cachetools-6.0.0.tar.gz", hash = "sha256:f225782b84438f828328fc2ad74346522f27e5b1440f4e9fd18b20ebfd1aa2cf", size = 30160, upload-time = "2025-05-23T20:01:13.076Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6a/c3/8bb087c903c95a570015ce84e0c23ae1d79f528c349cbc141b5c4e250293/cachetools-6.0.0-py3-none-any.whl", hash = "sha256:82e73ba88f7b30228b5507dce1a1f878498fc669d972aef2dde4f3a3c24f103e", size = 10964, upload-time = "2025-05-23T20:01:11.323Z" }, +] + [[package]] name = "cairocffi" version = "1.7.1" @@ -530,6 +539,7 @@ name = "mcp" source = { editable = "." } dependencies = [ { name = "anyio" }, + { name = "cachetools" }, { name = "httpx" }, { name = "httpx-sse" }, { name = "pydantic" }, @@ -574,6 +584,7 @@ docs = [ [package.metadata] requires-dist = [ { name = "anyio", specifier = ">=4.5" }, + { name = "cachetools", specifier = "==6" }, { name = "httpx", specifier = ">=0.27" }, { name = "httpx-sse", specifier = ">=0.4" }, { name = "pydantic", specifier = ">=2.7.2,<3.0.0" }, From 5c03c557dc180636e309ce1666be55f813999400 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 1 Jun 2025 06:20:08 +0000 Subject: [PATCH 10/31] remove TODO --- src/mcp/client/session.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 5e0af34ea..8dfda8c16 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -185,9 +185,6 @@ async def send_progress_notification( progress: float, total: float | None = None, message: str | None = None, - # TODO check whether MCP spec allows clients to create resources - # for server and therefore whether resource notifications - # would be required here too ) -> None: """Send a progress notification.""" await self.send_notification( From 3e933e489727fac2dcf8747379004b82f808e57d Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 1 Jun 2025 06:27:44 +0000 Subject: [PATCH 11/31] add notes on TODO items for result cache --- src/mcp/server/lowlevel/result_cache.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/mcp/server/lowlevel/result_cache.py b/src/mcp/server/lowlevel/result_cache.py index e6d2e8d28..79d849850 100644 --- a/src/mcp/server/lowlevel/result_cache.py +++ b/src/mcp/server/lowlevel/result_cache.py @@ -22,6 +22,20 @@ class InProgress: class ResultCache: + """ + Note this class is a work in progress + TODO externalise cachetools to allow for other implementations + e.g. redis etal for production scenarios + TODO properly support join nothing actually happens at the moment + TODO intercept progress notifications from original session and pass to joined + sessions + TODO handle session closure gracefully - + at the moment old connections will hang around and cause problems later + TODO keep_alive logic is not correct as per spec - results are cached for too long, + probably better than too short + TODO needs a lot more testing around edge cases/failure scenarios + """ + _in_progress: dict[types.AsyncToken, InProgress] def __init__(self, max_size: int, max_keep_alive: int): @@ -79,7 +93,7 @@ async def join_call( async with self._lock: in_progress = self._in_progress.get(req.params.token) if in_progress is None: - # TODO consider creating new token to allow client + # TODO consider creating new token to allow client # to get message describing why it wasn't accepted return types.CallToolAsyncResult(accepted=False) else: From 4316bec60afac93fcd08bd2fcf6f78458f3f9650 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 1 Jun 2025 06:43:26 +0000 Subject: [PATCH 12/31] More TODO notes --- src/mcp/server/lowlevel/result_cache.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/mcp/server/lowlevel/result_cache.py b/src/mcp/server/lowlevel/result_cache.py index 79d849850..840445103 100644 --- a/src/mcp/server/lowlevel/result_cache.py +++ b/src/mcp/server/lowlevel/result_cache.py @@ -24,6 +24,14 @@ class InProgress: class ResultCache: """ Note this class is a work in progress + Its purpose is to act as a central point for managing in progress + async calls, allowing multiple clients to join and receive progress + updates, get results and/or cancel in progress calls + TODO CRITICAL!! Decide how to limit Async tokens for security purposes + suggest use authentication protocol for identity - may need to add an + authorisation layer to decide if a user is allowed to join an existing + async call + TODO name is probably not quite right, more of a result broker? TODO externalise cachetools to allow for other implementations e.g. redis etal for production scenarios TODO properly support join nothing actually happens at the moment @@ -34,6 +42,8 @@ class ResultCache: TODO keep_alive logic is not correct as per spec - results are cached for too long, probably better than too short TODO needs a lot more testing around edge cases/failure scenarios + TODO might look into more fine grained locks, one global lock is a bottleneck + though this could be delegated to other cache impls if external """ _in_progress: dict[types.AsyncToken, InProgress] From 58fc23945bdefc79aeddd33b02a30bc0217fcbd8 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 1 Jun 2025 06:44:19 +0000 Subject: [PATCH 13/31] format fixes --- src/mcp/server/lowlevel/result_cache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mcp/server/lowlevel/result_cache.py b/src/mcp/server/lowlevel/result_cache.py index 840445103..9c7433f32 100644 --- a/src/mcp/server/lowlevel/result_cache.py +++ b/src/mcp/server/lowlevel/result_cache.py @@ -28,8 +28,8 @@ class ResultCache: async calls, allowing multiple clients to join and receive progress updates, get results and/or cancel in progress calls TODO CRITICAL!! Decide how to limit Async tokens for security purposes - suggest use authentication protocol for identity - may need to add an - authorisation layer to decide if a user is allowed to join an existing + suggest use authentication protocol for identity - may need to add an + authorisation layer to decide if a user is allowed to join an existing async call TODO name is probably not quite right, more of a result broker? TODO externalise cachetools to allow for other implementations From 93b3c665e9aa948ffd94259606e4275a965e0c3a Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 1 Jun 2025 07:17:49 +0000 Subject: [PATCH 14/31] Add trivial authorisation check - confirms if user is the same --- src/mcp/server/lowlevel/result_cache.py | 52 ++++++++++++++++++------- 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/src/mcp/server/lowlevel/result_cache.py b/src/mcp/server/lowlevel/result_cache.py index 9c7433f32..57e98409d 100644 --- a/src/mcp/server/lowlevel/result_cache.py +++ b/src/mcp/server/lowlevel/result_cache.py @@ -1,5 +1,6 @@ from collections.abc import Awaitable, Callable from dataclasses import dataclass, field +from logging import getLogger from time import time from typing import Any from uuid import uuid4 @@ -9,12 +10,17 @@ from cachetools import TTLCache from mcp import types +from mcp.server.auth.middleware.auth_context import auth_context_var as user_context +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser from mcp.shared.context import BaseSession, RequestContext, SessionT +logger = getLogger(__name__) + @dataclass class InProgress: token: str + user: AuthenticatedUser | None = None task_group: TaskGroup | None = None sessions: list[BaseSession[Any, Any, Any, Any, Any]] = field( default_factory=lambda: [] @@ -27,10 +33,9 @@ class ResultCache: Its purpose is to act as a central point for managing in progress async calls, allowing multiple clients to join and receive progress updates, get results and/or cancel in progress calls - TODO CRITICAL!! Decide how to limit Async tokens for security purposes - suggest use authentication protocol for identity - may need to add an - authorisation layer to decide if a user is allowed to join an existing - async call + TODO IMPORTANT! may need to add an authorisation layer to decide if + a user is allowed to get/join/cancel an existing async call current + simple logic only allows same user to perform these tasks TODO name is probably not quite right, more of a result broker? TODO externalise cachetools to allow for other implementations e.g. redis etal for production scenarios @@ -86,6 +91,7 @@ async def call_tool(): async with create_task_group() as tg: tg.start_soon(call_tool) in_progress.task_group = tg + in_progress.user = user_context.get() in_progress.sessions.append(ctx.session) result = types.CallToolAsyncResult( token=in_progress.token, @@ -107,17 +113,27 @@ async def join_call( # to get message describing why it wasn't accepted return types.CallToolAsyncResult(accepted=False) else: - in_progress.sessions.append(ctx.session) - return types.CallToolAsyncResult(accepted=True) - - return + # TODO consider adding authorisation layer to make this decision + if in_progress.user == user_context.get(): + in_progress.sessions.append(ctx.session) + return types.CallToolAsyncResult(accepted=True) + else: + # TODO consider creating new token to allow client + # to get message describing why it wasn't accepted + return types.CallToolAsyncResult(accepted=False) async def cancel(self, notification: types.CancelToolAsyncNotification) -> None: async with self._lock: in_progress = self._in_progress.get(notification.params.token) if in_progress is not None and in_progress.task_group is not None: - in_progress.task_group.cancel_scope.cancel() - del self._in_progress[notification.params.token] + if in_progress.user == user_context.get(): + in_progress.task_group.cancel_scope.cancel() + del self._in_progress[notification.params.token] + else: + logger.warning( + "Permission denied for cancel notification received" + f"from {user_context.get()}" + ) async def get_result(self, req: types.GetToolAsyncResultRequest): async with self._lock: @@ -130,11 +146,19 @@ async def get_result(self, req: types.GetToolAsyncResultRequest): isError=True, ) else: - result = self._result_cache.get(in_progress.token) - if result is None: - return types.CallToolResult(content=[], isPending=True) + if in_progress.user == user_context.get(): + result = self._result_cache.get(in_progress.token) + if result is None: + return types.CallToolResult(content=[], isPending=True) + else: + return result else: - return result + return types.CallToolResult( + content=[ + types.TextContent(type="text", text="Permission denied") + ], + isError=True, + ) async def _new_in_progress(self) -> InProgress: async with self._lock: From fc6ee15dcf22afdb81714df8ab68bc4706ffe13e Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 1 Jun 2025 07:23:44 +0000 Subject: [PATCH 15/31] sort TODOs in terms of relative priority --- src/mcp/server/lowlevel/result_cache.py | 26 ++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/mcp/server/lowlevel/result_cache.py b/src/mcp/server/lowlevel/result_cache.py index 57e98409d..00e2c47d1 100644 --- a/src/mcp/server/lowlevel/result_cache.py +++ b/src/mcp/server/lowlevel/result_cache.py @@ -33,22 +33,22 @@ class ResultCache: Its purpose is to act as a central point for managing in progress async calls, allowing multiple clients to join and receive progress updates, get results and/or cancel in progress calls - TODO IMPORTANT! may need to add an authorisation layer to decide if - a user is allowed to get/join/cancel an existing async call current - simple logic only allows same user to perform these tasks - TODO name is probably not quite right, more of a result broker? - TODO externalise cachetools to allow for other implementations - e.g. redis etal for production scenarios - TODO properly support join nothing actually happens at the moment - TODO intercept progress notifications from original session and pass to joined - sessions - TODO handle session closure gracefully - + TODO CRITICAL properly support join nothing actually happens at the moment + TODO CRITICAL intercept progress notifications from original session and + pass to joined sessions + TODO MAJOR handle session closure gracefully - at the moment old connections will hang around and cause problems later - TODO keep_alive logic is not correct as per spec - results are cached for too long, + TODO MAJOR needs a lot more testing around edge cases/failure scenarios + TODO MINOR keep_alive logic is not correct as per spec - results are cached for too long, probably better than too short - TODO needs a lot more testing around edge cases/failure scenarios - TODO might look into more fine grained locks, one global lock is a bottleneck + TODO ENHANCEMENT might look into more fine grained locks, one global lock is a bottleneck though this could be delegated to other cache impls if external + TODO ENHANCEMENT externalise cachetools to allow for other implementations + e.g. redis etal for production scenarios + TODO ENHANCEMENT may need to add an authorisation layer to decide if + a user is allowed to get/join/cancel an existing async call current + simple logic only allows same user to perform these tasks + TODO TRIVIAL name is probably not quite right, more of a result broker? """ _in_progress: dict[types.AsyncToken, InProgress] From 646dd6303385cd3587e95086e239d1e9854e2e87 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 1 Jun 2025 07:24:26 +0000 Subject: [PATCH 16/31] ruff check/format --- src/mcp/server/lowlevel/result_cache.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/mcp/server/lowlevel/result_cache.py b/src/mcp/server/lowlevel/result_cache.py index 00e2c47d1..04597501e 100644 --- a/src/mcp/server/lowlevel/result_cache.py +++ b/src/mcp/server/lowlevel/result_cache.py @@ -34,15 +34,15 @@ class ResultCache: async calls, allowing multiple clients to join and receive progress updates, get results and/or cancel in progress calls TODO CRITICAL properly support join nothing actually happens at the moment - TODO CRITICAL intercept progress notifications from original session and + TODO CRITICAL intercept progress notifications from original session and pass to joined sessions TODO MAJOR handle session closure gracefully - at the moment old connections will hang around and cause problems later TODO MAJOR needs a lot more testing around edge cases/failure scenarios - TODO MINOR keep_alive logic is not correct as per spec - results are cached for too long, - probably better than too short - TODO ENHANCEMENT might look into more fine grained locks, one global lock is a bottleneck - though this could be delegated to other cache impls if external + TODO MINOR keep_alive logic is not correct as per spec - results are + cached for too long, probably better than too short + TODO ENHANCEMENT might look into more fine grained locks, one global lock + is a bottleneck though this could be delegated to other cache impls if external TODO ENHANCEMENT externalise cachetools to allow for other implementations e.g. redis etal for production scenarios TODO ENHANCEMENT may need to add an authorisation layer to decide if From ea0048cad9280a940ec05bd8283b3bc905456323 Mon Sep 17 00:00:00 2001 From: David Savage Date: Mon, 2 Jun 2025 06:27:52 +0000 Subject: [PATCH 17/31] Updates to support join --- src/mcp/server/lowlevel/result_cache.py | 257 +++++++++++++++--------- src/mcp/server/lowlevel/server.py | 3 + src/mcp/server/session.py | 17 +- src/mcp/shared/session.py | 20 +- tests/shared/test_session.py | 241 ++++++++++++++++++++-- 5 files changed, 427 insertions(+), 111 deletions(-) diff --git a/src/mcp/server/lowlevel/result_cache.py b/src/mcp/server/lowlevel/result_cache.py index 04597501e..b9bc3f55d 100644 --- a/src/mcp/server/lowlevel/result_cache.py +++ b/src/mcp/server/lowlevel/result_cache.py @@ -1,18 +1,21 @@ from collections.abc import Awaitable, Callable +from concurrent.futures import Future from dataclasses import dataclass, field from logging import getLogger from time import time +from types import TracebackType from typing import Any from uuid import uuid4 -from anyio import Lock, create_task_group, move_on_after -from anyio.abc import TaskGroup -from cachetools import TTLCache +import anyio +import anyio.to_thread +from anyio.from_thread import BlockingPortal, BlockingPortalProvider from mcp import types from mcp.server.auth.middleware.auth_context import auth_context_var as user_context from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser -from mcp.shared.context import BaseSession, RequestContext, SessionT +from mcp.server.session import ServerSession +from mcp.shared.context import RequestContext logger = getLogger(__name__) @@ -21,10 +24,8 @@ class InProgress: token: str user: AuthenticatedUser | None = None - task_group: TaskGroup | None = None - sessions: list[BaseSession[Any, Any, Any, Any, Any]] = field( - default_factory=lambda: [] - ) + future: Future[types.CallToolResult] | None = None + sessions: dict[int, ServerSession] = field(default_factory=lambda: {}) class ResultCache: @@ -33,16 +34,11 @@ class ResultCache: Its purpose is to act as a central point for managing in progress async calls, allowing multiple clients to join and receive progress updates, get results and/or cancel in progress calls - TODO CRITICAL properly support join nothing actually happens at the moment - TODO CRITICAL intercept progress notifications from original session and - pass to joined sessions - TODO MAJOR handle session closure gracefully - - at the moment old connections will hang around and cause problems later + TODO CRITICAL keep_alive logic is not correct as per spec - results currently + only kept for as long as longest session reintroduce TTL cache TODO MAJOR needs a lot more testing around edge cases/failure scenarios - TODO MINOR keep_alive logic is not correct as per spec - results are - cached for too long, probably better than too short - TODO ENHANCEMENT might look into more fine grained locks, one global lock - is a bottleneck though this could be delegated to other cache impls if external + TODO MAJOR decide if async.Locks are required for integrity of internal + data structures TODO ENHANCEMENT externalise cachetools to allow for other implementations e.g. redis etal for production scenarios TODO ENHANCEMENT may need to add an authorisation layer to decide if @@ -52,21 +48,35 @@ class ResultCache: """ _in_progress: dict[types.AsyncToken, InProgress] + _session_lookup: dict[int, types.AsyncToken] + _portal: BlockingPortal def __init__(self, max_size: int, max_keep_alive: int): self._max_size = max_size self._max_keep_alive = max_keep_alive - self._result_cache = TTLCache[types.AsyncToken, types.CallToolResult]( - self._max_size, self._max_keep_alive - ) self._in_progress = {} - self._lock = Lock() + self._session_lookup = {} + self._portal_provider = BlockingPortalProvider() + + async def __aenter__(self): + def create_portal(): + self._portal = self._portal_provider.__enter__() + + await anyio.to_thread.run_sync(create_portal) + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + await anyio.to_thread.run_sync(lambda: self._portal_provider.__exit__) async def add_call( self, call: Callable[[types.CallToolRequest], Awaitable[types.ServerResult]], req: types.CallToolAsyncRequest, - ctx: RequestContext[SessionT, Any, Any], + ctx: RequestContext[ServerSession, Any, Any], ) -> types.CallToolAsyncResult: in_progress = await self._new_in_progress() timeout = min( @@ -74,97 +84,152 @@ async def add_call( ) async def call_tool(): - with move_on_after(timeout) as scope: - result = await call( - types.CallToolRequest( - method="tools/call", - params=types.CallToolRequestParams( - name=req.params.name, arguments=req.params.arguments - ), - ) + result = await call( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams( + name=req.params.name, + arguments=req.params.arguments, + _meta=req.params.meta, + ), ) - if not scope.cancel_called: - async with self._lock: - assert type(result.root) is types.CallToolResult - self._result_cache[in_progress.token] = result.root - - async with create_task_group() as tg: - tg.start_soon(call_tool) - in_progress.task_group = tg - in_progress.user = user_context.get() - in_progress.sessions.append(ctx.session) - result = types.CallToolAsyncResult( - token=in_progress.token, - recieved=round(time()), - keepAlive=timeout, - accepted=True, ) - return result + # async with self._lock: + assert type(result.root) is types.CallToolResult + logger.debug(f"Got result {result}") + return result.root + + in_progress.user = user_context.get() + in_progress.sessions[id(ctx.session)] = ctx.session + self._session_lookup[id(ctx.session)] = in_progress.token + in_progress.future = self._portal.start_task_soon(call_tool) + result = types.CallToolAsyncResult( + token=in_progress.token, + recieved=round(time()), + keepAlive=timeout, + accepted=True, + ) + return result async def join_call( self, req: types.JoinCallToolAsyncRequest, - ctx: RequestContext[SessionT, Any, Any], + ctx: RequestContext[ServerSession, Any, Any], ) -> types.CallToolAsyncResult: - async with self._lock: - in_progress = self._in_progress.get(req.params.token) - if in_progress is None: - # TODO consider creating new token to allow client - # to get message describing why it wasn't accepted - return types.CallToolAsyncResult(accepted=False) + # async with self._lock: + in_progress = self._in_progress.get(req.params.token) + if in_progress is None: + # TODO consider creating new token to allow client + # to get message describing why it wasn't accepted + logger.warning("Discarding join request for unknown async token") + return types.CallToolAsyncResult(accepted=False) + else: + # TODO consider adding authorisation layer to make this decision + if in_progress.user == user_context.get(): + logger.debug(f"Received join from {id(ctx.session)}") + self._session_lookup[id(ctx.session)] = req.params.token + in_progress.sessions[id(ctx.session)] = ctx.session + return types.CallToolAsyncResult(token=req.params.token, accepted=True) else: - # TODO consider adding authorisation layer to make this decision - if in_progress.user == user_context.get(): - in_progress.sessions.append(ctx.session) - return types.CallToolAsyncResult(accepted=True) - else: - # TODO consider creating new token to allow client - # to get message describing why it wasn't accepted - return types.CallToolAsyncResult(accepted=False) + # TODO consider sending error via get result + return types.CallToolAsyncResult(accepted=False) async def cancel(self, notification: types.CancelToolAsyncNotification) -> None: - async with self._lock: - in_progress = self._in_progress.get(notification.params.token) - if in_progress is not None and in_progress.task_group is not None: - if in_progress.user == user_context.get(): - in_progress.task_group.cancel_scope.cancel() - del self._in_progress[notification.params.token] - else: - logger.warning( - "Permission denied for cancel notification received" - f"from {user_context.get()}" - ) + # async with self._lock: + in_progress = self._in_progress.get(notification.params.token) + if in_progress is not None: + if in_progress.user == user_context.get(): + # in_progress.task_group.cancel_scope.cancel() + del self._in_progress[notification.params.token] + else: + logger.warning( + "Permission denied for cancel notification received" + f"from {user_context.get()}" + ) async def get_result(self, req: types.GetToolAsyncResultRequest): - async with self._lock: - in_progress = self._in_progress.get(req.params.token) - if in_progress is None: - return types.CallToolResult( - content=[ - types.TextContent(type="text", text="Unknown progress token") - ], - isError=True, - ) - else: - if in_progress.user == user_context.get(): - result = self._result_cache.get(in_progress.token) - if result is None: - return types.CallToolResult(content=[], isPending=True) - else: - return result - else: + logger.debug("Getting result") + in_progress = self._in_progress.get(req.params.token) + logger.debug(f"Found in progress {in_progress}") + if in_progress is None: + return types.CallToolResult( + content=[types.TextContent(type="text", text="Unknown progress token")], + isError=True, + ) + else: + if in_progress.user == user_context.get(): + if in_progress.future is None: return types.CallToolResult( content=[ types.TextContent(type="text", text="Permission denied") ], isError=True, ) + else: + # TODO add timeout to get async result + # return isPending=True if timesout + result = in_progress.future.result() + logger.debug(f"Found result {result}") + return result + else: + return types.CallToolResult( + content=[types.TextContent(type="text", text="Permission denied")], + isError=True, + ) + + async def notification_hook( + self, session: ServerSession, notification: types.ServerNotification + ): + if type(notification.root) is types.ProgressNotification: + # async with self._lock: + async_token = self._session_lookup.get(id(session)) + if async_token is None: + # not all sessions are async so just debug + logger.debug("Discarding progress notification from unknown session") + else: + in_progress = self._in_progress.get(async_token) + if in_progress is None: + # this should not happen + logger.error("Discarding progress notification, not async") + else: + for session_id, other_session in in_progress.sessions.items(): + logger.debug(f"Checking {session_id} == {id(session)}") + if not session_id == id(session): + logger.debug(f"Sending progress to {id(other_session)}") + await other_session.send_progress_notification( + progress_token=1, + progress=notification.root.params.progress, + total=notification.root.params.total, + message=notification.root.params.message, + resource_uri=notification.root.params.resourceUri, + ) + + async def session_close_hook(self, session: ServerSession): + logger.debug(f"Closing {id(session)}") + dropped = self._session_lookup.pop(id(session), None) + if dropped is None: + logger.warning(f"Discarding callback from unknown session {id(session)}") + else: + in_progress = self._in_progress.get(dropped) + if in_progress is None: + logger.warning("In progress not found") + else: + found = in_progress.sessions.pop(id(session), None) + if found is None: + logger.warning("No session found") + if len(in_progress.sessions) == 0: + self._in_progress.pop(dropped, None) + logger.debug("In progress found") + if in_progress.future is None: + logger.warning("In progress future is none") + else: + logger.debug("Cancelled in progress future") + in_progress.future.cancel() async def _new_in_progress(self) -> InProgress: - async with self._lock: - while True: - token = str(uuid4()) - if token not in self._in_progress: - new_in_progress = InProgress(token) - self._in_progress[token] = new_in_progress - return new_in_progress + while True: + token = str(uuid4()) + if token not in self._in_progress: + new_in_progress = InProgress(token) + self._in_progress[token] = new_in_progress + return new_in_progress diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 03610ab23..bb4f9b45b 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -534,8 +534,11 @@ async def run( write_stream, initialization_options, stateless=stateless, + notification_hook=self.result_cache.notification_hook, + session_close_hook=self.result_cache.session_close_hook, ) ) + await stack.enter_async_context(self.result_cache) async with anyio.create_task_group() as tg: async for message in session.incoming_messages: diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 9c839326c..d5b08ee58 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -37,6 +37,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: be instantiated directly by users of the MCP framework. """ +from collections.abc import Awaitable, Callable from enum import Enum from typing import Annotated, Any, TypeVar @@ -44,6 +45,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic.networks import AnyUrl, UrlConstraints +from typing_extensions import Self import mcp.types as types from mcp.server.models import InitializationOptions @@ -88,9 +90,16 @@ def __init__( write_stream: MemoryObjectSendStream[SessionMessage], init_options: InitializationOptions, stateless: bool = False, + notification_hook: Callable[[Self, types.ServerNotification], Awaitable[None]] + | None = None, + session_close_hook: Callable[[Self], Awaitable[None]] | None = None, ) -> None: super().__init__( - read_stream, write_stream, types.ClientRequest, types.ClientNotification + read_stream, + write_stream, + types.ClientRequest, + types.ClientNotification, + notification_hook=notification_hook, ) self._initialization_state = ( InitializationState.Initialized @@ -106,6 +115,12 @@ def __init__( lambda: self._incoming_message_stream_reader.aclose() ) + async def call_session_close(): + if session_close_hook is not None: + await session_close_hook(self) + + self._exit_stack.push_async_callback(call_session_close) + @property def client_params(self) -> types.InitializeRequestParams | None: return self._client_params diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 5a321b017..4cbbe2b71 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,10 +1,17 @@ import inspect import logging -from collections.abc import Callable +from collections.abc import Awaitable, Callable from contextlib import AsyncExitStack from datetime import timedelta from types import TracebackType -from typing import Annotated, Any, Generic, Protocol, TypeVar, runtime_checkable +from typing import ( + Annotated, + Any, + Generic, + Protocol, + TypeVar, + runtime_checkable, +) import anyio import httpx @@ -209,6 +216,8 @@ def __init__( receive_notification_type: type[ReceiveNotificationT], # If none, reading will never time out read_timeout_seconds: timedelta | None = None, + notification_hook: Callable[[Self, SendNotificationT], Awaitable[None]] + | None = None, ) -> None: self._read_stream = read_stream self._write_stream = write_stream @@ -221,6 +230,7 @@ def __init__( self._progress_callbacks = {} self._resource_callbacks = {} self._exit_stack = AsyncExitStack() + self._notification_hook = notification_hook async def __aenter__(self) -> Self: self._task_group = anyio.create_task_group() @@ -339,6 +349,12 @@ async def send_notification( Emits a notification, which is a one-way message that does not expect a response. """ + if self._notification_hook: + try: + await self._notification_hook(self, notification) + except Exception: + logging.exception("Notification hook failed") + # Some transport implementations may need to set the related_request_id # to attribute to the notifications to the request that triggered them. jsonrpc_notification = JSONRPCNotification( diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 5d27c9b2b..415041614 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -177,29 +177,246 @@ async def make_request(client_session: ClientSession): ) async def get_result(client_session: ClientSession, async_token: types.AsyncToken): - return await client_session.send_request( - ClientRequest( - types.GetToolAsyncResultRequest( - method="tools/async/get", - params=types.GetToolAsyncResultRequestParams(token=async_token), + with anyio.fail_after(1): + while True: + print("getting results") + result = await client_session.send_request( + ClientRequest( + types.GetToolAsyncResultRequest( + method="tools/async/get", + params=types.GetToolAsyncResultRequestParams( + token=async_token + ), + ) + ), + types.CallToolResult, ) - ), - types.CallToolResult, - ) + print(f"retrieved {result}") + if result.isPending: + await anyio.sleep(1) + elif result.isError: + print("wibble") + raise RuntimeError(str(result)) + else: + return result async with create_connected_server_and_client_session( make_server() ) as client_session: - async_result = await make_request(client_session) - assert async_result is not None - assert async_result.token is not None + async_call = await make_request(client_session) + assert async_call is not None + assert async_call.token is not None with anyio.fail_after(1): # Timeout after 1 second await ev_tool_called.wait() - result = await get_result(client_session, async_result.token) + result = await get_result(client_session, async_call.token) assert type(result.content[0]) is types.TextContent assert result.content[0].text == "test" +@pytest.mark.anyio +async def test_request_async_join(): + """Test that requests can be run asynchronously.""" + # The tool is already registered in the fixture + + # TODO note these events are not working as expected + # test code below uses move_on_after rather than + # fail_after as events are not triggered as expected + # this effectively makes the test lots of sleep + # calls, needs further investigation + ev_client_1_started = anyio.Event() + ev_client2_joined = anyio.Event() + ev_client1_progressed_1 = anyio.Event() + ev_client1_progressed_2 = anyio.Event() + ev_client2_progressed_1 = anyio.Event() + ev_done = anyio.Event() + + # Start the request in a separate task so we can cancel it + def make_server() -> Server: + server = Server(name="TestSessionServer") + + # Register the tool handler + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict | None) -> list: + nonlocal ev_client2_joined + if name == "async_tool": + try: + print("sending 1/2") + await server.request_context.session.send_progress_notification( + progress_token=server.request_context.request_id, + progress=1, + total=2, + ) + print("sent 1/2") + with anyio.move_on_after(10): # Timeout after 1 second + # TODO this is not working for some unknown reason + print("waiting for client 2 joined") + await ev_client2_joined.wait() + # await anyio.sleep(1) + + print("sending 2/2") + await server.request_context.session.send_progress_notification( + progress_token=server.request_context.request_id, + progress=2, + total=2, + ) + print("sent 2/2") + result = [types.TextContent(type="text", text="test")] + print("sending result") + return result + except Exception as e: + print(f"Caught: {str(e)}") + raise e + else: + raise ValueError(f"Unknown tool: {name}") + + # Register the tool so it shows up in list_tools + @server.list_tools() + async def handle_list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="async_tool", + description="A tool that does things asynchronously", + inputSchema={}, + preferAsync=True, + ) + ] + + return server + + async def progress_callback_initial( + progress: float, total: float | None, message: str | None + ): + nonlocal ev_client1_progressed_1 + nonlocal ev_client1_progressed_2 + print(f"progress initial started: {progress}/{total}") + if progress == 1.0: + ev_client1_progressed_1.set() + print("progress 1 set") + else: + ev_client1_progressed_2.set() + print("progress 1 set") + print(f"progress initial done: {progress}/{total}") + + async def make_request(client_session: ClientSession): + return await client_session.send_request( + ClientRequest( + types.CallToolAsyncRequest( + method="tools/async/call", + params=types.CallToolAsyncRequestParams( + name="async_tool", + arguments={}, + ), + ) + ), + types.CallToolAsyncResult, + progress_callback=progress_callback_initial, + ) + + async def progress_callback_joined( + progress: float, total: float | None, message: str | None + ): + nonlocal ev_client2_progressed_1 + print(f"progress joined started: {progress}/{total}") + ev_client2_progressed_1.set() + print(f"progress joined done: {progress}/{total}") + + async def join_request( + client_session: ClientSession, async_token: types.AsyncToken + ): + return await client_session.send_request( + ClientRequest( + types.JoinCallToolAsyncRequest( + method="tools/async/join", + params=types.JoinCallToolRequestParams(token=async_token), + ) + ), + types.CallToolAsyncResult, + progress_callback=progress_callback_joined, + ) + + async def get_result(client_session: ClientSession, async_token: types.AsyncToken): + while True: + result = await client_session.send_request( + ClientRequest( + types.GetToolAsyncResultRequest( + method="tools/async/get", + params=types.GetToolAsyncResultRequestParams(token=async_token), + ) + ), + types.CallToolResult, + ) + if result.isPending: + print("Result is pending, sleeping") + await anyio.sleep(1) + elif result.isError: + raise RuntimeError(str(result)) + else: + return result + + server = make_server() + token = None + + async with anyio.create_task_group() as tg: + + async def client_1_submit(): + async with create_connected_server_and_client_session( + server + ) as client_session: + nonlocal token + nonlocal ev_client_1_started + nonlocal ev_client2_progressed_1 + nonlocal ev_done + async_call = await make_request(client_session) + assert async_call is not None + assert async_call.token is not None + token = async_call.token + ev_client_1_started.set() + print("Got token") + with anyio.move_on_after(1): # Timeout after 1 second + print("waiting for client 2 progress") + await ev_client2_progressed_1.wait() + + print("Getting result") + result = await get_result(client_session, token) + assert type(result.content[0]) is types.TextContent + assert result.content[0].text == "test" + ev_done.set() + + async def client_2_join(): + async with create_connected_server_and_client_session( + server + ) as client_session: + nonlocal token + nonlocal ev_client_1_started + nonlocal ev_client1_progressed_1 + nonlocal ev_client2_joined + nonlocal ev_done + + with anyio.move_on_after(1): # Timeout after 1 second + print("waiting for token") + await ev_client_1_started.wait() + print("waiting for progress 1") + await ev_client1_progressed_1.wait() + + with anyio.move_on_after(1): # Timeout after 1 second + assert token is not None + print("joining") + join_async = await join_request(client_session, token) + assert join_async is not None + assert join_async.token is not None + print("joined") + ev_client2_joined.set() + print("client 2 joined") + + with anyio.move_on_after(1): # Timeout after 1 second + print("client 2 waiting for done") + await ev_done.wait() + print("client 2 done") + + tg.start_soon(client_1_submit) + tg.start_soon(client_2_join) + + @pytest.mark.anyio async def test_connection_closed(): """ From bef1d72ae9e79db028961ddc73810691c811b4f7 Mon Sep 17 00:00:00 2001 From: David Savage Date: Tue, 3 Jun 2025 05:36:08 +0000 Subject: [PATCH 18/31] Further updates, create lowlevel test for result caching and improve test logic for higher level session join test (currently skipped due to subtle bug in test) --- src/mcp/server/lowlevel/result_cache.py | 26 +++-- src/mcp/server/lowlevel/server.py | 2 +- tests/server/lowlevel/test_result_cache.py | 121 +++++++++++++++++++ tests/shared/test_session.py | 129 ++++++++++++--------- 4 files changed, 212 insertions(+), 66 deletions(-) create mode 100644 tests/server/lowlevel/test_result_cache.py diff --git a/src/mcp/server/lowlevel/result_cache.py b/src/mcp/server/lowlevel/result_cache.py index b9bc3f55d..56ec4a183 100644 --- a/src/mcp/server/lowlevel/result_cache.py +++ b/src/mcp/server/lowlevel/result_cache.py @@ -26,7 +26,7 @@ class InProgress: user: AuthenticatedUser | None = None future: Future[types.CallToolResult] | None = None sessions: dict[int, ServerSession] = field(default_factory=lambda: {}) - + session_progress: dict[int, types.ProgressToken | None] = field(default_factory=lambda: {}) class ResultCache: """ @@ -72,7 +72,7 @@ async def __aexit__( ) -> bool | None: await anyio.to_thread.run_sync(lambda: self._portal_provider.__exit__) - async def add_call( + async def start_call( self, call: Callable[[types.CallToolRequest], Awaitable[types.ServerResult]], req: types.CallToolAsyncRequest, @@ -101,6 +101,7 @@ async def call_tool(): in_progress.user = user_context.get() in_progress.sessions[id(ctx.session)] = ctx.session + in_progress.session_progress[id(ctx.session)] = None if req.params.meta is None else req.params.meta.progressToken self._session_lookup[id(ctx.session)] = in_progress.token in_progress.future = self._portal.start_task_soon(call_tool) result = types.CallToolAsyncResult( @@ -129,6 +130,7 @@ async def join_call( logger.debug(f"Received join from {id(ctx.session)}") self._session_lookup[id(ctx.session)] = req.params.token in_progress.sessions[id(ctx.session)] = ctx.session + in_progress.session_progress[id(ctx.session)] = None if req.params.meta is None else req.params.meta.progressToken return types.CallToolAsyncResult(token=req.params.token, accepted=True) else: # TODO consider sending error via get result @@ -167,10 +169,15 @@ async def get_result(self, req: types.GetToolAsyncResultRequest): ) else: # TODO add timeout to get async result - # return isPending=True if timesout - result = in_progress.future.result() - logger.debug(f"Found result {result}") - return result + try: + result = in_progress.future.result(1) + logger.debug(f"Found result {result}") + return result + except TimeoutError: + return types.CallToolResult( + content=[], + isPending=True, + ) else: return types.CallToolResult( content=[types.TextContent(type="text", text="Permission denied")], @@ -180,6 +187,7 @@ async def get_result(self, req: types.GetToolAsyncResultRequest): async def notification_hook( self, session: ServerSession, notification: types.ServerNotification ): + logger.debug(f"received {notification} from {id(session)}") if type(notification.root) is types.ProgressNotification: # async with self._lock: async_token = self._session_lookup.get(id(session)) @@ -196,8 +204,12 @@ async def notification_hook( logger.debug(f"Checking {session_id} == {id(session)}") if not session_id == id(session): logger.debug(f"Sending progress to {id(other_session)}") + progress_token = in_progress.session_progress.get(id(other_session)) + assert progress_token is not None await other_session.send_progress_notification( - progress_token=1, + # TODO this token is incorrect + # it needs to be collected from original request + progress_token=progress_token, progress=notification.root.params.progress, total=notification.root.params.total, message=notification.root.params.message, diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index bb4f9b45b..1455e8ac0 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -432,7 +432,7 @@ async def handler(req: types.CallToolRequest): async def async_call_handler(req: types.CallToolAsyncRequest): ctx = request_ctx.get() - result = await self.result_cache.add_call(handler, req, ctx) + result = await self.result_cache.start_call(handler, req, ctx) return types.ServerResult(result) async def async_join_handler(req: types.JoinCallToolAsyncRequest): diff --git a/tests/server/lowlevel/test_result_cache.py b/tests/server/lowlevel/test_result_cache.py new file mode 100644 index 000000000..093f211da --- /dev/null +++ b/tests/server/lowlevel/test_result_cache.py @@ -0,0 +1,121 @@ +import pytest +from mcp import types +from mcp.server.lowlevel.result_cache import ResultCache +from unittest.mock import AsyncMock, Mock, patch +from contextlib import AsyncExitStack + +@pytest.mark.anyio +async def test_async_call(): + """Tests basic async call""" + async def test_call(call: types.CallToolRequest) -> types.ServerResult: + return types.ServerResult(types.CallToolResult( + content=[types.TextContent( + type="text", + text="test" + )] + )) + async_call = types.CallToolAsyncRequest( + method="tools/async/call", + params=types.CallToolAsyncRequestParams( + name="test" + ) + ) + + mock_session = AsyncMock() + mock_context = Mock() + mock_context.session = mock_session + result_cache = ResultCache(max_size=1, max_keep_alive=1) + async with AsyncExitStack() as stack: + await stack.enter_async_context(result_cache) + async_call_ref = await result_cache.start_call(test_call, async_call, mock_context) + assert async_call_ref.token is not None + + result = await result_cache.get_result(types.GetToolAsyncResultRequest( + method="tools/async/get", + params=types.GetToolAsyncResultRequestParams( + token = async_call_ref.token + ) + )) + + assert not result.isError + assert not result.isPending + assert len(result.content) == 1 + assert type(result.content[0]) is types.TextContent + assert result.content[0].text == "test" + +@pytest.mark.anyio +async def test_async_join_call_progress(): + """Tests basic async call""" + async def test_call(call: types.CallToolRequest) -> types.ServerResult: + return types.ServerResult(types.CallToolResult( + content=[types.TextContent( + type="text", + text="test" + )] + )) + async_call = types.CallToolAsyncRequest( + method="tools/async/call", + params=types.CallToolAsyncRequestParams( + name="test" + ) + ) + + mock_session_1 = AsyncMock() + mock_context_1 = Mock() + mock_context_1.session = mock_session_1 + + mock_session_2 = AsyncMock() + mock_context_2 = Mock() + + mock_context_2.session = mock_session_2 + mock_session_2.send_progress_notification.result = None + + result_cache = ResultCache(max_size=1, max_keep_alive=1) + async with AsyncExitStack() as stack: + await stack.enter_async_context(result_cache) + async_call_ref = await result_cache.start_call(test_call, async_call, mock_context_1) + assert async_call_ref.token is not None + + await result_cache.join_call( + req=types.JoinCallToolAsyncRequest( + method="tools/async/join", + params=types.JoinCallToolRequestParams( + token=async_call_ref.token, + _meta = types.RequestParams.Meta( + progressToken="test" + ) + ) + ), + ctx=mock_context_2 + ) + assert async_call_ref.token is not None + await result_cache.notification_hook( + session=mock_session_1, + notification=types.ServerNotification(types.ProgressNotification( + method="notifications/progress", + params=types.ProgressNotificationParams( + progressToken="test", + progress=1 + ) + ))) + + result = await result_cache.get_result(types.GetToolAsyncResultRequest( + method="tools/async/get", + params=types.GetToolAsyncResultRequestParams( + token = async_call_ref.token + ) + )) + + assert not result.isError + assert not result.isPending + assert len(result.content) == 1 + assert type(result.content[0]) is types.TextContent + assert result.content[0].text == "test" + mock_context_1.send_progress_notification.assert_not_called() + mock_session_2.send_progress_notification.assert_called_with( + progress_token="test", + progress=1.0, + total=None, + message=None, + resource_uri = None + ) diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 415041614..08feb9da4 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -195,7 +195,6 @@ async def get_result(client_session: ClientSession, async_token: types.AsyncToke if result.isPending: await anyio.sleep(1) elif result.isError: - print("wibble") raise RuntimeError(str(result)) else: return result @@ -212,10 +211,17 @@ async def get_result(client_session: ClientSession, async_token: types.AsyncToke assert type(result.content[0]) is types.TextContent assert result.content[0].text == "test" +from logging import getLogger + +logger = getLogger(__name__) @pytest.mark.anyio +@pytest.mark.skip(reason="This test does not work, there is a subtle " + "bug with event.wait, lower level test_result_cache " + "tests underlying behaviour, revisit with feedback " \ + "from someone who cah help debug") async def test_request_async_join(): - """Test that requests can be run asynchronously.""" + """Test that requests can be joined from external sessions.""" # The tool is already registered in the fixture # TODO note these events are not working as expected @@ -224,12 +230,13 @@ async def test_request_async_join(): # this effectively makes the test lots of sleep # calls, needs further investigation ev_client_1_started = anyio.Event() - ev_client2_joined = anyio.Event() - ev_client1_progressed_1 = anyio.Event() - ev_client1_progressed_2 = anyio.Event() - ev_client2_progressed_1 = anyio.Event() + ev_client_2_joined = anyio.Event() + ev_client_1_progressed_1 = anyio.Event() + ev_client_1_progressed_2 = anyio.Event() + ev_client_2_progressed_1 = anyio.Event() ev_done = anyio.Event() + # Start the request in a separate task so we can cancel it def make_server() -> Server: server = Server(name="TestSessionServer") @@ -237,34 +244,34 @@ def make_server() -> Server: # Register the tool handler @server.call_tool() async def handle_call_tool(name: str, arguments: dict | None) -> list: - nonlocal ev_client2_joined + nonlocal ev_client_2_joined if name == "async_tool": try: - print("sending 1/2") + logger.info("tool: sending 1/2") await server.request_context.session.send_progress_notification( progress_token=server.request_context.request_id, progress=1, total=2, ) - print("sent 1/2") - with anyio.move_on_after(10): # Timeout after 1 second + logger.info("tool: sent 1/2") + with anyio.fail_after(10): # Timeout after 1 second # TODO this is not working for some unknown reason - print("waiting for client 2 joined") - await ev_client2_joined.wait() - # await anyio.sleep(1) + logger.info("tool: waiting for client 2 joined") + await ev_client_2_joined.wait() - print("sending 2/2") + logger.info("tool: sending 2/2") await server.request_context.session.send_progress_notification( progress_token=server.request_context.request_id, progress=2, total=2, ) - print("sent 2/2") + logger.info("tool: sent 2/2") result = [types.TextContent(type="text", text="test")] - print("sending result") + logger.info("tool: sending result") return result except Exception as e: - print(f"Caught: {str(e)}") + logger.exception(e) + logger.info(f"tool: caught: {str(e)}") raise e else: raise ValueError(f"Unknown tool: {name}") @@ -283,19 +290,19 @@ async def handle_list_tools() -> list[types.Tool]: return server - async def progress_callback_initial( + async def client_1_progress_callback( progress: float, total: float | None, message: str | None ): - nonlocal ev_client1_progressed_1 - nonlocal ev_client1_progressed_2 - print(f"progress initial started: {progress}/{total}") + nonlocal ev_client_1_progressed_1 + nonlocal ev_client_1_progressed_2 + logger.info(f"client1: progress started: {progress}/{total}") if progress == 1.0: - ev_client1_progressed_1.set() - print("progress 1 set") + ev_client_1_progressed_1.set() + logger.info("client1: progress 1 set") else: - ev_client1_progressed_2.set() - print("progress 1 set") - print(f"progress initial done: {progress}/{total}") + ev_client_1_progressed_2.set() + logger.info("client1: progress 2 set") + logger.info(f"client1: progress done: {progress}/{total}") async def make_request(client_session: ClientSession): return await client_session.send_request( @@ -309,19 +316,20 @@ async def make_request(client_session: ClientSession): ) ), types.CallToolAsyncResult, - progress_callback=progress_callback_initial, + progress_callback=client_1_progress_callback, ) - async def progress_callback_joined( + async def client_2_progress_callback( progress: float, total: float | None, message: str | None ): - nonlocal ev_client2_progressed_1 - print(f"progress joined started: {progress}/{total}") - ev_client2_progressed_1.set() - print(f"progress joined done: {progress}/{total}") + nonlocal ev_client_2_progressed_1 + logger.info(f"client2: progress started: {progress}/{total}") + ev_client_2_progressed_1.set() + logger.info(f"client2: progress done: {progress}/{total}") async def join_request( - client_session: ClientSession, async_token: types.AsyncToken + client_session: ClientSession, + async_token: types.AsyncToken ): return await client_session.send_request( ClientRequest( @@ -331,7 +339,7 @@ async def join_request( ) ), types.CallToolAsyncResult, - progress_callback=progress_callback_joined, + progress_callback=client_2_progress_callback, ) async def get_result(client_session: ClientSession, async_token: types.AsyncToken): @@ -346,7 +354,7 @@ async def get_result(client_session: ClientSession, async_token: types.AsyncToke types.CallToolResult, ) if result.isPending: - print("Result is pending, sleeping") + logger.info("client1: result is pending, sleeping") await anyio.sleep(1) elif result.isError: raise RuntimeError(str(result)) @@ -357,30 +365,30 @@ async def get_result(client_session: ClientSession, async_token: types.AsyncToke token = None async with anyio.create_task_group() as tg: - async def client_1_submit(): async with create_connected_server_and_client_session( server ) as client_session: nonlocal token nonlocal ev_client_1_started - nonlocal ev_client2_progressed_1 + nonlocal ev_client_2_progressed_1 nonlocal ev_done async_call = await make_request(client_session) assert async_call is not None assert async_call.token is not None token = async_call.token ev_client_1_started.set() - print("Got token") - with anyio.move_on_after(1): # Timeout after 1 second - print("waiting for client 2 progress") - await ev_client2_progressed_1.wait() + logger.info("client1: got token") + with anyio.fail_after(1): # Timeout after 1 second + logger.info("client1: waiting for client 2 progress") + await ev_client_2_progressed_1.wait() - print("Getting result") + logger.info("client1: getting result") result = await get_result(client_session, token) + ev_done.set() + assert type(result.content[0]) is types.TextContent assert result.content[0].text == "test" - ev_done.set() async def client_2_join(): async with create_connected_server_and_client_session( @@ -388,34 +396,39 @@ async def client_2_join(): ) as client_session: nonlocal token nonlocal ev_client_1_started - nonlocal ev_client1_progressed_1 - nonlocal ev_client2_joined + nonlocal ev_client_1_progressed_1 + nonlocal ev_client_2_joined nonlocal ev_done - with anyio.move_on_after(1): # Timeout after 1 second - print("waiting for token") + with anyio.fail_after(1): # Timeout after 1 second + logger.info("client2: waiting for token") await ev_client_1_started.wait() - print("waiting for progress 1") - await ev_client1_progressed_1.wait() - - with anyio.move_on_after(1): # Timeout after 1 second assert token is not None - print("joining") + logger.info("client2: got token") + logger.info("client2: waiting for client 1 progress 1") + await ev_client_1_progressed_1.wait() + + with anyio.fail_after(1): # Timeout after 1 second + logger.info("client2: joining") join_async = await join_request(client_session, token) assert join_async is not None assert join_async.token is not None - print("joined") - ev_client2_joined.set() - print("client 2 joined") + ev_client_2_joined.set() + ("client2: joined") - with anyio.move_on_after(1): # Timeout after 1 second - print("client 2 waiting for done") + with anyio.fail_after(10): # Timeout after 1 second + logger.info("client2: waiting for done") await ev_done.wait() - print("client 2 done") + logger.info("client2: done") tg.start_soon(client_1_submit) tg.start_soon(client_2_join) + assert ev_client_1_started.is_set() + assert ev_client_2_joined.is_set() + assert ev_client_1_progressed_1.is_set() + assert ev_client_1_progressed_2.is_set() + assert ev_client_2_progressed_1.is_set() @pytest.mark.anyio async def test_connection_closed(): From ba0f9ee466a98e3e11c49854750d09ddfd5d1701 Mon Sep 17 00:00:00 2001 From: David Savage Date: Tue, 3 Jun 2025 05:47:50 +0000 Subject: [PATCH 19/31] formatting and other misc tidy up highlighted by ruff check --- src/mcp/server/lowlevel/result_cache.py | 64 +++++++------ tests/server/lowlevel/test_result_cache.py | 106 +++++++++++---------- tests/shared/test_session.py | 22 +++-- 3 files changed, 104 insertions(+), 88 deletions(-) diff --git a/src/mcp/server/lowlevel/result_cache.py b/src/mcp/server/lowlevel/result_cache.py index 56ec4a183..5870704a5 100644 --- a/src/mcp/server/lowlevel/result_cache.py +++ b/src/mcp/server/lowlevel/result_cache.py @@ -26,7 +26,10 @@ class InProgress: user: AuthenticatedUser | None = None future: Future[types.CallToolResult] | None = None sessions: dict[int, ServerSession] = field(default_factory=lambda: {}) - session_progress: dict[int, types.ProgressToken | None] = field(default_factory=lambda: {}) + session_progress: dict[int, types.ProgressToken | None] = field( + default_factory=lambda: {} + ) + class ResultCache: """ @@ -100,9 +103,14 @@ async def call_tool(): return result.root in_progress.user = user_context.get() - in_progress.sessions[id(ctx.session)] = ctx.session - in_progress.session_progress[id(ctx.session)] = None if req.params.meta is None else req.params.meta.progressToken - self._session_lookup[id(ctx.session)] = in_progress.token + session_id = id(ctx.session) + in_progress.sessions[session_id] = ctx.session + if req.params.meta is not None: + progress_token = req.params.meta.progressToken + else: + progress_token = None + in_progress.session_progress[session_id] = progress_token + self._session_lookup[session_id] = in_progress.token in_progress.future = self._portal.start_task_soon(call_tool) result = types.CallToolAsyncResult( token=in_progress.token, @@ -127,10 +135,15 @@ async def join_call( else: # TODO consider adding authorisation layer to make this decision if in_progress.user == user_context.get(): - logger.debug(f"Received join from {id(ctx.session)}") - self._session_lookup[id(ctx.session)] = req.params.token - in_progress.sessions[id(ctx.session)] = ctx.session - in_progress.session_progress[id(ctx.session)] = None if req.params.meta is None else req.params.meta.progressToken + session_id = id(ctx.session) + logger.debug(f"Received join from {session_id}") + self._session_lookup[session_id] = req.params.token + in_progress.sessions[session_id] = ctx.session + if req.params.meta is not None: + progress_token = req.params.meta.progressToken + else: + progress_token = None + in_progress.session_progress[session_id] = progress_token return types.CallToolAsyncResult(token=req.params.token, accepted=True) else: # TODO consider sending error via get result @@ -196,25 +209,22 @@ async def notification_hook( logger.debug("Discarding progress notification from unknown session") else: in_progress = self._in_progress.get(async_token) - if in_progress is None: - # this should not happen - logger.error("Discarding progress notification, not async") - else: - for session_id, other_session in in_progress.sessions.items(): - logger.debug(f"Checking {session_id} == {id(session)}") - if not session_id == id(session): - logger.debug(f"Sending progress to {id(other_session)}") - progress_token = in_progress.session_progress.get(id(other_session)) - assert progress_token is not None - await other_session.send_progress_notification( - # TODO this token is incorrect - # it needs to be collected from original request - progress_token=progress_token, - progress=notification.root.params.progress, - total=notification.root.params.total, - message=notification.root.params.message, - resource_uri=notification.root.params.resourceUri, - ) + assert in_progress is not None + for other_id, other_session in in_progress.sessions.items(): + logger.debug(f"Checking {other_id} == {id(session)}") + if not other_id == id(session): + logger.debug(f"Sending progress to {other_id}") + progress_token = in_progress.session_progress.get(other_id) + assert progress_token is not None + await other_session.send_progress_notification( + # TODO this token is incorrect + # it needs to be collected from original request + progress_token=progress_token, + progress=notification.root.params.progress, + total=notification.root.params.total, + message=notification.root.params.message, + resource_uri=notification.root.params.resourceUri, + ) async def session_close_hook(self, session: ServerSession): logger.debug(f"Closing {id(session)}") diff --git a/tests/server/lowlevel/test_result_cache.py b/tests/server/lowlevel/test_result_cache.py index 093f211da..8ff9828be 100644 --- a/tests/server/lowlevel/test_result_cache.py +++ b/tests/server/lowlevel/test_result_cache.py @@ -1,24 +1,23 @@ +from contextlib import AsyncExitStack +from unittest.mock import AsyncMock, Mock + import pytest + from mcp import types from mcp.server.lowlevel.result_cache import ResultCache -from unittest.mock import AsyncMock, Mock, patch -from contextlib import AsyncExitStack + @pytest.mark.anyio async def test_async_call(): """Tests basic async call""" + async def test_call(call: types.CallToolRequest) -> types.ServerResult: - return types.ServerResult(types.CallToolResult( - content=[types.TextContent( - type="text", - text="test" - )] - )) - async_call = types.CallToolAsyncRequest( - method="tools/async/call", - params=types.CallToolAsyncRequestParams( - name="test" + return types.ServerResult( + types.CallToolResult(content=[types.TextContent(type="text", text="test")]) ) + + async_call = types.CallToolAsyncRequest( + method="tools/async/call", params=types.CallToolAsyncRequestParams(name="test") ) mock_session = AsyncMock() @@ -27,15 +26,19 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult: result_cache = ResultCache(max_size=1, max_keep_alive=1) async with AsyncExitStack() as stack: await stack.enter_async_context(result_cache) - async_call_ref = await result_cache.start_call(test_call, async_call, mock_context) + async_call_ref = await result_cache.start_call( + test_call, async_call, mock_context + ) assert async_call_ref.token is not None - result = await result_cache.get_result(types.GetToolAsyncResultRequest( - method="tools/async/get", - params=types.GetToolAsyncResultRequestParams( - token = async_call_ref.token + result = await result_cache.get_result( + types.GetToolAsyncResultRequest( + method="tools/async/get", + params=types.GetToolAsyncResultRequestParams( + token=async_call_ref.token + ), ) - )) + ) assert not result.isError assert not result.isPending @@ -43,21 +46,18 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult: assert type(result.content[0]) is types.TextContent assert result.content[0].text == "test" + @pytest.mark.anyio async def test_async_join_call_progress(): """Tests basic async call""" + async def test_call(call: types.CallToolRequest) -> types.ServerResult: - return types.ServerResult(types.CallToolResult( - content=[types.TextContent( - type="text", - text="test" - )] - )) - async_call = types.CallToolAsyncRequest( - method="tools/async/call", - params=types.CallToolAsyncRequestParams( - name="test" + return types.ServerResult( + types.CallToolResult(content=[types.TextContent(type="text", text="test")]) ) + + async_call = types.CallToolAsyncRequest( + method="tools/async/call", params=types.CallToolAsyncRequestParams(name="test") ) mock_session_1 = AsyncMock() @@ -73,7 +73,9 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult: result_cache = ResultCache(max_size=1, max_keep_alive=1) async with AsyncExitStack() as stack: await stack.enter_async_context(result_cache) - async_call_ref = await result_cache.start_call(test_call, async_call, mock_context_1) + async_call_ref = await result_cache.start_call( + test_call, async_call, mock_context_1 + ) assert async_call_ref.token is not None await result_cache.join_call( @@ -81,30 +83,32 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult: method="tools/async/join", params=types.JoinCallToolRequestParams( token=async_call_ref.token, - _meta = types.RequestParams.Meta( - progressToken="test" - ) - ) + _meta=types.RequestParams.Meta(progressToken="test"), + ), ), - ctx=mock_context_2 + ctx=mock_context_2, ) assert async_call_ref.token is not None await result_cache.notification_hook( - session=mock_session_1, - notification=types.ServerNotification(types.ProgressNotification( - method="notifications/progress", - params=types.ProgressNotificationParams( - progressToken="test", - progress=1 + session=mock_session_1, + notification=types.ServerNotification( + types.ProgressNotification( + method="notifications/progress", + params=types.ProgressNotificationParams( + progressToken="test", progress=1 + ), ) - ))) + ), + ) - result = await result_cache.get_result(types.GetToolAsyncResultRequest( - method="tools/async/get", - params=types.GetToolAsyncResultRequestParams( - token = async_call_ref.token + result = await result_cache.get_result( + types.GetToolAsyncResultRequest( + method="tools/async/get", + params=types.GetToolAsyncResultRequestParams( + token=async_call_ref.token + ), ) - )) + ) assert not result.isError assert not result.isPending @@ -113,9 +117,9 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult: assert result.content[0].text == "test" mock_context_1.send_progress_notification.assert_not_called() mock_session_2.send_progress_notification.assert_called_with( - progress_token="test", - progress=1.0, - total=None, - message=None, - resource_uri = None + progress_token="test", + progress=1.0, + total=None, + message=None, + resource_uri=None, ) diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 08feb9da4..f01688f5d 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -1,4 +1,5 @@ from collections.abc import AsyncGenerator +from logging import getLogger import anyio import pytest @@ -19,6 +20,8 @@ EmptyResult, ) +logger = getLogger(__name__) + @pytest.fixture def mcp_server() -> Server: @@ -211,15 +214,14 @@ async def get_result(client_session: ClientSession, async_token: types.AsyncToke assert type(result.content[0]) is types.TextContent assert result.content[0].text == "test" -from logging import getLogger - -logger = getLogger(__name__) @pytest.mark.anyio -@pytest.mark.skip(reason="This test does not work, there is a subtle " - "bug with event.wait, lower level test_result_cache " - "tests underlying behaviour, revisit with feedback " \ - "from someone who cah help debug") +@pytest.mark.skip( + reason="This test does not work, there is a subtle " + "bug with event.wait, lower level test_result_cache " + "tests underlying behaviour, revisit with feedback " + "from someone who cah help debug" +) async def test_request_async_join(): """Test that requests can be joined from external sessions.""" # The tool is already registered in the fixture @@ -236,7 +238,6 @@ async def test_request_async_join(): ev_client_2_progressed_1 = anyio.Event() ev_done = anyio.Event() - # Start the request in a separate task so we can cancel it def make_server() -> Server: server = Server(name="TestSessionServer") @@ -328,8 +329,7 @@ async def client_2_progress_callback( logger.info(f"client2: progress done: {progress}/{total}") async def join_request( - client_session: ClientSession, - async_token: types.AsyncToken + client_session: ClientSession, async_token: types.AsyncToken ): return await client_session.send_request( ClientRequest( @@ -365,6 +365,7 @@ async def get_result(client_session: ClientSession, async_token: types.AsyncToke token = None async with anyio.create_task_group() as tg: + async def client_1_submit(): async with create_connected_server_and_client_session( server @@ -430,6 +431,7 @@ async def client_2_join(): assert ev_client_1_progressed_2.is_set() assert ev_client_2_progressed_1.is_set() + @pytest.mark.anyio async def test_connection_closed(): """ From ff58c07efef5a20f0d93bc9e0f5755a6b70b38d5 Mon Sep 17 00:00:00 2001 From: David Savage Date: Tue, 3 Jun 2025 06:02:01 +0000 Subject: [PATCH 20/31] add comment to discuss uuid generation --- src/mcp/server/lowlevel/result_cache.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/mcp/server/lowlevel/result_cache.py b/src/mcp/server/lowlevel/result_cache.py index 5870704a5..4df133011 100644 --- a/src/mcp/server/lowlevel/result_cache.py +++ b/src/mcp/server/lowlevel/result_cache.py @@ -250,6 +250,13 @@ async def session_close_hook(self, session: ServerSession): async def _new_in_progress(self) -> InProgress: while True: + # this nonsense is required to protect against the + # ridiculously unlikely scenario that two v4 uuids + # are generated with the same value + # uuidv7 would fix this but it is not yet included + # in python standard library + # see https://github.com/python/cpython/issues/89083 + # for context token = str(uuid4()) if token not in self._in_progress: new_in_progress = InProgress(token) From 234ed4c3d6229e9f921a009d2d105bef098163a0 Mon Sep 17 00:00:00 2001 From: David Savage Date: Tue, 3 Jun 2025 06:50:01 +0000 Subject: [PATCH 21/31] various minor tidy up --- src/mcp/server/lowlevel/result_cache.py | 78 +++++++++++-------------- 1 file changed, 35 insertions(+), 43 deletions(-) diff --git a/src/mcp/server/lowlevel/result_cache.py b/src/mcp/server/lowlevel/result_cache.py index 4df133011..5cfa272bf 100644 --- a/src/mcp/server/lowlevel/result_cache.py +++ b/src/mcp/server/lowlevel/result_cache.py @@ -164,33 +164,27 @@ async def cancel(self, notification: types.CancelToolAsyncNotification) -> None: async def get_result(self, req: types.GetToolAsyncResultRequest): logger.debug("Getting result") - in_progress = self._in_progress.get(req.params.token) - logger.debug(f"Found in progress {in_progress}") + async_token = req.params.token + in_progress = self._in_progress.get(async_token) if in_progress is None: return types.CallToolResult( - content=[types.TextContent(type="text", text="Unknown progress token")], + content=[types.TextContent(type="text", text="Unknown async token")], isError=True, ) else: + logger.debug(f"Found in progress {in_progress}") if in_progress.user == user_context.get(): - if in_progress.future is None: + assert in_progress.future is not None + # TODO add timeout to get async result + try: + result = in_progress.future.result(1) + logger.debug(f"Found result {result}") + return result + except TimeoutError: return types.CallToolResult( - content=[ - types.TextContent(type="text", text="Permission denied") - ], - isError=True, + content=[], + isPending=True, ) - else: - # TODO add timeout to get async result - try: - result = in_progress.future.result(1) - logger.debug(f"Found result {result}") - return result - except TimeoutError: - return types.CallToolResult( - content=[], - isPending=True, - ) else: return types.CallToolResult( content=[types.TextContent(type="text", text="Permission denied")], @@ -200,19 +194,20 @@ async def get_result(self, req: types.GetToolAsyncResultRequest): async def notification_hook( self, session: ServerSession, notification: types.ServerNotification ): - logger.debug(f"received {notification} from {id(session)}") + session_id = id(session) + logger.debug(f"received {notification} from {session_id}") if type(notification.root) is types.ProgressNotification: # async with self._lock: - async_token = self._session_lookup.get(id(session)) + async_token = self._session_lookup.get(session_id) if async_token is None: # not all sessions are async so just debug logger.debug("Discarding progress notification from unknown session") else: in_progress = self._in_progress.get(async_token) - assert in_progress is not None + assert in_progress is not None, "lost in progress for {async_token}" for other_id, other_session in in_progress.sessions.items(): - logger.debug(f"Checking {other_id} == {id(session)}") - if not other_id == id(session): + logger.debug(f"Checking {other_id} == {session_id}") + if not other_id == session_id: logger.debug(f"Sending progress to {other_id}") progress_token = in_progress.session_progress.get(other_id) assert progress_token is not None @@ -227,26 +222,23 @@ async def notification_hook( ) async def session_close_hook(self, session: ServerSession): - logger.debug(f"Closing {id(session)}") - dropped = self._session_lookup.pop(id(session), None) - if dropped is None: - logger.warning(f"Discarding callback from unknown session {id(session)}") + session_id = id(session) + logger.debug(f"Closing {session_id}") + dropped = self._session_lookup.pop(session_id, None) + assert dropped is not None, f"Discarded callback, unknown session {session_id}" + + in_progress = self._in_progress.get(dropped) + if in_progress is None: + logger.warning("In progress not found") else: - in_progress = self._in_progress.get(dropped) - if in_progress is None: - logger.warning("In progress not found") - else: - found = in_progress.sessions.pop(id(session), None) - if found is None: - logger.warning("No session found") - if len(in_progress.sessions) == 0: - self._in_progress.pop(dropped, None) - logger.debug("In progress found") - if in_progress.future is None: - logger.warning("In progress future is none") - else: - logger.debug("Cancelled in progress future") - in_progress.future.cancel() + found = in_progress.sessions.pop(session_id, None) + if found is None: + logger.warning("No session found") + if len(in_progress.sessions) == 0: + self._in_progress.pop(dropped, None) + assert in_progress.future is not None + logger.debug("Cancelled in progress future") + in_progress.future.cancel() async def _new_in_progress(self) -> InProgress: while True: From fed9e70af079dda3e66fb8f0c36af3fb40417a1f Mon Sep 17 00:00:00 2001 From: David Savage Date: Tue, 3 Jun 2025 19:39:38 +0000 Subject: [PATCH 22/31] add test for keep_alive expiry --- src/mcp/server/lowlevel/result_cache.py | 40 ++++- tests/server/lowlevel/test_result_cache.py | 176 ++++++++++++++++++++- 2 files changed, 210 insertions(+), 6 deletions(-) diff --git a/src/mcp/server/lowlevel/result_cache.py b/src/mcp/server/lowlevel/result_cache.py index 5cfa272bf..81f559455 100644 --- a/src/mcp/server/lowlevel/result_cache.py +++ b/src/mcp/server/lowlevel/result_cache.py @@ -1,8 +1,8 @@ +import time from collections.abc import Awaitable, Callable from concurrent.futures import Future from dataclasses import dataclass, field from logging import getLogger -from time import time from types import TracebackType from typing import Any from uuid import uuid4 @@ -23,12 +23,21 @@ @dataclass class InProgress: token: str + timer: Callable[[], float] user: AuthenticatedUser | None = None future: Future[types.CallToolResult] | None = None sessions: dict[int, ServerSession] = field(default_factory=lambda: {}) session_progress: dict[int, types.ProgressToken | None] = field( default_factory=lambda: {} ) + keep_alive: int | None = None + keep_alive_start: int | None = None + + def is_expired(self): + if self.keep_alive_start is None or self.keep_alive is None: + return False + else: + return int(self.timer()) > self.keep_alive_start + self.keep_alive class ResultCache: @@ -54,11 +63,17 @@ class ResultCache: _session_lookup: dict[int, types.AsyncToken] _portal: BlockingPortal - def __init__(self, max_size: int, max_keep_alive: int): + def __init__( + self, + max_size: int, + max_keep_alive: int, + timer: Callable[[], float] = time.monotonic, + ): self._max_size = max_size self._max_keep_alive = max_keep_alive self._in_progress = {} self._session_lookup = {} + self._timer = timer self._portal_provider = BlockingPortalProvider() async def __aenter__(self): @@ -105,6 +120,7 @@ async def call_tool(): in_progress.user = user_context.get() session_id = id(ctx.session) in_progress.sessions[session_id] = ctx.session + in_progress.keep_alive = timeout if req.params.meta is not None: progress_token = req.params.meta.progressToken else: @@ -114,7 +130,7 @@ async def call_tool(): in_progress.future = self._portal.start_task_soon(call_tool) result = types.CallToolAsyncResult( token=in_progress.token, - recieved=round(time()), + recieved=round(self._timer()), keepAlive=timeout, accepted=True, ) @@ -176,6 +192,15 @@ async def get_result(self, req: types.GetToolAsyncResultRequest): if in_progress.user == user_context.get(): assert in_progress.future is not None # TODO add timeout to get async result + if in_progress.is_expired(): + self._portal.start_task_soon(self._expire) + return types.CallToolResult( + content=[ + types.TextContent(type="text", text="Unknown async token") + ], + isError=True, + ) + try: result = in_progress.future.result(1) logger.debug(f"Found result {result}") @@ -235,7 +260,12 @@ async def session_close_hook(self, session: ServerSession): if found is None: logger.warning("No session found") if len(in_progress.sessions) == 0: - self._in_progress.pop(dropped, None) + in_progress.keep_alive_start = int(self._timer()) + + async def _expire(self): + for in_progress in self._in_progress.values(): + if in_progress.is_expired(): + self._in_progress.pop(in_progress.token, None) assert in_progress.future is not None logger.debug("Cancelled in progress future") in_progress.future.cancel() @@ -251,6 +281,6 @@ async def _new_in_progress(self) -> InProgress: # for context token = str(uuid4()) if token not in self._in_progress: - new_in_progress = InProgress(token) + new_in_progress = InProgress(token, self._timer) self._in_progress[token] = new_in_progress return new_in_progress diff --git a/tests/server/lowlevel/test_result_cache.py b/tests/server/lowlevel/test_result_cache.py index 8ff9828be..3eab7215d 100644 --- a/tests/server/lowlevel/test_result_cache.py +++ b/tests/server/lowlevel/test_result_cache.py @@ -68,7 +68,6 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult: mock_context_2 = Mock() mock_context_2.session = mock_session_2 - mock_session_2.send_progress_notification.result = None result_cache = ResultCache(max_size=1, max_keep_alive=1) async with AsyncExitStack() as stack: @@ -123,3 +122,178 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult: message=None, resource_uri=None, ) + + +@pytest.mark.anyio +async def test_async_call_keep_alive(): + """Tests async call keep alive""" + + async def test_call(call: types.CallToolRequest) -> types.ServerResult: + return types.ServerResult( + types.CallToolResult(content=[types.TextContent(type="text", text="test")]) + ) + + async_call = types.CallToolAsyncRequest( + method="tools/async/call", params=types.CallToolAsyncRequestParams(name="test") + ) + + mock_session_1 = AsyncMock() + mock_context_1 = Mock() + mock_context_1.session = mock_session_1 + + mock_session_2 = AsyncMock() + mock_context_2 = Mock() + + mock_context_2.session = mock_session_2 + + result_cache = ResultCache(max_size=1, max_keep_alive=10) + async with AsyncExitStack() as stack: + await stack.enter_async_context(result_cache) + async_call_ref = await result_cache.start_call( + test_call, async_call, mock_context_1 + ) + assert async_call_ref.token is not None + + await result_cache.session_close_hook(mock_session_1) + + await result_cache.join_call( + req=types.JoinCallToolAsyncRequest( + method="tools/async/join", + params=types.JoinCallToolRequestParams( + token=async_call_ref.token, + _meta=types.RequestParams.Meta(progressToken="test"), + ), + ), + ctx=mock_context_2, + ) + assert async_call_ref.token is not None + await result_cache.notification_hook( + session=mock_session_1, + notification=types.ServerNotification( + types.ProgressNotification( + method="notifications/progress", + params=types.ProgressNotificationParams( + progressToken="test", progress=1 + ), + ) + ), + ) + + result = await result_cache.get_result( + types.GetToolAsyncResultRequest( + method="tools/async/get", + params=types.GetToolAsyncResultRequestParams( + token=async_call_ref.token + ), + ) + ) + + assert not result.isError, str(result) + assert not result.isPending + assert len(result.content) == 1 + assert type(result.content[0]) is types.TextContent + assert result.content[0].text == "test" + + +@pytest.mark.anyio +async def test_async_call_keep_alive_expired(): + """Tests async call keep alive expiry""" + + async def test_call(call: types.CallToolRequest) -> types.ServerResult: + return types.ServerResult( + types.CallToolResult(content=[types.TextContent(type="text", text="test")]) + ) + + async_call = types.CallToolAsyncRequest( + method="tools/async/call", params=types.CallToolAsyncRequestParams(name="test") + ) + + mock_session_1 = AsyncMock() + mock_context_1 = Mock() + mock_context_1.session = mock_session_1 + + mock_session_2 = AsyncMock() + mock_context_2 = Mock() + mock_context_2.session = mock_session_2 + + mock_session_3 = AsyncMock() + mock_context_3 = Mock() + mock_context_3.session = mock_session_3 + + time = 0.0 + + def test_timer(): + return time + + result_cache = ResultCache(max_size=1, max_keep_alive=1, timer=test_timer) + async with AsyncExitStack() as stack: + await stack.enter_async_context(result_cache) + async_call_ref = await result_cache.start_call( + test_call, async_call, mock_context_1 + ) + assert async_call_ref.token is not None + + # lose the connection + await result_cache.session_close_hook(mock_session_1) + + # reconnect before keep_alive_timeout + time = 0.5 + await result_cache.join_call( + req=types.JoinCallToolAsyncRequest( + method="tools/async/join", + params=types.JoinCallToolRequestParams( + token=async_call_ref.token, + _meta=types.RequestParams.Meta(progressToken="test"), + ), + ), + ctx=mock_context_2, + ) + + result = await result_cache.get_result( + types.GetToolAsyncResultRequest( + method="tools/async/get", + params=types.GetToolAsyncResultRequestParams( + token=async_call_ref.token + ), + ) + ) + + # should successfully read data + assert not result.isError, str(result) + assert len(result.content) == 1 + assert type(result.content[0]) is types.TextContent + assert result.content[0].text == "test" + + # lose connection a second time + + await result_cache.session_close_hook(mock_session_2) + + time = 2 + + # reconnect after the keep_alive_timeout + + await result_cache.join_call( + req=types.JoinCallToolAsyncRequest( + method="tools/async/join", + params=types.JoinCallToolRequestParams( + token=async_call_ref.token, + _meta=types.RequestParams.Meta(progressToken="test"), + ), + ), + ctx=mock_context_3, + ) + + result = await result_cache.get_result( + types.GetToolAsyncResultRequest( + method="tools/async/get", + params=types.GetToolAsyncResultRequestParams( + token=async_call_ref.token + ), + ) + ) + + # now token should be expired + assert result.isError, str(result) + assert len(result.content) == 1 + assert type(result.content[0]) is types.TextContent + assert result.content[0].text == "Unknown async token" From f8d9e04fee63d3dfeac281cd739187c8de427e23 Mon Sep 17 00:00:00 2001 From: David Savage Date: Wed, 4 Jun 2025 06:06:22 +0000 Subject: [PATCH 23/31] fix assertion order and tidyup code for readability --- src/mcp/server/lowlevel/result_cache.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/mcp/server/lowlevel/result_cache.py b/src/mcp/server/lowlevel/result_cache.py index 81f559455..73b59c96e 100644 --- a/src/mcp/server/lowlevel/result_cache.py +++ b/src/mcp/server/lowlevel/result_cache.py @@ -46,8 +46,6 @@ class ResultCache: Its purpose is to act as a central point for managing in progress async calls, allowing multiple clients to join and receive progress updates, get results and/or cancel in progress calls - TODO CRITICAL keep_alive logic is not correct as per spec - results currently - only kept for as long as longest session reintroduce TTL cache TODO MAJOR needs a lot more testing around edge cases/failure scenarios TODO MAJOR decide if async.Locks are required for integrity of internal data structures @@ -130,7 +128,6 @@ async def call_tool(): in_progress.future = self._portal.start_task_soon(call_tool) result = types.CallToolAsyncResult( token=in_progress.token, - recieved=round(self._timer()), keepAlive=timeout, accepted=True, ) @@ -248,19 +245,20 @@ async def notification_hook( async def session_close_hook(self, session: ServerSession): session_id = id(session) - logger.debug(f"Closing {session_id}") + logger.debug(f"Received session close for {session_id}") dropped = self._session_lookup.pop(session_id, None) - assert dropped is not None, f"Discarded callback, unknown session {session_id}" + if dropped is None: + # lots of sessions will have no async tasks debug and return + logger.debug(f"Discarded callback, unknown session {session_id}") + return in_progress = self._in_progress.get(dropped) - if in_progress is None: - logger.warning("In progress not found") - else: - found = in_progress.sessions.pop(session_id, None) - if found is None: - logger.warning("No session found") - if len(in_progress.sessions) == 0: - in_progress.keep_alive_start = int(self._timer()) + assert in_progress is not None, "In progress not found" + found = in_progress.sessions.pop(session_id, None) + if found is None: + logger.warning("No session found") + if len(in_progress.sessions) == 0: + in_progress.keep_alive_start = int(self._timer()) async def _expire(self): for in_progress in self._in_progress.values(): From f9262359324671521c22a158754972f85650e073 Mon Sep 17 00:00:00 2001 From: David Savage Date: Wed, 4 Jun 2025 06:07:24 +0000 Subject: [PATCH 24/31] remove received, serves no purpose in spec, update protocol merge request too --- src/mcp/types.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mcp/types.py b/src/mcp/types.py index d5370f035..a2416bfcc 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -893,7 +893,6 @@ class CallToolAsyncResult(Result): """The servers response to an async tool call""" token: AsyncToken | None = None - recieved: int | None = None keepAlive: int | None = None accepted: bool From f667b631f6b7eef8c4ed1f0112462ef47a080908 Mon Sep 17 00:00:00 2001 From: David Savage Date: Wed, 4 Jun 2025 06:22:40 +0000 Subject: [PATCH 25/31] Add tests for cancellation behaviour and implement --- src/mcp/server/lowlevel/result_cache.py | 10 ++- tests/server/lowlevel/test_result_cache.py | 79 ++++++++++++++++++++++ 2 files changed, 87 insertions(+), 2 deletions(-) diff --git a/src/mcp/server/lowlevel/result_cache.py b/src/mcp/server/lowlevel/result_cache.py index 73b59c96e..bc8ba19d0 100644 --- a/src/mcp/server/lowlevel/result_cache.py +++ b/src/mcp/server/lowlevel/result_cache.py @@ -1,6 +1,6 @@ import time from collections.abc import Awaitable, Callable -from concurrent.futures import Future +from concurrent.futures import CancelledError, Future from dataclasses import dataclass, field from logging import getLogger from types import TracebackType @@ -168,7 +168,8 @@ async def cancel(self, notification: types.CancelToolAsyncNotification) -> None: if in_progress is not None: if in_progress.user == user_context.get(): # in_progress.task_group.cancel_scope.cancel() - del self._in_progress[notification.params.token] + assert in_progress.future is not None, "In progress future not found" + in_progress.future.cancel() else: logger.warning( "Permission denied for cancel notification received" @@ -202,6 +203,11 @@ async def get_result(self, req: types.GetToolAsyncResultRequest): result = in_progress.future.result(1) logger.debug(f"Found result {result}") return result + except CancelledError: + return types.CallToolResult( + content=[types.TextContent(type="text", text="cancelled")], + isError=True, + ) except TimeoutError: return types.CallToolResult( content=[], diff --git a/tests/server/lowlevel/test_result_cache.py b/tests/server/lowlevel/test_result_cache.py index 3eab7215d..0d7e0f726 100644 --- a/tests/server/lowlevel/test_result_cache.py +++ b/tests/server/lowlevel/test_result_cache.py @@ -1,6 +1,7 @@ from contextlib import AsyncExitStack from unittest.mock import AsyncMock, Mock +import anyio import pytest from mcp import types @@ -124,6 +125,84 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult: ) +@pytest.mark.anyio +async def test_async_cancel_in_progress(): + """Tests basic async call""" + + async def slow_call(call: types.CallToolRequest) -> types.ServerResult: + with anyio.move_on_after(10) as scope: + await anyio.sleep(10) + + if scope.cancel_called: + return types.ServerResult( + types.CallToolResult( + content=[ + types.TextContent(type="text", text="should be discarded") + ], + isError=True, + ) + ) + else: + return types.ServerResult( + types.CallToolResult( + content=[types.TextContent(type="text", text="test")] + ) + ) + + async_call = types.CallToolAsyncRequest( + method="tools/async/call", params=types.CallToolAsyncRequestParams(name="test") + ) + + mock_session_1 = AsyncMock() + mock_context_1 = Mock() + mock_context_1.session = mock_session_1 + + result_cache = ResultCache(max_size=1, max_keep_alive=1) + async with AsyncExitStack() as stack: + await stack.enter_async_context(result_cache) + async_call_ref = await result_cache.start_call( + slow_call, async_call, mock_context_1 + ) + assert async_call_ref.token is not None + + await result_cache.cancel( + notification=types.CancelToolAsyncNotification( + method="tools/async/cancel", + params=types.CancelToolAsyncNotificationParams( + token=async_call_ref.token + ), + ), + ) + + assert async_call_ref.token is not None + await result_cache.notification_hook( + session=mock_session_1, + notification=types.ServerNotification( + types.ProgressNotification( + method="notifications/progress", + params=types.ProgressNotificationParams( + progressToken="test", progress=1 + ), + ) + ), + ) + + result = await result_cache.get_result( + types.GetToolAsyncResultRequest( + method="tools/async/get", + params=types.GetToolAsyncResultRequestParams( + token=async_call_ref.token + ), + ) + ) + + assert result.isError + assert not result.isPending + assert len(result.content) == 1 + assert type(result.content[0]) is types.TextContent + assert result.content[0].text == "cancelled" + + @pytest.mark.anyio async def test_async_call_keep_alive(): """Tests async call keep alive""" From 6ded316645d86ba5ce818a7261f5239ca6e1d970 Mon Sep 17 00:00:00 2001 From: David Savage Date: Wed, 4 Jun 2025 06:23:37 +0000 Subject: [PATCH 26/31] add todo note --- src/mcp/server/lowlevel/result_cache.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/mcp/server/lowlevel/result_cache.py b/src/mcp/server/lowlevel/result_cache.py index bc8ba19d0..8c9218aa3 100644 --- a/src/mcp/server/lowlevel/result_cache.py +++ b/src/mcp/server/lowlevel/result_cache.py @@ -207,6 +207,7 @@ async def get_result(self, req: types.GetToolAsyncResultRequest): return types.CallToolResult( content=[types.TextContent(type="text", text="cancelled")], isError=True, + # TODO add isCancelled state to protocol? ) except TimeoutError: return types.CallToolResult( From 36c80a13050e5dd00e6db8853b1dcbf3d017112e Mon Sep 17 00:00:00 2001 From: David Savage Date: Wed, 4 Jun 2025 06:25:44 +0000 Subject: [PATCH 27/31] updated doc comment --- tests/server/lowlevel/test_result_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/server/lowlevel/test_result_cache.py b/tests/server/lowlevel/test_result_cache.py index 0d7e0f726..ef55a85c9 100644 --- a/tests/server/lowlevel/test_result_cache.py +++ b/tests/server/lowlevel/test_result_cache.py @@ -127,7 +127,7 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult: @pytest.mark.anyio async def test_async_cancel_in_progress(): - """Tests basic async call""" + """Tests cancelling an in progress async call""" async def slow_call(call: types.CallToolRequest) -> types.ServerResult: with anyio.move_on_after(10) as scope: From ef7944c5cf7d157465797bc5af3bfb855397a572 Mon Sep 17 00:00:00 2001 From: David Savage Date: Wed, 4 Jun 2025 07:32:54 +0000 Subject: [PATCH 28/31] add TODO on user auth context propagation --- src/mcp/server/lowlevel/result_cache.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/mcp/server/lowlevel/result_cache.py b/src/mcp/server/lowlevel/result_cache.py index 8c9218aa3..4290ac181 100644 --- a/src/mcp/server/lowlevel/result_cache.py +++ b/src/mcp/server/lowlevel/result_cache.py @@ -46,6 +46,8 @@ class ResultCache: Its purpose is to act as a central point for managing in progress async calls, allowing multiple clients to join and receive progress updates, get results and/or cancel in progress calls + TODO CRITICAL not obvious user context will be passed to background thread + add tests to assert behaviour with authenticated calls TODO MAJOR needs a lot more testing around edge cases/failure scenarios TODO MAJOR decide if async.Locks are required for integrity of internal data structures From 0e2437ddbf0fd2de998e8520a2824f723079f579 Mon Sep 17 00:00:00 2001 From: David Savage Date: Wed, 4 Jun 2025 07:56:16 +0000 Subject: [PATCH 29/31] Add initial test for auth context propagation in async context --- tests/server/lowlevel/test_result_cache.py | 62 +++++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/tests/server/lowlevel/test_result_cache.py b/tests/server/lowlevel/test_result_cache.py index ef55a85c9..455d83e70 100644 --- a/tests/server/lowlevel/test_result_cache.py +++ b/tests/server/lowlevel/test_result_cache.py @@ -1,10 +1,13 @@ from contextlib import AsyncExitStack -from unittest.mock import AsyncMock, Mock +from unittest.mock import AsyncMock, Mock, PropertyMock import anyio import pytest from mcp import types +from mcp.server.auth.middleware.auth_context import ( + auth_context_var as user_context, +) from mcp.server.lowlevel.result_cache import ResultCache @@ -376,3 +379,60 @@ def test_timer(): assert len(result.content) == 1 assert type(result.content[0]) is types.TextContent assert result.content[0].text == "Unknown async token" + + +@pytest.mark.anyio +async def test_async_call_pass_auth(): + """Tests async calls pass auth context to background thread""" + + mock_user = Mock() + type(mock_user).username = PropertyMock(return_value="mock_user") + + mock_session = AsyncMock() + mock_context = Mock() + mock_context.session = mock_session + result_cache = ResultCache(max_size=1, max_keep_alive=1) + + async def test_call(call: types.CallToolRequest) -> types.ServerResult: + user = user_context.get() + if user is None: + return types.ServerResult( + types.CallToolResult( + content=[types.TextContent(type="text", text="unauthorised")], + isError=True, + ) + ) + else: + return types.ServerResult( + types.CallToolResult( + content=[types.TextContent(type="text", text=str(user.username))] + ) + ) + + async_call = types.CallToolAsyncRequest( + method="tools/async/call", params=types.CallToolAsyncRequestParams(name="test") + ) + + async with AsyncExitStack() as stack: + await stack.enter_async_context(result_cache) + + user_context.set(mock_user) + async_call_ref = await result_cache.start_call( + test_call, async_call, mock_context + ) + assert async_call_ref.token is not None + + result = await result_cache.get_result( + types.GetToolAsyncResultRequest( + method="tools/async/get", + params=types.GetToolAsyncResultRequestParams( + token=async_call_ref.token + ), + ) + ) + + assert not result.isError + assert not result.isPending + assert len(result.content) == 1 + assert type(result.content[0]) is types.TextContent + assert result.content[0].text == "mock_user" From f1d11166d30a084776b7d816c00274ba3cdc2a56 Mon Sep 17 00:00:00 2001 From: David Savage Date: Wed, 4 Jun 2025 08:03:17 +0000 Subject: [PATCH 30/31] remove cache tools, didn't turn out to be that useful in this context, avoid unnecessary dependencies --- pyproject.toml | 1 - uv.lock | 11 ----------- 2 files changed, 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b08cd0ddc..0a11a3b15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,6 @@ dependencies = [ "sse-starlette>=1.6.1", "pydantic-settings>=2.5.2", "uvicorn>=0.23.1; sys_platform != 'emscripten'", - "cachetools==6", ] [project.optional-dependencies] diff --git a/uv.lock b/uv.lock index 3d013ed53..180d5a9c1 100644 --- a/uv.lock +++ b/uv.lock @@ -104,15 +104,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/09/71/54e999902aed72baf26bca0d50781b01838251a462612966e9fc4891eadd/black-25.1.0-py3-none-any.whl", hash = "sha256:95e8176dae143ba9097f351d174fdaf0ccd29efb414b362ae3fd72bf0f710717", size = 207646, upload-time = "2025-01-29T04:15:38.082Z" }, ] -[[package]] -name = "cachetools" -version = "6.0.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/c0/b0/f539a1ddff36644c28a61490056e5bae43bd7386d9f9c69beae2d7e7d6d1/cachetools-6.0.0.tar.gz", hash = "sha256:f225782b84438f828328fc2ad74346522f27e5b1440f4e9fd18b20ebfd1aa2cf", size = 30160, upload-time = "2025-05-23T20:01:13.076Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/6a/c3/8bb087c903c95a570015ce84e0c23ae1d79f528c349cbc141b5c4e250293/cachetools-6.0.0-py3-none-any.whl", hash = "sha256:82e73ba88f7b30228b5507dce1a1f878498fc669d972aef2dde4f3a3c24f103e", size = 10964, upload-time = "2025-05-23T20:01:11.323Z" }, -] - [[package]] name = "cairocffi" version = "1.7.1" @@ -539,7 +530,6 @@ name = "mcp" source = { editable = "." } dependencies = [ { name = "anyio" }, - { name = "cachetools" }, { name = "httpx" }, { name = "httpx-sse" }, { name = "pydantic" }, @@ -584,7 +574,6 @@ docs = [ [package.metadata] requires-dist = [ { name = "anyio", specifier = ">=4.5" }, - { name = "cachetools", specifier = "==6" }, { name = "httpx", specifier = ">=0.27" }, { name = "httpx-sse", specifier = ">=0.4" }, { name = "pydantic", specifier = ">=2.7.2,<3.0.0" }, From d0e463e08d5faa4d992bc35b0d4fd29a58297c92 Mon Sep 17 00:00:00 2001 From: David Savage Date: Wed, 4 Jun 2025 18:44:07 +0000 Subject: [PATCH 31/31] rename result_cache class and allow other implementations for enterprise scenarios where memory cache may not be sufficient --- ...sult_cache.py => async_request_manager.py} | 41 +++++++++++++++---- src/mcp/server/lowlevel/server.py | 26 +++++++----- ...cache.py => test_async_request_manager.py} | 16 ++++---- 3 files changed, 58 insertions(+), 25 deletions(-) rename src/mcp/server/lowlevel/{result_cache.py => async_request_manager.py} (90%) rename tests/server/lowlevel/{test_result_cache.py => test_async_request_manager.py} (95%) diff --git a/src/mcp/server/lowlevel/result_cache.py b/src/mcp/server/lowlevel/async_request_manager.py similarity index 90% rename from src/mcp/server/lowlevel/result_cache.py rename to src/mcp/server/lowlevel/async_request_manager.py index 4290ac181..b550d4e05 100644 --- a/src/mcp/server/lowlevel/result_cache.py +++ b/src/mcp/server/lowlevel/async_request_manager.py @@ -20,6 +20,36 @@ logger = getLogger(__name__) +class AsyncRequestManager: + async def __aenter__(self): ... + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: ... + async def start_call( + self, + call: Callable[[types.CallToolRequest], Awaitable[types.ServerResult]], + req: types.CallToolAsyncRequest, + ctx: RequestContext[ServerSession, Any, Any], + ) -> types.CallToolAsyncResult: ... + async def join_call( + self, + req: types.JoinCallToolAsyncRequest, + ctx: RequestContext[ServerSession, Any, Any], + ) -> types.CallToolAsyncResult: ... + async def cancel(self, notification: types.CancelToolAsyncNotification) -> None: ... + async def get_result( + self, req: types.GetToolAsyncResultRequest + ) -> types.CallToolResult: ... + + async def notification_hook( + self, session: ServerSession, notification: types.ServerNotification + ) -> None: ... + async def session_close_hook(self, session: ServerSession): ... + + @dataclass class InProgress: token: str @@ -40,23 +70,18 @@ def is_expired(self): return int(self.timer()) > self.keep_alive_start + self.keep_alive -class ResultCache: +class SimpleInMemoryAsyncRequestManager(AsyncRequestManager): """ Note this class is a work in progress Its purpose is to act as a central point for managing in progress async calls, allowing multiple clients to join and receive progress updates, get results and/or cancel in progress calls - TODO CRITICAL not obvious user context will be passed to background thread - add tests to assert behaviour with authenticated calls TODO MAJOR needs a lot more testing around edge cases/failure scenarios TODO MAJOR decide if async.Locks are required for integrity of internal data structures - TODO ENHANCEMENT externalise cachetools to allow for other implementations - e.g. redis etal for production scenarios TODO ENHANCEMENT may need to add an authorisation layer to decide if a user is allowed to get/join/cancel an existing async call current simple logic only allows same user to perform these tasks - TODO TRIVIAL name is probably not quite right, more of a result broker? """ _in_progress: dict[types.AsyncToken, InProgress] @@ -178,7 +203,9 @@ async def cancel(self, notification: types.CancelToolAsyncNotification) -> None: f"from {user_context.get()}" ) - async def get_result(self, req: types.GetToolAsyncResultRequest): + async def get_result( + self, req: types.GetToolAsyncResultRequest + ) -> types.CallToolResult: logger.debug("Getting result") async_token = req.params.token in_progress = self._in_progress.get(async_token) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 6701f85de..89a218b9c 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -80,8 +80,11 @@ async def main(): from typing_extensions import TypeVar import mcp.types as types +from mcp.server.lowlevel.async_request_manager import ( + AsyncRequestManager, + SimpleInMemoryAsyncRequestManager, +) from mcp.server.lowlevel.helper_types import ReadResourceContents -from mcp.server.lowlevel.result_cache import ResultCache from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.server.stdio import stdio_server as stdio_server @@ -136,8 +139,9 @@ def __init__( [Server[LifespanResultT, RequestT]], AbstractAsyncContextManager[LifespanResultT], ] = lifespan, - max_cache_size: int = 1000, - max_cache_ttl: int = 60, + async_request_manager: AsyncRequestManager = SimpleInMemoryAsyncRequestManager( + max_size=1000, max_keep_alive=60 + ), ): self.name = name self.version = version @@ -148,7 +152,7 @@ def __init__( ] = { types.PingRequest: _ping_handler, } - self.result_cache = ResultCache(max_cache_size, max_cache_ttl) + self.async_request_manager = async_request_manager self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} self.notification_options = NotificationOptions() logger.debug("Initializing server %r", name) @@ -432,19 +436,19 @@ async def handler(req: types.CallToolRequest): async def async_call_handler(req: types.CallToolAsyncRequest): ctx = request_ctx.get() - result = await self.result_cache.start_call(handler, req, ctx) + result = await self.async_request_manager.start_call(handler, req, ctx) return types.ServerResult(result) async def async_join_handler(req: types.JoinCallToolAsyncRequest): ctx = request_ctx.get() - result = await self.result_cache.join_call(req, ctx) + result = await self.async_request_manager.join_call(req, ctx) return types.ServerResult(result) async def async_cancel_handler(req: types.CancelToolAsyncNotification): - await self.result_cache.cancel(req) + await self.async_request_manager.cancel(req) async def async_result_handler(req: types.GetToolAsyncResultRequest): - result = await self.result_cache.get_result(req) + result = await self.async_request_manager.get_result(req) return types.ServerResult(result) self.request_handlers[types.CallToolRequest] = handler @@ -534,11 +538,11 @@ async def run( write_stream, initialization_options, stateless=stateless, - notification_hook=self.result_cache.notification_hook, - session_close_hook=self.result_cache.session_close_hook, + notification_hook=self.async_request_manager.notification_hook, + session_close_hook=self.async_request_manager.session_close_hook, ) ) - await stack.enter_async_context(self.result_cache) + await stack.enter_async_context(self.async_request_manager) async with anyio.create_task_group() as tg: async for message in session.incoming_messages: diff --git a/tests/server/lowlevel/test_result_cache.py b/tests/server/lowlevel/test_async_request_manager.py similarity index 95% rename from tests/server/lowlevel/test_result_cache.py rename to tests/server/lowlevel/test_async_request_manager.py index 455d83e70..5b92cb075 100644 --- a/tests/server/lowlevel/test_result_cache.py +++ b/tests/server/lowlevel/test_async_request_manager.py @@ -8,7 +8,7 @@ from mcp.server.auth.middleware.auth_context import ( auth_context_var as user_context, ) -from mcp.server.lowlevel.result_cache import ResultCache +from mcp.server.lowlevel.async_request_manager import SimpleInMemoryAsyncRequestManager @pytest.mark.anyio @@ -27,7 +27,7 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult: mock_session = AsyncMock() mock_context = Mock() mock_context.session = mock_session - result_cache = ResultCache(max_size=1, max_keep_alive=1) + result_cache = SimpleInMemoryAsyncRequestManager(max_size=1, max_keep_alive=1) async with AsyncExitStack() as stack: await stack.enter_async_context(result_cache) async_call_ref = await result_cache.start_call( @@ -73,7 +73,7 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult: mock_context_2.session = mock_session_2 - result_cache = ResultCache(max_size=1, max_keep_alive=1) + result_cache = SimpleInMemoryAsyncRequestManager(max_size=1, max_keep_alive=1) async with AsyncExitStack() as stack: await stack.enter_async_context(result_cache) async_call_ref = await result_cache.start_call( @@ -160,7 +160,7 @@ async def slow_call(call: types.CallToolRequest) -> types.ServerResult: mock_context_1 = Mock() mock_context_1.session = mock_session_1 - result_cache = ResultCache(max_size=1, max_keep_alive=1) + result_cache = SimpleInMemoryAsyncRequestManager(max_size=1, max_keep_alive=1) async with AsyncExitStack() as stack: await stack.enter_async_context(result_cache) async_call_ref = await result_cache.start_call( @@ -228,7 +228,7 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult: mock_context_2.session = mock_session_2 - result_cache = ResultCache(max_size=1, max_keep_alive=10) + result_cache = SimpleInMemoryAsyncRequestManager(max_size=1, max_keep_alive=10) async with AsyncExitStack() as stack: await stack.enter_async_context(result_cache) async_call_ref = await result_cache.start_call( @@ -307,7 +307,9 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult: def test_timer(): return time - result_cache = ResultCache(max_size=1, max_keep_alive=1, timer=test_timer) + result_cache = SimpleInMemoryAsyncRequestManager( + max_size=1, max_keep_alive=1, timer=test_timer + ) async with AsyncExitStack() as stack: await stack.enter_async_context(result_cache) async_call_ref = await result_cache.start_call( @@ -391,7 +393,7 @@ async def test_async_call_pass_auth(): mock_session = AsyncMock() mock_context = Mock() mock_context.session = mock_session - result_cache = ResultCache(max_size=1, max_keep_alive=1) + result_cache = SimpleInMemoryAsyncRequestManager(max_size=1, max_keep_alive=1) async def test_call(call: types.CallToolRequest) -> types.ServerResult: user = user_context.get()