Skip to content

Commit 7af28fc

Browse files
committed
netstack: atomics for tun fd
1 parent 2037fec commit 7af28fc

File tree

1 file changed

+26
-9
lines changed

1 file changed

+26
-9
lines changed

intra/netstack/fdbased.go

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ package netstack
3131

3232
import (
3333
"fmt"
34+
"sync/atomic"
3435

3536
"github.com/celzero/firestack/intra/log"
3637
"golang.org/x/sys/unix"
@@ -48,6 +49,8 @@ var _ stack.LinkEndpoint = (*endpoint)(nil)
4849
var _ stack.LinkEndpoint = (*sniff)(nil)
4950
var _ Swapper = (*sniff)(nil)
5051

52+
const invalidfd int = -1
53+
5154
type Swapper interface {
5255
// Swap closes existing FDs; uses new fd and mtu.
5356
Swap(fd, mtu int) error
@@ -74,10 +77,10 @@ type endpoint struct {
7477
// fds is the set of file descriptors each identifying one inbound/outbound
7578
// channel. The endpoint will dispatch from all inbound channels as well as
7679
// hash outbound packets to specific channels based on the packet hash.
77-
fds fdInfo
80+
fds atomic.Value // int
7881

7982
// mtu (maximum transmission unit) is the maximum size of a packet.
80-
mtu uint32
83+
mtu atomic.Uint32
8184

8285
// hdrSize specifies the link-layer header size. If set to 0, no header
8386
// is added/removed; otherwise an ethernet header is used.
@@ -189,7 +192,8 @@ func NewFdbasedInjectableEndpoint(opts *Options) (SeamlessEndpoint, error) {
189192
}
190193

191194
e := &endpoint{
192-
mtu: opts.MTU,
195+
mtu: atomic.Uint32{},
196+
fds: atomic.Value{},
193197
caps: caps,
194198
addr: opts.Address,
195199
hdrSize: hdrSize,
@@ -231,8 +235,9 @@ func (e *endpoint) Swap(fd, mtu int) (err error) {
231235
return fmt.Errorf("unix.SetNonblock(%v) failed: %v", fd, err)
232236
}
233237

234-
e.fds = fdInfo{fd: fd} // commence WritePackets() on fd
235-
e.mtu = uint32(mtu)
238+
// commence WritePackets() on fd
239+
e.fds.Store(fd)
240+
e.mtu.Store(uint32(mtu))
236241

237242
e.Lock()
238243
defer e.Unlock()
@@ -288,7 +293,7 @@ func (e *endpoint) IsAttached() bool {
288293
// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
289294
// during construction.
290295
func (e *endpoint) MTU() uint32 {
291-
return e.mtu
296+
return e.mtu.Load()
292297
}
293298

294299
// Capabilities implements stack.LinkEndpoint.Capabilities.
@@ -354,6 +359,14 @@ func (e *endpoint) logPacketIfNeeded(dir sniffer.Direction, pkt *stack.PacketBuf
354359
}
355360
}
356361

362+
// fd returns the file descriptor associated with the endpoint.
363+
func (e *endpoint) fd() int {
364+
if fd, ok := e.fds.Load().(int); ok {
365+
return fd
366+
}
367+
return invalidfd
368+
}
369+
357370
// writePackets writes outbound packets to the file descriptor. If it is not
358371
// currently writable, the packet is dropped.
359372
// Way more simplified than og impl, ref: github.com/google/gvisor/issues/7125
@@ -363,7 +376,11 @@ func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error)
363376
// segment can get split into 46 segments of 1420 bytes and a single 216
364377
// byte segment.
365378
const batchSz = 47
366-
fd := e.fds.fd // may have been closed
379+
fd := e.fd() // may have been closed
380+
if fd == invalidfd { // unlikely; panic instead?
381+
log.E("ns: WritePackets (to tun): fd invalid")
382+
return 0, &tcpip.ErrNoSuchFile{}
383+
}
367384
batch := make([]unix.Iovec, 0, batchSz)
368385
packets, written := 0, 0
369386
total := pkts.Len()
@@ -409,7 +426,7 @@ func (e *endpoint) dispatchLoop(inbound linkDispatcher) tcpip.Error {
409426

410427
if inbound == nil {
411428
log.W("ns: dispatchLoop: inbound nil")
412-
return &tcpip.ErrInvalidEndpointState{}
429+
return &tcpip.ErrUnknownDevice{}
413430
}
414431
for {
415432
cont, err := inbound.dispatch()
@@ -445,5 +462,5 @@ func (e *endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stac
445462
func (e *endpoint) InjectOutbound(dest tcpip.Address, packet *buffer.View) tcpip.Error {
446463
log.V("ns: inject-outbound (to tun) to dst(%v)", dest)
447464
// TODO: e.logPacketIfNeeded(sniffer.DirectionSend, packet)
448-
return rawfile.NonBlockingWrite(e.fds.fd, packet.AsSlice())
465+
return rawfile.NonBlockingWrite(e.fd(), packet.AsSlice())
449466
}

0 commit comments

Comments
 (0)