diff --git a/client/client.go b/client/client.go index d3e42af..5de0329 100644 --- a/client/client.go +++ b/client/client.go @@ -99,7 +99,7 @@ type Client struct { pc *webrtc.PeerConnection dc *webrtc.DataChannel iceCh chan webrtc.ICECandidateInit - receivers map[string]*webrtc.RTPReceiver + receivers map[string][]*webrtc.RTPReceiver voiceSender *webrtc.RTPSender screenTransceiver *webrtc.RTPTransceiver @@ -126,7 +126,7 @@ func New(cfg Config, opts ...Option) (*Client, error) { wsCloseCh: make(chan struct{}), wsClientSeqNo: 1, iceCh: make(chan webrtc.ICECandidateInit, iceChSize), - receivers: make(map[string]*webrtc.RTPReceiver), + receivers: make(map[string][]*webrtc.RTPReceiver), apiClient: apiClient, } diff --git a/client/helper_test.go b/client/helper_test.go index 9fb344e..f3d29e1 100644 --- a/client/helper_test.go +++ b/client/helper_test.go @@ -136,6 +136,29 @@ func (th *TestHelper) screenTrackWriter(track *webrtc.TrackLocalStaticRTP, close } } +func (th *TestHelper) transmitScreenTrack(c *Client) { + th.tb.Helper() + + track := th.newScreenTrack() + + sender, err := c.pc.AddTrack(track) + require.NoError(th.tb, err) + + closeCh := make(chan struct{}) + go func() { + rtcpBuf := make([]byte, receiveMTU) + for { + if _, _, rtcpErr := sender.Read(rtcpBuf); rtcpErr != nil { + log.Printf("failed to read rtcp: %s", rtcpErr.Error()) + close(closeCh) + return + } + } + }() + + go th.screenTrackWriter(track, closeCh) +} + func (th *TestHelper) newVoiceTrack() *webrtc.TrackLocalStaticSample { th.tb.Helper() diff --git a/client/rtc.go b/client/rtc.go index 936d6e1..155f466 100644 --- a/client/rtc.go +++ b/client/rtc.go @@ -238,7 +238,7 @@ func (c *Client) initRTCSession() error { } c.mut.Lock() - c.receivers[sessionID] = receiver + c.receivers[sessionID] = append(c.receivers[sessionID], receiver) c.mut.Unlock() // RTCP handler @@ -259,7 +259,10 @@ func (c *Client) initRTCSession() error { } }(track.RID()) - c.emit(RTCTrackEvent, track) + c.emit(RTCTrackEvent, map[string]any{ + "track": track, + "receiver": receiver, + }) }) pc.OnNegotiationNeeded(func() { diff --git a/client/rtc_test.go b/client/rtc_test.go index bbd97f5..3994b24 100644 --- a/client/rtc_test.go +++ b/client/rtc_test.go @@ -143,7 +143,9 @@ func TestRTCTrack(t *testing.T) { rtcTrackCh := make(chan struct{}) err = th.userClient.On(RTCTrackEvent, func(ctx any) error { - track, ok := ctx.(*webrtc.TrackRemote) + ctxMap, ok := ctx.(map[string]any) + require.True(t, ok) + track, ok := ctxMap["track"].(*webrtc.TrackRemote) require.True(t, ok) require.Equal(t, webrtc.PayloadType(0x6f), track.PayloadType()) require.Equal(t, "audio/opus", track.Codec().MimeType) @@ -239,7 +241,9 @@ func TestRTCTrack(t *testing.T) { rtcTrackCh := make(chan struct{}) err = th.userClient.On(RTCTrackEvent, func(ctx any) error { - track, ok := ctx.(*webrtc.TrackRemote) + ctxMap, ok := ctx.(map[string]any) + require.True(t, ok) + track, ok := ctxMap["track"].(*webrtc.TrackRemote) require.True(t, ok) require.Equal(t, webrtc.PayloadType(0x6f), track.PayloadType()) require.Equal(t, "audio/opus", track.Codec().MimeType) @@ -330,7 +334,9 @@ func TestRTCTrack(t *testing.T) { rtcTrackCh := make(chan struct{}) err = th.userClient.On(RTCTrackEvent, func(ctx any) error { - track, ok := ctx.(*webrtc.TrackRemote) + ctxMap, ok := ctx.(map[string]any) + require.True(t, ok) + track, ok := ctxMap["track"].(*webrtc.TrackRemote) require.True(t, ok) require.Equal(t, webrtc.PayloadType(0x6f), track.PayloadType()) require.Equal(t, "audio/opus", track.Codec().MimeType) @@ -397,4 +403,138 @@ func TestRTCTrack(t *testing.T) { require.Fail(t, "timed out waiting for close event") } }) + + t.Run("multiple remote tracks per session", func(t *testing.T) { + th := setupTestHelper(t, "calls0") + + rtcConnectChA := make(chan struct{}) + err := th.userClient.On(RTCConnectEvent, func(_ any) error { + close(rtcConnectChA) + return nil + }) + require.NoError(t, err) + + rtcConnectChB := make(chan struct{}) + err = th.adminClient.On(RTCConnectEvent, func(_ any) error { + close(rtcConnectChB) + return nil + }) + require.NoError(t, err) + + go func() { + err := th.adminClient.Connect() + require.NoError(t, err) + }() + + go func() { + err := th.userClient.Connect() + require.NoError(t, err) + }() + + select { + case <-rtcConnectChA: + case <-time.After(waitTimeout): + require.Fail(t, "timed out waiting for rtc connect event") + } + + select { + case <-rtcConnectChB: + case <-time.After(waitTimeout): + require.Fail(t, "timed out waiting for rtc connect event") + } + + voiceTrackCh := make(chan struct{}) + screenTrackCh := make(chan struct{}) + err = th.adminClient.On(RTCTrackEvent, func(ctx any) error { + ctxMap, ok := ctx.(map[string]any) + require.True(t, ok) + track, ok := ctxMap["track"].(*webrtc.TrackRemote) + require.True(t, ok) + + receiver, ok := ctxMap["receiver"].(*webrtc.RTPReceiver) + require.True(t, ok) + require.Equal(t, track, receiver.Track()) + + trackType, sessionID, err := ParseTrackID(track.ID()) + require.NoError(t, err) + + require.Equal(t, th.userClient.originalConnID, sessionID) + + if trackType == TrackTypeVoice { + require.Equal(t, webrtc.PayloadType(0x6f), track.PayloadType()) + require.Equal(t, "audio/opus", track.Codec().MimeType) + close(voiceTrackCh) + } else if trackType == TrackTypeScreen { + require.Equal(t, webrtc.PayloadType(0x60), track.PayloadType()) + require.Equal(t, "video/VP8", track.Codec().MimeType) + close(screenTrackCh) + } else { + require.Fail(t, "unexpected track type received") + } + + return nil + }) + require.NoError(t, err) + + go func() { + th.transmitAudioTrack(th.userClient) + }() + + go func() { + th.transmitScreenTrack(th.userClient) + }() + + select { + case <-voiceTrackCh: + case <-time.After(waitTimeout): + require.Fail(t, "timed out waiting for voice track") + } + + select { + case <-screenTrackCh: + case <-time.After(waitTimeout): + require.Fail(t, "timed out waiting for screen track") + } + + th.userClient.mut.RLock() + require.Len(t, th.adminClient.receivers, 1) + require.Len(t, th.adminClient.receivers[th.userClient.originalConnID], 2) + th.userClient.mut.RUnlock() + + closeChA := make(chan struct{}) + err = th.userClient.On(CloseEvent, func(_ any) error { + close(closeChA) + return nil + }) + require.NoError(t, err) + + closeChB := make(chan struct{}) + err = th.adminClient.On(CloseEvent, func(_ any) error { + close(closeChB) + return nil + }) + require.NoError(t, err) + + go func() { + err := th.userClient.Close() + require.NoError(t, err) + }() + + go func() { + err := th.adminClient.Close() + require.NoError(t, err) + }() + + select { + case <-closeChA: + case <-time.After(waitTimeout): + require.Fail(t, "timed out waiting for close event") + } + + select { + case <-closeChB: + case <-time.After(waitTimeout): + require.Fail(t, "timed out waiting for close event") + } + }) } diff --git a/client/websocket.go b/client/websocket.go index 19acf76..a357f69 100644 --- a/client/websocket.go +++ b/client/websocket.go @@ -190,7 +190,7 @@ func (c *Client) handleWSMsg(msg ws.Message) error { return fmt.Errorf("missing session_id from user_left event") } c.mut.Lock() - if rx := c.receivers[sessionID]; rx != nil { + for _, rx := range c.receivers[sessionID] { log.Printf("stopping receiver for disconnected session %q", sessionID) if err := rx.Stop(); err != nil { log.Printf("failed to stop receiver for session %q: %s", sessionID, err)