Skip to content

Commit 47bf93b

Browse files
committed
tcp,udp: realips lack port to find active conns
1 parent 2b43ab1 commit 47bf93b

File tree

4 files changed

+16
-10
lines changed

4 files changed

+16
-10
lines changed

intra/common.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,12 @@ func undoAlg(r dnsx.Resolver, algip netip.Addr) (realips, domains, probableDomai
193193
return
194194
}
195195

196-
func hasActiveConn(cm core.ConnMapper, ip, ips string) bool {
196+
func hasActiveConn(cm core.ConnMapper, ipp, ips, port string) bool {
197197
if cm == nil {
198198
return false
199199
}
200200
// TODO: filter by protocol (tcp/udp) when finding conns
201-
return !hasSelfUid(cm.Find(ip), true) || !hasSelfUid(cm.FindAll(ips), true)
201+
return !hasSelfUid(cm.Find(ipp), true) || !hasSelfUid(cm.FindAll(ips, port), true)
202202
}
203203

204204
// returns proxy-id, conn-id, user-id

intra/core/connmap.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ type ConnMapper interface {
2121
Clear() []string
2222
Track(t ConnTuple, x ...net.Conn) int
2323
Find(dst string) (t []ConnTuple)
24-
FindAll(csvdst string) (t []ConnTuple)
24+
FindAll(csvips, port string) (t []ConnTuple)
2525
Get(cid string) []net.Conn
2626
Untrack(cid string) int
2727
UntrackBatch(cids []string) []string
@@ -72,6 +72,7 @@ func (h *cm) trackDstLocked(t ConnTuple, conns []net.Conn) {
7272
}
7373
dst := raddr.String()
7474
if tups, ok := h.dsttracker[dst]; ok {
75+
// TODO: do not add dup tuples (cid)
7576
h.dsttracker[dst] = append(tups, t)
7677
} else {
7778
h.dsttracker[dst] = []ConnTuple{t}
@@ -158,18 +159,19 @@ func (h *cm) Find(dst string) (tups []ConnTuple) {
158159
return
159160
}
160161

161-
func (h *cm) FindAll(csvdst string) (out []ConnTuple) {
162+
func (h *cm) FindAll(csvips, port string) (out []ConnTuple) {
162163
out = make([]ConnTuple, 0)
163164

164-
if len(csvdst) == 0 {
165+
if len(csvips) == 0 {
165166
return
166167
}
167168

168169
h.RLock()
169170
defer h.RUnlock()
170171

171-
dsts := strings.Split(csvdst, ",")
172+
dsts := strings.Split(csvips, ",")
172173
for _, dst := range dsts {
174+
dst = net.JoinHostPort(dst, port)
173175
if tups, ok := h.dsttracker[dst]; ok {
174176
out = append(out, tups...)
175177
}

intra/tcp.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030
"fmt"
3131
"net"
3232
"net/netip"
33+
"strconv"
3334
"time"
3435

3536
"github.com/celzero/firestack/intra/dnsx"
@@ -117,7 +118,8 @@ func (h *tcpHandler) onFlow(localaddr, target netip.AddrPort, realips, domains,
117118
var proto int32 = 6 // tcp
118119
src := localaddr.String()
119120
dst := target.String()
120-
dup := hasActiveConn(h.conntracker, dst, realips)
121+
dport := strconv.Itoa(int(target.Port()))
122+
dup := hasActiveConn(h.conntracker, dst, realips, dport)
121123
res := h.listener.Flow(proto, uid, dup, src, dst, realips, domains, probableDomains, blocklists)
122124

123125
if res == nil {
@@ -278,7 +280,7 @@ func (h *tcpHandler) handle(px ipn.Proxy, src net.Conn, target netip.AddrPort, s
278280
if r := recover(); r != nil {
279281
log.W("tcp: forward: panic %v", r)
280282
}
281-
defer h.conntracker.Untrack(ct.CID)
283+
h.conntracker.Untrack(ct.CID)
282284
}()
283285
forward(src, dst, l, smm) // src always *gonet.TCPConn
284286
}()

intra/udp.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030
"io"
3131
"net"
3232
"net/netip"
33+
"strconv"
3334
"time"
3435

3536
"github.com/celzero/firestack/intra/dnsx"
@@ -118,9 +119,10 @@ func (h *udpHandler) onFlow(localaddr, target netip.AddrPort, realips, domains,
118119
}
119120

120121
src := localaddr.String()
121-
dst := "" // unconnected udp sockets may not have a valid target
122+
dst, dport := "", "" // unconnected udp sockets may not have a valid target
122123
if target.IsValid() {
123124
dst = target.String()
125+
dport = strconv.Itoa(int(target.Port()))
124126
}
125127
if len(realips) <= 0 || len(domains) <= 0 {
126128
log.V("udp: onFlow: no realips(%s) or domains(%s + %s), for src=%s dst=%s", realips, domains, probableDomains, localaddr, dst)
@@ -136,7 +138,7 @@ func (h *udpHandler) onFlow(localaddr, target netip.AddrPort, realips, domains,
136138
}
137139

138140
var proto int32 = 17 // udp
139-
dup := hasActiveConn(h.conntracker, dst, realips)
141+
dup := hasActiveConn(h.conntracker, dst, realips, dport)
140142
res := h.listener.Flow(proto, uid, dup, src, dst, realips, domains, probableDomains, blocklists)
141143

142144
if res == nil {

0 commit comments

Comments
 (0)