Skip to content

Commit

Permalink
Handle multiple remote tracks per session
Browse files Browse the repository at this point in the history
  • Loading branch information
streamer45 committed Apr 15, 2024
1 parent 648d760 commit f2dd785
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 8 deletions.
4 changes: 2 additions & 2 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ type Client struct {
pc *webrtc.PeerConnection
dc *webrtc.DataChannel
iceCh chan webrtc.ICECandidateInit
receivers map[string]*webrtc.RTPReceiver
receivers map[string][]*webrtc.RTPReceiver
voiceSender *webrtc.RTPSender
screenTransceiver *webrtc.RTPTransceiver

Expand All @@ -126,7 +126,7 @@ func New(cfg Config, opts ...Option) (*Client, error) {
wsCloseCh: make(chan struct{}),
wsClientSeqNo: 1,
iceCh: make(chan webrtc.ICECandidateInit, iceChSize),
receivers: make(map[string]*webrtc.RTPReceiver),
receivers: make(map[string][]*webrtc.RTPReceiver),
apiClient: apiClient,
}

Expand Down
23 changes: 23 additions & 0 deletions client/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,29 @@ func (th *TestHelper) screenTrackWriter(track *webrtc.TrackLocalStaticRTP, close
}
}

func (th *TestHelper) transmitScreenTrack(c *Client) {
th.tb.Helper()

track := th.newScreenTrack()

sender, err := c.pc.AddTrack(track)
require.NoError(th.tb, err)

closeCh := make(chan struct{})
go func() {
rtcpBuf := make([]byte, receiveMTU)
for {
if _, _, rtcpErr := sender.Read(rtcpBuf); rtcpErr != nil {
log.Printf("failed to read rtcp: %s", rtcpErr.Error())
close(closeCh)
return
}
}
}()

go th.screenTrackWriter(track, closeCh)
}

func (th *TestHelper) newVoiceTrack() *webrtc.TrackLocalStaticSample {
th.tb.Helper()

Expand Down
7 changes: 5 additions & 2 deletions client/rtc.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ func (c *Client) initRTCSession() error {
}

c.mut.Lock()
c.receivers[sessionID] = receiver
c.receivers[sessionID] = append(c.receivers[sessionID], receiver)
c.mut.Unlock()

// RTCP handler
Expand All @@ -259,7 +259,10 @@ func (c *Client) initRTCSession() error {
}
}(track.RID())

c.emit(RTCTrackEvent, track)
c.emit(RTCTrackEvent, map[string]any{
"track": track,
"receiver": receiver,
})
})

pc.OnNegotiationNeeded(func() {
Expand Down
146 changes: 143 additions & 3 deletions client/rtc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ func TestRTCTrack(t *testing.T) {

rtcTrackCh := make(chan struct{})
err = th.userClient.On(RTCTrackEvent, func(ctx any) error {
track, ok := ctx.(*webrtc.TrackRemote)
ctxMap, ok := ctx.(map[string]any)
require.True(t, ok)
track, ok := ctxMap["track"].(*webrtc.TrackRemote)
require.True(t, ok)
require.Equal(t, webrtc.PayloadType(0x6f), track.PayloadType())
require.Equal(t, "audio/opus", track.Codec().MimeType)
Expand Down Expand Up @@ -239,7 +241,9 @@ func TestRTCTrack(t *testing.T) {

rtcTrackCh := make(chan struct{})
err = th.userClient.On(RTCTrackEvent, func(ctx any) error {
track, ok := ctx.(*webrtc.TrackRemote)
ctxMap, ok := ctx.(map[string]any)
require.True(t, ok)
track, ok := ctxMap["track"].(*webrtc.TrackRemote)
require.True(t, ok)
require.Equal(t, webrtc.PayloadType(0x6f), track.PayloadType())
require.Equal(t, "audio/opus", track.Codec().MimeType)
Expand Down Expand Up @@ -330,7 +334,9 @@ func TestRTCTrack(t *testing.T) {

rtcTrackCh := make(chan struct{})
err = th.userClient.On(RTCTrackEvent, func(ctx any) error {
track, ok := ctx.(*webrtc.TrackRemote)
ctxMap, ok := ctx.(map[string]any)
require.True(t, ok)
track, ok := ctxMap["track"].(*webrtc.TrackRemote)
require.True(t, ok)
require.Equal(t, webrtc.PayloadType(0x6f), track.PayloadType())
require.Equal(t, "audio/opus", track.Codec().MimeType)
Expand Down Expand Up @@ -397,4 +403,138 @@ func TestRTCTrack(t *testing.T) {
require.Fail(t, "timed out waiting for close event")
}
})

t.Run("multiple remote tracks per session", func(t *testing.T) {
th := setupTestHelper(t, "calls0")

rtcConnectChA := make(chan struct{})
err := th.userClient.On(RTCConnectEvent, func(_ any) error {
close(rtcConnectChA)
return nil
})
require.NoError(t, err)

rtcConnectChB := make(chan struct{})
err = th.adminClient.On(RTCConnectEvent, func(_ any) error {
close(rtcConnectChB)
return nil
})
require.NoError(t, err)

go func() {
err := th.adminClient.Connect()
require.NoError(t, err)
}()

go func() {
err := th.userClient.Connect()
require.NoError(t, err)
}()

select {
case <-rtcConnectChA:
case <-time.After(waitTimeout):
require.Fail(t, "timed out waiting for rtc connect event")
}

select {
case <-rtcConnectChB:
case <-time.After(waitTimeout):
require.Fail(t, "timed out waiting for rtc connect event")
}

voiceTrackCh := make(chan struct{})
screenTrackCh := make(chan struct{})
err = th.adminClient.On(RTCTrackEvent, func(ctx any) error {
ctxMap, ok := ctx.(map[string]any)
require.True(t, ok)
track, ok := ctxMap["track"].(*webrtc.TrackRemote)
require.True(t, ok)

receiver, ok := ctxMap["receiver"].(*webrtc.RTPReceiver)
require.True(t, ok)
require.Equal(t, track, receiver.Track())

trackType, sessionID, err := ParseTrackID(track.ID())
require.NoError(t, err)

require.Equal(t, th.userClient.originalConnID, sessionID)

if trackType == TrackTypeVoice {
require.Equal(t, webrtc.PayloadType(0x6f), track.PayloadType())
require.Equal(t, "audio/opus", track.Codec().MimeType)
close(voiceTrackCh)
} else if trackType == TrackTypeScreen {
require.Equal(t, webrtc.PayloadType(0x60), track.PayloadType())
require.Equal(t, "video/VP8", track.Codec().MimeType)
close(screenTrackCh)
} else {
require.Fail(t, "unexpected track type received")
}

return nil
})
require.NoError(t, err)

go func() {
th.transmitAudioTrack(th.userClient)
}()

go func() {
th.transmitScreenTrack(th.userClient)
}()

select {
case <-voiceTrackCh:
case <-time.After(waitTimeout):
require.Fail(t, "timed out waiting for voice track")
}

select {
case <-screenTrackCh:
case <-time.After(waitTimeout):
require.Fail(t, "timed out waiting for screen track")
}

th.userClient.mut.RLock()
require.Len(t, th.adminClient.receivers, 1)
require.Len(t, th.adminClient.receivers[th.userClient.originalConnID], 2)
th.userClient.mut.RUnlock()

closeChA := make(chan struct{})
err = th.userClient.On(CloseEvent, func(_ any) error {
close(closeChA)
return nil
})
require.NoError(t, err)

closeChB := make(chan struct{})
err = th.adminClient.On(CloseEvent, func(_ any) error {
close(closeChB)
return nil
})
require.NoError(t, err)

go func() {
err := th.userClient.Close()
require.NoError(t, err)
}()

go func() {
err := th.adminClient.Close()
require.NoError(t, err)
}()

select {
case <-closeChA:
case <-time.After(waitTimeout):
require.Fail(t, "timed out waiting for close event")
}

select {
case <-closeChB:
case <-time.After(waitTimeout):
require.Fail(t, "timed out waiting for close event")
}
})
}
2 changes: 1 addition & 1 deletion client/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ func (c *Client) handleWSMsg(msg ws.Message) error {
return fmt.Errorf("missing session_id from user_left event")
}
c.mut.Lock()
if rx := c.receivers[sessionID]; rx != nil {
for _, rx := range c.receivers[sessionID] {
log.Printf("stopping receiver for disconnected session %q", sessionID)
if err := rx.Stop(); err != nil {
log.Printf("failed to stop receiver for session %q: %s", sessionID, err)
Expand Down

0 comments on commit f2dd785

Please sign in to comment.