Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine TCPMux memory usage #653

Merged
merged 1 commit into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 51 additions & 12 deletions tcp_mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"net"
"strings"
"sync"
"time"

"github.com/pion/logging"
"github.com/pion/stun/v2"
Expand Down Expand Up @@ -52,6 +53,16 @@
// if the write buffer is full, the subsequent write packet will be dropped until it has enough space.
// a default 4MB is recommended.
WriteBufferSize int

// A new established connection will be removed if the first STUN binding request is not received within this timeout,
// avoiding the client with bad network or attacker to create a lot of empty connections.
// Default 30s timeout will be used if not set.
FirstStunBindTimeout time.Duration

// TCPMux will create connection from STUN binding request with an unknown username, if
// the connection is not used in the timeout, it will be removed to avoid resource leak / attack.
// Default 30s timeout will be used if not set.
AliveDurationForConnFromStun time.Duration
}

// NewTCPMuxDefault creates a new instance of TCPMuxDefault.
Expand All @@ -60,6 +71,14 @@
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
}

if params.FirstStunBindTimeout == 0 {
params.FirstStunBindTimeout = 30 * time.Second
}

if params.AliveDurationForConnFromStun == 0 {
params.AliveDurationForConnFromStun = 30 * time.Second
}

m := &TCPMuxDefault{
params: &params,

Expand Down Expand Up @@ -110,25 +129,32 @@
}

if conn, ok := m.getConn(ufrag, isIPv6, local); ok {
conn.ClearAliveTimer()
return conn, nil
}

return m.createConn(ufrag, isIPv6, local)
return m.createConn(ufrag, isIPv6, local, false)
}

func (m *TCPMuxDefault) createConn(ufrag string, isIPv6 bool, local net.IP) (*tcpPacketConn, error) {
func (m *TCPMuxDefault) createConn(ufrag string, isIPv6 bool, local net.IP, fromStun bool) (*tcpPacketConn, error) {
addr, ok := m.LocalAddr().(*net.TCPAddr)
if !ok {
return nil, ErrGetTransportAddress
}
localAddr := *addr
localAddr.IP = local

var alive time.Duration
if fromStun {
alive = m.params.AliveDurationForConnFromStun
}

conn := newTCPPacketConn(tcpPacketParams{
ReadBuffer: m.params.ReadBufferSize,
WriteBuffer: m.params.WriteBufferSize,
LocalAddr: &localAddr,
Logger: m.params.Logger,
ReadBuffer: m.params.ReadBufferSize,
WriteBuffer: m.params.WriteBufferSize,
LocalAddr: &localAddr,
Logger: m.params.Logger,
AliveDuration: alive,
})

var conns map[ipAddr]*tcpPacketConn
Expand Down Expand Up @@ -163,13 +189,26 @@
}

func (m *TCPMuxDefault) handleConn(conn net.Conn) {
buf := make([]byte, receiveMTU)
buf := make([]byte, 512)

if m.params.FirstStunBindTimeout > 0 {
if err := conn.SetReadDeadline(time.Now().Add(m.params.FirstStunBindTimeout)); err != nil {
m.params.Logger.Warnf("Failed to set read deadline for first STUN message: %s to %s , err: %s", conn.RemoteAddr(), conn.LocalAddr(), err)
}

Check warning on line 197 in tcp_mux.go

View check run for this annotation

Codecov / codecov/patch

tcp_mux.go#L196-L197

Added lines #L196 - L197 were not covered by tests
}
n, err := readStreamingPacket(conn, buf)
if err != nil {
m.params.Logger.Warnf("Error reading first packet from %s: %s", conn.RemoteAddr().String(), err)
if errors.Is(err, io.ErrShortBuffer) {
m.params.Logger.Warnf("Buffer too small for first packet from %s: %s", conn.RemoteAddr(), err)

Check warning on line 202 in tcp_mux.go

View check run for this annotation

Codecov / codecov/patch

tcp_mux.go#L202

Added line #L202 was not covered by tests
} else {
m.params.Logger.Warnf("Error reading first packet from %s: %s", conn.RemoteAddr(), err)
}
m.closeAndLogError(conn)
return
}
if err = conn.SetReadDeadline(time.Time{}); err != nil {
m.params.Logger.Warnf("Failed to reset read deadline from %s: %s", conn.RemoteAddr(), err)
}

Check warning on line 211 in tcp_mux.go

View check run for this annotation

Codecov / codecov/patch

tcp_mux.go#L210-L211

Added lines #L210 - L211 were not covered by tests

buf = buf[:n]

Expand Down Expand Up @@ -204,9 +243,6 @@
ufrag := strings.Split(string(attr), ":")[0]
m.params.Logger.Debugf("Ufrag: %s", ufrag)

m.mu.Lock()
defer m.mu.Unlock()

host, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err != nil {
m.closeAndLogError(conn)
Expand All @@ -222,15 +258,18 @@
m.params.Logger.Warnf("Failed to get local tcp address in STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr())
return
}
m.mu.Lock()
packetConn, ok := m.getConn(ufrag, isIPv6, localAddr.IP)
if !ok {
packetConn, err = m.createConn(ufrag, isIPv6, localAddr.IP)
packetConn, err = m.createConn(ufrag, isIPv6, localAddr.IP, true)
if err != nil {
m.mu.Unlock()

Check warning on line 266 in tcp_mux.go

View check run for this annotation

Codecov / codecov/patch

tcp_mux.go#L266

Added line #L266 was not covered by tests
m.closeAndLogError(conn)
m.params.Logger.Warnf("Failed to create packetConn for STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr())
return
}
}
m.mu.Unlock()

if err := packetConn.AddConn(conn, buf); err != nil {
m.closeAndLogError(conn)
Expand Down
141 changes: 141 additions & 0 deletions tcp_mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ package ice
import (
"io"
"net"
"os"
"testing"
"time"

"github.com/pion/logging"
"github.com/pion/stun/v2"
Expand Down Expand Up @@ -108,6 +110,10 @@ func TestTCPMux_NoDeadlockWhenClosingUnusedPacketConn(t *testing.T) {
ReadBufferSize: 20,
})

defer func() {
_ = tcpMux.Close()
}()

_, err = tcpMux.GetConnByUfrag("test", false, listener.Addr().(*net.TCPAddr).IP)
require.NoError(t, err, "error getting conn by ufrag")

Expand All @@ -117,3 +123,138 @@ func TestTCPMux_NoDeadlockWhenClosingUnusedPacketConn(t *testing.T) {
assert.Nil(t, conn, "should receive nil because mux is closed")
assert.Equal(t, io.ErrClosedPipe, err, "should receive error because mux is closed")
}

func TestTCPMux_FirstPacketTimeout(t *testing.T) {
report := test.CheckRoutines(t)
defer report()

loggerFactory := logging.NewDefaultLoggerFactory()

listener, err := net.ListenTCP("tcp", &net.TCPAddr{
IP: net.IP{127, 0, 0, 1},
Port: 0,
})
require.NoError(t, err, "error starting listener")
defer func() {
_ = listener.Close()
}()

tcpMux := NewTCPMuxDefault(TCPMuxParams{
Listener: listener,
Logger: loggerFactory.NewLogger("ice"),
ReadBufferSize: 20,
FirstStunBindTimeout: time.Second,
})

require.NotNil(t, tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")

conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr))
require.NoError(t, err, "error dialing test TCP connection")
defer func() {
_ = conn.Close()
}()

// Don't send any data, the mux should close the connection after the timeout
time.Sleep(1500 * time.Millisecond)
require.NoError(t, conn.SetReadDeadline(time.Now().Add(2*time.Second)))
buf := make([]byte, 1)
_, err = conn.Read(buf)
require.ErrorIs(t, err, io.EOF)
}

func TestTCPMux_NoLeakForConnectionFromStun(t *testing.T) {
report := test.CheckRoutines(t)
defer report()

loggerFactory := logging.NewDefaultLoggerFactory()

listener, err := net.ListenTCP("tcp", &net.TCPAddr{
IP: net.IP{127, 0, 0, 1},
Port: 0,
})
require.NoError(t, err, "error starting listener")
defer func() {
_ = listener.Close()
}()

tcpMux := NewTCPMuxDefault(TCPMuxParams{
Listener: listener,
Logger: loggerFactory.NewLogger("ice"),
ReadBufferSize: 20,
AliveDurationForConnFromStun: time.Second,
})

defer func() {
_ = tcpMux.Close()
}()

require.NotNil(t, tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")

t.Run("close connection from stun msg after timeout", func(t *testing.T) {
conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr))
require.NoError(t, err, "error dialing test TCP connection")
defer func() {
_ = conn.Close()
}()

msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
stun.NewUsername("myufrag:otherufrag"),
stun.NewShortTermIntegrity("myufrag"),
stun.Fingerprint,
)
require.NoError(t, err, "error building STUN packet")
msg.Encode()

_, err = writeStreamingPacket(conn, msg.Raw)
require.NoError(t, err, "error writing TCP STUN packet")

time.Sleep(1500 * time.Millisecond)
require.NoError(t, conn.SetReadDeadline(time.Now().Add(2*time.Second)))
buf := make([]byte, 1)
_, err = conn.Read(buf)
require.ErrorIs(t, err, io.EOF)
})

t.Run("connection keep alive if access by user", func(t *testing.T) {
conn, err := net.DialTCP("tcp", nil, tcpMux.LocalAddr().(*net.TCPAddr))
require.NoError(t, err, "error dialing test TCP connection")
defer func() {
_ = conn.Close()
}()

msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
stun.NewUsername("myufrag2:otherufrag2"),
stun.NewShortTermIntegrity("myufrag2"),
stun.Fingerprint,
)
require.NoError(t, err, "error building STUN packet")
msg.Encode()

n, err := writeStreamingPacket(conn, msg.Raw)
require.NoError(t, err, "error writing TCP STUN packet")

// wait for the connection to be created
time.Sleep(100 * time.Millisecond)

pktConn, err := tcpMux.GetConnByUfrag("myufrag2", false, listener.Addr().(*net.TCPAddr).IP)
require.NoError(t, err, "error retrieving muxed connection for ufrag")
defer func() {
_ = pktConn.Close()
}()

time.Sleep(1500 * time.Millisecond)

// timeout, not closed
buf := make([]byte, 1024)
require.NoError(t, conn.SetReadDeadline(time.Now().Add(100*time.Millisecond)))
_, err = conn.Read(buf)
require.ErrorIs(t, err, os.ErrDeadlineExceeded)

recv := make([]byte, n)
n2, rAddr, err := pktConn.ReadFrom(recv)
require.NoError(t, err, "error receiving data")
assert.Equal(t, conn.LocalAddr(), rAddr, "remote tcp address mismatch")
assert.Equal(t, n, n2, "received byte size mismatch")
assert.Equal(t, msg.Raw, recv, "received bytes mismatch")
})
}
28 changes: 24 additions & 4 deletions tcp_packet_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ type tcpPacketConn struct {
wg sync.WaitGroup
closedChan chan struct{}
closeOnce sync.Once
aliveTimer *time.Timer
}

type streamingPacket struct {
Expand All @@ -94,10 +95,11 @@ type streamingPacket struct {
}

type tcpPacketParams struct {
ReadBuffer int
LocalAddr net.Addr
Logger logging.LeveledLogger
WriteBuffer int
ReadBuffer int
LocalAddr net.Addr
Logger logging.LeveledLogger
WriteBuffer int
AliveDuration time.Duration
}

func newTCPPacketConn(params tcpPacketParams) *tcpPacketConn {
Expand All @@ -110,9 +112,24 @@ func newTCPPacketConn(params tcpPacketParams) *tcpPacketConn {
closedChan: make(chan struct{}),
}

if params.AliveDuration > 0 {
p.aliveTimer = time.AfterFunc(params.AliveDuration, func() {
p.params.Logger.Warn("close tcp packet conn by alive timeout")
_ = p.Close()
})
}

return p
}

func (t *tcpPacketConn) ClearAliveTimer() {
t.mu.Lock()
if t.aliveTimer != nil {
t.aliveTimer.Stop()
}
t.mu.Unlock()
}

func (t *tcpPacketConn) AddConn(conn net.Conn, firstPacketData []byte) error {
t.params.Logger.Infof("Added connection: %s remote %s to local %s", conn.RemoteAddr().Network(), conn.RemoteAddr(), conn.LocalAddr())

Expand Down Expand Up @@ -265,6 +282,9 @@ func (t *tcpPacketConn) Close() error {
t.closeOnce.Do(func() {
close(t.closedChan)
shouldCloseRecvChan = true
if t.aliveTimer != nil {
t.aliveTimer.Stop()
}
})

for _, conn := range t.conns {
Expand Down
Loading