Skip to content

Commit fd4cbc9

Browse files
committed
connmap: track [destaddr, (cid,uid)]
1 parent bfd4393 commit fd4cbc9

File tree

2 files changed

+56
-60
lines changed

2 files changed

+56
-60
lines changed

intra/common.go

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/celzero/firestack/intra/core"
2020
"github.com/celzero/firestack/intra/dnsx"
2121
"github.com/celzero/firestack/intra/log"
22+
"github.com/celzero/firestack/intra/protect"
2223
)
2324

2425
// pipe copies data from src to dst, and returns the number of bytes copied.
@@ -67,8 +68,10 @@ func download(cid string, local net.Conn, remote net.Conn) (n int64, err error)
6768
// It also sends a summary to the listener when done. Always called in a goroutine.
6869
func forward(local net.Conn, remote net.Conn, t core.ConnMapper, l SocketListener, smm *SocketSummary) {
6970
cid := smm.ID
71+
uid := smm.UID
72+
ct := core.ConnTuple{CID: cid, UID: uid}
7073

71-
t.Track(cid, local, remote)
74+
t.Track(ct, local, remote)
7275
defer t.Untrack(cid)
7376

7477
uploadch := make(chan ioinfo)
@@ -136,14 +139,6 @@ func stall(m *core.ExpMap, k string) (secs uint32) {
136139
return
137140
}
138141

139-
func netipFrom(ip net.IP) *netip.Addr {
140-
if addr, ok := netip.AddrFromSlice(ip); ok {
141-
addr = addr.Unmap()
142-
return &addr
143-
}
144-
return nil
145-
}
146-
147142
func oneRealIp(realips string, origipp netip.AddrPort) netip.AddrPort {
148143
if len(realips) <= 0 {
149144
return origipp
@@ -208,7 +203,7 @@ func hasActiveConn(cm core.ConnMapper, ip, ips string) bool {
208203
return false
209204
}
210205
// TODO: filter by protocol (tcp/udp) when finding conns
211-
return len(cm.Find(ip)) > 0 || len(cm.FindAny(ips)) > 0
206+
return !hasSelfUid(cm.Find(ip), true) || !hasSelfUid(cm.FindAll(ips), true)
212207
}
213208

214209
// returns proxy-id, conn-id, user-id
@@ -227,26 +222,6 @@ func ipp(addr net.Addr) (netip.AddrPort, error) {
227222
return netip.ParseAddrPort(addr.String())
228223
}
229224

230-
func addr2ip(a net.Addr) string {
231-
if a == nil {
232-
return ""
233-
}
234-
switch x := a.(type) {
235-
case *net.TCPAddr:
236-
return x.IP.String()
237-
case *net.UDPAddr:
238-
return x.IP.String()
239-
case *net.IPAddr:
240-
return x.IP.String()
241-
case *net.IPNet:
242-
return x.IP.String()
243-
}
244-
if b, err := netip.ParseAddrPort(a.String()); err == nil {
245-
return b.Addr().String()
246-
}
247-
return ""
248-
}
249-
250225
func conn2str(a net.Conn, b net.Conn) string {
251226
ar := a.RemoteAddr()
252227
br := b.RemoteAddr()
@@ -266,6 +241,18 @@ func closeconns(cm core.ConnMapper, cids []string) (closed []string) {
266241
return closed
267242
}
268243

244+
func hasSelfUid(t []core.ConnTuple, d bool) bool {
245+
if len(t) <= 0 {
246+
return d // default
247+
}
248+
for _, x := range t {
249+
if x.UID == protect.UidSelf {
250+
return true
251+
}
252+
}
253+
return false // regardless of d
254+
}
255+
269256
func clos(c ...net.Conn) {
270257
for _, x := range c {
271258
if x != nil {

intra/core/connmap.go

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,48 +12,55 @@ import (
1212
"sync"
1313
)
1414

15+
type ConnTuple struct {
16+
CID string // conn id
17+
UID string // proc id
18+
}
19+
1520
type ConnMapper interface {
1621
Clear() []string
17-
Track(id string, x ...net.Conn) int
18-
Find(dst string) (ids []string)
19-
FindAny(csvdst string) (ids []string)
20-
Get(id string) []net.Conn
21-
Untrack(id string) int
22-
UntrackBatch(ids []string) []string
22+
Track(t ConnTuple, x ...net.Conn) int
23+
Find(dst string) (t []ConnTuple)
24+
FindAll(csvdst string) (t []ConnTuple)
25+
Get(cid string) []net.Conn
26+
Untrack(cid string) int
27+
UntrackBatch(cids []string) []string
2328
}
2429

2530
type cm struct {
2631
sync.RWMutex
27-
conntracker map[string][]net.Conn // id -> conns
28-
dsttracker map[string][]string // dst ipport -> ids
32+
conntracker map[string][]net.Conn // id -> conns
33+
dsttracker map[string][]ConnTuple // dst ipport -> conntuple
2934
}
3035

3136
var _ ConnMapper = (*cm)(nil)
3237

3338
func NewConnMap() *cm {
3439
return &cm{
3540
conntracker: make(map[string][]net.Conn),
36-
dsttracker: make(map[string][]string),
41+
dsttracker: make(map[string][]ConnTuple),
3742
}
3843
}
3944

40-
func (h *cm) Track(cid string, conns ...net.Conn) (n int) {
45+
func (h *cm) Track(t ConnTuple, conns ...net.Conn) (n int) {
4146
h.Lock()
4247
defer h.Unlock()
4348

49+
cid := t.CID
50+
4451
if v, ok := h.conntracker[cid]; !ok {
4552
h.conntracker[cid] = conns
4653
n = len(conns)
4754
} else { // should not happen?
4855
h.conntracker[cid] = append(v, conns...)
4956
n = len(v) + len(conns)
5057
}
51-
h.trackDstLocked(cid, conns)
58+
h.trackDstLocked(t, conns)
5259

5360
return
5461
}
5562

56-
func (h *cm) trackDstLocked(cid string, conns []net.Conn) {
63+
func (h *cm) trackDstLocked(t ConnTuple, conns []net.Conn) {
5764
for _, c := range conns {
5865
if c == nil {
5966
continue
@@ -63,10 +70,10 @@ func (h *cm) trackDstLocked(cid string, conns []net.Conn) {
6370
continue
6471
}
6572
dst := raddr.String()
66-
if ids, ok := h.dsttracker[dst]; ok {
67-
h.dsttracker[dst] = append(ids, cid)
73+
if tups, ok := h.dsttracker[dst]; ok {
74+
h.dsttracker[dst] = append(tups, t)
6875
} else {
69-
h.dsttracker[dst] = []string{cid}
76+
h.dsttracker[dst] = []ConnTuple{t}
7077
}
7178
}
7279
}
@@ -92,12 +99,12 @@ func (h *cm) untrackDstLocked(cid string, c net.Conn) {
9299
return
93100
}
94101
dst := raddr.String()
95-
if ids, ok := h.dsttracker[dst]; ok {
96-
for i, id := range ids {
97-
if id == cid {
102+
if tups, ok := h.dsttracker[dst]; ok {
103+
for i, t := range tups {
104+
if t.CID == cid {
98105
// ids[i+1:] does not panic if i+1 is out of range
99106
// go.dev/play/p/troeQ5djf9h
100-
h.dsttracker[dst] = append(ids[:i], ids[i+1:]...)
107+
h.dsttracker[dst] = append(tups[:i], tups[i+1:]...)
101108
break
102109
}
103110
}
@@ -125,31 +132,33 @@ func (h *cm) UntrackBatch(cids []string) (out []string) {
125132
return
126133
}
127134

128-
func (h *cm) Get(id string) (conns []net.Conn) {
135+
func (h *cm) Get(cid string) (conns []net.Conn) {
129136
h.RLock()
130137
defer h.RUnlock()
131138

132-
if conns, ok := h.conntracker[id]; ok {
139+
if conns, ok := h.conntracker[cid]; ok {
133140
return conns
134141
}
135142
return
136143
}
137144

138-
func (h *cm) Find(dst string) (ids []string) {
145+
func (h *cm) Find(dst string) (tups []ConnTuple) {
139146
if len(dst) == 0 {
140147
return
141148
}
142149

143150
h.RLock()
144151
defer h.RUnlock()
145152

146-
if ids, ok := h.dsttracker[dst]; ok {
147-
return ids
153+
if tups, ok := h.dsttracker[dst]; ok {
154+
return tups
148155
}
149156
return
150157
}
151158

152-
func (h *cm) FindAny(csvdst string) (ids []string) {
159+
func (h *cm) FindAll(csvdst string) (out []ConnTuple) {
160+
out = make([]ConnTuple, 0)
161+
153162
if len(csvdst) == 0 {
154163
return
155164
}
@@ -159,25 +168,25 @@ func (h *cm) FindAny(csvdst string) (ids []string) {
159168

160169
dsts := strings.Split(csvdst, ",")
161170
for _, dst := range dsts {
162-
if ids, ok := h.dsttracker[string(dst)]; ok {
163-
return ids
171+
if tups, ok := h.dsttracker[dst]; ok {
172+
out = append(out, tups...)
164173
}
165174
}
166175
return
167176
}
168177

169-
func (h *cm) Clear() (ids []string) {
178+
func (h *cm) Clear() (cids []string) {
170179
h.Lock()
171180
defer h.Unlock()
172181

173-
ids = make([]string, 0, len(h.conntracker))
182+
cids = make([]string, 0, len(h.conntracker))
174183
for k, v := range h.conntracker {
175184
for _, c := range v {
176185
if c != nil {
177186
go c.Close()
178187
}
179188
}
180-
ids = append(ids, k)
189+
cids = append(cids, k)
181190
}
182191
clear(h.conntracker)
183192
clear(h.dsttracker)

0 commit comments

Comments
 (0)