Skip to content

Commit

Permalink
Initial netip implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
mrmelon54 committed Feb 3, 2025
1 parent a9ab227 commit f197857
Show file tree
Hide file tree
Showing 12 changed files with 141 additions and 177 deletions.
41 changes: 22 additions & 19 deletions client_integration_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package wgctrl_test

import (
"bytes"
"errors"
"fmt"
"net"
"net/netip"
"os"
"sort"
"strings"
Expand Down Expand Up @@ -144,9 +144,9 @@ func testGet(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) {
func testConfigure(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) {
var (
port = 8888
ips = []net.IPNet{
wgtest.MustCIDR("192.0.2.0/32"),
wgtest.MustCIDR("2001:db8::/128"),
ips = []netip.Prefix{
netip.MustParsePrefix("192.0.2.0/32"),
netip.MustParsePrefix("2001:db8::/128"),
}

priv = wgtest.MustPrivateKey()
Expand Down Expand Up @@ -194,7 +194,7 @@ func testConfigure(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) {
for i := range dn.Peers {
ips := dn.Peers[i].AllowedIPs
sort.Slice(ips, func(i, j int) bool {
return bytes.Compare(ips[i].IP, ips[j].IP) > 0
return ips[i].Addr().Compare(ips[j].Addr()) > 0
})
}

Expand Down Expand Up @@ -229,17 +229,19 @@ func testConfigureManyIPs(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) {
t.Fatalf("failed to create cursor: %v", err)
}

var ips []net.IPNet
var ips []netip.Prefix
for pos := cur.Next(); pos != nil; pos = cur.Next() {
bits := 128
if pos.IP.To4() != nil {
bits = 32
}

ips = append(ips, net.IPNet{
IP: pos.IP,
Mask: net.CIDRMask(bits, bits),
})
addr, ok := netip.AddrFromSlice(pos.IP)
if !ok {
t.Fatalf("failed to convert net.IP to netip.Addr: %s", pos.IP)
}

ips = append(ips, netip.PrefixFrom(addr, bits))
}

peers = append(peers, wgtypes.PeerConfig{
Expand Down Expand Up @@ -291,7 +293,7 @@ func testConfigureManyPeers(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) {
PresharedKey: &pk,
ReplaceAllowedIPs: true,
Endpoint: &net.UDPAddr{
IP: ips[0].IP,
IP: ips[0].Addr().AsSlice(),
Port: 1111,
},
PersistentKeepaliveInterval: &dur,
Expand Down Expand Up @@ -370,7 +372,6 @@ func testConfigurePeersUpdateOnly(t *testing.T, c *wgctrl.Client, d *wgtypes.Dev
t.Skip("FreeBSD kernel devices do not support UpdateOnly flag")
}


t.Fatalf("failed to configure second time on %q: %v", d.Name, err)
}

Expand Down Expand Up @@ -428,7 +429,7 @@ func countPeerIPs(d *wgtypes.Device) int {
return count
}

func ipsString(ipns []net.IPNet) string {
func ipsString(ipns []netip.Prefix) string {
ss := make([]string, 0, len(ipns))
for _, ipn := range ipns {
ss = append(ss, ipn.String())
Expand All @@ -437,23 +438,25 @@ func ipsString(ipns []net.IPNet) string {
return strings.Join(ss, ", ")
}

func generateIPs(n int) []net.IPNet {
func generateIPs(n int) []netip.Prefix {
cur, err := ipaddr.Parse("2001:db8::/64")
if err != nil {
panicf("failed to create cursor: %v", err)
}

ips := make([]net.IPNet, 0, n)
ips := make([]netip.Prefix, 0, n)
for i := 0; i < n; i++ {
pos := cur.Next()
if pos == nil {
panic("hit nil IP during IP generation")
}

ips = append(ips, net.IPNet{
IP: pos.IP,
Mask: net.CIDRMask(128, 128),
})
addr, ok := netip.AddrFromSlice(pos.IP)
if !ok {
panicf("failed to convert net.IP to netip.Addr: %s", pos.IP)
}

ips = append(ips, netip.PrefixFrom(addr, 128))
}

return ips
Expand Down
4 changes: 2 additions & 2 deletions cmd/wgctrl/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"flag"
"fmt"
"log"
"net"
"net/netip"
"strings"

"golang.zx2c4.com/wireguard/wgctrl"
Expand Down Expand Up @@ -83,7 +83,7 @@ func printPeer(p wgtypes.Peer) {
)
}

func ipsString(ipns []net.IPNet) string {
func ipsString(ipns []netip.Prefix) string {
ss := make([]string, 0, len(ipns))
for _, ipn := range ipns {
ss = append(ss, ipn.String())
Expand Down
4 changes: 2 additions & 2 deletions internal/wglinux/client_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package wglinux
import (
"errors"
"fmt"
"net"
"net/netip"
"os"
"os/user"
"syscall"
Expand Down Expand Up @@ -325,7 +325,7 @@ func diffAttrs(x, y []netlink.Attribute) string {
return cmp.Diff(xPrime, yPrime)
}

func mustAllowedIPs(ipns []net.IPNet) []byte {
func mustAllowedIPs(ipns []netip.Prefix) []byte {
ae := netlink.NewAttributeEncoder()
if err := encodeAllowedIPs(ipns)(ae); err != nil {
panicf("failed to create allowed IP attributes: %v", err)
Expand Down
46 changes: 15 additions & 31 deletions internal/wglinux/configure_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/binary"
"fmt"
"net"
"net/netip"
"unsafe"

"github.com/mdlayher/netlink"
Expand Down Expand Up @@ -101,16 +102,16 @@ func buildBatches(cfg wgtypes.Config) []wgtypes.Config {
// Iterate until no more allowed IPs.
var done bool
for !done {
var tmp []net.IPNet
var tmp []netip.Prefix
if len(p.AllowedIPs) < ipBatchChunk {
// IPs all fit within a batch; we are done.
tmp = make([]net.IPNet, len(p.AllowedIPs))
tmp = make([]netip.Prefix, len(p.AllowedIPs))
copy(tmp, p.AllowedIPs)
done = true
} else {
// IPs are larger than a single batch, copy a batch out and
// advance the cursor.
tmp = make([]net.IPNet, ipBatchChunk)
tmp = make([]netip.Prefix, ipBatchChunk)
copy(tmp, p.AllowedIPs[:ipBatchChunk])

p.AllowedIPs = p.AllowedIPs[ipBatchChunk:]
Expand Down Expand Up @@ -214,59 +215,52 @@ func encodePeer(p wgtypes.PeerConfig) func(ae *netlink.AttributeEncoder) error {
// sockaddr_in or sockaddr_in6 bytes.
func encodeSockaddr(endpoint net.UDPAddr) func() ([]byte, error) {
return func() ([]byte, error) {
if !isValidIP(endpoint.IP) {
addrPort := endpoint.AddrPort()
if !addrPort.Addr().IsValid() {
return nil, fmt.Errorf("wglinux: invalid endpoint IP: %s", endpoint.IP.String())
}

// Is this an IPv6 address?
if isIPv6(endpoint.IP) {
var addr [16]byte
copy(addr[:], endpoint.IP.To16())

if addrPort.Addr().Is6() {
sa := unix.RawSockaddrInet6{
Family: unix.AF_INET6,
Port: sockaddrPort(endpoint.Port),
Addr: addr,
Addr: addrPort.Addr().As16(),
}

return (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:], nil
}

// IPv4 address handling.
var addr [4]byte
copy(addr[:], endpoint.IP.To4())

sa := unix.RawSockaddrInet4{
Family: unix.AF_INET,
Port: sockaddrPort(endpoint.Port),
Addr: addr,
Addr: addrPort.Addr().As4(),
}

return (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:], nil
}
}

// encodeAllowedIPs returns a function to encode allowed IP nested attributes.
func encodeAllowedIPs(ipns []net.IPNet) func(ae *netlink.AttributeEncoder) error {
func encodeAllowedIPs(ipns []netip.Prefix) func(ae *netlink.AttributeEncoder) error {
return func(ae *netlink.AttributeEncoder) error {
for i, ipn := range ipns {
if !isValidIP(ipn.IP) {
return fmt.Errorf("wglinux: invalid allowed IP: %s", ipn.IP.String())
if !ipn.Addr().IsValid() {
return fmt.Errorf("wglinux: invalid allowed IP: %s", ipn.Addr())
}

family := uint16(unix.AF_INET6)
if !isIPv6(ipn.IP) {
if ipn.Addr().Is4() {
// Make sure address is 4 bytes if IPv4.
family = unix.AF_INET
ipn.IP = ipn.IP.To4()
}

// Netlink arrays use type as an array index.
ae.Nested(uint16(i), func(nae *netlink.AttributeEncoder) error {
nae.Uint16(unix.WGALLOWEDIP_A_FAMILY, family)
nae.Bytes(unix.WGALLOWEDIP_A_IPADDR, ipn.IP)
nae.Bytes(unix.WGALLOWEDIP_A_IPADDR, ipn.Addr().AsSlice())

ones, _ := ipn.Mask.Size()
ones := ipn.Bits()
nae.Uint8(unix.WGALLOWEDIP_A_CIDR_MASK, uint8(ones))
return nil
})
Expand All @@ -276,16 +270,6 @@ func encodeAllowedIPs(ipns []net.IPNet) func(ae *netlink.AttributeEncoder) error
}
}

// isValidIP determines if IP is a valid IPv4 or IPv6 address.
func isValidIP(ip net.IP) bool {
return ip.To16() != nil
}

// isIPv6 determines if IP is a valid IPv6 address.
func isIPv6(ip net.IP) bool {
return isValidIP(ip) && ip.To4() == nil
}

// sockaddrPort interprets port as a big endian uint16 for use passing sockaddr
// structures to the kernel.
func sockaddrPort(port int) uint16 {
Expand Down
49 changes: 26 additions & 23 deletions internal/wglinux/configure_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package wglinux

import (
"net"
"net/netip"
"testing"
"time"
"unsafe"
Expand Down Expand Up @@ -45,9 +46,9 @@ func TestLinuxClientConfigureDevice(t *testing.T) {
name: "bad peer allowed IP",
cfg: wgtypes.Config{
Peers: []wgtypes.PeerConfig{{
AllowedIPs: []net.IPNet{{
IP: net.IP{0xff},
}},
AllowedIPs: []netip.Prefix{
{},
},
}},
},
},
Expand All @@ -71,8 +72,8 @@ func TestLinuxClientConfigureDevice(t *testing.T) {
PresharedKey: keyPtr(wgtest.MustHexKey("188515093e952f5f22e865cef3012e72f8b5f0b598ac0309d5dacce3b70fcf52")),
Endpoint: wgtest.MustUDPAddr("[abcd:23::33%2]:51820"),
ReplaceAllowedIPs: true,
AllowedIPs: []net.IPNet{
wgtest.MustCIDR("192.168.4.4/32"),
AllowedIPs: []netip.Prefix{
netip.MustParsePrefix("192.168.4.4/32"),
},
},
{
Expand All @@ -81,17 +82,17 @@ func TestLinuxClientConfigureDevice(t *testing.T) {
Endpoint: wgtest.MustUDPAddr("182.122.22.19:3233"),
PersistentKeepaliveInterval: durPtr(111 * time.Second),
ReplaceAllowedIPs: true,
AllowedIPs: []net.IPNet{
wgtest.MustCIDR("192.168.4.6/32"),
AllowedIPs: []netip.Prefix{
netip.MustParsePrefix("192.168.4.6/32"),
},
},
{
PublicKey: wgtest.MustHexKey("662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58"),
Endpoint: wgtest.MustUDPAddr("5.152.198.39:51820"),
ReplaceAllowedIPs: true,
AllowedIPs: []net.IPNet{
wgtest.MustCIDR("192.168.4.10/32"),
wgtest.MustCIDR("192.168.4.11/32"),
AllowedIPs: []netip.Prefix{
netip.MustParsePrefix("192.168.4.10/32"),
netip.MustParsePrefix("192.168.4.11/32"),
},
},
{
Expand Down Expand Up @@ -151,8 +152,8 @@ func TestLinuxClientConfigureDevice(t *testing.T) {
},
{
Type: netlink.Nested | unix.WGPEER_A_ALLOWEDIPS,
Data: mustAllowedIPs([]net.IPNet{
wgtest.MustCIDR("192.168.4.4/32"),
Data: mustAllowedIPs([]netip.Prefix{
netip.MustParsePrefix("192.168.4.4/32"),
}),
},
}...),
Expand Down Expand Up @@ -182,8 +183,8 @@ func TestLinuxClientConfigureDevice(t *testing.T) {
},
{
Type: netlink.Nested | unix.WGPEER_A_ALLOWEDIPS,
Data: mustAllowedIPs([]net.IPNet{
wgtest.MustCIDR("192.168.4.6/32"),
Data: mustAllowedIPs([]netip.Prefix{
netip.MustParsePrefix("192.168.4.6/32"),
}),
},
}...),
Expand All @@ -209,9 +210,9 @@ func TestLinuxClientConfigureDevice(t *testing.T) {
},
{
Type: netlink.Nested | unix.WGPEER_A_ALLOWEDIPS,
Data: mustAllowedIPs([]net.IPNet{
wgtest.MustCIDR("192.168.4.10/32"),
wgtest.MustCIDR("192.168.4.11/32"),
Data: mustAllowedIPs([]netip.Prefix{
netip.MustParsePrefix("192.168.4.10/32"),
netip.MustParsePrefix("192.168.4.11/32"),
}),
},
}...),
Expand Down Expand Up @@ -513,23 +514,25 @@ func keyBytes(s string) []byte {
return k[:]
}

func generateIPs(n int) []net.IPNet {
func generateIPs(n int) []netip.Prefix {
cur, err := ipaddr.Parse("2001:db8::/64")
if err != nil {
panicf("failed to create cursor: %v", err)
}

ips := make([]net.IPNet, 0, n)
ips := make([]netip.Prefix, 0, n)
for i := 0; i < n; i++ {
pos := cur.Next()
if pos == nil {
panic("hit nil IP during IP generation")
}

ips = append(ips, net.IPNet{
IP: pos.IP,
Mask: net.CIDRMask(128, 128),
})
addr, ok := netip.AddrFromSlice(pos.IP)
if !ok {
panicf("failed to convert net.IP to netip.Addr: %s", pos.IP)
}

ips = append(ips, netip.PrefixFrom(addr, 128))
}

return ips
Expand Down
Loading

0 comments on commit f197857

Please sign in to comment.