Skip to content

Commit

Permalink
chore(x): implement WS client with Gorilla (#361)
Browse files Browse the repository at this point in the history
  • Loading branch information
fortuna authored Jan 29, 2025
1 parent efa8083 commit 7da176f
Show file tree
Hide file tree
Showing 5 changed files with 407 additions and 43 deletions.
52 changes: 9 additions & 43 deletions x/configurl/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
"strings"

"github.com/Jigsaw-Code/outline-sdk/transport"
"golang.org/x/net/websocket"
"github.com/Jigsaw-Code/outline-sdk/x/websocket"
)

type wsConfig struct {
Expand Down Expand Up @@ -57,20 +57,6 @@ func parseWSConfig(configURL url.URL) (*wsConfig, error) {
return &cfg, nil
}

// wsToStreamConn converts a [websocket.Conn] to a [transport.StreamConn].
type wsToStreamConn struct {
*websocket.Conn
}

func (c *wsToStreamConn) CloseRead() error {
// Nothing to do.
return nil
}

func (c *wsToStreamConn) CloseWrite() error {
return c.Close()
}

func registerWebsocketStreamDialer(r TypeRegistry[transport.StreamDialer], typeID string, newSD BuildFunc[transport.StreamDialer]) {
r.RegisterType(typeID, func(ctx context.Context, config *Config) (transport.StreamDialer, error) {
sd, err := newSD(ctx, config.BaseConfig)
Expand All @@ -85,22 +71,12 @@ func registerWebsocketStreamDialer(r TypeRegistry[transport.StreamDialer], typeI
return nil, errors.New("must specify tcp_path")
}
return transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) {
wsURL := url.URL{Scheme: "ws", Host: addr, Path: wsConfig.tcpPath}
origin := url.URL{Scheme: "http", Host: addr}
wsCfg, err := websocket.NewConfig(wsURL.String(), origin.String())
if err != nil {
return nil, fmt.Errorf("failed to create websocket config: %w", err)
}
baseConn, err := sd.DialStream(ctx, addr)
if err != nil {
return nil, fmt.Errorf("failed to connect to websocket endpoint: %w", err)
}
wsConn, err := websocket.NewClient(wsCfg, baseConn)
wsURL := url.URL{Scheme: "wss", Host: addr, Path: wsConfig.tcpPath}
connect, err := websocket.NewStreamEndpoint(wsURL.String(), sd, nil)
if err != nil {
baseConn.Close()
return nil, fmt.Errorf("failed to create websocket client: %w", err)
return nil, fmt.Errorf("failed to create websocket stream endpoint: %w", err)
}
return &wsToStreamConn{wsConn}, nil
return connect(ctx)
}), nil
})
}
Expand All @@ -119,22 +95,12 @@ func registerWebsocketPacketDialer(r TypeRegistry[transport.PacketDialer], typeI
return nil, errors.New("must specify udp_path")
}
return transport.FuncPacketDialer(func(ctx context.Context, addr string) (net.Conn, error) {
wsURL := url.URL{Scheme: "ws", Host: addr, Path: wsConfig.udpPath}
origin := url.URL{Scheme: "http", Host: addr}
wsCfg, err := websocket.NewConfig(wsURL.String(), origin.String())
if err != nil {
return nil, fmt.Errorf("failed to create websocket config: %w", err)
}
baseConn, err := sd.DialStream(ctx, addr)
if err != nil {
return nil, fmt.Errorf("failed to connect to websocket endpoint: %w", err)
}
wsConn, err := websocket.NewClient(wsCfg, baseConn)
wsURL := url.URL{Scheme: "wss", Host: addr, Path: wsConfig.udpPath}
connect, err := websocket.NewPacketEndpoint(wsURL.String(), sd, nil)
if err != nil {
baseConn.Close()
return nil, fmt.Errorf("failed to create websocket client: %w", err)
return nil, fmt.Errorf("failed to create websocket stream endpoint: %w", err)
}
return wsConn, nil
return connect(ctx)
}), nil
})
}
2 changes: 2 additions & 0 deletions x/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ require (
// Use github.com/Psiphon-Labs/psiphon-tunnel-core@staging-client as per
// https://github.com/Psiphon-Labs/psiphon-tunnel-core/?tab=readme-ov-file#using-psiphon-with-go-modules
github.com/Psiphon-Labs/psiphon-tunnel-core v1.0.11-0.20240619172145-03cade11f647
github.com/coder/websocket v1.8.12
github.com/gorilla/websocket v1.5.3
github.com/lmittmann/tint v1.0.5
github.com/quic-go/quic-go v0.48.1
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
Expand Down
4 changes: 4 additions & 0 deletions x/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ github.com/cheekybits/genny v0.0.0-20170328200008-9127e812e1e9/go.mod h1:+tQajlR
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/cognusion/go-cache-lru v0.0.0-20170419142635-f73e2280ecea h1:9C2rdYRp8Vzwhm3sbFX0yYfB+70zKFRjn7cnPCucHSw=
github.com/cognusion/go-cache-lru v0.0.0-20170419142635-f73e2280ecea/go.mod h1:MdyNkAe06D7xmJsf+MsLvbZKYNXuOHLKJrvw+x4LlcQ=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down Expand Up @@ -84,6 +86,8 @@ github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
github.com/google/pprof v0.0.0-20211214055906-6f57359322fd h1:1FjCyPC+syAzJ5/2S8fqdZK1R22vvA0J7JZKcuOIQ7Y=
github.com/google/pprof v0.0.0-20211214055906-6f57359322fd/go.mod h1:KgnwoLYCZ8IQu3XUZ8Nc/bM9CCZFOyjUNOSygVozoDg=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grafov/m3u8 v0.0.0-20171211212457-6ab8f28ed427 h1:xh96CCAZTX8LJPFoOVRgTwZbn2DvJl8fyCyivohhSIg=
github.com/grafov/m3u8 v0.0.0-20171211212457-6ab8f28ed427/go.mod h1:PdjzaU/pJUo4jTIn2rcgMFs+HqBGl/sPJLr8BI0Xq/I=
github.com/hashicorp/golang-lru v1.0.2 h1:dV3g9Z/unq5DpblPpw+Oqcv4dU/1omnb4Ok8iPY6p1c=
Expand Down
196 changes: 196 additions & 0 deletions x/websocket/endpoint.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
// Copyright 2025 The Outline Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package websocket provides the Websocket transport.
package websocket

import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"runtime"
"strings"
"time"

"github.com/Jigsaw-Code/outline-sdk/transport"
"github.com/gorilla/websocket"
)

// NewStreamEndpoint creates a new Websocket Stream Endpoint. Streams are sent over
// Websockets, with each Write becoming a separate message. Half-close is supported:
// CloseRead will not close the Websocket connection, while CloseWrite sends a Websocket
// close but continues reading until a close is received from the server.
func NewStreamEndpoint(urlStr string, sd transport.StreamDialer, opts ...Option) (func(context.Context) (transport.StreamConn, error), error) {
return newEndpoint(urlStr, sd, func(gc *gorillaConn) transport.StreamConn { return gc }, opts...)
}

// NewPacketEndpoint creates a new Websocket Packet Endpoint. Each packet is exchanged as a Websocket message.
func NewPacketEndpoint(urlStr string, sd transport.StreamDialer, opts ...Option) (func(context.Context) (net.Conn, error), error) {
return newEndpoint(urlStr, sd, func(gc *gorillaConn) net.Conn { return gc }, opts...)
}

type options struct {
tlsConfig *tls.Config
headers http.Header
}

// Option for building the Websocket endpoint.
type Option func(c *options)

// WithTLSConfig specifies the TLS configuration to use.
// TODO(fortuna): Use Outline TLS instead.
func WithTLSConfig(tlsConfig *tls.Config) Option {
return func(c *options) {
c.tlsConfig = tlsConfig
}
}

// WithHTTPHeaders specifies the HTTP headers to use.
func WithHTTPHeaders(headers http.Header) Option {
return func(c *options) {
c.headers = headers
}
}

func newEndpoint[ConnType net.Conn](urlStr string, sd transport.StreamDialer, wsToConn func(*gorillaConn) ConnType, opts ...Option) (func(context.Context) (ConnType, error), error) {
_, err := url.Parse(urlStr)
if err != nil {
return nil, fmt.Errorf("url is invalid: %w", err)
}

resolvedOpts := options{
// By default, we use this User-Agent.
headers: http.Header(map[string][]string{"User-Agent": {fmt.Sprintf("Outline (%s; %s; %s)", runtime.GOOS, runtime.GOARCH, runtime.Version())}}),
}
for _, opt := range opts {
opt(&resolvedOpts)
}

wsDialer := &websocket.Dialer{
TLSClientConfig: resolvedOpts.tlsConfig,
NetDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
if !strings.HasPrefix(network, "tcp") {
return nil, fmt.Errorf("websocket dialer does not support network type %v", network)
}
return sd.DialStream(ctx, addr)
},
}
return func(ctx context.Context) (ConnType, error) {
var zero ConnType
wsConn, _, err := wsDialer.DialContext(ctx, urlStr, resolvedOpts.headers)
if err != nil {
return zero, err
}
gConn := &gorillaConn{wsConn: wsConn}
wsConn.SetCloseHandler(func(code int, text string) error {
gConn.readErr = io.EOF
return nil
})
return wsToConn(gConn), nil
}, nil
}

type gorillaConn struct {
wsConn *websocket.Conn
writeErr error
readErr error
pendingReader io.Reader
}

var _ transport.StreamConn = (*gorillaConn)(nil)

func (c *gorillaConn) LocalAddr() net.Addr {
return c.wsConn.LocalAddr()
}

func (c *gorillaConn) RemoteAddr() net.Addr {
return c.wsConn.RemoteAddr()
}

func (c *gorillaConn) SetDeadline(deadline time.Time) error {
return errors.Join(c.wsConn.SetReadDeadline(deadline), c.wsConn.SetWriteDeadline(deadline))
}

func (c *gorillaConn) SetReadDeadline(deadline time.Time) error {
return c.wsConn.SetReadDeadline(deadline)
}

func (c *gorillaConn) SetWriteDeadline(deadline time.Time) error {
return c.wsConn.SetWriteDeadline(deadline)
}

func (c *gorillaConn) Read(buf []byte) (int, error) {
if c.readErr != nil {
return 0, c.readErr
}
if c.pendingReader != nil {
n, err := c.pendingReader.Read(buf)
if c.readErr != nil {
return n, c.readErr
}
if !errors.Is(err, io.EOF) {
return n, err
}
c.pendingReader = nil
}

msgType, reader, err := c.wsConn.NextReader()
if c.readErr != nil {
return 0, c.readErr
}
if err != nil {
return 0, err
}
if msgType != websocket.BinaryMessage {
return 0, errors.New("read message is not binary")
}

c.pendingReader = reader
return reader.Read(buf)
}

func (c *gorillaConn) Write(buf []byte) (int, error) {
err := c.wsConn.WriteMessage(websocket.BinaryMessage, buf)
if err != nil {
if c.writeErr != nil {
return 0, c.writeErr
}
return 0, err
}
return len(buf), nil
}

func (c *gorillaConn) CloseRead() error {
c.readErr = net.ErrClosed
c.wsConn.SetReadDeadline(time.Now())
return nil
}

func (c *gorillaConn) CloseWrite() error {
// Send close message.
message := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")
c.wsConn.WriteControl(websocket.CloseMessage, message, time.Now().Add(time.Second))
c.writeErr = net.ErrClosed
c.wsConn.SetWriteDeadline(time.Now())
return nil
}

func (c *gorillaConn) Close() error {
return c.wsConn.Close()
}
Loading

0 comments on commit 7da176f

Please sign in to comment.