@@ -12,48 +12,55 @@ import (
12
12
"sync"
13
13
)
14
14
15
+ type ConnTuple struct {
16
+ CID string // conn id
17
+ UID string // proc id
18
+ }
19
+
15
20
type ConnMapper interface {
16
21
Clear () []string
17
- Track (id string , x ... net.Conn ) int
18
- Find (dst string ) (ids []string )
19
- FindAny (csvdst string ) (ids []string )
20
- Get (id string ) []net.Conn
21
- Untrack (id string ) int
22
- UntrackBatch (ids []string ) []string
22
+ Track (t ConnTuple , x ... net.Conn ) int
23
+ Find (dst string ) (t []ConnTuple )
24
+ FindAll (csvdst string ) (t []ConnTuple )
25
+ Get (cid string ) []net.Conn
26
+ Untrack (cid string ) int
27
+ UntrackBatch (cids []string ) []string
23
28
}
24
29
25
30
type cm struct {
26
31
sync.RWMutex
27
- conntracker map [string ][]net.Conn // id -> conns
28
- dsttracker map [string ][]string // dst ipport -> ids
32
+ conntracker map [string ][]net.Conn // id -> conns
33
+ dsttracker map [string ][]ConnTuple // dst ipport -> conntuple
29
34
}
30
35
31
36
var _ ConnMapper = (* cm )(nil )
32
37
33
38
func NewConnMap () * cm {
34
39
return & cm {
35
40
conntracker : make (map [string ][]net.Conn ),
36
- dsttracker : make (map [string ][]string ),
41
+ dsttracker : make (map [string ][]ConnTuple ),
37
42
}
38
43
}
39
44
40
- func (h * cm ) Track (cid string , conns ... net.Conn ) (n int ) {
45
+ func (h * cm ) Track (t ConnTuple , conns ... net.Conn ) (n int ) {
41
46
h .Lock ()
42
47
defer h .Unlock ()
43
48
49
+ cid := t .CID
50
+
44
51
if v , ok := h .conntracker [cid ]; ! ok {
45
52
h .conntracker [cid ] = conns
46
53
n = len (conns )
47
54
} else { // should not happen?
48
55
h .conntracker [cid ] = append (v , conns ... )
49
56
n = len (v ) + len (conns )
50
57
}
51
- h .trackDstLocked (cid , conns )
58
+ h .trackDstLocked (t , conns )
52
59
53
60
return
54
61
}
55
62
56
- func (h * cm ) trackDstLocked (cid string , conns []net.Conn ) {
63
+ func (h * cm ) trackDstLocked (t ConnTuple , conns []net.Conn ) {
57
64
for _ , c := range conns {
58
65
if c == nil {
59
66
continue
@@ -63,10 +70,10 @@ func (h *cm) trackDstLocked(cid string, conns []net.Conn) {
63
70
continue
64
71
}
65
72
dst := raddr .String ()
66
- if ids , ok := h .dsttracker [dst ]; ok {
67
- h .dsttracker [dst ] = append (ids , cid )
73
+ if tups , ok := h .dsttracker [dst ]; ok {
74
+ h .dsttracker [dst ] = append (tups , t )
68
75
} else {
69
- h .dsttracker [dst ] = []string { cid }
76
+ h .dsttracker [dst ] = []ConnTuple { t }
70
77
}
71
78
}
72
79
}
@@ -92,12 +99,12 @@ func (h *cm) untrackDstLocked(cid string, c net.Conn) {
92
99
return
93
100
}
94
101
dst := raddr .String ()
95
- if ids , ok := h .dsttracker [dst ]; ok {
96
- for i , id := range ids {
97
- if id == cid {
102
+ if tups , ok := h .dsttracker [dst ]; ok {
103
+ for i , t := range tups {
104
+ if t . CID == cid {
98
105
// ids[i+1:] does not panic if i+1 is out of range
99
106
// go.dev/play/p/troeQ5djf9h
100
- h .dsttracker [dst ] = append (ids [:i ], ids [i + 1 :]... )
107
+ h .dsttracker [dst ] = append (tups [:i ], tups [i + 1 :]... )
101
108
break
102
109
}
103
110
}
@@ -125,31 +132,33 @@ func (h *cm) UntrackBatch(cids []string) (out []string) {
125
132
return
126
133
}
127
134
128
- func (h * cm ) Get (id string ) (conns []net.Conn ) {
135
+ func (h * cm ) Get (cid string ) (conns []net.Conn ) {
129
136
h .RLock ()
130
137
defer h .RUnlock ()
131
138
132
- if conns , ok := h .conntracker [id ]; ok {
139
+ if conns , ok := h .conntracker [cid ]; ok {
133
140
return conns
134
141
}
135
142
return
136
143
}
137
144
138
- func (h * cm ) Find (dst string ) (ids []string ) {
145
+ func (h * cm ) Find (dst string ) (tups []ConnTuple ) {
139
146
if len (dst ) == 0 {
140
147
return
141
148
}
142
149
143
150
h .RLock ()
144
151
defer h .RUnlock ()
145
152
146
- if ids , ok := h .dsttracker [dst ]; ok {
147
- return ids
153
+ if tups , ok := h .dsttracker [dst ]; ok {
154
+ return tups
148
155
}
149
156
return
150
157
}
151
158
152
- func (h * cm ) FindAny (csvdst string ) (ids []string ) {
159
+ func (h * cm ) FindAll (csvdst string ) (out []ConnTuple ) {
160
+ out = make ([]ConnTuple , 0 )
161
+
153
162
if len (csvdst ) == 0 {
154
163
return
155
164
}
@@ -159,25 +168,25 @@ func (h *cm) FindAny(csvdst string) (ids []string) {
159
168
160
169
dsts := strings .Split (csvdst , "," )
161
170
for _ , dst := range dsts {
162
- if ids , ok := h .dsttracker [string ( dst ) ]; ok {
163
- return ids
171
+ if tups , ok := h .dsttracker [dst ]; ok {
172
+ out = append ( out , tups ... )
164
173
}
165
174
}
166
175
return
167
176
}
168
177
169
- func (h * cm ) Clear () (ids []string ) {
178
+ func (h * cm ) Clear () (cids []string ) {
170
179
h .Lock ()
171
180
defer h .Unlock ()
172
181
173
- ids = make ([]string , 0 , len (h .conntracker ))
182
+ cids = make ([]string , 0 , len (h .conntracker ))
174
183
for k , v := range h .conntracker {
175
184
for _ , c := range v {
176
185
if c != nil {
177
186
go c .Close ()
178
187
}
179
188
}
180
- ids = append (ids , k )
189
+ cids = append (cids , k )
181
190
}
182
191
clear (h .conntracker )
183
192
clear (h .dsttracker )
0 commit comments