Skip to content

Commit

Permalink
fix: prevent conn leakage in transport, tls and tlsfrag (#195)
Browse files Browse the repository at this point in the history
  • Loading branch information
jyyi1 authored Mar 12, 2024
1 parent 142376e commit 07b7d40
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 14 deletions.
8 changes: 4 additions & 4 deletions transport/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,14 @@ var _ net.Conn = (*boundPacketConn)(nil)
// For example, a [net.UDPConn] only supports IP addresses, not domain names.
// If the host is a domain name, consider pre-resolving it to avoid resolution calls.
func (e PacketListenerDialer) DialPacket(ctx context.Context, address string) (net.Conn, error) {
packetConn, err := e.Listener.ListenPacket(ctx)
if err != nil {
return nil, fmt.Errorf("could not create PacketConn: %w", err)
}
netAddr, err := MakeNetAddr("udp", address)
if err != nil {
return nil, err
}
packetConn, err := e.Listener.ListenPacket(ctx)
if err != nil {
return nil, fmt.Errorf("could not create PacketConn: %w", err)
}
return &boundPacketConn{
PacketConn: packetConn,
remoteAddr: netAddr,
Expand Down
36 changes: 36 additions & 0 deletions transport/packet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,16 @@ func TestPacketListenerDialer(t *testing.T) {
running.Wait()
}

// Make sure there are no connection leakage in DialPacket
func TestPacketListenerDialerDialPacketCloseInnerConnOnError(t *testing.T) {
inner := &connCounterListener{base: UDPListener{Address: "127.0.0.1:0"}}
pd := PacketListenerDialer{inner}
conn, err := pd.DialPacket(context.Background(), "invalid-address?987654321")
require.Error(t, err)
require.Nil(t, conn)
require.Zero(t, inner.activeConns)
}

// PacketConn assertions

func TestPacketConnInvalidArgument(t *testing.T) {
Expand All @@ -241,3 +251,29 @@ func TestPacketConnInvalidArgument(t *testing.T) {
// This returns Invalid Argument because netAddr is not a *UDPAddr
require.ErrorIs(t, err, syscall.EINVAL)
}

// Private test helpers

// connCounterListener is a PacketListener that counts the number of active PacketConns.
type connCounterListener struct {
base PacketListener
activeConns int
}

type countedPacketConn struct {
net.PacketConn
counter *connCounterListener
}

func (l *connCounterListener) ListenPacket(ctx context.Context) (net.PacketConn, error) {
conn, err := l.base.ListenPacket(ctx)
if conn != nil {
l.activeConns++
}
return countedPacketConn{conn, l}, err
}

func (c countedPacketConn) Close() error {
c.counter.activeConns--
return c.PacketConn.Close()
}
8 changes: 4 additions & 4 deletions transport/tls/stream_dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ func (c streamConn) CloseRead() error {

// DialStream implements [transport.StreamDialer].DialStream.
func (d *StreamDialer) DialStream(ctx context.Context, remoteAddr string) (transport.StreamConn, error) {
innerConn, err := d.dialer.DialStream(ctx, remoteAddr)
if err != nil {
return nil, err
}
host, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return nil, fmt.Errorf("invalid address: %w", err)
}
innerConn, err := d.dialer.DialStream(ctx, remoteAddr)
if err != nil {
return nil, err
}
conn, err := WrapConn(ctx, innerConn, host, d.options...)
if err != nil {
innerConn.Close()
Expand Down
37 changes: 37 additions & 0 deletions transport/tls/stream_dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,40 @@ func TestWithALPN(t *testing.T) {
WithALPN([]string{"h2", "http/1.1"})("", &cfg)
require.Equal(t, []string{"h2", "http/1.1"}, cfg.NextProtos)
}

// Make sure there are no connection leakage in DialStream
func TestDialStreamCloseInnerConnOnError(t *testing.T) {
inner := &connCounterDialer{base: &transport.TCPDialer{}}
sd, err := NewStreamDialer(inner)
require.NoError(t, err)
conn, err := sd.DialStream(context.Background(), "invalid-address?987654321")
require.Error(t, err)
require.Nil(t, conn)
require.Zero(t, inner.activeConns)
}

// Private test helpers

// connCounterDialer is a StreamDialer that counts the number of active StreamConns.
type connCounterDialer struct {
base transport.StreamDialer
activeConns int
}

type countedStreamConn struct {
transport.StreamConn
counter *connCounterDialer
}

func (d *connCounterDialer) DialStream(ctx context.Context, raddr string) (transport.StreamConn, error) {
conn, err := d.base.DialStream(ctx, raddr)
if conn != nil {
d.activeConns++
}
return countedStreamConn{conn, d}, err
}

func (c countedStreamConn) Close() error {
c.counter.activeConns--
return c.StreamConn.Close()
}
13 changes: 9 additions & 4 deletions transport/tlsfrag/stream_dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,17 @@ func NewStreamDialerFunc(base transport.StreamDialer, frag FragFunc) (transport.

// DialStream implements [transport.StreamConn].DialStream. It establishes a connection to raddr in the format "host-or-ip:port".
// The initial TLS Client Hello record sent through the connection will be fragmented.
func (d *tlsFragDialer) DialStream(ctx context.Context, raddr string) (conn transport.StreamConn, err error) {
conn, err = d.dialer.DialStream(ctx, raddr)
func (d *tlsFragDialer) DialStream(ctx context.Context, raddr string) (transport.StreamConn, error) {
baseConn, err := d.dialer.DialStream(ctx, raddr)
if err != nil {
return
return nil, err
}
conn, err := WrapConnFunc(baseConn, d.frag)
if err != nil {
baseConn.Close()
return nil, err
}
return WrapConnFunc(conn, d.frag)
return conn, nil
}

// WrapConnFunc wraps the base [transport.StreamConn] and splits the first TLS Client Hello packet into two records
Expand Down
16 changes: 14 additions & 2 deletions transport/tlsfrag/stream_dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,16 @@ func TestNestedFixedLenStreamDialerSplitsClientHello(t *testing.T) {
require.Equal(t, expected, inner.bufs)
}

// Make sure there are no connection leakage in DialStream
func TestDialStreamCloseInnerConnOnError(t *testing.T) {
inner := &collectStreamDialer{}
d := &tlsFragDialer{inner, nil}
conn, err := d.DialStream(context.Background(), "127.0.0.1:8888")
require.Error(t, err)
require.Nil(t, conn)
require.Zero(t, inner.activeConns)
}

// test assertions

func assertCanDialFragFunc(t *testing.T, inner transport.StreamDialer, raddr string, frag FragFunc) transport.StreamConn {
Expand Down Expand Up @@ -228,10 +238,12 @@ func constructTLSRecord(t *testing.T, typ layers.TLSType, ver layers.TLSVersion,

// collectStreamDialer collects all writes to this stream dialer and append it to bufs
type collectStreamDialer struct {
bufs net.Buffers
bufs net.Buffers
activeConns int
}

func (d *collectStreamDialer) DialStream(ctx context.Context, raddr string) (transport.StreamConn, error) {
d.activeConns++
return d, nil
}

Expand All @@ -241,7 +253,7 @@ func (c *collectStreamDialer) Write(p []byte) (int, error) {
}

func (c *collectStreamDialer) Read(p []byte) (int, error) { return 0, errors.New("not supported") }
func (c *collectStreamDialer) Close() error { return nil }
func (c *collectStreamDialer) Close() error { c.activeConns--; return nil }
func (c *collectStreamDialer) CloseRead() error { return nil }
func (c *collectStreamDialer) CloseWrite() error { return nil }
func (c *collectStreamDialer) LocalAddr() net.Addr { return nil }
Expand Down

0 comments on commit 07b7d40

Please sign in to comment.