Skip to content

Commit 12f7ba9

Browse files
committed
refactor: make WebSocketAudioSource a proxy and handle audio decoding in the server itself
1 parent c6a2f92 commit 12f7ba9

File tree

2 files changed

+41
-35
lines changed

2 files changed

+41
-35
lines changed

src/diart/sources.py

-32
Original file line numberDiff line numberDiff line change
@@ -201,38 +201,6 @@ def close(self):
201201
self._mic_stream.close()
202202

203203

204-
class WebSocketAudioSource(AudioSource):
205-
"""Represents a source of audio coming from the network using the WebSocket protocol.
206-
207-
Parameters
208-
----------
209-
sample_rate: int
210-
Sample rate of the chunks emitted.
211-
"""
212-
213-
def __init__(
214-
self,
215-
uri: str,
216-
sample_rate: int,
217-
):
218-
# FIXME sample_rate is not being used, this can be confusing and lead to incompatibilities.
219-
# I would prefer the client to send a JSON with data and sample rate, then resample if needed
220-
super().__init__(uri, sample_rate)
221-
222-
def process_message(self, message: AnyStr):
223-
"""Decode and process an incoming audio message."""
224-
# Send decoded audio to pipeline
225-
self.stream.on_next(utils.decode_audio(message))
226-
227-
def read(self):
228-
"""Starts running the websocket server and listening for audio chunks"""
229-
pass
230-
231-
def close(self):
232-
"""Complete the audio stream for this client."""
233-
self.stream.on_completed()
234-
235-
236204
class TorchStreamAudioSource(AudioSource):
237205
def __init__(
238206
self,

src/diart/websockets.py

+41-3
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
from pathlib import Path
55
from typing import Any, AnyStr, Callable, Dict, Optional, Text, Union
66

7+
import numpy as np
78
from websocket_server import WebsocketServer
89

10+
from . import utils
911
from . import blocks
1012
from . import sources as src
1113
from .inference import StreamingInference
@@ -17,11 +19,45 @@
1719
logger = logging.getLogger(__name__)
1820

1921

22+
class ProxyAudioSource(src.AudioSource):
23+
"""A proxy audio source that forwards decoded audio chunks to a processing pipeline.
24+
25+
Parameters
26+
----------
27+
uri : str
28+
Unique identifier for this audio source
29+
sample_rate : int
30+
Expected sample rate of the audio chunks
31+
"""
32+
33+
def __init__(
34+
self,
35+
uri: str,
36+
sample_rate: int,
37+
):
38+
# FIXME sample_rate is not being used, this can be confusing and lead to incompatibilities.
39+
# I would prefer the client to send a JSON with data and sample rate, then resample if needed
40+
super().__init__(uri, sample_rate)
41+
42+
def process_message(self, message: np.ndarray):
43+
"""Process an incoming audio message."""
44+
# Send audio to pipeline
45+
self.stream.on_next(message)
46+
47+
def read(self):
48+
"""Starts running the websocket server and listening for audio chunks"""
49+
pass
50+
51+
def close(self):
52+
"""Complete the audio stream for this client."""
53+
self.stream.on_completed()
54+
55+
2056
@dataclass
2157
class ClientState:
2258
"""Represents the state of a connected client."""
2359

24-
audio_source: src.WebSocketAudioSource
60+
audio_source: ProxyAudioSource
2561
inference: StreamingInference
2662

2763

@@ -93,7 +129,7 @@ def _create_client_state(self, client_id: Text) -> ClientState:
93129
# This ensures each client has its own state while sharing model weights
94130
pipeline = self.pipeline_class(self.pipeline_config)
95131

96-
audio_source = src.WebSocketAudioSource(
132+
audio_source = ProxyAudioSource(
97133
uri=f"{self.uri}:{client_id}", sample_rate=self.pipeline_config.sample_rate,
98134
)
99135

@@ -186,7 +222,9 @@ def _on_message_received(
186222
return
187223

188224
try:
189-
self._clients[client_id].audio_source.process_message(message)
225+
# decode message to audio
226+
decoded_audio = utils.decode_audio(message)
227+
self._clients[client_id].audio_source.process_message(decoded_audio)
190228
except (socket.error, ConnectionError) as e:
191229
logger.warning(f"Client {client_id} disconnected: {e}")
192230
# Just cleanup since client is already disconnected

0 commit comments

Comments
 (0)