From 4daba09b150c6e5b75e96c7849de1850c410ad9a Mon Sep 17 00:00:00 2001 From: Paul Wells Date: Tue, 30 Apr 2024 02:52:19 -0700 Subject: [PATCH] Skip UDP address serialization in muxed conn --- udp_mux.go | 8 +- udp_mux_test.go | 68 ---------------- udp_muxed_conn.go | 201 +++++++++++++++++++++------------------------- 3 files changed, 97 insertions(+), 180 deletions(-) diff --git a/udp_mux.go b/udp_mux.go index 40e8ad9b..c78578b4 100644 --- a/udp_mux.go +++ b/udp_mux.go @@ -48,8 +48,6 @@ type UDPMuxDefault struct { localAddrsForUnspecified []net.Addr } -const maxAddrSize = 512 - // UDPMuxParams are parameters for UDPMux. type UDPMuxParams struct { Logger logging.LeveledLogger @@ -120,7 +118,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { pool: &sync.Pool{ New: func() interface{} { // Big enough buffer to fit both packet and address - return newBufferHolder(receiveMTU + maxAddrSize) + return newBufferHolder(receiveMTU) }, }, localAddrsForUnspecified: localAddrsForUnspecified, @@ -365,7 +363,9 @@ func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, o } type bufferHolder struct { - buf []byte + next *bufferHolder + buf []byte + addr *net.UDPAddr } func newBufferHolder(size int) *bufferHolder { diff --git a/udp_mux_test.go b/udp_mux_test.go index 77c575f4..d25a0705 100644 --- a/udp_mux_test.go +++ b/udp_mux_test.go @@ -126,74 +126,6 @@ func TestUDPMux(t *testing.T) { } } -func TestAddressEncoding(t *testing.T) { - cases := []struct { - name string - addr net.UDPAddr - }{ - { - name: "empty address", - }, - { - name: "ipv4", - addr: net.UDPAddr{ - IP: net.IPv4(244, 120, 0, 5), - Port: 6000, - Zone: "", - }, - }, - { - name: "ipv6", - addr: net.UDPAddr{ - IP: net.IPv6loopback, - Port: 2500, - Zone: "zone", - }, - }, - } - - for _, c := range cases { - addr := c.addr - t.Run(c.name, func(t *testing.T) { - buf := make([]byte, maxAddrSize) - n, err := encodeUDPAddr(&addr, buf) - require.NoError(t, err) - - parsedAddr, err := decodeUDPAddr(buf[:n]) - require.NoError(t, err) - require.EqualValues(t, &addr, parsedAddr) - }) - } -} - -func BenchmarkAddressEncoding(b *testing.B) { - addr := &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 1234, - } - buf := make([]byte, 64) - - b.Run("encode", func(b *testing.B) { - for i := 0; i < b.N; i++ { - if _, err := encodeUDPAddr(addr, buf); err != nil { - require.NoError(b, err) - } - } - }) - - b.Run("decode", func(b *testing.B) { - n, _ := encodeUDPAddr(addr, buf) - var addr *net.UDPAddr - var err error - for i := 0; i < b.N; i++ { - if addr, err = decodeUDPAddr(buf[:n]); err != nil { - require.NoError(b, err) - } - } - _ = addr - }) -} - func testMuxConnection(t *testing.T, udpMux *UDPMuxDefault, ufrag string, network string) { pktConn, err := udpMux.GetConn(ufrag, udpMux.LocalAddr()) require.NoError(t, err, "error retrieving muxed connection for ufrag") diff --git a/udp_muxed_conn.go b/udp_muxed_conn.go index 9c279b89..8de840e6 100644 --- a/udp_muxed_conn.go +++ b/udp_muxed_conn.go @@ -4,14 +4,20 @@ package ice import ( - "encoding/binary" "io" "net" "sync" "time" "github.com/pion/logging" - "github.com/pion/transport/v3/packetio" +) + +type udpMuxedConnState int + +const ( + udpMuxedConnOpen udpMuxedConnState = iota + udpMuxedConnWaiting + udpMuxedConnClosed ) type udpMuxedConnParams struct { @@ -28,52 +34,62 @@ type udpMuxedConn struct { // Remote addresses that we have sent to on this conn addresses []ipPort - // Channel holding incoming packets - buf *packetio.Buffer - closedChan chan struct{} - closeOnce sync.Once - mu sync.Mutex + // FIFO queue holding incoming packets + bufHead, bufTail *bufferHolder + notify chan struct{} + closedChan chan struct{} + state udpMuxedConnState + mu sync.Mutex } func newUDPMuxedConn(params *udpMuxedConnParams) *udpMuxedConn { - p := &udpMuxedConn{ + return &udpMuxedConn{ params: params, - buf: packetio.NewBuffer(), + notify: make(chan struct{}, 1), closedChan: make(chan struct{}), } - - return p } func (c *udpMuxedConn) ReadFrom(b []byte) (n int, rAddr net.Addr, err error) { - buf := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert - defer c.params.AddrPool.Put(buf) - - // Read address - total, err := c.buf.Read(buf.buf) - if err != nil { - return 0, nil, err - } - - dataLen := int(binary.LittleEndian.Uint16(buf.buf[:2])) - if dataLen > total || dataLen > len(b) { - return 0, nil, io.ErrShortBuffer - } + for { + c.mu.Lock() + for c.bufTail != nil { + pkt := c.bufTail + c.bufTail = pkt.next + + if pkt == c.bufHead { + c.bufHead = nil + } + c.mu.Unlock() + + if len(b) < len(pkt.buf) { + err = io.ErrShortBuffer + } else { + n = copy(b, pkt.buf) + rAddr = pkt.addr + } + + pkt.next = nil + pkt.addr = nil + c.params.AddrPool.Put(pkt) + + return + } - // Read data and then address - offset := 2 - copy(b, buf.buf[offset:offset+dataLen]) - offset += dataLen + if c.state == udpMuxedConnClosed { + c.mu.Unlock() + return 0, nil, io.EOF + } - // Read address len & decode address - addrLen := int(binary.LittleEndian.Uint16(buf.buf[offset : offset+2])) - offset += 2 + c.state = udpMuxedConnWaiting + c.mu.Unlock() - if rAddr, err = decodeUDPAddr(buf.buf[offset : offset+addrLen]); err != nil { - return 0, nil, err + select { + case <-c.notify: + case <-c.closedChan: + return 0, nil, io.EOF + } } - - return dataLen, rAddr, nil } func (c *udpMuxedConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) { @@ -118,21 +134,29 @@ func (c *udpMuxedConn) CloseChannel() <-chan struct{} { } func (c *udpMuxedConn) Close() error { - var err error - c.closeOnce.Do(func() { - err = c.buf.Close() + c.mu.Lock() + defer c.mu.Unlock() + if c.state != udpMuxedConnClosed { + for pkt := c.bufTail; pkt != nil; { + next := pkt.next + + pkt.next = nil + pkt.addr = nil + c.params.AddrPool.Put(pkt) + + pkt = next + } + + c.state = udpMuxedConnClosed close(c.closedChan) - }) - return err + } + return nil } func (c *udpMuxedConn) isClosed() bool { - select { - case <-c.closedChan: - return true - default: - return false - } + c.mu.Lock() + defer c.mu.Unlock() + return c.state == udpMuxedConnClosed } func (c *udpMuxedConn) getAddresses() []ipPort { @@ -178,79 +202,40 @@ func (c *udpMuxedConn) containsAddress(addr ipPort) bool { } func (c *udpMuxedConn) writePacket(data []byte, addr *net.UDPAddr) error { - // Write two packets, address and data - buf := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert - defer c.params.AddrPool.Put(buf) - - // Format of buffer | data len | data bytes | addr len | addr bytes | - if len(buf.buf) < len(data)+maxAddrSize { + pkt := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert + if cap(pkt.buf) < len(data) { + c.params.AddrPool.Put(pkt) return io.ErrShortBuffer } - // Data length - binary.LittleEndian.PutUint16(buf.buf, uint16(len(data))) - offset := 2 - // Data - copy(buf.buf[offset:], data) - offset += len(data) + pkt.buf = append(pkt.buf[:0], data...) + pkt.addr = addr - // Write address first, leaving room for its length - n, err := encodeUDPAddr(addr, buf.buf[offset+2:]) - if err != nil { - return err + c.mu.Lock() + if c.state == udpMuxedConnClosed { + c.mu.Unlock() + return io.ErrClosedPipe } - total := offset + n + 2 - - // Address len - binary.LittleEndian.PutUint16(buf.buf[offset:], uint16(n)) - if _, err := c.buf.Write(buf.buf[:total]); err != nil { - return err + if c.bufHead != nil { + c.bufHead.next = pkt } - return nil -} + c.bufHead = pkt -func encodeUDPAddr(addr *net.UDPAddr, buf []byte) (int, error) { - total := 1 + len(addr.IP) + 2 + len(addr.Zone) - if len(buf) < total { - return 0, io.ErrShortBuffer + if c.bufTail == nil { + c.bufTail = pkt } - buf[0] = uint8(len(addr.IP)) - offset := 1 - - copy(buf[offset:], addr.IP) - offset += len(addr.IP) - - binary.LittleEndian.PutUint16(buf[offset:], uint16(addr.Port)) - offset += 2 - - copy(buf[offset:], addr.Zone) - return total, nil -} - -func decodeUDPAddr(buf []byte) (*net.UDPAddr, error) { - addr := &net.UDPAddr{} - - // Basic bounds checking - if len(buf) == 0 || len(buf) < int(buf[0])+3 { - return nil, io.ErrShortBuffer - } - - ipLen := int(buf[0]) - offset := 1 + state := c.state + c.state = udpMuxedConnOpen + c.mu.Unlock() - if ipLen == 0 { - addr.IP = nil - } else { - addr.IP = append(addr.IP[:0], buf[offset:offset+ipLen]...) - offset += ipLen + if state == udpMuxedConnWaiting { + select { + case c.notify <- struct{}{}: + default: + } } - addr.Port = int(binary.LittleEndian.Uint16(buf[offset:])) - offset += 2 - - addr.Zone = string(buf[offset:]) - - return addr, nil + return nil }