diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 3b7fc3fae..8dfda8c16 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,14 +1,20 @@ from datetime import timedelta -from typing import Any, Protocol +from typing import Annotated, Any, Protocol import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import AnyUrl, TypeAdapter +from pydantic import TypeAdapter +from pydantic.networks import AnyUrl, UrlConstraints import mcp.types as types from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage -from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder +from mcp.shared.session import ( + BaseSession, + ProgressFnT, + RequestResponder, + ResourceProgressFnT, +) from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") @@ -208,7 +214,10 @@ async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResul ) async def list_resources( - self, cursor: str | None = None + self, + cursor: str | None = None, + # TODO suggest in progress resources should be excluded by default? + # possibly add an optional flag to include? ) -> types.ListResourcesResult: """Send a resources/list request.""" return await self.send_request( @@ -239,7 +248,9 @@ async def list_resource_templates( types.ListResourceTemplatesResult, ) - async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: + async def read_resource( + self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + ) -> types.ReadResourceResult: """Send a resources/read request.""" return await self.send_request( types.ClientRequest( @@ -251,7 +262,9 @@ async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: types.ReadResourceResult, ) - async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + async def subscribe_resource( + self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + ) -> types.EmptyResult: """Send a resources/subscribe request.""" return await self.send_request( types.ClientRequest( @@ -263,7 +276,9 @@ async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: types.EmptyResult, ) - async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + async def unsubscribe_resource( + self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + ) -> types.EmptyResult: """Send a resources/unsubscribe request.""" return await self.send_request( types.ClientRequest( @@ -280,7 +295,7 @@ async def call_tool( name: str, arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, - progress_callback: ProgressFnT | None = None, + progress_callback: ProgressFnT | ResourceProgressFnT | None = None, ) -> types.CallToolResult: """Send a tools/call request with optional progress callback support.""" diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index e5b6c3acc..746baf290 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -10,12 +10,12 @@ asynccontextmanager, ) from itertools import chain -from typing import Any, Generic, Literal +from typing import Annotated, Any, Generic, Literal import anyio import pydantic_core from pydantic import BaseModel, Field -from pydantic.networks import AnyUrl +from pydantic.networks import AnyUrl, UrlConstraints from pydantic_settings import BaseSettings, SettingsConfigDict from starlette.applications import Starlette from starlette.middleware import Middleware @@ -147,7 +147,7 @@ def __init__( tools: list[Tool] | None = None, **settings: Any, ): - self.settings = Settings(**settings) + self.settings = Settings(**settings) # type: ignore self._mcp_server = MCPServer( name=name or "FastMCP", @@ -962,7 +962,12 @@ def request_context( return self._request_context async def report_progress( - self, progress: float, total: float | None = None, message: str | None = None + self, + progress: float, + total: float | None = None, + message: str | None = None, + resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + | None = None, ) -> None: """Report progress for the current operation. @@ -985,6 +990,7 @@ async def report_progress( progress=progress, total=total, message=message, + resource_uri=resource_uri, ) async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]: diff --git a/src/mcp/server/lowlevel/async_request_manager.py b/src/mcp/server/lowlevel/async_request_manager.py new file mode 100644 index 000000000..b550d4e05 --- /dev/null +++ b/src/mcp/server/lowlevel/async_request_manager.py @@ -0,0 +1,320 @@ +import time +from collections.abc import Awaitable, Callable +from concurrent.futures import CancelledError, Future +from dataclasses import dataclass, field +from logging import getLogger +from types import TracebackType +from typing import Any +from uuid import uuid4 + +import anyio +import anyio.to_thread +from anyio.from_thread import BlockingPortal, BlockingPortalProvider + +from mcp import types +from mcp.server.auth.middleware.auth_context import auth_context_var as user_context +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.session import ServerSession +from mcp.shared.context import RequestContext + +logger = getLogger(__name__) + + +class AsyncRequestManager: + async def __aenter__(self): ... + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: ... + async def start_call( + self, + call: Callable[[types.CallToolRequest], Awaitable[types.ServerResult]], + req: types.CallToolAsyncRequest, + ctx: RequestContext[ServerSession, Any, Any], + ) -> types.CallToolAsyncResult: ... + async def join_call( + self, + req: types.JoinCallToolAsyncRequest, + ctx: RequestContext[ServerSession, Any, Any], + ) -> types.CallToolAsyncResult: ... + async def cancel(self, notification: types.CancelToolAsyncNotification) -> None: ... + async def get_result( + self, req: types.GetToolAsyncResultRequest + ) -> types.CallToolResult: ... + + async def notification_hook( + self, session: ServerSession, notification: types.ServerNotification + ) -> None: ... + async def session_close_hook(self, session: ServerSession): ... + + +@dataclass +class InProgress: + token: str + timer: Callable[[], float] + user: AuthenticatedUser | None = None + future: Future[types.CallToolResult] | None = None + sessions: dict[int, ServerSession] = field(default_factory=lambda: {}) + session_progress: dict[int, types.ProgressToken | None] = field( + default_factory=lambda: {} + ) + keep_alive: int | None = None + keep_alive_start: int | None = None + + def is_expired(self): + if self.keep_alive_start is None or self.keep_alive is None: + return False + else: + return int(self.timer()) > self.keep_alive_start + self.keep_alive + + +class SimpleInMemoryAsyncRequestManager(AsyncRequestManager): + """ + Note this class is a work in progress + Its purpose is to act as a central point for managing in progress + async calls, allowing multiple clients to join and receive progress + updates, get results and/or cancel in progress calls + TODO MAJOR needs a lot more testing around edge cases/failure scenarios + TODO MAJOR decide if async.Locks are required for integrity of internal + data structures + TODO ENHANCEMENT may need to add an authorisation layer to decide if + a user is allowed to get/join/cancel an existing async call current + simple logic only allows same user to perform these tasks + """ + + _in_progress: dict[types.AsyncToken, InProgress] + _session_lookup: dict[int, types.AsyncToken] + _portal: BlockingPortal + + def __init__( + self, + max_size: int, + max_keep_alive: int, + timer: Callable[[], float] = time.monotonic, + ): + self._max_size = max_size + self._max_keep_alive = max_keep_alive + self._in_progress = {} + self._session_lookup = {} + self._timer = timer + self._portal_provider = BlockingPortalProvider() + + async def __aenter__(self): + def create_portal(): + self._portal = self._portal_provider.__enter__() + + await anyio.to_thread.run_sync(create_portal) + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + await anyio.to_thread.run_sync(lambda: self._portal_provider.__exit__) + + async def start_call( + self, + call: Callable[[types.CallToolRequest], Awaitable[types.ServerResult]], + req: types.CallToolAsyncRequest, + ctx: RequestContext[ServerSession, Any, Any], + ) -> types.CallToolAsyncResult: + in_progress = await self._new_in_progress() + timeout = min( + req.params.keepAlive or self._max_keep_alive, self._max_keep_alive + ) + + async def call_tool(): + result = await call( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams( + name=req.params.name, + arguments=req.params.arguments, + _meta=req.params.meta, + ), + ) + ) + # async with self._lock: + assert type(result.root) is types.CallToolResult + logger.debug(f"Got result {result}") + return result.root + + in_progress.user = user_context.get() + session_id = id(ctx.session) + in_progress.sessions[session_id] = ctx.session + in_progress.keep_alive = timeout + if req.params.meta is not None: + progress_token = req.params.meta.progressToken + else: + progress_token = None + in_progress.session_progress[session_id] = progress_token + self._session_lookup[session_id] = in_progress.token + in_progress.future = self._portal.start_task_soon(call_tool) + result = types.CallToolAsyncResult( + token=in_progress.token, + keepAlive=timeout, + accepted=True, + ) + return result + + async def join_call( + self, + req: types.JoinCallToolAsyncRequest, + ctx: RequestContext[ServerSession, Any, Any], + ) -> types.CallToolAsyncResult: + # async with self._lock: + in_progress = self._in_progress.get(req.params.token) + if in_progress is None: + # TODO consider creating new token to allow client + # to get message describing why it wasn't accepted + logger.warning("Discarding join request for unknown async token") + return types.CallToolAsyncResult(accepted=False) + else: + # TODO consider adding authorisation layer to make this decision + if in_progress.user == user_context.get(): + session_id = id(ctx.session) + logger.debug(f"Received join from {session_id}") + self._session_lookup[session_id] = req.params.token + in_progress.sessions[session_id] = ctx.session + if req.params.meta is not None: + progress_token = req.params.meta.progressToken + else: + progress_token = None + in_progress.session_progress[session_id] = progress_token + return types.CallToolAsyncResult(token=req.params.token, accepted=True) + else: + # TODO consider sending error via get result + return types.CallToolAsyncResult(accepted=False) + + async def cancel(self, notification: types.CancelToolAsyncNotification) -> None: + # async with self._lock: + in_progress = self._in_progress.get(notification.params.token) + if in_progress is not None: + if in_progress.user == user_context.get(): + # in_progress.task_group.cancel_scope.cancel() + assert in_progress.future is not None, "In progress future not found" + in_progress.future.cancel() + else: + logger.warning( + "Permission denied for cancel notification received" + f"from {user_context.get()}" + ) + + async def get_result( + self, req: types.GetToolAsyncResultRequest + ) -> types.CallToolResult: + logger.debug("Getting result") + async_token = req.params.token + in_progress = self._in_progress.get(async_token) + if in_progress is None: + return types.CallToolResult( + content=[types.TextContent(type="text", text="Unknown async token")], + isError=True, + ) + else: + logger.debug(f"Found in progress {in_progress}") + if in_progress.user == user_context.get(): + assert in_progress.future is not None + # TODO add timeout to get async result + if in_progress.is_expired(): + self._portal.start_task_soon(self._expire) + return types.CallToolResult( + content=[ + types.TextContent(type="text", text="Unknown async token") + ], + isError=True, + ) + + try: + result = in_progress.future.result(1) + logger.debug(f"Found result {result}") + return result + except CancelledError: + return types.CallToolResult( + content=[types.TextContent(type="text", text="cancelled")], + isError=True, + # TODO add isCancelled state to protocol? + ) + except TimeoutError: + return types.CallToolResult( + content=[], + isPending=True, + ) + else: + return types.CallToolResult( + content=[types.TextContent(type="text", text="Permission denied")], + isError=True, + ) + + async def notification_hook( + self, session: ServerSession, notification: types.ServerNotification + ): + session_id = id(session) + logger.debug(f"received {notification} from {session_id}") + if type(notification.root) is types.ProgressNotification: + # async with self._lock: + async_token = self._session_lookup.get(session_id) + if async_token is None: + # not all sessions are async so just debug + logger.debug("Discarding progress notification from unknown session") + else: + in_progress = self._in_progress.get(async_token) + assert in_progress is not None, "lost in progress for {async_token}" + for other_id, other_session in in_progress.sessions.items(): + logger.debug(f"Checking {other_id} == {session_id}") + if not other_id == session_id: + logger.debug(f"Sending progress to {other_id}") + progress_token = in_progress.session_progress.get(other_id) + assert progress_token is not None + await other_session.send_progress_notification( + # TODO this token is incorrect + # it needs to be collected from original request + progress_token=progress_token, + progress=notification.root.params.progress, + total=notification.root.params.total, + message=notification.root.params.message, + resource_uri=notification.root.params.resourceUri, + ) + + async def session_close_hook(self, session: ServerSession): + session_id = id(session) + logger.debug(f"Received session close for {session_id}") + dropped = self._session_lookup.pop(session_id, None) + if dropped is None: + # lots of sessions will have no async tasks debug and return + logger.debug(f"Discarded callback, unknown session {session_id}") + return + + in_progress = self._in_progress.get(dropped) + assert in_progress is not None, "In progress not found" + found = in_progress.sessions.pop(session_id, None) + if found is None: + logger.warning("No session found") + if len(in_progress.sessions) == 0: + in_progress.keep_alive_start = int(self._timer()) + + async def _expire(self): + for in_progress in self._in_progress.values(): + if in_progress.is_expired(): + self._in_progress.pop(in_progress.token, None) + assert in_progress.future is not None + logger.debug("Cancelled in progress future") + in_progress.future.cancel() + + async def _new_in_progress(self) -> InProgress: + while True: + # this nonsense is required to protect against the + # ridiculously unlikely scenario that two v4 uuids + # are generated with the same value + # uuidv7 would fix this but it is not yet included + # in python standard library + # see https://github.com/python/cpython/issues/89083 + # for context + token = str(uuid4()) + if token not in self._in_progress: + new_in_progress = InProgress(token, self._timer) + self._in_progress[token] = new_in_progress + return new_in_progress diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index f6d390c2f..89a218b9c 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -80,6 +80,10 @@ async def main(): from typing_extensions import TypeVar import mcp.types as types +from mcp.server.lowlevel.async_request_manager import ( + AsyncRequestManager, + SimpleInMemoryAsyncRequestManager, +) from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession @@ -135,6 +139,9 @@ def __init__( [Server[LifespanResultT, RequestT]], AbstractAsyncContextManager[LifespanResultT], ] = lifespan, + async_request_manager: AsyncRequestManager = SimpleInMemoryAsyncRequestManager( + max_size=1000, max_keep_alive=60 + ), ): self.name = name self.version = version @@ -145,6 +152,7 @@ def __init__( ] = { types.PingRequest: _ping_handler, } + self.async_request_manager = async_request_manager self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} self.notification_options = NotificationOptions() logger.debug("Initializing server %r", name) @@ -426,7 +434,32 @@ async def handler(req: types.CallToolRequest): ) ) + async def async_call_handler(req: types.CallToolAsyncRequest): + ctx = request_ctx.get() + result = await self.async_request_manager.start_call(handler, req, ctx) + return types.ServerResult(result) + + async def async_join_handler(req: types.JoinCallToolAsyncRequest): + ctx = request_ctx.get() + result = await self.async_request_manager.join_call(req, ctx) + return types.ServerResult(result) + + async def async_cancel_handler(req: types.CancelToolAsyncNotification): + await self.async_request_manager.cancel(req) + + async def async_result_handler(req: types.GetToolAsyncResultRequest): + result = await self.async_request_manager.get_result(req) + return types.ServerResult(result) + self.request_handlers[types.CallToolRequest] = handler + self.request_handlers[types.CallToolAsyncRequest] = async_call_handler + self.request_handlers[types.JoinCallToolAsyncRequest] = async_join_handler + self.request_handlers[types.GetToolAsyncResultRequest] = ( + async_result_handler + ) + self.notification_handlers[types.CancelToolAsyncNotification] = ( + async_cancel_handler + ) return func return decorator @@ -505,8 +538,11 @@ async def run( write_stream, initialization_options, stateless=stateless, + notification_hook=self.async_request_manager.notification_hook, + session_close_hook=self.async_request_manager.session_close_hook, ) ) + await stack.enter_async_context(self.async_request_manager) async with anyio.create_task_group() as tg: async for message in session.incoming_messages: diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index ef5c5a3c3..d5b08ee58 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -37,13 +37,15 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: be instantiated directly by users of the MCP framework. """ +from collections.abc import Awaitable, Callable from enum import Enum -from typing import Any, TypeVar +from typing import Annotated, Any, TypeVar import anyio import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import AnyUrl +from pydantic.networks import AnyUrl, UrlConstraints +from typing_extensions import Self import mcp.types as types from mcp.server.models import InitializationOptions @@ -88,9 +90,16 @@ def __init__( write_stream: MemoryObjectSendStream[SessionMessage], init_options: InitializationOptions, stateless: bool = False, + notification_hook: Callable[[Self, types.ServerNotification], Awaitable[None]] + | None = None, + session_close_hook: Callable[[Self], Awaitable[None]] | None = None, ) -> None: super().__init__( - read_stream, write_stream, types.ClientRequest, types.ClientNotification + read_stream, + write_stream, + types.ClientRequest, + types.ClientNotification, + notification_hook=notification_hook, ) self._initialization_state = ( InitializationState.Initialized @@ -106,6 +115,12 @@ def __init__( lambda: self._incoming_message_stream_reader.aclose() ) + async def call_session_close(): + if session_close_hook is not None: + await session_close_hook(self) + + self._exit_stack.push_async_callback(call_session_close) + @property def client_params(self) -> types.InitializeRequestParams | None: return self._client_params @@ -288,6 +303,8 @@ async def send_progress_notification( total: float | None = None, message: str | None = None, related_request_id: str | None = None, + resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + | None = None, ) -> None: """Send a progress notification.""" await self.send_notification( @@ -299,6 +316,7 @@ async def send_progress_notification( progress=progress, total=total, message=message, + resourceUri=resource_uri, ), ) ), diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 4b13709c6..4cbbe2b71 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,14 +1,23 @@ +import inspect import logging -from collections.abc import Callable +from collections.abc import Awaitable, Callable from contextlib import AsyncExitStack from datetime import timedelta from types import TracebackType -from typing import Any, Generic, Protocol, TypeVar +from typing import ( + Annotated, + Any, + Generic, + Protocol, + TypeVar, + runtime_checkable, +) import anyio import httpx from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import BaseModel +from pydantic.networks import AnyUrl, UrlConstraints from typing_extensions import Self from mcp.shared.exceptions import McpError @@ -44,6 +53,7 @@ RequestId = str | int +@runtime_checkable class ProgressFnT(Protocol): """Protocol for progress notification callbacks.""" @@ -52,6 +62,20 @@ async def __call__( ) -> None: ... +@runtime_checkable +class ResourceProgressFnT(Protocol): + """Protocol for progress notification callbacks with resources.""" + + async def __call__( + self, + progress: float, + total: float | None, + message: str | None, + resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + | None = None, + ) -> None: ... + + class RequestResponder(Generic[ReceiveRequestT, SendResultT]): """Handles responding to MCP requests and manages request lifecycle. @@ -182,6 +206,7 @@ class BaseSession( _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] _progress_callbacks: dict[RequestId, ProgressFnT] + _resource_callbacks: dict[RequestId, ResourceProgressFnT] def __init__( self, @@ -191,6 +216,8 @@ def __init__( receive_notification_type: type[ReceiveNotificationT], # If none, reading will never time out read_timeout_seconds: timedelta | None = None, + notification_hook: Callable[[Self, SendNotificationT], Awaitable[None]] + | None = None, ) -> None: self._read_stream = read_stream self._write_stream = write_stream @@ -201,7 +228,9 @@ def __init__( self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} self._progress_callbacks = {} + self._resource_callbacks = {} self._exit_stack = AsyncExitStack() + self._notification_hook = notification_hook async def __aenter__(self) -> Self: self._task_group = anyio.create_task_group() @@ -228,7 +257,7 @@ async def send_request( result_type: type[ReceiveResultT], request_read_timeout_seconds: timedelta | None = None, metadata: MessageMetadata = None, - progress_callback: ProgressFnT | None = None, + progress_callback: ProgressFnT | ResourceProgressFnT | None = None, ) -> ReceiveResultT: """ Sends a request and wait for a response. Raises an McpError if the @@ -255,8 +284,15 @@ async def send_request( if "_meta" not in request_data["params"]: request_data["params"]["_meta"] = {} request_data["params"]["_meta"]["progressToken"] = request_id - # Store the callback for this request - self._progress_callbacks[request_id] = progress_callback + # note this is required to ensure backwards compatibility + # for previous clients + signature = inspect.signature(progress_callback.__call__) + if len(signature.parameters) == 3: + # Store the callback for this request + self._resource_callbacks[request_id] = progress_callback # type: ignore + else: + # Store the callback for this request + self._progress_callbacks[request_id] = progress_callback try: jsonrpc_request = JSONRPCRequest( @@ -313,6 +349,12 @@ async def send_notification( Emits a notification, which is a one-way message that does not expect a response. """ + if self._notification_hook: + try: + await self._notification_hook(self, notification) + except Exception: + logging.exception("Notification hook failed") + # Some transport implementations may need to set the related_request_id # to attribute to the notifications to the request that triggered them. jsonrpc_notification = JSONRPCNotification( @@ -401,6 +443,15 @@ async def _receive_loop(self) -> None: notification.root.params.total, notification.root.params.message, ) + elif progress_token in self._resource_callbacks: + callback = self._resource_callbacks[progress_token] + await callback( + notification.root.params.progress, + notification.root.params.total, + notification.root.params.message, + notification.root.params.resourceUri, + ) + await self._received_notification(notification) await self._handle_incoming(notification) except Exception as e: diff --git a/src/mcp/types.py b/src/mcp/types.py index 4f5af27b9..a2416bfcc 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -31,6 +31,7 @@ LATEST_PROTOCOL_VERSION = "2025-03-26" +AsyncToken = str | int ProgressToken = str | int Cursor = str Role = Literal["user", "assistant"] @@ -261,6 +262,14 @@ class ToolsCapability(BaseModel): model_config = ConfigDict(extra="allow") +class AsyncCapability(BaseModel): + """Capability for async operations.""" + + maxKeepAliveTime: int | None = None + """The maximum keep alive time in seconds for async requests.""" + model_config = ConfigDict(extra="allow") + + class LoggingCapability(BaseModel): """Capability for logging operations.""" @@ -279,6 +288,9 @@ class ServerCapabilities(BaseModel): resources: ResourcesCapability | None = None """Present if the server offers any resources to read.""" tools: ToolsCapability | None = None + """Present if the server offers async tool calling support.""" + async_: AsyncCapability | None = Field(alias="async", default=None) + """Present if the server offers any tools to call.""" model_config = ConfigDict(extra="allow") @@ -350,12 +362,19 @@ class ProgressNotificationParams(NotificationParams): total is unknown. """ total: float | None = None + """Total number of items to process (or total progress required), if known.""" + message: str | None = None """ Message related to progress. This should provide relevant human readable progress information. """ - message: str | None = None - """Total number of items to process (or total progress required), if known.""" + resourceUri: Annotated[AnyUrl, UrlConstraints(host_required=False)] | None = None + """ + An optional reference to an ephemeral resource associated with this + progress, servers may delete these at their descretion, but are encouraged + to make them available for a reasonable time period to allow clients to + retrieve and cache the resources locally + """ model_config = ConfigDict(extra="allow") @@ -760,6 +779,12 @@ class ToolAnnotations(BaseModel): of a memory tool is not. Default: true """ + preferAsync: bool | None = None + """ + If true, should ideally be called using the async protocol + as requests are expected to be long running. + Default: false + """ model_config = ConfigDict(extra="allow") @@ -774,6 +799,8 @@ class Tool(BaseModel): """A JSON Schema object defining the expected parameters for the tool.""" annotations: ToolAnnotations | None = None """Optional additional tool information.""" + preferAsync: bool | None = None + """Optional flag to suggest to client async calls should be preferred""" model_config = ConfigDict(extra="allow") @@ -798,11 +825,76 @@ class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]): params: CallToolRequestParams +class CallToolAsyncRequestParams(CallToolRequestParams): + """Parameters for calling a tool asynchronously.""" + + keepAlive: int | None = None + model_config = ConfigDict(extra="allow") + + +class CallToolAsyncRequest( + Request[CallToolAsyncRequestParams, Literal["tools/async/call"]] +): + """Used by the client to invoke a tool provided by the server asynchronously.""" + + method: Literal["tools/async/call"] + params: CallToolAsyncRequestParams + + +class JoinCallToolRequestParams(RequestParams): + """Parameters for joining an asynchronous tool call.""" + + token: AsyncToken + keepAlive: int | None = None + model_config = ConfigDict(extra="allow") + + +class JoinCallToolAsyncRequest( + Request[JoinCallToolRequestParams, Literal["tools/async/join"]] +): + """Used by the client to join an tool call executing on the server + asynchronously.""" + + method: Literal["tools/async/join"] + params: JoinCallToolRequestParams + + +class CancelToolAsyncNotificationParams(NotificationParams): + token: AsyncToken + + +class CancelToolAsyncNotification( + Notification[CancelToolAsyncNotificationParams, Literal["tools/async/cancel"]] +): + method: Literal["tools/async/cancel"] + params: CancelToolAsyncNotificationParams + + +class GetToolAsyncResultRequestParams(RequestParams): + token: AsyncToken + + +class GetToolAsyncResultRequest( + Request[GetToolAsyncResultRequestParams, Literal["tools/async/get"]] +): + method: Literal["tools/async/get"] + params: GetToolAsyncResultRequestParams + + class CallToolResult(Result): """The server's response to a tool call.""" content: list[TextContent | ImageContent | EmbeddedResource] isError: bool = False + isPending: bool = False + + +class CallToolAsyncResult(Result): + """The servers response to an async tool call""" + + token: AsyncToken | None = None + keepAlive: int | None = None + accepted: bool class ToolListChangedNotification( @@ -1134,6 +1226,9 @@ class ClientRequest( | SubscribeRequest | UnsubscribeRequest | CallToolRequest + | CallToolAsyncRequest + | JoinCallToolAsyncRequest + | GetToolAsyncResultRequest | ListToolsRequest ] ): @@ -1145,6 +1240,7 @@ class ClientNotification( CancelledNotification | ProgressNotification | InitializedNotification + | CancelToolAsyncNotification | RootsListChangedNotification ] ): @@ -1184,6 +1280,7 @@ class ServerResult( | ListResourceTemplatesResult | ReadResourceResult | CallToolResult + | CallToolAsyncResult | ListToolsResult ] ): diff --git a/tests/issues/test_176_progress_token.py b/tests/issues/test_176_progress_token.py index 4ad22f294..6ba4770aa 100644 --- a/tests/issues/test_176_progress_token.py +++ b/tests/issues/test_176_progress_token.py @@ -39,11 +39,11 @@ async def test_progress_token_zero_first_call(): mock_session.send_progress_notification.call_count == 3 ), "All progress notifications should be sent" mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=0.0, total=10.0, message=None + progress_token=0, progress=0.0, total=10.0, message=None, resource_uri=None ) mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=5.0, total=10.0, message=None + progress_token=0, progress=5.0, total=10.0, message=None, resource_uri=None ) mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=10.0, total=10.0, message=None + progress_token=0, progress=10.0, total=10.0, message=None, resource_uri=None ) diff --git a/tests/server/lowlevel/test_async_request_manager.py b/tests/server/lowlevel/test_async_request_manager.py new file mode 100644 index 000000000..5b92cb075 --- /dev/null +++ b/tests/server/lowlevel/test_async_request_manager.py @@ -0,0 +1,440 @@ +from contextlib import AsyncExitStack +from unittest.mock import AsyncMock, Mock, PropertyMock + +import anyio +import pytest + +from mcp import types +from mcp.server.auth.middleware.auth_context import ( + auth_context_var as user_context, +) +from mcp.server.lowlevel.async_request_manager import SimpleInMemoryAsyncRequestManager + + +@pytest.mark.anyio +async def test_async_call(): + """Tests basic async call""" + + async def test_call(call: types.CallToolRequest) -> types.ServerResult: + return types.ServerResult( + types.CallToolResult(content=[types.TextContent(type="text", text="test")]) + ) + + async_call = types.CallToolAsyncRequest( + method="tools/async/call", params=types.CallToolAsyncRequestParams(name="test") + ) + + mock_session = AsyncMock() + mock_context = Mock() + mock_context.session = mock_session + result_cache = SimpleInMemoryAsyncRequestManager(max_size=1, max_keep_alive=1) + async with AsyncExitStack() as stack: + await stack.enter_async_context(result_cache) + async_call_ref = await result_cache.start_call( + test_call, async_call, mock_context + ) + assert async_call_ref.token is not None + + result = await result_cache.get_result( + types.GetToolAsyncResultRequest( + method="tools/async/get", + params=types.GetToolAsyncResultRequestParams( + token=async_call_ref.token + ), + ) + ) + + assert not result.isError + assert not result.isPending + assert len(result.content) == 1 + assert type(result.content[0]) is types.TextContent + assert result.content[0].text == "test" + + +@pytest.mark.anyio +async def test_async_join_call_progress(): + """Tests basic async call""" + + async def test_call(call: types.CallToolRequest) -> types.ServerResult: + return types.ServerResult( + types.CallToolResult(content=[types.TextContent(type="text", text="test")]) + ) + + async_call = types.CallToolAsyncRequest( + method="tools/async/call", params=types.CallToolAsyncRequestParams(name="test") + ) + + mock_session_1 = AsyncMock() + mock_context_1 = Mock() + mock_context_1.session = mock_session_1 + + mock_session_2 = AsyncMock() + mock_context_2 = Mock() + + mock_context_2.session = mock_session_2 + + result_cache = SimpleInMemoryAsyncRequestManager(max_size=1, max_keep_alive=1) + async with AsyncExitStack() as stack: + await stack.enter_async_context(result_cache) + async_call_ref = await result_cache.start_call( + test_call, async_call, mock_context_1 + ) + assert async_call_ref.token is not None + + await result_cache.join_call( + req=types.JoinCallToolAsyncRequest( + method="tools/async/join", + params=types.JoinCallToolRequestParams( + token=async_call_ref.token, + _meta=types.RequestParams.Meta(progressToken="test"), + ), + ), + ctx=mock_context_2, + ) + assert async_call_ref.token is not None + await result_cache.notification_hook( + session=mock_session_1, + notification=types.ServerNotification( + types.ProgressNotification( + method="notifications/progress", + params=types.ProgressNotificationParams( + progressToken="test", progress=1 + ), + ) + ), + ) + + result = await result_cache.get_result( + types.GetToolAsyncResultRequest( + method="tools/async/get", + params=types.GetToolAsyncResultRequestParams( + token=async_call_ref.token + ), + ) + ) + + assert not result.isError + assert not result.isPending + assert len(result.content) == 1 + assert type(result.content[0]) is types.TextContent + assert result.content[0].text == "test" + mock_context_1.send_progress_notification.assert_not_called() + mock_session_2.send_progress_notification.assert_called_with( + progress_token="test", + progress=1.0, + total=None, + message=None, + resource_uri=None, + ) + + +@pytest.mark.anyio +async def test_async_cancel_in_progress(): + """Tests cancelling an in progress async call""" + + async def slow_call(call: types.CallToolRequest) -> types.ServerResult: + with anyio.move_on_after(10) as scope: + await anyio.sleep(10) + + if scope.cancel_called: + return types.ServerResult( + types.CallToolResult( + content=[ + types.TextContent(type="text", text="should be discarded") + ], + isError=True, + ) + ) + else: + return types.ServerResult( + types.CallToolResult( + content=[types.TextContent(type="text", text="test")] + ) + ) + + async_call = types.CallToolAsyncRequest( + method="tools/async/call", params=types.CallToolAsyncRequestParams(name="test") + ) + + mock_session_1 = AsyncMock() + mock_context_1 = Mock() + mock_context_1.session = mock_session_1 + + result_cache = SimpleInMemoryAsyncRequestManager(max_size=1, max_keep_alive=1) + async with AsyncExitStack() as stack: + await stack.enter_async_context(result_cache) + async_call_ref = await result_cache.start_call( + slow_call, async_call, mock_context_1 + ) + assert async_call_ref.token is not None + + await result_cache.cancel( + notification=types.CancelToolAsyncNotification( + method="tools/async/cancel", + params=types.CancelToolAsyncNotificationParams( + token=async_call_ref.token + ), + ), + ) + + assert async_call_ref.token is not None + await result_cache.notification_hook( + session=mock_session_1, + notification=types.ServerNotification( + types.ProgressNotification( + method="notifications/progress", + params=types.ProgressNotificationParams( + progressToken="test", progress=1 + ), + ) + ), + ) + + result = await result_cache.get_result( + types.GetToolAsyncResultRequest( + method="tools/async/get", + params=types.GetToolAsyncResultRequestParams( + token=async_call_ref.token + ), + ) + ) + + assert result.isError + assert not result.isPending + assert len(result.content) == 1 + assert type(result.content[0]) is types.TextContent + assert result.content[0].text == "cancelled" + + +@pytest.mark.anyio +async def test_async_call_keep_alive(): + """Tests async call keep alive""" + + async def test_call(call: types.CallToolRequest) -> types.ServerResult: + return types.ServerResult( + types.CallToolResult(content=[types.TextContent(type="text", text="test")]) + ) + + async_call = types.CallToolAsyncRequest( + method="tools/async/call", params=types.CallToolAsyncRequestParams(name="test") + ) + + mock_session_1 = AsyncMock() + mock_context_1 = Mock() + mock_context_1.session = mock_session_1 + + mock_session_2 = AsyncMock() + mock_context_2 = Mock() + + mock_context_2.session = mock_session_2 + + result_cache = SimpleInMemoryAsyncRequestManager(max_size=1, max_keep_alive=10) + async with AsyncExitStack() as stack: + await stack.enter_async_context(result_cache) + async_call_ref = await result_cache.start_call( + test_call, async_call, mock_context_1 + ) + assert async_call_ref.token is not None + + await result_cache.session_close_hook(mock_session_1) + + await result_cache.join_call( + req=types.JoinCallToolAsyncRequest( + method="tools/async/join", + params=types.JoinCallToolRequestParams( + token=async_call_ref.token, + _meta=types.RequestParams.Meta(progressToken="test"), + ), + ), + ctx=mock_context_2, + ) + assert async_call_ref.token is not None + await result_cache.notification_hook( + session=mock_session_1, + notification=types.ServerNotification( + types.ProgressNotification( + method="notifications/progress", + params=types.ProgressNotificationParams( + progressToken="test", progress=1 + ), + ) + ), + ) + + result = await result_cache.get_result( + types.GetToolAsyncResultRequest( + method="tools/async/get", + params=types.GetToolAsyncResultRequestParams( + token=async_call_ref.token + ), + ) + ) + + assert not result.isError, str(result) + assert not result.isPending + assert len(result.content) == 1 + assert type(result.content[0]) is types.TextContent + assert result.content[0].text == "test" + + +@pytest.mark.anyio +async def test_async_call_keep_alive_expired(): + """Tests async call keep alive expiry""" + + async def test_call(call: types.CallToolRequest) -> types.ServerResult: + return types.ServerResult( + types.CallToolResult(content=[types.TextContent(type="text", text="test")]) + ) + + async_call = types.CallToolAsyncRequest( + method="tools/async/call", params=types.CallToolAsyncRequestParams(name="test") + ) + + mock_session_1 = AsyncMock() + mock_context_1 = Mock() + mock_context_1.session = mock_session_1 + + mock_session_2 = AsyncMock() + mock_context_2 = Mock() + mock_context_2.session = mock_session_2 + + mock_session_3 = AsyncMock() + mock_context_3 = Mock() + mock_context_3.session = mock_session_3 + + time = 0.0 + + def test_timer(): + return time + + result_cache = SimpleInMemoryAsyncRequestManager( + max_size=1, max_keep_alive=1, timer=test_timer + ) + async with AsyncExitStack() as stack: + await stack.enter_async_context(result_cache) + async_call_ref = await result_cache.start_call( + test_call, async_call, mock_context_1 + ) + assert async_call_ref.token is not None + + # lose the connection + await result_cache.session_close_hook(mock_session_1) + + # reconnect before keep_alive_timeout + time = 0.5 + await result_cache.join_call( + req=types.JoinCallToolAsyncRequest( + method="tools/async/join", + params=types.JoinCallToolRequestParams( + token=async_call_ref.token, + _meta=types.RequestParams.Meta(progressToken="test"), + ), + ), + ctx=mock_context_2, + ) + + result = await result_cache.get_result( + types.GetToolAsyncResultRequest( + method="tools/async/get", + params=types.GetToolAsyncResultRequestParams( + token=async_call_ref.token + ), + ) + ) + + # should successfully read data + assert not result.isError, str(result) + assert len(result.content) == 1 + assert type(result.content[0]) is types.TextContent + assert result.content[0].text == "test" + + # lose connection a second time + + await result_cache.session_close_hook(mock_session_2) + + time = 2 + + # reconnect after the keep_alive_timeout + + await result_cache.join_call( + req=types.JoinCallToolAsyncRequest( + method="tools/async/join", + params=types.JoinCallToolRequestParams( + token=async_call_ref.token, + _meta=types.RequestParams.Meta(progressToken="test"), + ), + ), + ctx=mock_context_3, + ) + + result = await result_cache.get_result( + types.GetToolAsyncResultRequest( + method="tools/async/get", + params=types.GetToolAsyncResultRequestParams( + token=async_call_ref.token + ), + ) + ) + + # now token should be expired + assert result.isError, str(result) + assert len(result.content) == 1 + assert type(result.content[0]) is types.TextContent + assert result.content[0].text == "Unknown async token" + + +@pytest.mark.anyio +async def test_async_call_pass_auth(): + """Tests async calls pass auth context to background thread""" + + mock_user = Mock() + type(mock_user).username = PropertyMock(return_value="mock_user") + + mock_session = AsyncMock() + mock_context = Mock() + mock_context.session = mock_session + result_cache = SimpleInMemoryAsyncRequestManager(max_size=1, max_keep_alive=1) + + async def test_call(call: types.CallToolRequest) -> types.ServerResult: + user = user_context.get() + if user is None: + return types.ServerResult( + types.CallToolResult( + content=[types.TextContent(type="text", text="unauthorised")], + isError=True, + ) + ) + else: + return types.ServerResult( + types.CallToolResult( + content=[types.TextContent(type="text", text=str(user.username))] + ) + ) + + async_call = types.CallToolAsyncRequest( + method="tools/async/call", params=types.CallToolAsyncRequestParams(name="test") + ) + + async with AsyncExitStack() as stack: + await stack.enter_async_context(result_cache) + + user_context.set(mock_user) + async_call_ref = await result_cache.start_call( + test_call, async_call, mock_context + ) + assert async_call_ref.token is not None + + result = await result_cache.get_result( + types.GetToolAsyncResultRequest( + method="tools/async/get", + params=types.GetToolAsyncResultRequestParams( + token=async_call_ref.token + ), + ) + ) + + assert not result.isError + assert not result.isPending + assert len(result.content) == 1 + assert type(result.content[0]) is types.TextContent + assert result.content[0].text == "mock_user" diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index eb4e004ae..f01688f5d 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -1,4 +1,5 @@ from collections.abc import AsyncGenerator +from logging import getLogger import anyio import pytest @@ -19,6 +20,8 @@ EmptyResult, ) +logger = getLogger(__name__) + @pytest.fixture def mcp_server() -> Server: @@ -129,6 +132,306 @@ async def make_request(client_session): await ev_cancelled.wait() +@pytest.mark.anyio +async def test_request_async(): + """Test that requests can be run asynchronously.""" + # The tool is already registered in the fixture + + ev_tool_called = anyio.Event() + + # Start the request in a separate task so we can cancel it + def make_server() -> Server: + server = Server(name="TestSessionServer") + + # Register the tool handler + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict | None) -> list: + nonlocal ev_tool_called + if name == "async_tool": + ev_tool_called.set() + return [types.TextContent(type="text", text="test")] + raise ValueError(f"Unknown tool: {name}") + + # Register the tool so it shows up in list_tools + @server.list_tools() + async def handle_list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="async_tool", + description="A tool that does things asynchronously", + inputSchema={}, + preferAsync=True, + ) + ] + + return server + + async def make_request(client_session: ClientSession): + return await client_session.send_request( + ClientRequest( + types.CallToolAsyncRequest( + method="tools/async/call", + params=types.CallToolAsyncRequestParams( + name="async_tool", arguments={} + ), + ) + ), + types.CallToolAsyncResult, + ) + + async def get_result(client_session: ClientSession, async_token: types.AsyncToken): + with anyio.fail_after(1): + while True: + print("getting results") + result = await client_session.send_request( + ClientRequest( + types.GetToolAsyncResultRequest( + method="tools/async/get", + params=types.GetToolAsyncResultRequestParams( + token=async_token + ), + ) + ), + types.CallToolResult, + ) + print(f"retrieved {result}") + if result.isPending: + await anyio.sleep(1) + elif result.isError: + raise RuntimeError(str(result)) + else: + return result + + async with create_connected_server_and_client_session( + make_server() + ) as client_session: + async_call = await make_request(client_session) + assert async_call is not None + assert async_call.token is not None + with anyio.fail_after(1): # Timeout after 1 second + await ev_tool_called.wait() + result = await get_result(client_session, async_call.token) + assert type(result.content[0]) is types.TextContent + assert result.content[0].text == "test" + + +@pytest.mark.anyio +@pytest.mark.skip( + reason="This test does not work, there is a subtle " + "bug with event.wait, lower level test_result_cache " + "tests underlying behaviour, revisit with feedback " + "from someone who cah help debug" +) +async def test_request_async_join(): + """Test that requests can be joined from external sessions.""" + # The tool is already registered in the fixture + + # TODO note these events are not working as expected + # test code below uses move_on_after rather than + # fail_after as events are not triggered as expected + # this effectively makes the test lots of sleep + # calls, needs further investigation + ev_client_1_started = anyio.Event() + ev_client_2_joined = anyio.Event() + ev_client_1_progressed_1 = anyio.Event() + ev_client_1_progressed_2 = anyio.Event() + ev_client_2_progressed_1 = anyio.Event() + ev_done = anyio.Event() + + # Start the request in a separate task so we can cancel it + def make_server() -> Server: + server = Server(name="TestSessionServer") + + # Register the tool handler + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict | None) -> list: + nonlocal ev_client_2_joined + if name == "async_tool": + try: + logger.info("tool: sending 1/2") + await server.request_context.session.send_progress_notification( + progress_token=server.request_context.request_id, + progress=1, + total=2, + ) + logger.info("tool: sent 1/2") + with anyio.fail_after(10): # Timeout after 1 second + # TODO this is not working for some unknown reason + logger.info("tool: waiting for client 2 joined") + await ev_client_2_joined.wait() + + logger.info("tool: sending 2/2") + await server.request_context.session.send_progress_notification( + progress_token=server.request_context.request_id, + progress=2, + total=2, + ) + logger.info("tool: sent 2/2") + result = [types.TextContent(type="text", text="test")] + logger.info("tool: sending result") + return result + except Exception as e: + logger.exception(e) + logger.info(f"tool: caught: {str(e)}") + raise e + else: + raise ValueError(f"Unknown tool: {name}") + + # Register the tool so it shows up in list_tools + @server.list_tools() + async def handle_list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="async_tool", + description="A tool that does things asynchronously", + inputSchema={}, + preferAsync=True, + ) + ] + + return server + + async def client_1_progress_callback( + progress: float, total: float | None, message: str | None + ): + nonlocal ev_client_1_progressed_1 + nonlocal ev_client_1_progressed_2 + logger.info(f"client1: progress started: {progress}/{total}") + if progress == 1.0: + ev_client_1_progressed_1.set() + logger.info("client1: progress 1 set") + else: + ev_client_1_progressed_2.set() + logger.info("client1: progress 2 set") + logger.info(f"client1: progress done: {progress}/{total}") + + async def make_request(client_session: ClientSession): + return await client_session.send_request( + ClientRequest( + types.CallToolAsyncRequest( + method="tools/async/call", + params=types.CallToolAsyncRequestParams( + name="async_tool", + arguments={}, + ), + ) + ), + types.CallToolAsyncResult, + progress_callback=client_1_progress_callback, + ) + + async def client_2_progress_callback( + progress: float, total: float | None, message: str | None + ): + nonlocal ev_client_2_progressed_1 + logger.info(f"client2: progress started: {progress}/{total}") + ev_client_2_progressed_1.set() + logger.info(f"client2: progress done: {progress}/{total}") + + async def join_request( + client_session: ClientSession, async_token: types.AsyncToken + ): + return await client_session.send_request( + ClientRequest( + types.JoinCallToolAsyncRequest( + method="tools/async/join", + params=types.JoinCallToolRequestParams(token=async_token), + ) + ), + types.CallToolAsyncResult, + progress_callback=client_2_progress_callback, + ) + + async def get_result(client_session: ClientSession, async_token: types.AsyncToken): + while True: + result = await client_session.send_request( + ClientRequest( + types.GetToolAsyncResultRequest( + method="tools/async/get", + params=types.GetToolAsyncResultRequestParams(token=async_token), + ) + ), + types.CallToolResult, + ) + if result.isPending: + logger.info("client1: result is pending, sleeping") + await anyio.sleep(1) + elif result.isError: + raise RuntimeError(str(result)) + else: + return result + + server = make_server() + token = None + + async with anyio.create_task_group() as tg: + + async def client_1_submit(): + async with create_connected_server_and_client_session( + server + ) as client_session: + nonlocal token + nonlocal ev_client_1_started + nonlocal ev_client_2_progressed_1 + nonlocal ev_done + async_call = await make_request(client_session) + assert async_call is not None + assert async_call.token is not None + token = async_call.token + ev_client_1_started.set() + logger.info("client1: got token") + with anyio.fail_after(1): # Timeout after 1 second + logger.info("client1: waiting for client 2 progress") + await ev_client_2_progressed_1.wait() + + logger.info("client1: getting result") + result = await get_result(client_session, token) + ev_done.set() + + assert type(result.content[0]) is types.TextContent + assert result.content[0].text == "test" + + async def client_2_join(): + async with create_connected_server_and_client_session( + server + ) as client_session: + nonlocal token + nonlocal ev_client_1_started + nonlocal ev_client_1_progressed_1 + nonlocal ev_client_2_joined + nonlocal ev_done + + with anyio.fail_after(1): # Timeout after 1 second + logger.info("client2: waiting for token") + await ev_client_1_started.wait() + assert token is not None + logger.info("client2: got token") + logger.info("client2: waiting for client 1 progress 1") + await ev_client_1_progressed_1.wait() + + with anyio.fail_after(1): # Timeout after 1 second + logger.info("client2: joining") + join_async = await join_request(client_session, token) + assert join_async is not None + assert join_async.token is not None + ev_client_2_joined.set() + ("client2: joined") + + with anyio.fail_after(10): # Timeout after 1 second + logger.info("client2: waiting for done") + await ev_done.wait() + logger.info("client2: done") + + tg.start_soon(client_1_submit) + tg.start_soon(client_2_join) + + assert ev_client_1_started.is_set() + assert ev_client_2_joined.is_set() + assert ev_client_1_progressed_1.is_set() + assert ev_client_1_progressed_2.is_set() + assert ev_client_2_progressed_1.is_set() + + @pytest.mark.anyio async def test_connection_closed(): """