From 596864ec7cb382b7ea2552eb31316b611614bb3f Mon Sep 17 00:00:00 2001 From: Nick Cooper Date: Mon, 19 May 2025 02:09:01 -0700 Subject: [PATCH 1/2] Support injectable httpx client --- src/mcp/client/sse.py | 7 ++++--- src/mcp/client/streamable_http.py | 5 +++-- src/mcp/shared/_httpx_utils.py | 11 ++++++++++- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 29195cbd9..8047a6c67 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -1,6 +1,6 @@ import logging from contextlib import asynccontextmanager -from typing import Any +from typing import Any, Callable from urllib.parse import urljoin, urlparse import anyio @@ -10,7 +10,7 @@ from httpx_sse import aconnect_sse import mcp.types as types -from mcp.shared._httpx_utils import create_mcp_http_client +from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client from mcp.shared.message import SessionMessage logger = logging.getLogger(__name__) @@ -26,6 +26,7 @@ async def sse_client( headers: dict[str, Any] | None = None, timeout: float = 5, sse_read_timeout: float = 60 * 5, + httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, ): """ Client transport for SSE. @@ -45,7 +46,7 @@ async def sse_client( async with anyio.create_task_group() as tg: try: logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") - async with create_mcp_http_client(headers=headers) as client: + async with httpx_client_factory(headers=headers) as client: async with aconnect_sse( client, "GET", diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 3324dab5a..19ba6c9ba 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -19,7 +19,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import EventSource, ServerSentEvent, aconnect_sse -from mcp.shared._httpx_utils import create_mcp_http_client +from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.types import ( ErrorData, @@ -427,6 +427,7 @@ async def streamablehttp_client( timeout: timedelta = timedelta(seconds=30), sse_read_timeout: timedelta = timedelta(seconds=60 * 5), terminate_on_close: bool = True, + httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, ) -> AsyncGenerator[ tuple[ MemoryObjectReceiveStream[SessionMessage | Exception], @@ -460,7 +461,7 @@ async def streamablehttp_client( try: logger.info(f"Connecting to StreamableHTTP endpoint: {url}") - async with create_mcp_http_client( + async with httpx_client_factory( headers=transport.request_headers, timeout=httpx.Timeout( transport.timeout.seconds, read=transport.sse_read_timeout.seconds diff --git a/src/mcp/shared/_httpx_utils.py b/src/mcp/shared/_httpx_utils.py index 95080bde1..7729fe697 100644 --- a/src/mcp/shared/_httpx_utils.py +++ b/src/mcp/shared/_httpx_utils.py @@ -1,12 +1,21 @@ """Utilities for creating standardized httpx AsyncClient instances.""" -from typing import Any +from typing import Any, Protocol import httpx __all__ = ["create_mcp_http_client"] +class McpHttpClientFactory(Protocol): + def __call__( + self, + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + ) -> httpx.AsyncClient: + ... + + def create_mcp_http_client( headers: dict[str, str] | None = None, timeout: httpx.Timeout | None = None, From 63208fe08186d37327c2be04128f54205749d6a8 Mon Sep 17 00:00:00 2001 From: Nick Cooper Date: Mon, 19 May 2025 02:20:20 -0700 Subject: [PATCH 2/2] nits --- src/mcp/client/sse.py | 2 +- src/mcp/shared/_httpx_utils.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 8047a6c67..572360e5c 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -1,6 +1,6 @@ import logging from contextlib import asynccontextmanager -from typing import Any, Callable +from typing import Any from urllib.parse import urljoin, urlparse import anyio diff --git a/src/mcp/shared/_httpx_utils.py b/src/mcp/shared/_httpx_utils.py index 7729fe697..277e67d7a 100644 --- a/src/mcp/shared/_httpx_utils.py +++ b/src/mcp/shared/_httpx_utils.py @@ -12,8 +12,7 @@ def __call__( self, headers: dict[str, str] | None = None, timeout: httpx.Timeout | None = None, - ) -> httpx.AsyncClient: - ... + ) -> httpx.AsyncClient: ... def create_mcp_http_client(