Skip to content

Feature/resource progress #800

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from datetime import timedelta
from typing import Any, Protocol
from typing import Annotated, Any, Protocol

import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl, TypeAdapter
from pydantic import TypeAdapter
from pydantic.networks import AnyUrl, UrlConstraints

import mcp.types as types
from mcp.shared.context import RequestContext
from mcp.shared.message import SessionMessage
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
from mcp.shared.session import (
BaseSession,
ProgressFnT,
RequestResponder,
ResourceProgressFnT,
)
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS

DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
Expand Down Expand Up @@ -173,6 +179,7 @@ async def send_progress_notification(
progress: float,
total: float | None = None,
message: str | None = None,
# TODO decide whether clients can send resource progress too?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey i'm working on resources stuff right now so i hope you don't mind if I'm commenting as you work, what is the notion of a client resource in general? i thought they were owned by the server

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the protocol indicates resources are offered by servers to clients

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to me a client progress notification would not be about a resource, but some other process.

Copy link
Author

@davemssavage davemssavage May 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep I think that makes, I think that's what I guessed too but hadn't got around to validating my hunch, I might tidy this TODO up if there really is no use case for a client resource, it was more a reminder to think about it before finishing off the patch

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well client progress notifications are within the spec, it's just they're not about resources, they're about long-running process:

The Model Context Protocol (MCP) supports optional progress tracking for long-running operations through notification messages. Either side can send progress notifications to provide updates about operation status.

Not sure what the use case is or if it's relevant to your goal, but yeah.

) -> None:
"""Send a progress notification."""
await self.send_notification(
Expand Down Expand Up @@ -202,7 +209,10 @@ async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResul
)

async def list_resources(
self, cursor: str | None = None
self,
cursor: str | None = None,
# TODO suggest in progress resources should be excluded by default?
# possibly add an optional flag to include?
) -> types.ListResourcesResult:
"""Send a resources/list request."""
return await self.send_request(
Expand Down Expand Up @@ -233,7 +243,9 @@ async def list_resource_templates(
types.ListResourceTemplatesResult,
)

async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
async def read_resource(
self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
) -> types.ReadResourceResult:
"""Send a resources/read request."""
return await self.send_request(
types.ClientRequest(
Expand All @@ -245,7 +257,9 @@ async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
types.ReadResourceResult,
)

async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
async def subscribe_resource(
self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
) -> types.EmptyResult:
"""Send a resources/subscribe request."""
return await self.send_request(
types.ClientRequest(
Expand All @@ -257,7 +271,9 @@ async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
types.EmptyResult,
)

async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
async def unsubscribe_resource(
self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
) -> types.EmptyResult:
"""Send a resources/unsubscribe request."""
return await self.send_request(
types.ClientRequest(
Expand All @@ -274,7 +290,7 @@ async def call_tool(
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
progress_callback: ProgressFnT | None = None,
progress_callback: ProgressFnT | ResourceProgressFnT | None = None,
) -> types.CallToolResult:
"""Send a tools/call request with optional progress callback support."""

Expand Down
1 change: 0 additions & 1 deletion src/mcp/client/session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ async def __aexit__(
for exit_stack in self._session_exit_stacks.values():
tg.start_soon(exit_stack.aclose)


@property
def sessions(self) -> list[mcp.ClientSession]:
"""Returns the list of sessions being managed."""
Expand Down
12 changes: 9 additions & 3 deletions src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
asynccontextmanager,
)
from itertools import chain
from typing import Any, Generic, Literal
from typing import Annotated, Any, Generic, Literal

import anyio
import pydantic_core
from pydantic import BaseModel, Field
from pydantic.networks import AnyUrl
from pydantic.networks import AnyUrl, UrlConstraints
from pydantic_settings import BaseSettings, SettingsConfigDict
from starlette.applications import Starlette
from starlette.middleware import Middleware
Expand Down Expand Up @@ -956,7 +956,12 @@ def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]:
return self._request_context

async def report_progress(
self, progress: float, total: float | None = None, message: str | None = None
self,
progress: float,
total: float | None = None,
message: str | None = None,
resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
| None = None,
) -> None:
"""Report progress for the current operation.

Expand All @@ -979,6 +984,7 @@ async def report_progress(
progress=progress,
total=total,
message=message,
resource_uri=resource_uri,
)

async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]:
Expand Down
7 changes: 5 additions & 2 deletions src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
"""

from enum import Enum
from typing import Any, TypeVar
from typing import Annotated, Any, TypeVar

import anyio
import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
from pydantic.networks import AnyUrl, UrlConstraints

import mcp.types as types
from mcp.server.models import InitializationOptions
Expand Down Expand Up @@ -288,6 +288,8 @@ async def send_progress_notification(
total: float | None = None,
message: str | None = None,
related_request_id: str | None = None,
resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
| None = None,
) -> None:
"""Send a progress notification."""
await self.send_notification(
Expand All @@ -299,6 +301,7 @@ async def send_progress_notification(
progress=progress,
total=total,
message=message,
resource_uri=resource_uri,
),
)
),
Expand Down
43 changes: 39 additions & 4 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import inspect
import logging
from collections.abc import Callable
from contextlib import AsyncExitStack
from datetime import timedelta
from types import TracebackType
from typing import Any, Generic, Protocol, TypeVar
from typing import Annotated, Any, Generic, Protocol, TypeVar, runtime_checkable

import anyio
import httpx
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import BaseModel
from pydantic.networks import AnyUrl, UrlConstraints
from typing_extensions import Self

from mcp.shared.exceptions import McpError
Expand Down Expand Up @@ -43,6 +45,7 @@
RequestId = str | int


@runtime_checkable
class ProgressFnT(Protocol):
"""Protocol for progress notification callbacks."""

Expand All @@ -51,6 +54,20 @@ async def __call__(
) -> None: ...


@runtime_checkable
class ResourceProgressFnT(Protocol):
"""Protocol for progress notification callbacks with resources."""

async def __call__(
self,
progress: float,
total: float | None,
message: str | None,
resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
| None = None,
) -> None: ...


class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
"""Handles responding to MCP requests and manages request lifecycle.

Expand Down Expand Up @@ -179,6 +196,7 @@ class BaseSession(
_request_id: int
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
_progress_callbacks: dict[RequestId, ProgressFnT]
_resource_callbacks: dict[RequestId, ResourceProgressFnT]

def __init__(
self,
Expand All @@ -198,6 +216,7 @@ def __init__(
self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {}
self._progress_callbacks = {}
self._resource_callbacks = {}
self._exit_stack = AsyncExitStack()

async def __aenter__(self) -> Self:
Expand Down Expand Up @@ -225,7 +244,7 @@ async def send_request(
result_type: type[ReceiveResultT],
request_read_timeout_seconds: timedelta | None = None,
metadata: MessageMetadata = None,
progress_callback: ProgressFnT | None = None,
progress_callback: ProgressFnT | ResourceProgressFnT | None = None,
) -> ReceiveResultT:
"""
Sends a request and wait for a response. Raises an McpError if the
Expand All @@ -252,8 +271,15 @@ async def send_request(
if "_meta" not in request_data["params"]:
request_data["params"]["_meta"] = {}
request_data["params"]["_meta"]["progressToken"] = request_id
# Store the callback for this request
self._progress_callbacks[request_id] = progress_callback
# note this is required to ensure backwards compatibility
# for previous clients
signature = inspect.signature(progress_callback.__call__)
if len(signature.parameters) == 3:
# Store the callback for this request
self._resource_callbacks[request_id] = progress_callback # type: ignore
else:
# Store the callback for this request
self._progress_callbacks[request_id] = progress_callback

try:
jsonrpc_request = JSONRPCRequest(
Expand Down Expand Up @@ -397,6 +423,15 @@ async def _receive_loop(self) -> None:
notification.root.params.total,
notification.root.params.message,
)
elif progress_token in self._resource_callbacks:
callback = self._resource_callbacks[progress_token]
await callback(
notification.root.params.progress,
notification.root.params.total,
notification.root.params.message,
notification.root.params.resource_uri,
)

await self._received_notification(notification)
await self._handle_incoming(notification)
except Exception as e:
Expand Down
11 changes: 9 additions & 2 deletions src/mcp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,12 +346,19 @@ class ProgressNotificationParams(NotificationParams):
total is unknown.
"""
total: float | None = None
"""Total number of items to process (or total progress required), if known."""
message: str | None = None
"""
Message related to progress. This should provide relevant human readable
progress information.
"""
message: str | None = None
"""Total number of items to process (or total progress required), if known."""
resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] | None = None
"""
An optional reference to an ephemeral resource associated with this
progress, servers may delete these at their descretion, but are encouraged
to make them available for a reasonable time period to allow clients to
retrieve and cache the resources locally
"""
model_config = ConfigDict(extra="allow")


Expand Down
6 changes: 3 additions & 3 deletions tests/issues/test_176_progress_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ async def test_progress_token_zero_first_call():
mock_session.send_progress_notification.call_count == 3
), "All progress notifications should be sent"
mock_session.send_progress_notification.assert_any_call(
progress_token=0, progress=0.0, total=10.0, message=None
progress_token=0, progress=0.0, total=10.0, message=None, resource_uri=None
)
mock_session.send_progress_notification.assert_any_call(
progress_token=0, progress=5.0, total=10.0, message=None
progress_token=0, progress=5.0, total=10.0, message=None, resource_uri=None
)
mock_session.send_progress_notification.assert_any_call(
progress_token=0, progress=10.0, total=10.0, message=None
progress_token=0, progress=10.0, total=10.0, message=None, resource_uri=None
)
Loading