24
24
package dialers
25
25
26
26
import (
27
+ "context"
27
28
"io"
28
29
"net"
30
+ "net/netip"
29
31
"sync"
30
32
"sync/atomic"
31
33
"syscall"
@@ -44,15 +46,22 @@ func (zeroNetAddr) String() string { return "none" }
44
46
45
47
const maxRetryCount = 3
46
48
49
+ // ippPins maintains a limited-time mapping between ip:port addresses and dialer IDs.
50
+ // TODO: invalidate cache on network changes.
51
+ // TODO: with context.TODO, expmap's reaper goroutine will leak.
52
+ var ippPins = core .NewSieve [netip.AddrPort , string ](context .TODO (), desync_cache_ttl )
53
+
47
54
// retrier implements the DuplexConn interface and must
48
55
// be typecastable to *net.TCPConn (see: xdial.DialTCP)
49
56
// inheritance: go.dev/play/p/mMiQgXsPM7Y
50
57
type retrier struct {
51
- dialers []protect.RDialer
52
- dialerOpts settings.DialerOpts
53
- multidial bool
54
- raddr net.Addr
55
- laddr net.Addr // laddr may be nil; TCPAddr.IP may be nil.
58
+ dialers []protect.RDialer
59
+ dialerOpts settings.DialerOpts
60
+ nextDialerIdx int
61
+ multidial bool
62
+
63
+ raddr net.Addr
64
+ laddr net.Addr // laddr may be nil; TCPAddr.IP may be nil.
56
65
57
66
// Flags indicating whether the caller has called CloseRead and CloseWrite.
58
67
readDone atomic.Bool
@@ -80,9 +89,8 @@ type retrier struct {
80
89
// and is cleared when the first byte is received.
81
90
tee []byte
82
91
// retryErr is set to the error from the last retry, if any.
83
- retryErr error
84
- retryCount uint8
85
- dialerCount int
92
+ retryErr error
93
+ retryCount uint8
86
94
// Flag indicating when retry is finished or unnecessary.
87
95
retryDoneCh chan struct {} // always unbuffered
88
96
}
@@ -116,7 +124,7 @@ func (r *retrier) retryCompleted() bool {
116
124
117
125
func (r * retrier ) canRetryLocked () bool {
118
126
if r .multidial {
119
- return r .dialerCount < len (r .dialers )
127
+ return r .nextDialerIdx < len (r .dialers )
120
128
} else {
121
129
return r .retryCount < maxRetryCount
122
130
}
@@ -162,9 +170,27 @@ func dialerOptsForRace() settings.DialerOpts {
162
170
}
163
171
}
164
172
173
+ func reprioritize (ds []protect.RDialer , ipp netip.AddrPort ) []protect.RDialer {
174
+ // reprioritize the dialers based on the IP:port pair
175
+ if ! ipp .IsValid () {
176
+ return ds
177
+ }
178
+ id , ok := ippPins .Get (ipp )
179
+ if ! ok || len (id ) <= 0 {
180
+ return ds
181
+ }
182
+ for i , d := range ds {
183
+ if d .ID () == id {
184
+ ds [i ], ds [0 ] = ds [0 ], ds [i ]
185
+ break
186
+ }
187
+ }
188
+ return ds
189
+ }
190
+
165
191
func DialAny (ds []protect.RDialer , laddr , raddr net.Addr ) (* retrier , error ) {
166
192
r := & retrier {
167
- dialers : ds ,
193
+ dialers : reprioritize ( ds , asAddrPort ( raddr )) ,
168
194
dialerOpts : dialerOptsForRace (),
169
195
multidial : true ,
170
196
laddr : laddr , // may be nil
@@ -274,7 +300,6 @@ func (r *retrier) dialLocked() (c core.DuplexConn, err error) {
274
300
begin := time .Now ()
275
301
c , err = r .doDialLocked (strat )
276
302
rtt := time .Since (begin )
277
- r .dialerCount ++
278
303
279
304
r .conn = c // c may be nil
280
305
r .timeout = calcTimeout (rtt )
@@ -290,7 +315,13 @@ func (r *retrier) dialLocked() (c core.DuplexConn, err error) {
290
315
func (r * retrier ) doDialLocked (dialStrat int32 ) (_ core.DuplexConn , err error ) {
291
316
var conn * net.TCPConn
292
317
293
- di := r .dialerCount % len (r .dialers )
318
+ di := r .nextDialerIdx
319
+ if r .multidial {
320
+ if di >= len (r .dialers ) {
321
+ return nil , errNoDialer
322
+ }
323
+ }
324
+ r .nextDialerIdx = di + 1
294
325
295
326
// r.raddr may be nil or laddr.IP may be nil.
296
327
switch dialStrat {
@@ -395,7 +426,7 @@ func (r *retrier) Read(buf []byte) (n int, err error) {
395
426
err = core .UniqErr (err , retryerr )
396
427
}
397
428
logeor (retryerr , log .I )("retrier: read# %d + (mult? %t / c: %d): [%s<=%s] %d; err? %v" ,
398
- r .retryCount , r .multidial , r .dialerCount , laddr (c ), r .raddr , n , retryerr )
429
+ r .retryCount , r .multidial , r .nextDialerIdx , laddr (c ), r .raddr , n , retryerr )
399
430
}
400
431
if c != nil && core .IsNotNil (c ) {
401
432
_ = c .SetReadDeadline (r .readDeadline )
@@ -404,7 +435,8 @@ func (r *retrier) Read(buf []byte) (n int, err error) {
404
435
r .tee = nil // discard teed data
405
436
return
406
437
}
407
- logeor (err , note )("retrier: read: already retried! [%s<=%s] %d; err? %v" , laddr (c ), r .raddr , n , err )
438
+ logeor (err , note )("retrier: read: already retried! [%s<=%s] %s; err? %v" ,
439
+ laddr (c ), r .raddr , n , err )
408
440
} // else: just one read is enough; no retry needed
409
441
return
410
442
}
@@ -514,11 +546,24 @@ func (r *retrier) ReadFrom(reader io.Reader) (bytes int64, err error) {
514
546
return bytes , io .ErrUnexpectedEOF
515
547
}
516
548
549
+ pinned := false
550
+ pinnedID := ""
551
+ if r .multidial {
552
+ if ipp := asAddrPort (r .raddr ); ipp .IsValid () {
553
+ // cache the dialer ID for the IP:port pair
554
+ di := max (0 , r .nextDialerIdx - 1 ) % len (r .dialers )
555
+ pinnedID = r .dialers [di ].ID ()
556
+ ippPins .Put (ipp , pinnedID )
557
+ pinned = true
558
+ }
559
+ }
560
+
517
561
var b int64
518
562
b , err = c .ReadFrom (reader )
519
563
bytes += b
520
564
521
- logeif (err )("retrier: readfrom: done; sz: %d; err: %v" , bytes , err )
565
+ logeif (err )("retrier: readfrom: done (id: %s, pinned? %t); sz: %d; err: %v" ,
566
+ pinnedID , pinned , bytes , err )
522
567
return
523
568
}
524
569
0 commit comments