Skip to content

Commit

Permalink
Add GracefulClose
Browse files Browse the repository at this point in the history
  • Loading branch information
edaniels committed Jul 25, 2024
1 parent abf50f9 commit a0385ee
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 5 deletions.
27 changes: 23 additions & 4 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,9 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit

userBindingRequestHandler: config.BindingRequestHandler,
}
a.connectionStateNotifier = &handlerNotifier{connectionStateFunc: a.onConnectionStateChange}
a.candidateNotifier = &handlerNotifier{candidateFunc: a.onCandidate}
a.selectedCandidatePairNotifier = &handlerNotifier{candidatePairFunc: a.onSelectedCandidatePairChange}
a.connectionStateNotifier = &handlerNotifier{connectionStateFunc: a.onConnectionStateChange, done: make(chan struct{})}
a.candidateNotifier = &handlerNotifier{candidateFunc: a.onCandidate, done: make(chan struct{})}
a.selectedCandidatePairNotifier = &handlerNotifier{candidatePairFunc: a.onSelectedCandidatePairChange, done: make(chan struct{})}

if a.net == nil {
a.net, err = stdnet.NewNet()
Expand Down Expand Up @@ -849,7 +849,26 @@ func (a *Agent) removeUfragFromMux() {

// Close cleans up the Agent
func (a *Agent) Close() error {
return a.loop.Close()
return a.close(false)
}

// GracefulClose cleans up the Agent and waits for any goroutines it started
// to complete. This is only safe to call outside of Agent callbacks or if in a callback,
// in its own goroutine.
func (a *Agent) GracefulClose() error {
return a.close(true)
}

func (a *Agent) close(graceful bool) error {
// the loop is safe to wait on no matter what
err := a.loop.Close()

// but we are in less control of the notifiers, so we will
// pass through `graceful`.
a.connectionStateNotifier.Close(graceful)
a.candidateNotifier.Close(graceful)
a.selectedCandidatePairNotifier.Close(graceful)
return err
}

// Remove all candidates. This closes any listening sockets
Expand Down
47 changes: 46 additions & 1 deletion agent_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ func (a *Agent) onConnectionStateChange(s ConnectionState) {

type handlerNotifier struct {
sync.Mutex
running bool
running bool
notifiers sync.WaitGroup

connectionStates []ConnectionState
connectionStateFunc func(ConnectionState)
Expand All @@ -55,13 +56,40 @@ type handlerNotifier struct {

selectedCandidatePairs []*CandidatePair
candidatePairFunc func(*CandidatePair)

// State for closing
done chan struct{}
}

func (h *handlerNotifier) Close(graceful bool) {
h.Lock()

select {
case <-h.done:
h.Unlock()
return
default:
}
close(h.done)
h.Unlock()

if graceful {
h.notifiers.Wait()
}
}

func (h *handlerNotifier) EnqueueConnectionState(s ConnectionState) {
h.Lock()
defer h.Unlock()

select {
case <-h.done:
return
default:
}

notify := func() {
defer h.notifiers.Done()
for {
h.Lock()
if len(h.connectionStates) == 0 {
Expand All @@ -79,6 +107,7 @@ func (h *handlerNotifier) EnqueueConnectionState(s ConnectionState) {
h.connectionStates = append(h.connectionStates, s)
if !h.running {
h.running = true
h.notifiers.Add(1)
go notify()
}
}
Expand All @@ -87,7 +116,14 @@ func (h *handlerNotifier) EnqueueCandidate(c Candidate) {
h.Lock()
defer h.Unlock()

select {
case <-h.done:
return
default:
}

notify := func() {
defer h.notifiers.Done()
for {
h.Lock()
if len(h.candidates) == 0 {
Expand All @@ -105,6 +141,7 @@ func (h *handlerNotifier) EnqueueCandidate(c Candidate) {
h.candidates = append(h.candidates, c)
if !h.running {
h.running = true
h.notifiers.Add(1)
go notify()
}
}
Expand All @@ -113,7 +150,14 @@ func (h *handlerNotifier) EnqueueSelectedCandidatePair(p *CandidatePair) {
h.Lock()
defer h.Unlock()

select {
case <-h.done:
return
default:
}

notify := func() {
defer h.notifiers.Done()
for {
h.Lock()
if len(h.selectedCandidatePairs) == 0 {
Expand All @@ -131,6 +175,7 @@ func (h *handlerNotifier) EnqueueSelectedCandidatePair(p *CandidatePair) {
h.selectedCandidatePairs = append(h.selectedCandidatePairs, p)
if !h.running {
h.running = true
h.notifiers.Add(1)
go notify()
}
}
4 changes: 4 additions & 0 deletions agent_handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ func TestConnectionStateNotifier(t *testing.T) {
connectionStateFunc: func(_ ConnectionState) {
updates <- struct{}{}
},
done: make(chan struct{}),
}
// Enqueue all updates upfront to ensure that it
// doesn't block
Expand All @@ -38,6 +39,7 @@ func TestConnectionStateNotifier(t *testing.T) {
close(done)
}()
<-done
c.Close(true)
})
t.Run("TestUpdateOrdering", func(t *testing.T) {
defer test.CheckRoutines(t)()
Expand All @@ -46,6 +48,7 @@ func TestConnectionStateNotifier(t *testing.T) {
connectionStateFunc: func(cs ConnectionState) {
updates <- cs
},
done: make(chan struct{}),
}
done := make(chan struct{})
go func() {
Expand All @@ -66,5 +69,6 @@ func TestConnectionStateNotifier(t *testing.T) {
c.EnqueueConnectionState(ConnectionState(i))
}
<-done
c.Close(true)
})
}
52 changes: 52 additions & 0 deletions agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1737,3 +1737,55 @@ func TestAcceptAggressiveNomination(t *testing.T) {
require.NoError(t, wan.Stop())
closePipe(t, aConn, bConn)
}

// Close can deadlock but GracefulClose must not
func TestAgentGracefulCloseDeadlock(t *testing.T) {
defer test.CheckRoutinesStrict(t)()
defer test.TimeOut(time.Second * 5).Stop()

config := &AgentConfig{
NetworkTypes: supportedNetworkTypes(),
}
aAgent, err := NewAgent(config)
require.NoError(t, err)

bAgent, err := NewAgent(config)
require.NoError(t, err)

var connected, closeNow, closed sync.WaitGroup
connected.Add(2)
closeNow.Add(1)
closed.Add(2)
closeHdlr := func(agent *Agent) {
check(agent.OnConnectionStateChange(func(cs ConnectionState) {
if cs == ConnectionStateConnected {
connected.Done()
closeNow.Wait()

go func() {
if err := agent.GracefulClose(); err != nil {
require.NoError(t, err)
}
closed.Done()
}()
}
}))
}

closeHdlr(aAgent)
closeHdlr(bAgent)

t.Log("connecting agents")
_, _ = connect(aAgent, bAgent)

t.Log("waiting for them to confirm connection in callback")
connected.Wait()

t.Log("tell them to close themselves in the same callback and wait")
closeNow.Done()
closed.Wait()

// already closed
require.Error(t, aAgent.Close())
require.Error(t, bAgent.Close())
}

0 comments on commit a0385ee

Please sign in to comment.