1
+ import logging
2
+ import socket
1
3
from dataclasses import dataclass
2
4
from pathlib import Path
3
- from typing import Union , Text , Optional , AnyStr , Dict , Any , Callable
4
- import logging
5
+ from typing import Any , AnyStr , Callable , Dict , Optional , Text , Union
6
+
5
7
from websocket_server import WebsocketServer
6
- import socket
7
8
8
9
from . import blocks
9
10
from . import sources as src
12
13
13
14
# Configure logging
14
15
logging .basicConfig (
15
- level = logging .INFO ,
16
- format = '%(asctime)s - %(levelname)s - %(message)s'
16
+ level = logging .INFO , format = "%(asctime)s - %(levelname)s - %(message)s"
17
17
)
18
18
logger = logging .getLogger (__name__ )
19
19
@@ -29,6 +29,7 @@ class WebSocketAudioSourceConfig:
29
29
sample_rate : int
30
30
Audio sample rate in Hz
31
31
"""
32
+
32
33
uri : str
33
34
sample_rate : int = 16000
34
35
@@ -52,6 +53,7 @@ class StreamingInferenceConfig:
52
53
progress_bar : Optional[ProgressBar]
53
54
Custom progress bar implementation
54
55
"""
56
+
55
57
pipeline : blocks .Pipeline
56
58
batch_size : int = 1
57
59
do_profile : bool = True
@@ -63,6 +65,7 @@ class StreamingInferenceConfig:
63
65
@dataclass
64
66
class ClientState :
65
67
"""Represents the state of a connected client."""
68
+
66
69
audio_source : src .WebSocketAudioSource
67
70
inference : StreamingInference
68
71
@@ -102,7 +105,7 @@ def __init__(
102
105
self .sample_rate = sample_rate
103
106
self .host = host
104
107
self .port = port
105
-
108
+
106
109
# Server configuration
107
110
self .uri = f"{ host } :{ port } "
108
111
self ._clients : Dict [Text , ClientState ] = {}
@@ -132,16 +135,16 @@ def _create_client_state(self, client_id: Text) -> ClientState:
132
135
"""
133
136
# Create a new pipeline instance with the same config
134
137
# This ensures each client has its own state while sharing model weights
135
- pipeline = self .inference_config .pipeline .__class__ (self .inference_config .pipeline .config )
136
-
138
+ pipeline = self .inference_config .pipeline .__class__ (
139
+ self .inference_config .pipeline .config
140
+ )
141
+
137
142
audio_config = WebSocketAudioSourceConfig (
138
- uri = f"{ self .uri } :{ client_id } " ,
139
- sample_rate = self .sample_rate
143
+ uri = f"{ self .uri } :{ client_id } " , sample_rate = self .sample_rate
140
144
)
141
-
145
+
142
146
audio_source = src .WebSocketAudioSource (
143
- uri = audio_config .uri ,
144
- sample_rate = audio_config .sample_rate
147
+ uri = audio_config .uri , sample_rate = audio_config .sample_rate
145
148
)
146
149
147
150
inference = StreamingInference (
@@ -151,7 +154,7 @@ def _create_client_state(self, client_id: Text) -> ClientState:
151
154
do_profile = self .inference_config .do_profile ,
152
155
do_plot = self .inference_config .do_plot ,
153
156
show_progress = self .inference_config .show_progress ,
154
- progress_bar = self .inference_config .progress_bar
157
+ progress_bar = self .inference_config .progress_bar ,
155
158
)
156
159
157
160
return ClientState (audio_source = audio_source , inference = inference )
@@ -182,7 +185,7 @@ def _on_connect(self, client: Dict[Text, Any], server: WebsocketServer) -> None:
182
185
# Start inference
183
186
client_state .inference ()
184
187
logger .info (f"Started inference for client: { client_id } " )
185
-
188
+
186
189
# Send ready notification to client
187
190
self .send (client_id , "READY" )
188
191
except Exception as e :
@@ -204,10 +207,7 @@ def _on_disconnect(self, client: Dict[Text, Any], server: WebsocketServer) -> No
204
207
self .close (client_id )
205
208
206
209
def _on_message_received (
207
- self ,
208
- client : Dict [Text , Any ],
209
- server : WebsocketServer ,
210
- message : AnyStr
210
+ self , client : Dict [Text , Any ], server : WebsocketServer , message : AnyStr
211
211
) -> None :
212
212
"""Process incoming client messages.
213
213
@@ -245,16 +245,15 @@ def send(self, client_id: Text, message: AnyStr) -> None:
245
245
if not message :
246
246
return
247
247
248
- client = next (
249
- (c for c in self .server .clients if c ["id" ] == client_id ),
250
- None
251
- )
252
-
248
+ client = next ((c for c in self .server .clients if c ["id" ] == client_id ), None )
249
+
253
250
if client is not None :
254
251
try :
255
252
self .server .send_message (client , message )
256
253
except (socket .error , ConnectionError ) as e :
257
- logger .warning (f"Client { client_id } disconnected while sending message: { e } " )
254
+ logger .warning (
255
+ f"Client { client_id } disconnected while sending message: { e } "
256
+ )
258
257
self .close (client_id )
259
258
except Exception as e :
260
259
logger .error (f"Failed to send message to client { client_id } : { e } " )
@@ -264,7 +263,7 @@ def run(self) -> None:
264
263
logger .info (f"Starting WebSocket server on { self .uri } " )
265
264
max_retries = 3
266
265
retry_count = 0
267
-
266
+
268
267
while retry_count < max_retries :
269
268
try :
270
269
self .server .run_forever ()
@@ -273,7 +272,9 @@ def run(self) -> None:
273
272
logger .warning (f"WebSocket connection error: { e } " )
274
273
retry_count += 1
275
274
if retry_count < max_retries :
276
- logger .info (f"Attempting to restart server (attempt { retry_count + 1 } /{ max_retries } )" )
275
+ logger .info (
276
+ f"Attempting to restart server (attempt { retry_count + 1 } /{ max_retries } )"
277
+ )
277
278
else :
278
279
logger .error ("Max retry attempts reached. Server shutting down." )
279
280
except Exception as e :
@@ -295,20 +296,24 @@ def close(self, client_id: Text) -> None:
295
296
# Clean up pipeline state using built-in reset method
296
297
client_state = self ._clients [client_id ]
297
298
client_state .inference .pipeline .reset ()
298
-
299
+
299
300
# Close audio source and remove client
300
301
client_state .audio_source .close ()
301
302
del self ._clients [client_id ]
302
-
303
+
303
304
# Try to send a close frame to the client
304
305
try :
305
- client = next ((c for c in self .server .clients if c ["id" ] == client_id ), None )
306
+ client = next (
307
+ (c for c in self .server .clients if c ["id" ] == client_id ), None
308
+ )
306
309
if client :
307
310
self .server .send_message (client , "CLOSE" )
308
311
except Exception :
309
312
pass # Ignore errors when trying to send close message
310
-
311
- logger .info (f"Closed connection and cleaned up state for client: { client_id } " )
313
+
314
+ logger .info (
315
+ f"Closed connection and cleaned up state for client: { client_id } "
316
+ )
312
317
except Exception as e :
313
318
logger .error (f"Error closing client { client_id } : { e } " )
314
319
# Ensure client is removed from dictionary even if cleanup fails
0 commit comments