Skip to content

Commit f95d6ca

Browse files
committed
refactor close and send methods of WebSocketStreamingServer to separate functionality
1 parent bba43ae commit f95d6ca

File tree

1 file changed

+49
-50
lines changed

1 file changed

+49
-50
lines changed

src/diart/websockets.py

+49-50
Original file line numberDiff line numberDiff line change
@@ -94,16 +94,15 @@ def _create_client_state(self, client_id: Text) -> ClientState:
9494
pipeline = self.pipeline_class(self.pipeline_config)
9595

9696
audio_source = src.WebSocketAudioSource(
97-
uri=f"{self.uri}:{client_id}",
98-
sample_rate=self.pipeline_config.sample_rate,
97+
uri=f"{self.uri}:{client_id}", sample_rate=self.pipeline_config.sample_rate,
9998
)
10099

101100
inference = StreamingInference(
102101
pipeline=pipeline,
103102
source=audio_source,
104103
# The following variables are fixed for a client
105104
batch_size=1,
106-
do_profile=False, # for minimal latency
105+
do_profile=False, # for minimal latency
107106
do_plot=False,
108107
show_progress=False,
109108
progress_bar=None,
@@ -143,6 +142,11 @@ def _on_connect(self, client: Dict[Text, Any], server: WebsocketServer) -> None:
143142
self.send(client_id, "READY")
144143
except Exception as e:
145144
logger.error(f"Failed to initialize client {client_id}: {e}")
145+
146+
# Send close notification to client
147+
self.send(client_id, "CLOSE")
148+
149+
# Close audio source and remove client
146150
self.close(client_id)
147151

148152
def _on_disconnect(self, client: Dict[Text, Any], server: WebsocketServer) -> None:
@@ -157,6 +161,11 @@ def _on_disconnect(self, client: Dict[Text, Any], server: WebsocketServer) -> No
157161
"""
158162
client_id = client["id"]
159163
logger.info(f"Client disconnected: {client_id}")
164+
165+
# Send close notification to client
166+
self.send(client_id, "CLOSE")
167+
168+
# Close audio source and remove client
160169
self.close(client_id)
161170

162171
def _on_message_received(
@@ -182,7 +191,13 @@ def _on_message_received(
182191
self._clients[client_id].audio_source.process_message(message)
183192
except (socket.error, ConnectionError) as e:
184193
logger.warning(f"Client {client_id} disconnected: {e}")
194+
195+
# Send close notification to client
196+
self.send(client_id, "CLOSE")
197+
198+
# Close audio source and remove client
185199
self.close(client_id)
200+
186201
except Exception as e:
187202
logger.error(f"Error processing message from client {client_id}: {e}")
188203
# Don't close the connection for non-connection related errors
@@ -202,45 +217,14 @@ def send(self, client_id: Text, message: AnyStr) -> None:
202217
return
203218

204219
client = next((c for c in self.server.clients if c["id"] == client_id), None)
205-
206220
if client is None:
207221
return
208222

209223
try:
210224
self.server.send_message(client, message)
211-
except (socket.error, ConnectionError) as e:
212-
logger.warning(
213-
f"Client {client_id} disconnected while sending message: {e}"
214-
)
215-
self.close(client_id)
216225
except Exception as e:
217226
logger.error(f"Failed to send message to client {client_id}: {e}")
218227

219-
def run(self) -> None:
220-
"""Start the WebSocket server."""
221-
logger.info(f"Starting WebSocket server on {self.uri}")
222-
max_retries = 3
223-
retry_count = 0
224-
225-
while retry_count < max_retries:
226-
try:
227-
self.server.run_forever()
228-
break # If server exits normally, break the retry loop
229-
except (socket.error, ConnectionError) as e:
230-
logger.warning(f"WebSocket connection error: {e}")
231-
retry_count += 1
232-
if retry_count < max_retries:
233-
logger.info(
234-
f"Attempting to restart server (attempt {retry_count + 1}/{max_retries})"
235-
)
236-
else:
237-
logger.error("Max retry attempts reached. Server shutting down.")
238-
except Exception as e:
239-
logger.error(f"Fatal server error: {e}")
240-
break
241-
finally:
242-
self.close_all()
243-
244228
def close(self, client_id: Text) -> None:
245229
"""Close a specific client's connection and cleanup resources.
246230
@@ -261,22 +245,6 @@ def close(self, client_id: Text) -> None:
261245
client_state.audio_source.close()
262246
del self._clients[client_id]
263247

264-
# Try to send a close frame to the client
265-
client = next((c for c in self.server.clients if c["id"] == client_id), None)
266-
267-
if client is None:
268-
return
269-
270-
try:
271-
self.server.send_message(client, "CLOSE")
272-
except (socket.error, ConnectionError) as e:
273-
logger.warning(
274-
f"Client {client_id} disconnected while sending message: {e}"
275-
)
276-
self.close(client_id)
277-
except Exception as e:
278-
logger.error(f"Failed to send message to client {client_id}: {e}")
279-
280248
logger.info(
281249
f"Closed connection and cleaned up state for client: {client_id}"
282250
)
@@ -290,9 +258,40 @@ def close_all(self) -> None:
290258
logger.info("Shutting down server...")
291259
try:
292260
for client_id in self._clients.keys():
261+
# Close audio source and remove client
293262
self.close(client_id)
263+
264+
# Send close notification to client
265+
self.send(client_id, "CLOSE")
266+
294267
if self.server is not None:
295268
self.server.shutdown_gracefully()
269+
296270
logger.info("Server shutdown complete")
297271
except Exception as e:
298272
logger.error(f"Error during shutdown: {e}")
273+
274+
def run(self) -> None:
275+
"""Start the WebSocket server."""
276+
logger.info(f"Starting WebSocket server on {self.uri}")
277+
max_retries = 3
278+
retry_count = 0
279+
280+
while retry_count < max_retries:
281+
try:
282+
self.server.run_forever()
283+
break # If server exits normally, break the retry loop
284+
except (socket.error, ConnectionError) as e:
285+
logger.warning(f"WebSocket connection error: {e}")
286+
retry_count += 1
287+
if retry_count < max_retries:
288+
logger.info(
289+
f"Attempting to restart server (attempt {retry_count + 1}/{max_retries})"
290+
)
291+
else:
292+
logger.error("Max retry attempts reached. Server shutting down.")
293+
except Exception as e:
294+
logger.error(f"Fatal server error: {e}")
295+
break
296+
finally:
297+
self.close_all()

0 commit comments

Comments
 (0)