Skip to content

Commit 1975f75

Browse files
committed
apply styling with black and isort
1 parent fb9fecf commit 1975f75

File tree

4 files changed

+43
-38
lines changed

4 files changed

+43
-38
lines changed

src/diart/console/client.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import argparse
22
from pathlib import Path
33
from threading import Thread
4-
from typing import Text, Optional
4+
from typing import Optional, Text
55

66
import rx.operators as ops
77
from websocket import WebSocket
@@ -66,7 +66,7 @@ def run():
6666
# Run websocket client
6767
ws = WebSocket()
6868
ws.connect(f"ws://{args.host}:{args.port}")
69-
69+
7070
# Wait for READY signal from server
7171
print("Waiting for server to be ready...", end="", flush=True)
7272
while True:
@@ -75,7 +75,7 @@ def run():
7575
print(" OK")
7676
break
7777
print(f"\nUnexpected message while waiting for READY: {message}")
78-
78+
7979
sender = Thread(
8080
target=send_audio, args=[ws, args.source, args.step, args.sample_rate]
8181
)

src/diart/console/serve.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from diart import argdoc
77
from diart import models as m
88
from diart import utils
9-
from diart.handler import StreamingInferenceHandler, StreamingInferenceConfig
9+
from diart.handler import StreamingInferenceConfig, StreamingInferenceHandler
1010

1111

1212
def run():

src/diart/handler.py

+37-32
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import logging
2+
import socket
13
from dataclasses import dataclass
24
from pathlib import Path
3-
from typing import Union, Text, Optional, AnyStr, Dict, Any, Callable
4-
import logging
5+
from typing import Any, AnyStr, Callable, Dict, Optional, Text, Union
6+
57
from websocket_server import WebsocketServer
6-
import socket
78

89
from . import blocks
910
from . import sources as src
@@ -12,8 +13,7 @@
1213

1314
# Configure logging
1415
logging.basicConfig(
15-
level=logging.INFO,
16-
format='%(asctime)s - %(levelname)s - %(message)s'
16+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
1717
)
1818
logger = logging.getLogger(__name__)
1919

@@ -29,6 +29,7 @@ class WebSocketAudioSourceConfig:
2929
sample_rate : int
3030
Audio sample rate in Hz
3131
"""
32+
3233
uri: str
3334
sample_rate: int = 16000
3435

@@ -52,6 +53,7 @@ class StreamingInferenceConfig:
5253
progress_bar : Optional[ProgressBar]
5354
Custom progress bar implementation
5455
"""
56+
5557
pipeline: blocks.Pipeline
5658
batch_size: int = 1
5759
do_profile: bool = True
@@ -63,6 +65,7 @@ class StreamingInferenceConfig:
6365
@dataclass
6466
class ClientState:
6567
"""Represents the state of a connected client."""
68+
6669
audio_source: src.WebSocketAudioSource
6770
inference: StreamingInference
6871

@@ -102,7 +105,7 @@ def __init__(
102105
self.sample_rate = sample_rate
103106
self.host = host
104107
self.port = port
105-
108+
106109
# Server configuration
107110
self.uri = f"{host}:{port}"
108111
self._clients: Dict[Text, ClientState] = {}
@@ -132,16 +135,16 @@ def _create_client_state(self, client_id: Text) -> ClientState:
132135
"""
133136
# Create a new pipeline instance with the same config
134137
# This ensures each client has its own state while sharing model weights
135-
pipeline = self.inference_config.pipeline.__class__(self.inference_config.pipeline.config)
136-
138+
pipeline = self.inference_config.pipeline.__class__(
139+
self.inference_config.pipeline.config
140+
)
141+
137142
audio_config = WebSocketAudioSourceConfig(
138-
uri=f"{self.uri}:{client_id}",
139-
sample_rate=self.sample_rate
143+
uri=f"{self.uri}:{client_id}", sample_rate=self.sample_rate
140144
)
141-
145+
142146
audio_source = src.WebSocketAudioSource(
143-
uri=audio_config.uri,
144-
sample_rate=audio_config.sample_rate
147+
uri=audio_config.uri, sample_rate=audio_config.sample_rate
145148
)
146149

147150
inference = StreamingInference(
@@ -151,7 +154,7 @@ def _create_client_state(self, client_id: Text) -> ClientState:
151154
do_profile=self.inference_config.do_profile,
152155
do_plot=self.inference_config.do_plot,
153156
show_progress=self.inference_config.show_progress,
154-
progress_bar=self.inference_config.progress_bar
157+
progress_bar=self.inference_config.progress_bar,
155158
)
156159

157160
return ClientState(audio_source=audio_source, inference=inference)
@@ -182,7 +185,7 @@ def _on_connect(self, client: Dict[Text, Any], server: WebsocketServer) -> None:
182185
# Start inference
183186
client_state.inference()
184187
logger.info(f"Started inference for client: {client_id}")
185-
188+
186189
# Send ready notification to client
187190
self.send(client_id, "READY")
188191
except Exception as e:
@@ -204,10 +207,7 @@ def _on_disconnect(self, client: Dict[Text, Any], server: WebsocketServer) -> No
204207
self.close(client_id)
205208

206209
def _on_message_received(
207-
self,
208-
client: Dict[Text, Any],
209-
server: WebsocketServer,
210-
message: AnyStr
210+
self, client: Dict[Text, Any], server: WebsocketServer, message: AnyStr
211211
) -> None:
212212
"""Process incoming client messages.
213213
@@ -245,16 +245,15 @@ def send(self, client_id: Text, message: AnyStr) -> None:
245245
if not message:
246246
return
247247

248-
client = next(
249-
(c for c in self.server.clients if c["id"] == client_id),
250-
None
251-
)
252-
248+
client = next((c for c in self.server.clients if c["id"] == client_id), None)
249+
253250
if client is not None:
254251
try:
255252
self.server.send_message(client, message)
256253
except (socket.error, ConnectionError) as e:
257-
logger.warning(f"Client {client_id} disconnected while sending message: {e}")
254+
logger.warning(
255+
f"Client {client_id} disconnected while sending message: {e}"
256+
)
258257
self.close(client_id)
259258
except Exception as e:
260259
logger.error(f"Failed to send message to client {client_id}: {e}")
@@ -264,7 +263,7 @@ def run(self) -> None:
264263
logger.info(f"Starting WebSocket server on {self.uri}")
265264
max_retries = 3
266265
retry_count = 0
267-
266+
268267
while retry_count < max_retries:
269268
try:
270269
self.server.run_forever()
@@ -273,7 +272,9 @@ def run(self) -> None:
273272
logger.warning(f"WebSocket connection error: {e}")
274273
retry_count += 1
275274
if retry_count < max_retries:
276-
logger.info(f"Attempting to restart server (attempt {retry_count + 1}/{max_retries})")
275+
logger.info(
276+
f"Attempting to restart server (attempt {retry_count + 1}/{max_retries})"
277+
)
277278
else:
278279
logger.error("Max retry attempts reached. Server shutting down.")
279280
except Exception as e:
@@ -295,20 +296,24 @@ def close(self, client_id: Text) -> None:
295296
# Clean up pipeline state using built-in reset method
296297
client_state = self._clients[client_id]
297298
client_state.inference.pipeline.reset()
298-
299+
299300
# Close audio source and remove client
300301
client_state.audio_source.close()
301302
del self._clients[client_id]
302-
303+
303304
# Try to send a close frame to the client
304305
try:
305-
client = next((c for c in self.server.clients if c["id"] == client_id), None)
306+
client = next(
307+
(c for c in self.server.clients if c["id"] == client_id), None
308+
)
306309
if client:
307310
self.server.send_message(client, "CLOSE")
308311
except Exception:
309312
pass # Ignore errors when trying to send close message
310-
311-
logger.info(f"Closed connection and cleaned up state for client: {client_id}")
313+
314+
logger.info(
315+
f"Closed connection and cleaned up state for client: {client_id}"
316+
)
312317
except Exception as e:
313318
logger.error(f"Error closing client {client_id}: {e}")
314319
# Ensure client is removed from dictionary even if cleanup fails

src/diart/sources.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import ABC, abstractmethod
22
from pathlib import Path
33
from queue import SimpleQueue
4-
from typing import Text, Optional, AnyStr, Dict, Any, Union, Tuple
4+
from typing import Any, AnyStr, Dict, Optional, Text, Tuple, Union
55

66
import numpy as np
77
import sounddevice as sd
@@ -12,7 +12,7 @@
1212
from websocket_server import WebsocketServer
1313

1414
from . import utils
15-
from .audio import FilePath, AudioLoader
15+
from .audio import AudioLoader, FilePath
1616

1717

1818
class AudioSource(ABC):

0 commit comments

Comments
 (0)