From 83ed6769be874254d8e6d8787804c1c978b3e184 Mon Sep 17 00:00:00 2001 From: Andrey Date: Wed, 11 Jun 2025 21:54:48 +0100 Subject: [PATCH] Add the ability to provide a custom HTTP client for SSE connections --- internal/mcp/sse.go | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/internal/mcp/sse.go b/internal/mcp/sse.go index f49424a1d66..b33da62fd46 100644 --- a/internal/mcp/sse.go +++ b/internal/mcp/sse.go @@ -310,19 +310,30 @@ func (s sseServerStream) Close() error { // https://modelcontextprotocol.io/specification/2024-11-05/basic/transports type SSEClientTransport struct { sseEndpoint *url.URL + httpClient *http.Client } // NewSSEClientTransport returns a new client transport that connects to the -// SSE server at the provided URL. +// SSE server at the provided URL using the default HTTP client. // // NewSSEClientTransport panics if the given URL is invalid. func NewSSEClientTransport(baseURL string) *SSEClientTransport { + // Use the default HTTP client. + return NewSSEClientTransportWithHTTPClient(baseURL, http.DefaultClient) +} + +// NewSSEClientTransportWithHTTPClient returns a new client transport that connects to the +// SSE server at the provided URL using the provided HTTP client. +// +// NewSSEClientTransportWithHTTPClient panics if the given URL is invalid. +func NewSSEClientTransportWithHTTPClient(baseURL string, httpClient *http.Client) *SSEClientTransport { url, err := url.Parse(baseURL) if err != nil { panic(fmt.Sprintf("invalid base url: %v", err)) } return &SSEClientTransport{ sseEndpoint: url, + httpClient: httpClient, } } @@ -333,7 +344,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Stream, error) { return nil, err } req.Header.Set("Accept", "text/event-stream") - resp, err := http.DefaultClient.Do(req) + resp, err := c.httpClient.Do(req) if err != nil { return nil, err } @@ -404,6 +415,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Stream, error) { // From here on, the stream takes ownership of resp.Body. s := &sseClientStream{ sseEndpoint: c.sseEndpoint, + httpClient: c.httpClient, msgEndpoint: msgEndpoint, incoming: make(chan []byte, 100), body: resp.Body, @@ -435,9 +447,10 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Stream, error) { // - Reads are SSE 'message' events, and pushes them onto a buffered channel. // - Close terminates the GET request. type sseClientStream struct { - sseEndpoint *url.URL // SSE endpoint for the GET - msgEndpoint *url.URL // session endpoint for POSTs - incoming chan []byte // queue of incoming messages + sseEndpoint *url.URL // SSE endpoint for the GET + msgEndpoint *url.URL // session endpoint for POSTs + httpClient *http.Client // HTTP client to use for requests + incoming chan []byte // queue of incoming messages mu sync.Mutex body io.ReadCloser // body of the hanging GET @@ -484,7 +497,7 @@ func (c *sseClientStream) Write(ctx context.Context, msg jsonrpc2.Message) error return err } req.Header.Set("Content-Type", "application/json") - resp, err := http.DefaultClient.Do(req) + resp, err := c.httpClient.Do(req) if err != nil { return err }