From def1670796f6af2dd505cbaf858f066e2cb8573e Mon Sep 17 00:00:00 2001 From: Eric Daniels Date: Thu, 25 Jul 2024 16:27:33 -0400 Subject: [PATCH] Clean up Agents in tests more aggressively --- active_tcp_test.go | 22 +- ..._get_best_available_candidate_pair_test.go | 17 +- agent_get_best_valid_candidate_pair_test.go | 5 +- ..._on_selected_candidate_pair_change_test.go | 4 +- agent_test.go | 249 +++++++++++++----- agent_udpmux_test.go | 16 ++ candidate_relay_test.go | 13 +- candidate_server_reflexive_test.go | 13 +- connectivity_vnet_test.go | 16 +- gather_test.go | 84 ++++-- gather_vnet_test.go | 48 ++-- go.mod | 2 +- go.sum | 4 +- mdns_test.go | 22 +- transport_test.go | 13 +- transport_vnet_test.go | 5 +- udp_mux_test.go | 4 +- 17 files changed, 361 insertions(+), 176 deletions(-) diff --git a/active_tcp_test.go b/active_tcp_test.go index 44d6a47c..fc51844b 100644 --- a/active_tcp_test.go +++ b/active_tcp_test.go @@ -71,7 +71,7 @@ func TestActiveTCP(t *testing.T) { networkTypes: []NetworkType{NetworkTypeTCP6}, listenIPAddress: getLocalIPAddress(t, NetworkTypeTCP6), selectedPairNetworkType: tcp, - // if we don't use mDNS, we will very liekly be filtering out location tracked ips. + // if we don't use mDNS, we will very likely be filtering out location tracked ips. useMDNS: true, }, testCase{ @@ -79,7 +79,7 @@ func TestActiveTCP(t *testing.T) { networkTypes: supportedNetworkTypes(), listenIPAddress: getLocalIPAddress(t, NetworkTypeTCP6), selectedPairNetworkType: udp, - // if we don't use mDNS, we will very liekly be filtering out location tracked ips. + // if we don't use mDNS, we will very likely be filtering out location tracked ips. useMDNS: true, }, ) @@ -143,6 +143,11 @@ func TestActiveTCP(t *testing.T) { r.NotNil(passiveAgentConn) r.NotNil(activeAgenConn) + defer func() { + r.NoError(activeAgenConn.Close()) + r.NoError(passiveAgentConn.Close()) + }() + pair := passiveAgent.getSelectedPair() r.NotNil(pair) r.Equal(testCase.selectedPairNetworkType, pair.Local.NetworkType().NetworkShort()) @@ -163,9 +168,6 @@ func TestActiveTCP(t *testing.T) { n, err = passiveAgentConn.Read(buffer) r.NoError(err) r.Equal(bar, buffer[:n]) - - r.NoError(activeAgenConn.Close()) - r.NoError(passiveAgentConn.Close()) }) } } @@ -185,9 +187,17 @@ func TestActiveTCP_NonBlocking(t *testing.T) { aAgent, err := NewAgent(cfg) require.NoError(t, err) + defer func() { + require.NoError(t, aAgent.Close()) + }() + bAgent, err := NewAgent(cfg) require.NoError(t, err) + defer func() { + require.NoError(t, bAgent.Close()) + }() + isConnected := make(chan interface{}) err = aAgent.OnConnectionStateChange(func(c ConnectionState) { if c == ConnectionStateConnected { @@ -205,6 +215,4 @@ func TestActiveTCP_NonBlocking(t *testing.T) { connect(aAgent, bAgent) <-isConnected - require.NoError(t, aAgent.Close()) - require.NoError(t, bAgent.Close()) } diff --git a/agent_get_best_available_candidate_pair_test.go b/agent_get_best_available_candidate_pair_test.go index 44c6a78e..c7bb2e44 100644 --- a/agent_get_best_available_candidate_pair_test.go +++ b/agent_get_best_available_candidate_pair_test.go @@ -13,19 +13,12 @@ import ( ) func TestNoBestAvailableCandidatePairAfterAgentConstruction(t *testing.T) { - agent := setupTest(t) - - require.Nil(t, agent.getBestAvailableCandidatePair()) - - tearDownTest(t, agent) -} - -func setupTest(t *testing.T) *Agent { agent, err := NewAgent(&AgentConfig{}) require.NoError(t, err) - return agent -} -func tearDownTest(t *testing.T, agent *Agent) { - require.NoError(t, agent.Close()) + defer func() { + require.NoError(t, agent.Close()) + }() + + require.Nil(t, agent.getBestAvailableCandidatePair()) } diff --git a/agent_get_best_valid_candidate_pair_test.go b/agent_get_best_valid_candidate_pair_test.go index 24e3c475..2ab20693 100644 --- a/agent_get_best_valid_candidate_pair_test.go +++ b/agent_get_best_valid_candidate_pair_test.go @@ -14,6 +14,9 @@ import ( func TestAgentGetBestValidCandidatePair(t *testing.T) { f := setupTestAgentGetBestValidCandidatePair(t) + defer func() { + require.NoError(t, f.sut.Close()) + }() remoteCandidatesFromLowestPriorityToHighest := []Candidate{f.relayRemote, f.srflxRemote, f.prflxRemote, f.hostRemote} @@ -26,8 +29,6 @@ func TestAgentGetBestValidCandidatePair(t *testing.T) { require.Equal(t, actualBestPair.String(), expectedBestPair.String()) } - - require.NoError(t, f.sut.Close()) } func setupTestAgentGetBestValidCandidatePair(t *testing.T) *TestAgentGetBestValidCandidatePairFixture { diff --git a/agent_on_selected_candidate_pair_change_test.go b/agent_on_selected_candidate_pair_change_test.go index b744749d..6ac21490 100644 --- a/agent_on_selected_candidate_pair_change_test.go +++ b/agent_on_selected_candidate_pair_change_test.go @@ -15,6 +15,9 @@ import ( func TestOnSelectedCandidatePairChange(t *testing.T) { agent, candidatePair := fixtureTestOnSelectedCandidatePairChange(t) + defer func() { + require.NoError(t, agent.Close()) + }() callbackCalled := make(chan struct{}, 1) err := agent.OnSelectedCandidatePairChange(func(_, _ Candidate) { @@ -28,7 +31,6 @@ func TestOnSelectedCandidatePairChange(t *testing.T) { require.NoError(t, err) <-callbackCalled - require.NoError(t, agent.Close()) } func fixtureTestOnSelectedCandidatePairChange(t *testing.T) (*Agent, *CandidatePair) { diff --git a/agent_test.go b/agent_test.go index 6f540350..401a7f14 100644 --- a/agent_test.go +++ b/agent_test.go @@ -42,6 +42,9 @@ func TestHandlePeerReflexive(t *testing.T) { t.Run("UDP prflx candidate from handleInbound()", func(t *testing.T) { a, err := NewAgent(&AgentConfig{}) require.NoError(t, err) + defer func() { + require.NoError(t, a.Close()) + }() require.NoError(t, a.loop.Run(a.loop, func(_ context.Context) { a.selector = &controllingSelector{agent: a, log: a.log} @@ -98,12 +101,14 @@ func TestHandlePeerReflexive(t *testing.T) { t.Fatal("Port number mismatch") } })) - require.NoError(t, a.Close()) }) t.Run("Bad network type with handleInbound()", func(t *testing.T) { a, err := NewAgent(&AgentConfig{}) require.NoError(t, err) + defer func() { + require.NoError(t, a.Close()) + }() require.NoError(t, a.loop.Run(a.loop, func(_ context.Context) { a.selector = &controllingSelector{agent: a, log: a.log} @@ -128,13 +133,14 @@ func TestHandlePeerReflexive(t *testing.T) { t.Fatal("bad address should not be added to the remote candidate list") } })) - - require.NoError(t, a.Close()) }) t.Run("Success from unknown remote, prflx candidate MUST only be created via Binding Request", func(t *testing.T) { a, err := NewAgent(&AgentConfig{}) require.NoError(t, err) + defer func() { + require.NoError(t, a.Close()) + }() require.NoError(t, a.loop.Run(a.loop, func(_ context.Context) { a.selector = &controllingSelector{agent: a, log: a.log} @@ -169,8 +175,6 @@ func TestHandlePeerReflexive(t *testing.T) { t.Fatal("unknown remote was able to create a candidate") } })) - - require.NoError(t, a.Close()) }) } @@ -216,6 +220,9 @@ func TestConnectivityOnStartup(t *testing.T) { aAgent, err := NewAgent(cfg0) require.NoError(t, err) + defer func() { + require.NoError(t, aAgent.Close()) + }() require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) cfg1 := &AgentConfig{ @@ -228,9 +235,12 @@ func TestConnectivityOnStartup(t *testing.T) { bAgent, err := NewAgent(cfg1) require.NoError(t, err) + defer func() { + require.NoError(t, bAgent.Close()) + }() require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) - aConn, bConn := func(aAgent, bAgent *Agent) (*Conn, *Conn) { + func(aAgent, bAgent *Agent) (*Conn, *Conn) { // Manual signaling aUfrag, aPwd, err := aAgent.GetLocalUserCredentials() require.NoError(t, err) @@ -280,7 +290,6 @@ func TestConnectivityOnStartup(t *testing.T) { <-bConnected require.NoError(t, wan.Stop()) - closePipe(t, aConn, bConn) } func TestConnectivityLite(t *testing.T) { @@ -315,6 +324,9 @@ func TestConnectivityLite(t *testing.T) { aAgent, err := NewAgent(cfg0) require.NoError(t, err) + defer func() { + require.NoError(t, aAgent.Close()) + }() require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) cfg1 := &AgentConfig{ @@ -328,16 +340,17 @@ func TestConnectivityLite(t *testing.T) { bAgent, err := NewAgent(cfg1) require.NoError(t, err) + defer func() { + require.NoError(t, bAgent.Close()) + }() require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) - aConn, bConn := connectWithVNet(aAgent, bAgent) + connectWithVNet(aAgent, bAgent) // Ensure pair selected // Note: this assumes ConnectionStateConnected is thrown after selecting the final pair <-aConnected <-bConnected - - closePipe(t, aConn, bConn) } func TestInboundValidity(t *testing.T) { @@ -372,6 +385,9 @@ func TestInboundValidity(t *testing.T) { if err != nil { t.Fatalf("Error constructing ice.Agent") } + defer func() { + require.NoError(t, a.Close()) + }() a.handleInbound(buildMsg(stun.ClassRequest, "invalid", a.localPwd), local, remote) if len(a.remoteCandidates) == 1 { @@ -382,8 +398,6 @@ func TestInboundValidity(t *testing.T) { if len(a.remoteCandidates) == 1 { t.Fatal("Binding with invalid MessageIntegrity was able to create prflx candidate") } - - require.NoError(t, a.Close()) }) t.Run("Invalid Binding success responses should be discarded", func(t *testing.T) { @@ -391,13 +405,14 @@ func TestInboundValidity(t *testing.T) { if err != nil { t.Fatalf("Error constructing ice.Agent") } + defer func() { + require.NoError(t, a.Close()) + }() a.handleInbound(buildMsg(stun.ClassSuccessResponse, a.localUfrag+":"+a.remoteUfrag, "Invalid"), local, remote) if len(a.remoteCandidates) == 1 { t.Fatal("Binding with invalid MessageIntegrity was able to create prflx candidate") } - - require.NoError(t, a.Close()) }) t.Run("Discard non-binding messages", func(t *testing.T) { @@ -405,13 +420,14 @@ func TestInboundValidity(t *testing.T) { if err != nil { t.Fatalf("Error constructing ice.Agent") } + defer func() { + require.NoError(t, a.Close()) + }() a.handleInbound(buildMsg(stun.ClassErrorResponse, a.localUfrag+":"+a.remoteUfrag, "Invalid"), local, remote) if len(a.remoteCandidates) == 1 { t.Fatal("non-binding message was able to create prflxRemote") } - - require.NoError(t, a.Close()) }) t.Run("Valid bind request", func(t *testing.T) { @@ -419,6 +435,9 @@ func TestInboundValidity(t *testing.T) { if err != nil { t.Fatalf("Error constructing ice.Agent") } + defer func() { + require.NoError(t, a.Close()) + }() err = a.loop.Run(a.loop, func(_ context.Context) { a.selector = &controllingSelector{agent: a, log: a.log} @@ -430,12 +449,14 @@ func TestInboundValidity(t *testing.T) { }) require.NoError(t, err) - require.NoError(t, a.Close()) }) t.Run("Valid bind without fingerprint", func(t *testing.T) { a, err := NewAgent(&AgentConfig{}) require.NoError(t, err) + defer func() { + require.NoError(t, a.Close()) + }() require.NoError(t, a.loop.Run(a.loop, func(_ context.Context) { a.selector = &controllingSelector{agent: a, log: a.log} @@ -451,8 +472,6 @@ func TestInboundValidity(t *testing.T) { t.Fatal("Binding with valid values (but no fingerprint) was unable to create prflx candidate") } })) - - require.NoError(t, a.Close()) }) t.Run("Success with invalid TransactionID", func(t *testing.T) { @@ -460,6 +479,9 @@ func TestInboundValidity(t *testing.T) { if err != nil { t.Fatalf("Error constructing ice.Agent") } + defer func() { + require.NoError(t, a.Close()) + }() hostConfig := CandidateHostConfig{ Network: "udp", @@ -486,8 +508,6 @@ func TestInboundValidity(t *testing.T) { if len(a.remoteCandidates) != 0 { t.Fatal("unknown remote was able to create a candidate") } - - require.NoError(t, a.Close()) }) } @@ -496,6 +516,9 @@ func TestInvalidAgentStarts(t *testing.T) { a, err := NewAgent(&AgentConfig{}) require.NoError(t, err) + defer func() { + require.NoError(t, a.Close()) + }() ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) @@ -516,8 +539,6 @@ func TestInvalidAgentStarts(t *testing.T) { if _, err = a.Dial(context.TODO(), "foo", "bar"); err != nil && !errors.Is(err, ErrMultipleStart) { t.Fatal(err) } - - require.NoError(t, a.Close()) } // Assert that Agent emits Connecting/Connected/Disconnected/Failed/Closed messages @@ -539,17 +560,34 @@ func TestConnectionStateCallback(t *testing.T) { InterfaceFilter: problematicNetworkInterfaces, } + isClosed := make(chan interface{}) + aAgent, err := NewAgent(cfg) require.NoError(t, err) + defer func() { + select { + case <-isClosed: + return + default: + } + require.NoError(t, aAgent.Close()) + }() bAgent, err := NewAgent(cfg) require.NoError(t, err) + defer func() { + select { + case <-isClosed: + return + default: + } + require.NoError(t, bAgent.Close()) + }() isChecking := make(chan interface{}) isConnected := make(chan interface{}) isDisconnected := make(chan interface{}) isFailed := make(chan interface{}) - isClosed := make(chan interface{}) err = aAgent.OnConnectionStateChange(func(c ConnectionState) { switch c { case ConnectionStateChecking: @@ -586,12 +624,14 @@ func TestInvalidGather(t *testing.T) { if err != nil { t.Fatalf("Error constructing ice.Agent") } + defer func() { + require.NoError(t, a.Close()) + }() err = a.GatherCandidates() if !errors.Is(err, ErrNoOnCandidateHandler) { t.Fatal("trickle GatherCandidates succeeded without OnCandidate") } - require.NoError(t, a.Close()) }) } @@ -605,6 +645,9 @@ func TestCandidatePairStats(t *testing.T) { if err != nil { t.Fatalf("Failed to create agent: %s", err) } + defer func() { + require.NoError(t, a.Close()) + }() hostConfig := &CandidateHostConfig{ Network: "udp", @@ -723,8 +766,6 @@ func TestCandidatePairStats(t *testing.T) { t.Fatalf("expected host-prflx pair to have state failed, it has state %s instead", prflxPairStat.State.String()) } - - require.NoError(t, a.Close()) } func TestLocalCandidateStats(t *testing.T) { @@ -737,6 +778,9 @@ func TestLocalCandidateStats(t *testing.T) { if err != nil { t.Fatalf("Failed to create agent: %s", err) } + defer func() { + require.NoError(t, a.Close()) + }() hostConfig := &CandidateHostConfig{ Network: "udp", @@ -803,8 +847,6 @@ func TestLocalCandidateStats(t *testing.T) { if srflxLocalStat.ID != srflxLocal.ID() { t.Fatal("missing srflx local stat") } - - require.NoError(t, a.Close()) } func TestRemoteCandidateStats(t *testing.T) { @@ -817,6 +859,9 @@ func TestRemoteCandidateStats(t *testing.T) { if err != nil { t.Fatalf("Failed to create agent: %s", err) } + defer func() { + require.NoError(t, a.Close()) + }() relayConfig := &CandidateRelayConfig{ Network: "udp", @@ -922,8 +967,6 @@ func TestRemoteCandidateStats(t *testing.T) { if hostRemoteStat.ID != hostRemote.ID() { t.Fatal("missing host remote stat") } - - require.NoError(t, a.Close()) } func TestInitExtIPMapping(t *testing.T) { @@ -935,6 +978,7 @@ func TestInitExtIPMapping(t *testing.T) { t.Fatalf("Failed to create agent: %v", err) } if a.extIPMapper != nil { + require.NoError(t, a.Close()) t.Fatal("a.extIPMapper should be nil by default") } require.NoError(t, a.Close()) @@ -945,9 +989,11 @@ func TestInitExtIPMapping(t *testing.T) { NAT1To1IPCandidateType: CandidateTypeHost, }) if err != nil { + require.NoError(t, a.Close()) t.Fatalf("Failed to create agent: %v", err) } if a.extIPMapper != nil { + require.NoError(t, a.Close()) t.Fatal("a.extIPMapper should be nil by default") } require.NoError(t, a.Close()) @@ -1002,6 +1048,9 @@ func TestBindingRequestTimeout(t *testing.T) { a, err := NewAgent(&AgentConfig{}) require.NoError(t, err) + defer func() { + require.NoError(t, a.Close()) + }() now := time.Now() a.pendingBindingRequests = append(a.pendingBindingRequests, bindingRequest{ @@ -1019,7 +1068,6 @@ func TestBindingRequestTimeout(t *testing.T) { a.invalidatePendingBindingRequests(now) require.Equal(t, expectedRemovalCount, len(a.pendingBindingRequests), "Binding invalidation due to timeout did not remove the correct number of binding requests") - require.NoError(t, a.Close()) } // TestAgentCredentials checks if local username fragments and passwords (if set) meet RFC standard @@ -1036,9 +1084,11 @@ func TestAgentCredentials(t *testing.T) { agent, err := NewAgent(&AgentConfig{LoggerFactory: log}) require.NoError(t, err) + defer func() { + require.NoError(t, agent.Close()) + }() require.GreaterOrEqual(t, len([]rune(agent.localUfrag))*8, 24) require.GreaterOrEqual(t, len([]rune(agent.localPwd))*8, 128) - require.NoError(t, agent.Close()) // Should honor RFC standards // Local values MUST be unguessable, with at least 128 bits of @@ -1071,9 +1121,15 @@ func TestConnectionStateFailedDeleteAllCandidates(t *testing.T) { aAgent, err := NewAgent(cfg) require.NoError(t, err) + defer func() { + require.NoError(t, aAgent.Close()) + }() bAgent, err := NewAgent(cfg) require.NoError(t, err) + defer func() { + require.NoError(t, bAgent.Close()) + }() isFailed := make(chan interface{}) require.NoError(t, aAgent.OnConnectionStateChange(func(c ConnectionState) { @@ -1092,9 +1148,6 @@ func TestConnectionStateFailedDeleteAllCandidates(t *testing.T) { close(done) })) <-done - - require.NoError(t, aAgent.Close()) - require.NoError(t, bAgent.Close()) } // Assert that the ICE Agent can go directly from Connecting -> Failed on both sides @@ -1114,9 +1167,15 @@ func TestConnectionStateConnectingToFailed(t *testing.T) { aAgent, err := NewAgent(cfg) require.NoError(t, err) + defer func() { + require.NoError(t, aAgent.Close()) + }() bAgent, err := NewAgent(cfg) require.NoError(t, err) + defer func() { + require.NoError(t, bAgent.Close()) + }() var isFailed sync.WaitGroup var isChecking sync.WaitGroup @@ -1151,9 +1210,6 @@ func TestConnectionStateConnectingToFailed(t *testing.T) { isChecking.Wait() isFailed.Wait() - - require.NoError(t, aAgent.Close()) - require.NoError(t, bAgent.Close()) } func TestAgentRestart(t *testing.T) { @@ -1168,6 +1224,7 @@ func TestAgentRestart(t *testing.T) { DisconnectedTimeout: &oneSecond, FailedTimeout: &oneSecond, }) + defer closePipe(t, connA, connB) ctx, cancel := context.WithCancel(context.Background()) require.NoError(t, connB.agent.OnConnectionStateChange(func(c ConnectionState) { @@ -1180,8 +1237,6 @@ func TestAgentRestart(t *testing.T) { require.NoError(t, connA.agent.Restart("", "")) <-ctx.Done() - require.NoError(t, connA.agent.Close()) - require.NoError(t, connB.agent.Close()) }) t.Run("Restart When Closed", func(t *testing.T) { @@ -1197,6 +1252,7 @@ func TestAgentRestart(t *testing.T) { DisconnectedTimeout: &oneSecond, FailedTimeout: &oneSecond, }) + defer closePipe(t, connA, connB) ctx, cancel := context.WithCancel(context.Background()) require.NoError(t, connB.agent.OnConnectionStateChange(func(c ConnectionState) { @@ -1207,8 +1263,6 @@ func TestAgentRestart(t *testing.T) { require.NoError(t, connA.agent.Restart("", "")) <-ctx.Done() - require.NoError(t, connA.agent.Close()) - require.NoError(t, connB.agent.Close()) }) t.Run("Restart Both Sides", func(t *testing.T) { @@ -1228,6 +1282,7 @@ func TestAgentRestart(t *testing.T) { DisconnectedTimeout: &oneSecond, FailedTimeout: &oneSecond, }) + defer closePipe(t, connA, connB) connAFirstCandidates := generateCandidateAddressStrings(connA.agent.GetLocalCandidates()) connBFirstCandidates := generateCandidateAddressStrings(connB.agent.GetLocalCandidates()) @@ -1259,9 +1314,6 @@ func TestAgentRestart(t *testing.T) { // Assert that we have new candidates each time require.NotEqual(t, connAFirstCandidates, generateCandidateAddressStrings(connA.agent.GetLocalCandidates())) require.NotEqual(t, connBFirstCandidates, generateCandidateAddressStrings(connB.agent.GetLocalCandidates())) - - require.NoError(t, connA.agent.Close()) - require.NoError(t, connB.agent.Close()) }) } @@ -1271,6 +1323,9 @@ func TestGetRemoteCredentials(t *testing.T) { if err != nil { t.Fatalf("Error constructing ice.Agent: %v", err) } + defer func() { + require.NoError(t, a.Close()) + }() a.remoteUfrag = "remoteUfrag" a.remotePwd = "remotePwd" @@ -1280,8 +1335,6 @@ func TestGetRemoteCredentials(t *testing.T) { require.Equal(t, actualUfrag, a.remoteUfrag) require.Equal(t, actualPwd, a.remotePwd) - - require.NoError(t, a.Close()) } func TestGetRemoteCandidates(t *testing.T) { @@ -1291,6 +1344,9 @@ func TestGetRemoteCandidates(t *testing.T) { if err != nil { t.Fatalf("Error constructing ice.Agent: %v", err) } + defer func() { + require.NoError(t, a.Close()) + }() expectedCandidates := []Candidate{} @@ -1313,8 +1369,6 @@ func TestGetRemoteCandidates(t *testing.T) { actualCandidates, err := a.GetRemoteCandidates() require.NoError(t, err) require.ElementsMatch(t, expectedCandidates, actualCandidates) - - require.NoError(t, a.Close()) } func TestGetLocalCandidates(t *testing.T) { @@ -1324,6 +1378,9 @@ func TestGetLocalCandidates(t *testing.T) { if err != nil { t.Fatalf("Error constructing ice.Agent: %v", err) } + defer func() { + require.NoError(t, a.Close()) + }() dummyConn := &net.UDPConn{} expectedCandidates := []Candidate{} @@ -1348,8 +1405,6 @@ func TestGetLocalCandidates(t *testing.T) { actualCandidates, err := a.GetLocalCandidates() require.NoError(t, err) require.ElementsMatch(t, expectedCandidates, actualCandidates) - - require.NoError(t, a.Close()) } func TestCloseInConnectionStateCallback(t *testing.T) { @@ -1373,9 +1428,19 @@ func TestCloseInConnectionStateCallback(t *testing.T) { aAgent, err := NewAgent(cfg) require.NoError(t, err) + var aAgentClosed bool + defer func() { + if aAgentClosed { + return + } + require.NoError(t, aAgent.Close()) + }() bAgent, err := NewAgent(cfg) require.NoError(t, err) + defer func() { + require.NoError(t, bAgent.Close()) + }() isClosed := make(chan interface{}) isConnected := make(chan interface{}) @@ -1384,6 +1449,7 @@ func TestCloseInConnectionStateCallback(t *testing.T) { case ConnectionStateConnected: <-isConnected require.NoError(t, aAgent.Close()) + aAgentClosed = true case ConnectionStateClosed: close(isClosed) default: @@ -1395,7 +1461,6 @@ func TestCloseInConnectionStateCallback(t *testing.T) { close(isConnected) <-isClosed - require.NoError(t, bAgent.Close()) } func TestRunTaskInConnectionStateCallback(t *testing.T) { @@ -1418,8 +1483,14 @@ func TestRunTaskInConnectionStateCallback(t *testing.T) { aAgent, err := NewAgent(cfg) check(err) + defer func() { + require.NoError(t, aAgent.Close()) + }() bAgent, err := NewAgent(cfg) check(err) + defer func() { + require.NoError(t, bAgent.Close()) + }() isComplete := make(chan interface{}) err = aAgent.OnConnectionStateChange(func(c ConnectionState) { @@ -1435,8 +1506,6 @@ func TestRunTaskInConnectionStateCallback(t *testing.T) { connect(aAgent, bAgent) <-isComplete - require.NoError(t, aAgent.Close()) - require.NoError(t, bAgent.Close()) } func TestRunTaskInSelectedCandidatePairChangeCallback(t *testing.T) { @@ -1459,8 +1528,14 @@ func TestRunTaskInSelectedCandidatePairChangeCallback(t *testing.T) { aAgent, err := NewAgent(cfg) check(err) + defer func() { + require.NoError(t, aAgent.Close()) + }() bAgent, err := NewAgent(cfg) check(err) + defer func() { + require.NoError(t, bAgent.Close()) + }() isComplete := make(chan interface{}) isTested := make(chan interface{}) @@ -1485,8 +1560,6 @@ func TestRunTaskInSelectedCandidatePairChangeCallback(t *testing.T) { <-isComplete <-isTested - require.NoError(t, aAgent.Close()) - require.NoError(t, bAgent.Close()) } // Assert that a Lite agent goes to disconnected and failed @@ -1502,6 +1575,13 @@ func TestLiteLifecycle(t *testing.T) { MulticastDNSMode: MulticastDNSModeDisabled, }) require.NoError(t, err) + var aClosed bool + defer func() { + if aClosed { + return + } + require.NoError(t, aAgent.Close()) + }() require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) disconnectedDuration := time.Second @@ -1519,6 +1599,13 @@ func TestLiteLifecycle(t *testing.T) { CheckInterval: &CheckInterval, }) require.NoError(t, err) + var bClosed bool + defer func() { + if bClosed { + return + } + require.NoError(t, bAgent.Close()) + }() bConnected := make(chan interface{}) bDisconnected := make(chan interface{}) @@ -1541,10 +1628,12 @@ func TestLiteLifecycle(t *testing.T) { <-aConnected <-bConnected require.NoError(t, aAgent.Close()) + aClosed = true <-bDisconnected <-bFailed require.NoError(t, bAgent.Close()) + bClosed = true } func TestNilCandidate(t *testing.T) { @@ -1558,9 +1647,11 @@ func TestNilCandidate(t *testing.T) { func TestNilCandidatePair(t *testing.T) { a, err := NewAgent(&AgentConfig{}) require.NoError(t, err) + defer func() { + require.NoError(t, a.Close()) + }() a.setSelectedPair(nil) - require.NoError(t, a.Close()) } func TestGetSelectedCandidatePair(t *testing.T) { @@ -1589,9 +1680,15 @@ func TestGetSelectedCandidatePair(t *testing.T) { aAgent, err := NewAgent(cfg) require.NoError(t, err) + defer func() { + require.NoError(t, aAgent.Close()) + }() bAgent, err := NewAgent(cfg) require.NoError(t, err) + defer func() { + require.NoError(t, bAgent.Close()) + }() aAgentPair, err := aAgent.GetSelectedCandidatePair() require.NoError(t, err) @@ -1615,8 +1712,6 @@ func TestGetSelectedCandidatePair(t *testing.T) { require.True(t, bAgentPair.Remote.Equal(aAgentPair.Local)) require.NoError(t, wan.Stop()) - require.NoError(t, aAgent.Close()) - require.NoError(t, bAgent.Close()) } func TestAcceptAggressiveNomination(t *testing.T) { @@ -1662,6 +1757,9 @@ func TestAcceptAggressiveNomination(t *testing.T) { var aAgent, bAgent *Agent aAgent, err = NewAgent(cfg0) require.NoError(t, err) + defer func() { + require.NoError(t, aAgent.Close()) + }() require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) cfg1 := &AgentConfig{ @@ -1674,9 +1772,12 @@ func TestAcceptAggressiveNomination(t *testing.T) { bAgent, err = NewAgent(cfg1) require.NoError(t, err) + defer func() { + require.NoError(t, bAgent.Close()) + }() require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) - aConn, bConn := connect(aAgent, bAgent) + connect(aAgent, bAgent) // Ensure pair selected // Note: this assumes ConnectionStateConnected is thrown after selecting the final pair @@ -1735,7 +1836,6 @@ func TestAcceptAggressiveNomination(t *testing.T) { } require.NoError(t, wan.Stop()) - closePipe(t, aConn, bConn) } // Close can deadlock but GracefulClose must not @@ -1748,15 +1848,29 @@ func TestAgentGracefulCloseDeadlock(t *testing.T) { } aAgent, err := NewAgent(config) require.NoError(t, err) + var aAgentClosed bool + defer func() { + if aAgentClosed { + return + } + require.NoError(t, aAgent.Close()) + }() bAgent, err := NewAgent(config) require.NoError(t, err) + var bAgentClosed bool + defer func() { + if bAgentClosed { + return + } + require.NoError(t, bAgent.Close()) + }() var connected, closeNow, closed sync.WaitGroup connected.Add(2) closeNow.Add(1) closed.Add(2) - closeHdlr := func(agent *Agent) { + closeHdlr := func(agent *Agent, agentClosed *bool) { check(agent.OnConnectionStateChange(func(cs ConnectionState) { if cs == ConnectionStateConnected { connected.Done() @@ -1766,14 +1880,15 @@ func TestAgentGracefulCloseDeadlock(t *testing.T) { if err := agent.GracefulClose(); err != nil { require.NoError(t, err) } + *agentClosed = true closed.Done() }() } })) } - closeHdlr(aAgent) - closeHdlr(bAgent) + closeHdlr(aAgent, &aAgentClosed) + closeHdlr(bAgent, &bAgentClosed) t.Log("connecting agents") _, _ = connect(aAgent, bAgent) @@ -1784,8 +1899,4 @@ func TestAgentGracefulCloseDeadlock(t *testing.T) { t.Log("tell them to close themselves in the same callback and wait") closeNow.Done() closed.Wait() - - // already closed - require.Error(t, aAgent.Close()) - require.Error(t, bAgent.Close()) } diff --git a/agent_udpmux_test.go b/agent_udpmux_test.go index 70e7d25c..8f4efe3d 100644 --- a/agent_udpmux_test.go +++ b/agent_udpmux_test.go @@ -50,12 +50,26 @@ func TestMuxAgent(t *testing.T) { IncludeLoopback: addr.IP.IsLoopback(), }) require.NoError(t, err) + var muxedAClosed bool + defer func() { + if muxedAClosed { + return + } + require.NoError(t, muxedA.Close()) + }() a, err := NewAgent(&AgentConfig{ CandidateTypes: []CandidateType{CandidateTypeHost}, NetworkTypes: supportedNetworkTypes(), }) require.NoError(t, err) + var aClosed bool + defer func() { + if aClosed { + return + } + require.NoError(t, a.Close()) + }() conn, muxedConn := connect(a, muxedA) @@ -83,7 +97,9 @@ func TestMuxAgent(t *testing.T) { // Close it down require.NoError(t, conn.Close()) + aClosed = true require.NoError(t, muxedConn.Close()) + muxedAClosed = true require.NoError(t, udpMux.Close()) // Expect error when reading from closed mux diff --git a/candidate_relay_test.go b/candidate_relay_test.go index b3ebc263..f74c0d54 100644 --- a/candidate_relay_test.go +++ b/candidate_relay_test.go @@ -43,6 +43,9 @@ func TestRelayOnlyConnection(t *testing.T) { }, }) require.NoError(t, err) + defer func() { + require.NoError(t, server.Close()) + }() cfg := &AgentConfig{ NetworkTypes: supportedNetworkTypes(), @@ -61,12 +64,18 @@ func TestRelayOnlyConnection(t *testing.T) { aAgent, err := NewAgent(cfg) require.NoError(t, err) + defer func() { + require.NoError(t, aAgent.Close()) + }() aNotifier, aConnected := onConnected() require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) bAgent, err := NewAgent(cfg) require.NoError(t, err) + defer func() { + require.NoError(t, bAgent.Close()) + }() bNotifier, bConnected := onConnected() require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) @@ -74,8 +83,4 @@ func TestRelayOnlyConnection(t *testing.T) { connect(aAgent, bAgent) <-aConnected <-bConnected - - require.NoError(t, aAgent.Close()) - require.NoError(t, bAgent.Close()) - require.NoError(t, server.Close()) } diff --git a/candidate_server_reflexive_test.go b/candidate_server_reflexive_test.go index 037c058f..54531401 100644 --- a/candidate_server_reflexive_test.go +++ b/candidate_server_reflexive_test.go @@ -39,6 +39,9 @@ func TestServerReflexiveOnlyConnection(t *testing.T) { }, }) require.NoError(t, err) + defer func() { + require.NoError(t, server.Close()) + }() cfg := &AgentConfig{ NetworkTypes: []NetworkType{NetworkTypeUDP4}, @@ -54,12 +57,18 @@ func TestServerReflexiveOnlyConnection(t *testing.T) { aAgent, err := NewAgent(cfg) require.NoError(t, err) + defer func() { + require.NoError(t, aAgent.Close()) + }() aNotifier, aConnected := onConnected() require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) bAgent, err := NewAgent(cfg) require.NoError(t, err) + defer func() { + require.NoError(t, bAgent.Close()) + }() bNotifier, bConnected := onConnected() require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) @@ -67,8 +76,4 @@ func TestServerReflexiveOnlyConnection(t *testing.T) { connect(aAgent, bAgent) <-aConnected <-bConnected - - require.NoError(t, aAgent.Close()) - require.NoError(t, bAgent.Close()) - require.NoError(t, server.Close()) } diff --git a/connectivity_vnet_test.go b/connectivity_vnet_test.go index f85ac54a..0c20e8ce 100644 --- a/connectivity_vnet_test.go +++ b/connectivity_vnet_test.go @@ -493,6 +493,9 @@ func TestDisconnectedToConnected(t *testing.T) { CheckInterval: &keepaliveInterval, }) require.NoError(t, err) + defer func() { + require.NoError(t, controllingAgent.Close()) + }() controlledAgent, err := NewAgent(&AgentConfig{ NetworkTypes: supportedNetworkTypes(), @@ -503,6 +506,9 @@ func TestDisconnectedToConnected(t *testing.T) { CheckInterval: &keepaliveInterval, }) require.NoError(t, err) + defer func() { + require.NoError(t, controlledAgent.Close()) + }() controllingStateChanges := make(chan ConnectionState, 100) require.NoError(t, controllingAgent.OnConnectionStateChange(func(c ConnectionState) { @@ -538,8 +544,6 @@ func TestDisconnectedToConnected(t *testing.T) { blockUntilStateSeen(ConnectionStateConnected, controlledStateChanges) require.NoError(t, wan.Stop()) - require.NoError(t, controllingAgent.Close()) - require.NoError(t, controlledAgent.Close()) } // Agent.Write should use the best valid pair if a selected pair is not yet available @@ -593,6 +597,9 @@ func TestWriteUseValidPair(t *testing.T) { Net: net0, }) require.NoError(t, err) + defer func() { + require.NoError(t, controllingAgent.Close()) + }() controlledAgent, err := NewAgent(&AgentConfig{ NetworkTypes: supportedNetworkTypes(), @@ -600,6 +607,9 @@ func TestWriteUseValidPair(t *testing.T) { Net: net1, }) require.NoError(t, err) + defer func() { + require.NoError(t, controlledAgent.Close()) + }() gatherAndExchangeCandidates(controllingAgent, controlledAgent) @@ -630,6 +640,4 @@ func TestWriteUseValidPair(t *testing.T) { require.Equal(t, readBuf, testMessage) require.NoError(t, wan.Stop()) - require.NoError(t, controllingAgent.Close()) - require.NoError(t, controlledAgent.Close()) } diff --git a/gather_test.go b/gather_test.go index 1e4b8960..8d54e2be 100644 --- a/gather_test.go +++ b/gather_test.go @@ -33,6 +33,9 @@ import ( func TestListenUDP(t *testing.T) { a, err := NewAgent(&AgentConfig{}) require.NoError(t, err) + defer func() { + require.NoError(t, a.Close()) + }() _, localAddrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) require.NotEqual(t, len(localAddrs), 0, "localInterfaces found no interfaces, unable to test") @@ -84,8 +87,6 @@ func TestListenUDP(t *testing.T) { } _, err = listenUDPInPortRange(a.net, a.log, portMax, portMin, udp, &net.UDPAddr{IP: ip, Port: 0}) require.Equal(t, err, ErrPort, "listenUDP with port restriction [%d, %d], did not return ErrPort", portMin, portMax) - - require.NoError(t, a.Close()) } func TestGatherConcurrency(t *testing.T) { @@ -98,6 +99,9 @@ func TestGatherConcurrency(t *testing.T) { IncludeLoopback: true, }) require.NoError(t, err) + defer func() { + require.NoError(t, a.Close()) + }() candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background()) require.NoError(t, a.OnCandidate(func(Candidate) { @@ -110,8 +114,6 @@ func TestGatherConcurrency(t *testing.T) { } <-candidateGathered.Done() - - require.NoError(t, a.Close()) } func TestLoopbackCandidate(t *testing.T) { @@ -194,6 +196,9 @@ func TestLoopbackCandidate(t *testing.T) { t.Run(tcase.name, func(t *testing.T) { a, err := NewAgent(tc.agentConfig) require.NoError(t, err) + defer func() { + require.NoError(t, a.Close()) + }() candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background()) var loopback int32 @@ -212,7 +217,6 @@ func TestLoopbackCandidate(t *testing.T) { <-candidateGathered.Done() - require.NoError(t, a.Close()) require.Equal(t, tcase.loExpected, atomic.LoadInt32(&loopback) == 1) }) } @@ -243,6 +247,9 @@ func TestSTUNConcurrency(t *testing.T) { }, }) require.NoError(t, err) + defer func() { + require.NoError(t, server.Close()) + }() urls := []*stun.URI{} for i := 0; i <= 10; i++ { @@ -279,6 +286,9 @@ func TestSTUNConcurrency(t *testing.T) { ), }) require.NoError(t, err) + defer func() { + require.NoError(t, a.Close()) + }() candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background()) require.NoError(t, a.OnCandidate(func(c Candidate) { @@ -291,9 +301,6 @@ func TestSTUNConcurrency(t *testing.T) { require.NoError(t, a.GatherCandidates()) <-candidateGathered.Done() - - require.NoError(t, a.Close()) - require.NoError(t, server.Close()) } // Assert that TURN gathering is done concurrently @@ -326,6 +333,9 @@ func TestTURNConcurrency(t *testing.T) { ListenerConfigs: listenerConfigs, }) require.NoError(t, err) + defer func() { + require.NoError(t, server.Close()) + }() urls := []*stun.URI{} for i := 0; i <= 10; i++ { @@ -354,6 +364,9 @@ func TestTURNConcurrency(t *testing.T) { Urls: urls, }) require.NoError(t, err) + defer func() { + require.NoError(t, a.Close()) + }() candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background()) require.NoError(t, a.OnCandidate(func(c Candidate) { @@ -364,9 +377,6 @@ func TestTURNConcurrency(t *testing.T) { require.NoError(t, a.GatherCandidates()) <-candidateGathered.Done() - - require.NoError(t, a.Close()) - require.NoError(t, server.Close()) } t.Run("UDP Relay", func(t *testing.T) { @@ -433,6 +443,9 @@ func TestSTUNTURNConcurrency(t *testing.T) { }, }) require.NoError(t, err) + defer func() { + require.NoError(t, server.Close()) + }() urls := []*stun.URI{} for i := 0; i <= 10; i++ { @@ -457,6 +470,9 @@ func TestSTUNTURNConcurrency(t *testing.T) { CandidateTypes: []CandidateType{CandidateTypeServerReflexive, CandidateTypeRelay}, }) require.NoError(t, err) + defer func() { + require.NoError(t, a.Close()) + }() { gatherLim := test.TimeOut(time.Second * 3) // As TURN and STUN should be checked in parallel, this should complete before the default STUN timeout (5s) @@ -471,9 +487,6 @@ func TestSTUNTURNConcurrency(t *testing.T) { <-candidateGathered.Done() gatherLim.Stop() } - - require.NoError(t, a.Close()) - require.NoError(t, server.Close()) } // Assert that srflx candidates can be gathered from TURN servers @@ -502,6 +515,9 @@ func TestTURNSrflx(t *testing.T) { }, }) require.NoError(t, err) + defer func() { + require.NoError(t, server.Close()) + }() urls := []*stun.URI{{ Scheme: stun.SchemeTypeTURN, @@ -518,6 +534,9 @@ func TestTURNSrflx(t *testing.T) { CandidateTypes: []CandidateType{CandidateTypeServerReflexive, CandidateTypeRelay}, }) require.NoError(t, err) + defer func() { + require.NoError(t, a.Close()) + }() candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background()) require.NoError(t, a.OnCandidate(func(c Candidate) { @@ -529,21 +548,19 @@ func TestTURNSrflx(t *testing.T) { require.NoError(t, a.GatherCandidates()) <-candidateGathered.Done() - - require.NoError(t, a.Close()) - require.NoError(t, server.Close()) } func TestCloseConnLog(t *testing.T) { a, err := NewAgent(&AgentConfig{}) require.NoError(t, err) + defer func() { + require.NoError(t, a.Close()) + }() closeConnAndLog(nil, a.log, "normal nil") var nc *net.UDPConn closeConnAndLog(nc, a.log, "nil ptr") - - require.NoError(t, a.Close()) } type mockProxy struct { @@ -598,6 +615,9 @@ func TestTURNProxyDialer(t *testing.T) { ProxyDialer: proxyDialer, }) require.NoError(t, err) + defer func() { + require.NoError(t, a.Close()) + }() candidateGatherFinish, candidateGatherFinishFunc := context.WithCancel(context.Background()) require.NoError(t, a.OnCandidate(func(c Candidate) { @@ -609,8 +629,6 @@ func TestTURNProxyDialer(t *testing.T) { require.NoError(t, a.GatherCandidates()) <-candidateGatherFinish.Done() <-proxyWasDialed.Done() - - require.NoError(t, a.Close()) } // TestUDPMuxDefaultWithNAT1To1IPsUsage requires that candidates @@ -639,6 +657,9 @@ func TestUDPMuxDefaultWithNAT1To1IPsUsage(t *testing.T) { UDPMux: mux, }) require.NoError(t, err) + defer func() { + require.NoError(t, a.Close()) + }() gatherCandidateDone := make(chan struct{}) require.NoError(t, a.OnCandidate(func(c Candidate) { @@ -652,8 +673,6 @@ func TestUDPMuxDefaultWithNAT1To1IPsUsage(t *testing.T) { <-gatherCandidateDone require.NotEqual(t, 0, len(mux.connsIPv4)) - - require.NoError(t, a.Close()) } // Assert that candidates are given for each mux in a MultiUDPMux @@ -687,6 +706,9 @@ func TestMultiUDPMuxUsage(t *testing.T) { UDPMux: NewMultiUDPMuxDefault(udpMuxInstances...), }) require.NoError(t, err) + defer func() { + require.NoError(t, a.Close()) + }() candidateCh := make(chan Candidate) require.NoError(t, a.OnCandidate(func(c Candidate) { @@ -707,8 +729,6 @@ func TestMultiUDPMuxUsage(t *testing.T) { for _, port := range expectedPorts { require.True(t, portFound[port], "There should be a candidate for each UDP mux port") } - - require.NoError(t, a.Close()) } // Assert that candidates are given for each mux in a MultiTCPMux @@ -743,6 +763,9 @@ func TestMultiTCPMuxUsage(t *testing.T) { TCPMux: NewMultiTCPMuxDefault(tcpMuxInstances...), }) require.NoError(t, err) + defer func() { + require.NoError(t, a.Close()) + }() candidateCh := make(chan Candidate) require.NoError(t, a.OnCandidate(func(c Candidate) { @@ -765,8 +788,6 @@ func TestMultiTCPMuxUsage(t *testing.T) { for _, port := range expectedPorts { require.True(t, portFound[port], "There should be a candidate for each TCP mux port") } - - require.NoError(t, a.Close()) } // Assert that UniversalUDPMux is used while gathering when configured in the Agent @@ -802,6 +823,13 @@ func TestUniversalUDPMuxUsage(t *testing.T) { UDPMuxSrflx: udpMuxSrflx, }) require.NoError(t, err) + var aClosed bool + defer func() { + if aClosed { + return + } + require.NoError(t, a.Close()) + }() candidateGathered, candidateGatheredFunc := context.WithCancel(context.Background()) require.NoError(t, a.OnCandidate(func(c Candidate) { @@ -816,6 +844,8 @@ func TestUniversalUDPMuxUsage(t *testing.T) { <-candidateGathered.Done() require.NoError(t, a.Close()) + aClosed = true + // Twice because of 2 STUN servers configured require.Equal(t, numSTUNS, udpMuxSrflx.getXORMappedAddrUsedTimes, "expected times that GetXORMappedAddr should be called") // One for Restart() when agent has been initialized and one time when Close() the agent diff --git a/gather_vnet_test.go b/gather_vnet_test.go index 9fda3fdc..15f9e3df 100644 --- a/gather_vnet_test.go +++ b/gather_vnet_test.go @@ -33,14 +33,15 @@ func TestVNetGather(t *testing.T) { Net: n, }) require.NoError(t, err) + defer func() { + require.NoError(t, a.Close()) + }() _, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) if len(localIPs) > 0 { t.Fatal("should return no local IP") } require.NoError(t, err) - - require.NoError(t, a.Close()) }) t.Run("Gather a dynamic IP address", func(t *testing.T) { @@ -72,6 +73,9 @@ func TestVNetGather(t *testing.T) { Net: nw, }) require.NoError(t, err) + defer func() { + require.NoError(t, a.Close()) + }() _, localAddrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) if len(localAddrs) == 0 { @@ -87,8 +91,6 @@ func TestVNetGather(t *testing.T) { t.Fatal("should be contained in the CIDR") } } - - require.NoError(t, a.Close()) }) t.Run("listenUDP", func(t *testing.T) { @@ -114,6 +116,9 @@ func TestVNetGather(t *testing.T) { if err != nil { t.Fatalf("Failed to create agent: %s", err) } + defer func() { + require.NoError(t, a.Close()) + }() _, localAddrs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) if len(localAddrs) == 0 { @@ -145,6 +150,9 @@ func TestVNetGather(t *testing.T) { } else if conn == nil { t.Fatalf("listenUDP error with no port restriction return a nil conn") } + defer func() { + require.NoError(t, conn.Close()) + }() _, port, err := net.SplitHostPort(conn.LocalAddr().String()) @@ -152,9 +160,6 @@ func TestVNetGather(t *testing.T) { if port != "5000" { t.Fatalf("listenUDP with port restriction of 5000 listened on incorrect port (%s)", port) } - - require.NoError(t, conn.Close()) - require.NoError(t, a.Close()) }) } @@ -209,7 +214,9 @@ func TestVNetGatherWithNAT1To1(t *testing.T) { Net: nw, }) require.NoError(t, err, "should succeed") - defer a.Close() //nolint:errcheck + defer func() { + require.NoError(t, a.Close()) + }() done := make(chan struct{}) err = a.OnCandidate(func(c Candidate) { @@ -309,7 +316,9 @@ func TestVNetGatherWithNAT1To1(t *testing.T) { Net: nw, }) require.NoError(t, err, "should succeed") - defer a.Close() //nolint:errcheck + defer func() { + require.NoError(t, a.Close()) + }() done := make(chan struct{}) err = a.OnCandidate(func(c Candidate) { @@ -384,6 +393,9 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) { }, }) require.NoError(t, err) + defer func() { + require.NoError(t, a.Close()) + }() _, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) require.NoError(t, err) @@ -391,8 +403,6 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) { if len(localIPs) != 0 { t.Fatal("InterfaceFilter should have excluded everything") } - - require.NoError(t, a.Close()) }) t.Run("IPFilter should exclude the IP", func(t *testing.T) { @@ -404,6 +414,9 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) { }, }) require.NoError(t, err) + defer func() { + require.NoError(t, a.Close()) + }() _, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) require.NoError(t, err) @@ -411,8 +424,6 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) { if len(localIPs) != 0 { t.Fatal("IPFilter should have excluded everything") } - - require.NoError(t, a.Close()) }) t.Run("InterfaceFilter should not exclude the interface", func(t *testing.T) { @@ -424,6 +435,9 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) { }, }) require.NoError(t, err) + defer func() { + require.NoError(t, a.Close()) + }() _, localIPs, err := localInterfaces(a.net, a.interfaceFilter, a.ipFilter, []NetworkType{NetworkTypeUDP4}, false) require.NoError(t, err) @@ -431,8 +445,6 @@ func TestVNetGatherWithInterfaceFilter(t *testing.T) { if len(localIPs) == 0 { t.Fatal("InterfaceFilter should not have excluded anything") } - - require.NoError(t, a.Close()) }) } @@ -469,8 +481,10 @@ func TestVNetGather_TURNConnectionLeak(t *testing.T) { } aAgent, err := NewAgent(cfg0) require.NoError(t, err, "should succeed") + defer func() { + // Assert relay conn leak on close. + require.NoError(t, aAgent.Close()) + }() aAgent.gatherCandidatesRelay(context.Background(), []*stun.URI{turnServerURL}) - // Assert relay conn leak on close. - require.NoError(t, aAgent.Close()) } diff --git a/go.mod b/go.mod index 731e7cf2..c122997e 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/pion/mdns/v2 v2.0.7 github.com/pion/randutil v0.1.0 github.com/pion/stun/v2 v2.0.0 - github.com/pion/transport/v3 v3.0.5 + github.com/pion/transport/v3 v3.0.6 github.com/pion/turn/v3 v3.0.3 github.com/stretchr/testify v1.9.0 golang.org/x/net v0.26.0 diff --git a/go.sum b/go.sum index 22a1dc70..abaf6503 100644 --- a/go.sum +++ b/go.sum @@ -23,8 +23,8 @@ github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1A github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo= github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0= github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0= -github.com/pion/transport/v3 v3.0.5 h1:ofVrcbPNqVPuKaTO5AMFnFuJ1ZX7ElYiWzC5PCf9YVQ= -github.com/pion/transport/v3 v3.0.5/go.mod h1:HvJr2N/JwNJAfipsRleqwFoR3t/pWyHeZUs89v3+t5s= +github.com/pion/transport/v3 v3.0.6 h1:k1mQU06bmmX143qSWgXFqSH1KUJceQvIUuVH/K5ELWw= +github.com/pion/transport/v3 v3.0.6/go.mod h1:HvJr2N/JwNJAfipsRleqwFoR3t/pWyHeZUs89v3+t5s= github.com/pion/turn/v3 v3.0.3 h1:1e3GVk8gHZLPBA5LqadWYV60lmaKUaHCkm9DX9CkGcE= github.com/pion/turn/v3 v3.0.3/go.mod h1:vw0Dz420q7VYAF3J4wJKzReLHIo2LGp4ev8nXQexYsc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/mdns_test.go b/mdns_test.go index 617492b2..a65d836e 100644 --- a/mdns_test.go +++ b/mdns_test.go @@ -49,12 +49,18 @@ func TestMulticastDNSOnlyConnection(t *testing.T) { aAgent, err := NewAgent(cfg) require.NoError(t, err) + defer func() { + require.NoError(t, aAgent.Close()) + }() aNotifier, aConnected := onConnected() require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) bAgent, err := NewAgent(cfg) require.NoError(t, err) + defer func() { + require.NoError(t, bAgent.Close()) + }() bNotifier, bConnected := onConnected() require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) @@ -62,9 +68,6 @@ func TestMulticastDNSOnlyConnection(t *testing.T) { connect(aAgent, bAgent) <-aConnected <-bConnected - - require.NoError(t, aAgent.Close()) - require.NoError(t, bAgent.Close()) }) } } @@ -100,6 +103,9 @@ func TestMulticastDNSMixedConnection(t *testing.T) { InterfaceFilter: problematicNetworkInterfaces, }) require.NoError(t, err) + defer func() { + require.NoError(t, aAgent.Close()) + }() aNotifier, aConnected := onConnected() require.NoError(t, aAgent.OnConnectionStateChange(aNotifier)) @@ -111,6 +117,9 @@ func TestMulticastDNSMixedConnection(t *testing.T) { InterfaceFilter: problematicNetworkInterfaces, }) require.NoError(t, err) + defer func() { + require.NoError(t, bAgent.Close()) + }() bNotifier, bConnected := onConnected() require.NoError(t, bAgent.OnConnectionStateChange(bNotifier)) @@ -118,9 +127,6 @@ func TestMulticastDNSMixedConnection(t *testing.T) { connect(aAgent, bAgent) <-aConnected <-bConnected - - require.NoError(t, aAgent.Close()) - require.NoError(t, bAgent.Close()) }) } } @@ -165,6 +171,9 @@ func TestMulticastDNSStaticHostName(t *testing.T) { InterfaceFilter: problematicNetworkInterfaces, }) require.NoError(t, err) + defer func() { + require.NoError(t, agent.Close()) + }() correctHostName, resolveFunc := context.WithCancel(context.Background()) require.NoError(t, agent.OnCandidate(func(c Candidate) { @@ -175,7 +184,6 @@ func TestMulticastDNSStaticHostName(t *testing.T) { require.NoError(t, agent.GatherCandidates()) <-correctHostName.Done() - require.NoError(t, agent.Close()) }) } } diff --git a/transport_test.go b/transport_test.go index c7907a17..511d2537 100644 --- a/transport_test.go +++ b/transport_test.go @@ -324,6 +324,7 @@ func TestConnStats(t *testing.T) { if _, err := ca.Write(make([]byte, 10)); err != nil { t.Fatal("unexpected error trying to write") } + defer closePipe(t, ca, cb) var wg sync.WaitGroup wg.Add(1) @@ -344,16 +345,4 @@ func TestConnStats(t *testing.T) { if cb.BytesReceived() != 10 { t.Fatal("bytes received don't match") } - - err := ca.Close() - if err != nil { - // We should never get here. - panic(err) - } - - err = cb.Close() - if err != nil { - // We should never get here. - panic(err) - } } diff --git a/transport_vnet_test.go b/transport_vnet_test.go index 36a22f5b..1644742d 100644 --- a/transport_vnet_test.go +++ b/transport_vnet_test.go @@ -61,6 +61,7 @@ func TestRemoteLocalAddr(t *testing.T) { urls: []*stun.URI{stunServerURL}, }, ) + defer closePipe(t, ca, cb) aRAddr := ca.RemoteAddr() aLAddr := ca.LocalAddr() @@ -86,9 +87,5 @@ func TestRemoteLocalAddr(t *testing.T) { require.Equal(t, bRAddr.String(), fmt.Sprintf("%s:%d", vnetGlobalIPA, aLAddr.(*net.UDPAddr).Port), //nolint:forcetypeassert ) - - // Close - require.NoError(t, ca.Close()) - require.NoError(t, cb.Close()) }) } diff --git a/udp_mux_test.go b/udp_mux_test.go index d25a0705..1e218a81 100644 --- a/udp_mux_test.go +++ b/udp_mux_test.go @@ -250,6 +250,7 @@ func TestUDPMux_Agent_Restart(t *testing.T) { DisconnectedTimeout: &oneSecond, FailedTimeout: &oneSecond, }) + defer closePipe(t, connA, connB) aNotifier, aConnected := onConnected() require.NoError(t, connA.agent.OnConnectionStateChange(aNotifier)) @@ -277,7 +278,4 @@ func TestUDPMux_Agent_Restart(t *testing.T) { // Wait until both have gone back to connected <-aConnected <-bConnected - - require.NoError(t, connA.agent.Close()) - require.NoError(t, connB.agent.Close()) }