From 7da176f21e754934a5f9c81037e932fde701ac64 Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Wed, 29 Jan 2025 16:41:53 -0500 Subject: [PATCH] chore(x): implement WS client with Gorilla (#361) --- x/configurl/websocket.go | 52 ++-------- x/go.mod | 2 + x/go.sum | 4 + x/websocket/endpoint.go | 196 +++++++++++++++++++++++++++++++++++ x/websocket/endpoint_test.go | 196 +++++++++++++++++++++++++++++++++++ 5 files changed, 407 insertions(+), 43 deletions(-) create mode 100644 x/websocket/endpoint.go create mode 100644 x/websocket/endpoint_test.go diff --git a/x/configurl/websocket.go b/x/configurl/websocket.go index 31ef95e4..e07a2ef7 100644 --- a/x/configurl/websocket.go +++ b/x/configurl/websocket.go @@ -23,7 +23,7 @@ import ( "strings" "github.com/Jigsaw-Code/outline-sdk/transport" - "golang.org/x/net/websocket" + "github.com/Jigsaw-Code/outline-sdk/x/websocket" ) type wsConfig struct { @@ -57,20 +57,6 @@ func parseWSConfig(configURL url.URL) (*wsConfig, error) { return &cfg, nil } -// wsToStreamConn converts a [websocket.Conn] to a [transport.StreamConn]. -type wsToStreamConn struct { - *websocket.Conn -} - -func (c *wsToStreamConn) CloseRead() error { - // Nothing to do. - return nil -} - -func (c *wsToStreamConn) CloseWrite() error { - return c.Close() -} - func registerWebsocketStreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) { r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) { sd, err := newSD(ctx, config.BaseConfig) @@ -85,22 +71,12 @@ func registerWebsocketStreamDialer(r TypeRegistry[transport.StreamDialer], typeI return nil, errors.New("must specify tcp_path") } return transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) { - wsURL := url.URL{Scheme: "ws", Host: addr, Path: wsConfig.tcpPath} - origin := url.URL{Scheme: "http", Host: addr} - wsCfg, err := websocket.NewConfig(wsURL.String(), origin.String()) - if err != nil { - return nil, fmt.Errorf("failed to create websocket config: %w", err) - } - baseConn, err := sd.DialStream(ctx, addr) - if err != nil { - return nil, fmt.Errorf("failed to connect to websocket endpoint: %w", err) - } - wsConn, err := websocket.NewClient(wsCfg, baseConn) + wsURL := url.URL{Scheme: "wss", Host: addr, Path: wsConfig.tcpPath} + connect, err := websocket.NewStreamEndpoint(wsURL.String(), sd, nil) if err != nil { - baseConn.Close() - return nil, fmt.Errorf("failed to create websocket client: %w", err) + return nil, fmt.Errorf("failed to create websocket stream endpoint: %w", err) } - return &wsToStreamConn{wsConn}, nil + return connect(ctx) }), nil }) } @@ -119,22 +95,12 @@ func registerWebsocketPacketDialer(r TypeRegistry[transport.PacketDialer], typeI return nil, errors.New("must specify udp_path") } return transport.FuncPacketDialer(func(ctx context.Context, addr string) (net.Conn, error) { - wsURL := url.URL{Scheme: "ws", Host: addr, Path: wsConfig.udpPath} - origin := url.URL{Scheme: "http", Host: addr} - wsCfg, err := websocket.NewConfig(wsURL.String(), origin.String()) - if err != nil { - return nil, fmt.Errorf("failed to create websocket config: %w", err) - } - baseConn, err := sd.DialStream(ctx, addr) - if err != nil { - return nil, fmt.Errorf("failed to connect to websocket endpoint: %w", err) - } - wsConn, err := websocket.NewClient(wsCfg, baseConn) + wsURL := url.URL{Scheme: "wss", Host: addr, Path: wsConfig.udpPath} + connect, err := websocket.NewPacketEndpoint(wsURL.String(), sd, nil) if err != nil { - baseConn.Close() - return nil, fmt.Errorf("failed to create websocket client: %w", err) + return nil, fmt.Errorf("failed to create websocket stream endpoint: %w", err) } - return wsConn, nil + return connect(ctx) }), nil }) } diff --git a/x/go.mod b/x/go.mod index 825e07f7..ea9d32a9 100644 --- a/x/go.mod +++ b/x/go.mod @@ -7,6 +7,8 @@ require ( // Use github.com/Psiphon-Labs/psiphon-tunnel-core@staging-client as per // https://github.com/Psiphon-Labs/psiphon-tunnel-core/?tab=readme-ov-file#using-psiphon-with-go-modules github.com/Psiphon-Labs/psiphon-tunnel-core v1.0.11-0.20240619172145-03cade11f647 + github.com/coder/websocket v1.8.12 + github.com/gorilla/websocket v1.5.3 github.com/lmittmann/tint v1.0.5 github.com/quic-go/quic-go v0.48.1 github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 diff --git a/x/go.sum b/x/go.sum index a7238e9a..3d19fb51 100644 --- a/x/go.sum +++ b/x/go.sum @@ -37,6 +37,8 @@ github.com/cheekybits/genny v0.0.0-20170328200008-9127e812e1e9/go.mod h1:+tQajlR github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= +github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= +github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/cognusion/go-cache-lru v0.0.0-20170419142635-f73e2280ecea h1:9C2rdYRp8Vzwhm3sbFX0yYfB+70zKFRjn7cnPCucHSw= github.com/cognusion/go-cache-lru v0.0.0-20170419142635-f73e2280ecea/go.mod h1:MdyNkAe06D7xmJsf+MsLvbZKYNXuOHLKJrvw+x4LlcQ= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -84,6 +86,8 @@ github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/pprof v0.0.0-20211214055906-6f57359322fd h1:1FjCyPC+syAzJ5/2S8fqdZK1R22vvA0J7JZKcuOIQ7Y= github.com/google/pprof v0.0.0-20211214055906-6f57359322fd/go.mod h1:KgnwoLYCZ8IQu3XUZ8Nc/bM9CCZFOyjUNOSygVozoDg= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grafov/m3u8 v0.0.0-20171211212457-6ab8f28ed427 h1:xh96CCAZTX8LJPFoOVRgTwZbn2DvJl8fyCyivohhSIg= github.com/grafov/m3u8 v0.0.0-20171211212457-6ab8f28ed427/go.mod h1:PdjzaU/pJUo4jTIn2rcgMFs+HqBGl/sPJLr8BI0Xq/I= github.com/hashicorp/golang-lru v1.0.2 h1:dV3g9Z/unq5DpblPpw+Oqcv4dU/1omnb4Ok8iPY6p1c= diff --git a/x/websocket/endpoint.go b/x/websocket/endpoint.go new file mode 100644 index 00000000..35ec9edf --- /dev/null +++ b/x/websocket/endpoint.go @@ -0,0 +1,196 @@ +// Copyright 2025 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package websocket provides the Websocket transport. +package websocket + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "runtime" + "strings" + "time" + + "github.com/Jigsaw-Code/outline-sdk/transport" + "github.com/gorilla/websocket" +) + +// NewStreamEndpoint creates a new Websocket Stream Endpoint. Streams are sent over +// Websockets, with each Write becoming a separate message. Half-close is supported: +// CloseRead will not close the Websocket connection, while CloseWrite sends a Websocket +// close but continues reading until a close is received from the server. +func NewStreamEndpoint(urlStr string, sd transport.StreamDialer, opts ...Option) (func(context.Context) (transport.StreamConn, error), error) { + return newEndpoint(urlStr, sd, func(gc *gorillaConn) transport.StreamConn { return gc }, opts...) +} + +// NewPacketEndpoint creates a new Websocket Packet Endpoint. Each packet is exchanged as a Websocket message. +func NewPacketEndpoint(urlStr string, sd transport.StreamDialer, opts ...Option) (func(context.Context) (net.Conn, error), error) { + return newEndpoint(urlStr, sd, func(gc *gorillaConn) net.Conn { return gc }, opts...) +} + +type options struct { + tlsConfig *tls.Config + headers http.Header +} + +// Option for building the Websocket endpoint. +type Option func(c *options) + +// WithTLSConfig specifies the TLS configuration to use. +// TODO(fortuna): Use Outline TLS instead. +func WithTLSConfig(tlsConfig *tls.Config) Option { + return func(c *options) { + c.tlsConfig = tlsConfig + } +} + +// WithHTTPHeaders specifies the HTTP headers to use. +func WithHTTPHeaders(headers http.Header) Option { + return func(c *options) { + c.headers = headers + } +} + +func newEndpoint[ConnType net.Conn](urlStr string, sd transport.StreamDialer, wsToConn func(*gorillaConn) ConnType, opts ...Option) (func(context.Context) (ConnType, error), error) { + _, err := url.Parse(urlStr) + if err != nil { + return nil, fmt.Errorf("url is invalid: %w", err) + } + + resolvedOpts := options{ + // By default, we use this User-Agent. + headers: http.Header(map[string][]string{"User-Agent": {fmt.Sprintf("Outline (%s; %s; %s)", runtime.GOOS, runtime.GOARCH, runtime.Version())}}), + } + for _, opt := range opts { + opt(&resolvedOpts) + } + + wsDialer := &websocket.Dialer{ + TLSClientConfig: resolvedOpts.tlsConfig, + NetDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if !strings.HasPrefix(network, "tcp") { + return nil, fmt.Errorf("websocket dialer does not support network type %v", network) + } + return sd.DialStream(ctx, addr) + }, + } + return func(ctx context.Context) (ConnType, error) { + var zero ConnType + wsConn, _, err := wsDialer.DialContext(ctx, urlStr, resolvedOpts.headers) + if err != nil { + return zero, err + } + gConn := &gorillaConn{wsConn: wsConn} + wsConn.SetCloseHandler(func(code int, text string) error { + gConn.readErr = io.EOF + return nil + }) + return wsToConn(gConn), nil + }, nil +} + +type gorillaConn struct { + wsConn *websocket.Conn + writeErr error + readErr error + pendingReader io.Reader +} + +var _ transport.StreamConn = (*gorillaConn)(nil) + +func (c *gorillaConn) LocalAddr() net.Addr { + return c.wsConn.LocalAddr() +} + +func (c *gorillaConn) RemoteAddr() net.Addr { + return c.wsConn.RemoteAddr() +} + +func (c *gorillaConn) SetDeadline(deadline time.Time) error { + return errors.Join(c.wsConn.SetReadDeadline(deadline), c.wsConn.SetWriteDeadline(deadline)) +} + +func (c *gorillaConn) SetReadDeadline(deadline time.Time) error { + return c.wsConn.SetReadDeadline(deadline) +} + +func (c *gorillaConn) SetWriteDeadline(deadline time.Time) error { + return c.wsConn.SetWriteDeadline(deadline) +} + +func (c *gorillaConn) Read(buf []byte) (int, error) { + if c.readErr != nil { + return 0, c.readErr + } + if c.pendingReader != nil { + n, err := c.pendingReader.Read(buf) + if c.readErr != nil { + return n, c.readErr + } + if !errors.Is(err, io.EOF) { + return n, err + } + c.pendingReader = nil + } + + msgType, reader, err := c.wsConn.NextReader() + if c.readErr != nil { + return 0, c.readErr + } + if err != nil { + return 0, err + } + if msgType != websocket.BinaryMessage { + return 0, errors.New("read message is not binary") + } + + c.pendingReader = reader + return reader.Read(buf) +} + +func (c *gorillaConn) Write(buf []byte) (int, error) { + err := c.wsConn.WriteMessage(websocket.BinaryMessage, buf) + if err != nil { + if c.writeErr != nil { + return 0, c.writeErr + } + return 0, err + } + return len(buf), nil +} + +func (c *gorillaConn) CloseRead() error { + c.readErr = net.ErrClosed + c.wsConn.SetReadDeadline(time.Now()) + return nil +} + +func (c *gorillaConn) CloseWrite() error { + // Send close message. + message := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") + c.wsConn.WriteControl(websocket.CloseMessage, message, time.Now().Add(time.Second)) + c.writeErr = net.ErrClosed + c.wsConn.SetWriteDeadline(time.Now()) + return nil +} + +func (c *gorillaConn) Close() error { + return c.wsConn.Close() +} diff --git a/x/websocket/endpoint_test.go b/x/websocket/endpoint_test.go new file mode 100644 index 00000000..342dfe53 --- /dev/null +++ b/x/websocket/endpoint_test.go @@ -0,0 +1,196 @@ +// Copyright 2025 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package websocket + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Jigsaw-Code/outline-sdk/transport" + // TODO(fortuna): Implement the test with gorilla instead. + "github.com/coder/websocket" + "github.com/stretchr/testify/require" +) + +func Test_NewStreamEndpoint(t *testing.T) { + mux := http.NewServeMux() + toTargetReader, toTargetWriter := io.Pipe() + fromTargetReader, fromTargetWriter := io.Pipe() + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // TODO(fortuna): support h2 and h3 on the server. + require.Equal(t, "", r.TLS.NegotiatedProtocol) + require.Equal(t, "HTTP/1.1", r.Proto) + + t.Log("Got stream request", "request", r) + defer t.Log("Request done") + clientConn, err := websocket.Accept(w, r, nil) + if err != nil { + t.Log("Failed to accept Websocket connection", "error", err) + http.Error(w, "Failed to accept Websocket connection", http.StatusBadGateway) + return + } + clientConn.SetReadLimit(-1) + defer clientConn.CloseNow() + + // Handle client -> target. + readClientDone := make(chan struct{}) + go func() { + defer close(readClientDone) + defer clientConn.CloseRead(r.Context()) + for { + msgType, msg, err := clientConn.Read(r.Context()) + if err != nil { + if !errors.Is(err, io.EOF) { + t.Log("Failed to read from client", "error", err) + clientConn.Close(websocket.StatusInternalError, "failed to read from client") + } + break + } + require.Equal(t, websocket.MessageBinary, msgType) + if _, err := toTargetWriter.Write(msg); err != nil { + t.Log("Failed to write to target", "error", err) + clientConn.Close(websocket.StatusInternalError, "failed to write message to target") + break + } + } + }() + // Handle target -> client + func() { + // About 2 MTUs + buf := make([]byte, 3000) + for { + n, err := fromTargetReader.Read(buf) + if err != nil { + if !errors.Is(err, io.EOF) { + t.Log("Failed to read from target", "error", err) + clientConn.Close(websocket.StatusInternalError, "failed to read message from target") + } + break + } + read := buf[:n] + if err := clientConn.Write(r.Context(), websocket.MessageBinary, read); err != nil { + t.Log("Failed to write to client", "error", err) + clientConn.Close(websocket.StatusInternalError, "failed to write message to client") + break + } + } + }() + <-readClientDone + }) + mux.Handle("/tcp", http.StripPrefix("/tcp", handler)) + ts := httptest.NewUnstartedServer(mux) + ts.EnableHTTP2 = true + ts.StartTLS() + defer ts.Close() + + // Run server functionality. + go func() { + for { + // Fits "Request\n" + req := make([]byte, 8) + n, err := io.ReadFull(toTargetReader, req) + if err != nil { + if !errors.Is(err, io.EOF) { + require.NoError(t, err) + } + break + } + require.Equal(t, "Request\n", string(req[:n])) + + n, err = fromTargetWriter.Write([]byte("Response\n")) + require.NoError(t, err) + require.Equal(t, 9, n) + } + }() + + // TODO(fortuna): Support h2. We can force h2 on the client with the code below. + // client := &http.Client{ + // Transport: &http2.Transport{ + // TLSClientConfig: ts.Client().Transport.(*http.Transport).TLSClientConfig, + // }, + // } + client := ts.Client() + connect, err := NewStreamEndpoint("wss"+ts.URL[5:]+"/tcp", &transport.TCPDialer{}, WithTLSConfig(client.Transport.(*http.Transport).TLSClientConfig)) + require.NoError(t, err) + require.NotNil(t, connect) + + conn, err := connect(context.Background()) + require.NoError(t, err) + require.NotNil(t, conn) + defer conn.Close() + + for i := 0; i < 10; i++ { + n, err := conn.Write([]byte("Req")) + require.NoError(t, err) + require.Equal(t, 3, n) + n, err = conn.Write([]byte("uest\n")) + require.NoError(t, err) + require.Equal(t, 5, n) + + resp := make([]byte, 9) + n, err = conn.Read(resp) + require.NoError(t, err) + require.Equal(t, "Response\n", string(resp[:n])) + } + require.NoError(t, conn.CloseWrite()) +} + +func Test_NewPacketEndpoint(t *testing.T) { + mux := http.NewServeMux() + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // TODO(fortuna): support h2 and h3 on the server. + require.Equal(t, "", r.TLS.NegotiatedProtocol) + require.Equal(t, "HTTP/1.1", r.Proto) + clientConn, err := websocket.Accept(w, r, nil) + require.NoError(t, err) + defer clientConn.CloseNow() + + msgType, msg, err := clientConn.Read(r.Context()) + require.NoError(t, err) + require.Equal(t, websocket.MessageBinary, msgType) + require.Equal(t, []byte("Request"), msg) + + err = clientConn.Write(r.Context(), websocket.MessageBinary, []byte("Response")) + require.NoError(t, err) + + clientConn.Close(websocket.StatusNormalClosure, "") + }) + mux.Handle("/udp", http.StripPrefix("/udp", handler)) + ts := httptest.NewUnstartedServer(mux) + ts.EnableHTTP2 = true + ts.StartTLS() + defer ts.Close() + + client := ts.Client() + connect, err := NewPacketEndpoint("wss"+ts.URL[5:]+"/udp", &transport.TCPDialer{}, WithTLSConfig(client.Transport.(*http.Transport).TLSClientConfig)) + require.NoError(t, err) + require.NotNil(t, connect) + + conn, err := connect(context.Background()) + require.NoError(t, err) + require.NotNil(t, conn) + + n, err := conn.Write([]byte("Request")) + require.NoError(t, err) + require.Equal(t, 7, n) + + resp, err := io.ReadAll(conn) + require.NoError(t, err) + require.Equal(t, []byte("Response"), resp) +}