diff --git a/config_example.yml b/config_example.yml new file mode 100644 index 00000000..b62d2a71 --- /dev/null +++ b/config_example.yml @@ -0,0 +1,12 @@ +postgres: + password: 123456789 +log: + level: debug +dht_server: + addr: '::' + query_timeout: 8s +dht_crawler: + bootstrap_nodes: + - 'router.utorrent.com:6881' + - 'router.bittorrent.com:6881' + - "router.silotis.us:6881" \ No newline at end of file diff --git a/internal/protocol/dht/client/server_adapter.go b/internal/protocol/dht/client/server_adapter.go index 4424d2c2..c212d8ee 100644 --- a/internal/protocol/dht/client/server_adapter.go +++ b/internal/protocol/dht/client/server_adapter.go @@ -3,10 +3,11 @@ package client import ( "context" "errors" + "net/netip" + "github.com/bitmagnet-io/bitmagnet/internal/protocol" "github.com/bitmagnet-io/bitmagnet/internal/protocol/dht" "github.com/bitmagnet-io/bitmagnet/internal/protocol/dht/server" - "net/netip" ) type serverAdapter struct { @@ -92,13 +93,16 @@ func (a serverAdapter) SampleInfoHashes(ctx context.Context, addr netip.AddrPort } func extractNodes(msg dht.Msg) []NodeInfo { - if len(msg.R.Nodes) == 0 { + if len(msg.R.Nodes)+len(msg.R.Nodes6) == 0 { return nil } - nodes := make([]NodeInfo, 0, len(msg.R.Nodes)) + nodes := make([]NodeInfo, 0, len(msg.R.Nodes)+len(msg.R.Nodes6)) for _, n := range msg.R.Nodes { nodes = append(nodes, NodeInfo{ID: n.ID, Addr: n.Addr.ToAddrPort()}) } + for _, n6 := range msg.R.Nodes6 { + nodes = append(nodes, NodeInfo{ID: n6.ID, Addr: n6.Addr.ToAddrPort()}) + } return nodes } diff --git a/internal/protocol/dht/nodeaddr.go b/internal/protocol/dht/nodeaddr.go index 09b24fd8..5d117fef 100644 --- a/internal/protocol/dht/nodeaddr.go +++ b/internal/protocol/dht/nodeaddr.go @@ -3,10 +3,11 @@ package dht import ( "bytes" "encoding/binary" - "github.com/anacrolix/torrent/bencode" "net" "net/netip" "strconv" + + "github.com/anacrolix/torrent/bencode" ) type NodeAddr struct { diff --git a/internal/protocol/dht/nodeinfo.go b/internal/protocol/dht/nodeinfo.go index a96ab04f..ce2619fc 100644 --- a/internal/protocol/dht/nodeinfo.go +++ b/internal/protocol/dht/nodeinfo.go @@ -5,13 +5,14 @@ import ( "encoding" "encoding/binary" "fmt" - "github.com/anacrolix/missinggo/v2/slices" - "github.com/anacrolix/torrent/bencode" - "github.com/bitmagnet-io/bitmagnet/internal/protocol" "math" "math/rand" "net" "reflect" + + "github.com/anacrolix/missinggo/v2/slices" + "github.com/anacrolix/torrent/bencode" + "github.com/bitmagnet-io/bitmagnet/internal/protocol" ) type NodeInfo struct { diff --git a/internal/protocol/dht/responder/responder.go b/internal/protocol/dht/responder/responder.go index 194506f0..713f15ad 100644 --- a/internal/protocol/dht/responder/responder.go +++ b/internal/protocol/dht/responder/responder.go @@ -4,10 +4,11 @@ import ( "context" "crypto/md5" "encoding/hex" + "net/netip" + "github.com/bitmagnet-io/bitmagnet/internal/protocol" "github.com/bitmagnet-io/bitmagnet/internal/protocol/dht" "github.com/bitmagnet-io/bitmagnet/internal/protocol/dht/ktable" - "net/netip" ) type Responder interface { @@ -56,7 +57,8 @@ func (r responder) Respond(_ context.Context, msg dht.RecvMsg) (ret dht.Return, return } closestNodes := r.kTable.GetClosestNodes(args.Target) - ret.Nodes = nodeInfosFromNodes(closestNodes...) + + ret.Nodes, ret.Nodes6 = nodeInfosFromNodes(closestNodes...) case dht.QGetPeers: if args.InfoHash == [20]byte{} { err = ErrMissingArguments @@ -71,7 +73,7 @@ func (r responder) Respond(_ context.Context, msg dht.RecvMsg) (ret dht.Return, } ret.Values = values } - ret.Nodes = nodeInfosFromNodes(result.ClosestNodes...) + ret.Nodes, ret.Nodes6 = nodeInfosFromNodes(result.ClosestNodes...) token := r.announceToken(args.InfoHash, args.ID, msg.From.Addr()) ret.Token = &token case dht.QAnnouncePeer: @@ -93,7 +95,7 @@ func (r responder) Respond(_ context.Context, msg dht.RecvMsg) (ret dht.Return, samples = append(samples, h.ID()) } ret.Samples = &samples - ret.Nodes = nodeInfosFromNodes(result.Nodes...) + ret.Nodes, ret.Nodes6 = nodeInfosFromNodes(result.Nodes...) numInt64 := int64(result.TotalHashes) ret.Num = &numInt64 ret.Interval = &r.sampleInfoHashesInterval @@ -122,15 +124,41 @@ func (r responder) announceToken(infoHash protocol.ID, nodeID protocol.ID, nodeA return hex.EncodeToString(tokenHash[:]) } -func nodeInfosFromNodes(ns ...ktable.Node) []dht.NodeInfo { +func nodeInfosFromNodes(ns ...ktable.Node) ([]dht.NodeInfo, []dht.NodeInfo) { if len(ns) == 0 { - return nil + return nil, nil } - nodes := make([]dht.NodeInfo, 0, len(ns)) + ns_count, ns6_count := 0, 0 for _, n := range ns { - nodes = append(nodes, nodeInfoFromNode(n)) + if n.Addr().Addr().Is4() { + ns_count += 1 + } + } + for _, n := range ns { + if n.Addr().Addr().Is6() || n.Addr().Addr().Is4In6() { + ns6_count += 1 + } + } + nodes6 := make([]dht.NodeInfo, 0, ns_count) + + nodes := make([]dht.NodeInfo, 0, ns6_count) + for _, n := range ns { + if n.Addr().Addr().Is4() { + nodes = append(nodes, nodeInfoFromNode(n)) + } + } + for _, n := range ns { + if n.Addr().Addr().Is6() || n.Addr().Addr().Is4In6() { + nodes6 = append(nodes6, nodeInfoFromNode(n)) + } + } + if len(nodes) == 0 { + nodes = nil + } + if len(nodes6) == 0 { + nodes6 = nil } - return nodes + return nodes, nodes6 } func nodeInfoFromNode(n ktable.Node) dht.NodeInfo { diff --git a/internal/protocol/dht/responder/responder_test.go b/internal/protocol/dht/responder/responder_test.go index 4267a89f..7cd957cf 100644 --- a/internal/protocol/dht/responder/responder_test.go +++ b/internal/protocol/dht/responder/responder_test.go @@ -2,15 +2,16 @@ package responder import ( "context" + "net/netip" + "testing" + "time" + "github.com/anacrolix/dht/v2/krpc" "github.com/bitmagnet-io/bitmagnet/internal/protocol" "github.com/bitmagnet-io/bitmagnet/internal/protocol/dht" "github.com/bitmagnet-io/bitmagnet/internal/protocol/dht/ktable" - "github.com/bitmagnet-io/bitmagnet/internal/protocol/dht/ktable/mocks" + ktable_mocks "github.com/bitmagnet-io/bitmagnet/internal/protocol/dht/ktable/mocks" "github.com/stretchr/testify/assert" - "net/netip" - "testing" - "time" ) type testResponderMocks struct { @@ -145,6 +146,47 @@ func TestResponder_find_node(t *testing.T) { assert.NoError(t, err) } +func TestResponder_find_node_ipv6(t *testing.T) { + mocks := newTestResponderMocks(t) + target := protocol.RandomNodeID() + msg := dht.RecvMsg{ + From: mocks.sender.Addr.ToAddrPort(), + Msg: dht.Msg{ + Q: "find_node", + A: &dht.MsgArgs{ + ID: mocks.sender.ID, + Target: target, + }, + }, + } + nodes6 := dht.CompactIPv6NodeInfo{ + dht.RandomNodeInfo(16), + dht.RandomNodeInfo(16), + dht.RandomNodeInfo(16), + } + nodes := dht.CompactIPv4NodeInfo{ + dht.RandomNodeInfo(4), + dht.RandomNodeInfo(4), + dht.RandomNodeInfo(4), + } + peers := []ktable.Node{ + mockedPeer{nodes[0]}, + mockedPeer{nodes[1]}, + mockedPeer{nodes[2]}, + mockedPeer{nodes6[0]}, + mockedPeer{nodes6[1]}, + mockedPeer{nodes6[2]}, + } + mocks.table.On("GetClosestNodes", target).Return(peers) + ret, err := mocks.responder.Respond(context.Background(), msg) + assert.Equal(t, dht.Return{ + ID: mocks.nodeID, + Nodes6: nodes6, + Nodes: nodes, + }, ret) + assert.NoError(t, err) +} + func TestResponder_find_node__missing_target(t *testing.T) { mocks := newTestResponderMocks(t) msg := dht.RecvMsg{ diff --git a/internal/protocol/dht/server/config.go b/internal/protocol/dht/server/config.go index 09b6d93b..99cfa266 100644 --- a/internal/protocol/dht/server/config.go +++ b/internal/protocol/dht/server/config.go @@ -5,11 +5,13 @@ import "time" type Config struct { Port uint16 QueryTimeout time.Duration + Addr string } func NewDefaultConfig() Config { return Config{ Port: 3334, QueryTimeout: time.Second * 4, + Addr: "0.0.0.0", } } diff --git a/internal/protocol/dht/server/factory.go b/internal/protocol/dht/server/factory.go index 812080e0..22b2c60c 100644 --- a/internal/protocol/dht/server/factory.go +++ b/internal/protocol/dht/server/factory.go @@ -3,6 +3,9 @@ package server import ( "context" "fmt" + "net/netip" + "time" + "github.com/bitmagnet-io/bitmagnet/internal/boilerplate/lazy" "github.com/bitmagnet-io/bitmagnet/internal/concurrency" "github.com/bitmagnet-io/bitmagnet/internal/protocol/dht" @@ -11,8 +14,6 @@ import ( "go.uber.org/fx" "go.uber.org/zap" "golang.org/x/time/rate" - "net/netip" - "time" ) type Params struct { @@ -39,6 +40,18 @@ const subsystem = "dht_server" func New(p Params) Result { lastResponses := &concurrency.AtomicValue[LastResponses]{} collector := newPrometheusCollector() + addr, err := netip.ParseAddr(p.Config.Addr) + socket_ip_type := 4 + + if err != nil { + addr = netip.IPv4Unspecified() + } + if addr.Is4() { + socket_ip_type = 4 + } + if addr.Is6() || addr.Is4In6() { + socket_ip_type = 6 + } ls := lazy.New(func() (Server, error) { s := queryLimiter{ server: prometheusServerWrapper{ @@ -46,8 +59,8 @@ func New(p Params) Result { server: healthCollector{ baseServer: &server{ stopped: make(chan struct{}), - localAddr: netip.AddrPortFrom(netip.IPv4Unspecified(), p.Config.Port), - socket: NewSocket(), + localAddr: netip.AddrPortFrom(addr, p.Config.Port), + socket: NewSocket(socket_ip_type), queries: make(map[string]chan dht.RecvMsg), queryTimeout: p.Config.QueryTimeout, responder: p.Responder, diff --git a/internal/protocol/dht/server/socket.go b/internal/protocol/dht/server/socket.go index 928c17b5..5570c6d3 100644 --- a/internal/protocol/dht/server/socket.go +++ b/internal/protocol/dht/server/socket.go @@ -11,6 +11,6 @@ type Socket interface { Receive([]byte) (int, netip.AddrPort, error) } -func NewSocket() Socket { - return newSocket() +func NewSocket(ip_type int) Socket { + return newSocket(ip_type) } diff --git a/internal/protocol/dht/server/socket_unix.go b/internal/protocol/dht/server/socket_unix.go index 4f9eec2f..8ac543ee 100644 --- a/internal/protocol/dht/server/socket_unix.go +++ b/internal/protocol/dht/server/socket_unix.go @@ -10,8 +10,14 @@ import ( "golang.org/x/sys/unix" ) -func newSocket() Socket { - fd, sockErr := unix.Socket(unix.SOCK_DGRAM, unix.AF_INET, 0) +func newSocket(ip_type int) Socket { + var fd int + var sockErr error + if ip_type == 4 { + fd, sockErr = unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0) + } else if ip_type == 6 { + fd, sockErr = unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, 0) + } if sockErr != nil { panic(fmt.Errorf("error creating socket: %w", sockErr)) } diff --git a/internal/protocol/metainfo/metainforequester/requester.go b/internal/protocol/metainfo/metainforequester/requester.go index af755d69..e531a77d 100644 --- a/internal/protocol/metainfo/metainforequester/requester.go +++ b/internal/protocol/metainfo/metainforequester/requester.go @@ -6,15 +6,16 @@ import ( "encoding/binary" "errors" "fmt" - "github.com/anacrolix/torrent/bencode" - "github.com/anacrolix/torrent/peer_protocol" - "github.com/bitmagnet-io/bitmagnet/internal/protocol" - "github.com/bitmagnet-io/bitmagnet/internal/protocol/metainfo" "io" "math" "net" "net/netip" "time" + + "github.com/anacrolix/torrent/bencode" + "github.com/anacrolix/torrent/peer_protocol" + "github.com/bitmagnet-io/bitmagnet/internal/protocol" + "github.com/bitmagnet-io/bitmagnet/internal/protocol/metainfo" ) type Requester interface { @@ -111,7 +112,14 @@ func (r requester) Request(ctx context.Context, infoHash protocol.ID, addr netip } func (r requester) connect(ctx context.Context, addr netip.AddrPort) (conn *net.TCPConn, err error) { - c, dialErr := r.dialer.DialContext(ctx, "tcp4", addr.String()) + tcp := "tcp6" + if addr.Addr().Is4() { + tcp = "tcp4" + } + if addr.Addr().Is6() || addr.Addr().Is4In6() { + tcp = "tcp6" + } + c, dialErr := r.dialer.DialContext(ctx, tcp, addr.String()) if dialErr != nil { err = dialErr return