diff --git a/agent.go b/agent.go index 344200c2..91d25f12 100644 --- a/agent.go +++ b/agent.go @@ -323,9 +323,9 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit userBindingRequestHandler: config.BindingRequestHandler, } - a.connectionStateNotifier = &handlerNotifier{connectionStateFunc: a.onConnectionStateChange} - a.candidateNotifier = &handlerNotifier{candidateFunc: a.onCandidate} - a.selectedCandidatePairNotifier = &handlerNotifier{candidatePairFunc: a.onSelectedCandidatePairChange} + a.connectionStateNotifier = &handlerNotifier{connectionStateFunc: a.onConnectionStateChange, done: make(chan struct{})} + a.candidateNotifier = &handlerNotifier{candidateFunc: a.onCandidate, done: make(chan struct{})} + a.selectedCandidatePairNotifier = &handlerNotifier{candidatePairFunc: a.onSelectedCandidatePairChange, done: make(chan struct{})} if a.net == nil { a.net, err = stdnet.NewNet() @@ -931,6 +931,9 @@ func (a *Agent) Close() error { close(a.done) <-a.taskLoopDone + a.connectionStateNotifier.Close() + a.candidateNotifier.Close() + a.selectedCandidatePairNotifier.Close() return nil } diff --git a/agent_handlers.go b/agent_handlers.go index bb0c8d30..7ebfedd1 100644 --- a/agent_handlers.go +++ b/agent_handlers.go @@ -45,7 +45,8 @@ func (a *Agent) onConnectionStateChange(s ConnectionState) { type handlerNotifier struct { sync.Mutex - running bool + running bool + notifiers sync.WaitGroup connectionStates []ConnectionState connectionStateFunc func(ConnectionState) @@ -55,13 +56,38 @@ type handlerNotifier struct { selectedCandidatePairs []*CandidatePair candidatePairFunc func(*CandidatePair) + + // State for closing + done chan struct{} +} + +func (h *handlerNotifier) Close() { + h.Lock() + + select { + case <-h.done: + h.Unlock() + return + default: + } + close(h.done) + h.Unlock() + + h.notifiers.Wait() } func (h *handlerNotifier) EnqueueConnectionState(s ConnectionState) { h.Lock() defer h.Unlock() + select { + case <-h.done: + return + default: + } + notify := func() { + defer h.notifiers.Done() for { h.Lock() if len(h.connectionStates) == 0 { @@ -79,6 +105,7 @@ func (h *handlerNotifier) EnqueueConnectionState(s ConnectionState) { h.connectionStates = append(h.connectionStates, s) if !h.running { h.running = true + h.notifiers.Add(1) go notify() } } @@ -87,7 +114,14 @@ func (h *handlerNotifier) EnqueueCandidate(c Candidate) { h.Lock() defer h.Unlock() + select { + case <-h.done: + return + default: + } + notify := func() { + defer h.notifiers.Done() for { h.Lock() if len(h.candidates) == 0 { @@ -105,6 +139,7 @@ func (h *handlerNotifier) EnqueueCandidate(c Candidate) { h.candidates = append(h.candidates, c) if !h.running { h.running = true + h.notifiers.Add(1) go notify() } } @@ -113,7 +148,14 @@ func (h *handlerNotifier) EnqueueSelectedCandidatePair(p *CandidatePair) { h.Lock() defer h.Unlock() + select { + case <-h.done: + return + default: + } + notify := func() { + defer h.notifiers.Done() for { h.Lock() if len(h.selectedCandidatePairs) == 0 { @@ -131,6 +173,7 @@ func (h *handlerNotifier) EnqueueSelectedCandidatePair(p *CandidatePair) { h.selectedCandidatePairs = append(h.selectedCandidatePairs, p) if !h.running { h.running = true + h.notifiers.Add(1) go notify() } } diff --git a/agent_handlers_test.go b/agent_handlers_test.go index 66ff8048..a4741816 100644 --- a/agent_handlers_test.go +++ b/agent_handlers_test.go @@ -19,6 +19,7 @@ func TestConnectionStateNotifier(t *testing.T) { connectionStateFunc: func(_ ConnectionState) { updates <- struct{}{} }, + done: make(chan struct{}), } // Enqueue all updates upfront to ensure that it // doesn't block @@ -38,6 +39,7 @@ func TestConnectionStateNotifier(t *testing.T) { close(done) }() <-done + c.Close() }) t.Run("TestUpdateOrdering", func(t *testing.T) { report := test.CheckRoutines(t) @@ -47,6 +49,7 @@ func TestConnectionStateNotifier(t *testing.T) { connectionStateFunc: func(cs ConnectionState) { updates <- cs }, + done: make(chan struct{}), } done := make(chan struct{}) go func() { @@ -67,5 +70,6 @@ func TestConnectionStateNotifier(t *testing.T) { c.EnqueueConnectionState(ConnectionState(i)) } <-done + c.Close() }) } diff --git a/agent_test.go b/agent_test.go index d41dc36b..369f8617 100644 --- a/agent_test.go +++ b/agent_test.go @@ -1423,11 +1423,12 @@ func TestCloseInConnectionStateCallback(t *testing.T) { isClosed := make(chan interface{}) isConnected := make(chan interface{}) + connectionStateConnectedSeen := make(chan interface{}) err = aAgent.OnConnectionStateChange(func(c ConnectionState) { switch c { case ConnectionStateConnected: <-isConnected - assert.NoError(t, aAgent.Close()) + close(connectionStateConnectedSeen) case ConnectionStateClosed: close(isClosed) default: @@ -1439,6 +1440,8 @@ func TestCloseInConnectionStateCallback(t *testing.T) { connect(aAgent, bAgent) close(isConnected) + <-connectionStateConnectedSeen + require.NoError(t, aAgent.Close()) <-isClosed assert.NoError(t, bAgent.Close())