From c3ef7deb613e9df8ab7316bd0b5ed39120e61bf5 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sat, 24 May 2025 15:35:10 +0000 Subject: [PATCH 1/4] tinkering with resource progress --- src/mcp/client/session.py | 17 +++++++----- src/mcp/server/fastmcp/server.py | 11 ++++++-- src/mcp/server/session.py | 6 ++-- src/mcp/shared/session.py | 37 +++++++++++++++++++++---- src/mcp/types.py | 10 +++++-- tests/issues/test_176_progress_token.py | 6 ++-- 6 files changed, 65 insertions(+), 22 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index fe90716e2..719dc2808 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,14 +1,15 @@ from datetime import timedelta -from typing import Any, Protocol +from typing import Annotated, Any, Protocol import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import AnyUrl, TypeAdapter +from pydantic import TypeAdapter +from pydantic.networks import AnyUrl, UrlConstraints import mcp.types as types from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage -from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder +from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder, ResourceProgressFnT from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") @@ -173,6 +174,7 @@ async def send_progress_notification( progress: float, total: float | None = None, message: str | None = None, + # TODO decide whether clients can send resource progress too? ) -> None: """Send a progress notification.""" await self.send_notification( @@ -203,6 +205,7 @@ async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResul async def list_resources( self, cursor: str | None = None + # TODO suggest in progress resources should be excluded by default? possibly add an optional flag to include? ) -> types.ListResourcesResult: """Send a resources/list request.""" return await self.send_request( @@ -233,7 +236,7 @@ async def list_resource_templates( types.ListResourceTemplatesResult, ) - async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: + async def read_resource(self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]) -> types.ReadResourceResult: """Send a resources/read request.""" return await self.send_request( types.ClientRequest( @@ -245,7 +248,7 @@ async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: types.ReadResourceResult, ) - async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + async def subscribe_resource(self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]) -> types.EmptyResult: """Send a resources/subscribe request.""" return await self.send_request( types.ClientRequest( @@ -257,7 +260,7 @@ async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: types.EmptyResult, ) - async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + async def unsubscribe_resource(self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]) -> types.EmptyResult: """Send a resources/unsubscribe request.""" return await self.send_request( types.ClientRequest( @@ -274,7 +277,7 @@ async def call_tool( name: str, arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, - progress_callback: ProgressFnT | None = None, + progress_callback: ProgressFnT | ResourceProgressFnT | None = None, ) -> types.CallToolResult: """Send a tools/call request with optional progress callback support.""" diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 3282baae6..1b4ab7428 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -10,12 +10,12 @@ asynccontextmanager, ) from itertools import chain -from typing import Any, Generic, Literal +from typing import Annotated, Any, Generic, Literal import anyio import pydantic_core from pydantic import BaseModel, Field -from pydantic.networks import AnyUrl +from pydantic.networks import AnyUrl, UrlConstraints from pydantic_settings import BaseSettings, SettingsConfigDict from starlette.applications import Starlette from starlette.middleware import Middleware @@ -956,7 +956,11 @@ def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]: return self._request_context async def report_progress( - self, progress: float, total: float | None = None, message: str | None = None + self, + progress: float, + total: float | None = None, + message: str | None = None, + resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] | None = None, ) -> None: """Report progress for the current operation. @@ -979,6 +983,7 @@ async def report_progress( progress=progress, total=total, message=message, + resource_uri=resource_uri, ) async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]: diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index ef5c5a3c3..37115a0c6 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -38,12 +38,12 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: """ from enum import Enum -from typing import Any, TypeVar +from typing import Annotated, Any, TypeVar import anyio import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import AnyUrl +from pydantic.networks import AnyUrl, UrlConstraints import mcp.types as types from mcp.server.models import InitializationOptions @@ -288,6 +288,7 @@ async def send_progress_notification( total: float | None = None, message: str | None = None, related_request_id: str | None = None, + resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] | None = None, ) -> None: """Send a progress notification.""" await self.send_notification( @@ -299,6 +300,7 @@ async def send_progress_notification( progress=progress, total=total, message=message, + resource_uri=resource_uri, ), ) ), diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 90b4eb27c..8f0d2fe2d 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -3,12 +3,14 @@ from contextlib import AsyncExitStack from datetime import timedelta from types import TracebackType -from typing import Any, Generic, Protocol, TypeVar +from typing import Annotated, Any, Generic, Protocol, TypeVar, runtime_checkable import anyio import httpx +import inspect from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import BaseModel +from pydantic.networks import AnyUrl, UrlConstraints from typing_extensions import Self from mcp.shared.exceptions import McpError @@ -43,6 +45,7 @@ RequestId = str | int +@runtime_checkable class ProgressFnT(Protocol): """Protocol for progress notification callbacks.""" @@ -50,6 +53,13 @@ async def __call__( self, progress: float, total: float | None, message: str | None ) -> None: ... +@runtime_checkable +class ResourceProgressFnT(Protocol): + """Protocol for progress notification callbacks with resources.""" + + async def __call__( + self, progress: float, total: float | None, message: str | None, resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] | None = None + ) -> None: ... class RequestResponder(Generic[ReceiveRequestT, SendResultT]): """Handles responding to MCP requests and manages request lifecycle. @@ -178,7 +188,8 @@ class BaseSession( ] _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] - _progress_callbacks: dict[RequestId, ProgressFnT] + _progress_callbacks: dict[RequestId, ProgressFnT ] + _resource_progress_callbacks: dict[RequestId, ResourceProgressFnT] def __init__( self, @@ -198,6 +209,7 @@ def __init__( self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} self._progress_callbacks = {} + self._resource_progress_callbacks = {} self._exit_stack = AsyncExitStack() async def __aenter__(self) -> Self: @@ -225,7 +237,7 @@ async def send_request( result_type: type[ReceiveResultT], request_read_timeout_seconds: timedelta | None = None, metadata: MessageMetadata = None, - progress_callback: ProgressFnT | None = None, + progress_callback: ProgressFnT | ResourceProgressFnT | None = None, ) -> ReceiveResultT: """ Sends a request and wait for a response. Raises an McpError if the @@ -252,8 +264,14 @@ async def send_request( if "_meta" not in request_data["params"]: request_data["params"]["_meta"] = {} request_data["params"]["_meta"]["progressToken"] = request_id - # Store the callback for this request - self._progress_callbacks[request_id] = progress_callback + # note this is required to ensure backwards compatibility for previous clients + signature = inspect.signature(progress_callback.__call__) + if 'resource_uri' in signature.parameters: + # Store the callback for this request + self._resource_progress_callbacks[request_id] = progress_callback # type: ignore + else: + # Store the callback for this request + self._progress_callbacks[request_id] = progress_callback try: jsonrpc_request = JSONRPCRequest( @@ -397,6 +415,15 @@ async def _receive_loop(self) -> None: notification.root.params.total, notification.root.params.message, ) + elif progress_token in self._resource_progress_callbacks: + callback = self._resource_progress_callbacks[progress_token] + await callback( + notification.root.params.progress, + notification.root.params.total, + notification.root.params.message, + notification.root.params.resource_uri, + ) + await self._received_notification(notification) await self._handle_incoming(notification) except Exception as e: diff --git a/src/mcp/types.py b/src/mcp/types.py index 465fc6ee6..ac83deda3 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -346,12 +346,18 @@ class ProgressNotificationParams(NotificationParams): total is unknown. """ total: float | None = None + """Total number of items to process (or total progress required), if known.""" + message: str | None = None """ Message related to progress. This should provide relevant human readable progress information. """ - message: str | None = None - """Total number of items to process (or total progress required), if known.""" + resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] | None = None + """ + An optional reference to an ephemeral resource associated with this progress, servers + may delete these at their descretion, but are encouraged to make them available for + a reasonable time period to allow clients to retrieve and cache the resources locally + """ model_config = ConfigDict(extra="allow") diff --git a/tests/issues/test_176_progress_token.py b/tests/issues/test_176_progress_token.py index 4ad22f294..6ba4770aa 100644 --- a/tests/issues/test_176_progress_token.py +++ b/tests/issues/test_176_progress_token.py @@ -39,11 +39,11 @@ async def test_progress_token_zero_first_call(): mock_session.send_progress_notification.call_count == 3 ), "All progress notifications should be sent" mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=0.0, total=10.0, message=None + progress_token=0, progress=0.0, total=10.0, message=None, resource_uri=None ) mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=5.0, total=10.0, message=None + progress_token=0, progress=5.0, total=10.0, message=None, resource_uri=None ) mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=10.0, total=10.0, message=None + progress_token=0, progress=10.0, total=10.0, message=None, resource_uri=None ) From b72bbc42dee94406a518b0c858bc4cd2d8927e9f Mon Sep 17 00:00:00 2001 From: David Savage Date: Sat, 24 May 2025 15:44:36 +0000 Subject: [PATCH 2/4] ruff check and format fixes --- src/mcp/client/session.py | 25 +++++++++++++++++++------ src/mcp/client/session_group.py | 1 - src/mcp/server/fastmcp/server.py | 3 ++- src/mcp/server/session.py | 3 ++- src/mcp/shared/session.py | 28 ++++++++++++++++++---------- src/mcp/types.py | 7 ++++--- 6 files changed, 45 insertions(+), 22 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 719dc2808..e1db957e6 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -9,7 +9,12 @@ import mcp.types as types from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage -from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder, ResourceProgressFnT +from mcp.shared.session import ( + BaseSession, + ProgressFnT, + RequestResponder, + ResourceProgressFnT, +) from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") @@ -204,8 +209,10 @@ async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResul ) async def list_resources( - self, cursor: str | None = None - # TODO suggest in progress resources should be excluded by default? possibly add an optional flag to include? + self, + cursor: str | None = None, + # TODO suggest in progress resources should be excluded by default? + # possibly add an optional flag to include? ) -> types.ListResourcesResult: """Send a resources/list request.""" return await self.send_request( @@ -236,7 +243,9 @@ async def list_resource_templates( types.ListResourceTemplatesResult, ) - async def read_resource(self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]) -> types.ReadResourceResult: + async def read_resource( + self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + ) -> types.ReadResourceResult: """Send a resources/read request.""" return await self.send_request( types.ClientRequest( @@ -248,7 +257,9 @@ async def read_resource(self, uri: Annotated[AnyUrl, UrlConstraints(host_require types.ReadResourceResult, ) - async def subscribe_resource(self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]) -> types.EmptyResult: + async def subscribe_resource( + self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + ) -> types.EmptyResult: """Send a resources/subscribe request.""" return await self.send_request( types.ClientRequest( @@ -260,7 +271,9 @@ async def subscribe_resource(self, uri: Annotated[AnyUrl, UrlConstraints(host_re types.EmptyResult, ) - async def unsubscribe_resource(self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]) -> types.EmptyResult: + async def unsubscribe_resource( + self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + ) -> types.EmptyResult: """Send a resources/unsubscribe request.""" return await self.send_request( types.ClientRequest( diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index a430533b3..a77dc7a1e 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -154,7 +154,6 @@ async def __aexit__( for exit_stack in self._session_exit_stacks.values(): tg.start_soon(exit_stack.aclose) - @property def sessions(self) -> list[mcp.ClientSession]: """Returns the list of sessions being managed.""" diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 1b4ab7428..5489f6952 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -960,7 +960,8 @@ async def report_progress( progress: float, total: float | None = None, message: str | None = None, - resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] | None = None, + resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + | None = None, ) -> None: """Report progress for the current operation. diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 37115a0c6..147bf0326 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -288,7 +288,8 @@ async def send_progress_notification( total: float | None = None, message: str | None = None, related_request_id: str | None = None, - resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] | None = None, + resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + | None = None, ) -> None: """Send a progress notification.""" await self.send_notification( diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 8f0d2fe2d..94ae3651d 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,3 +1,4 @@ +import inspect import logging from collections.abc import Callable from contextlib import AsyncExitStack @@ -7,7 +8,6 @@ import anyio import httpx -import inspect from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import BaseModel from pydantic.networks import AnyUrl, UrlConstraints @@ -53,14 +53,21 @@ async def __call__( self, progress: float, total: float | None, message: str | None ) -> None: ... + @runtime_checkable class ResourceProgressFnT(Protocol): """Protocol for progress notification callbacks with resources.""" async def __call__( - self, progress: float, total: float | None, message: str | None, resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] | None = None + self, + progress: float, + total: float | None, + message: str | None, + resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + | None = None, ) -> None: ... + class RequestResponder(Generic[ReceiveRequestT, SendResultT]): """Handles responding to MCP requests and manages request lifecycle. @@ -188,8 +195,8 @@ class BaseSession( ] _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] - _progress_callbacks: dict[RequestId, ProgressFnT ] - _resource_progress_callbacks: dict[RequestId, ResourceProgressFnT] + _progress_callbacks: dict[RequestId, ProgressFnT] + _resource_callbacks: dict[RequestId, ResourceProgressFnT] def __init__( self, @@ -209,7 +216,7 @@ def __init__( self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} self._progress_callbacks = {} - self._resource_progress_callbacks = {} + self._resource_callbacks = {} self._exit_stack = AsyncExitStack() async def __aenter__(self) -> Self: @@ -264,11 +271,12 @@ async def send_request( if "_meta" not in request_data["params"]: request_data["params"]["_meta"] = {} request_data["params"]["_meta"]["progressToken"] = request_id - # note this is required to ensure backwards compatibility for previous clients + # note this is required to ensure backwards compatibility + # for previous clients signature = inspect.signature(progress_callback.__call__) - if 'resource_uri' in signature.parameters: + if "resource_uri" in signature.parameters: # Store the callback for this request - self._resource_progress_callbacks[request_id] = progress_callback # type: ignore + self._resource_callbacks[request_id] = progress_callback # type: ignore else: # Store the callback for this request self._progress_callbacks[request_id] = progress_callback @@ -415,8 +423,8 @@ async def _receive_loop(self) -> None: notification.root.params.total, notification.root.params.message, ) - elif progress_token in self._resource_progress_callbacks: - callback = self._resource_progress_callbacks[progress_token] + elif progress_token in self._resource_callbacks: + callback = self._resource_callbacks[progress_token] await callback( notification.root.params.progress, notification.root.params.total, diff --git a/src/mcp/types.py b/src/mcp/types.py index ac83deda3..c7a6dfff4 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -354,9 +354,10 @@ class ProgressNotificationParams(NotificationParams): """ resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] | None = None """ - An optional reference to an ephemeral resource associated with this progress, servers - may delete these at their descretion, but are encouraged to make them available for - a reasonable time period to allow clients to retrieve and cache the resources locally + An optional reference to an ephemeral resource associated with this + progress, servers may delete these at their descretion, but are encouraged + to make them available for a reasonable time period to allow clients to + retrieve and cache the resources locally """ model_config = ConfigDict(extra="allow") From 202e92273bd70ed12e266c425165be6058b06b4a Mon Sep 17 00:00:00 2001 From: David Savage Date: Sat, 24 May 2025 17:52:16 +0000 Subject: [PATCH 3/4] use len of parameters to check progress variant, can't rely on clients naming params consistently --- src/mcp/shared/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 94ae3651d..f31657c14 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -274,7 +274,7 @@ async def send_request( # note this is required to ensure backwards compatibility # for previous clients signature = inspect.signature(progress_callback.__call__) - if "resource_uri" in signature.parameters: + if len(signature.parameters) == 3: # Store the callback for this request self._resource_callbacks[request_id] = progress_callback # type: ignore else: From d634e6a4228a175a984d08301068dba894b6feb2 Mon Sep 17 00:00:00 2001 From: David Savage Date: Sun, 25 May 2025 09:05:32 +0000 Subject: [PATCH 4/4] clarify TODO message --- src/mcp/client/session.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index e1db957e6..ce124bd1b 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -179,7 +179,9 @@ async def send_progress_notification( progress: float, total: float | None = None, message: str | None = None, - # TODO decide whether clients can send resource progress too? + # TODO check whether MCP spec allows clients to create resources + # for server and therefore whether resource notifications + # would be required here too ) -> None: """Send a progress notification.""" await self.send_notification(