Skip to content

Commit 604d140

Browse files
committed
refactor: handle audio message decoding in WebSocketStreamingServer and make WebSocketAudioSource a proxy class
1 parent c6a2f92 commit 604d140

File tree

2 files changed

+39
-35
lines changed

2 files changed

+39
-35
lines changed

src/diart/sources.py

-32
Original file line numberDiff line numberDiff line change
@@ -201,38 +201,6 @@ def close(self):
201201
self._mic_stream.close()
202202

203203

204-
class WebSocketAudioSource(AudioSource):
205-
"""Represents a source of audio coming from the network using the WebSocket protocol.
206-
207-
Parameters
208-
----------
209-
sample_rate: int
210-
Sample rate of the chunks emitted.
211-
"""
212-
213-
def __init__(
214-
self,
215-
uri: str,
216-
sample_rate: int,
217-
):
218-
# FIXME sample_rate is not being used, this can be confusing and lead to incompatibilities.
219-
# I would prefer the client to send a JSON with data and sample rate, then resample if needed
220-
super().__init__(uri, sample_rate)
221-
222-
def process_message(self, message: AnyStr):
223-
"""Decode and process an incoming audio message."""
224-
# Send decoded audio to pipeline
225-
self.stream.on_next(utils.decode_audio(message))
226-
227-
def read(self):
228-
"""Starts running the websocket server and listening for audio chunks"""
229-
pass
230-
231-
def close(self):
232-
"""Complete the audio stream for this client."""
233-
self.stream.on_completed()
234-
235-
236204
class TorchStreamAudioSource(AudioSource):
237205
def __init__(
238206
self,

src/diart/websockets.py

+39-3
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
from pathlib import Path
55
from typing import Any, AnyStr, Callable, Dict, Optional, Text, Union
66

7+
import numpy as np
78
from websocket_server import WebsocketServer
89

10+
from . import utils
911
from . import blocks
1012
from . import sources as src
1113
from .inference import StreamingInference
@@ -17,11 +19,43 @@
1719
logger = logging.getLogger(__name__)
1820

1921

22+
class ProxyAudioSource(src.AudioSource):
23+
"""Represents a source of audio coming from the network using the WebSocket protocol.
24+
25+
Parameters
26+
----------
27+
sample_rate: int
28+
Sample rate of the chunks emitted.
29+
"""
30+
31+
def __init__(
32+
self,
33+
uri: str,
34+
sample_rate: int,
35+
):
36+
# FIXME sample_rate is not being used, this can be confusing and lead to incompatibilities.
37+
# I would prefer the client to send a JSON with data and sample rate, then resample if needed
38+
super().__init__(uri, sample_rate)
39+
40+
def process_message(self, message: np.ndarray):
41+
"""Process an incoming audio message."""
42+
# Send audio to pipeline
43+
self.stream.on_next(message)
44+
45+
def read(self):
46+
"""Starts running the websocket server and listening for audio chunks"""
47+
pass
48+
49+
def close(self):
50+
"""Complete the audio stream for this client."""
51+
self.stream.on_completed()
52+
53+
2054
@dataclass
2155
class ClientState:
2256
"""Represents the state of a connected client."""
2357

24-
audio_source: src.WebSocketAudioSource
58+
audio_source: ProxyAudioSource
2559
inference: StreamingInference
2660

2761

@@ -93,7 +127,7 @@ def _create_client_state(self, client_id: Text) -> ClientState:
93127
# This ensures each client has its own state while sharing model weights
94128
pipeline = self.pipeline_class(self.pipeline_config)
95129

96-
audio_source = src.WebSocketAudioSource(
130+
audio_source = ProxyAudioSource(
97131
uri=f"{self.uri}:{client_id}", sample_rate=self.pipeline_config.sample_rate,
98132
)
99133

@@ -186,7 +220,9 @@ def _on_message_received(
186220
return
187221

188222
try:
189-
self._clients[client_id].audio_source.process_message(message)
223+
# decode message to audio
224+
decoded_audio = utils.decode_audio(message)
225+
self._clients[client_id].audio_source.process_message(decoded_audio)
190226
except (socket.error, ConnectionError) as e:
191227
logger.warning(f"Client {client_id} disconnected: {e}")
192228
# Just cleanup since client is already disconnected

0 commit comments

Comments
 (0)