diff --git a/lt/client/client.go b/lt/client/client.go index 4decef8a1..270388925 100644 --- a/lt/client/client.go +++ b/lt/client/client.go @@ -2,35 +2,29 @@ package client import ( "bytes" - "compress/zlib" "context" "encoding/binary" - "encoding/json" "errors" "fmt" "io" "log" - "net/http" "os" - "strings" "sync" + "sync/atomic" "time" - "github.com/aws/aws-sdk-go/service/polly" - "github.com/pion/webrtc/v3" - - "github.com/mattermost/mattermost-plugin-calls/lt/ws" + "github.com/mattermost/rtcd/client" "github.com/mattermost/mattermost/server/public/model" - "github.com/pion/interceptor" - "github.com/pion/rtcp" "github.com/pion/rtp" "github.com/pion/rtp/codecs" + "github.com/pion/webrtc/v3" "github.com/pion/webrtc/v3/pkg/media" "github.com/pion/webrtc/v3/pkg/media/ivfreader" "github.com/pion/webrtc/v3/pkg/media/oggreader" + "github.com/aws/aws-sdk-go/service/polly" "gopkg.in/hraban/opus.v2" ) @@ -58,7 +52,6 @@ var ( "urn:ietf:params:rtp-hdrext:sdes:rtp-stream-id", "urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id", } - audioLevelExtensionURI = "urn:ietf:params:rtp-hdrext:ssrc-audio-level" ) const ( @@ -81,7 +74,6 @@ type Config struct { Speak bool ScreenSharing bool Recording bool - Simulcast bool Setup bool SpeechFile string PollySession *polly.Polly @@ -91,40 +83,20 @@ type Config struct { type User struct { userID string cfg Config - client *model.Client4 - pc *webrtc.PeerConnection - dc *webrtc.DataChannel - connectedCh chan struct{} - doneCh chan struct{} - iceCh chan webrtc.ICECandidateInit - initCh chan struct{} - isHost bool + apiClient *model.Client4 + callsClient *client.Client + callsConfig map[string]any + hostID atomic.Value pollySession *polly.Polly pollyVoiceID *string speechTextCh chan string doneSpeakingCh chan struct{} - - // WebSocket - wsCloseCh chan struct{} - wsSendCh chan wsMsg -} - -type wsMsg struct { - event string - data map[string]interface{} - binary bool } func NewUser(cfg Config) *User { return &User{ cfg: cfg, - connectedCh: make(chan struct{}), - doneCh: make(chan struct{}), - iceCh: make(chan webrtc.ICECandidateInit, 10), - wsCloseCh: make(chan struct{}), - wsSendCh: make(chan wsMsg, 256), - initCh: make(chan struct{}), speechTextCh: make(chan string, 8), doneSpeakingCh: make(chan struct{}), pollySession: cfg.PollySession, @@ -132,7 +104,7 @@ func NewUser(cfg Config) *User { } } -func (u *User) sendVideoFile(track *webrtc.TrackLocalStaticRTP, trx *webrtc.RTPTransceiver, simulcast bool) { +func (u *User) sendVideoFile(track *webrtc.TrackLocalStaticRTP, trx *webrtc.RTPTransceiver) { getExtensionID := func(URI string) uint8 { for _, ext := range trx.Sender().GetParameters().RTPParameters.HeaderExtensions { if ext.URI == URI { @@ -154,7 +126,7 @@ func (u *User) sendVideoFile(track *webrtc.TrackLocalStaticRTP, trx *webrtc.RTPT ) // Open a IVF file and start reading using our IVFReader - file, ivfErr := os.Open(fmt.Sprintf("./lt/samples/video_%s.ivf", track.RID())) + file, ivfErr := os.Open(fmt.Sprintf("./samples/video_%s.ivf", track.RID())) if ivfErr != nil { log.Fatalf(ivfErr.Error()) } @@ -165,9 +137,6 @@ func (u *User) sendVideoFile(track *webrtc.TrackLocalStaticRTP, trx *webrtc.RTPT log.Fatalf(ivfErr.Error()) } - // Wait for connection established - <-u.connectedCh - // Send our video file frame at a time. Pace our sending so we send it at the same speed it should be played back as. // This isn't required since the video is timestamped, but we will such much higher loss if we send all at once. // @@ -198,7 +167,7 @@ func (u *User) sendVideoFile(track *webrtc.TrackLocalStaticRTP, trx *webrtc.RTPT packets := packetizer.Packetize(frame, rtpVideoCodecVP8.ClockRate/header.TimebaseDenominator) for _, p := range packets { - if simulcast { + if u.callsConfig["EnableSimulcast"].(bool) { if err := p.Header.SetExtension(getExtensionID(rtpVideoExtensions[0]), []byte(trx.Mid())); err != nil { log.Printf("failed to set header extension: %s", err.Error()) } @@ -216,25 +185,7 @@ func (u *User) sendVideoFile(track *webrtc.TrackLocalStaticRTP, trx *webrtc.RTPT } } -func (u *User) startRecording() error { - log.Printf("%s: starting recording", u.cfg.Username) - ctx, cancel := context.WithTimeout(context.Background(), HTTPRequestTimeout) - defer cancel() - res, err := u.client.DoAPIRequest(ctx, http.MethodPost, - fmt.Sprintf("%s/plugins/com.mattermost.calls/calls/%s/recording/start", u.client.URL, u.cfg.ChannelID), "", "") - if err != nil { - return fmt.Errorf("request failed: %w", err) - } - defer res.Body.Close() - - if res.StatusCode == 200 { - return nil - } - - return fmt.Errorf("unexpected status code %d", res.StatusCode) -} - -func (u *User) transmitScreen(simulcast bool) { +func (u *User) transmitScreen() { streamID := model.NewId() trackHigh, err := webrtc.NewTrackLocalStaticRTP(rtpVideoCodecVP8, "video", streamID, webrtc.WithRTPStreamID(simulcastLevelHigh)) @@ -242,165 +193,90 @@ func (u *User) transmitScreen(simulcast bool) { log.Fatalf(err.Error()) } - trx, err := u.pc.AddTransceiverFromTrack(trackHigh, webrtc.RTPTransceiverInit{Direction: webrtc.RTPTransceiverDirectionSendonly}) - if err != nil { - log.Fatalf(err.Error()) - } - - info := map[string]string{ - "screenStreamID": trackHigh.StreamID(), - } - data, err := json.Marshal(&info) - if err != nil { - log.Fatalf(err.Error()) - } - - select { - case u.wsSendCh <- wsMsg{event: "custom_com.mattermost.calls_screen_on", data: map[string]interface{}{ - "data": string(data), - }}: - default: - log.Printf("failed to send ws message") - } - - rtpSender := trx.Sender() + tracks := []webrtc.TrackLocal{trackHigh} var trackLow *webrtc.TrackLocalStaticRTP - if simulcast { + if u.callsConfig["EnableSimulcast"].(bool) { trackLow, err = webrtc.NewTrackLocalStaticRTP(rtpVideoCodecVP8, "video", streamID, webrtc.WithRTPStreamID(simulcastLevelLow)) if err != nil { log.Fatalf(err.Error()) } - if err := rtpSender.AddEncoding(trackLow); err != nil { - log.Fatalf(err.Error()) - } + tracks = []webrtc.TrackLocal{trackLow, trackHigh} } - go func() { - rtcpBuf := make([]byte, receiveMTU) - for { - if _, _, rtcpErr := rtpSender.Read(rtcpBuf); rtcpErr != nil { - return - } - } - }() - - go func() { - defer func() { - select { - case u.wsSendCh <- wsMsg{event: "custom_com.mattermost.calls_screen_off", data: nil}: - default: - log.Printf("failed to send ws message") - } - }() + trx, err := u.callsClient.StartScreenShare(tracks) + if err != nil { + log.Fatalf(err.Error()) + } - if simulcast { - go u.sendVideoFile(trackLow, trx, simulcast) - } + if u.callsConfig["EnableSimulcast"].(bool) { + go u.sendVideoFile(trackLow, trx) + } - u.sendVideoFile(trackHigh, trx, simulcast) - }() + u.sendVideoFile(trackHigh, trx) } func (u *User) transmitAudio() { - track, err := webrtc.NewTrackLocalStaticSample(rtpAudioCodec, "audio", "voice"+model.NewId()) + track, err := webrtc.NewTrackLocalStaticSample(rtpAudioCodec, "audio", "voice_"+model.NewId()) if err != nil { log.Fatalf(err.Error()) } - sender, err := u.pc.AddTrack(track) - if err != nil { + + // Open a OGG file and start reading using our OGGReader + file, oggErr := os.Open(u.cfg.SpeechFile) + if oggErr != nil { + log.Fatalf(oggErr.Error()) + } + defer file.Close() + + // Open on oggfile in non-checksum mode. + ogg, _, oggErr := oggreader.NewWith(file) + if oggErr != nil { + log.Fatalf(oggErr.Error()) + } + + if err := u.callsClient.Unmute(track); err != nil { log.Fatalf(err.Error()) } - go func() { - rtcpBuf := make([]byte, receiveMTU) - for { - if _, _, rtcpErr := sender.Read(rtcpBuf); rtcpErr != nil { - log.Printf("%s: failed to read rtcp: %s", u.cfg.Username, rtcpErr.Error()) - return - } - } - }() + // Keep track of last granule, the difference is the amount of samples in the buffer + var lastGranule uint64 - go func() { - // Open a OGG file and start reading using our OGGReader - file, oggErr := os.Open(u.cfg.SpeechFile) - if oggErr != nil { - log.Fatalf(oggErr.Error()) + // It is important to use a time.Ticker instead of time.Sleep because + // * avoids accumulating skew, just calling time.Sleep didn't compensate for the time spent parsing the data + // * works around latency issues with Sleep (see https://github.com/golang/go/issues/44343) + oggPageDuration := time.Millisecond * 20 + ticker := time.NewTicker(oggPageDuration) + for ; true; <-ticker.C { + var oggErr error + var pageData []byte + var pageHeader *oggreader.OggPageHeader + pageData, pageHeader, oggErr = ogg.ParseNextPage() + if oggErr == io.EOF { + ogg.ResetReader(func(_ int64) io.Reader { + _, _ = file.Seek(0, 0) + return file + }) + pageData, pageHeader, oggErr = ogg.ParseNextPage() } - defer file.Close() - - // Open on oggfile in non-checksum mode. - ogg, _, oggErr := oggreader.NewWith(file) if oggErr != nil { log.Fatalf(oggErr.Error()) } - // Wait for connection established - <-u.connectedCh - - select { - case u.wsSendCh <- wsMsg{event: "custom_com.mattermost.calls_unmute", data: nil}: - default: - log.Printf("failed to send ws message") - } - defer func() { - select { - case u.wsSendCh <- wsMsg{event: "custom_com.mattermost.calls_mute", data: nil}: - default: - log.Printf("failed to send ws message") - } - }() - - // Keep track of last granule, the difference is the amount of samples in the buffer - var lastGranule uint64 - - // It is important to use a time.Ticker instead of time.Sleep because - // * avoids accumulating skew, just calling time.Sleep didn't compensate for the time spent parsing the data - // * works around latency issues with Sleep (see https://github.com/golang/go/issues/44343) - oggPageDuration := time.Millisecond * 20 - ticker := time.NewTicker(oggPageDuration) - for ; true; <-ticker.C { - var oggErr error - var pageData []byte - var pageHeader *oggreader.OggPageHeader - pageData, pageHeader, oggErr = ogg.ParseNextPage() - if oggErr == io.EOF { - ogg.ResetReader(func(_ int64) io.Reader { - _, _ = file.Seek(0, 0) - return file - }) - pageData, pageHeader, oggErr = ogg.ParseNextPage() - } - if oggErr != nil { - log.Fatalf(oggErr.Error()) - } - - // The amount of samples is the difference between the last and current timestamp - sampleCount := float64(pageHeader.GranulePosition - lastGranule) - lastGranule = pageHeader.GranulePosition - sampleDuration := time.Duration((sampleCount/48000)*1000) * time.Millisecond + // The amount of samples is the difference between the last and current timestamp + sampleCount := float64(pageHeader.GranulePosition - lastGranule) + lastGranule = pageHeader.GranulePosition + sampleDuration := time.Duration((sampleCount/48000)*1000) * time.Millisecond - if err := track.WriteSample(media.Sample{Data: pageData, Duration: sampleDuration}); err != nil { - log.Printf("failed to write audio sample: %s", err.Error()) - } + if err := track.WriteSample(media.Sample{Data: pageData, Duration: sampleDuration}); err != nil { + log.Printf("failed to write audio sample: %s", err.Error()) } - }() -} - -func (u *User) Unmute() { - select { - case u.wsSendCh <- wsMsg{event: "custom_com.mattermost.calls_unmute", data: nil}: - default: - log.Printf("failed to send ws message") } } func (u *User) Mute() { - select { - case u.wsSendCh <- wsMsg{event: "custom_com.mattermost.calls_mute", data: nil}: - default: - log.Printf("failed to send ws message") + if err := u.callsClient.Mute(); err != nil { + log.Printf("%s: failed to mute: %s", u.cfg.Username, err.Error()) } } @@ -409,470 +285,96 @@ func (u *User) transmitSpeech() { if err != nil { log.Fatalf(err.Error()) } - sender, err := u.pc.AddTrack(track) + + enc, err := opus.NewEncoder(24000, 1, opus.AppVoIP) if err != nil { + log.Fatalf("%s: failed to create opus encoder: %s", u.cfg.Username, err.Error()) + } + + if err := u.callsClient.Unmute(track); err != nil { log.Fatalf(err.Error()) } - go func() { - rtcpBuf := make([]byte, receiveMTU) - for { - if _, _, rtcpErr := sender.Read(rtcpBuf); rtcpErr != nil { - log.Printf("%s: failed to read rtcp: %s", u.cfg.Username, rtcpErr.Error()) + for text := range u.speechTextCh { + func() { + defer func() { + u.doneSpeakingCh <- struct{}{} + }() + log.Printf("%s: received text to speak: %q", u.cfg.Username, text) + + var rd io.Reader + var rate int + var err error + if u.pollySession != nil { + rd, rate, err = u.pollyToSpeech(text) + } + if err != nil { + log.Printf("%s: textToSpeech failed: %s", u.cfg.Username, err.Error()) return } - } - }() - go func() { - // Wait for connection established - <-u.connectedCh + log.Printf("%s: raw speech samples decoded (%d)", u.cfg.Username, rate) - enc, err := opus.NewEncoder(24000, 1, opus.AppVoIP) - if err != nil { - log.Fatalf("%s: failed to create opus encoder: %s", u.cfg.Username, err.Error()) - } + audioSamplesDataBuf := bytes.NewBuffer([]byte{}) + if _, err := audioSamplesDataBuf.ReadFrom(rd); err != nil { + log.Printf("%s: failed to read samples data: %s", u.cfg.Username, err.Error()) + return + } - for text := range u.speechTextCh { - func() { - defer func() { - u.doneSpeakingCh <- struct{}{} - }() - log.Printf("%s: received text to speak: %q", u.cfg.Username, text) - - var rd io.Reader - var rate int - var err error - if u.pollySession != nil { - rd, rate, err = u.pollyToSpeech(text) - } + log.Printf("read %d samples bytes", audioSamplesDataBuf.Len()) + + sampleDuration := time.Millisecond * 20 + ticker := time.NewTicker(sampleDuration) + audioSamplesData := make([]byte, 480*4) + audioSamples := make([]int16, 480) + opusData := make([]byte, 8192) + for ; true; <-ticker.C { + n, err := audioSamplesDataBuf.Read(audioSamplesData) if err != nil { - log.Printf("%s: textToSpeech failed: %s", u.cfg.Username, err.Error()) - return + if !errors.Is(err, io.EOF) { + log.Printf("%s: failed to read audio samples: %s", u.cfg.Username, err.Error()) + } + break } - log.Printf("%s: raw speech samples decoded (%d)", u.cfg.Username, rate) - - audioSamplesDataBuf := bytes.NewBuffer([]byte{}) - if _, err := audioSamplesDataBuf.ReadFrom(rd); err != nil { - log.Printf("%s: failed to read samples data: %s", u.cfg.Username, err.Error()) - return + // Convert []byte to []int16 + for i := 0; i < n; i += 4 { + audioSamples[i/4] = int16(binary.LittleEndian.Uint16(audioSamplesData[i : i+4])) } - log.Printf("read %d samples bytes", audioSamplesDataBuf.Len()) - - sampleDuration := time.Millisecond * 20 - ticker := time.NewTicker(sampleDuration) - audioSamplesData := make([]byte, 480*4) - audioSamples := make([]int16, 480) - opusData := make([]byte, 8192) - for ; true; <-ticker.C { - n, err := audioSamplesDataBuf.Read(audioSamplesData) - if err != nil { - if !errors.Is(err, io.EOF) { - log.Printf("%s: failed to read audio samples: %s", u.cfg.Username, err.Error()) - } - break - } - - // Convert []byte to []int16 - for i := 0; i < n; i += 4 { - audioSamples[i/4] = int16(binary.LittleEndian.Uint16(audioSamplesData[i : i+4])) - } - - n, err = enc.Encode(audioSamples, opusData) - if err != nil { - log.Printf("%s: failed to encode: %s", u.cfg.Username, err.Error()) - continue - } + n, err = enc.Encode(audioSamples, opusData) + if err != nil { + log.Printf("%s: failed to encode: %s", u.cfg.Username, err.Error()) + continue + } - if err := track.WriteSample(media.Sample{Data: opusData[:n], Duration: sampleDuration}); err != nil { - log.Printf("%s: failed to write audio sample: %s", u.cfg.Username, err.Error()) - } + if err := track.WriteSample(media.Sample{Data: opusData[:n], Duration: sampleDuration}); err != nil { + log.Printf("%s: failed to write audio sample: %s", u.cfg.Username, err.Error()) } - }() - } - }() + } + }() + } } func (u *User) Speak(text string) chan struct{} { - u.Unmute() u.speechTextCh <- text return u.doneSpeakingCh } -func (u *User) initRTC() error { - log.Printf("%s: setting up RTC connection", u.cfg.Username) - - peerConnConfig := webrtc.Configuration{ - ICEServers: []webrtc.ICEServer{}, - SDPSemantics: webrtc.SDPSemanticsUnifiedPlan, - } - - var m webrtc.MediaEngine - if err := m.RegisterDefaultCodecs(); err != nil { - return err - } - - i := interceptor.Registry{} - if err := webrtc.RegisterDefaultInterceptors(&m, &i); err != nil { - return err - } - - if err := m.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{ - URI: audioLevelExtensionURI, - }, webrtc.RTPCodecTypeAudio); err != nil { - return err - } - - if u.cfg.Simulcast { - for _, ext := range rtpVideoExtensions { - if err := m.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{URI: ext}, webrtc.RTPCodecTypeVideo); err != nil { - return err - } - } - } - - api := webrtc.NewAPI(webrtc.WithMediaEngine(&m), webrtc.WithInterceptorRegistry(&i)) - - pc, err := api.NewPeerConnection(peerConnConfig) - if err != nil { - return err - } - u.pc = pc - - gatherCh := make(chan struct{}, 1) - pc.OnICECandidate(func(c *webrtc.ICECandidate) { - if c == nil { - log.Printf("%s: end of candidates", u.cfg.Username) - select { - case gatherCh <- struct{}{}: - default: - } - return - } - - log.Printf("%s: ice: %v", u.cfg.Username, c) - - data, err := json.Marshal(c.ToJSON()) - if err != nil { - log.Fatalf(err.Error()) - } - - select { - case u.wsSendCh <- wsMsg{"custom_com.mattermost.calls_ice", map[string]interface{}{ - "data": string(data), - }, false}: - default: - log.Fatalf("failed to send ice ws message") - } - }) - - pc.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) { - if connectionState == webrtc.ICEConnectionStateConnected { - log.Printf("%s: rtc connected", u.cfg.Username) - close(u.connectedCh) - - if u.cfg.Recording && u.isHost { - if err := u.startRecording(); err != nil { - log.Printf("%s: failed to start recording: %s", u.cfg.Username, err) - } else { - log.Printf("%s: recording started successfully", u.cfg.Username) - } - } - } - - if connectionState == webrtc.ICEConnectionStateDisconnected || connectionState == webrtc.ICEConnectionStateFailed { - log.Printf("%s: ice disconnect", u.cfg.Username) - close(u.wsCloseCh) - } - }) - - pc.OnTrack(func(track *webrtc.TrackRemote, _ *webrtc.RTPReceiver) { - if track.Kind() == webrtc.RTPCodecTypeVideo { - rtcpSendErr := pc.WriteRTCP([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(track.SSRC())}}) - if rtcpSendErr != nil { - log.Printf("%s: rtcp send error: %s", u.cfg.Username, rtcpSendErr.Error()) - } - } - - codecName := strings.Split(track.Codec().RTPCodecCapability.MimeType, "/")[1] - log.Printf("%s: Track has started, of type %d: %s \n", u.cfg.Username, track.PayloadType(), codecName) - - buf := make([]byte, receiveMTU) - for { - _, _, readErr := track.Read(buf) - if readErr != nil { - log.Printf("%s: track read error: %s", u.cfg.Username, readErr.Error()) - return - } - } - }) - +func (u *User) onConnect() { if u.cfg.Unmuted { - u.transmitAudio() + go u.transmitAudio() } else if u.cfg.Speak { - u.transmitSpeech() + go u.transmitSpeech() } - if u.cfg.ScreenSharing { - u.transmitScreen(u.cfg.Simulcast) + go u.transmitScreen() } - dc, err := pc.CreateDataChannel("calls-dc", nil) - if err != nil { - return err - } - - u.dc = dc - - offer, err := pc.CreateOffer(nil) - if err != nil { - return err - } - - if err := pc.SetLocalDescription(offer); err != nil { - return err - } - - var sdpData bytes.Buffer - w := zlib.NewWriter(&sdpData) - if err := json.NewEncoder(w).Encode(offer); err != nil { - return err - } - w.Close() - - data := map[string]interface{}{ - "data": sdpData.Bytes(), - } - - select { - case u.wsSendCh <- wsMsg{"custom_com.mattermost.calls_sdp", data, true}: - default: - log.Fatalf("failed to send sdp ws message") - } - - <-gatherCh - - close(u.initCh) - - return nil -} - -func (u *User) handleSignal(ev *model.WebSocketEvent) { - evData := ev.GetData() - var data map[string]interface{} - if err := json.Unmarshal([]byte(evData["data"].(string)), &data); err != nil { - log.Fatalf(err.Error()) - } - - t, _ := data["type"].(string) - - if t == "candidate" { - log.Printf("%s: ice!", u.cfg.Username) - u.iceCh <- webrtc.ICECandidateInit{Candidate: data["candidate"].(map[string]interface{})["candidate"].(string)} - } else if t == "answer" { - log.Printf("%s: sdp answer!", u.cfg.Username) - if err := u.pc.SetRemoteDescription(webrtc.SessionDescription{ - Type: webrtc.SDPTypeAnswer, - SDP: data["sdp"].(string), - }); err != nil { - log.Fatalf("%s: SetRemoteDescription failed: %s", u.cfg.Username, err.Error()) - } - - go func() { - for ice := range u.iceCh { - if err := u.pc.AddICECandidate(ice); err != nil { - log.Printf("%s: %s", u.cfg.Username, err.Error()) - } - } - }() - } else if t == "offer" { - log.Printf("%s: sdp offer", u.cfg.Username) - - if u.pc.SignalingState() != webrtc.SignalingStateStable { - log.Printf("%s: signaling conflict on offer, queuing", u.cfg.Username) - go func() { - time.Sleep(100 * time.Millisecond) - log.Printf("%s: applying previously queued offer", u.cfg.Username) - u.handleSignal(ev) - }() - return - } - - if err := u.pc.SetRemoteDescription(webrtc.SessionDescription{ - Type: webrtc.SDPTypeOffer, - SDP: data["sdp"].(string), - }); err != nil { - log.Fatalf("%s: SetRemoteDescription failed: %s", u.cfg.Username, err.Error()) - } - - sdp, err := u.pc.CreateAnswer(nil) - if err != nil { - log.Printf("%s: %s", u.cfg.Username, err.Error()) - } - - if err := u.pc.SetLocalDescription(sdp); err != nil { - log.Printf("%s: SetLocalDescription failed: %s", u.cfg.Username, err.Error()) - } - - var sdpData bytes.Buffer - w := zlib.NewWriter(&sdpData) - if err := json.NewEncoder(w).Encode(sdp); err != nil { - log.Fatalf("%s: %s", u.cfg.Username, err.Error()) - } - w.Close() - - data := map[string]interface{}{ - "data": sdpData.Bytes(), - } - select { - case u.wsSendCh <- wsMsg{"custom_com.mattermost.calls_sdp", data, true}: - default: - log.Printf("failed to send ws message") - } - } -} - -func (u *User) wsListen(authToken string) { - defer close(u.iceCh) - - var wsConnID string - var originalConnID string - var wsServerSeq int64 - - connect := func() (*ws.Client, error) { - ws, err := ws.NewClient(&ws.ClientParams{ - WsURL: u.cfg.WsURL, - AuthToken: authToken, - ConnID: wsConnID, - ServerSequence: wsServerSeq, - }) - return ws, err - } - - ws, err := connect() - if err != nil { - log.Fatalf(err.Error()) - return - } - - defer func() { - err := ws.SendMessage("custom_com.mattermost.calls_leave", nil) - if err != nil { - log.Printf("%s: ws send error: %s", u.cfg.Username, err.Error()) - } - ws.Close() - }() - - for { - select { - case ev, ok := <-ws.EventChannel: - if !ok { - log.Printf("ws disconnected") - for { - time.Sleep(time.Second) - log.Printf("attempting ws reconnection") - ws, err = connect() - if err != nil { - log.Print(err.Error()) - continue - } - - data := map[string]interface{}{ - "channelID": u.cfg.ChannelID, - "originalConnID": originalConnID, - "prevConnID": wsConnID, - } - if err := ws.SendMessage("custom_com.mattermost.calls_reconnect", data); err != nil { - log.Printf("%s: ws send error: %s", u.cfg.Username, err.Error()) - continue - } - - break - } - continue - } - if ev.EventType() == "hello" { - if connID, ok := ev.GetData()["connection_id"].(string); ok { - if wsConnID != connID { - log.Printf("new connection id from server") - wsServerSeq = 0 - } - wsConnID = connID - if originalConnID == "" { - log.Printf("setting original conn id") - originalConnID = connID - - log.Printf("%s: joining call", u.cfg.Username) - data := map[string]interface{}{ - "channelID": u.cfg.ChannelID, - } - if err := ws.SendMessage("custom_com.mattermost.calls_join", data); err != nil { - log.Fatalf(err.Error()) - } - } - } - } - - if ev.GetSequence() != wsServerSeq { - log.Printf("missed websocket event") - return - } - - wsServerSeq = ev.GetSequence() + 1 - - if ev.EventType() == "custom_com.mattermost.calls_call_start" { - channelID, _ := ev.GetData()["channelID"].(string) - hostID, _ := ev.GetData()["host_id"].(string) - if channelID == u.cfg.ChannelID && hostID == u.userID { - log.Printf("%s: I am call host", u.cfg.Username) - u.isHost = true - } - continue - } - - if connID, ok := ev.GetData()["connID"].(string); !ok || (connID != wsConnID && connID != originalConnID) { - continue - } - - switch ev.EventType() { - case "custom_com.mattermost.calls_join": - log.Printf("%s: joined call", u.cfg.Username) - if err := u.initRTC(); err != nil { - log.Fatalf(err.Error()) - } - defer u.pc.Close() - case "custom_com.mattermost.calls_signal": - log.Printf("%s: received signal", u.cfg.Username) - select { - case <-u.initCh: - u.handleSignal(ev) - case <-time.After(2 * time.Second): - log.Printf("%s: timed out waiting for init", u.cfg.Username) - } - case "custom_com.mattermost.calls_call_end": - log.Printf("%s: call end event, exiting", u.cfg.Username) - return - default: - } - case msg, ok := <-u.wsSendCh: - if !ok { - return - } - if msg.binary { - if err := ws.SendBinaryMessage(msg.event, msg.data); err != nil { - log.Fatalf(err.Error()) - } - } else { - if err := ws.SendMessage(msg.event, msg.data); err != nil { - log.Fatalf(err.Error()) - } - } - case <-u.wsCloseCh: - return - case <-u.doneCh: - return + if u.cfg.Recording && u.hostID.Load() == u.userID { + log.Printf("%s: I am host, starting recording", u.cfg.Username) + if err := u.callsClient.StartRecording(); err != nil { + log.Fatalf("failed to start recording: %s", err.Error()) } } } @@ -881,12 +383,12 @@ func (u *User) Connect(stopCh chan struct{}) error { log.Printf("%s: connecting user", u.cfg.Username) var user *model.User - client := model.NewAPIv4Client(u.cfg.SiteURL) - u.client = client + apiClient := model.NewAPIv4Client(u.cfg.SiteURL) + u.apiClient = apiClient // login (or create) user ctx, cancel := context.WithTimeout(context.Background(), HTTPRequestTimeout) defer cancel() - user, _, err := client.Login(ctx, u.cfg.Username, u.cfg.Password) + user, _, err := apiClient.Login(ctx, u.cfg.Username, u.cfg.Password) appErr, ok := err.(*model.AppError) if err != nil && !ok { return err @@ -902,7 +404,7 @@ func (u *User) Connect(stopCh chan struct{}) error { log.Printf("%s: registering user", u.cfg.Username) ctx, cancel := context.WithTimeout(context.Background(), HTTPRequestTimeout) - _, _, err = client.CreateUser(ctx, &model.User{ + _, _, err = apiClient.CreateUser(ctx, &model.User{ Username: u.cfg.Username, Password: u.cfg.Password, Email: u.cfg.Username + "@example.com", @@ -913,7 +415,7 @@ func (u *User) Connect(stopCh chan struct{}) error { } ctx, cancel = context.WithTimeout(context.Background(), HTTPRequestTimeout) defer cancel() - user, _, err = client.Login(ctx, u.cfg.Username, u.cfg.Password) + user, _, err = apiClient.Login(ctx, u.cfg.Username, u.cfg.Password) if err != nil { return err } @@ -927,7 +429,7 @@ func (u *User) Connect(stopCh chan struct{}) error { if u.cfg.Setup { ctx, cancel = context.WithTimeout(context.Background(), HTTPRequestTimeout) defer cancel() - _, _, err = client.AddTeamMember(ctx, u.cfg.TeamID, user.Id) + _, _, err = apiClient.AddTeamMember(ctx, u.cfg.TeamID, user.Id) if err != nil { return err } @@ -935,7 +437,7 @@ func (u *User) Connect(stopCh chan struct{}) error { ctx, cancel = context.WithTimeout(context.Background(), HTTPRequestTimeout) defer cancel() - channel, _, err := client.GetChannel(ctx, u.cfg.ChannelID, "") + channel, _, err := apiClient.GetChannel(ctx, u.cfg.ChannelID, "") if err != nil { return err } @@ -945,7 +447,7 @@ func (u *User) Connect(stopCh chan struct{}) error { // join channel ctx, cancel = context.WithTimeout(context.Background(), HTTPRequestTimeout) defer cancel() - _, _, err = client.AddChannelMember(ctx, u.cfg.ChannelID, user.Id) + _, _, err = apiClient.AddChannelMember(ctx, u.cfg.ChannelID, user.Id) if err != nil { return err } @@ -953,26 +455,79 @@ func (u *User) Connect(stopCh chan struct{}) error { } } - log.Printf("%s: connecting to websocket", u.cfg.Username) + log.Printf("%s: creating calls client", u.cfg.Username) + + callsClient, err := client.New(client.Config{ + SiteURL: u.cfg.SiteURL, + AuthToken: apiClient.AuthToken, + ChannelID: u.cfg.ChannelID, + }) + if err != nil { + return fmt.Errorf("failed to create calls client: %w", err) + } + + callsConfig, err := callsClient.GetCallsConfig() + if err != nil { + return fmt.Errorf("failed to get calls config: %w", err) + } + + u.callsClient = callsClient + u.callsConfig = callsConfig + + log.Printf("%s: connecting to call", u.cfg.Username) - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - u.wsListen(client.AuthToken) - }() + var connectOnce sync.Once + err = callsClient.On(client.RTCConnectEvent, func(_ any) error { + log.Printf("%s: connected to call", u.cfg.Username) + connectOnce.Do(u.onConnect) + return nil + }) + if err != nil { + return fmt.Errorf("failed to subscribe to connect event: %w", err) + } + + closedCh := make(chan struct{}) + err = callsClient.On(client.CloseEvent, func(_ any) error { + log.Printf("%s: disconnected from call", u.cfg.Username) + close(closedCh) + return nil + }) + if err != nil { + return fmt.Errorf("failed to subscribe to close event: %w", err) + } + + err = callsClient.On(client.WSCallHostChangedEvent, func(ctx any) error { + u.hostID.Store(ctx.(string)) + return nil + }) + if err != nil { + return fmt.Errorf("failed to subscribe to host changed event: %w", err) + } + + if err := callsClient.Connect(); err != nil { + return fmt.Errorf("failed to connect: %w", err) + } ticker := time.NewTicker(u.cfg.Duration) defer ticker.Stop() select { case <-ticker.C: + case <-closedCh: case <-stopCh: } log.Printf("%s: disconnecting...", u.cfg.Username) - close(u.doneCh) - wg.Wait() + + if err := callsClient.Close(); err != nil { + return fmt.Errorf("failed to close calls client: %w", err) + } + + select { + case <-closedCh: + case <-time.After(10 * time.Second): + return fmt.Errorf("timed out waiting for close event") + } log.Printf("%s: disconnected", u.cfg.Username) diff --git a/lt/cmd/lt/main.go b/lt/cmd/lt/main.go index 9372fac24..aa42fc8e5 100644 --- a/lt/cmd/lt/main.go +++ b/lt/cmd/lt/main.go @@ -6,7 +6,6 @@ import ( "fmt" "log" "math/rand" - "net/url" "os" "os/signal" "sync" @@ -35,7 +34,6 @@ func main() { var numCalls int var numUsersPerCall int var numRecordings int - var simulcast bool var setup bool var speechFile string @@ -54,9 +52,8 @@ func main() { flag.StringVar(&joinDuration, "join-duration", "30s", "The amount of time it takes for all participants to join their calls") flag.StringVar(&adminUsername, "admin-username", "sysadmin", "The username of a system admin account") flag.StringVar(&adminPassword, "admin-password", "Sys@dmin-sample1", "The password of a system admin account") - flag.BoolVar(&simulcast, "simulcast", false, "Whether or not to enable simulcast for screen") flag.BoolVar(&setup, "setup", true, "Whether or not setup actions like creating users, channels, teams and/or members should be executed.") - flag.StringVar(&speechFile, "speech-file", "./lt/samples/speech_0.ogg", "The path to a speech OGG file to read to simulate real voice samples") + flag.StringVar(&speechFile, "speech-file", "./samples/speech_0.ogg", "The path to a speech OGG file to read to simulate real voice samples") flag.Parse() @@ -94,17 +91,6 @@ func main() { log.Fatalf(err.Error()) } - var wsURL string - u, err := url.Parse(siteURL) - if err != nil { - log.Fatalf(err.Error()) - } - if u.Scheme == "https" { - wsURL = "wss://" + u.Host - } else { - wsURL = "ws://" + u.Host - } - if numUnmuted > numUsersPerCall { log.Fatalf("unmuted cannot be greater than the number of users per call") } @@ -215,12 +201,10 @@ func main() { TeamID: teamID, ChannelID: channelID, SiteURL: siteURL, - WsURL: wsURL, Duration: dur, Unmuted: unmuted, ScreenSharing: screenSharing, Recording: recording, - Simulcast: simulcast, Setup: setup, SpeechFile: speechFile, } diff --git a/lt/cmd/speech/main.go b/lt/cmd/speech/main.go index 68bdc9d4f..2c2187a7c 100644 --- a/lt/cmd/speech/main.go +++ b/lt/cmd/speech/main.go @@ -36,84 +36,20 @@ func main() { flag.Parse() if channelID == "" { - log.Fatalf("need a --channelID flag") + log.Fatalf("need a -channelID flag") } - if script != "" { - if setup && teamID == "" { - log.Fatalf("need a --teamID flag") - } - - if err := performScript(script); err != nil { - log.Fatalf("error performing script: %v", err) - } - return + if script == "" { + log.Fatalf("need a -script flag") } - stopCh := make(chan struct{}) - var wg sync.WaitGroup - wg.Add(2) - - userA := client.NewUser(client.Config{ - Username: "testuser-0", - Password: userPassword, - ChannelID: channelID, - SiteURL: siteURL, - WsURL: wsURL, - Duration: duration, - Speak: true, - }) - go func() { - defer wg.Done() - if err := userA.Connect(stopCh); err != nil { - log.Fatalf("connectUser failed: %s", err.Error()) - } - }() - - userB := client.NewUser(client.Config{ - Username: "testuser-1", - Password: userPassword, - ChannelID: channelID, - SiteURL: siteURL, - WsURL: wsURL, - Duration: duration, - Speak: true, - }) - go func() { - defer wg.Done() - if err := userB.Connect(stopCh); err != nil { - log.Fatalf("connectUser failed: %s", err.Error()) - } - }() - - // "Conversation" logic - go func() { - time.Sleep(2 * time.Second) - - userA.Unmute() - doneCh := userA.Speak("Hi, this is user A") - <-doneCh - userA.Mute() - - userB.Unmute() - doneCh = userB.Speak("Hi user A, this is user B responding") - <-doneCh - userB.Mute() - - userA.Unmute() - doneCh = userA.Speak("Nice to meet you user B!") - <-doneCh - userA.Mute() - }() - - go func() { - sig := make(chan os.Signal, 1) - signal.Notify(sig, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) - <-sig - close(stopCh) - }() + if setup && teamID == "" { + log.Fatalf("need a -teamID flag") + } - wg.Wait() + if err := performScript(script); err != nil { + log.Fatalf("error performing script: %v", err) + } } func performScript(filename string) error { diff --git a/lt/go.mod b/lt/go.mod index a1b7bfa7c..11f7ed4a5 100644 --- a/lt/go.mod +++ b/lt/go.mod @@ -1,18 +1,14 @@ module github.com/mattermost/mattermost-plugin-calls/lt -go 1.21.5 +go 1.21.9 require ( github.com/aws/aws-sdk-go v1.50.3 - github.com/gorilla/websocket v1.5.1 github.com/hajimehoshi/go-mp3 v0.3.4 github.com/mattermost/mattermost/server/public v0.0.12 - github.com/pion/interceptor v0.1.28 - github.com/pion/rtcp v1.2.14 + github.com/mattermost/rtcd v0.14.1-0.20240416222725-6b418f9231d3 github.com/pion/rtp v1.8.5 github.com/pion/webrtc/v3 v3.2.37 - github.com/stretchr/testify v1.9.0 - github.com/vmihailenco/msgpack/v5 v5.4.1 gopkg.in/hraban/opus.v2 v2.0.0-20230925203106-0188a62cb302 ) @@ -23,6 +19,7 @@ require ( github.com/francoispqt/gojay v1.2.13 // indirect github.com/go-asn1-ber/asn1-ber v1.5.5 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/gorilla/websocket v1.5.1 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/mattermost/go-i18n v1.11.1-0.20211013152124-5c415071e404 // indirect github.com/mattermost/ldap v0.0.0-20231116144001-0f480c025956 // indirect @@ -33,9 +30,11 @@ require ( github.com/pion/datachannel v1.5.6 // indirect github.com/pion/dtls/v2 v2.2.10 // indirect github.com/pion/ice/v2 v2.3.14 // indirect + github.com/pion/interceptor v0.1.28 // indirect github.com/pion/logging v0.2.2 // indirect github.com/pion/mdns v0.0.12 // indirect github.com/pion/randutil v0.1.0 // indirect + github.com/pion/rtcp v1.2.14 // indirect github.com/pion/sctp v1.8.15 // indirect github.com/pion/sdp/v3 v3.0.9 // indirect github.com/pion/srtp/v2 v2.0.18 // indirect @@ -44,7 +43,9 @@ require ( github.com/pion/turn/v2 v2.1.5 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/testify v1.9.0 // indirect github.com/tinylib/msgp v1.1.9 // indirect + github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/wiggin77/merror v1.0.5 // indirect github.com/wiggin77/srslog v1.0.1 // indirect diff --git a/lt/go.sum b/lt/go.sum index 7a7b5be33..b7b715fbd 100644 --- a/lt/go.sum +++ b/lt/go.sum @@ -85,6 +85,8 @@ github.com/mattermost/logr/v2 v2.0.21 h1:CMHsP+nrbRlEC4g7BwOk1GAnMtHkniFhlSQPXy5 github.com/mattermost/logr/v2 v2.0.21/go.mod h1:kZkB/zqKL9e+RY5gB3vGpsyenC+TpuiOenjMkvJJbzc= github.com/mattermost/mattermost/server/public v0.0.12 h1:iunc9q4/XkArOrndEUn73uFw6v9TOEXEtp6Nm6Iv218= github.com/mattermost/mattermost/server/public v0.0.12/go.mod h1:Bk+atJcELCIk9Yeq5FoqTr+gra9704+X4amrlwlTgSc= +github.com/mattermost/rtcd v0.14.1-0.20240416222725-6b418f9231d3 h1:0YQB7sY8Qpg1BjGascPa5YUMm9CaKIhzLG6m8wD91PY= +github.com/mattermost/rtcd v0.14.1-0.20240416222725-6b418f9231d3/go.mod h1:CaDYwB/c+UC0MHs+nMZkE3p7Sd4f0Sh+66jjZNDVdRQ= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -151,8 +153,8 @@ github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXP github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= -github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= diff --git a/lt/ws/client.go b/lt/ws/client.go deleted file mode 100644 index 2724edbcd..000000000 --- a/lt/ws/client.go +++ /dev/null @@ -1,154 +0,0 @@ -package ws - -import ( - "bytes" - "fmt" - "net/http" - "sync" - - "github.com/mattermost/mattermost/server/public/model" - "github.com/mattermost/mattermost/server/public/shared/mlog" - - "github.com/gorilla/websocket" - "github.com/vmihailenco/msgpack/v5" -) - -const avgReadMsgSizeBytes = 1024 - -// Client is the websocket client to perform all actions. -type Client struct { - EventChannel chan *model.WebSocketEvent - - conn *websocket.Conn - authToken string - sequence int64 - readWg sync.WaitGroup - writeMut sync.RWMutex -} - -type ClientParams struct { - WsURL string - AuthToken string - ConnID string - ServerSequence int64 -} - -// NewClient constructs a new WebSocket client. -func NewClient(param *ClientParams) (*Client, error) { - header := http.Header{ - "Authorization": []string{"Bearer " + param.AuthToken}, - } - - url := param.WsURL + model.APIURLSuffix + "/websocket" + fmt.Sprintf("?connection_id=%s&sequence_number=%d", param.ConnID, param.ServerSequence) - conn, _, err := websocket.DefaultDialer.Dial(url, header) - if err != nil { - return nil, err - } - - client := &Client{ - EventChannel: make(chan *model.WebSocketEvent, 100), - - conn: conn, - authToken: param.AuthToken, - sequence: 1, - } - - client.readWg.Add(1) - go client.reader() - - return client, nil -} - -// Close closes the client. -func (c *Client) Close() { - // If Close gets called concurrently during the time - // a connection-break happens, this will become a no-op. - c.conn.Close() - // Wait for reader to return. - // If the reader has already quit, this will just fall-through. - c.readWg.Wait() -} - -func (c *Client) reader() { - defer func() { - close(c.EventChannel) - // Mark wg as Done. - c.readWg.Done() - }() - - var buf bytes.Buffer - buf.Grow(avgReadMsgSizeBytes) - - for { - // Reset buffer. - buf.Reset() - _, r, err := c.conn.NextReader() - if err != nil { - if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) { - // log error - mlog.Debug("error from conn.NextReader", mlog.Err(err)) - } - return - } - // Use pre-allocated buffer. - _, err = buf.ReadFrom(r) - if err != nil { - if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) { - // log error - mlog.Warn("error from buf.ReadFrom", mlog.Err(err)) - } - return - } - - event, err := model.WebSocketEventFromJSON(&buf) - if event == nil || err != nil { - continue - } - if event.IsValid() { - // non-blocking send in case event channel is full. - select { - case c.EventChannel <- event: - default: - } - } - } -} - -// SendMessage is the method to write to the websocket. -func (c *Client) SendMessage(action string, data map[string]interface{}) error { - // It uses a mutex to synchronize writes. - // Intentionally no atomics are used to perform additional state tracking. - // Therefore, we let it fail if the user tries to write again on a closed connection. - c.writeMut.Lock() - defer c.writeMut.Unlock() - - req := &model.WebSocketRequest{ - Seq: c.sequence, - Action: action, - Data: data, - } - - c.sequence++ - return c.conn.WriteJSON(req) -} - -// SendBinaryMessage is the method to write to the websocket using binary data type -// (MessagePack encoded). -func (c *Client) SendBinaryMessage(action string, data map[string]interface{}) error { - req := &model.WebSocketRequest{ - Seq: c.sequence, - Action: action, - Data: data, - } - - binaryData, err := msgpack.Marshal(req) - if err != nil { - return fmt.Errorf("failed to marshal request to msgpack: %w", err) - } - - c.writeMut.Lock() - defer c.writeMut.Unlock() - - c.sequence++ - return c.conn.WriteMessage(websocket.BinaryMessage, binaryData) -} diff --git a/lt/ws/client_test.go b/lt/ws/client_test.go deleted file mode 100644 index 50974303e..000000000 --- a/lt/ws/client_test.go +++ /dev/null @@ -1,260 +0,0 @@ -package ws - -import ( - "net/http" - "net/http/httptest" - "strings" - "sync" - "testing" - "time" - - "github.com/mattermost/mattermost/server/public/model" - - "github.com/gorilla/websocket" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/vmihailenco/msgpack/v5" -) - -func dummyWebsocketHandler(t *testing.T, wg *sync.WaitGroup) http.HandlerFunc { - return func(w http.ResponseWriter, req *http.Request) { - defer wg.Done() - upgrader := &websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - } - conn, err := upgrader.Upgrade(w, req, nil) - require.Nil(t, err) - var buf []byte - for { - _, buf, err = conn.ReadMessage() - if err != nil { - break - } - t.Logf("%s\n", buf) - err = conn.WriteMessage(websocket.TextMessage, []byte("hello world")) - if err != nil { - break - } - } - } -} - -// TestClose verifies that the client is properly and safely closed in all possible ways. -func TestClose(t *testing.T) { - var wg sync.WaitGroup - s := httptest.NewServer(dummyWebsocketHandler(t, &wg)) - defer func() { - wg.Wait() - s.Close() - }() - - checkEventChan := func(eventChan chan *model.WebSocketEvent) { - defer func() { - if x := recover(); x == nil { - require.Fail(t, "should have panicked due to closing a closed channel") - } - }() - close(eventChan) - } - - t.Run("Sudden", func(t *testing.T) { - wg.Add(1) - url := strings.Replace(s.URL, "http://", "ws://", 1) - c, err := NewClient(&ClientParams{ - WsURL: url, - AuthToken: "authToken", - }) - require.Nil(t, err) - - err = c.SendMessage("test_action", map[string]interface{}{"test": "data"}) - assert.Nil(t, err) - - err = c.conn.Close() - assert.Nil(t, err) - - // wait for a while for reader to exit - time.Sleep(200 * time.Millisecond) - - // Verify that event channel is closed. - checkEventChan(c.EventChannel) - }) - - t.Run("Normal", func(t *testing.T) { - wg.Add(1) - url := strings.Replace(s.URL, "http://", "ws://", 1) - c, err := NewClient(&ClientParams{ - WsURL: url, - AuthToken: "authToken", - }) - require.Nil(t, err) - - err = c.SendMessage("test_action", map[string]interface{}{"test": "data"}) - assert.Nil(t, err) - - c.Close() - - // Verify that event channel is closed. - checkEventChan(c.EventChannel) - }) - - t.Run("Concurrent", func(t *testing.T) { - wg.Add(1) - url := strings.Replace(s.URL, "http://", "ws://", 1) - c, err := NewClient(&ClientParams{ - WsURL: url, - AuthToken: "authToken", - }) - require.Nil(t, err) - - err = c.SendMessage("test_action", map[string]interface{}{"test": "data"}) - assert.Nil(t, err) - - var wg2 sync.WaitGroup - wg2.Add(2) - go func() { - defer wg2.Done() - c.Close() - }() - - go func() { - defer wg2.Done() - c.conn.Close() - }() - - wg2.Wait() - // Verify that event channel is closed. - checkEventChan(c.EventChannel) - }) -} - -// TestSendMessage verifies that there are no races or panics during message send -// in various conditions. -func TestSendMessage(t *testing.T) { - var wg sync.WaitGroup - s := httptest.NewServer(dummyWebsocketHandler(t, &wg)) - defer func() { - wg.Wait() - s.Close() - }() - - t.Run("SendAfterSuddenClose", func(t *testing.T) { - wg.Add(1) - url := strings.Replace(s.URL, "http://", "ws://", 1) - c, err := NewClient(&ClientParams{ - WsURL: url, - AuthToken: "authToken", - }) - require.Nil(t, err) - - err = c.SendMessage("test_action", map[string]interface{}{"test": "data"}) - assert.Nil(t, err) - - err = c.conn.Close() - assert.Nil(t, err) - - err = c.SendMessage("test_action", map[string]interface{}{"test": "data"}) - assert.NotNil(t, err) - }) - - t.Run("SendAfterClose", func(t *testing.T) { - wg.Add(1) - url := strings.Replace(s.URL, "http://", "ws://", 1) - c, err := NewClient(&ClientParams{ - WsURL: url, - AuthToken: "authToken", - }) - require.Nil(t, err) - - err = c.SendMessage("test_action", map[string]interface{}{"test": "data"}) - assert.Nil(t, err) - - c.Close() - - err = c.SendMessage("test_action", map[string]interface{}{"test": "data"}) - assert.NotNil(t, err) - }) - - t.Run("SendDuringSuddenClose", func(t *testing.T) { - wg.Add(1) - url := strings.Replace(s.URL, "http://", "ws://", 1) - c, err := NewClient(&ClientParams{ - WsURL: url, - AuthToken: "authToken", - }) - require.Nil(t, err) - - err = c.SendMessage("test_action", map[string]interface{}{"test": "data"}) - assert.Nil(t, err) - - go func() { - _ = c.SendMessage("test_action", map[string]interface{}{"test": "data"}) - }() - - err = c.conn.Close() - assert.Nil(t, err) - }) - - t.Run("SendDuringClose", func(t *testing.T) { - wg.Add(1) - url := strings.Replace(s.URL, "http://", "ws://", 1) - c, err := NewClient(&ClientParams{ - WsURL: url, - AuthToken: "authToken", - }) - require.Nil(t, err) - - err = c.SendMessage("test_action", map[string]interface{}{"test": "data"}) - assert.Nil(t, err) - - go func() { - _ = c.SendMessage("test_action", map[string]interface{}{"test": "data"}) - }() - - c.Close() - }) -} - -func TestSendBinaryMessage(t *testing.T) { - var wg sync.WaitGroup - inputData := map[string]interface{}{ - "data": "testing binary data", - } - - wsHandler := func(w http.ResponseWriter, req *http.Request) { - defer wg.Done() - upgrader := &websocket.Upgrader{} - conn, err := upgrader.Upgrade(w, req, nil) - require.NoError(t, err) - for { - msgType, buf, err := conn.ReadMessage() - if err != nil { - break - } - require.Equal(t, websocket.BinaryMessage, msgType) - var outputData map[string]interface{} - err = msgpack.Unmarshal(buf, &outputData) - require.NoError(t, err) - require.Equal(t, "test_action", outputData["action"]) - require.Equal(t, inputData, outputData["data"]) - } - } - - wg.Add(1) - s := httptest.NewServer(http.HandlerFunc(wsHandler)) - defer func() { - wg.Wait() - s.Close() - }() - - url := strings.Replace(s.URL, "http://", "ws://", 1) - c, err := NewClient(&ClientParams{ - WsURL: url, - AuthToken: "authToken", - }) - require.Nil(t, err) - - err = c.SendBinaryMessage("test_action", inputData) - require.NoError(t, err) - c.Close() -}