|
4 | 4 | from pathlib import Path
|
5 | 5 | from typing import Any, AnyStr, Callable, Dict, Optional, Text, Union
|
6 | 6 |
|
| 7 | +import numpy as np |
7 | 8 | from websocket_server import WebsocketServer
|
8 | 9 |
|
| 10 | +from . import utils |
9 | 11 | from . import blocks
|
10 | 12 | from . import sources as src
|
11 | 13 | from .inference import StreamingInference
|
|
17 | 19 | logger = logging.getLogger(__name__)
|
18 | 20 |
|
19 | 21 |
|
| 22 | +class ProxyAudioSource(src.AudioSource): |
| 23 | + """A proxy audio source that forwards decoded audio chunks to a processing pipeline. |
| 24 | +
|
| 25 | + Parameters |
| 26 | + ---------- |
| 27 | + uri : str |
| 28 | + Unique identifier for this audio source |
| 29 | + sample_rate : int |
| 30 | + Expected sample rate of the audio chunks |
| 31 | + """ |
| 32 | + |
| 33 | + def __init__( |
| 34 | + self, |
| 35 | + uri: str, |
| 36 | + sample_rate: int, |
| 37 | + ): |
| 38 | + # FIXME sample_rate is not being used, this can be confusing and lead to incompatibilities. |
| 39 | + # I would prefer the client to send a JSON with data and sample rate, then resample if needed |
| 40 | + super().__init__(uri, sample_rate) |
| 41 | + |
| 42 | + def process_message(self, message: np.ndarray): |
| 43 | + """Process an incoming audio message.""" |
| 44 | + # Send audio to pipeline |
| 45 | + self.stream.on_next(message) |
| 46 | + |
| 47 | + def read(self): |
| 48 | + """Starts running the websocket server and listening for audio chunks""" |
| 49 | + pass |
| 50 | + |
| 51 | + def close(self): |
| 52 | + """Complete the audio stream for this client.""" |
| 53 | + self.stream.on_completed() |
| 54 | + |
| 55 | + |
20 | 56 | @dataclass
|
21 | 57 | class ClientState:
|
22 | 58 | """Represents the state of a connected client."""
|
23 | 59 |
|
24 |
| - audio_source: src.WebSocketAudioSource |
| 60 | + audio_source: ProxyAudioSource |
25 | 61 | inference: StreamingInference
|
26 | 62 |
|
27 | 63 |
|
@@ -93,7 +129,7 @@ def _create_client_state(self, client_id: Text) -> ClientState:
|
93 | 129 | # This ensures each client has its own state while sharing model weights
|
94 | 130 | pipeline = self.pipeline_class(self.pipeline_config)
|
95 | 131 |
|
96 |
| - audio_source = src.WebSocketAudioSource( |
| 132 | + audio_source = ProxyAudioSource( |
97 | 133 | uri=f"{self.uri}:{client_id}", sample_rate=self.pipeline_config.sample_rate,
|
98 | 134 | )
|
99 | 135 |
|
@@ -186,7 +222,9 @@ def _on_message_received(
|
186 | 222 | return
|
187 | 223 |
|
188 | 224 | try:
|
189 |
| - self._clients[client_id].audio_source.process_message(message) |
| 225 | + # decode message to audio |
| 226 | + decoded_audio = utils.decode_audio(message) |
| 227 | + self._clients[client_id].audio_source.process_message(decoded_audio) |
190 | 228 | except (socket.error, ConnectionError) as e:
|
191 | 229 | logger.warning(f"Client {client_id} disconnected: {e}")
|
192 | 230 | # Just cleanup since client is already disconnected
|
|
0 commit comments