@@ -31,6 +31,7 @@ package netstack
31
31
32
32
import (
33
33
"fmt"
34
+ "sync/atomic"
34
35
35
36
"github.com/celzero/firestack/intra/log"
36
37
"golang.org/x/sys/unix"
@@ -48,6 +49,8 @@ var _ stack.LinkEndpoint = (*endpoint)(nil)
48
49
var _ stack.LinkEndpoint = (* sniff )(nil )
49
50
var _ Swapper = (* sniff )(nil )
50
51
52
+ const invalidfd int = - 1
53
+
51
54
type Swapper interface {
52
55
// Swap closes existing FDs; uses new fd and mtu.
53
56
Swap (fd , mtu int ) error
@@ -74,10 +77,10 @@ type endpoint struct {
74
77
// fds is the set of file descriptors each identifying one inbound/outbound
75
78
// channel. The endpoint will dispatch from all inbound channels as well as
76
79
// hash outbound packets to specific channels based on the packet hash.
77
- fds fdInfo
80
+ fds atomic. Value // int
78
81
79
82
// mtu (maximum transmission unit) is the maximum size of a packet.
80
- mtu uint32
83
+ mtu atomic. Uint32
81
84
82
85
// hdrSize specifies the link-layer header size. If set to 0, no header
83
86
// is added/removed; otherwise an ethernet header is used.
@@ -189,7 +192,8 @@ func NewFdbasedInjectableEndpoint(opts *Options) (SeamlessEndpoint, error) {
189
192
}
190
193
191
194
e := & endpoint {
192
- mtu : opts .MTU ,
195
+ mtu : atomic.Uint32 {},
196
+ fds : atomic.Value {},
193
197
caps : caps ,
194
198
addr : opts .Address ,
195
199
hdrSize : hdrSize ,
@@ -231,8 +235,9 @@ func (e *endpoint) Swap(fd, mtu int) (err error) {
231
235
return fmt .Errorf ("unix.SetNonblock(%v) failed: %v" , fd , err )
232
236
}
233
237
234
- e .fds = fdInfo {fd : fd } // commence WritePackets() on fd
235
- e .mtu = uint32 (mtu )
238
+ // commence WritePackets() on fd
239
+ e .fds .Store (fd )
240
+ e .mtu .Store (uint32 (mtu ))
236
241
237
242
e .Lock ()
238
243
defer e .Unlock ()
@@ -288,7 +293,7 @@ func (e *endpoint) IsAttached() bool {
288
293
// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
289
294
// during construction.
290
295
func (e * endpoint ) MTU () uint32 {
291
- return e .mtu
296
+ return e .mtu . Load ()
292
297
}
293
298
294
299
// Capabilities implements stack.LinkEndpoint.Capabilities.
@@ -354,6 +359,14 @@ func (e *endpoint) logPacketIfNeeded(dir sniffer.Direction, pkt *stack.PacketBuf
354
359
}
355
360
}
356
361
362
+ // fd returns the file descriptor associated with the endpoint.
363
+ func (e * endpoint ) fd () int {
364
+ if fd , ok := e .fds .Load ().(int ); ok {
365
+ return fd
366
+ }
367
+ return invalidfd
368
+ }
369
+
357
370
// writePackets writes outbound packets to the file descriptor. If it is not
358
371
// currently writable, the packet is dropped.
359
372
// Way more simplified than og impl, ref: github.com/google/gvisor/issues/7125
@@ -363,7 +376,11 @@ func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error)
363
376
// segment can get split into 46 segments of 1420 bytes and a single 216
364
377
// byte segment.
365
378
const batchSz = 47
366
- fd := e .fds .fd // may have been closed
379
+ fd := e .fd () // may have been closed
380
+ if fd == invalidfd { // unlikely; panic instead?
381
+ log .E ("ns: WritePackets (to tun): fd invalid" )
382
+ return 0 , & tcpip.ErrNoSuchFile {}
383
+ }
367
384
batch := make ([]unix.Iovec , 0 , batchSz )
368
385
packets , written := 0 , 0
369
386
total := pkts .Len ()
@@ -409,7 +426,7 @@ func (e *endpoint) dispatchLoop(inbound linkDispatcher) tcpip.Error {
409
426
410
427
if inbound == nil {
411
428
log .W ("ns: dispatchLoop: inbound nil" )
412
- return & tcpip.ErrInvalidEndpointState {}
429
+ return & tcpip.ErrUnknownDevice {}
413
430
}
414
431
for {
415
432
cont , err := inbound .dispatch ()
@@ -445,5 +462,5 @@ func (e *endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stac
445
462
func (e * endpoint ) InjectOutbound (dest tcpip.Address , packet * buffer.View ) tcpip.Error {
446
463
log .V ("ns: inject-outbound (to tun) to dst(%v)" , dest )
447
464
// TODO: e.logPacketIfNeeded(sniffer.DirectionSend, packet)
448
- return rawfile .NonBlockingWrite (e .fds . fd , packet .AsSlice ())
465
+ return rawfile .NonBlockingWrite (e .fd () , packet .AsSlice ())
449
466
}
0 commit comments