Skip to content

Commit

Permalink
Calls AI
Browse files Browse the repository at this point in the history
  • Loading branch information
streamer45 committed May 3, 2024
1 parent eb3c3f0 commit 81a8199
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 17 deletions.
12 changes: 6 additions & 6 deletions client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,30 @@ const (
httpResponseBodyMaxSizeBytes = 1024 * 1024 // 1MB
)

func (c *Client) Unmute(track webrtc.TrackLocal) error {
func (c *Client) Unmute(track webrtc.TrackLocal) (*webrtc.RTPSender, error) {
if track == nil {
return fmt.Errorf("invalid nil track")
return nil, fmt.Errorf("invalid nil track")
}

c.mut.Lock()
defer c.mut.Unlock()

if c.pc == nil {
return fmt.Errorf("rtc client is not initialized")
return nil, fmt.Errorf("rtc client is not initialized")
}

sender := c.voiceSender

if sender == nil {
snd, err := c.pc.AddTrack(track)
if err != nil {
return fmt.Errorf("failed to add track: %w", err)
return nil, fmt.Errorf("failed to add track: %w", err)
}
c.voiceSender = snd
sender = snd
} else {
if err := sender.ReplaceTrack(track); err != nil {
return fmt.Errorf("failed to replace track: %w", err)
return nil, fmt.Errorf("failed to replace track: %w", err)
}
}

Expand All @@ -58,7 +58,7 @@ func (c *Client) Unmute(track webrtc.TrackLocal) error {
}
}()

return c.sendWS(wsEventUnmute, nil, false)
return sender, c.sendWS(wsEventUnmute, nil, false)
}

func (c *Client) Mute() error {
Expand Down
18 changes: 9 additions & 9 deletions client/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func TestAPIMuteUnmute(t *testing.T) {
require.NoError(t, err)

t.Run("not initialized", func(t *testing.T) {
err := th.userClient.Unmute(th.newVoiceTrack())
_, err := th.userClient.Unmute(th.newVoiceTrack())
require.EqualError(t, err, "rtc client is not initialized")
})

Expand Down Expand Up @@ -92,7 +92,7 @@ func TestAPIMuteUnmute(t *testing.T) {

// User unmutes, admin should receive the track
userVoiceTrack := th.newVoiceTrack()
err = th.userClient.Unmute(userVoiceTrack)
_, err = th.userClient.Unmute(userVoiceTrack)
require.NoError(t, err)
go th.voiceTrackWriter(userVoiceTrack, userCloseCh)

Expand Down Expand Up @@ -144,7 +144,7 @@ func TestAPIMuteUnmute(t *testing.T) {

// Admin unmutes, user should receive the track
adminVoiceTrack := th.newVoiceTrack()
err = th.adminClient.Unmute(adminVoiceTrack)
_, err = th.adminClient.Unmute(adminVoiceTrack)
require.NoError(t, err)
go th.voiceTrackWriter(adminVoiceTrack, adminCloseCh)

Expand Down Expand Up @@ -257,7 +257,7 @@ func TestAPIMuteUnmuteNegotiation(t *testing.T) {
adminCloseCh := make(chan struct{})

userVoiceTrack := th.newVoiceTrack()
err = th.userClient.Unmute(userVoiceTrack)
_, err = th.userClient.Unmute(userVoiceTrack)
require.NoError(t, err)
go th.voiceTrackWriter(userVoiceTrack, userCloseCh)

Expand All @@ -283,13 +283,13 @@ func TestAPIMuteUnmuteNegotiation(t *testing.T) {
})

adminVoiceTrack := th.newVoiceTrack()
err = th.adminClient.Unmute(adminVoiceTrack)
_, err = th.adminClient.Unmute(adminVoiceTrack)
require.NoError(t, err)
go th.voiceTrackWriter(adminVoiceTrack, adminCloseCh)

time.Sleep(time.Second)

err = th.userClient.Unmute(userVoiceTrack)
_, err = th.userClient.Unmute(userVoiceTrack)
require.NoError(t, err)

time.Sleep(time.Second)
Expand Down Expand Up @@ -652,7 +652,7 @@ func TestAPIConcurrency(t *testing.T) {
for i := 0; i < 10; i++ {
go func() {
defer wg.Done()
err := th.userClient.Unmute(nil)
_, err := th.userClient.Unmute(nil)
require.EqualError(t, err, "invalid nil track")
}()
}
Expand All @@ -668,7 +668,7 @@ func TestAPIConcurrency(t *testing.T) {
for i := 0; i < 10; i++ {
go func() {
defer wg.Done()
err := th.userClient.Unmute(track)
_, err := th.userClient.Unmute(track)
require.NoError(t, err)
}()

Expand All @@ -688,7 +688,7 @@ func TestAPIConcurrency(t *testing.T) {
for i := 0; i < 10; i++ {
go func() {
defer wg.Done()
err := th.userClient.Unmute(th.newVoiceTrack())
_, err := th.userClient.Unmute(th.newVoiceTrack())
require.NoError(t, err)
}()

Expand Down
6 changes: 5 additions & 1 deletion client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ const (
WSCallLoweredHandEvent EventType = "WSCallLoweredHand"
WSCallScreenOnEvent EventType = "WSCallScreenOn"
WSCallScreenOffEvent EventType = "WSCallScreenOff"
WSSummonAIEvent EventType = "WSSummonAI"
WSGenericEvent EventType = "WSGeneric"
)

func (e EventType) IsValid() bool {
Expand All @@ -58,7 +60,9 @@ func (e EventType) IsValid() bool {
WSCallRaisedHandEvent, WSCallLoweredHandEvent,
WSCallScreenOnEvent, WSCallScreenOffEvent,
WSCallJobStateEvent,
WSJobStopEvent:
WSJobStopEvent,
WSSummonAIEvent,
WSGenericEvent:
return true
default:
return false
Expand Down
9 changes: 9 additions & 0 deletions client/rtc.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ var (
"urn:ietf:params:rtp-hdrext:sdes:rtp-stream-id",
"urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id",
}
rtpAudioExtensions = []string{
"urn:ietf:params:rtp-hdrext:ssrc-audio-level",
}
)

func (c *Client) handleWSEventSignal(evData map[string]any) error {
Expand Down Expand Up @@ -163,6 +166,12 @@ func (c *Client) initRTCSession() error {
}
}

for _, ext := range rtpAudioExtensions {
if err := m.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{URI: ext}, webrtc.RTPCodecTypeAudio); err != nil {
return fmt.Errorf("failed to register header extension: %w", err)
}
}

api := webrtc.NewAPI(webrtc.WithMediaEngine(&m), webrtc.WithInterceptorRegistry(&i))

pc, err := api.NewPeerConnection(cfg)
Expand Down
11 changes: 10 additions & 1 deletion client/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ const (
wsEventUserScreenOn = wsEvPrefix + "user_screen_on"
wsEventUserScreenOff = wsEvPrefix + "user_screen_off"
wsEventUserReacted = wsEvPrefix + "user_reacted"
wsEventSummonAI = wsEvPrefix + "summon_ai"
)

var (
Expand Down Expand Up @@ -195,8 +196,8 @@ func (c *Client) handleWSMsg(msg ws.Message) error {
if err := rx.Stop(); err != nil {
log.Printf("failed to stop receiver for session %q: %s", sessionID, err)
}
delete(c.receivers, sessionID)
}
delete(c.receivers, sessionID)
c.mut.Unlock()
case wsEventCallEnd:
channelID := ev.GetBroadcast().ChannelId
Expand Down Expand Up @@ -295,7 +296,15 @@ func (c *Client) handleWSMsg(msg ws.Message) error {
evType = WSCallScreenOffEvent
}
c.emit(evType, sessionID)
case wsEventSummonAI:
channelID, _ := ev.GetData()["channel_id"].(string)
if channelID != c.cfg.ChannelID {
return nil
}
authToken, _ := ev.GetData()["auth_token"].(string)
c.emit(WSSummonAIEvent, authToken)
default:
c.emit(WSGenericEvent, ev)
}
case ws.BinaryMessage:
default:
Expand Down

0 comments on commit 81a8199

Please sign in to comment.