Skip to content

x/tools/internal/mcp: add custom HTTP client overload #580

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions internal/mcp/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,19 +310,30 @@ func (s sseServerStream) Close() error {
// https://modelcontextprotocol.io/specification/2024-11-05/basic/transports
type SSEClientTransport struct {
sseEndpoint *url.URL
httpClient *http.Client
}

// NewSSEClientTransport returns a new client transport that connects to the
// SSE server at the provided URL.
// SSE server at the provided URL using the default HTTP client.
//
// NewSSEClientTransport panics if the given URL is invalid.
func NewSSEClientTransport(baseURL string) *SSEClientTransport {
// Use the default HTTP client.
return NewSSEClientTransportWithHTTPClient(baseURL, http.DefaultClient)
}

// NewSSEClientTransportWithHTTPClient returns a new client transport that connects to the
// SSE server at the provided URL using the provided HTTP client.
//
// NewSSEClientTransportWithHTTPClient panics if the given URL is invalid.
func NewSSEClientTransportWithHTTPClient(baseURL string, httpClient *http.Client) *SSEClientTransport {
url, err := url.Parse(baseURL)
if err != nil {
panic(fmt.Sprintf("invalid base url: %v", err))
}
return &SSEClientTransport{
sseEndpoint: url,
httpClient: httpClient,
}
}

Expand All @@ -333,7 +344,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Stream, error) {
return nil, err
}
req.Header.Set("Accept", "text/event-stream")
resp, err := http.DefaultClient.Do(req)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -404,6 +415,7 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Stream, error) {
// From here on, the stream takes ownership of resp.Body.
s := &sseClientStream{
sseEndpoint: c.sseEndpoint,
httpClient: c.httpClient,
msgEndpoint: msgEndpoint,
incoming: make(chan []byte, 100),
body: resp.Body,
Expand Down Expand Up @@ -435,9 +447,10 @@ func (c *SSEClientTransport) Connect(ctx context.Context) (Stream, error) {
// - Reads are SSE 'message' events, and pushes them onto a buffered channel.
// - Close terminates the GET request.
type sseClientStream struct {
sseEndpoint *url.URL // SSE endpoint for the GET
msgEndpoint *url.URL // session endpoint for POSTs
incoming chan []byte // queue of incoming messages
sseEndpoint *url.URL // SSE endpoint for the GET
msgEndpoint *url.URL // session endpoint for POSTs
httpClient *http.Client // HTTP client to use for requests
incoming chan []byte // queue of incoming messages

mu sync.Mutex
body io.ReadCloser // body of the hanging GET
Expand Down Expand Up @@ -484,7 +497,7 @@ func (c *sseClientStream) Write(ctx context.Context, msg jsonrpc2.Message) error
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
resp, err := c.httpClient.Do(req)
if err != nil {
return err
}
Expand Down