From bf9c7ed4ed442e98c6bd57eed0a74c43b60560ae Mon Sep 17 00:00:00 2001 From: Joe Turki Date: Thu, 30 Jan 2025 00:17:06 -0600 Subject: [PATCH] Add methods to add and remove extensions Added `AddExtension` and `RemoveExtension` methods to `ICECandidate`, allowing extensions to be managed dynamically --- candidate.go | 8 ++++- candidate_base.go | 43 +++++++++++++++++++++++- candidate_test.go | 84 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 133 insertions(+), 2 deletions(-) diff --git a/candidate.go b/candidate.go index 4eb52061..89082f98 100644 --- a/candidate.go +++ b/candidate.go @@ -58,12 +58,18 @@ type Candidate interface { // https://datatracker.ietf.org/doc/html/rfc5245#section-15.1 //. Extensions() []CandidateExtension - // GetExtension returns the value of the extension attribute associated with the ICECandidate. // Extension attributes are defined in RFC 5245, Section 15.1: // https://datatracker.ietf.org/doc/html/rfc5245#section-15.1 //. GetExtension(key string) (value CandidateExtension, ok bool) + // AddExtension adds an extension attribute to the ICECandidate. + // If an extension with the same key already exists, it will be overwritten. + // Extension attributes are defined in RFC 5245, Section 15.1: + AddExtension(extension CandidateExtension) error + // RemoveExtension removes an extension attribute from the ICECandidate. + // Extension attributes are defined in RFC 5245, Section 15.1: + RemoveExtension(key string) (ok bool) String() string Type() CandidateType diff --git a/candidate_base.go b/candidate_base.go index 55a6ce86..9273a8bd 100644 --- a/candidate_base.go +++ b/candidate_base.go @@ -576,7 +576,7 @@ func (c *candidateBase) GetExtension(key string) (CandidateExtension, bool) { } // TCPType was manually set. - if key == "tcptype" && c.TCPType() != TCPTypeUnspecified { + if key == "tcptype" && c.TCPType() != TCPTypeUnspecified { //nolint:goconst extension.Value = c.TCPType().String() return extension, true @@ -585,6 +585,47 @@ func (c *candidateBase) GetExtension(key string) (CandidateExtension, bool) { return extension, false } +func (c *candidateBase) AddExtension(ext CandidateExtension) error { + if ext.Key == "tcptype" { + tcpType := NewTCPType(ext.Value) + if tcpType == TCPTypeUnspecified { + return fmt.Errorf("%w: invalid or unsupported TCPtype %s", errParseTCPType, ext.Value) + } + + c.tcpType = tcpType + } + + if ext.Key == "" { + return fmt.Errorf("%w: key is empty", errParseExtension) + } + + // per spec, Extensions aren't explicitly unique, we only set the first one. + // If the exteion is set multiple times. + for i := range c.extensions { + if c.extensions[i].Key == ext.Key { + c.extensions[i] = ext + + return nil + } + } + + c.extensions = append(c.extensions, ext) + + return nil +} + +func (c *candidateBase) RemoveExtension(key string) bool { + for i := range c.extensions { + if c.extensions[i].Key == key { + c.extensions = append(c.extensions[:i], c.extensions[i+1:]...) + + return true + } + } + + return false +} + // marshalExtensions returns the string representation of the candidate extensions. func (c *candidateBase) marshalExtensions() string { value := "" diff --git a/candidate_test.go b/candidate_test.go index caafefe7..a5a2af61 100644 --- a/candidate_test.go +++ b/candidate_test.go @@ -1271,3 +1271,87 @@ func TestBaseCandidateExtensionsEqual(t *testing.T) { }) } } + +func TestCandidateAddExtension(t *testing.T) { + t.Run("Add extension", func(t *testing.T) { + candidate, err := NewCandidateHost(&CandidateHostConfig{ + Network: NetworkTypeUDP4.String(), + Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", + Port: 53987, + Priority: 500, + Foundation: "750", + }) + if err != nil { + t.Error(err) + } + + require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"})) + require.NoError(t, candidate.AddExtension(CandidateExtension{"c", "d"})) + + extensions := candidate.Extensions() + require.Equal(t, []CandidateExtension{{"a", "b"}, {"c", "d"}}, extensions) + }) + + t.Run("Add extension with existing key", func(t *testing.T) { + candidate, err := NewCandidateHost(&CandidateHostConfig{ + Network: NetworkTypeUDP4.String(), + Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", + Port: 53987, + Priority: 500, + Foundation: "750", + }) + if err != nil { + t.Error(err) + } + + require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"})) + require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "d"})) + + extensions := candidate.Extensions() + require.Equal(t, []CandidateExtension{{"a", "d"}}, extensions) + }) +} + +func TestCandidateRemoveExtension(t *testing.T) { + t.Run("Remove extension", func(t *testing.T) { + candidate, err := NewCandidateHost(&CandidateHostConfig{ + Network: NetworkTypeUDP4.String(), + Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", + Port: 53987, + Priority: 500, + Foundation: "750", + }) + if err != nil { + t.Error(err) + } + + require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"})) + require.NoError(t, candidate.AddExtension(CandidateExtension{"c", "d"})) + + require.True(t, candidate.RemoveExtension("a")) + + extensions := candidate.Extensions() + require.Equal(t, []CandidateExtension{{"c", "d"}}, extensions) + }) + + t.Run("Remove extension that does not exist", func(t *testing.T) { + candidate, err := NewCandidateHost(&CandidateHostConfig{ + Network: NetworkTypeUDP4.String(), + Address: "fcd9:e3b8:12ce:9fc5:74a5:c6bb:d8b:e08a", + Port: 53987, + Priority: 500, + Foundation: "750", + }) + if err != nil { + t.Error(err) + } + + require.NoError(t, candidate.AddExtension(CandidateExtension{"a", "b"})) + require.NoError(t, candidate.AddExtension(CandidateExtension{"c", "d"})) + + require.False(t, candidate.RemoveExtension("b")) + + extensions := candidate.Extensions() + require.Equal(t, []CandidateExtension{{"a", "b"}, {"c", "d"}}, extensions) + }) +}