Skip to content

Commit 5ddf680

Browse files
committed
connmap: track [destaddr, (cid,uid)]
Track both owner uid and connection id against a destination address to later filter out connection to destination that were attributed to protect.UidSelf (which are connections owned by Rethink). protect.UidSelf is ever useful in Loopback mode (routing Rethink's traffic back within Rethink's tunnel), and so, it mostly contains, among genuine traffic originating from Rethink (like DNS queries to user-selected upstreams, for example), traffic "mirrored" or duplicated (cloned) to the same destination (as it was loopbacked in) belonging to another uid. Tracking both (cid, uid) against a destination addr lets us ignore all protect.UidSelf connections as "non-mirrored".
1 parent bfd4393 commit 5ddf680

File tree

2 files changed

+56
-60
lines changed

2 files changed

+56
-60
lines changed

intra/common.go

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/celzero/firestack/intra/core"
2020
"github.com/celzero/firestack/intra/dnsx"
2121
"github.com/celzero/firestack/intra/log"
22+
"github.com/celzero/firestack/intra/protect"
2223
)
2324

2425
// pipe copies data from src to dst, and returns the number of bytes copied.
@@ -67,8 +68,10 @@ func download(cid string, local net.Conn, remote net.Conn) (n int64, err error)
6768
// It also sends a summary to the listener when done. Always called in a goroutine.
6869
func forward(local net.Conn, remote net.Conn, t core.ConnMapper, l SocketListener, smm *SocketSummary) {
6970
cid := smm.ID
71+
uid := smm.UID
72+
ct := core.ConnTuple{CID: cid, UID: uid}
7073

71-
t.Track(cid, local, remote)
74+
t.Track(ct, local, remote)
7275
defer t.Untrack(cid)
7376

7477
uploadch := make(chan ioinfo)
@@ -136,14 +139,6 @@ func stall(m *core.ExpMap, k string) (secs uint32) {
136139
return
137140
}
138141

139-
func netipFrom(ip net.IP) *netip.Addr {
140-
if addr, ok := netip.AddrFromSlice(ip); ok {
141-
addr = addr.Unmap()
142-
return &addr
143-
}
144-
return nil
145-
}
146-
147142
func oneRealIp(realips string, origipp netip.AddrPort) netip.AddrPort {
148143
if len(realips) <= 0 {
149144
return origipp
@@ -208,7 +203,7 @@ func hasActiveConn(cm core.ConnMapper, ip, ips string) bool {
208203
return false
209204
}
210205
// TODO: filter by protocol (tcp/udp) when finding conns
211-
return len(cm.Find(ip)) > 0 || len(cm.FindAny(ips)) > 0
206+
return !hasSelfUid(cm.Find(ip), true) || !hasSelfUid(cm.FindAll(ips), true)
212207
}
213208

214209
// returns proxy-id, conn-id, user-id
@@ -227,26 +222,6 @@ func ipp(addr net.Addr) (netip.AddrPort, error) {
227222
return netip.ParseAddrPort(addr.String())
228223
}
229224

230-
func addr2ip(a net.Addr) string {
231-
if a == nil {
232-
return ""
233-
}
234-
switch x := a.(type) {
235-
case *net.TCPAddr:
236-
return x.IP.String()
237-
case *net.UDPAddr:
238-
return x.IP.String()
239-
case *net.IPAddr:
240-
return x.IP.String()
241-
case *net.IPNet:
242-
return x.IP.String()
243-
}
244-
if b, err := netip.ParseAddrPort(a.String()); err == nil {
245-
return b.Addr().String()
246-
}
247-
return ""
248-
}
249-
250225
func conn2str(a net.Conn, b net.Conn) string {
251226
ar := a.RemoteAddr()
252227
br := b.RemoteAddr()
@@ -266,6 +241,18 @@ func closeconns(cm core.ConnMapper, cids []string) (closed []string) {
266241
return closed
267242
}
268243

244+
func hasSelfUid(t []core.ConnTuple, d bool) bool {
245+
if len(t) <= 0 {
246+
return d // default
247+
}
248+
for _, x := range t {
249+
if x.UID == protect.UidSelf {
250+
return true
251+
}
252+
}
253+
return false // regardless of d
254+
}
255+
269256
func clos(c ...net.Conn) {
270257
for _, x := range c {
271258
if x != nil {

intra/core/connmap.go

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,48 +12,55 @@ import (
1212
"sync"
1313
)
1414

15+
type ConnTuple struct {
16+
CID string // conn id
17+
UID string // proc id
18+
}
19+
1520
type ConnMapper interface {
1621
Clear() []string
17-
Track(id string, x ...net.Conn) int
18-
Find(dst string) (ids []string)
19-
FindAny(csvdst string) (ids []string)
20-
Get(id string) []net.Conn
21-
Untrack(id string) int
22-
UntrackBatch(ids []string) []string
22+
Track(t ConnTuple, x ...net.Conn) int
23+
Find(dst string) (t []ConnTuple)
24+
FindAll(csvdst string) (t []ConnTuple)
25+
Get(cid string) []net.Conn
26+
Untrack(cid string) int
27+
UntrackBatch(cids []string) []string
2328
}
2429

2530
type cm struct {
2631
sync.RWMutex
27-
conntracker map[string][]net.Conn // id -> conns
28-
dsttracker map[string][]string // dst ipport -> ids
32+
conntracker map[string][]net.Conn // id -> conns
33+
dsttracker map[string][]ConnTuple // dst ipport -> conntuple
2934
}
3035

3136
var _ ConnMapper = (*cm)(nil)
3237

3338
func NewConnMap() *cm {
3439
return &cm{
3540
conntracker: make(map[string][]net.Conn),
36-
dsttracker: make(map[string][]string),
41+
dsttracker: make(map[string][]ConnTuple),
3742
}
3843
}
3944

40-
func (h *cm) Track(cid string, conns ...net.Conn) (n int) {
45+
func (h *cm) Track(t ConnTuple, conns ...net.Conn) (n int) {
4146
h.Lock()
4247
defer h.Unlock()
4348

49+
cid := t.CID
50+
4451
if v, ok := h.conntracker[cid]; !ok {
4552
h.conntracker[cid] = conns
4653
n = len(conns)
4754
} else { // should not happen?
4855
h.conntracker[cid] = append(v, conns...)
4956
n = len(v) + len(conns)
5057
}
51-
h.trackDstLocked(cid, conns)
58+
h.trackDstLocked(t, conns)
5259

5360
return
5461
}
5562

56-
func (h *cm) trackDstLocked(cid string, conns []net.Conn) {
63+
func (h *cm) trackDstLocked(t ConnTuple, conns []net.Conn) {
5764
for _, c := range conns {
5865
if c == nil {
5966
continue
@@ -63,10 +70,10 @@ func (h *cm) trackDstLocked(cid string, conns []net.Conn) {
6370
continue
6471
}
6572
dst := raddr.String()
66-
if ids, ok := h.dsttracker[dst]; ok {
67-
h.dsttracker[dst] = append(ids, cid)
73+
if tups, ok := h.dsttracker[dst]; ok {
74+
h.dsttracker[dst] = append(tups, t)
6875
} else {
69-
h.dsttracker[dst] = []string{cid}
76+
h.dsttracker[dst] = []ConnTuple{t}
7077
}
7178
}
7279
}
@@ -92,12 +99,12 @@ func (h *cm) untrackDstLocked(cid string, c net.Conn) {
9299
return
93100
}
94101
dst := raddr.String()
95-
if ids, ok := h.dsttracker[dst]; ok {
96-
for i, id := range ids {
97-
if id == cid {
102+
if tups, ok := h.dsttracker[dst]; ok {
103+
for i, t := range tups {
104+
if t.CID == cid {
98105
// ids[i+1:] does not panic if i+1 is out of range
99106
// go.dev/play/p/troeQ5djf9h
100-
h.dsttracker[dst] = append(ids[:i], ids[i+1:]...)
107+
h.dsttracker[dst] = append(tups[:i], tups[i+1:]...)
101108
break
102109
}
103110
}
@@ -125,31 +132,33 @@ func (h *cm) UntrackBatch(cids []string) (out []string) {
125132
return
126133
}
127134

128-
func (h *cm) Get(id string) (conns []net.Conn) {
135+
func (h *cm) Get(cid string) (conns []net.Conn) {
129136
h.RLock()
130137
defer h.RUnlock()
131138

132-
if conns, ok := h.conntracker[id]; ok {
139+
if conns, ok := h.conntracker[cid]; ok {
133140
return conns
134141
}
135142
return
136143
}
137144

138-
func (h *cm) Find(dst string) (ids []string) {
145+
func (h *cm) Find(dst string) (tups []ConnTuple) {
139146
if len(dst) == 0 {
140147
return
141148
}
142149

143150
h.RLock()
144151
defer h.RUnlock()
145152

146-
if ids, ok := h.dsttracker[dst]; ok {
147-
return ids
153+
if tups, ok := h.dsttracker[dst]; ok {
154+
return tups
148155
}
149156
return
150157
}
151158

152-
func (h *cm) FindAny(csvdst string) (ids []string) {
159+
func (h *cm) FindAll(csvdst string) (out []ConnTuple) {
160+
out = make([]ConnTuple, 0)
161+
153162
if len(csvdst) == 0 {
154163
return
155164
}
@@ -159,25 +168,25 @@ func (h *cm) FindAny(csvdst string) (ids []string) {
159168

160169
dsts := strings.Split(csvdst, ",")
161170
for _, dst := range dsts {
162-
if ids, ok := h.dsttracker[string(dst)]; ok {
163-
return ids
171+
if tups, ok := h.dsttracker[dst]; ok {
172+
out = append(out, tups...)
164173
}
165174
}
166175
return
167176
}
168177

169-
func (h *cm) Clear() (ids []string) {
178+
func (h *cm) Clear() (cids []string) {
170179
h.Lock()
171180
defer h.Unlock()
172181

173-
ids = make([]string, 0, len(h.conntracker))
182+
cids = make([]string, 0, len(h.conntracker))
174183
for k, v := range h.conntracker {
175184
for _, c := range v {
176185
if c != nil {
177186
go c.Close()
178187
}
179188
}
180-
ids = append(ids, k)
189+
cids = append(cids, k)
181190
}
182191
clear(h.conntracker)
183192
clear(h.dsttracker)

0 commit comments

Comments
 (0)