Skip to content

Commit

Permalink
Taskloop Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean-Der committed Mar 21, 2024
1 parent d17be4d commit 0c22222
Show file tree
Hide file tree
Showing 10 changed files with 243 additions and 248 deletions.
218 changes: 60 additions & 158 deletions agent.go

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions agent_on_selected_candidate_pair_change_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ func TestOnSelectedCandidatePairChange(t *testing.T) {
})
require.NoError(t, err)

err = agent.run(context.Background(), func(_ context.Context, agent *Agent) {
agent.setSelectedPair(candidatePair)
err = agent.loop.Run(func(_ context.Context) {
agent.setSelectedPair(candidatePair) // nolint: contextcheck
})
require.NoError(t, err)

Expand Down
18 changes: 9 additions & 9 deletions agent_stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ import (
// GetCandidatePairsStats returns a list of candidate pair stats
func (a *Agent) GetCandidatePairsStats() []CandidatePairStats {
var res []CandidatePairStats
err := a.run(a.context(), func(_ context.Context, agent *Agent) {
result := make([]CandidatePairStats, 0, len(agent.checklist))
for _, cp := range agent.checklist {
err := a.loop.Run(func(_ context.Context) {
result := make([]CandidatePairStats, 0, len(a.checklist))
for _, cp := range a.checklist {
stat := CandidatePairStats{
Timestamp: time.Now(),
LocalCandidateID: cp.Local.ID(),
Expand Down Expand Up @@ -57,9 +57,9 @@ func (a *Agent) GetCandidatePairsStats() []CandidatePairStats {
// GetLocalCandidatesStats returns a list of local candidates stats
func (a *Agent) GetLocalCandidatesStats() []CandidateStats {
var res []CandidateStats
err := a.run(a.context(), func(_ context.Context, agent *Agent) {
result := make([]CandidateStats, 0, len(agent.localCandidates))
for networkType, localCandidates := range agent.localCandidates {
err := a.loop.Run(func(_ context.Context) {
result := make([]CandidateStats, 0, len(a.localCandidates))
for networkType, localCandidates := range a.localCandidates {
for _, c := range localCandidates {
relayProtocol := ""
if c.Type() == CandidateTypeRelay {
Expand Down Expand Up @@ -94,9 +94,9 @@ func (a *Agent) GetLocalCandidatesStats() []CandidateStats {
// GetRemoteCandidatesStats returns a list of remote candidates stats
func (a *Agent) GetRemoteCandidatesStats() []CandidateStats {
var res []CandidateStats
err := a.run(a.context(), func(_ context.Context, agent *Agent) {
result := make([]CandidateStats, 0, len(agent.remoteCandidates))
for networkType, remoteCandidates := range agent.remoteCandidates {
err := a.loop.Run(func(_ context.Context) {
result := make([]CandidateStats, 0, len(a.remoteCandidates))
for networkType, remoteCandidates := range a.remoteCandidates {
for _, c := range remoteCandidates {
stat := CandidateStats{
Timestamp: time.Now(),
Expand Down
56 changes: 29 additions & 27 deletions agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,6 @@ func (ba *BadAddr) String() string {
return "yyy"
}

func runAgentTest(t *testing.T, config *AgentConfig, task func(ctx context.Context, a *Agent)) {
a, err := NewAgent(config)
if err != nil {
t.Fatalf("Error constructing ice.Agent")
}

if err := a.run(context.Background(), task); err != nil {
t.Fatalf("Agent run failure: %v", err)
}

assert.NoError(t, a.Close())
}

func TestHandlePeerReflexive(t *testing.T) {
report := test.CheckRoutines(t)
defer report()
Expand All @@ -56,8 +43,10 @@ func TestHandlePeerReflexive(t *testing.T) {
defer lim.Stop()

t.Run("UDP prflx candidate from handleInbound()", func(t *testing.T) {
var config AgentConfig
runAgentTest(t, &config, func(_ context.Context, a *Agent) {
a, err := NewAgent(&AgentConfig{})
assert.NoError(t, err)

assert.NoError(t, a.loop.Run(func(_ context.Context) {
a.selector = &controllingSelector{agent: a, log: a.log}

hostConfig := CandidateHostConfig{
Expand Down Expand Up @@ -113,12 +102,15 @@ func TestHandlePeerReflexive(t *testing.T) {
if c.Port() != 999 {
t.Fatal("Port number mismatch")
}
})
}))
assert.NoError(t, a.Close())
})

t.Run("Bad network type with handleInbound()", func(t *testing.T) {
var config AgentConfig
runAgentTest(t, &config, func(_ context.Context, a *Agent) {
a, err := NewAgent(&AgentConfig{})
assert.NoError(t, err)

assert.NoError(t, a.loop.Run(func(_ context.Context) {
a.selector = &controllingSelector{agent: a, log: a.log}

hostConfig := CandidateHostConfig{
Expand All @@ -140,12 +132,16 @@ func TestHandlePeerReflexive(t *testing.T) {
if len(a.remoteCandidates) != 0 {
t.Fatal("bad address should not be added to the remote candidate list")
}
})
}))

assert.NoError(t, a.Close())
})

t.Run("Success from unknown remote, prflx candidate MUST only be created via Binding Request", func(t *testing.T) {
var config AgentConfig
runAgentTest(t, &config, func(_ context.Context, a *Agent) {
a, err := NewAgent(&AgentConfig{})
assert.NoError(t, err)

assert.NoError(t, a.loop.Run(func(_ context.Context) {
a.selector = &controllingSelector{agent: a, log: a.log}
tID := [stun.TransactionIDSize]byte{}
copy(tID[:], "ABC")
Expand Down Expand Up @@ -179,7 +175,9 @@ func TestHandlePeerReflexive(t *testing.T) {
if len(a.remoteCandidates) != 0 {
t.Fatal("unknown remote was able to create a candidate")
}
})
}))

assert.NoError(t, a.Close())
})
}

Expand Down Expand Up @@ -440,7 +438,7 @@ func TestInboundValidity(t *testing.T) {
t.Fatalf("Error constructing ice.Agent")
}

err = a.run(context.Background(), func(_ context.Context, a *Agent) {
err = a.loop.Run(func(_ context.Context) {
a.selector = &controllingSelector{agent: a, log: a.log}
// nolint: contextcheck
a.handleInbound(buildMsg(stun.ClassRequest, a.localUfrag+":"+a.remoteUfrag, a.localPwd), local, remote)
Expand All @@ -454,8 +452,10 @@ func TestInboundValidity(t *testing.T) {
})

t.Run("Valid bind without fingerprint", func(t *testing.T) {
var config AgentConfig
runAgentTest(t, &config, func(_ context.Context, a *Agent) {
a, err := NewAgent(&AgentConfig{})
assert.NoError(t, err)

assert.NoError(t, a.loop.Run(func(_ context.Context) {
a.selector = &controllingSelector{agent: a, log: a.log}
msg, err := stun.Build(stun.BindingRequest, stun.TransactionID,
stun.NewUsername(a.localUfrag+":"+a.remoteUfrag),
Expand All @@ -470,7 +470,9 @@ func TestInboundValidity(t *testing.T) {
if len(a.remoteCandidates) != 1 {
t.Fatal("Binding with valid values (but no fingerprint) was unable to create prflx candidate")
}
})
}))

assert.NoError(t, a.Close())
})

t.Run("Success with invalid TransactionID", func(t *testing.T) {
Expand Down Expand Up @@ -1120,7 +1122,7 @@ func TestConnectionStateFailedDeleteAllCandidates(t *testing.T) {
<-isFailed

done := make(chan struct{})
assert.NoError(t, aAgent.run(context.Background(), func(context.Context, *Agent) {
assert.NoError(t, aAgent.loop.Run(func(context.Context) {
assert.Equal(t, len(aAgent.remoteCandidates), 0)
assert.Equal(t, len(aAgent.localCandidates), 0)
close(done)
Expand Down
2 changes: 1 addition & 1 deletion candidate_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ func (c *candidateBase) handleInboundPacket(buf []byte, srcAddr net.Addr) {
return
}

if err := a.run(c, func(_ context.Context, a *Agent) {
if err := a.loop.Run(func(_ context.Context) {
// nolint: contextcheck
a.handleInbound(m, c, srcAddr)
}); err != nil {
Expand Down
40 changes: 0 additions & 40 deletions context.go

This file was deleted.

4 changes: 2 additions & 2 deletions gather.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func closeConnAndLog(c io.Closer, log logging.LeveledLogger, msg string, args ..
func (a *Agent) GatherCandidates() error {
var gatherErr error

if runErr := a.run(a.context(), func(ctx context.Context, _ *Agent) {
if runErr := a.loop.Run(func(ctx context.Context) {
if a.gatheringState != GatheringStateNew {
gatherErr = ErrMultipleGatherAttempted
return
Expand Down Expand Up @@ -495,7 +495,7 @@ func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*stun.URI, net
select {
case <-cancelCtx.Done():
return
case <-a.done:
case <-a.loop.Done():
_ = conn.Close()
}
}()
Expand Down
131 changes: 131 additions & 0 deletions internal/taskloop/taskloop.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

// Package taskloop implements a task loop to run
// tasks sequentially in a separate Goroutine.
package taskloop

import (
"context"
"errors"
"time"

atomicx "github.com/pion/ice/v3/internal/atomic"
)

// ErrClosed indicates that the loop has been stopped
var ErrClosed = errors.New("task loop has been stopped")

type task struct {
fn func(context.Context)
done chan struct{}
}

// Loop runs submitted task serially in a dedicated Goroutine
type Loop struct {
tasks chan task
runAfter func()
done chan struct{}
taskLoopDone chan struct{}
err atomicx.Error
}

// NewLoop creates and starts a new task loop
func NewLoop(runAfter func()) *Loop {
l := &Loop{
tasks: make(chan task),
done: make(chan struct{}),
taskLoopDone: make(chan struct{}),
runAfter: runAfter,
}

go l.runLoop()

return l
}

// Close stops the loop after finishing the execution of the current task.
// Other pending tasks will not be executed.
func (l *Loop) Close() error {
l.err.Store(ErrClosed)

close(l.done)
<-l.taskLoopDone

return nil
}

// RunContext serially executes the submitted callback.
// Blocking tasks must be cancelable by context.
func (l *Loop) RunContext(ctx context.Context, cb func(context.Context)) error {
if err := l.Ok(); err != nil {
return err
}

done := make(chan struct{})
select {
case <-ctx.Done():
return ctx.Err()
case l.tasks <- task{cb, done}:
<-done
return nil
}
}

// Run serially executes the submitted callback.
func (l *Loop) Run(cb func(context.Context)) error {
return l.RunContext(l, cb)
}

// Ok waits for the next task to complete and checks then if the loop is still running.
func (l *Loop) Ok() error {
select {
case <-l.done:
return l.Err()
default:
}

return nil
}

// runLoop handles registered tasks and agent close.
func (l *Loop) runLoop() {
for {
select {
case <-l.done:
l.runAfter()
close(l.taskLoopDone)
return
case t := <-l.tasks:
t.fn(l)
close(t.done)
}
}
}

// The following methods implement context.Context for TaskLoop

// Deadline returns the no valid time as task loops have no deadline.
func (l *Loop) Deadline() (deadline time.Time, ok bool) {
return time.Time{}, false
}

// Done returns a channel that's closed when the task loop has been stopped.
func (l *Loop) Done() <-chan struct{} {
return l.done
}

// Err returns nil if the task loop is still running.
// Otherwise it return ErrClosed if the loop has been closed/stopped.
func (l *Loop) Err() error {
if err := l.err.Load(); err != nil {
return err
}

return ErrClosed
}

// Value is not supported for task loops
func (l *Loop) Value(_ interface{}) interface{} {
return nil
}
Loading

0 comments on commit 0c22222

Please sign in to comment.