diff --git a/pyproject.toml b/pyproject.toml index f352de5a0..69db82c9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,6 +77,7 @@ venvPath = "." venv = ".venv" strict = [ "src/mcp/server/fastmcp/tools/base.py", + "src/mcp/client/*.py" ] [tool.ruff.lint] diff --git a/src/mcp/client/__main__.py b/src/mcp/client/__main__.py index 8ce704ff1..baf815c0e 100644 --- a/src/mcp/client/__main__.py +++ b/src/mcp/client/__main__.py @@ -5,10 +5,12 @@ from urllib.parse import urlparse import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.types import JSONRPCMessage if not sys.warnoptions: import warnings @@ -29,7 +31,10 @@ async def receive_loop(session: ClientSession): logger.info("Received message from server: %s", message) -async def run_session(read_stream, write_stream): +async def run_session( + read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], + write_stream: MemoryObjectSendStream[JSONRPCMessage], +): async with ( ClientSession(read_stream, write_stream) as session, anyio.create_task_group() as tg, diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index cde3103b6..2ac248778 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -76,18 +76,12 @@ def __init__( self._list_roots_callback = list_roots_callback or _default_list_roots_callback async def initialize(self) -> types.InitializeResult: - sampling = ( - types.SamplingCapability() if self._sampling_callback is not None else None - ) - roots = ( - types.RootsCapability( - # TODO: Should this be based on whether we - # _will_ send notifications, or only whether - # they're supported? - listChanged=True, - ) - if self._list_roots_callback is not None - else None + sampling = types.SamplingCapability() + roots = types.RootsCapability( + # TODO: Should this be based on whether we + # _will_ send notifications, or only whether + # they're supported? + listChanged=True, ) result = await self.send_request( diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index abafacb96..4f6241a72 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -98,6 +98,10 @@ async def sse_reader( continue await read_stream_writer.send(message) + case _: + logger.warning( + f"Unknown SSE event: {sse.event}" + ) except Exception as exc: logger.error(f"Error in sse_reader: {exc}") await read_stream_writer.send(exc) diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index 3e73b0204..9cf32296f 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -39,6 +39,11 @@ async def websocket_client( # Create two in-memory streams: # - One for incoming messages (read_stream, written by ws_reader) # - One for outgoing messages (write_stream, read by ws_writer) + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) diff --git a/src/mcp/shared/version.py b/src/mcp/shared/version.py index 51bf3521d..8fd13b992 100644 --- a/src/mcp/shared/version.py +++ b/src/mcp/shared/version.py @@ -1,3 +1,3 @@ from mcp.types import LATEST_PROTOCOL_VERSION -SUPPORTED_PROTOCOL_VERSIONS = [1, LATEST_PROTOCOL_VERSION] +SUPPORTED_PROTOCOL_VERSIONS: tuple[int, str] = (1, LATEST_PROTOCOL_VERSION)