Skip to content

Commit 99d256b

Browse files
committed
tcp,udp: track conn in same rountine to prevent race
1 parent 45f6c68 commit 99d256b

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

intra/common.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,8 @@ func download(cid string, local net.Conn, remote net.Conn) (n int64, err error)
6666

6767
// forward copies data between local and remote, and tracks the connection.
6868
// It also sends a summary to the listener when done. Always called in a goroutine.
69-
func forward(local net.Conn, remote net.Conn, t core.ConnMapper, l SocketListener, smm *SocketSummary) {
69+
func forward(local, remote net.Conn, l SocketListener, smm *SocketSummary) {
7070
cid := smm.ID
71-
uid := smm.UID
72-
ct := core.ConnTuple{CID: cid, UID: uid}
73-
74-
t.Track(ct, local, remote)
75-
defer t.Untrack(cid)
7671

7772
uploadch := make(chan ioinfo)
7873

intra/tcp.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,15 +269,18 @@ func (h *tcpHandler) handle(px ipn.Proxy, src net.Conn, target netip.AddrPort, s
269269
return err
270270
}
271271

272+
ct := core.ConnTuple{CID: smm.ID, UID: smm.UID}
273+
274+
h.conntracker.Track(ct, src, dst)
272275
go func() {
273-
cm := h.conntracker
274276
l := h.listener
275277
defer func() {
276278
if r := recover(); r != nil {
277279
log.W("tcp: forward: panic %v", r)
278280
}
281+
defer h.conntracker.Untrack(ct.CID)
279282
}()
280-
forward(src, dst, cm, l, smm) // src always *gonet.TCPConn
283+
forward(src, dst, l, smm) // src always *gonet.TCPConn
281284
}()
282285

283286
log.I("tcp: new conn %s via proxy(%s); src(%s) -> dst(%s) for %s", smm.ID, px.ID(), src.LocalAddr(), target, smm.UID)

intra/udp.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,15 +240,17 @@ func (h *udpHandler) proxy(gconn net.Conn, src, dst netip.AddrPort) (ok bool) {
240240
// no summary for dns queries
241241
return true // ok
242242
}
243+
244+
ct := core.ConnTuple{CID: smm.ID, UID: smm.UID}
245+
h.conntracker.Track(ct, gconn, remote)
243246
go func() {
244-
cm := h.conntracker
245247
defer func() {
246248
if r := recover(); r != nil {
247249
log.W("udp: forward: %s -> %s panic %v", src, dst, r)
248250
}
251+
h.conntracker.Untrack(ct.CID)
249252
}()
250-
251-
forward(gconn, &rwext{remote}, cm, l, smm)
253+
forward(gconn, &rwext{remote}, l, smm)
252254
}()
253255
return true // ok
254256
}

0 commit comments

Comments
 (0)