From a4fe5f4d8020553c406b240a4932751ed70a7dff Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Mon, 25 Mar 2024 18:28:30 -0400 Subject: [PATCH] Support IPv6 from mDNS This also adds support for IPv6 with zones, making use of the netip package. This breaks the TCPMux and AllConnsGetter interfaces to get more accurate Mux matching. --- active_tcp.go | 9 +- active_tcp_test.go | 39 ++++--- addr.go | 99 ++++++++++++++---- agent.go | 59 +++++++---- agent_test.go | 1 + candidate.go | 1 + candidate_base.go | 7 +- candidate_host.go | 34 +++--- candidate_peer_reflexive.go | 14 +-- candidate_relay.go | 25 +++-- candidate_server_reflexive.go | 29 ++++-- gather.go | 106 +++++++++++++++---- gather_test.go | 6 +- gather_vnet_test.go | 24 ++--- mdns.go | 87 ++++++++++++++-- mdns_test.go | 191 +++++++++++++++++++++------------- net.go | 81 ++++++++------ net_test.go | 52 ++++++--- networktype.go | 12 +-- networktype_test.go | 4 +- tcp_mux.go | 44 +++++--- tcp_mux_multi.go | 11 +- tcp_mux_multi_test.go | 16 ++- tcp_mux_test.go | 20 +++- transport_test.go | 2 + udp_mux.go | 24 +++-- udp_mux_multi.go | 12 ++- udp_mux_multi_test.go | 7 +- udp_mux_test.go | 22 +++- udp_muxed_conn.go | 2 +- 30 files changed, 715 insertions(+), 325 deletions(-) diff --git a/active_tcp.go b/active_tcp.go index 4ffcb6e7..56d78219 100644 --- a/active_tcp.go +++ b/active_tcp.go @@ -47,7 +47,6 @@ func newActiveTCPConn(ctx context.Context, localAddress, remoteAddress string, l log.Infof("Failed to dial TCP address %s: %v", remoteAddress, err) return } - a.remoteAddr.Store(conn.RemoteAddr()) go func() { @@ -95,8 +94,9 @@ func (a *activeTCPConn) ReadFrom(buff []byte) (n int, srcAddr net.Addr, err erro return 0, nil, io.ErrClosedPipe } - srcAddr = a.RemoteAddr() n, err = a.readBuffer.Read(buff) + // RemoteAddr is assuredly set *after* we can read from the buffer + srcAddr = a.RemoteAddr() return } @@ -123,6 +123,11 @@ func (a *activeTCPConn) LocalAddr() net.Addr { return &net.TCPAddr{} } +// RemoteAddr returns the remote address of the connection which is only +// set once a background goroutine has successfully dialed. That means +// this may return ":0" for the address prior to that happening. If this +// becomes an issue, we can introduce a synchronization point between Dial +// and these methods. func (a *activeTCPConn) RemoteAddr() net.Addr { if v, ok := a.remoteAddr.Load().(*net.TCPAddr); ok { return v diff --git a/active_tcp_test.go b/active_tcp_test.go index 7b696d1c..064559ed 100644 --- a/active_tcp_test.go +++ b/active_tcp_test.go @@ -8,6 +8,7 @@ package ice import ( "net" + "net/netip" "testing" "time" @@ -17,21 +18,21 @@ import ( "github.com/stretchr/testify/require" ) -func getLocalIPAddress(t *testing.T, networkType NetworkType) net.IP { +func getLocalIPAddress(t *testing.T, networkType NetworkType) netip.Addr { net, err := stdnet.NewNet() require.NoError(t, err) - localIPs, err := localInterfaces(net, nil, nil, []NetworkType{networkType}, false) + _, localAddrs, err := localInterfaces(net, problematicNetworkInterfaces, nil, []NetworkType{networkType}, false) require.NoError(t, err) - require.NotEmpty(t, localIPs) - return localIPs[0] + require.NotEmpty(t, localAddrs) + return localAddrs[0] } func ipv6Available(t *testing.T) bool { net, err := stdnet.NewNet() require.NoError(t, err) - localIPs, err := localInterfaces(net, nil, nil, []NetworkType{NetworkTypeTCP6}, false) + _, localAddrs, err := localInterfaces(net, problematicNetworkInterfaces, nil, []NetworkType{NetworkTypeTCP6}, false) require.NoError(t, err) - return len(localIPs) > 0 + return len(localAddrs) > 0 } func TestActiveTCP(t *testing.T) { @@ -43,8 +44,9 @@ func TestActiveTCP(t *testing.T) { type testCase struct { name string networkTypes []NetworkType - listenIPAddress net.IP + listenIPAddress netip.Addr selectedPairNetworkType string + useMDNS bool } testCases := []testCase{ @@ -69,12 +71,16 @@ func TestActiveTCP(t *testing.T) { networkTypes: []NetworkType{NetworkTypeTCP6}, listenIPAddress: getLocalIPAddress(t, NetworkTypeTCP6), selectedPairNetworkType: tcp, + // if we don't use mDNS, we will be filtering out location tracked ips. + useMDNS: true, }, testCase{ - name: "UDP is preferred over TCP6", // This fails some time + name: "UDP is preferred over TCP6", networkTypes: supportedNetworkTypes(), listenIPAddress: getLocalIPAddress(t, NetworkTypeTCP6), selectedPairNetworkType: udp, + // if we don't use mDNS, we will be filtering out location tracked ips. + useMDNS: true, }, ) } @@ -84,8 +90,9 @@ func TestActiveTCP(t *testing.T) { r := require.New(t) listener, err := net.ListenTCP("tcp", &net.TCPAddr{ - IP: testCase.listenIPAddress, + IP: testCase.listenIPAddress.AsSlice(), Port: listenPort, + Zone: testCase.listenIPAddress.Zone(), }) r.NoError(err) defer func() { @@ -107,14 +114,18 @@ func TestActiveTCP(t *testing.T) { r.NotNil(tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil") hostAcceptanceMinWait := 100 * time.Millisecond - passiveAgent, err := NewAgent(&AgentConfig{ + cfg := &AgentConfig{ TCPMux: tcpMux, CandidateTypes: []CandidateType{CandidateTypeHost}, NetworkTypes: testCase.networkTypes, LoggerFactory: loggerFactory, - IncludeLoopback: true, HostAcceptanceMinWait: &hostAcceptanceMinWait, - }) + InterfaceFilter: problematicNetworkInterfaces, + } + if testCase.useMDNS { + cfg.MulticastDNSMode = MulticastDNSModeQueryAndGather + } + passiveAgent, err := NewAgent(cfg) r.NoError(err) r.NotNil(passiveAgent) @@ -123,6 +134,7 @@ func TestActiveTCP(t *testing.T) { NetworkTypes: testCase.networkTypes, LoggerFactory: loggerFactory, HostAcceptanceMinWait: &hostAcceptanceMinWait, + InterfaceFilter: problematicNetworkInterfaces, }) r.NoError(err) r.NotNil(activeAgent) @@ -166,7 +178,8 @@ func TestActiveTCP_NonBlocking(t *testing.T) { defer test.TimeOut(time.Second * 5).Stop() cfg := &AgentConfig{ - NetworkTypes: supportedNetworkTypes(), + NetworkTypes: supportedNetworkTypes(), + InterfaceFilter: problematicNetworkInterfaces, } aAgent, err := NewAgent(cfg) diff --git a/addr.go b/addr.go index 1d70025b..18e3b877 100644 --- a/addr.go +++ b/addr.go @@ -4,52 +4,111 @@ package ice import ( + "fmt" "net" + "net/netip" ) -func parseMulticastAnswerAddr(in net.Addr) (net.IP, bool) { +func addrWithOptionalZone(addr netip.Addr, zone string) netip.Addr { + if zone == "" { + return addr + } + if addr.Is6() && (addr.IsLinkLocalUnicast() || addr.IsLinkLocalMulticast()) { + return addr.WithZone(zone) + } + return addr +} + +// parseAddrFromIface should only be used when it's known the address belongs to that interface. +// e.g. it's LocalAddress on a listener. +func parseAddrFromIface(in net.Addr, ifcName string) (netip.Addr, int, NetworkType, error) { + addr, port, nt, err := parseAddr(in) + if err != nil { + return netip.Addr{}, 0, 0, err + } + switch in.(type) { + case *net.IPNet: + // net.IPNet does not have a Zone but we provide it from the interface + addr = addrWithOptionalZone(addr, ifcName) + } + return addr, port, nt, nil +} + +func parseAddr(in net.Addr) (netip.Addr, int, NetworkType, error) { switch addr := in.(type) { + case *net.IPNet: + ipAddr, err := ipAddrToNetIP(addr.IP, "") + if err != nil { + return netip.Addr{}, 0, 0, err + } + return ipAddr, 0, 0, nil case *net.IPAddr: - return addr.IP, true + ipAddr, err := ipAddrToNetIP(addr.IP, addr.Zone) + if err != nil { + return netip.Addr{}, 0, 0, err + } + return ipAddr, 0, 0, nil case *net.UDPAddr: - return addr.IP, true + ipAddr, err := ipAddrToNetIP(addr.IP, addr.Zone) + if err != nil { + return netip.Addr{}, 0, 0, err + } + var nt NetworkType + if ipAddr.Is4() { + nt = NetworkTypeUDP4 + } else { + nt = NetworkTypeUDP6 + } + return ipAddr, addr.Port, nt, nil case *net.TCPAddr: - return addr.IP, true + ipAddr, err := ipAddrToNetIP(addr.IP, addr.Zone) + if err != nil { + return netip.Addr{}, 0, 0, err + } + var nt NetworkType + if ipAddr.Is4() { + nt = NetworkTypeTCP4 + } else { + nt = NetworkTypeTCP6 + } + return ipAddr, addr.Port, nt, nil + default: + return netip.Addr{}, 0, 0, fmt.Errorf("do not know how to parse address type %T", in) } - return nil, false } -func parseAddr(in net.Addr) (net.IP, int, NetworkType, bool) { - switch addr := in.(type) { - case *net.UDPAddr: - return addr.IP, addr.Port, NetworkTypeUDP4, true - case *net.TCPAddr: - return addr.IP, addr.Port, NetworkTypeTCP4, true +func ipAddrToNetIP(ip []byte, zone string) (netip.Addr, error) { + netIPAddr, ok := netip.AddrFromSlice(ip) + if !ok { + return netip.Addr{}, fmt.Errorf("failed to convert IP '%s' to netip.Addr", ip) } - return nil, 0, 0, false + // we'd rather have an IPv4-mapped IPv6 become IPv4 so that it is usable. + netIPAddr = netIPAddr.Unmap() + netIPAddr = addrWithOptionalZone(netIPAddr, zone) + return netIPAddr, nil } -func createAddr(network NetworkType, ip net.IP, port int) net.Addr { +func createAddr(network NetworkType, ip netip.Addr, port int) net.Addr { switch { case network.IsTCP(): - return &net.TCPAddr{IP: ip, Port: port} + return &net.TCPAddr{IP: ip.AsSlice(), Port: port, Zone: ip.Zone()} default: - return &net.UDPAddr{IP: ip, Port: port} + return &net.UDPAddr{IP: ip.AsSlice(), Port: port, Zone: ip.Zone()} } } func addrEqual(a, b net.Addr) bool { - aIP, aPort, aType, aOk := parseAddr(a) - if !aOk { + aIP, aPort, aType, aErr := parseAddr(a) + if aErr != nil { return false } - bIP, bPort, bType, bOk := parseAddr(b) - if !bOk { + bIP, bPort, bType, bErr := parseAddr(b) + if bErr != nil { return false } - return aType == bType && aIP.Equal(bIP) && aPort == bPort + return aType == bType && aIP.Compare(bIP) == 0 && aPort == bPort } // AddrPort is an IP and a port number. diff --git a/agent.go b/agent.go index 27d429d3..8fb52f6b 100644 --- a/agent.go +++ b/agent.go @@ -228,9 +228,22 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit } } + localIfcs, _, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, a.networkTypes, a.includeLoopback) + if err != nil { + return nil, fmt.Errorf("error getting local interfaces: %w", err) + } + // Opportunistic mDNS: If we can't open the connection, that's ok: we // can continue without it. - if a.mDNSConn, a.mDNSMode, err = createMulticastDNS(a.net, mDNSMode, mDNSName, log); err != nil { + if a.mDNSConn, a.mDNSMode, err = createMulticastDNS( + a.net, + a.networkTypes, + localIfcs, + a.includeLoopback, + mDNSMode, + mDNSName, + log, + ); err != nil { log.Warnf("Failed to initialize mDNS %s: %v", mDNSName, err) } @@ -592,19 +605,14 @@ func (a *Agent) resolveAndAddMulticastCandidate(c *CandidateHost) { if a.mDNSConn == nil { return } - _, src, err := a.mDNSConn.Query(c.context(), c.Address()) + + _, src, err := a.mDNSConn.QueryAddr(c.context(), c.Address()) if err != nil { a.log.Warnf("Failed to discover mDNS candidate %s: %v", c.Address(), err) return } - ip, ipOk := parseMulticastAnswerAddr(src) - if !ipOk { - a.log.Warnf("Failed to discover mDNS candidate %s: failed to parse IP", c.Address()) - return - } - - if err = c.setIP(ip); err != nil { + if err = c.setIPAddr(src); err != nil { a.log.Warnf("Failed to discover mDNS candidate %s: %v", c.Address(), err) return } @@ -626,7 +634,7 @@ func (a *Agent) requestConnectivityCheck() { } func (a *Agent) addRemotePassiveTCPCandidate(remoteCandidate Candidate) { - localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{remoteCandidate.NetworkType()}, a.includeLoopback) + _, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{remoteCandidate.NetworkType()}, a.includeLoopback) if err != nil { a.log.Warnf("Failed to iterate local interfaces, host candidates will not be gathered %s", err) return @@ -730,7 +738,9 @@ func (a *Agent) addCandidate(ctx context.Context, c Candidate, candidateConn net a.requestConnectivityCheck() - a.candidateNotifier.EnqueueCandidate(c) + if !c.filterForLocationTracking() { + a.candidateNotifier.EnqueueCandidate(c) + } }) } @@ -759,7 +769,12 @@ func (a *Agent) GetLocalCandidates() ([]Candidate, error) { err := a.loop.Run(a.loop, func(_ context.Context) { var candidates []Candidate for _, set := range a.localCandidates { - candidates = append(candidates, set...) + for _, c := range set { + if c.filterForLocationTracking() { + continue + } + candidates = append(candidates, c) + } } res = candidates }) @@ -841,9 +856,9 @@ func (a *Agent) deleteAllCandidates() { } func (a *Agent) findRemoteCandidate(networkType NetworkType, addr net.Addr) Candidate { - ip, port, _, ok := parseAddr(addr) - if !ok { - a.log.Warnf("Failed to parse address: %s", addr) + ip, port, _, err := parseAddr(addr) + if err != nil { + a.log.Warnf("Failed to parse address: %s; error: %s", addr, err) return nil } @@ -873,15 +888,15 @@ func (a *Agent) sendBindingRequest(m *stun.Message, local, remote Candidate) { func (a *Agent) sendBindingSuccess(m *stun.Message, local, remote Candidate) { base := remote - ip, port, _, ok := parseAddr(base.addr()) - if !ok { - a.log.Warnf("Failed to parse address: %s", base.addr()) + ip, port, _, err := parseAddr(base.addr()) + if err != nil { + a.log.Warnf("Failed to parse address: %s; error: %s", base.addr(), err) return } if out, err := stun.Build(m, stun.BindingSuccess, &stun.XORMappedAddress{ - IP: ip, + IP: ip.AsSlice(), Port: port, }, stun.NewShortTermIntegrity(a.localPwd), @@ -983,9 +998,9 @@ func (a *Agent) handleInbound(m *stun.Message, local Candidate, remote net.Addr) } if remoteCandidate == nil { - ip, port, networkType, ok := parseAddr(remote) - if !ok { - a.log.Errorf("Failed to create parse remote net.Addr when creating remote prflx candidate") + ip, port, networkType, err := parseAddr(remote) + if err != nil { + a.log.Errorf("Failed to create parse remote net.Addr when creating remote prflx candidate: %s", err) return } diff --git a/agent_test.go b/agent_test.go index bbb44cf2..552d4215 100644 --- a/agent_test.go +++ b/agent_test.go @@ -536,6 +536,7 @@ func TestConnectionStateCallback(t *testing.T) { DisconnectedTimeout: &disconnectedDuration, FailedTimeout: &failedDuration, KeepaliveInterval: &KeepaliveInterval, + InterfaceFilter: problematicNetworkInterfaces, } aAgent, err := NewAgent(cfg) diff --git a/candidate.go b/candidate.go index 92a00768..4324159b 100644 --- a/candidate.go +++ b/candidate.go @@ -61,6 +61,7 @@ type Candidate interface { Marshal() string addr() net.Addr + filterForLocationTracking() bool agent() *Agent context() context.Context diff --git a/candidate_base.go b/candidate_base.go index fd0f292a..192eb112 100644 --- a/candidate_base.go +++ b/candidate_base.go @@ -43,6 +43,7 @@ type candidateBase struct { priorityOverride uint32 remoteCandidateCaches map[AddrPort]Candidate + isLocationTracked bool } // Done implements context.Context @@ -435,6 +436,10 @@ func (c *candidateBase) addr() net.Addr { return c.resolvedAddr } +func (c *candidateBase) filterForLocationTracking() bool { + return c.isLocationTracked +} + func (c *candidateBase) agent() *Agent { return c.currAgent } @@ -551,7 +556,7 @@ func UnmarshalCandidate(raw string) (Candidate, error) { switch typ { case "host": - return NewCandidateHost(&CandidateHostConfig{"", protocol, address, port, component, priority, foundation, tcpType}) + return NewCandidateHost(&CandidateHostConfig{"", protocol, address, port, component, priority, foundation, tcpType, false}) case "srflx": return NewCandidateServerReflexive(&CandidateServerReflexiveConfig{"", protocol, address, port, component, priority, foundation, relatedAddress, relatedPort}) case "prflx": diff --git a/candidate_host.go b/candidate_host.go index 5d207dd9..c14b6a7c 100644 --- a/candidate_host.go +++ b/candidate_host.go @@ -4,7 +4,7 @@ package ice import ( - "net" + "net/netip" "strings" ) @@ -17,14 +17,15 @@ type CandidateHost struct { // CandidateHostConfig is the config required to create a new CandidateHost type CandidateHostConfig struct { - CandidateID string - Network string - Address string - Port int - Component uint16 - Priority uint32 - Foundation string - TCPType TCPType + CandidateID string + Network string + Address string + Port int + Component uint16 + Priority uint32 + Foundation string + TCPType TCPType + IsLocationTracked bool } // NewCandidateHost creates a new host candidate @@ -46,17 +47,18 @@ func NewCandidateHost(config *CandidateHostConfig) (*CandidateHost, error) { foundationOverride: config.Foundation, priorityOverride: config.Priority, remoteCandidateCaches: map[AddrPort]Candidate{}, + isLocationTracked: config.IsLocationTracked, }, network: config.Network, } if !strings.HasSuffix(config.Address, ".local") { - ip := net.ParseIP(config.Address) - if ip == nil { - return nil, ErrAddressParseFailed + ipAddr, err := netip.ParseAddr(config.Address) + if err != nil { + return nil, err } - if err := c.setIP(ip); err != nil { + if err := c.setIPAddr(ipAddr); err != nil { return nil, err } } else { @@ -67,14 +69,14 @@ func NewCandidateHost(config *CandidateHostConfig) (*CandidateHost, error) { return c, nil } -func (c *CandidateHost) setIP(ip net.IP) error { - networkType, err := determineNetworkType(c.network, ip) +func (c *CandidateHost) setIPAddr(addr netip.Addr) error { + networkType, err := determineNetworkType(c.network, addr) if err != nil { return err } c.candidateBase.networkType = networkType - c.candidateBase.resolvedAddr = createAddr(networkType, ip, c.port) + c.candidateBase.resolvedAddr = createAddr(networkType, addr, c.port) return nil } diff --git a/candidate_peer_reflexive.go b/candidate_peer_reflexive.go index bbcfe335..b28e9a76 100644 --- a/candidate_peer_reflexive.go +++ b/candidate_peer_reflexive.go @@ -6,7 +6,9 @@ //nolint:dupl package ice -import "net" +import ( + "net/netip" +) // CandidatePeerReflexive ... type CandidatePeerReflexive struct { @@ -28,12 +30,12 @@ type CandidatePeerReflexiveConfig struct { // NewCandidatePeerReflexive creates a new peer reflective candidate func NewCandidatePeerReflexive(config *CandidatePeerReflexiveConfig) (*CandidatePeerReflexive, error) { - ip := net.ParseIP(config.Address) - if ip == nil { - return nil, ErrAddressParseFailed + ipAddr, err := netip.ParseAddr(config.Address) + if err != nil { + return nil, err } - networkType, err := determineNetworkType(config.Network, ip) + networkType, err := determineNetworkType(config.Network, ipAddr) if err != nil { return nil, err } @@ -50,7 +52,7 @@ func NewCandidatePeerReflexive(config *CandidatePeerReflexiveConfig) (*Candidate candidateType: CandidateTypePeerReflexive, address: config.Address, port: config.Port, - resolvedAddr: createAddr(networkType, ip, config.Port), + resolvedAddr: createAddr(networkType, ipAddr, config.Port), component: config.Component, foundationOverride: config.Foundation, priorityOverride: config.Priority, diff --git a/candidate_relay.go b/candidate_relay.go index fa5297b2..faf281b0 100644 --- a/candidate_relay.go +++ b/candidate_relay.go @@ -5,6 +5,7 @@ package ice import ( "net" + "net/netip" ) // CandidateRelay ... @@ -38,24 +39,28 @@ func NewCandidateRelay(config *CandidateRelayConfig) (*CandidateRelay, error) { candidateID = globalCandidateIDGenerator.Generate() } - ip := net.ParseIP(config.Address) - if ip == nil { - return nil, ErrAddressParseFailed + ipAddr, err := netip.ParseAddr(config.Address) + if err != nil { + return nil, err } - networkType, err := determineNetworkType(config.Network, ip) + networkType, err := determineNetworkType(config.Network, ipAddr) if err != nil { return nil, err } return &CandidateRelay{ candidateBase: candidateBase{ - id: candidateID, - networkType: networkType, - candidateType: CandidateTypeRelay, - address: config.Address, - port: config.Port, - resolvedAddr: &net.UDPAddr{IP: ip, Port: config.Port}, + id: candidateID, + networkType: networkType, + candidateType: CandidateTypeRelay, + address: config.Address, + port: config.Port, + resolvedAddr: &net.UDPAddr{ + IP: ipAddr.AsSlice(), + Port: config.Port, + Zone: ipAddr.Zone(), + }, component: config.Component, foundationOverride: config.Foundation, priorityOverride: config.Priority, diff --git a/candidate_server_reflexive.go b/candidate_server_reflexive.go index 3a8ac0ff..85d613e2 100644 --- a/candidate_server_reflexive.go +++ b/candidate_server_reflexive.go @@ -3,7 +3,10 @@ package ice -import "net" +import ( + "net" + "net/netip" +) // CandidateServerReflexive ... type CandidateServerReflexive struct { @@ -25,12 +28,12 @@ type CandidateServerReflexiveConfig struct { // NewCandidateServerReflexive creates a new server reflective candidate func NewCandidateServerReflexive(config *CandidateServerReflexiveConfig) (*CandidateServerReflexive, error) { - ip := net.ParseIP(config.Address) - if ip == nil { - return nil, ErrAddressParseFailed + ipAddr, err := netip.ParseAddr(config.Address) + if err != nil { + return nil, err } - networkType, err := determineNetworkType(config.Network, ip) + networkType, err := determineNetworkType(config.Network, ipAddr) if err != nil { return nil, err } @@ -42,12 +45,16 @@ func NewCandidateServerReflexive(config *CandidateServerReflexiveConfig) (*Candi return &CandidateServerReflexive{ candidateBase: candidateBase{ - id: candidateID, - networkType: networkType, - candidateType: CandidateTypeServerReflexive, - address: config.Address, - port: config.Port, - resolvedAddr: &net.UDPAddr{IP: ip, Port: config.Port}, + id: candidateID, + networkType: networkType, + candidateType: CandidateTypeServerReflexive, + address: config.Address, + port: config.Port, + resolvedAddr: &net.UDPAddr{ + IP: ipAddr.AsSlice(), + Port: config.Port, + Zone: ipAddr.Zone(), + }, component: config.Component, foundationOverride: config.Foundation, priorityOverride: config.Priority, diff --git a/gather.go b/gather.go index 258e7fcb..ca614777 100644 --- a/gather.go +++ b/gather.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "net" + "net/netip" "reflect" "sync" "time" @@ -133,25 +134,37 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ delete(networks, udp) } - localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, networkTypes, a.includeLoopback) + _, localAddrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, networkTypes, a.includeLoopback) if err != nil { a.log.Warnf("Failed to iterate local interfaces, host candidates will not be gathered %s", err) return } - for _, ip := range localIPs { - mappedIP := ip + for _, addr := range localAddrs { + mappedIP := addr if a.mDNSMode != MulticastDNSModeQueryAndGather && a.extIPMapper != nil && a.extIPMapper.candidateType == CandidateTypeHost { - if _mappedIP, innerErr := a.extIPMapper.findExternalIP(ip.String()); innerErr == nil { - mappedIP = _mappedIP + if _mappedIP, innerErr := a.extIPMapper.findExternalIP(addr.String()); innerErr == nil { + conv, ok := netip.AddrFromSlice(_mappedIP) + if !ok { + a.log.Warnf("failed to convert mapped external IP to netip.Addr'%s'", addr.String()) + continue + } + // we'd rather have an IPv4-mapped IPv6 become IPv4 so that it is usable + mappedIP = conv.Unmap() } else { - a.log.Warnf("1:1 NAT mapping is enabled but no external IP is found for %s", ip.String()) + a.log.Warnf("1:1 NAT mapping is enabled but no external IP is found for %s", addr.String()) } } address := mappedIP.String() + var isLocationTracked bool if a.mDNSMode == MulticastDNSModeQueryAndGather { address = a.mDNSName + } else { + // Here, we are not doing multicast gathering, so we will need to skip this address so + // that we don't accidentally reveal location tracking information. Otherwise, the + // case above hides the IP behind an mDNS address. + isLocationTracked = shouldFilterLocationTrackedIP(mappedIP) } for network := range networks { @@ -174,16 +187,16 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ var muxConns []net.PacketConn if multi, ok := a.tcpMux.(AllConnsGetter); ok { a.log.Debugf("GetAllConns by ufrag: %s", a.localUfrag) - muxConns, err = multi.GetAllConns(a.localUfrag, mappedIP.To4() == nil, ip) + muxConns, err = multi.GetAllConns(a.localUfrag, mappedIP.Is6(), addr) if err != nil { - a.log.Warnf("Failed to get all TCP connections by ufrag: %s %s %s", network, ip, a.localUfrag) + a.log.Warnf("Failed to get all TCP connections by ufrag: %s %s %s", network, addr, a.localUfrag) continue } } else { a.log.Debugf("GetConn by ufrag: %s", a.localUfrag) - conn, err := a.tcpMux.GetConnByUfrag(a.localUfrag, mappedIP.To4() == nil, ip) + conn, err := a.tcpMux.GetConnByUfrag(a.localUfrag, mappedIP.Is6(), addr) if err != nil { - a.log.Warnf("Failed to get TCP connections by ufrag: %s %s %s", network, ip, a.localUfrag) + a.log.Warnf("Failed to get TCP connections by ufrag: %s %s %s", network, addr, a.localUfrag) continue } muxConns = []net.PacketConn{conn} @@ -194,7 +207,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ if tcpConn, ok := conn.LocalAddr().(*net.TCPAddr); ok { conns = append(conns, connAndPort{conn, tcpConn.Port}) } else { - a.log.Warnf("Failed to get port of connection from TCPMux: %s %s %s", network, ip, a.localUfrag) + a.log.Warnf("Failed to get port of connection from TCPMux: %s %s %s", network, addr, a.localUfrag) } } if len(conns) == 0 { @@ -205,16 +218,20 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ // Is there a way to verify that the listen address is even // accessible from the current interface. case udp: - conn, err := listenUDPInPortRange(a.net, a.log, int(a.portMax), int(a.portMin), network, &net.UDPAddr{IP: ip, Port: 0}) + conn, err := listenUDPInPortRange(a.net, a.log, int(a.portMax), int(a.portMin), network, &net.UDPAddr{ + IP: addr.AsSlice(), + Port: 0, + Zone: addr.Zone(), + }) if err != nil { - a.log.Warnf("Failed to listen %s %s", network, ip) + a.log.Warnf("Failed to listen %s %s", network, addr) continue } if udpConn, ok := conn.LocalAddr().(*net.UDPAddr); ok { conns = append(conns, connAndPort{conn, udpConn.Port}) } else { - a.log.Warnf("Failed to get port of UDPAddr from ListenUDPInPortRange: %s %s %s", network, ip, a.localUfrag) + a.log.Warnf("Failed to get port of UDPAddr from ListenUDPInPortRange: %s %s %s", network, addr, a.localUfrag) continue } } @@ -226,6 +243,9 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ Port: connAndPort.port, Component: ComponentRTP, TCPType: tcpType, + // we will still process this candidate so that we start up the right + // listeners. + IsLocationTracked: isLocationTracked, } c, err := NewCandidateHost(&hostConfig) @@ -235,7 +255,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ } if a.mDNSMode == MulticastDNSModeQueryAndGather { - if err = c.setIP(ip); err != nil { + if err = c.setIPAddr(addr); err != nil { closeConnAndLog(connAndPort.conn, a.log, "failed to create host candidate: %s %s %d: %v", network, mappedIP, connAndPort.port, err) continue } @@ -252,6 +272,27 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ } } +// shouldFilterLocationTrackedIP returns if this candidate IP should be filtered out from +// any candidate publishing/notification for location tracking reasons. +func shouldFilterLocationTrackedIP(candidateIP netip.Addr) bool { + // https://tools.ietf.org/html/rfc8445#section-5.1.1.1 + // Similarly, when host candidates corresponding to + // an IPv6 address generated using a mechanism that prevents location + // tracking are gathered, then host candidates corresponding to IPv6 + // link-local addresses [RFC4291] MUST NOT be gathered. + return candidateIP.Is6() && (candidateIP.IsLinkLocalUnicast() || candidateIP.IsLinkLocalMulticast()) +} + +// shouldFilterLocationTracked returns if this candidate IP should be filtered out from +// any candidate publishing/notification for location tracking reasons. +func shouldFilterLocationTracked(candidateIP net.IP) bool { + addr, ok := netip.AddrFromSlice(candidateIP) + if !ok { + return false + } + return shouldFilterLocationTrackedIP(addr) +} + func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolint:gocognit if a.udpMux == nil { return errUDPMuxDisabled @@ -286,17 +327,23 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolin } var address string + var isLocationTracked bool if a.mDNSMode == MulticastDNSModeQueryAndGather { address = a.mDNSName } else { address = candidateIP.String() + // Here, we are not doing multicast gathering, so we will need to skip this address so + // that we don't accidentally reveal location tracking information. Otherwise, the + // case above hides the IP behind an mDNS address. + isLocationTracked = shouldFilterLocationTracked(candidateIP) } hostConfig := CandidateHostConfig{ - Network: udp, - Address: address, - Port: udpAddr.Port, - Component: ComponentRTP, + Network: udp, + Address: address, + Port: udpAddr.Port, + Component: ComponentRTP, + IsLocationTracked: isLocationTracked, } // Detect a duplicate candidate before calling addCandidate(). @@ -365,6 +412,11 @@ func (a *Agent) gatherCandidatesSrflxMapped(ctx context.Context, networkTypes [] return } + if shouldFilterLocationTracked(mappedIP) { + closeConnAndLog(conn, a.log, "external IP is somehow filtered for location tracking reasons %s", mappedIP) + return + } + srflxConfig := CandidateServerReflexiveConfig{ Network: network, Address: mappedIP.String(), @@ -420,6 +472,11 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*stun.UR return } + if shouldFilterLocationTracked(serverAddr.IP) { + a.log.Warnf("STUN host %s is somehow filtered for location tracking reasons", hostPort) + return + } + xorAddr, err := a.udpMuxSrflx.GetXORMappedAddr(serverAddr, stunGatherTimeout) if err != nil { a.log.Warnf("Failed get server reflexive address %s %s: %v", network, url, err) @@ -482,6 +539,11 @@ func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, net return } + if shouldFilterLocationTracked(serverAddr.IP) { + a.log.Warnf("STUN host %s is somehow filtered for location tracking reasons", hostPort) + return + } + conn, err := listenUDPInPortRange(a.net, a.log, int(a.portMax), int(a.portMin), network, &net.UDPAddr{IP: nil, Port: 0}) if err != nil { closeConnAndLog(conn, a.log, "failed to listen for %s: %v", serverAddr.String(), err) @@ -696,6 +758,12 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { / } rAddr := relayConn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert + + if shouldFilterLocationTracked(rAddr.IP) { + a.log.Warnf("TURN address %s is somehow filtered for location tracking reasons", rAddr.IP) + return + } + relayConfig := CandidateRelayConfig{ Network: network, Component: ComponentRTP, diff --git a/gather_test.go b/gather_test.go index 54c8a8db..1e4b8960 100644 --- a/gather_test.go +++ b/gather_test.go @@ -34,11 +34,11 @@ func TestListenUDP(t *testing.T) { a, err := NewAgent(&AgentConfig{}) require.NoError(t, err) - localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) - require.NotEqual(t, len(localIPs), 0, "localInterfaces found no interfaces, unable to test") + _, localAddrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) + require.NotEqual(t, len(localAddrs), 0, "localInterfaces found no interfaces, unable to test") require.NoError(t, err) - ip := localIPs[0] + ip := localAddrs[0].AsSlice() conn, err := listenUDPInPortRange(a.net, a.log, 0, 0, udp, &net.UDPAddr{IP: ip, Port: 0}) require.NoError(t, err, "listenUDP error with no port restriction") diff --git a/gather_vnet_test.go b/gather_vnet_test.go index ce11e839..9fda3fdc 100644 --- a/gather_vnet_test.go +++ b/gather_vnet_test.go @@ -34,7 +34,7 @@ func TestVNetGather(t *testing.T) { }) require.NoError(t, err) - localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) + _, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) if len(localIPs) > 0 { t.Fatal("should return no local IP") } @@ -73,17 +73,17 @@ func TestVNetGather(t *testing.T) { }) require.NoError(t, err) - localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) - if len(localIPs) == 0 { + _, localAddrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) + if len(localAddrs) == 0 { t.Fatal("should have one local IP") } require.NoError(t, err) - for _, ip := range localIPs { - if ip.IsLoopback() { + for _, addr := range localAddrs { + if addr.IsLoopback() { t.Fatal("should not return loopback IP") } - if !ipNet.Contains(ip) { + if !ipNet.Contains(addr.AsSlice()) { t.Fatal("should be contained in the CIDR") } } @@ -115,13 +115,13 @@ func TestVNetGather(t *testing.T) { t.Fatalf("Failed to create agent: %s", err) } - localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) - if len(localIPs) == 0 { + _, localAddrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) + if len(localAddrs) == 0 { t.Fatal("localInterfaces found no interfaces, unable to test") } require.NoError(t, err) - ip := localIPs[0] + ip := localAddrs[0].AsSlice() conn, err := listenUDPInPortRange(a.net, a.log, 0, 0, udp, &net.UDPAddr{IP: ip, Port: 0}) if err != nil { @@ -385,7 +385,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) { }) require.NoError(t, err) - localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) + _, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) require.NoError(t, err) if len(localIPs) != 0 { @@ -405,7 +405,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) { }) require.NoError(t, err) - localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) + _, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) require.NoError(t, err) if len(localIPs) != 0 { @@ -425,7 +425,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) { }) require.NoError(t, err) - localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) + _, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) require.NoError(t, err) if len(localIPs) == 0 { diff --git a/mdns.go b/mdns.go index a3ad899b..1308d2a7 100644 --- a/mdns.go +++ b/mdns.go @@ -4,11 +4,14 @@ package ice import ( + "net" + "github.com/google/uuid" "github.com/pion/logging" "github.com/pion/mdns/v2" "github.com/pion/transport/v3" "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" ) // MulticastDNSMode represents the different Multicast modes ICE can run in @@ -33,30 +36,98 @@ func generateMulticastDNSName() (string, error) { return u.String() + ".local", err } -func createMulticastDNS(n transport.Net, mDNSMode MulticastDNSMode, mDNSName string, log logging.LeveledLogger) (*mdns.Conn, MulticastDNSMode, error) { +func createMulticastDNS( + n transport.Net, + networkTypes []NetworkType, + interfaces []*transport.Interface, + includeLoopback bool, + mDNSMode MulticastDNSMode, + mDNSName string, + log logging.LeveledLogger, +) (*mdns.Conn, MulticastDNSMode, error) { if mDNSMode == MulticastDNSModeDisabled { return nil, mDNSMode, nil } - addr, mdnsErr := n.ResolveUDPAddr("udp4", mdns.DefaultAddressIPv4) + var useV4, useV6 bool + if len(networkTypes) == 0 { + useV4 = true + useV6 = true + } else { + for _, nt := range networkTypes { + if nt.IsIPv4() { + useV4 = true + continue + } + if nt.IsIPv6() { + useV6 = true + } + } + } + + addr4, mdnsErr := n.ResolveUDPAddr("udp4", mdns.DefaultAddressIPv4) if mdnsErr != nil { return nil, mDNSMode, mdnsErr } - - l, mdnsErr := n.ListenUDP("udp4", addr) + addr6, mdnsErr := n.ResolveUDPAddr("udp6", mdns.DefaultAddressIPv6) if mdnsErr != nil { + return nil, mDNSMode, mdnsErr + } + + var pktConnV4 *ipv4.PacketConn + var mdns4Err error + if useV4 { + var l transport.UDPConn + l, mdns4Err = n.ListenUDP("udp4", addr4) + if mdns4Err != nil { + // If ICE fails to start MulticastDNS server just warn the user and continue + log.Errorf("Failed to enable mDNS over IPv4: (%s)", mdns4Err) + return nil, MulticastDNSModeDisabled, nil + } else { + pktConnV4 = ipv4.NewPacketConn(l) + } + } + + var pktConnV6 *ipv6.PacketConn + var mdns6Err error + if useV6 { + var l transport.UDPConn + l, mdns6Err = n.ListenUDP("udp6", addr6) + if mdns6Err != nil { + log.Errorf("Failed to enable mDNS over IPv6: (%s)", mdns6Err) + return nil, MulticastDNSModeDisabled, nil + } else { + pktConnV6 = ipv6.NewPacketConn(l) + } + } + + if mdns4Err != nil && mdns6Err != nil { // If ICE fails to start MulticastDNS server just warn the user and continue - log.Errorf("Failed to enable mDNS, continuing in mDNS disabled mode: (%s)", mdnsErr) + log.Errorf("Failed to enable mDNS, continuing in mDNS disabled mode") return nil, MulticastDNSModeDisabled, nil + + } + + var ifcs []net.Interface + if interfaces != nil { + ifcs = make([]net.Interface, 0, len(ifcs)) + for _, ifc := range interfaces { + ifcs = append(ifcs, ifc.Interface) + } } switch mDNSMode { case MulticastDNSModeQueryOnly: - conn, err := mdns.Server(ipv4.NewPacketConn(l), nil, &mdns.Config{}) + conn, err := mdns.Server(pktConnV4, pktConnV6, &mdns.Config{ + Interfaces: ifcs, + IncludeLoopback: includeLoopback, + }) return conn, mDNSMode, err case MulticastDNSModeQueryAndGather: - conn, err := mdns.Server(ipv4.NewPacketConn(l), nil, &mdns.Config{ - LocalNames: []string{mDNSName}, + conn, err := mdns.Server(pktConnV4, pktConnV6, &mdns.Config{ + Interfaces: ifcs, + IncludeLoopback: includeLoopback, + LocalNames: []string{mDNSName}, }) return conn, mDNSMode, err default: diff --git a/mdns_test.go b/mdns_test.go index a82d6b1d..04fdf916 100644 --- a/mdns_test.go +++ b/mdns_test.go @@ -22,30 +22,43 @@ func TestMulticastDNSOnlyConnection(t *testing.T) { // Limit runtime in case of deadlocks defer test.TimeOut(time.Second * 30).Stop() - cfg := &AgentConfig{ - NetworkTypes: []NetworkType{NetworkTypeUDP4}, - CandidateTypes: []CandidateType{CandidateTypeHost}, - MulticastDNSMode: MulticastDNSModeQueryAndGather, + for _, tc := range []struct { + Name string + NetworkTypes []NetworkType + }{ + {Name: "UDP4", NetworkTypes: []NetworkType{NetworkTypeUDP4}}, + {Name: "UDP6", NetworkTypes: []NetworkType{NetworkTypeUDP6}}, + {Name: "UDP46", NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}}, + } { + t.Run(tc.Name, func(t *testing.T) { + + cfg := &AgentConfig{ + NetworkTypes: tc.NetworkTypes, + CandidateTypes: []CandidateType{CandidateTypeHost}, + MulticastDNSMode: MulticastDNSModeQueryAndGather, + InterfaceFilter: problematicNetworkInterfaces, + } + + aAgent, err := NewAgent(cfg) + require.NoError(t, err) + + aNotifier, aConnected := onConnected() + require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) + + bAgent, err := NewAgent(cfg) + require.NoError(t, err) + + bNotifier, bConnected := onConnected() + require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) + + connect(aAgent, bAgent) + <-aConnected + <-bConnected + + require.NoError(t, aAgent.Close()) + require.NoError(t, bAgent.Close()) + }) } - - aAgent, err := NewAgent(cfg) - require.NoError(t, err) - - aNotifier, aConnected := onConnected() - require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) - - bAgent, err := NewAgent(cfg) - require.NoError(t, err) - - bNotifier, bConnected := onConnected() - require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) - - connect(aAgent, bAgent) - <-aConnected - <-bConnected - - require.NoError(t, aAgent.Close()) - require.NoError(t, bAgent.Close()) } func TestMulticastDNSMixedConnection(t *testing.T) { @@ -54,32 +67,46 @@ func TestMulticastDNSMixedConnection(t *testing.T) { // Limit runtime in case of deadlocks defer test.TimeOut(time.Second * 30).Stop() - aAgent, err := NewAgent(&AgentConfig{ - NetworkTypes: []NetworkType{NetworkTypeUDP4}, - CandidateTypes: []CandidateType{CandidateTypeHost}, - MulticastDNSMode: MulticastDNSModeQueryAndGather, - }) - require.NoError(t, err) - - aNotifier, aConnected := onConnected() - require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) - - bAgent, err := NewAgent(&AgentConfig{ - NetworkTypes: []NetworkType{NetworkTypeUDP4}, - CandidateTypes: []CandidateType{CandidateTypeHost}, - MulticastDNSMode: MulticastDNSModeQueryOnly, - }) - require.NoError(t, err) - - bNotifier, bConnected := onConnected() - require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) - - connect(aAgent, bAgent) - <-aConnected - <-bConnected - - require.NoError(t, aAgent.Close()) - require.NoError(t, bAgent.Close()) + for _, tc := range []struct { + Name string + NetworkTypes []NetworkType + }{ + {Name: "UDP4", NetworkTypes: []NetworkType{NetworkTypeUDP4}}, + {Name: "UDP6", NetworkTypes: []NetworkType{NetworkTypeUDP6}}, + {Name: "UDP46", NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}}, + } { + t.Run(tc.Name, func(t *testing.T) { + + aAgent, err := NewAgent(&AgentConfig{ + NetworkTypes: tc.NetworkTypes, + CandidateTypes: []CandidateType{CandidateTypeHost}, + MulticastDNSMode: MulticastDNSModeQueryAndGather, + InterfaceFilter: problematicNetworkInterfaces, + }) + require.NoError(t, err) + + aNotifier, aConnected := onConnected() + require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) + + bAgent, err := NewAgent(&AgentConfig{ + NetworkTypes: tc.NetworkTypes, + CandidateTypes: []CandidateType{CandidateTypeHost}, + MulticastDNSMode: MulticastDNSModeQueryOnly, + InterfaceFilter: problematicNetworkInterfaces, + }) + require.NoError(t, err) + + bNotifier, bConnected := onConnected() + require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) + + connect(aAgent, bAgent) + <-aConnected + <-bConnected + + require.NoError(t, aAgent.Close()) + require.NoError(t, bAgent.Close()) + }) + } } func TestMulticastDNSStaticHostName(t *testing.T) { @@ -87,32 +114,46 @@ func TestMulticastDNSStaticHostName(t *testing.T) { defer test.TimeOut(time.Second * 30).Stop() - _, err := NewAgent(&AgentConfig{ - NetworkTypes: []NetworkType{NetworkTypeUDP4}, - CandidateTypes: []CandidateType{CandidateTypeHost}, - MulticastDNSMode: MulticastDNSModeQueryAndGather, - MulticastDNSHostName: "invalidHostName", - }) - require.Equal(t, err, ErrInvalidMulticastDNSHostName) - - agent, err := NewAgent(&AgentConfig{ - NetworkTypes: []NetworkType{NetworkTypeUDP4}, - CandidateTypes: []CandidateType{CandidateTypeHost}, - MulticastDNSMode: MulticastDNSModeQueryAndGather, - MulticastDNSHostName: "validName.local", - }) - require.NoError(t, err) - - correctHostName, resolveFunc := context.WithCancel(context.Background()) - require.NoError(t, agent.OnCandidate(func(c Candidate) { - if c != nil && c.Address() == "validName.local" { - resolveFunc() - } - })) - - require.NoError(t, agent.GatherCandidates()) - <-correctHostName.Done() - require.NoError(t, agent.Close()) + for _, tc := range []struct { + Name string + NetworkTypes []NetworkType + }{ + {Name: "UDP4", NetworkTypes: []NetworkType{NetworkTypeUDP4}}, + {Name: "UDP6", NetworkTypes: []NetworkType{NetworkTypeUDP6}}, + {Name: "UDP46", NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}}, + } { + t.Run(tc.Name, func(t *testing.T) { + + _, err := NewAgent(&AgentConfig{ + NetworkTypes: tc.NetworkTypes, + CandidateTypes: []CandidateType{CandidateTypeHost}, + MulticastDNSMode: MulticastDNSModeQueryAndGather, + MulticastDNSHostName: "invalidHostName", + InterfaceFilter: problematicNetworkInterfaces, + }) + require.Equal(t, err, ErrInvalidMulticastDNSHostName) + + agent, err := NewAgent(&AgentConfig{ + NetworkTypes: tc.NetworkTypes, + CandidateTypes: []CandidateType{CandidateTypeHost}, + MulticastDNSMode: MulticastDNSModeQueryAndGather, + MulticastDNSHostName: "validName.local", + InterfaceFilter: problematicNetworkInterfaces, + }) + require.NoError(t, err) + + correctHostName, resolveFunc := context.WithCancel(context.Background()) + require.NoError(t, agent.OnCandidate(func(c Candidate) { + if c != nil && c.Address() == "validName.local" { + resolveFunc() + } + })) + + require.NoError(t, agent.GatherCandidates()) + <-correctHostName.Done() + require.NoError(t, agent.Close()) + }) + } } func TestGenerateMulticastDNSName(t *testing.T) { diff --git a/net.go b/net.go index f686d411..7261a1d9 100644 --- a/net.go +++ b/net.go @@ -5,6 +5,7 @@ package ice import ( "net" + "net/netip" "github.com/pion/logging" "github.com/pion/transport/v3" @@ -12,12 +13,15 @@ import ( // The conditions of invalidation written below are defined in // https://tools.ietf.org/html/rfc8445#section-5.1.1.1 -func isSupportedIPv6(ip net.IP) bool { +// It is partial because the link-local check is done later in various gather local +// candidate methods which conditionally accept IPv6 based on usage of mDNS or not. +func isSupportedIPv6Partial(ip net.IP) bool { if len(ip) != net.IPv6len || + // Deprecated IPv4-compatible IPv6 addresses [RFC4291] and IPv6 site- + // local unicast addresses [RFC3879] MUST NOT be included in the + // address candidates. isZeros(ip[0:12]) || // !(IPv4-compatible IPv6) - ip[0] == 0xfe && ip[1]&0xc0 == 0xc0 || // !(IPv6 site-local unicast) - ip.IsLinkLocalUnicast() || - ip.IsLinkLocalMulticast() { + ip[0] == 0xfe && ip[1]&0xc0 == 0xc0 { // !(IPv6 site-local unicast) return false } return true @@ -32,21 +36,34 @@ func isZeros(ip net.IP) bool { return true } -func localInterfaces(n transport.Net, interfaceFilter func(string) bool, ipFilter func(net.IP) bool, networkTypes []NetworkType, includeLoopback bool) ([]net.IP, error) { //nolint:gocognit - ips := []net.IP{} +func localInterfaces( + n transport.Net, + interfaceFilter func(string) bool, + ipFilter func(net.IP) bool, + networkTypes []NetworkType, + includeLoopback bool, +) ([]*transport.Interface, []netip.Addr, error) { //nolint:gocognit + ipAddrs := []netip.Addr{} ifaces, err := n.Interfaces() if err != nil { - return ips, err + return nil, ipAddrs, err } - var IPv4Requested, IPv6Requested bool - for _, typ := range networkTypes { - if typ.IsIPv4() { - IPv4Requested = true - } + filteredIfaces := make([]*transport.Interface, 0, len(ifaces)) + + var ipV4Requested, ipv6Requested bool + if len(networkTypes) == 0 { + ipV4Requested = true + ipv6Requested = true + } else { + for _, typ := range networkTypes { + if typ.IsIPv4() { + ipV4Requested = true + } - if typ.IsIPv6() { - IPv6Requested = true + if typ.IsIPv6() { + ipv6Requested = true + } } } @@ -62,41 +79,41 @@ func localInterfaces(n transport.Net, interfaceFilter func(string) bool, ipFilte continue } - addrs, err := iface.Addrs() + ifaceAddrs, err := iface.Addrs() if err != nil { continue } - for _, addr := range addrs { - var ip net.IP - switch addr := addr.(type) { - case *net.IPNet: - ip = addr.IP - case *net.IPAddr: - ip = addr.IP - } - if ip == nil || (ip.IsLoopback() && !includeLoopback) { + atLeastOneAddr := false + for _, addr := range ifaceAddrs { + ipAddr, _, _, err := parseAddrFromIface(addr, iface.Name) + if err != nil || (ipAddr.IsLoopback() && !includeLoopback) { continue } - - if ipv4 := ip.To4(); ipv4 == nil { - if !IPv6Requested { + if ipAddr.Is6() { + if !ipv6Requested { continue - } else if !isSupportedIPv6(ip) { + } else if !isSupportedIPv6Partial(ipAddr.AsSlice()) { continue } - } else if !IPv4Requested { + } else if !ipV4Requested { continue } - if ipFilter != nil && !ipFilter(ip) { + if ipFilter != nil && !ipFilter(ipAddr.AsSlice()) { continue } - ips = append(ips, ip) + atLeastOneAddr = true + ipAddrs = append(ipAddrs, ipAddr) + } + + if atLeastOneAddr { + ifaceCopy := iface + filteredIfaces = append(filteredIfaces, ifaceCopy) } } - return ips, nil + return filteredIfaces, ipAddrs, nil } func listenUDPInPortRange(n transport.Net, log logging.LeveledLogger, portMax, portMin int, network string, lAddr *net.UDPAddr) (transport.UDPConn, error) { diff --git a/net_test.go b/net_test.go index 949fd813..1adf80c2 100644 --- a/net_test.go +++ b/net_test.go @@ -5,40 +5,58 @@ package ice import ( "net" + "strings" "testing" "github.com/stretchr/testify/require" ) -func TestIsSupportedIPv6(t *testing.T) { - if isSupportedIPv6(net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1}) { - t.Errorf("isSupportedIPv6 return true with IPv4-compatible IPv6 address") +func TestIsSupportedIPv6Partial(t *testing.T) { + if isSupportedIPv6Partial(net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1}) { + t.Errorf("isSupportedIPv6Partial returned true with IPv4-compatible IPv6 address") } - if isSupportedIPv6(net.ParseIP("fec0::2333")) { - t.Errorf("isSupportedIPv6 return true with IPv6 site-local unicast address") + if isSupportedIPv6Partial(net.ParseIP("fec0::2333")) { + t.Errorf("isSupportedIPv6Partial returned true with IPv6 site-local unicast address") } - if isSupportedIPv6(net.ParseIP("fe80::2333")) { - t.Errorf("isSupportedIPv6 return true with IPv6 link-local address") + if !isSupportedIPv6Partial(net.ParseIP("fe80::2333")) { + t.Errorf("isSupportedIPv6Partial returned false with IPv6 link-local address") } - if isSupportedIPv6(net.ParseIP("ff02::2333")) { - t.Errorf("isSupportedIPv6 return true with IPv6 link-local multicast address") + if !isSupportedIPv6Partial(net.ParseIP("ff02::2333")) { + t.Errorf("isSupportedIPv6Partial returned false with IPv6 link-local multicast address") } - if !isSupportedIPv6(net.ParseIP("2001::1")) { - t.Errorf("isSupportedIPv6 return false with IPv6 global unicast address") + if !isSupportedIPv6Partial(net.ParseIP("2001::1")) { + t.Errorf("isSupportedIPv6Partial returned false with IPv6 global unicast address") } } func TestCreateAddr(t *testing.T) { - ipv4 := net.IP{127, 0, 0, 1} - ipv6 := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} + ipv4 := mustAddr(net.IP{127, 0, 0, 1}) + ipv6 := mustAddr(net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}) port := 9000 - require.Equal(t, &net.UDPAddr{IP: ipv4, Port: port}, createAddr(NetworkTypeUDP4, ipv4, port)) - require.Equal(t, &net.UDPAddr{IP: ipv6, Port: port}, createAddr(NetworkTypeUDP6, ipv6, port)) - require.Equal(t, &net.TCPAddr{IP: ipv4, Port: port}, createAddr(NetworkTypeTCP4, ipv4, port)) - require.Equal(t, &net.TCPAddr{IP: ipv6, Port: port}, createAddr(NetworkTypeTCP6, ipv6, port)) + require.Equal(t, &net.UDPAddr{IP: ipv4.AsSlice(), Port: port}, createAddr(NetworkTypeUDP4, ipv4, port)) + require.Equal(t, &net.UDPAddr{IP: ipv6.AsSlice(), Port: port}, createAddr(NetworkTypeUDP6, ipv6, port)) + require.Equal(t, &net.TCPAddr{IP: ipv4.AsSlice(), Port: port}, createAddr(NetworkTypeTCP4, ipv4, port)) + require.Equal(t, &net.TCPAddr{IP: ipv6.AsSlice(), Port: port}, createAddr(NetworkTypeTCP6, ipv6, port)) +} + +func problematicNetworkInterfaces(s string) bool { + defaultDockerBridgeNetwork := strings.Contains(s, "docker") + customDockerBridgeNetwork := strings.Contains(s, "br-") + + // Apple filters + accessPoint := strings.Contains(s, "ap") + appleWirelessDirectLink := strings.Contains(s, "awdl") + appleLowLatencyWLANInterface := strings.Contains(s, "llw") + appleTunnelingInterface := strings.Contains(s, "utun") + return !defaultDockerBridgeNetwork && + !customDockerBridgeNetwork && + !accessPoint && + !appleWirelessDirectLink && + !appleLowLatencyWLANInterface && + !appleTunnelingInterface } diff --git a/networktype.go b/networktype.go index 57df1863..376b56c7 100644 --- a/networktype.go +++ b/networktype.go @@ -5,7 +5,7 @@ package ice import ( "fmt" - "net" + "net/netip" "strings" ) @@ -116,18 +116,18 @@ func (t NetworkType) IsIPv6() bool { // determineNetworkType determines the type of network based on // the short network string and an IP address. -func determineNetworkType(network string, ip net.IP) (NetworkType, error) { - ipv4 := ip.To4() != nil - +func determineNetworkType(network string, ip netip.Addr) (NetworkType, error) { + // we'd rather have an IPv4-mapped IPv6 become IPv4 so that it is usable. + ip = ip.Unmap() switch { case strings.HasPrefix(strings.ToLower(network), udp): - if ipv4 { + if ip.Is4() { return NetworkTypeUDP4, nil } return NetworkTypeUDP6, nil case strings.HasPrefix(strings.ToLower(network), tcp): - if ipv4 { + if ip.Is4() { return NetworkTypeTCP4, nil } return NetworkTypeTCP6, nil diff --git a/networktype_test.go b/networktype_test.go index eb4a2e47..5954f404 100644 --- a/networktype_test.go +++ b/networktype_test.go @@ -45,7 +45,7 @@ func TestNetworkTypeParsing_Success(t *testing.T) { NetworkTypeUDP6, }, } { - actual, err := determineNetworkType(test.inNetwork, test.inIP) + actual, err := determineNetworkType(test.inNetwork, mustAddr(test.inIP)) if err != nil { t.Errorf("NetworkTypeParsing failed: %v", err) } @@ -70,7 +70,7 @@ func TestNetworkTypeParsing_Failure(t *testing.T) { ipv6, }, } { - actual, err := determineNetworkType(test.inNetwork, test.inIP) + actual, err := determineNetworkType(test.inNetwork, mustAddr(test.inIP)) if err == nil { t.Errorf("NetworkTypeParsing should fail: '%s' -- input:%s actual:%s", test.name, test.inNetwork, actual) diff --git a/tcp_mux.go b/tcp_mux.go index c5608b39..1501557f 100644 --- a/tcp_mux.go +++ b/tcp_mux.go @@ -8,6 +8,7 @@ import ( "errors" "io" "net" + "net/netip" "strings" "sync" "time" @@ -24,7 +25,7 @@ var ErrGetTransportAddress = errors.New("failed to get local transport address") // interface exists to allow mocking in tests. type TCPMux interface { io.Closer - GetConnByUfrag(ufrag string, isIPv6 bool, local net.IP) (net.PacketConn, error) + GetConnByUfrag(ufrag string, isIPv6 bool, local netip.Addr) (net.PacketConn, error) RemoveConnByUfrag(ufrag string) } @@ -120,7 +121,7 @@ func (m *TCPMuxDefault) LocalAddr() net.Addr { } // GetConnByUfrag retrieves an existing or creates a new net.PacketConn. -func (m *TCPMuxDefault) GetConnByUfrag(ufrag string, isIPv6 bool, local net.IP) (net.PacketConn, error) { +func (m *TCPMuxDefault) GetConnByUfrag(ufrag string, isIPv6 bool, local netip.Addr) (net.PacketConn, error) { m.mu.Lock() defer m.mu.Unlock() @@ -136,13 +137,14 @@ func (m *TCPMuxDefault) GetConnByUfrag(ufrag string, isIPv6 bool, local net.IP) return m.createConn(ufrag, isIPv6, local, false) } -func (m *TCPMuxDefault) createConn(ufrag string, isIPv6 bool, local net.IP, fromStun bool) (*tcpPacketConn, error) { +func (m *TCPMuxDefault) createConn(ufrag string, isIPv6 bool, local netip.Addr, fromStun bool) (*tcpPacketConn, error) { addr, ok := m.LocalAddr().(*net.TCPAddr) if !ok { return nil, ErrGetTransportAddress } localAddr := *addr - localAddr.IP = local + localAddr.IP = local.AsSlice() + localAddr.Zone = local.Zone() var alive time.Duration if fromStun { @@ -169,13 +171,14 @@ func (m *TCPMuxDefault) createConn(ufrag string, isIPv6 bool, local net.IP, from m.connsIPv4[ufrag] = conns } } - conns[ipAddr(local.String())] = conn + connKey := ipAddr(local.String()) + conns[connKey] = conn m.wg.Add(1) go func() { defer m.wg.Done() <-conn.CloseChannel() - m.removeConnByUfragAndLocalHost(ufrag, local) + m.removeConnByUfragAndLocalHost(ufrag, connKey) }() return conn, nil @@ -259,15 +262,24 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) { return } m.mu.Lock() - packetConn, ok := m.getConn(ufrag, isIPv6, localAddr.IP) + + ipAddr, err := ipAddrToNetIP(localAddr.IP, localAddr.Zone) + if err != nil { + m.closeAndLogError(conn) + m.params.Logger.Warnf("Failed to get ip address for STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr()) + return + } + + packetConn, ok := m.getConn(ufrag, isIPv6, ipAddr) if !ok { - packetConn, err = m.createConn(ufrag, isIPv6, localAddr.IP, true) + packetConn, err = m.createConn(ufrag, isIPv6, ipAddr, true) if err != nil { m.mu.Unlock() m.closeAndLogError(conn) m.params.Logger.Warnf("Failed to create packetConn for STUN message from %s to %s", conn.RemoteAddr(), conn.LocalAddr()) return } + } else { } m.mu.Unlock() @@ -334,15 +346,14 @@ func (m *TCPMuxDefault) RemoveConnByUfrag(ufrag string) { } } -func (m *TCPMuxDefault) removeConnByUfragAndLocalHost(ufrag string, local net.IP) { +func (m *TCPMuxDefault) removeConnByUfragAndLocalHost(ufrag string, localIPAddr ipAddr) { removedConns := make([]*tcpPacketConn, 0, 4) - localIP := ipAddr(local.String()) // Keep lock section small to avoid deadlock with conn lock m.mu.Lock() if conns, ok := m.connsIPv4[ufrag]; ok { - if conn, ok := conns[localIP]; ok { - delete(conns, localIP) + if conn, ok := conns[localIPAddr]; ok { + delete(conns, localIPAddr) if len(conns) == 0 { delete(m.connsIPv4, ufrag) } @@ -350,8 +361,8 @@ func (m *TCPMuxDefault) removeConnByUfragAndLocalHost(ufrag string, local net.IP } } if conns, ok := m.connsIPv6[ufrag]; ok { - if conn, ok := conns[localIP]; ok { - delete(conns, localIP) + if conn, ok := conns[localIPAddr]; ok { + delete(conns, localIPAddr) if len(conns) == 0 { delete(m.connsIPv6, ufrag) } @@ -367,7 +378,7 @@ func (m *TCPMuxDefault) removeConnByUfragAndLocalHost(ufrag string, local net.IP } } -func (m *TCPMuxDefault) getConn(ufrag string, isIPv6 bool, local net.IP) (val *tcpPacketConn, ok bool) { +func (m *TCPMuxDefault) getConn(ufrag string, isIPv6 bool, local netip.Addr) (val *tcpPacketConn, ok bool) { var conns map[ipAddr]*tcpPacketConn if isIPv6 { conns, ok = m.connsIPv6[ufrag] @@ -375,7 +386,8 @@ func (m *TCPMuxDefault) getConn(ufrag string, isIPv6 bool, local net.IP) (val *t conns, ok = m.connsIPv4[ufrag] } if conns != nil { - val, ok = conns[ipAddr(local.String())] + connKey := ipAddr(local.String()) + val, ok = conns[connKey] } return diff --git a/tcp_mux_multi.go b/tcp_mux_multi.go index e32acbf3..9ff62f25 100644 --- a/tcp_mux_multi.go +++ b/tcp_mux_multi.go @@ -3,14 +3,17 @@ package ice -import "net" +import ( + "net" + "net/netip" +) // AllConnsGetter allows multiple fixed TCP ports to be used, // each of which is multiplexed like TCPMux. AllConnsGetter also acts as // a TCPMux, in which case it will return a single connection for one // of the ports. type AllConnsGetter interface { - GetAllConns(ufrag string, isIPv6 bool, localIP net.IP) ([]net.PacketConn, error) + GetAllConns(ufrag string, isIPv6 bool, localIP netip.Addr) ([]net.PacketConn, error) } // MultiTCPMuxDefault implements both TCPMux and AllConnsGetter, @@ -32,7 +35,7 @@ func NewMultiTCPMuxDefault(muxes ...TCPMux) *MultiTCPMuxDefault { // creates the connection if an existing one can't be found. This, unlike // GetAllConns, will only return a single PacketConn from the first mux that was // passed in to NewMultiTCPMuxDefault. -func (m *MultiTCPMuxDefault) GetConnByUfrag(ufrag string, isIPv6 bool, local net.IP) (net.PacketConn, error) { +func (m *MultiTCPMuxDefault) GetConnByUfrag(ufrag string, isIPv6 bool, local netip.Addr) (net.PacketConn, error) { // NOTE: We always use the first element here in order to maintain the // behavior of using an existing connection if one exists. if len(m.muxes) == 0 { @@ -50,7 +53,7 @@ func (m *MultiTCPMuxDefault) RemoveConnByUfrag(ufrag string) { } // GetAllConns returns a PacketConn for each underlying TCPMux -func (m *MultiTCPMuxDefault) GetAllConns(ufrag string, isIPv6 bool, local net.IP) ([]net.PacketConn, error) { +func (m *MultiTCPMuxDefault) GetAllConns(ufrag string, isIPv6 bool, local netip.Addr) ([]net.PacketConn, error) { if len(m.muxes) == 0 { // Make sure that we either return at least one connection or an error. return nil, errNoTCPMuxAvailable diff --git a/tcp_mux_multi_test.go b/tcp_mux_multi_test.go index 56306890..038f09ec 100644 --- a/tcp_mux_multi_test.go +++ b/tcp_mux_multi_test.go @@ -7,8 +7,10 @@ package ice import ( + "fmt" "io" "net" + "net/netip" "testing" "github.com/pion/logging" @@ -54,7 +56,7 @@ func TestMultiTCPMux_Recv(t *testing.T) { _ = multiMux.Close() }() - pktConns, err := multiMux.GetAllConns("myufrag", false, net.IP{127, 0, 0, 1}) + pktConns, err := multiMux.GetAllConns("myufrag", false, mustAddr(net.IP{127, 0, 0, 1})) require.NoError(t, err, "error retrieving muxed connection for ufrag") for _, pktConn := range pktConns { @@ -117,12 +119,20 @@ func TestMultiTCPMux_NoDeadlockWhenClosingUnusedPacketConn(t *testing.T) { } muxMulti := NewMultiTCPMuxDefault(tcpMuxInstances...) - _, err := muxMulti.GetAllConns("test", false, net.IP{127, 0, 0, 1}) + _, err := muxMulti.GetAllConns("test", false, mustAddr(net.IP{127, 0, 0, 1})) require.NoError(t, err, "error getting conn by ufrag") require.NoError(t, muxMulti.Close(), "error closing tcpMux") - conn, err := muxMulti.GetAllConns("test", false, net.IP{127, 0, 0, 1}) + conn, err := muxMulti.GetAllConns("test", false, mustAddr(net.IP{127, 0, 0, 1})) require.Nil(t, conn, "should receive nil because mux is closed") require.Equal(t, io.ErrClosedPipe, err, "should receive error because mux is closed") } + +func mustAddr(ip net.IP) netip.Addr { + addr, ok := netip.AddrFromSlice(ip) + if !ok { + panic(fmt.Errorf("failed to convert ip '%s' to netip.Addr", ip)) + } + return addr +} diff --git a/tcp_mux_test.go b/tcp_mux_test.go index c6d9748c..77d5a3f0 100644 --- a/tcp_mux_test.go +++ b/tcp_mux_test.go @@ -62,7 +62,11 @@ func TestTCPMux_Recv(t *testing.T) { n, err := writeStreamingPacket(conn, msg.Raw) require.NoError(t, err, "error writing TCP STUN packet") - pktConn, err := tcpMux.GetConnByUfrag("myufrag", false, listener.Addr().(*net.TCPAddr).IP) + listenerAddr := listener.Addr().(*net.TCPAddr) + ipAddr, err := ipAddrToNetIP(listenerAddr.IP, listenerAddr.Zone) + require.NoError(t, err, "error getting listener ipAddr") + + pktConn, err := tcpMux.GetConnByUfrag("myufrag", false, ipAddr) require.NoError(t, err, "error retrieving muxed connection for ufrag") defer func() { _ = pktConn.Close() @@ -111,12 +115,16 @@ func TestTCPMux_NoDeadlockWhenClosingUnusedPacketConn(t *testing.T) { _ = tcpMux.Close() }() - _, err = tcpMux.GetConnByUfrag("test", false, listener.Addr().(*net.TCPAddr).IP) + listenerAddr := listener.Addr().(*net.TCPAddr) + ipAddr, err := ipAddrToNetIP(listenerAddr.IP, listenerAddr.Zone) + require.NoError(t, err, "error getting listener ipAddr") + + _, err = tcpMux.GetConnByUfrag("test", false, ipAddr) require.NoError(t, err, "error getting conn by ufrag") require.NoError(t, tcpMux.Close(), "error closing tcpMux") - conn, err := tcpMux.GetConnByUfrag("test", false, listener.Addr().(*net.TCPAddr).IP) + conn, err := tcpMux.GetConnByUfrag("test", false, ipAddr) require.Nil(t, conn, "should receive nil because mux is closed") require.Equal(t, io.ErrClosedPipe, err, "should receive error because mux is closed") } @@ -231,7 +239,11 @@ func TestTCPMux_NoLeakForConnectionFromStun(t *testing.T) { // wait for the connection to be created time.Sleep(100 * time.Millisecond) - pktConn, err := tcpMux.GetConnByUfrag("myufrag2", false, listener.Addr().(*net.TCPAddr).IP) + listenerAddr := listener.Addr().(*net.TCPAddr) + ipAddr, err := ipAddrToNetIP(listenerAddr.IP, listenerAddr.Zone) + require.NoError(t, err, "error getting listener ipAddr") + + pktConn, err := tcpMux.GetConnByUfrag("myufrag2", false, ipAddr) require.NoError(t, err, "error retrieving muxed connection for ufrag") defer func() { _ = pktConn.Close() diff --git a/transport_test.go b/transport_test.go index 216ee8b7..89fe375a 100644 --- a/transport_test.go +++ b/transport_test.go @@ -175,6 +175,7 @@ func gatherAndExchangeCandidates(aAgent, bAgent *Agent) { candidates, err := aAgent.GetLocalCandidates() check(err) + for _, c := range candidates { candidateCopy, copyErr := c.copy() check(copyErr) @@ -182,6 +183,7 @@ func gatherAndExchangeCandidates(aAgent, bAgent *Agent) { } candidates, err = bAgent.GetLocalCandidates() + check(err) for _, c := range candidates { candidateCopy, copyErr := c.copy() diff --git a/udp_mux.go b/udp_mux.go index dc45458f..40e8ad9b 100644 --- a/udp_mux.go +++ b/udp_mux.go @@ -69,19 +69,19 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { } var localAddrsForUnspecified []net.Addr - if addr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok { + if udpAddr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok { params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr()) - } else if ok && addr.IP.IsUnspecified() { + } else if ok && udpAddr.IP.IsUnspecified() { // For unspecified addresses, the correct behavior is to return errListenUnspecified, but // it will break the applications that are already using unspecified UDP connection // with UDPMuxDefault, so print a warn log and create a local address list for mux. params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") var networks []NetworkType switch { - case addr.IP.To4() != nil: + case udpAddr.IP.To4() != nil: networks = []NetworkType{NetworkTypeUDP4} - case addr.IP.To16() != nil: + case udpAddr.IP.To16() != nil: networks = []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6} default: @@ -95,10 +95,14 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { } } - ips, err := localInterfaces(params.Net, nil, nil, networks, true) + _, addrs, err := localInterfaces(params.Net, nil, nil, networks, true) if err == nil { - for _, ip := range ips { - localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port}) + for _, addr := range addrs { + localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{ + IP: addr.AsSlice(), + Port: udpAddr.Port, + Zone: addr.Zone(), + }) } } else { params.Logger.Errorf("Failed to get local interfaces for unspecified addr: %v", err) @@ -304,7 +308,7 @@ func (m *UDPMuxDefault) connWorker() { logger.Errorf("Underlying PacketConn did not return a UDPAddr") return } - udpAddr, err := newIPPort(netUDPAddr.IP, uint16(netUDPAddr.Port)) + udpAddr, err := newIPPort(netUDPAddr.IP, netUDPAddr.Zone, uint16(netUDPAddr.Port)) if err != nil { logger.Errorf("Failed to create a new IP/Port host pair") return @@ -378,14 +382,14 @@ type ipPort struct { // newIPPort create a custom type of address based on netip.Addr and // port. The underlying ip address passed is converted to IPv6 format // to simplify ip address handling -func newIPPort(ip net.IP, port uint16) (ipPort, error) { +func newIPPort(ip net.IP, zone string, port uint16) (ipPort, error) { n, ok := netip.AddrFromSlice(ip.To16()) if !ok { return ipPort{}, errInvalidIPAddress } return ipPort{ - addr: n, + addr: n.WithZone(zone), port: port, }, nil } diff --git a/udp_mux_multi.go b/udp_mux_multi.go index c46db9bf..2594cb02 100644 --- a/udp_mux_multi.go +++ b/udp_mux_multi.go @@ -90,14 +90,18 @@ func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMu } } - ips, err := localInterfaces(params.net, params.ifFilter, params.ipFilter, params.networks, params.includeLoopback) + _, addrs, err := localInterfaces(params.net, params.ifFilter, params.ipFilter, params.networks, params.includeLoopback) if err != nil { return nil, err } - conns := make([]net.PacketConn, 0, len(ips)) - for _, ip := range ips { - conn, listenErr := params.net.ListenUDP("udp", &net.UDPAddr{IP: ip, Port: port}) + conns := make([]net.PacketConn, 0, len(addrs)) + for _, addr := range addrs { + conn, listenErr := params.net.ListenUDP("udp", &net.UDPAddr{ + IP: addr.AsSlice(), + Port: port, + Zone: addr.Zone(), + }) if listenErr != nil { err = listenErr break diff --git a/udp_mux_multi_test.go b/udp_mux_multi_test.go index bb12022a..f5611be5 100644 --- a/udp_mux_multi_test.go +++ b/udp_mux_multi_test.go @@ -8,7 +8,6 @@ package ice import ( "net" - "strings" "sync" "testing" "time" @@ -114,11 +113,7 @@ func TestUnspecifiedUDPMux(t *testing.T) { defer test.TimeOut(time.Second * 30).Stop() muxPort := 7778 - udpMuxMulti, err := NewMultiUDPMuxFromPort(muxPort, UDPMuxFromPortWithInterfaceFilter(func(s string) bool { - defaultDockerBridgeNetwork := strings.Contains(s, "docker") - customDockerBridgeNetwork := strings.Contains(s, "br-") - return !defaultDockerBridgeNetwork && !customDockerBridgeNetwork - })) + udpMuxMulti, err := NewMultiUDPMuxFromPort(muxPort, UDPMuxFromPortWithInterfaceFilter(problematicNetworkInterfaces)) require.NoError(t, err) require.GreaterOrEqual(t, len(udpMuxMulti.muxes), 1, "at least have 1 muxes") diff --git a/udp_mux_test.go b/udp_mux_test.go index 5f8b1e0d..bc8d36b0 100644 --- a/udp_mux_test.go +++ b/udp_mux_test.go @@ -50,13 +50,31 @@ func TestUDPMux(t *testing.T) { network string } - for _, subTest := range []testCase{ + testCases := []testCase{ {name: "IPv4loopback", conn: conn4, network: udp4}, {name: "IPv6loopback", conn: conn6, network: udp6}, {name: "Unspecified", conn: connUnspecified, network: udp}, {name: "IPv4Unspecified", conn: conn4Unspecified, network: udp4}, {name: "IPv6Unspecified", conn: conn6Unspecified, network: udp6}, - } { + } + + if ipv6Available(t) { + addr6 := getLocalIPAddress(t, NetworkTypeUDP6) + + conn6Unspecified, err := net.ListenUDP(udp, &net.UDPAddr{ + IP: addr6.AsSlice(), + Zone: addr6.Zone(), + }) + if err != nil { + t.Log("IPv6 is not supported on this machine") + } + + testCases = append(testCases, + testCase{name: "IPv6Specified", conn: conn6Unspecified, network: udp6}, + ) + } + + for _, subTest := range testCases { network, conn := subTest.network, subTest.conn if udpConn, ok := conn.(*net.UDPConn); !ok || udpConn == nil { continue diff --git a/udp_muxed_conn.go b/udp_muxed_conn.go index fb05e231..0244d130 100644 --- a/udp_muxed_conn.go +++ b/udp_muxed_conn.go @@ -86,7 +86,7 @@ func (c *udpMuxedConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) { return 0, errFailedToCastUDPAddr } - ipAndPort, err := newIPPort(netUDPAddr.IP, uint16(netUDPAddr.Port)) + ipAndPort, err := newIPPort(netUDPAddr.IP, netUDPAddr.Zone, uint16(netUDPAddr.Port)) if err != nil { return 0, err }