|
4 | 4 | from pathlib import Path
|
5 | 5 | from typing import Any, AnyStr, Callable, Dict, Optional, Text, Union
|
6 | 6 |
|
| 7 | +import numpy as np |
7 | 8 | from websocket_server import WebsocketServer
|
8 | 9 |
|
| 10 | +from . import utils |
9 | 11 | from . import blocks
|
10 | 12 | from . import sources as src
|
11 | 13 | from .inference import StreamingInference
|
|
17 | 19 | logger = logging.getLogger(__name__)
|
18 | 20 |
|
19 | 21 |
|
| 22 | +class ProxyAudioSource(src.AudioSource): |
| 23 | + """Represents a source of audio coming from the network using the WebSocket protocol. |
| 24 | +
|
| 25 | + Parameters |
| 26 | + ---------- |
| 27 | + sample_rate: int |
| 28 | + Sample rate of the chunks emitted. |
| 29 | + """ |
| 30 | + |
| 31 | + def __init__( |
| 32 | + self, |
| 33 | + uri: str, |
| 34 | + sample_rate: int, |
| 35 | + ): |
| 36 | + # FIXME sample_rate is not being used, this can be confusing and lead to incompatibilities. |
| 37 | + # I would prefer the client to send a JSON with data and sample rate, then resample if needed |
| 38 | + super().__init__(uri, sample_rate) |
| 39 | + |
| 40 | + def process_message(self, message: np.ndarray): |
| 41 | + """Process an incoming audio message.""" |
| 42 | + # Send audio to pipeline |
| 43 | + self.stream.on_next(message) |
| 44 | + |
| 45 | + def read(self): |
| 46 | + """Starts running the websocket server and listening for audio chunks""" |
| 47 | + pass |
| 48 | + |
| 49 | + def close(self): |
| 50 | + """Complete the audio stream for this client.""" |
| 51 | + self.stream.on_completed() |
| 52 | + |
| 53 | + |
20 | 54 | @dataclass
|
21 | 55 | class ClientState:
|
22 | 56 | """Represents the state of a connected client."""
|
23 | 57 |
|
24 |
| - audio_source: src.WebSocketAudioSource |
| 58 | + audio_source: ProxyAudioSource |
25 | 59 | inference: StreamingInference
|
26 | 60 |
|
27 | 61 |
|
@@ -93,7 +127,7 @@ def _create_client_state(self, client_id: Text) -> ClientState:
|
93 | 127 | # This ensures each client has its own state while sharing model weights
|
94 | 128 | pipeline = self.pipeline_class(self.pipeline_config)
|
95 | 129 |
|
96 |
| - audio_source = src.WebSocketAudioSource( |
| 130 | + audio_source = ProxyAudioSource( |
97 | 131 | uri=f"{self.uri}:{client_id}", sample_rate=self.pipeline_config.sample_rate,
|
98 | 132 | )
|
99 | 133 |
|
@@ -186,7 +220,9 @@ def _on_message_received(
|
186 | 220 | return
|
187 | 221 |
|
188 | 222 | try:
|
189 |
| - self._clients[client_id].audio_source.process_message(message) |
| 223 | + # decode message to audio |
| 224 | + decoded_audio = utils.decode_audio(message) |
| 225 | + self._clients[client_id].audio_source.process_message(decoded_audio) |
190 | 226 | except (socket.error, ConnectionError) as e:
|
191 | 227 | logger.warning(f"Client {client_id} disconnected: {e}")
|
192 | 228 | # Just cleanup since client is already disconnected
|
|
0 commit comments