From 62e2be3164b2f2493016d160d2494e9217218cce Mon Sep 17 00:00:00 2001 From: jhaoming-oai Date: Wed, 21 May 2025 08:18:53 -0700 Subject: [PATCH] fix taskgroup cleanup ordering --- src/mcp/client/sse.py | 3 +- src/mcp/client/stdio/__init__.py | 3 +- src/mcp/client/streamable_http.py | 6 +- src/mcp/client/websocket.py | 3 +- src/mcp/server/lowlevel/server.py | 4 +- src/mcp/server/sse.py | 3 +- src/mcp/server/stdio.py | 3 +- src/mcp/server/streamable_http.py | 5 +- src/mcp/server/streamable_http_manager.py | 3 +- src/mcp/server/websocket.py | 3 +- src/mcp/shared/session.py | 3 +- src/mcp/shared/taskgroup.py | 76 +++++++++++++++++++++++ 12 files changed, 100 insertions(+), 15 deletions(-) create mode 100644 src/mcp/shared/taskgroup.py diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index a782b58a7..b2ed3ad1a 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -12,6 +12,7 @@ import mcp.types as types from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.message import SessionMessage +from mcp.shared.taskgroup import CompatTaskGroup logger = logging.getLogger(__name__) @@ -50,7 +51,7 @@ async def sse_client( read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - async with anyio.create_task_group() as tg: + async with CompatTaskGroup() as tg: try: logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") async with create_mcp_http_client(headers=headers, auth=auth) as client: diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index 6d815b43a..5e7569c9f 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -12,6 +12,7 @@ import mcp.types as types from mcp.shared.message import SessionMessage +from mcp.shared.taskgroup import CompatTaskGroup from .win32 import ( create_windows_process, @@ -168,7 +169,7 @@ async def stdin_writer(): await anyio.lowlevel.checkpoint() async with ( - anyio.create_task_group() as tg, + CompatTaskGroup() as tg, process, ): tg.start_soon(stdout_reader) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 79b2995e1..26f0405f4 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -15,12 +15,12 @@ import anyio import httpx -from anyio.abc import TaskGroup from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import EventSource, ServerSentEvent, aconnect_sse from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.message import ClientMessageMetadata, SessionMessage +from mcp.shared.taskgroup import CompatTaskGroup from mcp.types import ( ErrorData, JSONRPCError, @@ -352,7 +352,7 @@ async def post_writer( read_stream_writer: StreamWriter, write_stream: MemoryObjectSendStream[SessionMessage], start_get_stream: Callable[[], None], - tg: TaskGroup, + tg: CompatTaskGroup, ) -> None: """Handle writing requests to the server.""" try: @@ -460,7 +460,7 @@ async def streamablehttp_client( SessionMessage ](0) - async with anyio.create_task_group() as tg: + async with CompatTaskGroup() as tg: try: logger.info(f"Connecting to StreamableHTTP endpoint: {url}") diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index ac542fb3f..d401c8888 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -11,6 +11,7 @@ import mcp.types as types from mcp.shared.message import SessionMessage +from mcp.shared.taskgroup import CompatTaskGroup logger = logging.getLogger(__name__) @@ -79,7 +80,7 @@ async def ws_writer(): ) await ws.send(json.dumps(msg_dict)) - async with anyio.create_task_group() as tg: + async with CompatTaskGroup() as tg: # Start reader and writer tasks tg.start_soon(ws_reader) tg.start_soon(ws_writer) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 876aef817..16cbcbf5d 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -74,7 +74,6 @@ async def main(): from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager from typing import Any, Generic, TypeVar -import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl @@ -87,6 +86,7 @@ async def main(): from mcp.shared.exceptions import McpError from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder +from mcp.shared.taskgroup import CompatTaskGroup logger = logging.getLogger(__name__) @@ -503,7 +503,7 @@ async def run( ) ) - async with anyio.create_task_group() as tg: + async with CompatTaskGroup() as tg: async for message in session.incoming_messages: logger.debug(f"Received message: {message}") diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index bae2bbf52..6c63b9739 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -53,6 +53,7 @@ async def handle_sse(request): import mcp.types as types from mcp.shared.message import SessionMessage +from mcp.shared.taskgroup import CompatTaskGroup logger = logging.getLogger(__name__) @@ -143,7 +144,7 @@ async def sse_writer(): } ) - async with anyio.create_task_group() as tg: + async with CompatTaskGroup() as tg: async def response_wrapper(scope: Scope, receive: Receive, send: Send): """ diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index f0bbe5a31..b0e76ae52 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -28,6 +28,7 @@ async def run_server(): import mcp.types as types from mcp.shared.message import SessionMessage +from mcp.shared.taskgroup import CompatTaskGroup @asynccontextmanager @@ -84,7 +85,7 @@ async def stdout_writer(): except anyio.ClosedResourceError: await anyio.lowlevel.checkpoint() - async with anyio.create_task_group() as tg: + async with CompatTaskGroup() as tg: tg.start_soon(stdin_reader) tg.start_soon(stdout_writer) yield read_stream, write_stream diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 8f4a1f512..e467f1190 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -25,6 +25,7 @@ from starlette.types import Receive, Scope, Send from mcp.shared.message import ServerMessageMetadata, SessionMessage +from mcp.shared.taskgroup import CompatTaskGroup from mcp.types import ( INTERNAL_ERROR, INVALID_PARAMS, @@ -508,7 +509,7 @@ async def sse_writer(): # Start the SSE response (this will send headers immediately) try: # First send the response to establish the SSE connection - async with anyio.create_task_group() as tg: + async with CompatTaskGroup() as tg: tg.start_soon(response, scope, receive, send) # Then send the message to be processed by the server session_message = SessionMessage(message) @@ -840,7 +841,7 @@ async def connect( self._write_stream = write_stream # Start a task group for message routing - async with anyio.create_task_group() as tg: + async with CompatTaskGroup() as tg: # Create a message router that distributes messages to request streams async def message_router(): try: diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index e5ef8b4aa..78e690355 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -22,6 +22,7 @@ EventStore, StreamableHTTPServerTransport, ) +from mcp.shared.taskgroup import CompatTaskGroup logger = logging.getLogger(__name__) @@ -103,7 +104,7 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]: ) self._has_started = True - async with anyio.create_task_group() as tg: + async with CompatTaskGroup() as tg: # Store the task group for later use self._task_group = tg logger.info("StreamableHTTP session manager started") diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index 9dc3f2a25..05de59fb2 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -9,6 +9,7 @@ import mcp.types as types from mcp.shared.message import SessionMessage +from mcp.shared.taskgroup import CompatTaskGroup logger = logging.getLogger(__name__) @@ -58,7 +59,7 @@ async def ws_writer(): except anyio.ClosedResourceError: await websocket.close() - async with anyio.create_task_group() as tg: + async with CompatTaskGroup() as tg: tg.start_soon(ws_reader) tg.start_soon(ws_writer) yield (read_stream, write_stream) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 90b4eb27c..c56f37c9e 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -13,6 +13,7 @@ from mcp.shared.exceptions import McpError from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage +from mcp.shared.taskgroup import CompatTaskGroup from mcp.types import ( CancelledNotification, ClientNotification, @@ -201,7 +202,7 @@ def __init__( self._exit_stack = AsyncExitStack() async def __aenter__(self) -> Self: - self._task_group = anyio.create_task_group() + self._task_group = CompatTaskGroup() await self._task_group.__aenter__() self._task_group.start_soon(self._receive_loop) return self diff --git a/src/mcp/shared/taskgroup.py b/src/mcp/shared/taskgroup.py new file mode 100644 index 000000000..a4ff9283d --- /dev/null +++ b/src/mcp/shared/taskgroup.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import asyncio +import sys +from collections.abc import Awaitable, Callable +from contextlib import AbstractAsyncContextManager +from typing import Any, TypeVar + +import anyio + +_T = TypeVar("_T") + +class _AsyncioCancelScope: + def __init__(self, tasks: set[asyncio.Task[Any]]): + self._tasks = tasks + + def cancel(self) -> None: + for task in list(self._tasks): + task.cancel() + +class CompatTaskGroup(AbstractAsyncContextManager): + """Minimal compatibility layer mimicking ``anyio.TaskGroup``.""" + + def __init__(self) -> None: + self._use_asyncio = sys.version_info >= (3, 11) + if self._use_asyncio: + self._tg = asyncio.TaskGroup() + self._tasks: set[asyncio.Task[Any]] = set() + self.cancel_scope = _AsyncioCancelScope(self._tasks) + else: + self._tg = anyio.create_task_group() + self.cancel_scope = self._tg.cancel_scope # type: ignore[attr-defined] + + async def __aenter__(self) -> CompatTaskGroup: + await self._tg.__aenter__() + return self + + async def __aexit__(self, exc_type, exc, tb) -> bool | None: + return await self._tg.__aexit__(exc_type, exc, tb) + + def start_soon( + self, + func: Callable[..., Awaitable[Any]], + *args: Any, + name: Any | None = None, + ) -> None: + if self._use_asyncio: + task = self._tg.create_task(func(*args)) + self._tasks.add(task) + else: + self._tg.start_soon(func, *args, name=name) + + async def start( + self, + func: Callable[..., Awaitable[Any]], + *args: Any, + name: Any | None = None, + ) -> Any: + if self._use_asyncio: + fut: asyncio.Future[Any] = asyncio.get_running_loop().create_future() + + async def runner() -> None: + try: + result = await func(*args, task_status=fut) + if not fut.done(): + fut.set_result(result) + except BaseException as exc: + if not fut.done(): + fut.set_exception(exc) + raise + + task = self._tg.create_task(runner()) + self._tasks.add(task) + return await fut + else: + return await self._tg.start(func, *args, name=name)