From 5ef09de4a4c99773800eb002a976731c2903b9cb Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Wed, 27 Mar 2024 11:08:27 -0400 Subject: [PATCH] Support IPv6 from mDNS --- active_tcp.go | 14 ++- active_tcp_test.go | 39 ++++--- addr.go | 114 ++++++++++++++++---- agent.go | 69 +++++++----- agent_test.go | 1 + candidate.go | 1 + candidate_base.go | 17 ++- candidate_host.go | 34 +++--- candidate_peer_reflexive.go | 14 +-- candidate_relay.go | 25 +++-- candidate_server_reflexive.go | 29 +++-- candidate_test.go | 159 ++++++++++++++-------------- gather.go | 108 +++++++++++++++---- gather_test.go | 6 +- gather_vnet_test.go | 24 ++--- go.mod | 2 +- go.sum | 4 +- mdns.go | 84 +++++++++++++-- mdns_test.go | 193 +++++++++++++++++++++++----------- net.go | 82 +++++++++------ net_test.go | 62 ++++++++--- networktype.go | 12 +-- networktype_test.go | 4 +- tcp_mux.go | 23 ++-- tcp_mux_multi.go | 4 +- tcp_mux_test.go | 17 ++- transport_test.go | 8 ++ udp_mux.go | 24 +++-- udp_mux_multi.go | 12 ++- udp_mux_multi_test.go | 7 +- udp_mux_test.go | 22 +++- udp_muxed_conn.go | 2 +- 32 files changed, 828 insertions(+), 388 deletions(-) diff --git a/active_tcp.go b/active_tcp.go index 4ffcb6e7..a6f8387e 100644 --- a/active_tcp.go +++ b/active_tcp.go @@ -7,6 +7,7 @@ import ( "context" "io" "net" + "net/netip" "sync/atomic" "time" @@ -20,7 +21,7 @@ type activeTCPConn struct { closed int32 } -func newActiveTCPConn(ctx context.Context, localAddress, remoteAddress string, log logging.LeveledLogger) (a *activeTCPConn) { +func newActiveTCPConn(ctx context.Context, localAddress string, remoteAddress netip.AddrPort, log logging.LeveledLogger) (a *activeTCPConn) { a = &activeTCPConn{ readBuffer: packetio.NewBuffer(), writeBuffer: packetio.NewBuffer(), @@ -42,12 +43,11 @@ func newActiveTCPConn(ctx context.Context, localAddress, remoteAddress string, l dialer := &net.Dialer{ LocalAddr: laddr, } - conn, err := dialer.DialContext(ctx, "tcp", remoteAddress) + conn, err := dialer.DialContext(ctx, "tcp", remoteAddress.String()) if err != nil { log.Infof("Failed to dial TCP address %s: %v", remoteAddress, err) return } - a.remoteAddr.Store(conn.RemoteAddr()) go func() { @@ -95,8 +95,9 @@ func (a *activeTCPConn) ReadFrom(buff []byte) (n int, srcAddr net.Addr, err erro return 0, nil, io.ErrClosedPipe } - srcAddr = a.RemoteAddr() n, err = a.readBuffer.Read(buff) + // RemoteAddr is assuredly set *after* we can read from the buffer + srcAddr = a.RemoteAddr() return } @@ -123,6 +124,11 @@ func (a *activeTCPConn) LocalAddr() net.Addr { return &net.TCPAddr{} } +// RemoteAddr returns the remote address of the connection which is only +// set once a background goroutine has successfully dialed. That means +// this may return ":0" for the address prior to that happening. If this +// becomes an issue, we can introduce a synchronization point between Dial +// and these methods. func (a *activeTCPConn) RemoteAddr() net.Addr { if v, ok := a.remoteAddr.Load().(*net.TCPAddr); ok { return v diff --git a/active_tcp_test.go b/active_tcp_test.go index 7b696d1c..44d6a47c 100644 --- a/active_tcp_test.go +++ b/active_tcp_test.go @@ -8,6 +8,7 @@ package ice import ( "net" + "net/netip" "testing" "time" @@ -17,21 +18,21 @@ import ( "github.com/stretchr/testify/require" ) -func getLocalIPAddress(t *testing.T, networkType NetworkType) net.IP { +func getLocalIPAddress(t *testing.T, networkType NetworkType) netip.Addr { net, err := stdnet.NewNet() require.NoError(t, err) - localIPs, err := localInterfaces(net, nil, nil, []NetworkType{networkType}, false) + _, localAddrs, err := localInterfaces(net, problematicNetworkInterfaces, nil, []NetworkType{networkType}, false) require.NoError(t, err) - require.NotEmpty(t, localIPs) - return localIPs[0] + require.NotEmpty(t, localAddrs) + return localAddrs[0] } func ipv6Available(t *testing.T) bool { net, err := stdnet.NewNet() require.NoError(t, err) - localIPs, err := localInterfaces(net, nil, nil, []NetworkType{NetworkTypeTCP6}, false) + _, localAddrs, err := localInterfaces(net, problematicNetworkInterfaces, nil, []NetworkType{NetworkTypeTCP6}, false) require.NoError(t, err) - return len(localIPs) > 0 + return len(localAddrs) > 0 } func TestActiveTCP(t *testing.T) { @@ -43,8 +44,9 @@ func TestActiveTCP(t *testing.T) { type testCase struct { name string networkTypes []NetworkType - listenIPAddress net.IP + listenIPAddress netip.Addr selectedPairNetworkType string + useMDNS bool } testCases := []testCase{ @@ -69,12 +71,16 @@ func TestActiveTCP(t *testing.T) { networkTypes: []NetworkType{NetworkTypeTCP6}, listenIPAddress: getLocalIPAddress(t, NetworkTypeTCP6), selectedPairNetworkType: tcp, + // if we don't use mDNS, we will very liekly be filtering out location tracked ips. + useMDNS: true, }, testCase{ - name: "UDP is preferred over TCP6", // This fails some time + name: "UDP is preferred over TCP6", networkTypes: supportedNetworkTypes(), listenIPAddress: getLocalIPAddress(t, NetworkTypeTCP6), selectedPairNetworkType: udp, + // if we don't use mDNS, we will very liekly be filtering out location tracked ips. + useMDNS: true, }, ) } @@ -84,8 +90,9 @@ func TestActiveTCP(t *testing.T) { r := require.New(t) listener, err := net.ListenTCP("tcp", &net.TCPAddr{ - IP: testCase.listenIPAddress, + IP: testCase.listenIPAddress.AsSlice(), Port: listenPort, + Zone: testCase.listenIPAddress.Zone(), }) r.NoError(err) defer func() { @@ -107,14 +114,18 @@ func TestActiveTCP(t *testing.T) { r.NotNil(tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil") hostAcceptanceMinWait := 100 * time.Millisecond - passiveAgent, err := NewAgent(&AgentConfig{ + cfg := &AgentConfig{ TCPMux: tcpMux, CandidateTypes: []CandidateType{CandidateTypeHost}, NetworkTypes: testCase.networkTypes, LoggerFactory: loggerFactory, - IncludeLoopback: true, HostAcceptanceMinWait: &hostAcceptanceMinWait, - }) + InterfaceFilter: problematicNetworkInterfaces, + } + if testCase.useMDNS { + cfg.MulticastDNSMode = MulticastDNSModeQueryAndGather + } + passiveAgent, err := NewAgent(cfg) r.NoError(err) r.NotNil(passiveAgent) @@ -123,6 +134,7 @@ func TestActiveTCP(t *testing.T) { NetworkTypes: testCase.networkTypes, LoggerFactory: loggerFactory, HostAcceptanceMinWait: &hostAcceptanceMinWait, + InterfaceFilter: problematicNetworkInterfaces, }) r.NoError(err) r.NotNil(activeAgent) @@ -166,7 +178,8 @@ func TestActiveTCP_NonBlocking(t *testing.T) { defer test.TimeOut(time.Second * 5).Stop() cfg := &AgentConfig{ - NetworkTypes: supportedNetworkTypes(), + NetworkTypes: supportedNetworkTypes(), + InterfaceFilter: problematicNetworkInterfaces, } aAgent, err := NewAgent(cfg) diff --git a/addr.go b/addr.go index 1d70025b..fb40061d 100644 --- a/addr.go +++ b/addr.go @@ -4,52 +4,126 @@ package ice import ( + "fmt" "net" + "net/netip" ) -func parseMulticastAnswerAddr(in net.Addr) (net.IP, bool) { +func addrWithOptionalZone(addr netip.Addr, zone string) netip.Addr { + if zone == "" { + return addr + } + if addr.Is6() && (addr.IsLinkLocalUnicast() || addr.IsLinkLocalMulticast()) { + return addr.WithZone(zone) + } + return addr +} + +// parseAddrFromIface should only be used when it's known the address belongs to that interface. +// e.g. it's LocalAddress on a listener. +func parseAddrFromIface(in net.Addr, ifcName string) (netip.Addr, int, NetworkType, error) { + addr, port, nt, err := parseAddr(in) + if err != nil { + return netip.Addr{}, 0, 0, err + } + if _, ok := in.(*net.IPNet); ok { + // net.IPNet does not have a Zone but we provide it from the interface + addr = addrWithOptionalZone(addr, ifcName) + } + return addr, port, nt, nil +} + +func parseAddr(in net.Addr) (netip.Addr, int, NetworkType, error) { switch addr := in.(type) { + case *net.IPNet: + ipAddr, err := ipAddrToNetIP(addr.IP, "") + if err != nil { + return netip.Addr{}, 0, 0, err + } + return ipAddr, 0, 0, nil case *net.IPAddr: - return addr.IP, true + ipAddr, err := ipAddrToNetIP(addr.IP, addr.Zone) + if err != nil { + return netip.Addr{}, 0, 0, err + } + return ipAddr, 0, 0, nil case *net.UDPAddr: - return addr.IP, true + ipAddr, err := ipAddrToNetIP(addr.IP, addr.Zone) + if err != nil { + return netip.Addr{}, 0, 0, err + } + var nt NetworkType + if ipAddr.Is4() { + nt = NetworkTypeUDP4 + } else { + nt = NetworkTypeUDP6 + } + return ipAddr, addr.Port, nt, nil case *net.TCPAddr: - return addr.IP, true + ipAddr, err := ipAddrToNetIP(addr.IP, addr.Zone) + if err != nil { + return netip.Addr{}, 0, 0, err + } + var nt NetworkType + if ipAddr.Is4() { + nt = NetworkTypeTCP4 + } else { + nt = NetworkTypeTCP6 + } + return ipAddr, addr.Port, nt, nil + default: + return netip.Addr{}, 0, 0, addrParseError{in} } - return nil, false } -func parseAddr(in net.Addr) (net.IP, int, NetworkType, bool) { - switch addr := in.(type) { - case *net.UDPAddr: - return addr.IP, addr.Port, NetworkTypeUDP4, true - case *net.TCPAddr: - return addr.IP, addr.Port, NetworkTypeTCP4, true +type addrParseError struct { + addr net.Addr +} + +func (e addrParseError) Error() string { + return fmt.Sprintf("do not know how to parse address type %T", e.addr) +} + +type ipConvertError struct { + ip []byte +} + +func (e ipConvertError) Error() string { + return fmt.Sprintf("failed to convert IP '%s' to netip.Addr", e.ip) +} + +func ipAddrToNetIP(ip []byte, zone string) (netip.Addr, error) { + netIPAddr, ok := netip.AddrFromSlice(ip) + if !ok { + return netip.Addr{}, ipConvertError{ip} } - return nil, 0, 0, false + // we'd rather have an IPv4-mapped IPv6 become IPv4 so that it is usable. + netIPAddr = netIPAddr.Unmap() + netIPAddr = addrWithOptionalZone(netIPAddr, zone) + return netIPAddr, nil } -func createAddr(network NetworkType, ip net.IP, port int) net.Addr { +func createAddr(network NetworkType, ip netip.Addr, port int) net.Addr { switch { case network.IsTCP(): - return &net.TCPAddr{IP: ip, Port: port} + return &net.TCPAddr{IP: ip.AsSlice(), Port: port, Zone: ip.Zone()} default: - return &net.UDPAddr{IP: ip, Port: port} + return &net.UDPAddr{IP: ip.AsSlice(), Port: port, Zone: ip.Zone()} } } func addrEqual(a, b net.Addr) bool { - aIP, aPort, aType, aOk := parseAddr(a) - if !aOk { + aIP, aPort, aType, aErr := parseAddr(a) + if aErr != nil { return false } - bIP, bPort, bType, bOk := parseAddr(b) - if !bOk { + bIP, bPort, bType, bErr := parseAddr(b) + if bErr != nil { return false } - return aType == bType && aIP.Equal(bIP) && aPort == bPort + return aType == bType && aIP.Compare(bIP) == 0 && aPort == bPort } // AddrPort is an IP and a port number. diff --git a/agent.go b/agent.go index 27d429d3..1a8c897c 100644 --- a/agent.go +++ b/agent.go @@ -9,7 +9,7 @@ import ( "context" "fmt" "net" - "strconv" + "net/netip" "strings" "sync" "sync/atomic" @@ -228,9 +228,22 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit } } + localIfcs, _, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, a.networkTypes, a.includeLoopback) + if err != nil { + return nil, fmt.Errorf("error getting local interfaces: %w", err) + } + // Opportunistic mDNS: If we can't open the connection, that's ok: we // can continue without it. - if a.mDNSConn, a.mDNSMode, err = createMulticastDNS(a.net, mDNSMode, mDNSName, log); err != nil { + if a.mDNSConn, a.mDNSMode, err = createMulticastDNS( + a.net, + a.networkTypes, + localIfcs, + a.includeLoopback, + mDNSMode, + mDNSName, + log, + ); err != nil { log.Warnf("Failed to initialize mDNS %s: %v", mDNSName, err) } @@ -592,19 +605,14 @@ func (a *Agent) resolveAndAddMulticastCandidate(c *CandidateHost) { if a.mDNSConn == nil { return } - _, src, err := a.mDNSConn.Query(c.context(), c.Address()) + + _, src, err := a.mDNSConn.QueryAddr(c.context(), c.Address()) if err != nil { a.log.Warnf("Failed to discover mDNS candidate %s: %v", c.Address(), err) return } - ip, ipOk := parseMulticastAnswerAddr(src) - if !ipOk { - a.log.Warnf("Failed to discover mDNS candidate %s: failed to parse IP", c.Address()) - return - } - - if err = c.setIP(ip); err != nil { + if err = c.setIPAddr(src); err != nil { a.log.Warnf("Failed to discover mDNS candidate %s: %v", c.Address(), err) return } @@ -626,17 +634,23 @@ func (a *Agent) requestConnectivityCheck() { } func (a *Agent) addRemotePassiveTCPCandidate(remoteCandidate Candidate) { - localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{remoteCandidate.NetworkType()}, a.includeLoopback) + _, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{remoteCandidate.NetworkType()}, a.includeLoopback) if err != nil { a.log.Warnf("Failed to iterate local interfaces, host candidates will not be gathered %s", err) return } for i := range localIPs { + ip, _, _, err := parseAddr(remoteCandidate.addr()) + if err != nil { + a.log.Warnf("Failed to parse address: %s; error: %s", remoteCandidate.addr(), err) + continue + } + conn := newActiveTCPConn( a.loop, net.JoinHostPort(localIPs[i].String(), "0"), - net.JoinHostPort(remoteCandidate.Address(), strconv.Itoa(remoteCandidate.Port())), + netip.AddrPortFrom(ip, uint16(remoteCandidate.Port())), a.log, ) @@ -730,7 +744,9 @@ func (a *Agent) addCandidate(ctx context.Context, c Candidate, candidateConn net a.requestConnectivityCheck() - a.candidateNotifier.EnqueueCandidate(c) + if !c.filterForLocationTracking() { + a.candidateNotifier.EnqueueCandidate(c) + } }) } @@ -759,7 +775,12 @@ func (a *Agent) GetLocalCandidates() ([]Candidate, error) { err := a.loop.Run(a.loop, func(_ context.Context) { var candidates []Candidate for _, set := range a.localCandidates { - candidates = append(candidates, set...) + for _, c := range set { + if c.filterForLocationTracking() { + continue + } + candidates = append(candidates, c) + } } res = candidates }) @@ -841,9 +862,9 @@ func (a *Agent) deleteAllCandidates() { } func (a *Agent) findRemoteCandidate(networkType NetworkType, addr net.Addr) Candidate { - ip, port, _, ok := parseAddr(addr) - if !ok { - a.log.Warnf("Failed to parse address: %s", addr) + ip, port, _, err := parseAddr(addr) + if err != nil { + a.log.Warnf("Failed to parse address: %s; error: %s", addr, err) return nil } @@ -873,15 +894,15 @@ func (a *Agent) sendBindingRequest(m *stun.Message, local, remote Candidate) { func (a *Agent) sendBindingSuccess(m *stun.Message, local, remote Candidate) { base := remote - ip, port, _, ok := parseAddr(base.addr()) - if !ok { - a.log.Warnf("Failed to parse address: %s", base.addr()) + ip, port, _, err := parseAddr(base.addr()) + if err != nil { + a.log.Warnf("Failed to parse address: %s; error: %s", base.addr(), err) return } if out, err := stun.Build(m, stun.BindingSuccess, &stun.XORMappedAddress{ - IP: ip, + IP: ip.AsSlice(), Port: port, }, stun.NewShortTermIntegrity(a.localPwd), @@ -983,9 +1004,9 @@ func (a *Agent) handleInbound(m *stun.Message, local Candidate, remote net.Addr) } if remoteCandidate == nil { - ip, port, networkType, ok := parseAddr(remote) - if !ok { - a.log.Errorf("Failed to create parse remote net.Addr when creating remote prflx candidate") + ip, port, networkType, err := parseAddr(remote) + if err != nil { + a.log.Errorf("Failed to create parse remote net.Addr when creating remote prflx candidate: %s", err) return } diff --git a/agent_test.go b/agent_test.go index bbb44cf2..552d4215 100644 --- a/agent_test.go +++ b/agent_test.go @@ -536,6 +536,7 @@ func TestConnectionStateCallback(t *testing.T) { DisconnectedTimeout: &disconnectedDuration, FailedTimeout: &failedDuration, KeepaliveInterval: &KeepaliveInterval, + InterfaceFilter: problematicNetworkInterfaces, } aAgent, err := NewAgent(cfg) diff --git a/candidate.go b/candidate.go index 92a00768..4324159b 100644 --- a/candidate.go +++ b/candidate.go @@ -61,6 +61,7 @@ type Candidate interface { Marshal() string addr() net.Addr + filterForLocationTracking() bool agent() *Agent context() context.Context diff --git a/candidate_base.go b/candidate_base.go index fd0f292a..1a60899a 100644 --- a/candidate_base.go +++ b/candidate_base.go @@ -43,6 +43,7 @@ type candidateBase struct { priorityOverride uint32 remoteCandidateCaches map[AddrPort]Candidate + isLocationTracked bool } // Done implements context.Context @@ -384,6 +385,14 @@ func (c *candidateBase) Priority() uint32 { // Equal is used to compare two candidateBases func (c *candidateBase) Equal(other Candidate) bool { + if c.addr() != other.addr() { + if c.addr() == nil || other.addr() == nil { + return false + } + if c.addr().String() != other.addr().String() { + return false + } + } return c.NetworkType() == other.NetworkType() && c.Type() == other.Type() && c.Address() == other.Address() && @@ -394,7 +403,7 @@ func (c *candidateBase) Equal(other Candidate) bool { // String makes the candidateBase printable func (c *candidateBase) String() string { - return fmt.Sprintf("%s %s %s%s", c.NetworkType(), c.Type(), net.JoinHostPort(c.Address(), strconv.Itoa(c.Port())), c.relatedAddress) + return fmt.Sprintf("%s %s %s%s (resolved: %v)", c.NetworkType(), c.Type(), net.JoinHostPort(c.Address(), strconv.Itoa(c.Port())), c.relatedAddress, c.resolvedAddr) } // LastReceived returns a time.Time indicating the last time @@ -435,6 +444,10 @@ func (c *candidateBase) addr() net.Addr { return c.resolvedAddr } +func (c *candidateBase) filterForLocationTracking() bool { + return c.isLocationTracked +} + func (c *candidateBase) agent() *Agent { return c.currAgent } @@ -551,7 +564,7 @@ func UnmarshalCandidate(raw string) (Candidate, error) { switch typ { case "host": - return NewCandidateHost(&CandidateHostConfig{"", protocol, address, port, component, priority, foundation, tcpType}) + return NewCandidateHost(&CandidateHostConfig{"", protocol, address, port, component, priority, foundation, tcpType, false}) case "srflx": return NewCandidateServerReflexive(&CandidateServerReflexiveConfig{"", protocol, address, port, component, priority, foundation, relatedAddress, relatedPort}) case "prflx": diff --git a/candidate_host.go b/candidate_host.go index 5d207dd9..c14b6a7c 100644 --- a/candidate_host.go +++ b/candidate_host.go @@ -4,7 +4,7 @@ package ice import ( - "net" + "net/netip" "strings" ) @@ -17,14 +17,15 @@ type CandidateHost struct { // CandidateHostConfig is the config required to create a new CandidateHost type CandidateHostConfig struct { - CandidateID string - Network string - Address string - Port int - Component uint16 - Priority uint32 - Foundation string - TCPType TCPType + CandidateID string + Network string + Address string + Port int + Component uint16 + Priority uint32 + Foundation string + TCPType TCPType + IsLocationTracked bool } // NewCandidateHost creates a new host candidate @@ -46,17 +47,18 @@ func NewCandidateHost(config *CandidateHostConfig) (*CandidateHost, error) { foundationOverride: config.Foundation, priorityOverride: config.Priority, remoteCandidateCaches: map[AddrPort]Candidate{}, + isLocationTracked: config.IsLocationTracked, }, network: config.Network, } if !strings.HasSuffix(config.Address, ".local") { - ip := net.ParseIP(config.Address) - if ip == nil { - return nil, ErrAddressParseFailed + ipAddr, err := netip.ParseAddr(config.Address) + if err != nil { + return nil, err } - if err := c.setIP(ip); err != nil { + if err := c.setIPAddr(ipAddr); err != nil { return nil, err } } else { @@ -67,14 +69,14 @@ func NewCandidateHost(config *CandidateHostConfig) (*CandidateHost, error) { return c, nil } -func (c *CandidateHost) setIP(ip net.IP) error { - networkType, err := determineNetworkType(c.network, ip) +func (c *CandidateHost) setIPAddr(addr netip.Addr) error { + networkType, err := determineNetworkType(c.network, addr) if err != nil { return err } c.candidateBase.networkType = networkType - c.candidateBase.resolvedAddr = createAddr(networkType, ip, c.port) + c.candidateBase.resolvedAddr = createAddr(networkType, addr, c.port) return nil } diff --git a/candidate_peer_reflexive.go b/candidate_peer_reflexive.go index bbcfe335..b28e9a76 100644 --- a/candidate_peer_reflexive.go +++ b/candidate_peer_reflexive.go @@ -6,7 +6,9 @@ //nolint:dupl package ice -import "net" +import ( + "net/netip" +) // CandidatePeerReflexive ... type CandidatePeerReflexive struct { @@ -28,12 +30,12 @@ type CandidatePeerReflexiveConfig struct { // NewCandidatePeerReflexive creates a new peer reflective candidate func NewCandidatePeerReflexive(config *CandidatePeerReflexiveConfig) (*CandidatePeerReflexive, error) { - ip := net.ParseIP(config.Address) - if ip == nil { - return nil, ErrAddressParseFailed + ipAddr, err := netip.ParseAddr(config.Address) + if err != nil { + return nil, err } - networkType, err := determineNetworkType(config.Network, ip) + networkType, err := determineNetworkType(config.Network, ipAddr) if err != nil { return nil, err } @@ -50,7 +52,7 @@ func NewCandidatePeerReflexive(config *CandidatePeerReflexiveConfig) (*Candidate candidateType: CandidateTypePeerReflexive, address: config.Address, port: config.Port, - resolvedAddr: createAddr(networkType, ip, config.Port), + resolvedAddr: createAddr(networkType, ipAddr, config.Port), component: config.Component, foundationOverride: config.Foundation, priorityOverride: config.Priority, diff --git a/candidate_relay.go b/candidate_relay.go index fa5297b2..faf281b0 100644 --- a/candidate_relay.go +++ b/candidate_relay.go @@ -5,6 +5,7 @@ package ice import ( "net" + "net/netip" ) // CandidateRelay ... @@ -38,24 +39,28 @@ func NewCandidateRelay(config *CandidateRelayConfig) (*CandidateRelay, error) { candidateID = globalCandidateIDGenerator.Generate() } - ip := net.ParseIP(config.Address) - if ip == nil { - return nil, ErrAddressParseFailed + ipAddr, err := netip.ParseAddr(config.Address) + if err != nil { + return nil, err } - networkType, err := determineNetworkType(config.Network, ip) + networkType, err := determineNetworkType(config.Network, ipAddr) if err != nil { return nil, err } return &CandidateRelay{ candidateBase: candidateBase{ - id: candidateID, - networkType: networkType, - candidateType: CandidateTypeRelay, - address: config.Address, - port: config.Port, - resolvedAddr: &net.UDPAddr{IP: ip, Port: config.Port}, + id: candidateID, + networkType: networkType, + candidateType: CandidateTypeRelay, + address: config.Address, + port: config.Port, + resolvedAddr: &net.UDPAddr{ + IP: ipAddr.AsSlice(), + Port: config.Port, + Zone: ipAddr.Zone(), + }, component: config.Component, foundationOverride: config.Foundation, priorityOverride: config.Priority, diff --git a/candidate_server_reflexive.go b/candidate_server_reflexive.go index 3a8ac0ff..85d613e2 100644 --- a/candidate_server_reflexive.go +++ b/candidate_server_reflexive.go @@ -3,7 +3,10 @@ package ice -import "net" +import ( + "net" + "net/netip" +) // CandidateServerReflexive ... type CandidateServerReflexive struct { @@ -25,12 +28,12 @@ type CandidateServerReflexiveConfig struct { // NewCandidateServerReflexive creates a new server reflective candidate func NewCandidateServerReflexive(config *CandidateServerReflexiveConfig) (*CandidateServerReflexive, error) { - ip := net.ParseIP(config.Address) - if ip == nil { - return nil, ErrAddressParseFailed + ipAddr, err := netip.ParseAddr(config.Address) + if err != nil { + return nil, err } - networkType, err := determineNetworkType(config.Network, ip) + networkType, err := determineNetworkType(config.Network, ipAddr) if err != nil { return nil, err } @@ -42,12 +45,16 @@ func NewCandidateServerReflexive(config *CandidateServerReflexiveConfig) (*Candi return &CandidateServerReflexive{ candidateBase: candidateBase{ - id: candidateID, - networkType: networkType, - candidateType: CandidateTypeServerReflexive, - address: config.Address, - port: config.Port, - resolvedAddr: &net.UDPAddr{IP: ip, Port: config.Port}, + id: candidateID, + networkType: networkType, + candidateType: CandidateTypeServerReflexive, + address: config.Address, + port: config.Port, + resolvedAddr: &net.UDPAddr{ + IP: ipAddr.AsSlice(), + Port: config.Port, + Zone: ipAddr.Zone(), + }, component: config.Component, foundationOverride: config.Foundation, priorityOverride: config.Priority, diff --git a/candidate_test.go b/candidate_test.go index 96ab87c9..04bb17c1 100644 --- a/candidate_test.go +++ b/candidate_test.go @@ -5,6 +5,7 @@ package ice import ( "net" + "strconv" "testing" "time" @@ -265,107 +266,105 @@ func TestCandidateFoundation(t *testing.T) { }).Foundation()) } +func mustCandidateHost(conf *CandidateHostConfig) Candidate { + cand, err := NewCandidateHost(conf) + if err != nil { + panic(err) + } + return cand +} + +func mustCandidateRelay(conf *CandidateRelayConfig) Candidate { + cand, err := NewCandidateRelay(conf) + if err != nil { + panic(err) + } + return cand +} + +func mustCandidateServerReflexive(conf *CandidateServerReflexiveConfig) Candidate { + cand, err := NewCandidateServerReflexive(conf) + if err != nil { + panic(err) + } + return cand +} + func TestCandidateMarshal(t *testing.T) { - for _, test := range []struct { + for idx, test := range []struct { candidate Candidate marshaled string expectError bool }{ { - &CandidateHost{ - candidateBase{ - networkType: NetworkTypeUDP6, - candidateType: CandidateTypeHost, - address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", - port: 53987, - priorityOverride: 500, - foundationOverride: "750", - }, - "", - }, + mustCandidateHost(&CandidateHostConfig{ + Network: NetworkTypeUDP6.String(), + Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", + Port: 53987, + Priority: 500, + Foundation: "750", + }), "750 1 udp 500 fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a 53987 typ host", false, }, { - &CandidateHost{ - candidateBase{ - networkType: NetworkTypeUDP4, - candidateType: CandidateTypeHost, - address: "10.0.75.1", - port: 53634, - }, - "", - }, + mustCandidateHost(&CandidateHostConfig{ + Network: NetworkTypeUDP4.String(), + Address: "10.0.75.1", + Port: 53634, + }), "4273957277 1 udp 2130706431 10.0.75.1 53634 typ host", false, }, { - &CandidateServerReflexive{ - candidateBase{ - networkType: NetworkTypeUDP4, - candidateType: CandidateTypeServerReflexive, - address: "191.228.238.68", - port: 53991, - relatedAddress: &CandidateRelatedAddress{"192.168.0.274", 53991}, - }, - }, + mustCandidateServerReflexive(&CandidateServerReflexiveConfig{ + Network: NetworkTypeUDP4.String(), + Address: "191.228.238.68", + Port: 53991, + RelAddr: "192.168.0.274", + RelPort: 53991, + }), "647372371 1 udp 1694498815 191.228.238.68 53991 typ srflx raddr 192.168.0.274 rport 53991", false, }, { - &CandidateRelay{ - candidateBase{ - networkType: NetworkTypeUDP4, - candidateType: CandidateTypeRelay, - address: "50.0.0.1", - port: 5000, - relatedAddress: &CandidateRelatedAddress{"192.168.0.1", 5001}, - }, - "", - nil, - }, + mustCandidateRelay(&CandidateRelayConfig{ + Network: NetworkTypeUDP4.String(), + Address: "50.0.0.1", + Port: 5000, + RelAddr: "192.168.0.1", + RelPort: 5001, + }), "848194626 1 udp 16777215 50.0.0.1 5000 typ relay raddr 192.168.0.1 rport 5001", false, }, { - &CandidateHost{ - candidateBase{ - networkType: NetworkTypeTCP4, - candidateType: CandidateTypeHost, - address: "192.168.0.196", - port: 0, - tcpType: TCPTypeActive, - }, - "", - }, + mustCandidateHost(&CandidateHostConfig{ + Network: NetworkTypeTCP4.String(), + Address: "192.168.0.196", + Port: 0, + TCPType: TCPTypeActive, + }), "1052353102 1 tcp 2128609279 192.168.0.196 0 typ host tcptype active", false, }, { - &CandidateHost{ - candidateBase{ - networkType: NetworkTypeUDP4, - candidateType: CandidateTypeHost, - address: "e2494022-4d9a-4c1e-a750-cc48d4f8d6ee.local", - port: 60542, - }, - "", - }, + mustCandidateHost(&CandidateHostConfig{ + Network: NetworkTypeUDP4.String(), + Address: "e2494022-4d9a-4c1e-a750-cc48d4f8d6ee.local", + Port: 60542, + }), "1380287402 1 udp 2130706431 e2494022-4d9a-4c1e-a750-cc48d4f8d6ee.local 60542 typ host", false, }, // Missing Foundation { - &CandidateHost{ - candidateBase{ - networkType: NetworkTypeUDP4, - candidateType: CandidateTypeHost, - address: localhostIPStr, - port: 80, - priorityOverride: 500, - foundationOverride: " ", - }, - "", - }, + mustCandidateHost(&CandidateHostConfig{ + Network: NetworkTypeUDP4.String(), + Address: localhostIPStr, + Port: 80, + Priority: 500, + Foundation: " ", + }), " 1 udp 500 " + localhostIPStr + " 80 typ host", false, }, @@ -384,16 +383,18 @@ func TestCandidateMarshal(t *testing.T) { {nil, "4207374051 1 udp 2130706431 10.0.75.1 53634 typ INVALID", true}, {nil, "4207374051 1 INVALID 2130706431 10.0.75.1 53634 typ host", true}, } { - actualCandidate, err := UnmarshalCandidate(test.marshaled) - if test.expectError { - require.Error(t, err) - continue - } + t.Run(strconv.Itoa(idx), func(t *testing.T) { + actualCandidate, err := UnmarshalCandidate(test.marshaled) + if test.expectError { + require.Error(t, err) + return + } - require.NoError(t, err) + require.NoError(t, err) - require.True(t, test.candidate.Equal(actualCandidate)) - require.Equal(t, test.marshaled, actualCandidate.Marshal()) + require.True(t, test.candidate.Equal(actualCandidate)) + require.Equal(t, test.marshaled, actualCandidate.Marshal()) + }) } } diff --git a/gather.go b/gather.go index 258e7fcb..fe56cc7f 100644 --- a/gather.go +++ b/gather.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "net" + "net/netip" "reflect" "sync" "time" @@ -133,25 +134,37 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ delete(networks, udp) } - localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, networkTypes, a.includeLoopback) + _, localAddrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, networkTypes, a.includeLoopback) if err != nil { a.log.Warnf("Failed to iterate local interfaces, host candidates will not be gathered %s", err) return } - for _, ip := range localIPs { - mappedIP := ip + for _, addr := range localAddrs { + mappedIP := addr if a.mDNSMode != MulticastDNSModeQueryAndGather && a.extIPMapper != nil && a.extIPMapper.candidateType == CandidateTypeHost { - if _mappedIP, innerErr := a.extIPMapper.findExternalIP(ip.String()); innerErr == nil { - mappedIP = _mappedIP + if _mappedIP, innerErr := a.extIPMapper.findExternalIP(addr.String()); innerErr == nil { + conv, ok := netip.AddrFromSlice(_mappedIP) + if !ok { + a.log.Warnf("failed to convert mapped external IP to netip.Addr'%s'", addr.String()) + continue + } + // we'd rather have an IPv4-mapped IPv6 become IPv4 so that it is usable + mappedIP = conv.Unmap() } else { - a.log.Warnf("1:1 NAT mapping is enabled but no external IP is found for %s", ip.String()) + a.log.Warnf("1:1 NAT mapping is enabled but no external IP is found for %s", addr.String()) } } address := mappedIP.String() + var isLocationTracked bool if a.mDNSMode == MulticastDNSModeQueryAndGather { address = a.mDNSName + } else { + // Here, we are not doing multicast gathering, so we will need to skip this address so + // that we don't accidentally reveal location tracking information. Otherwise, the + // case above hides the IP behind an mDNS address. + isLocationTracked = shouldFilterLocationTrackedIP(mappedIP) } for network := range networks { @@ -174,16 +187,18 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ var muxConns []net.PacketConn if multi, ok := a.tcpMux.(AllConnsGetter); ok { a.log.Debugf("GetAllConns by ufrag: %s", a.localUfrag) - muxConns, err = multi.GetAllConns(a.localUfrag, mappedIP.To4() == nil, ip) + // Note: this is missing zone for IPv6 by just grabbing the IP slice + muxConns, err = multi.GetAllConns(a.localUfrag, mappedIP.Is6(), addr.AsSlice()) if err != nil { - a.log.Warnf("Failed to get all TCP connections by ufrag: %s %s %s", network, ip, a.localUfrag) + a.log.Warnf("Failed to get all TCP connections by ufrag: %s %s %s", network, addr, a.localUfrag) continue } } else { a.log.Debugf("GetConn by ufrag: %s", a.localUfrag) - conn, err := a.tcpMux.GetConnByUfrag(a.localUfrag, mappedIP.To4() == nil, ip) + // Note: this is missing zone for IPv6 by just grabbing the IP slice + conn, err := a.tcpMux.GetConnByUfrag(a.localUfrag, mappedIP.Is6(), addr.AsSlice()) if err != nil { - a.log.Warnf("Failed to get TCP connections by ufrag: %s %s %s", network, ip, a.localUfrag) + a.log.Warnf("Failed to get TCP connections by ufrag: %s %s %s", network, addr, a.localUfrag) continue } muxConns = []net.PacketConn{conn} @@ -194,7 +209,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ if tcpConn, ok := conn.LocalAddr().(*net.TCPAddr); ok { conns = append(conns, connAndPort{conn, tcpConn.Port}) } else { - a.log.Warnf("Failed to get port of connection from TCPMux: %s %s %s", network, ip, a.localUfrag) + a.log.Warnf("Failed to get port of connection from TCPMux: %s %s %s", network, addr, a.localUfrag) } } if len(conns) == 0 { @@ -205,16 +220,20 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ // Is there a way to verify that the listen address is even // accessible from the current interface. case udp: - conn, err := listenUDPInPortRange(a.net, a.log, int(a.portMax), int(a.portMin), network, &net.UDPAddr{IP: ip, Port: 0}) + conn, err := listenUDPInPortRange(a.net, a.log, int(a.portMax), int(a.portMin), network, &net.UDPAddr{ + IP: addr.AsSlice(), + Port: 0, + Zone: addr.Zone(), + }) if err != nil { - a.log.Warnf("Failed to listen %s %s", network, ip) + a.log.Warnf("Failed to listen %s %s", network, addr) continue } if udpConn, ok := conn.LocalAddr().(*net.UDPAddr); ok { conns = append(conns, connAndPort{conn, udpConn.Port}) } else { - a.log.Warnf("Failed to get port of UDPAddr from ListenUDPInPortRange: %s %s %s", network, ip, a.localUfrag) + a.log.Warnf("Failed to get port of UDPAddr from ListenUDPInPortRange: %s %s %s", network, addr, a.localUfrag) continue } } @@ -226,6 +245,9 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ Port: connAndPort.port, Component: ComponentRTP, TCPType: tcpType, + // we will still process this candidate so that we start up the right + // listeners. + IsLocationTracked: isLocationTracked, } c, err := NewCandidateHost(&hostConfig) @@ -235,7 +257,7 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ } if a.mDNSMode == MulticastDNSModeQueryAndGather { - if err = c.setIP(ip); err != nil { + if err = c.setIPAddr(addr); err != nil { closeConnAndLog(connAndPort.conn, a.log, "failed to create host candidate: %s %s %d: %v", network, mappedIP, connAndPort.port, err) continue } @@ -252,6 +274,27 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ } } +// shouldFilterLocationTrackedIP returns if this candidate IP should be filtered out from +// any candidate publishing/notification for location tracking reasons. +func shouldFilterLocationTrackedIP(candidateIP netip.Addr) bool { + // https://tools.ietf.org/html/rfc8445#section-5.1.1.1 + // Similarly, when host candidates corresponding to + // an IPv6 address generated using a mechanism that prevents location + // tracking are gathered, then host candidates corresponding to IPv6 + // link-local addresses [RFC4291] MUST NOT be gathered. + return candidateIP.Is6() && (candidateIP.IsLinkLocalUnicast() || candidateIP.IsLinkLocalMulticast()) +} + +// shouldFilterLocationTracked returns if this candidate IP should be filtered out from +// any candidate publishing/notification for location tracking reasons. +func shouldFilterLocationTracked(candidateIP net.IP) bool { + addr, ok := netip.AddrFromSlice(candidateIP) + if !ok { + return false + } + return shouldFilterLocationTrackedIP(addr) +} + func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolint:gocognit if a.udpMux == nil { return errUDPMuxDisabled @@ -286,17 +329,23 @@ func (a *Agent) gatherCandidatesLocalUDPMux(ctx context.Context) error { //nolin } var address string + var isLocationTracked bool if a.mDNSMode == MulticastDNSModeQueryAndGather { address = a.mDNSName } else { address = candidateIP.String() + // Here, we are not doing multicast gathering, so we will need to skip this address so + // that we don't accidentally reveal location tracking information. Otherwise, the + // case above hides the IP behind an mDNS address. + isLocationTracked = shouldFilterLocationTracked(candidateIP) } hostConfig := CandidateHostConfig{ - Network: udp, - Address: address, - Port: udpAddr.Port, - Component: ComponentRTP, + Network: udp, + Address: address, + Port: udpAddr.Port, + Component: ComponentRTP, + IsLocationTracked: isLocationTracked, } // Detect a duplicate candidate before calling addCandidate(). @@ -365,6 +414,11 @@ func (a *Agent) gatherCandidatesSrflxMapped(ctx context.Context, networkTypes [] return } + if shouldFilterLocationTracked(mappedIP) { + closeConnAndLog(conn, a.log, "external IP is somehow filtered for location tracking reasons %s", mappedIP) + return + } + srflxConfig := CandidateServerReflexiveConfig{ Network: network, Address: mappedIP.String(), @@ -420,6 +474,11 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*stun.UR return } + if shouldFilterLocationTracked(serverAddr.IP) { + a.log.Warnf("STUN host %s is somehow filtered for location tracking reasons", hostPort) + return + } + xorAddr, err := a.udpMuxSrflx.GetXORMappedAddr(serverAddr, stunGatherTimeout) if err != nil { a.log.Warnf("Failed get server reflexive address %s %s: %v", network, url, err) @@ -482,6 +541,11 @@ func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, net return } + if shouldFilterLocationTracked(serverAddr.IP) { + a.log.Warnf("STUN host %s is somehow filtered for location tracking reasons", hostPort) + return + } + conn, err := listenUDPInPortRange(a.net, a.log, int(a.portMax), int(a.portMin), network, &net.UDPAddr{IP: nil, Port: 0}) if err != nil { closeConnAndLog(conn, a.log, "failed to listen for %s: %v", serverAddr.String(), err) @@ -696,6 +760,12 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*stun.URI) { / } rAddr := relayConn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert + + if shouldFilterLocationTracked(rAddr.IP) { + a.log.Warnf("TURN address %s is somehow filtered for location tracking reasons", rAddr.IP) + return + } + relayConfig := CandidateRelayConfig{ Network: network, Component: ComponentRTP, diff --git a/gather_test.go b/gather_test.go index 54c8a8db..1e4b8960 100644 --- a/gather_test.go +++ b/gather_test.go @@ -34,11 +34,11 @@ func TestListenUDP(t *testing.T) { a, err := NewAgent(&AgentConfig{}) require.NoError(t, err) - localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) - require.NotEqual(t, len(localIPs), 0, "localInterfaces found no interfaces, unable to test") + _, localAddrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) + require.NotEqual(t, len(localAddrs), 0, "localInterfaces found no interfaces, unable to test") require.NoError(t, err) - ip := localIPs[0] + ip := localAddrs[0].AsSlice() conn, err := listenUDPInPortRange(a.net, a.log, 0, 0, udp, &net.UDPAddr{IP: ip, Port: 0}) require.NoError(t, err, "listenUDP error with no port restriction") diff --git a/gather_vnet_test.go b/gather_vnet_test.go index ce11e839..9fda3fdc 100644 --- a/gather_vnet_test.go +++ b/gather_vnet_test.go @@ -34,7 +34,7 @@ func TestVNetGather(t *testing.T) { }) require.NoError(t, err) - localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) + _, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) if len(localIPs) > 0 { t.Fatal("should return no local IP") } @@ -73,17 +73,17 @@ func TestVNetGather(t *testing.T) { }) require.NoError(t, err) - localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) - if len(localIPs) == 0 { + _, localAddrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) + if len(localAddrs) == 0 { t.Fatal("should have one local IP") } require.NoError(t, err) - for _, ip := range localIPs { - if ip.IsLoopback() { + for _, addr := range localAddrs { + if addr.IsLoopback() { t.Fatal("should not return loopback IP") } - if !ipNet.Contains(ip) { + if !ipNet.Contains(addr.AsSlice()) { t.Fatal("should be contained in the CIDR") } } @@ -115,13 +115,13 @@ func TestVNetGather(t *testing.T) { t.Fatalf("Failed to create agent: %s", err) } - localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) - if len(localIPs) == 0 { + _, localAddrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) + if len(localAddrs) == 0 { t.Fatal("localInterfaces found no interfaces, unable to test") } require.NoError(t, err) - ip := localIPs[0] + ip := localAddrs[0].AsSlice() conn, err := listenUDPInPortRange(a.net, a.log, 0, 0, udp, &net.UDPAddr{IP: ip, Port: 0}) if err != nil { @@ -385,7 +385,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) { }) require.NoError(t, err) - localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) + _, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) require.NoError(t, err) if len(localIPs) != 0 { @@ -405,7 +405,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) { }) require.NoError(t, err) - localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) + _, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) require.NoError(t, err) if len(localIPs) != 0 { @@ -425,7 +425,7 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) { }) require.NoError(t, err) - localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) + _, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) require.NoError(t, err) if len(localIPs) == 0 { diff --git a/go.mod b/go.mod index 231f46c3..8be8f433 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/kr/pretty v0.1.0 // indirect github.com/pion/dtls/v2 v2.2.10 github.com/pion/logging v0.2.2 - github.com/pion/mdns/v2 v2.0.4 + github.com/pion/mdns/v2 v2.0.6 github.com/pion/randutil v0.1.0 github.com/pion/stun/v2 v2.0.0 github.com/pion/transport/v3 v3.0.1 diff --git a/go.sum b/go.sum index 1e0706d3..c31b88b3 100644 --- a/go.sum +++ b/go.sum @@ -13,8 +13,8 @@ github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA= github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= -github.com/pion/mdns/v2 v2.0.4 h1:ZdK19Yd+9iPrw95rW1tTwdjmYY5O3WwmsvfX6HBBefY= -github.com/pion/mdns/v2 v2.0.4/go.mod h1:y4Y034qALR23oAJuiElt2TP1ma7b1Q/uF1oYzIePHcM= +github.com/pion/mdns/v2 v2.0.6 h1:mrqisUnOajlMKqXXXtyiBmez/0rYMFVztHU3Mg1RETQ= +github.com/pion/mdns/v2 v2.0.6/go.mod h1:y4Y034qALR23oAJuiElt2TP1ma7b1Q/uF1oYzIePHcM= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= github.com/pion/stun/v2 v2.0.0 h1:A5+wXKLAypxQri59+tmQKVs7+l6mMM+3d+eER9ifRU0= diff --git a/mdns.go b/mdns.go index a3ad899b..2c10a327 100644 --- a/mdns.go +++ b/mdns.go @@ -4,11 +4,14 @@ package ice import ( + "net" + "github.com/google/uuid" "github.com/pion/logging" "github.com/pion/mdns/v2" "github.com/pion/transport/v3" "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" ) // MulticastDNSMode represents the different Multicast modes ICE can run in @@ -33,30 +36,95 @@ func generateMulticastDNSName() (string, error) { return u.String() + ".local", err } -func createMulticastDNS(n transport.Net, mDNSMode MulticastDNSMode, mDNSName string, log logging.LeveledLogger) (*mdns.Conn, MulticastDNSMode, error) { +func createMulticastDNS( + n transport.Net, + networkTypes []NetworkType, + interfaces []*transport.Interface, + includeLoopback bool, + mDNSMode MulticastDNSMode, + mDNSName string, + log logging.LeveledLogger, +) (*mdns.Conn, MulticastDNSMode, error) { if mDNSMode == MulticastDNSModeDisabled { return nil, mDNSMode, nil } - addr, mdnsErr := n.ResolveUDPAddr("udp4", mdns.DefaultAddressIPv4) + var useV4, useV6 bool + if len(networkTypes) == 0 { + useV4 = true + useV6 = true + } else { + for _, nt := range networkTypes { + if nt.IsIPv4() { + useV4 = true + continue + } + if nt.IsIPv6() { + useV6 = true + } + } + } + + addr4, mdnsErr := n.ResolveUDPAddr("udp4", mdns.DefaultAddressIPv4) if mdnsErr != nil { return nil, mDNSMode, mdnsErr } - - l, mdnsErr := n.ListenUDP("udp4", addr) + addr6, mdnsErr := n.ResolveUDPAddr("udp6", mdns.DefaultAddressIPv6) if mdnsErr != nil { + return nil, mDNSMode, mdnsErr + } + + var pktConnV4 *ipv4.PacketConn + var mdns4Err error + if useV4 { + var l transport.UDPConn + l, mdns4Err = n.ListenUDP("udp4", addr4) + if mdns4Err != nil { + // If ICE fails to start MulticastDNS server just warn the user and continue + log.Errorf("Failed to enable mDNS over IPv4: (%s)", mdns4Err) + return nil, MulticastDNSModeDisabled, nil + } + pktConnV4 = ipv4.NewPacketConn(l) + } + + var pktConnV6 *ipv6.PacketConn + var mdns6Err error + if useV6 { + var l transport.UDPConn + l, mdns6Err = n.ListenUDP("udp6", addr6) + if mdns6Err != nil { + log.Errorf("Failed to enable mDNS over IPv6: (%s)", mdns6Err) + return nil, MulticastDNSModeDisabled, nil + } + pktConnV6 = ipv6.NewPacketConn(l) + } + + if mdns4Err != nil && mdns6Err != nil { // If ICE fails to start MulticastDNS server just warn the user and continue - log.Errorf("Failed to enable mDNS, continuing in mDNS disabled mode: (%s)", mdnsErr) + log.Errorf("Failed to enable mDNS, continuing in mDNS disabled mode") + //nolint:nilerr return nil, MulticastDNSModeDisabled, nil } + var ifcs []net.Interface + if interfaces != nil { + ifcs = make([]net.Interface, 0, len(ifcs)) + for _, ifc := range interfaces { + ifcs = append(ifcs, ifc.Interface) + } + } switch mDNSMode { case MulticastDNSModeQueryOnly: - conn, err := mdns.Server(ipv4.NewPacketConn(l), nil, &mdns.Config{}) + conn, err := mdns.Server(pktConnV4, pktConnV6, &mdns.Config{ + Interfaces: ifcs, + IncludeLoopback: includeLoopback, + }) return conn, mDNSMode, err case MulticastDNSModeQueryAndGather: - conn, err := mdns.Server(ipv4.NewPacketConn(l), nil, &mdns.Config{ - LocalNames: []string{mDNSName}, + conn, err := mdns.Server(pktConnV4, pktConnV6, &mdns.Config{ + Interfaces: ifcs, + IncludeLoopback: includeLoopback, + LocalNames: []string{mDNSName}, }) return conn, mDNSMode, err default: diff --git a/mdns_test.go b/mdns_test.go index a82d6b1d..617492b2 100644 --- a/mdns_test.go +++ b/mdns_test.go @@ -22,30 +22,51 @@ func TestMulticastDNSOnlyConnection(t *testing.T) { // Limit runtime in case of deadlocks defer test.TimeOut(time.Second * 30).Stop() - cfg := &AgentConfig{ - NetworkTypes: []NetworkType{NetworkTypeUDP4}, - CandidateTypes: []CandidateType{CandidateTypeHost}, - MulticastDNSMode: MulticastDNSModeQueryAndGather, + type testCase struct { + Name string + NetworkTypes []NetworkType } - aAgent, err := NewAgent(cfg) - require.NoError(t, err) + testCases := []testCase{ + {Name: "UDP4", NetworkTypes: []NetworkType{NetworkTypeUDP4}}, + } + + if ipv6Available(t) { + testCases = append(testCases, + testCase{Name: "UDP6", NetworkTypes: []NetworkType{NetworkTypeUDP6}}, + testCase{Name: "UDP46", NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}}, + ) + } - aNotifier, aConnected := onConnected() - require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + cfg := &AgentConfig{ + NetworkTypes: tc.NetworkTypes, + CandidateTypes: []CandidateType{CandidateTypeHost}, + MulticastDNSMode: MulticastDNSModeQueryAndGather, + InterfaceFilter: problematicNetworkInterfaces, + } - bAgent, err := NewAgent(cfg) - require.NoError(t, err) + aAgent, err := NewAgent(cfg) + require.NoError(t, err) + + aNotifier, aConnected := onConnected() + require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) - bNotifier, bConnected := onConnected() - require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) + bAgent, err := NewAgent(cfg) + require.NoError(t, err) - connect(aAgent, bAgent) - <-aConnected - <-bConnected + bNotifier, bConnected := onConnected() + require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) - require.NoError(t, aAgent.Close()) - require.NoError(t, bAgent.Close()) + connect(aAgent, bAgent) + <-aConnected + <-bConnected + + require.NoError(t, aAgent.Close()) + require.NoError(t, bAgent.Close()) + }) + } } func TestMulticastDNSMixedConnection(t *testing.T) { @@ -54,32 +75,54 @@ func TestMulticastDNSMixedConnection(t *testing.T) { // Limit runtime in case of deadlocks defer test.TimeOut(time.Second * 30).Stop() - aAgent, err := NewAgent(&AgentConfig{ - NetworkTypes: []NetworkType{NetworkTypeUDP4}, - CandidateTypes: []CandidateType{CandidateTypeHost}, - MulticastDNSMode: MulticastDNSModeQueryAndGather, - }) - require.NoError(t, err) - - aNotifier, aConnected := onConnected() - require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) - - bAgent, err := NewAgent(&AgentConfig{ - NetworkTypes: []NetworkType{NetworkTypeUDP4}, - CandidateTypes: []CandidateType{CandidateTypeHost}, - MulticastDNSMode: MulticastDNSModeQueryOnly, - }) - require.NoError(t, err) + type testCase struct { + Name string + NetworkTypes []NetworkType + } - bNotifier, bConnected := onConnected() - require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) + testCases := []testCase{ + {Name: "UDP4", NetworkTypes: []NetworkType{NetworkTypeUDP4}}, + } - connect(aAgent, bAgent) - <-aConnected - <-bConnected + if ipv6Available(t) { + testCases = append(testCases, + testCase{Name: "UDP6", NetworkTypes: []NetworkType{NetworkTypeUDP6}}, + testCase{Name: "UDP46", NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}}, + ) + } - require.NoError(t, aAgent.Close()) - require.NoError(t, bAgent.Close()) + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + aAgent, err := NewAgent(&AgentConfig{ + NetworkTypes: tc.NetworkTypes, + CandidateTypes: []CandidateType{CandidateTypeHost}, + MulticastDNSMode: MulticastDNSModeQueryAndGather, + InterfaceFilter: problematicNetworkInterfaces, + }) + require.NoError(t, err) + + aNotifier, aConnected := onConnected() + require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) + + bAgent, err := NewAgent(&AgentConfig{ + NetworkTypes: tc.NetworkTypes, + CandidateTypes: []CandidateType{CandidateTypeHost}, + MulticastDNSMode: MulticastDNSModeQueryOnly, + InterfaceFilter: problematicNetworkInterfaces, + }) + require.NoError(t, err) + + bNotifier, bConnected := onConnected() + require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) + + connect(aAgent, bAgent) + <-aConnected + <-bConnected + + require.NoError(t, aAgent.Close()) + require.NoError(t, bAgent.Close()) + }) + } } func TestMulticastDNSStaticHostName(t *testing.T) { @@ -87,32 +130,54 @@ func TestMulticastDNSStaticHostName(t *testing.T) { defer test.TimeOut(time.Second * 30).Stop() - _, err := NewAgent(&AgentConfig{ - NetworkTypes: []NetworkType{NetworkTypeUDP4}, - CandidateTypes: []CandidateType{CandidateTypeHost}, - MulticastDNSMode: MulticastDNSModeQueryAndGather, - MulticastDNSHostName: "invalidHostName", - }) - require.Equal(t, err, ErrInvalidMulticastDNSHostName) - - agent, err := NewAgent(&AgentConfig{ - NetworkTypes: []NetworkType{NetworkTypeUDP4}, - CandidateTypes: []CandidateType{CandidateTypeHost}, - MulticastDNSMode: MulticastDNSModeQueryAndGather, - MulticastDNSHostName: "validName.local", - }) - require.NoError(t, err) + type testCase struct { + Name string + NetworkTypes []NetworkType + } - correctHostName, resolveFunc := context.WithCancel(context.Background()) - require.NoError(t, agent.OnCandidate(func(c Candidate) { - if c != nil && c.Address() == "validName.local" { - resolveFunc() - } - })) + testCases := []testCase{ + {Name: "UDP4", NetworkTypes: []NetworkType{NetworkTypeUDP4}}, + } - require.NoError(t, agent.GatherCandidates()) - <-correctHostName.Done() - require.NoError(t, agent.Close()) + if ipv6Available(t) { + testCases = append(testCases, + testCase{Name: "UDP6", NetworkTypes: []NetworkType{NetworkTypeUDP6}}, + testCase{Name: "UDP46", NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}}, + ) + } + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + _, err := NewAgent(&AgentConfig{ + NetworkTypes: tc.NetworkTypes, + CandidateTypes: []CandidateType{CandidateTypeHost}, + MulticastDNSMode: MulticastDNSModeQueryAndGather, + MulticastDNSHostName: "invalidHostName", + InterfaceFilter: problematicNetworkInterfaces, + }) + require.Equal(t, err, ErrInvalidMulticastDNSHostName) + + agent, err := NewAgent(&AgentConfig{ + NetworkTypes: tc.NetworkTypes, + CandidateTypes: []CandidateType{CandidateTypeHost}, + MulticastDNSMode: MulticastDNSModeQueryAndGather, + MulticastDNSHostName: "validName.local", + InterfaceFilter: problematicNetworkInterfaces, + }) + require.NoError(t, err) + + correctHostName, resolveFunc := context.WithCancel(context.Background()) + require.NoError(t, agent.OnCandidate(func(c Candidate) { + if c != nil && c.Address() == "validName.local" { + resolveFunc() + } + })) + + require.NoError(t, agent.GatherCandidates()) + <-correctHostName.Done() + require.NoError(t, agent.Close()) + }) + } } func TestGenerateMulticastDNSName(t *testing.T) { diff --git a/net.go b/net.go index f686d411..6e740a70 100644 --- a/net.go +++ b/net.go @@ -5,6 +5,7 @@ package ice import ( "net" + "net/netip" "github.com/pion/logging" "github.com/pion/transport/v3" @@ -12,12 +13,15 @@ import ( // The conditions of invalidation written below are defined in // https://tools.ietf.org/html/rfc8445#section-5.1.1.1 -func isSupportedIPv6(ip net.IP) bool { +// It is partial because the link-local check is done later in various gather local +// candidate methods which conditionally accept IPv6 based on usage of mDNS or not. +func isSupportedIPv6Partial(ip net.IP) bool { if len(ip) != net.IPv6len || + // Deprecated IPv4-compatible IPv6 addresses [RFC4291] and IPv6 site- + // local unicast addresses [RFC3879] MUST NOT be included in the + // address candidates. isZeros(ip[0:12]) || // !(IPv4-compatible IPv6) - ip[0] == 0xfe && ip[1]&0xc0 == 0xc0 || // !(IPv6 site-local unicast) - ip.IsLinkLocalUnicast() || - ip.IsLinkLocalMulticast() { + ip[0] == 0xfe && ip[1]&0xc0 == 0xc0 { // !(IPv6 site-local unicast) return false } return true @@ -32,21 +36,35 @@ func isZeros(ip net.IP) bool { return true } -func localInterfaces(n transport.Net, interfaceFilter func(string) bool, ipFilter func(net.IP) bool, networkTypes []NetworkType, includeLoopback bool) ([]net.IP, error) { //nolint:gocognit - ips := []net.IP{} +//nolint:gocognit +func localInterfaces( + n transport.Net, + interfaceFilter func(string) bool, + ipFilter func(net.IP) bool, + networkTypes []NetworkType, + includeLoopback bool, +) ([]*transport.Interface, []netip.Addr, error) { + ipAddrs := []netip.Addr{} ifaces, err := n.Interfaces() if err != nil { - return ips, err + return nil, ipAddrs, err } - var IPv4Requested, IPv6Requested bool - for _, typ := range networkTypes { - if typ.IsIPv4() { - IPv4Requested = true - } + filteredIfaces := make([]*transport.Interface, 0, len(ifaces)) + + var ipV4Requested, ipv6Requested bool + if len(networkTypes) == 0 { + ipV4Requested = true + ipv6Requested = true + } else { + for _, typ := range networkTypes { + if typ.IsIPv4() { + ipV4Requested = true + } - if typ.IsIPv6() { - IPv6Requested = true + if typ.IsIPv6() { + ipv6Requested = true + } } } @@ -62,41 +80,41 @@ func localInterfaces(n transport.Net, interfaceFilter func(string) bool, ipFilte continue } - addrs, err := iface.Addrs() + ifaceAddrs, err := iface.Addrs() if err != nil { continue } - for _, addr := range addrs { - var ip net.IP - switch addr := addr.(type) { - case *net.IPNet: - ip = addr.IP - case *net.IPAddr: - ip = addr.IP - } - if ip == nil || (ip.IsLoopback() && !includeLoopback) { + atLeastOneAddr := false + for _, addr := range ifaceAddrs { + ipAddr, _, _, err := parseAddrFromIface(addr, iface.Name) + if err != nil || (ipAddr.IsLoopback() && !includeLoopback) { continue } - - if ipv4 := ip.To4(); ipv4 == nil { - if !IPv6Requested { + if ipAddr.Is6() { + if !ipv6Requested { continue - } else if !isSupportedIPv6(ip) { + } else if !isSupportedIPv6Partial(ipAddr.AsSlice()) { continue } - } else if !IPv4Requested { + } else if !ipV4Requested { continue } - if ipFilter != nil && !ipFilter(ip) { + if ipFilter != nil && !ipFilter(ipAddr.AsSlice()) { continue } - ips = append(ips, ip) + atLeastOneAddr = true + ipAddrs = append(ipAddrs, ipAddr) + } + + if atLeastOneAddr { + ifaceCopy := iface + filteredIfaces = append(filteredIfaces, ifaceCopy) } } - return ips, nil + return filteredIfaces, ipAddrs, nil } func listenUDPInPortRange(n transport.Net, log logging.LeveledLogger, portMax, portMin int, network string, lAddr *net.UDPAddr) (transport.UDPConn, error) { diff --git a/net_test.go b/net_test.go index 949fd813..12f0a9e2 100644 --- a/net_test.go +++ b/net_test.go @@ -5,40 +5,68 @@ package ice import ( "net" + "net/netip" + "strings" "testing" "github.com/stretchr/testify/require" ) -func TestIsSupportedIPv6(t *testing.T) { - if isSupportedIPv6(net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1}) { - t.Errorf("isSupportedIPv6 return true with IPv4-compatible IPv6 address") +func TestIsSupportedIPv6Partial(t *testing.T) { + if isSupportedIPv6Partial(net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1}) { + t.Errorf("isSupportedIPv6Partial returned true with IPv4-compatible IPv6 address") } - if isSupportedIPv6(net.ParseIP("fec0::2333")) { - t.Errorf("isSupportedIPv6 return true with IPv6 site-local unicast address") + if isSupportedIPv6Partial(net.ParseIP("fec0::2333")) { + t.Errorf("isSupportedIPv6Partial returned true with IPv6 site-local unicast address") } - if isSupportedIPv6(net.ParseIP("fe80::2333")) { - t.Errorf("isSupportedIPv6 return true with IPv6 link-local address") + if !isSupportedIPv6Partial(net.ParseIP("fe80::2333")) { + t.Errorf("isSupportedIPv6Partial returned false with IPv6 link-local address") } - if isSupportedIPv6(net.ParseIP("ff02::2333")) { - t.Errorf("isSupportedIPv6 return true with IPv6 link-local multicast address") + if !isSupportedIPv6Partial(net.ParseIP("ff02::2333")) { + t.Errorf("isSupportedIPv6Partial returned false with IPv6 link-local multicast address") } - if !isSupportedIPv6(net.ParseIP("2001::1")) { - t.Errorf("isSupportedIPv6 return false with IPv6 global unicast address") + if !isSupportedIPv6Partial(net.ParseIP("2001::1")) { + t.Errorf("isSupportedIPv6Partial returned false with IPv6 global unicast address") } } func TestCreateAddr(t *testing.T) { - ipv4 := net.IP{127, 0, 0, 1} - ipv6 := net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} + ipv4 := mustAddr(t, net.IP{127, 0, 0, 1}) + ipv6 := mustAddr(t, net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}) port := 9000 - require.Equal(t, &net.UDPAddr{IP: ipv4, Port: port}, createAddr(NetworkTypeUDP4, ipv4, port)) - require.Equal(t, &net.UDPAddr{IP: ipv6, Port: port}, createAddr(NetworkTypeUDP6, ipv6, port)) - require.Equal(t, &net.TCPAddr{IP: ipv4, Port: port}, createAddr(NetworkTypeTCP4, ipv4, port)) - require.Equal(t, &net.TCPAddr{IP: ipv6, Port: port}, createAddr(NetworkTypeTCP6, ipv6, port)) + require.Equal(t, &net.UDPAddr{IP: ipv4.AsSlice(), Port: port}, createAddr(NetworkTypeUDP4, ipv4, port)) + require.Equal(t, &net.UDPAddr{IP: ipv6.AsSlice(), Port: port}, createAddr(NetworkTypeUDP6, ipv6, port)) + require.Equal(t, &net.TCPAddr{IP: ipv4.AsSlice(), Port: port}, createAddr(NetworkTypeTCP4, ipv4, port)) + require.Equal(t, &net.TCPAddr{IP: ipv6.AsSlice(), Port: port}, createAddr(NetworkTypeTCP6, ipv6, port)) +} + +func problematicNetworkInterfaces(s string) bool { + defaultDockerBridgeNetwork := strings.Contains(s, "docker") + customDockerBridgeNetwork := strings.Contains(s, "br-") + + // Apple filters + accessPoint := strings.Contains(s, "ap") + appleWirelessDirectLink := strings.Contains(s, "awdl") + appleLowLatencyWLANInterface := strings.Contains(s, "llw") + appleTunnelingInterface := strings.Contains(s, "utun") + return !defaultDockerBridgeNetwork && + !customDockerBridgeNetwork && + !accessPoint && + !appleWirelessDirectLink && + !appleLowLatencyWLANInterface && + !appleTunnelingInterface +} + +func mustAddr(t *testing.T, ip net.IP) netip.Addr { + t.Helper() + addr, ok := netip.AddrFromSlice(ip) + if !ok { + t.Fatal(ipConvertError{ip}) + } + return addr } diff --git a/networktype.go b/networktype.go index 57df1863..376b56c7 100644 --- a/networktype.go +++ b/networktype.go @@ -5,7 +5,7 @@ package ice import ( "fmt" - "net" + "net/netip" "strings" ) @@ -116,18 +116,18 @@ func (t NetworkType) IsIPv6() bool { // determineNetworkType determines the type of network based on // the short network string and an IP address. -func determineNetworkType(network string, ip net.IP) (NetworkType, error) { - ipv4 := ip.To4() != nil - +func determineNetworkType(network string, ip netip.Addr) (NetworkType, error) { + // we'd rather have an IPv4-mapped IPv6 become IPv4 so that it is usable. + ip = ip.Unmap() switch { case strings.HasPrefix(strings.ToLower(network), udp): - if ipv4 { + if ip.Is4() { return NetworkTypeUDP4, nil } return NetworkTypeUDP6, nil case strings.HasPrefix(strings.ToLower(network), tcp): - if ipv4 { + if ip.Is4() { return NetworkTypeTCP4, nil } return NetworkTypeTCP6, nil diff --git a/networktype_test.go b/networktype_test.go index eb4a2e47..d327af69 100644 --- a/networktype_test.go +++ b/networktype_test.go @@ -45,7 +45,7 @@ func TestNetworkTypeParsing_Success(t *testing.T) { NetworkTypeUDP6, }, } { - actual, err := determineNetworkType(test.inNetwork, test.inIP) + actual, err := determineNetworkType(test.inNetwork, mustAddr(t, test.inIP)) if err != nil { t.Errorf("NetworkTypeParsing failed: %v", err) } @@ -70,7 +70,7 @@ func TestNetworkTypeParsing_Failure(t *testing.T) { ipv6, }, } { - actual, err := determineNetworkType(test.inNetwork, test.inIP) + actual, err := determineNetworkType(test.inNetwork, mustAddr(t, test.inIP)) if err == nil { t.Errorf("NetworkTypeParsing should fail: '%s' -- input:%s actual:%s", test.name, test.inNetwork, actual) diff --git a/tcp_mux.go b/tcp_mux.go index c5608b39..dfedad1b 100644 --- a/tcp_mux.go +++ b/tcp_mux.go @@ -142,6 +142,7 @@ func (m *TCPMuxDefault) createConn(ufrag string, isIPv6 bool, local net.IP, from return nil, ErrGetTransportAddress } localAddr := *addr + // Note: this is missing zone for IPv6 localAddr.IP = local var alive time.Duration @@ -169,13 +170,15 @@ func (m *TCPMuxDefault) createConn(ufrag string, isIPv6 bool, local net.IP, from m.connsIPv4[ufrag] = conns } } - conns[ipAddr(local.String())] = conn + // Note: this is missing zone for IPv6 + connKey := ipAddr(local.String()) + conns[connKey] = conn m.wg.Add(1) go func() { defer m.wg.Done() <-conn.CloseChannel() - m.removeConnByUfragAndLocalHost(ufrag, local) + m.removeConnByUfragAndLocalHost(ufrag, connKey) }() return conn, nil @@ -259,6 +262,7 @@ func (m *TCPMuxDefault) handleConn(conn net.Conn) { return } m.mu.Lock() + packetConn, ok := m.getConn(ufrag, isIPv6, localAddr.IP) if !ok { packetConn, err = m.createConn(ufrag, isIPv6, localAddr.IP, true) @@ -334,15 +338,14 @@ func (m *TCPMuxDefault) RemoveConnByUfrag(ufrag string) { } } -func (m *TCPMuxDefault) removeConnByUfragAndLocalHost(ufrag string, local net.IP) { +func (m *TCPMuxDefault) removeConnByUfragAndLocalHost(ufrag string, localIPAddr ipAddr) { removedConns := make([]*tcpPacketConn, 0, 4) - localIP := ipAddr(local.String()) // Keep lock section small to avoid deadlock with conn lock m.mu.Lock() if conns, ok := m.connsIPv4[ufrag]; ok { - if conn, ok := conns[localIP]; ok { - delete(conns, localIP) + if conn, ok := conns[localIPAddr]; ok { + delete(conns, localIPAddr) if len(conns) == 0 { delete(m.connsIPv4, ufrag) } @@ -350,8 +353,8 @@ func (m *TCPMuxDefault) removeConnByUfragAndLocalHost(ufrag string, local net.IP } } if conns, ok := m.connsIPv6[ufrag]; ok { - if conn, ok := conns[localIP]; ok { - delete(conns, localIP) + if conn, ok := conns[localIPAddr]; ok { + delete(conns, localIPAddr) if len(conns) == 0 { delete(m.connsIPv6, ufrag) } @@ -375,7 +378,9 @@ func (m *TCPMuxDefault) getConn(ufrag string, isIPv6 bool, local net.IP) (val *t conns, ok = m.connsIPv4[ufrag] } if conns != nil { - val, ok = conns[ipAddr(local.String())] + // Note: this is missing zone for IPv6 + connKey := ipAddr(local.String()) + val, ok = conns[connKey] } return diff --git a/tcp_mux_multi.go b/tcp_mux_multi.go index e32acbf3..71fc570a 100644 --- a/tcp_mux_multi.go +++ b/tcp_mux_multi.go @@ -3,7 +3,9 @@ package ice -import "net" +import ( + "net" +) // AllConnsGetter allows multiple fixed TCP ports to be used, // each of which is multiplexed like TCPMux. AllConnsGetter also acts as diff --git a/tcp_mux_test.go b/tcp_mux_test.go index c6d9748c..dc8dd8e3 100644 --- a/tcp_mux_test.go +++ b/tcp_mux_test.go @@ -62,7 +62,10 @@ func TestTCPMux_Recv(t *testing.T) { n, err := writeStreamingPacket(conn, msg.Raw) require.NoError(t, err, "error writing TCP STUN packet") - pktConn, err := tcpMux.GetConnByUfrag("myufrag", false, listener.Addr().(*net.TCPAddr).IP) + listenerAddr, ok := listener.Addr().(*net.TCPAddr) + require.True(t, ok) + + pktConn, err := tcpMux.GetConnByUfrag("myufrag", false, listenerAddr.IP) require.NoError(t, err, "error retrieving muxed connection for ufrag") defer func() { _ = pktConn.Close() @@ -111,12 +114,15 @@ func TestTCPMux_NoDeadlockWhenClosingUnusedPacketConn(t *testing.T) { _ = tcpMux.Close() }() - _, err = tcpMux.GetConnByUfrag("test", false, listener.Addr().(*net.TCPAddr).IP) + listenerAddr, ok := listener.Addr().(*net.TCPAddr) + require.True(t, ok) + + _, err = tcpMux.GetConnByUfrag("test", false, listenerAddr.IP) require.NoError(t, err, "error getting conn by ufrag") require.NoError(t, tcpMux.Close(), "error closing tcpMux") - conn, err := tcpMux.GetConnByUfrag("test", false, listener.Addr().(*net.TCPAddr).IP) + conn, err := tcpMux.GetConnByUfrag("test", false, listenerAddr.IP) require.Nil(t, conn, "should receive nil because mux is closed") require.Equal(t, io.ErrClosedPipe, err, "should receive error because mux is closed") } @@ -231,7 +237,10 @@ func TestTCPMux_NoLeakForConnectionFromStun(t *testing.T) { // wait for the connection to be created time.Sleep(100 * time.Millisecond) - pktConn, err := tcpMux.GetConnByUfrag("myufrag2", false, listener.Addr().(*net.TCPAddr).IP) + listenerAddr, ok := listener.Addr().(*net.TCPAddr) + require.True(t, ok) + + pktConn, err := tcpMux.GetConnByUfrag("myufrag2", false, listenerAddr.IP) require.NoError(t, err, "error retrieving muxed connection for ufrag") defer func() { _ = pktConn.Close() diff --git a/transport_test.go b/transport_test.go index 216ee8b7..c7907a17 100644 --- a/transport_test.go +++ b/transport_test.go @@ -9,6 +9,7 @@ package ice import ( "context" "net" + "net/netip" "sync" "testing" "time" @@ -175,13 +176,20 @@ func gatherAndExchangeCandidates(aAgent, bAgent *Agent) { candidates, err := aAgent.GetLocalCandidates() check(err) + for _, c := range candidates { + if addr, parseErr := netip.ParseAddr(c.Address()); parseErr == nil { + if shouldFilterLocationTrackedIP(addr) { + panic(addr) + } + } candidateCopy, copyErr := c.copy() check(copyErr) check(bAgent.AddRemoteCandidate(candidateCopy)) } candidates, err = bAgent.GetLocalCandidates() + check(err) for _, c := range candidates { candidateCopy, copyErr := c.copy() diff --git a/udp_mux.go b/udp_mux.go index dc45458f..40e8ad9b 100644 --- a/udp_mux.go +++ b/udp_mux.go @@ -69,19 +69,19 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { } var localAddrsForUnspecified []net.Addr - if addr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok { + if udpAddr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok { params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr()) - } else if ok && addr.IP.IsUnspecified() { + } else if ok && udpAddr.IP.IsUnspecified() { // For unspecified addresses, the correct behavior is to return errListenUnspecified, but // it will break the applications that are already using unspecified UDP connection // with UDPMuxDefault, so print a warn log and create a local address list for mux. params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") var networks []NetworkType switch { - case addr.IP.To4() != nil: + case udpAddr.IP.To4() != nil: networks = []NetworkType{NetworkTypeUDP4} - case addr.IP.To16() != nil: + case udpAddr.IP.To16() != nil: networks = []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6} default: @@ -95,10 +95,14 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { } } - ips, err := localInterfaces(params.Net, nil, nil, networks, true) + _, addrs, err := localInterfaces(params.Net, nil, nil, networks, true) if err == nil { - for _, ip := range ips { - localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port}) + for _, addr := range addrs { + localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{ + IP: addr.AsSlice(), + Port: udpAddr.Port, + Zone: addr.Zone(), + }) } } else { params.Logger.Errorf("Failed to get local interfaces for unspecified addr: %v", err) @@ -304,7 +308,7 @@ func (m *UDPMuxDefault) connWorker() { logger.Errorf("Underlying PacketConn did not return a UDPAddr") return } - udpAddr, err := newIPPort(netUDPAddr.IP, uint16(netUDPAddr.Port)) + udpAddr, err := newIPPort(netUDPAddr.IP, netUDPAddr.Zone, uint16(netUDPAddr.Port)) if err != nil { logger.Errorf("Failed to create a new IP/Port host pair") return @@ -378,14 +382,14 @@ type ipPort struct { // newIPPort create a custom type of address based on netip.Addr and // port. The underlying ip address passed is converted to IPv6 format // to simplify ip address handling -func newIPPort(ip net.IP, port uint16) (ipPort, error) { +func newIPPort(ip net.IP, zone string, port uint16) (ipPort, error) { n, ok := netip.AddrFromSlice(ip.To16()) if !ok { return ipPort{}, errInvalidIPAddress } return ipPort{ - addr: n, + addr: n.WithZone(zone), port: port, }, nil } diff --git a/udp_mux_multi.go b/udp_mux_multi.go index c46db9bf..2594cb02 100644 --- a/udp_mux_multi.go +++ b/udp_mux_multi.go @@ -90,14 +90,18 @@ func NewMultiUDPMuxFromPort(port int, opts ...UDPMuxFromPortOption) (*MultiUDPMu } } - ips, err := localInterfaces(params.net, params.ifFilter, params.ipFilter, params.networks, params.includeLoopback) + _, addrs, err := localInterfaces(params.net, params.ifFilter, params.ipFilter, params.networks, params.includeLoopback) if err != nil { return nil, err } - conns := make([]net.PacketConn, 0, len(ips)) - for _, ip := range ips { - conn, listenErr := params.net.ListenUDP("udp", &net.UDPAddr{IP: ip, Port: port}) + conns := make([]net.PacketConn, 0, len(addrs)) + for _, addr := range addrs { + conn, listenErr := params.net.ListenUDP("udp", &net.UDPAddr{ + IP: addr.AsSlice(), + Port: port, + Zone: addr.Zone(), + }) if listenErr != nil { err = listenErr break diff --git a/udp_mux_multi_test.go b/udp_mux_multi_test.go index bb12022a..f5611be5 100644 --- a/udp_mux_multi_test.go +++ b/udp_mux_multi_test.go @@ -8,7 +8,6 @@ package ice import ( "net" - "strings" "sync" "testing" "time" @@ -114,11 +113,7 @@ func TestUnspecifiedUDPMux(t *testing.T) { defer test.TimeOut(time.Second * 30).Stop() muxPort := 7778 - udpMuxMulti, err := NewMultiUDPMuxFromPort(muxPort, UDPMuxFromPortWithInterfaceFilter(func(s string) bool { - defaultDockerBridgeNetwork := strings.Contains(s, "docker") - customDockerBridgeNetwork := strings.Contains(s, "br-") - return !defaultDockerBridgeNetwork && !customDockerBridgeNetwork - })) + udpMuxMulti, err := NewMultiUDPMuxFromPort(muxPort, UDPMuxFromPortWithInterfaceFilter(problematicNetworkInterfaces)) require.NoError(t, err) require.GreaterOrEqual(t, len(udpMuxMulti.muxes), 1, "at least have 1 muxes") diff --git a/udp_mux_test.go b/udp_mux_test.go index 5f8b1e0d..8b64f651 100644 --- a/udp_mux_test.go +++ b/udp_mux_test.go @@ -50,13 +50,31 @@ func TestUDPMux(t *testing.T) { network string } - for _, subTest := range []testCase{ + testCases := []testCase{ {name: "IPv4loopback", conn: conn4, network: udp4}, {name: "IPv6loopback", conn: conn6, network: udp6}, {name: "Unspecified", conn: connUnspecified, network: udp}, {name: "IPv4Unspecified", conn: conn4Unspecified, network: udp4}, {name: "IPv6Unspecified", conn: conn6Unspecified, network: udp6}, - } { + } + + if ipv6Available(t) { + addr6 := getLocalIPAddress(t, NetworkTypeUDP6) + + conn6Unspecified, listenEerr := net.ListenUDP(udp, &net.UDPAddr{ + IP: addr6.AsSlice(), + Zone: addr6.Zone(), + }) + if listenEerr != nil { + t.Log("IPv6 is not supported on this machine") + } + + testCases = append(testCases, + testCase{name: "IPv6Specified", conn: conn6Unspecified, network: udp6}, + ) + } + + for _, subTest := range testCases { network, conn := subTest.network, subTest.conn if udpConn, ok := conn.(*net.UDPConn); !ok || udpConn == nil { continue diff --git a/udp_muxed_conn.go b/udp_muxed_conn.go index fb05e231..0244d130 100644 --- a/udp_muxed_conn.go +++ b/udp_muxed_conn.go @@ -86,7 +86,7 @@ func (c *udpMuxedConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) { return 0, errFailedToCastUDPAddr } - ipAndPort, err := newIPPort(netUDPAddr.IP, uint16(netUDPAddr.Port)) + ipAndPort, err := newIPPort(netUDPAddr.IP, netUDPAddr.Zone, uint16(netUDPAddr.Port)) if err != nil { return 0, err }