diff --git a/agent.go b/agent.go index 7aac7179..c8e0771b 100644 --- a/agent.go +++ b/agent.go @@ -15,8 +15,8 @@ import ( "sync/atomic" "time" - atomicx "github.com/pion/ice/v3/internal/atomic" stunx "github.com/pion/ice/v3/internal/stun" + "github.com/pion/ice/v3/internal/taskloop" "github.com/pion/logging" "github.com/pion/mdns/v2" "github.com/pion/stun/v2" @@ -36,9 +36,7 @@ type bindingRequest struct { // Agent represents the ICE agent type Agent struct { - chanTask chan task - afterRunFn []func(ctx context.Context) - muAfterRun sync.Mutex + loop *taskloop.Loop onConnectionStateChangeHdlr atomic.Value // func(ConnectionState) onSelectedCandidatePairChangeHdlr atomic.Value // func(Candidate, Candidate) @@ -120,11 +118,6 @@ type Agent struct { // 1:1 D-NAT IP address mapping extIPMapper *externalIPMapper - // State for closing - done chan struct{} - taskLoopDone chan struct{} - err atomicx.Error - gatherCandidateCancel func() gatherCandidateDone chan struct{} @@ -149,99 +142,6 @@ type Agent struct { proxyDialer proxy.Dialer } -type task struct { - fn func(context.Context, *Agent) - done chan struct{} -} - -// afterRun registers function to be run after the task. -func (a *Agent) afterRun(f func(context.Context)) { - a.muAfterRun.Lock() - a.afterRunFn = append(a.afterRunFn, f) - a.muAfterRun.Unlock() -} - -func (a *Agent) getAfterRunFn() []func(context.Context) { - a.muAfterRun.Lock() - defer a.muAfterRun.Unlock() - fns := a.afterRunFn - a.afterRunFn = nil - return fns -} - -func (a *Agent) ok() error { - select { - case <-a.done: - return a.getErr() - default: - } - return nil -} - -func (a *Agent) getErr() error { - if err := a.err.Load(); err != nil { - return err - } - return ErrClosed -} - -// Run task in serial. Blocking tasks must be cancelable by context. -func (a *Agent) run(ctx context.Context, t func(context.Context, *Agent)) error { - if err := a.ok(); err != nil { - return err - } - done := make(chan struct{}) - select { - case <-ctx.Done(): - return ctx.Err() - case a.chanTask <- task{t, done}: - <-done - return nil - } -} - -// taskLoop handles registered tasks and agent close. -func (a *Agent) taskLoop() { - after := func() { - for { - // Get and run func registered by afterRun(). - fns := a.getAfterRunFn() - if len(fns) == 0 { - break - } - for _, fn := range fns { - fn(a.context()) - } - } - } - defer func() { - a.deleteAllCandidates() - a.startedFn() - - if err := a.buf.Close(); err != nil { - a.log.Warnf("Failed to close buffer: %v", err) - } - - a.closeMulticastConn() - a.updateConnectionState(ConnectionStateClosed) - - after() - - close(a.taskLoopDone) - }() - - for { - select { - case <-a.done: - return - case t := <-a.chanTask: - t.fn(a.context(), a) - close(t.done) - after() - } - } -} - // NewAgent creates a new Agent func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit var err error @@ -274,7 +174,6 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit startedCtx, startedFn := context.WithCancel(context.Background()) a := &Agent{ - chanTask: make(chan task), tieBreaker: globalMathRandomGenerator.Uint64(), lite: config.Lite, gatheringState: GatheringStateNew, @@ -285,8 +184,6 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit networkTypes: config.NetworkTypes, onConnected: make(chan struct{}), buf: packetio.NewBuffer(), - done: make(chan struct{}), - taskLoopDone: make(chan struct{}), startedCh: startedCtx.Done(), startedFn: startedFn, portMin: config.PortMin, @@ -360,7 +257,24 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit return nil, err } - go a.taskLoop() + a.loop = taskloop.NewLoop(func() { + a.deleteAllCandidates() + a.startedFn() + + if err := a.buf.Close(); err != nil { + a.log.Warnf("failed to close buffer: %v", err) + } + + a.closeMulticastConn() + a.updateConnectionState(ConnectionStateClosed) // nolint: contextcheck + + a.gatherCandidateCancel() + if a.gatherCandidateDone != nil { + <-a.gatherCandidateDone + } + + a.removeUfragFromMux() + }) // Restart is also used to initialize the agent for the first time if err := a.Restart(config.LocalUfrag, config.LocalPwd); err != nil { @@ -386,10 +300,10 @@ func (a *Agent) startConnectivityChecks(isControlling bool, remoteUfrag, remoteP a.log.Debugf("Started agent: isControlling? %t, remoteUfrag: %q, remotePwd: %q", isControlling, remoteUfrag, remotePwd) - return a.run(a.context(), func(_ context.Context, agent *Agent) { - agent.isControlling = isControlling - agent.remoteUfrag = remoteUfrag - agent.remotePwd = remotePwd + return a.loop.Run(func(_ context.Context) { + a.isControlling = isControlling + a.remoteUfrag = remoteUfrag + a.remotePwd = remotePwd if isControlling { a.selector = &controllingSelector{agent: a, log: a.log} @@ -404,7 +318,7 @@ func (a *Agent) startConnectivityChecks(isControlling bool, remoteUfrag, remoteP a.selector.Start() a.startedFn() - agent.updateConnectionState(ConnectionStateChecking) + a.updateConnectionState(ConnectionStateChecking) // nolint: contextcheck a.requestConnectivityCheck() go a.connectivityChecks() //nolint:contextcheck @@ -416,7 +330,7 @@ func (a *Agent) connectivityChecks() { checkingDuration := time.Time{} contact := func() { - if err := a.run(a.context(), func(_ context.Context, a *Agent) { + if err := a.loop.Run(func(_ context.Context) { defer func() { lastConnectionState = a.connectionState }() @@ -434,7 +348,7 @@ func (a *Agent) connectivityChecks() { // We have been in checking longer then Disconnect+Failed timeout, set the connection to Failed if time.Since(checkingDuration) > a.disconnectedTimeout+a.failedTimeout { - a.updateConnectionState(ConnectionStateFailed) + a.updateConnectionState(ConnectionStateFailed) // nolint: contextcheck return } default: @@ -473,7 +387,7 @@ func (a *Agent) connectivityChecks() { contact() case <-t.C: contact() - case <-a.done: + case <-a.loop.Done(): t.Stop() return } @@ -665,9 +579,9 @@ func (a *Agent) AddRemoteCandidate(c Candidate) error { } go func() { - if err := a.run(a.context(), func(_ context.Context, agent *Agent) { + if err := a.loop.Run(func(_ context.Context) { // nolint: contextcheck - agent.addRemoteCandidate(c) + a.addRemoteCandidate(c) }); err != nil { a.log.Warnf("Failed to add remote candidate %s: %v", c.Address(), err) return @@ -697,9 +611,9 @@ func (a *Agent) resolveAndAddMulticastCandidate(c *CandidateHost) { return } - if err = a.run(a.context(), func(_ context.Context, agent *Agent) { + if err = a.loop.Run(func(_ context.Context) { // nolint: contextcheck - agent.addRemoteCandidate(c) + a.addRemoteCandidate(c) }); err != nil { a.log.Warnf("Failed to add mDNS candidate %s: %v", c.Address(), err) return @@ -722,7 +636,7 @@ func (a *Agent) addRemotePassiveTCPCandidate(remoteCandidate Candidate) { for i := range localIPs { conn := newActiveTCPConn( - a.context(), + a.loop, net.JoinHostPort(localIPs[i].String(), "0"), net.JoinHostPort(remoteCandidate.Address(), strconv.Itoa(remoteCandidate.Port())), a.log, @@ -790,7 +704,7 @@ func (a *Agent) addRemoteCandidate(c Candidate) { } func (a *Agent) addCandidate(ctx context.Context, c Candidate, candidateConn net.PacketConn) error { - return a.run(ctx, func(context.Context, *Agent) { + return a.loop.RunContext(ctx, func(context.Context) { set := a.localCandidates[c.NetworkType()] for _, candidate := range set { if candidate.Equal(c) { @@ -826,9 +740,9 @@ func (a *Agent) addCandidate(ctx context.Context, c Candidate, candidateConn net func (a *Agent) GetRemoteCandidates() ([]Candidate, error) { var res []Candidate - err := a.run(a.context(), func(_ context.Context, agent *Agent) { + err := a.loop.Run(func(_ context.Context) { var candidates []Candidate - for _, set := range agent.remoteCandidates { + for _, set := range a.remoteCandidates { candidates = append(candidates, set...) } res = candidates @@ -844,9 +758,9 @@ func (a *Agent) GetRemoteCandidates() ([]Candidate, error) { func (a *Agent) GetLocalCandidates() ([]Candidate, error) { var res []Candidate - err := a.run(a.context(), func(_ context.Context, agent *Agent) { + err := a.loop.Run(func(_ context.Context) { var candidates []Candidate - for _, set := range agent.localCandidates { + for _, set := range a.localCandidates { candidates = append(candidates, set...) } res = candidates @@ -861,9 +775,9 @@ func (a *Agent) GetLocalCandidates() ([]Candidate, error) { // GetLocalUserCredentials returns the local user credentials func (a *Agent) GetLocalUserCredentials() (frag string, pwd string, err error) { valSet := make(chan struct{}) - err = a.run(a.context(), func(_ context.Context, agent *Agent) { - frag = agent.localUfrag - pwd = agent.localPwd + err = a.loop.Run(func(_ context.Context) { + frag = a.localUfrag + pwd = a.localPwd close(valSet) }) @@ -876,9 +790,9 @@ func (a *Agent) GetLocalUserCredentials() (frag string, pwd string, err error) { // GetRemoteUserCredentials returns the remote user credentials func (a *Agent) GetRemoteUserCredentials() (frag string, pwd string, err error) { valSet := make(chan struct{}) - err = a.run(a.context(), func(_ context.Context, agent *Agent) { - frag = agent.remoteUfrag - pwd = agent.remotePwd + err = a.loop.Run(func(_ context.Context) { + frag = a.remoteUfrag + pwd = a.remotePwd close(valSet) }) @@ -902,23 +816,11 @@ func (a *Agent) removeUfragFromMux() { // Close cleans up the Agent func (a *Agent) Close() error { - if err := a.ok(); err != nil { + if err := a.loop.Ok(); err != nil { return err } - a.afterRun(func(context.Context) { - a.gatherCandidateCancel() - if a.gatherCandidateDone != nil { - <-a.gatherCandidateDone - } - }) - a.err.Store(ErrClosed) - - a.removeUfragFromMux() - - close(a.done) - <-a.taskLoopDone - return nil + return a.loop.Close() } // Remove all candidates. This closes any listening sockets @@ -1125,7 +1027,7 @@ func (a *Agent) handleInbound(m *stun.Message, local Candidate, remote net.Addr) // and returns true if it is an actual remote candidate func (a *Agent) validateNonSTUNTraffic(local Candidate, remote net.Addr) (Candidate, bool) { var remoteCandidate Candidate - if err := a.run(local.context(), func(context.Context, *Agent) { + if err := a.loop.Run(func(context.Context) { remoteCandidate = a.findRemoteCandidate(local.NetworkType(), remote) if remoteCandidate != nil { remoteCandidate.seen(false) @@ -1182,9 +1084,9 @@ func (a *Agent) SetRemoteCredentials(remoteUfrag, remotePwd string) error { return ErrRemotePwdEmpty } - return a.run(a.context(), func(_ context.Context, agent *Agent) { - agent.remoteUfrag = remoteUfrag - agent.remotePwd = remotePwd + return a.loop.Run(func(_ context.Context) { + a.remoteUfrag = remoteUfrag + a.remotePwd = remotePwd }) } @@ -1219,21 +1121,21 @@ func (a *Agent) Restart(ufrag, pwd string) error { } var err error - if runErr := a.run(a.context(), func(_ context.Context, agent *Agent) { - if agent.gatheringState == GatheringStateGathering { - agent.gatherCandidateCancel() + if runErr := a.loop.Run(func(_ context.Context) { + if a.gatheringState == GatheringStateGathering { + a.gatherCandidateCancel() } // Clear all agent needed to take back to fresh state a.removeUfragFromMux() - agent.localUfrag = ufrag - agent.localPwd = pwd - agent.remoteUfrag = "" - agent.remotePwd = "" + a.localUfrag = ufrag + a.localPwd = pwd + a.remoteUfrag = "" + a.remotePwd = "" a.gatheringState = GatheringStateNew a.checklist = make([]*CandidatePair, 0) a.pendingBindingRequests = make([]bindingRequest, 0) - a.setSelectedPair(nil) + a.setSelectedPair(nil) // nolint: contextcheck a.deleteAllCandidates() if a.selector != nil { a.selector.Start() @@ -1242,7 +1144,7 @@ func (a *Agent) Restart(ufrag, pwd string) error { // Restart is used by NewAgent. Accept/Connect should be used to move to checking // for new Agents if a.connectionState != ConnectionStateNew { - a.updateConnectionState(ConnectionStateChecking) + a.updateConnectionState(ConnectionStateChecking) // nolint: contextcheck } }); runErr != nil { return runErr @@ -1252,7 +1154,7 @@ func (a *Agent) Restart(ufrag, pwd string) error { func (a *Agent) setGatheringState(newState GatheringState) error { done := make(chan struct{}) - if err := a.run(a.context(), func(context.Context, *Agent) { + if err := a.loop.Run(func(context.Context) { if a.gatheringState != newState && newState == GatheringStateComplete { a.candidateNotifier.EnqueueCandidate(nil) } diff --git a/agent_on_selected_candidate_pair_change_test.go b/agent_on_selected_candidate_pair_change_test.go index 63b35ed7..23e6eed2 100644 --- a/agent_on_selected_candidate_pair_change_test.go +++ b/agent_on_selected_candidate_pair_change_test.go @@ -22,8 +22,8 @@ func TestOnSelectedCandidatePairChange(t *testing.T) { }) require.NoError(t, err) - err = agent.run(context.Background(), func(_ context.Context, agent *Agent) { - agent.setSelectedPair(candidatePair) + err = agent.loop.Run(func(_ context.Context) { + agent.setSelectedPair(candidatePair) // nolint: contextcheck }) require.NoError(t, err) diff --git a/agent_stats.go b/agent_stats.go index 9582cac0..882f201c 100644 --- a/agent_stats.go +++ b/agent_stats.go @@ -11,9 +11,9 @@ import ( // GetCandidatePairsStats returns a list of candidate pair stats func (a *Agent) GetCandidatePairsStats() []CandidatePairStats { var res []CandidatePairStats - err := a.run(a.context(), func(_ context.Context, agent *Agent) { - result := make([]CandidatePairStats, 0, len(agent.checklist)) - for _, cp := range agent.checklist { + err := a.loop.Run(func(_ context.Context) { + result := make([]CandidatePairStats, 0, len(a.checklist)) + for _, cp := range a.checklist { stat := CandidatePairStats{ Timestamp: time.Now(), LocalCandidateID: cp.Local.ID(), @@ -57,9 +57,9 @@ func (a *Agent) GetCandidatePairsStats() []CandidatePairStats { // GetLocalCandidatesStats returns a list of local candidates stats func (a *Agent) GetLocalCandidatesStats() []CandidateStats { var res []CandidateStats - err := a.run(a.context(), func(_ context.Context, agent *Agent) { - result := make([]CandidateStats, 0, len(agent.localCandidates)) - for networkType, localCandidates := range agent.localCandidates { + err := a.loop.Run(func(_ context.Context) { + result := make([]CandidateStats, 0, len(a.localCandidates)) + for networkType, localCandidates := range a.localCandidates { for _, c := range localCandidates { relayProtocol := "" if c.Type() == CandidateTypeRelay { @@ -94,9 +94,9 @@ func (a *Agent) GetLocalCandidatesStats() []CandidateStats { // GetRemoteCandidatesStats returns a list of remote candidates stats func (a *Agent) GetRemoteCandidatesStats() []CandidateStats { var res []CandidateStats - err := a.run(a.context(), func(_ context.Context, agent *Agent) { - result := make([]CandidateStats, 0, len(agent.remoteCandidates)) - for networkType, remoteCandidates := range agent.remoteCandidates { + err := a.loop.Run(func(_ context.Context) { + result := make([]CandidateStats, 0, len(a.remoteCandidates)) + for networkType, remoteCandidates := range a.remoteCandidates { for _, c := range remoteCandidates { stat := CandidateStats{ Timestamp: time.Now(), diff --git a/agent_test.go b/agent_test.go index f5ca54e5..a6bc82af 100644 --- a/agent_test.go +++ b/agent_test.go @@ -34,19 +34,6 @@ func (ba *BadAddr) String() string { return "yyy" } -func runAgentTest(t *testing.T, config *AgentConfig, task func(ctx context.Context, a *Agent)) { - a, err := NewAgent(config) - if err != nil { - t.Fatalf("Error constructing ice.Agent") - } - - if err := a.run(context.Background(), task); err != nil { - t.Fatalf("Agent run failure: %v", err) - } - - assert.NoError(t, a.Close()) -} - func TestHandlePeerReflexive(t *testing.T) { report := test.CheckRoutines(t) defer report() @@ -56,8 +43,10 @@ func TestHandlePeerReflexive(t *testing.T) { defer lim.Stop() t.Run("UDP prflx candidate from handleInbound()", func(t *testing.T) { - var config AgentConfig - runAgentTest(t, &config, func(_ context.Context, a *Agent) { + a, err := NewAgent(&AgentConfig{}) + assert.NoError(t, err) + + assert.NoError(t, a.loop.Run(func(_ context.Context) { a.selector = &controllingSelector{agent: a, log: a.log} hostConfig := CandidateHostConfig{ @@ -113,12 +102,15 @@ func TestHandlePeerReflexive(t *testing.T) { if c.Port() != 999 { t.Fatal("Port number mismatch") } - }) + })) + assert.NoError(t, a.Close()) }) t.Run("Bad network type with handleInbound()", func(t *testing.T) { - var config AgentConfig - runAgentTest(t, &config, func(_ context.Context, a *Agent) { + a, err := NewAgent(&AgentConfig{}) + assert.NoError(t, err) + + assert.NoError(t, a.loop.Run(func(_ context.Context) { a.selector = &controllingSelector{agent: a, log: a.log} hostConfig := CandidateHostConfig{ @@ -140,12 +132,16 @@ func TestHandlePeerReflexive(t *testing.T) { if len(a.remoteCandidates) != 0 { t.Fatal("bad address should not be added to the remote candidate list") } - }) + })) + + assert.NoError(t, a.Close()) }) t.Run("Success from unknown remote, prflx candidate MUST only be created via Binding Request", func(t *testing.T) { - var config AgentConfig - runAgentTest(t, &config, func(_ context.Context, a *Agent) { + a, err := NewAgent(&AgentConfig{}) + assert.NoError(t, err) + + assert.NoError(t, a.loop.Run(func(_ context.Context) { a.selector = &controllingSelector{agent: a, log: a.log} tID := [stun.TransactionIDSize]byte{} copy(tID[:], "ABC") @@ -179,7 +175,9 @@ func TestHandlePeerReflexive(t *testing.T) { if len(a.remoteCandidates) != 0 { t.Fatal("unknown remote was able to create a candidate") } - }) + })) + + assert.NoError(t, a.Close()) }) } @@ -440,7 +438,7 @@ func TestInboundValidity(t *testing.T) { t.Fatalf("Error constructing ice.Agent") } - err = a.run(context.Background(), func(_ context.Context, a *Agent) { + err = a.loop.Run(func(_ context.Context) { a.selector = &controllingSelector{agent: a, log: a.log} // nolint: contextcheck a.handleInbound(buildMsg(stun.ClassRequest, a.localUfrag+":"+a.remoteUfrag, a.localPwd), local, remote) @@ -454,8 +452,10 @@ func TestInboundValidity(t *testing.T) { }) t.Run("Valid bind without fingerprint", func(t *testing.T) { - var config AgentConfig - runAgentTest(t, &config, func(_ context.Context, a *Agent) { + a, err := NewAgent(&AgentConfig{}) + assert.NoError(t, err) + + assert.NoError(t, a.loop.Run(func(_ context.Context) { a.selector = &controllingSelector{agent: a, log: a.log} msg, err := stun.Build(stun.BindingRequest, stun.TransactionID, stun.NewUsername(a.localUfrag+":"+a.remoteUfrag), @@ -470,7 +470,9 @@ func TestInboundValidity(t *testing.T) { if len(a.remoteCandidates) != 1 { t.Fatal("Binding with valid values (but no fingerprint) was unable to create prflx candidate") } - }) + })) + + assert.NoError(t, a.Close()) }) t.Run("Success with invalid TransactionID", func(t *testing.T) { @@ -1120,7 +1122,7 @@ func TestConnectionStateFailedDeleteAllCandidates(t *testing.T) { <-isFailed done := make(chan struct{}) - assert.NoError(t, aAgent.run(context.Background(), func(context.Context, *Agent) { + assert.NoError(t, aAgent.loop.Run(func(context.Context) { assert.Equal(t, len(aAgent.remoteCandidates), 0) assert.Equal(t, len(aAgent.localCandidates), 0) close(done) diff --git a/candidate_base.go b/candidate_base.go index 499b8729..dde84a4a 100644 --- a/candidate_base.go +++ b/candidate_base.go @@ -267,7 +267,7 @@ func (c *candidateBase) handleInboundPacket(buf []byte, srcAddr net.Addr) { return } - if err := a.run(c, func(_ context.Context, a *Agent) { + if err := a.loop.Run(func(_ context.Context) { // nolint: contextcheck a.handleInbound(m, c, srcAddr) }); err != nil { diff --git a/context.go b/context.go deleted file mode 100644 index 36454450..00000000 --- a/context.go +++ /dev/null @@ -1,40 +0,0 @@ -// SPDX-FileCopyrightText: 2023 The Pion community -// SPDX-License-Identifier: MIT - -package ice - -import ( - "context" - "time" -) - -func (a *Agent) context() context.Context { - return agentContext(a.done) -} - -type agentContext chan struct{} - -// Done implements context.Context -func (a agentContext) Done() <-chan struct{} { - return (chan struct{})(a) -} - -// Err implements context.Context -func (a agentContext) Err() error { - select { - case <-(chan struct{})(a): - return ErrRunCanceled - default: - return nil - } -} - -// Deadline implements context.Context -func (a agentContext) Deadline() (deadline time.Time, ok bool) { - return time.Time{}, false -} - -// Value implements context.Context -func (a agentContext) Value(interface{}) interface{} { - return nil -} diff --git a/gather.go b/gather.go index 507b7fc4..d5aea89e 100644 --- a/gather.go +++ b/gather.go @@ -42,7 +42,7 @@ func closeConnAndLog(c io.Closer, log logging.LeveledLogger, msg string, args .. func (a *Agent) GatherCandidates() error { var gatherErr error - if runErr := a.run(a.context(), func(ctx context.Context, _ *Agent) { + if runErr := a.loop.Run(func(ctx context.Context) { if a.gatheringState != GatheringStateNew { gatherErr = ErrMultipleGatherAttempted return @@ -495,7 +495,7 @@ func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, net select { case <-cancelCtx.Done(): return - case <-a.done: + case <-a.loop.Done(): _ = conn.Close() } }() diff --git a/internal/taskloop/taskloop.go b/internal/taskloop/taskloop.go new file mode 100644 index 00000000..348efbf6 --- /dev/null +++ b/internal/taskloop/taskloop.go @@ -0,0 +1,131 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +// Package taskloop implements a task loop to run +// tasks sequentially in a separate Goroutine. +package taskloop + +import ( + "context" + "errors" + "time" + + atomicx "github.com/pion/ice/v3/internal/atomic" +) + +// ErrClosed indicates that the loop has been stopped +var ErrClosed = errors.New("task loop has been stopped") + +type task struct { + fn func(context.Context) + done chan struct{} +} + +// Loop runs submitted task serially in a dedicated Goroutine +type Loop struct { + tasks chan task + runAfter func() + done chan struct{} + taskLoopDone chan struct{} + err atomicx.Error +} + +// NewLoop creates and starts a new task loop +func NewLoop(runAfter func()) *Loop { + l := &Loop{ + tasks: make(chan task), + done: make(chan struct{}), + taskLoopDone: make(chan struct{}), + runAfter: runAfter, + } + + go l.runLoop() + + return l +} + +// Close stops the loop after finishing the execution of the current task. +// Other pending tasks will not be executed. +func (l *Loop) Close() error { + l.err.Store(ErrClosed) + + close(l.done) + <-l.taskLoopDone + + return nil +} + +// RunContext serially executes the submitted callback. +// Blocking tasks must be cancelable by context. +func (l *Loop) RunContext(ctx context.Context, cb func(context.Context)) error { + if err := l.Ok(); err != nil { + return err + } + + done := make(chan struct{}) + select { + case <-ctx.Done(): + return ctx.Err() + case l.tasks <- task{cb, done}: + <-done + return nil + } +} + +// Run serially executes the submitted callback. +func (l *Loop) Run(cb func(context.Context)) error { + return l.RunContext(l, cb) +} + +// Ok waits for the next task to complete and checks then if the loop is still running. +func (l *Loop) Ok() error { + select { + case <-l.done: + return l.Err() + default: + } + + return nil +} + +// runLoop handles registered tasks and agent close. +func (l *Loop) runLoop() { + for { + select { + case <-l.done: + l.runAfter() + close(l.taskLoopDone) + return + case t := <-l.tasks: + t.fn(l) + close(t.done) + } + } +} + +// The following methods implement context.Context for TaskLoop + +// Deadline returns the no valid time as task loops have no deadline. +func (l *Loop) Deadline() (deadline time.Time, ok bool) { + return time.Time{}, false +} + +// Done returns a channel that's closed when the task loop has been stopped. +func (l *Loop) Done() <-chan struct{} { + return l.done +} + +// Err returns nil if the task loop is still running. +// Otherwise it return ErrClosed if the loop has been closed/stopped. +func (l *Loop) Err() error { + if err := l.err.Load(); err != nil { + return err + } + + return ErrClosed +} + +// Value is not supported for task loops +func (l *Loop) Value(_ interface{}) interface{} { + return nil +} diff --git a/transport.go b/transport.go index 9c30a827..d2d76f47 100644 --- a/transport.go +++ b/transport.go @@ -43,7 +43,7 @@ func (c *Conn) BytesReceived() uint64 { } func (a *Agent) connect(ctx context.Context, isControlling bool, remoteUfrag, remotePwd string) (*Conn, error) { - err := a.ok() + err := a.loop.Ok() if err != nil { return nil, err } @@ -54,8 +54,8 @@ func (a *Agent) connect(ctx context.Context, isControlling bool, remoteUfrag, re // Block until pair selected select { - case <-a.done: - return nil, a.getErr() + case <-a.loop.Done(): + return nil, a.loop.Err() case <-ctx.Done(): return nil, ErrCanceledByCaller case <-a.onConnected: @@ -68,7 +68,7 @@ func (a *Agent) connect(ctx context.Context, isControlling bool, remoteUfrag, re // Read implements the Conn Read method. func (c *Conn) Read(p []byte) (int, error) { - err := c.agent.ok() + err := c.agent.loop.Ok() if err != nil { return 0, err } @@ -80,7 +80,7 @@ func (c *Conn) Read(p []byte) (int, error) { // Write implements the Conn Write method. func (c *Conn) Write(p []byte) (int, error) { - err := c.agent.ok() + err := c.agent.loop.Ok() if err != nil { return 0, err } @@ -91,8 +91,8 @@ func (c *Conn) Write(p []byte) (int, error) { pair := c.agent.getSelectedPair() if pair == nil { - if err = c.agent.run(c.agent.context(), func(_ context.Context, a *Agent) { - pair = a.getBestValidCandidatePair() + if err = c.agent.loop.Run(func(_ context.Context) { + pair = c.agent.getBestValidCandidatePair() }); err != nil { return 0, err } diff --git a/transport_test.go b/transport_test.go index 5e967fd2..d6aaabf3 100644 --- a/transport_test.go +++ b/transport_test.go @@ -49,8 +49,8 @@ func testTimeout(t *testing.T, c *Conn, timeout time.Duration) { var cs ConnectionState - err := c.agent.run(context.Background(), func(_ context.Context, agent *Agent) { - cs = agent.connectionState + err := c.agent.loop.Run(func(_ context.Context) { + cs = c.agent.connectionState }) if err != nil { // We should never get here.