diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index fe90716e2..ce124bd1b 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,14 +1,20 @@ from datetime import timedelta -from typing import Any, Protocol +from typing import Annotated, Any, Protocol import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import AnyUrl, TypeAdapter +from pydantic import TypeAdapter +from pydantic.networks import AnyUrl, UrlConstraints import mcp.types as types from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage -from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder +from mcp.shared.session import ( + BaseSession, + ProgressFnT, + RequestResponder, + ResourceProgressFnT, +) from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") @@ -173,6 +179,9 @@ async def send_progress_notification( progress: float, total: float | None = None, message: str | None = None, + # TODO check whether MCP spec allows clients to create resources + # for server and therefore whether resource notifications + # would be required here too ) -> None: """Send a progress notification.""" await self.send_notification( @@ -202,7 +211,10 @@ async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResul ) async def list_resources( - self, cursor: str | None = None + self, + cursor: str | None = None, + # TODO suggest in progress resources should be excluded by default? + # possibly add an optional flag to include? ) -> types.ListResourcesResult: """Send a resources/list request.""" return await self.send_request( @@ -233,7 +245,9 @@ async def list_resource_templates( types.ListResourceTemplatesResult, ) - async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: + async def read_resource( + self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + ) -> types.ReadResourceResult: """Send a resources/read request.""" return await self.send_request( types.ClientRequest( @@ -245,7 +259,9 @@ async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: types.ReadResourceResult, ) - async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + async def subscribe_resource( + self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + ) -> types.EmptyResult: """Send a resources/subscribe request.""" return await self.send_request( types.ClientRequest( @@ -257,7 +273,9 @@ async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: types.EmptyResult, ) - async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + async def unsubscribe_resource( + self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + ) -> types.EmptyResult: """Send a resources/unsubscribe request.""" return await self.send_request( types.ClientRequest( @@ -274,7 +292,7 @@ async def call_tool( name: str, arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, - progress_callback: ProgressFnT | None = None, + progress_callback: ProgressFnT | ResourceProgressFnT | None = None, ) -> types.CallToolResult: """Send a tools/call request with optional progress callback support.""" diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index a430533b3..a77dc7a1e 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -154,7 +154,6 @@ async def __aexit__( for exit_stack in self._session_exit_stacks.values(): tg.start_soon(exit_stack.aclose) - @property def sessions(self) -> list[mcp.ClientSession]: """Returns the list of sessions being managed.""" diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 3282baae6..5489f6952 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -10,12 +10,12 @@ asynccontextmanager, ) from itertools import chain -from typing import Any, Generic, Literal +from typing import Annotated, Any, Generic, Literal import anyio import pydantic_core from pydantic import BaseModel, Field -from pydantic.networks import AnyUrl +from pydantic.networks import AnyUrl, UrlConstraints from pydantic_settings import BaseSettings, SettingsConfigDict from starlette.applications import Starlette from starlette.middleware import Middleware @@ -956,7 +956,12 @@ def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]: return self._request_context async def report_progress( - self, progress: float, total: float | None = None, message: str | None = None + self, + progress: float, + total: float | None = None, + message: str | None = None, + resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + | None = None, ) -> None: """Report progress for the current operation. @@ -979,6 +984,7 @@ async def report_progress( progress=progress, total=total, message=message, + resource_uri=resource_uri, ) async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]: diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index ef5c5a3c3..147bf0326 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -38,12 +38,12 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: """ from enum import Enum -from typing import Any, TypeVar +from typing import Annotated, Any, TypeVar import anyio import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import AnyUrl +from pydantic.networks import AnyUrl, UrlConstraints import mcp.types as types from mcp.server.models import InitializationOptions @@ -288,6 +288,8 @@ async def send_progress_notification( total: float | None = None, message: str | None = None, related_request_id: str | None = None, + resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + | None = None, ) -> None: """Send a progress notification.""" await self.send_notification( @@ -299,6 +301,7 @@ async def send_progress_notification( progress=progress, total=total, message=message, + resource_uri=resource_uri, ), ) ), diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 90b4eb27c..f31657c14 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,14 +1,16 @@ +import inspect import logging from collections.abc import Callable from contextlib import AsyncExitStack from datetime import timedelta from types import TracebackType -from typing import Any, Generic, Protocol, TypeVar +from typing import Annotated, Any, Generic, Protocol, TypeVar, runtime_checkable import anyio import httpx from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import BaseModel +from pydantic.networks import AnyUrl, UrlConstraints from typing_extensions import Self from mcp.shared.exceptions import McpError @@ -43,6 +45,7 @@ RequestId = str | int +@runtime_checkable class ProgressFnT(Protocol): """Protocol for progress notification callbacks.""" @@ -51,6 +54,20 @@ async def __call__( ) -> None: ... +@runtime_checkable +class ResourceProgressFnT(Protocol): + """Protocol for progress notification callbacks with resources.""" + + async def __call__( + self, + progress: float, + total: float | None, + message: str | None, + resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + | None = None, + ) -> None: ... + + class RequestResponder(Generic[ReceiveRequestT, SendResultT]): """Handles responding to MCP requests and manages request lifecycle. @@ -179,6 +196,7 @@ class BaseSession( _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] _progress_callbacks: dict[RequestId, ProgressFnT] + _resource_callbacks: dict[RequestId, ResourceProgressFnT] def __init__( self, @@ -198,6 +216,7 @@ def __init__( self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} self._progress_callbacks = {} + self._resource_callbacks = {} self._exit_stack = AsyncExitStack() async def __aenter__(self) -> Self: @@ -225,7 +244,7 @@ async def send_request( result_type: type[ReceiveResultT], request_read_timeout_seconds: timedelta | None = None, metadata: MessageMetadata = None, - progress_callback: ProgressFnT | None = None, + progress_callback: ProgressFnT | ResourceProgressFnT | None = None, ) -> ReceiveResultT: """ Sends a request and wait for a response. Raises an McpError if the @@ -252,8 +271,15 @@ async def send_request( if "_meta" not in request_data["params"]: request_data["params"]["_meta"] = {} request_data["params"]["_meta"]["progressToken"] = request_id - # Store the callback for this request - self._progress_callbacks[request_id] = progress_callback + # note this is required to ensure backwards compatibility + # for previous clients + signature = inspect.signature(progress_callback.__call__) + if len(signature.parameters) == 3: + # Store the callback for this request + self._resource_callbacks[request_id] = progress_callback # type: ignore + else: + # Store the callback for this request + self._progress_callbacks[request_id] = progress_callback try: jsonrpc_request = JSONRPCRequest( @@ -397,6 +423,15 @@ async def _receive_loop(self) -> None: notification.root.params.total, notification.root.params.message, ) + elif progress_token in self._resource_callbacks: + callback = self._resource_callbacks[progress_token] + await callback( + notification.root.params.progress, + notification.root.params.total, + notification.root.params.message, + notification.root.params.resource_uri, + ) + await self._received_notification(notification) await self._handle_incoming(notification) except Exception as e: diff --git a/src/mcp/types.py b/src/mcp/types.py index 465fc6ee6..c7a6dfff4 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -346,12 +346,19 @@ class ProgressNotificationParams(NotificationParams): total is unknown. """ total: float | None = None + """Total number of items to process (or total progress required), if known.""" + message: str | None = None """ Message related to progress. This should provide relevant human readable progress information. """ - message: str | None = None - """Total number of items to process (or total progress required), if known.""" + resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] | None = None + """ + An optional reference to an ephemeral resource associated with this + progress, servers may delete these at their descretion, but are encouraged + to make them available for a reasonable time period to allow clients to + retrieve and cache the resources locally + """ model_config = ConfigDict(extra="allow") diff --git a/tests/issues/test_176_progress_token.py b/tests/issues/test_176_progress_token.py index 4ad22f294..6ba4770aa 100644 --- a/tests/issues/test_176_progress_token.py +++ b/tests/issues/test_176_progress_token.py @@ -39,11 +39,11 @@ async def test_progress_token_zero_first_call(): mock_session.send_progress_notification.call_count == 3 ), "All progress notifications should be sent" mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=0.0, total=10.0, message=None + progress_token=0, progress=0.0, total=10.0, message=None, resource_uri=None ) mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=5.0, total=10.0, message=None + progress_token=0, progress=5.0, total=10.0, message=None, resource_uri=None ) mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=10.0, total=10.0, message=None + progress_token=0, progress=10.0, total=10.0, message=None, resource_uri=None )