Skip to content

Commit 2cbcdc3

Browse files
committed
improved code quality and style
1 parent 96ab282 commit 2cbcdc3

File tree

1 file changed

+75
-60
lines changed

1 file changed

+75
-60
lines changed

src/diart/websockets.py

+75-60
Original file line numberDiff line numberDiff line change
@@ -124,24 +124,26 @@ def _on_connect(self, client: Dict[Text, Any], server: WebsocketServer) -> None:
124124
client_id = client["id"]
125125
logger.info(f"New client connected: {client_id}")
126126

127-
if client_id not in self._clients:
128-
try:
129-
self._clients[client_id] = self._create_client_state(client_id)
127+
if client_id in self._clients:
128+
return
130129

131-
# Setup RTTM response hook
132-
self._clients[client_id].inference.attach_hooks(
133-
lambda ann_wav: self.send(client_id, ann_wav[0].to_rttm())
134-
)
130+
try:
131+
self._clients[client_id] = self._create_client_state(client_id)
135132

136-
# Start inference
137-
self._clients[client_id].inference()
138-
logger.info(f"Started inference for client: {client_id}")
133+
# Setup RTTM response hook
134+
self._clients[client_id].inference.attach_hooks(
135+
lambda ann_wav: self.send(client_id, ann_wav[0].to_rttm())
136+
)
139137

140-
# Send ready notification to client
141-
self.send(client_id, "READY")
142-
except Exception as e:
143-
logger.error(f"Failed to initialize client {client_id}: {e}")
144-
self.close(client_id)
138+
# Start inference
139+
self._clients[client_id].inference()
140+
logger.info(f"Started inference for client: {client_id}")
141+
142+
# Send ready notification to client
143+
self.send(client_id, "READY")
144+
except Exception as e:
145+
logger.error(f"Failed to initialize client {client_id}: {e}")
146+
self.close(client_id)
145147

146148
def _on_disconnect(self, client: Dict[Text, Any], server: WebsocketServer) -> None:
147149
"""Handle client disconnection.
@@ -172,16 +174,19 @@ def _on_message_received(
172174
Received message data
173175
"""
174176
client_id = client["id"]
175-
if client_id in self._clients:
176-
try:
177-
self._clients[client_id].audio_source.process_message(message)
178-
except (socket.error, ConnectionError) as e:
179-
logger.warning(f"Client {client_id} disconnected: {e}")
180-
self.close(client_id)
181-
except Exception as e:
182-
logger.error(f"Error processing message from client {client_id}: {e}")
183-
# Don't close the connection for non-connection related errors
184-
# This allows the client to retry sending the message
177+
178+
if client_id not in self._clients:
179+
return
180+
181+
try:
182+
self._clients[client_id].audio_source.process_message(message)
183+
except (socket.error, ConnectionError) as e:
184+
logger.warning(f"Client {client_id} disconnected: {e}")
185+
self.close(client_id)
186+
except Exception as e:
187+
logger.error(f"Error processing message from client {client_id}: {e}")
188+
# Don't close the connection for non-connection related errors
189+
# This allows the client to retry sending the message
185190

186191
def send(self, client_id: Text, message: AnyStr) -> None:
187192
"""Send a message to a specific client.
@@ -198,16 +203,18 @@ def send(self, client_id: Text, message: AnyStr) -> None:
198203

199204
client = next((c for c in self.server.clients if c["id"] == client_id), None)
200205

201-
if client is not None:
202-
try:
203-
self.server.send_message(client, message)
204-
except (socket.error, ConnectionError) as e:
205-
logger.warning(
206-
f"Client {client_id} disconnected while sending message: {e}"
207-
)
208-
self.close(client_id)
209-
except Exception as e:
210-
logger.error(f"Failed to send message to client {client_id}: {e}")
206+
if client is None:
207+
return
208+
209+
try:
210+
self.server.send_message(client, message)
211+
except (socket.error, ConnectionError) as e:
212+
logger.warning(
213+
f"Client {client_id} disconnected while sending message: {e}"
214+
)
215+
self.close(client_id)
216+
except Exception as e:
217+
logger.error(f"Failed to send message to client {client_id}: {e}")
211218

212219
def run(self) -> None:
213220
"""Start the WebSocket server."""
@@ -242,39 +249,47 @@ def close(self, client_id: Text) -> None:
242249
client_id : Text
243250
Client identifier to close
244251
"""
245-
if client_id in self._clients:
246-
try:
247-
# Clean up pipeline state using built-in reset method
248-
client_state = self._clients[client_id]
249-
client_state.inference.pipeline.reset()
250-
251-
# Close audio source and remove client
252-
client_state.audio_source.close()
253-
del self._clients[client_id]
254-
255-
# Try to send a close frame to the client
256-
try:
257-
client = next(
258-
(c for c in self.server.clients if c["id"] == client_id), None
259-
)
260-
if client:
261-
self.server.send_message(client, "CLOSE")
262-
except Exception:
263-
pass # Ignore errors when trying to send close message
252+
if client_id not in self._clients:
253+
return
264254

265-
logger.info(
266-
f"Closed connection and cleaned up state for client: {client_id}"
255+
try:
256+
# Clean up pipeline state using built-in reset method
257+
client_state = self._clients[client_id]
258+
client_state.inference.pipeline.reset()
259+
260+
# Close audio source and remove client
261+
client_state.audio_source.close()
262+
del self._clients[client_id]
263+
264+
# Try to send a close frame to the client
265+
client = next((c for c in self.server.clients if c["id"] == client_id), None)
266+
267+
if client is None:
268+
return
269+
270+
try:
271+
self.server.send_message(client, "CLOSE")
272+
except (socket.error, ConnectionError) as e:
273+
logger.warning(
274+
f"Client {client_id} disconnected while sending message: {e}"
267275
)
276+
self.close(client_id)
268277
except Exception as e:
269-
logger.error(f"Error closing client {client_id}: {e}")
270-
# Ensure client is removed from dictionary even if cleanup fails
271-
self._clients.pop(client_id, None)
278+
logger.error(f"Failed to send message to client {client_id}: {e}")
279+
280+
logger.info(
281+
f"Closed connection and cleaned up state for client: {client_id}"
282+
)
283+
except Exception as e:
284+
logger.error(f"Error closing client {client_id}: {e}")
285+
# Ensure client is removed from dictionary even if cleanup fails
286+
self._clients.pop(client_id, None)
272287

273288
def close_all(self) -> None:
274289
"""Shutdown the server and cleanup all client connections."""
275290
logger.info("Shutting down server...")
276291
try:
277-
for client_id in list(self._clients.keys()):
292+
for client_id in self._clients.keys():
278293
self.close(client_id)
279294
if self.server is not None:
280295
self.server.shutdown_gracefully()

0 commit comments

Comments
 (0)