Skip to content

Commit 96ab282

Browse files
committed
simplified websocket-server class and improved naming
1 parent 7ba2f55 commit 96ab282

File tree

2 files changed

+24
-59
lines changed

2 files changed

+24
-59
lines changed

src/diart/console/serve.py

+4-13
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from diart import argdoc
77
from diart import models as m
88
from diart import utils
9-
from diart.handler import StreamingHandlerConfig, StreamingHandler
9+
from diart.websockets import StreamingHandlerConfig, WebSocketStreamingServer
1010

1111

1212
def run():
@@ -98,24 +98,15 @@ def run():
9898
pipeline_class = utils.get_pipeline_class(args.pipeline)
9999
pipeline_config = pipeline_class.get_config_class()(**vars(args))
100100

101-
# Create handler configuration for inference
102-
config = StreamingHandlerConfig(
101+
# Initialize Websocket server
102+
server = WebSocketStreamingServer(
103103
pipeline_class=pipeline_class,
104104
pipeline_config=pipeline_config,
105-
batch_size=1,
106-
do_profile=False,
107-
do_plot=False,
108-
show_progress=False,
109-
)
110-
111-
# Initialize handler
112-
handler = StreamingHandler(
113-
config=config,
114105
host=args.host,
115106
port=args.port,
116107
)
117108

118-
handler.run()
109+
server.run()
119110

120111

121112
if __name__ == "__main__":

src/diart/handler.py src/diart/websockets.py

+20-46
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from . import blocks
1010
from . import sources as src
1111
from .inference import StreamingInference
12-
from .progress import ProgressBar, RichProgressBar
1312

1413
# Configure logging
1514
logging.basicConfig(
@@ -18,37 +17,6 @@
1817
logger = logging.getLogger(__name__)
1918

2019

21-
@dataclass
22-
class StreamingHandlerConfig:
23-
"""Configuration for streaming inference.
24-
25-
Parameters
26-
----------
27-
pipeline_class : type
28-
Pipeline class
29-
pipeline_config : blocks.PipelineConfig
30-
Pipeline configuration
31-
batch_size : int
32-
Number of inputs to process at once
33-
do_profile : bool
34-
Enable processing time profiling
35-
do_plot : bool
36-
Enable real-time prediction plotting
37-
show_progress : bool
38-
Display progress bar
39-
progress_bar : Optional[ProgressBar]
40-
Custom progress bar implementation
41-
"""
42-
43-
pipeline_class: type
44-
pipeline_config: blocks.PipelineConfig
45-
batch_size: int = 1
46-
do_profile: bool = True
47-
do_plot: bool = False
48-
show_progress: bool = True
49-
progress_bar: Optional[ProgressBar] = None
50-
51-
5220
@dataclass
5321
class ClientState:
5422
"""Represents the state of a connected client."""
@@ -57,16 +25,18 @@ class ClientState:
5725
inference: StreamingInference
5826

5927

60-
class StreamingHandler:
28+
class WebSocketStreamingServer:
6129
"""Handles real-time speaker diarization inference for multiple audio sources over WebSocket.
6230
6331
This handler manages WebSocket connections from multiple clients, processing
6432
audio streams and performing speaker diarization in real-time.
6533
6634
Parameters
6735
----------
68-
config : StreamingHandlerConfig
69-
Streaming inference configuration
36+
pipeline_class : type
37+
Pipeline class
38+
pipeline_config : blocks.PipelineConfig
39+
Pipeline configuration
7040
host : str, optional
7141
WebSocket server host, by default "127.0.0.1"
7242
port : int, optional
@@ -79,17 +49,20 @@ class StreamingHandler:
7949

8050
def __init__(
8151
self,
82-
config: StreamingHandlerConfig,
52+
pipeline_class: type,
53+
pipeline_config: blocks.PipelineConfig,
8354
host: Text = "127.0.0.1",
8455
port: int = 7007,
8556
key: Optional[Union[Text, Path]] = None,
8657
certificate: Optional[Union[Text, Path]] = None,
8758
):
88-
self.config = config
89-
self.host = host
90-
self.port = port
59+
# Pipeline configuration
60+
self.pipeline_class = pipeline_class
61+
self.pipeline_config = pipeline_config
9162

9263
# Server configuration
64+
self.host = host
65+
self.port = port
9366
self.uri = f"{host}:{port}"
9467
self._clients: Dict[Text, ClientState] = {}
9568

@@ -118,21 +91,22 @@ def _create_client_state(self, client_id: Text) -> ClientState:
11891
"""
11992
# Create a new pipeline instance with the same config
12093
# This ensures each client has its own state while sharing model weights
121-
pipeline = self.config.pipeline_class(self.config.pipeline_config)
94+
pipeline = self.pipeline_class(self.pipeline_config)
12295

12396
audio_source = src.WebSocketAudioSource(
12497
uri=f"{self.uri}:{client_id}",
125-
sample_rate=self.config.pipeline_config.sample_rate,
98+
sample_rate=self.pipeline_config.sample_rate,
12699
)
127100

128101
inference = StreamingInference(
129102
pipeline=pipeline,
130103
source=audio_source,
131-
batch_size=self.config.batch_size,
132-
do_profile=self.config.do_profile,
133-
do_plot=self.config.do_plot,
134-
show_progress=self.config.show_progress,
135-
progress_bar=self.config.progress_bar,
104+
# The following variables are fixed for a client
105+
batch_size=1,
106+
do_profile=False, # for minimal latency
107+
do_plot=False,
108+
show_progress=False,
109+
progress_bar=None,
136110
)
137111

138112
return ClientState(audio_source=audio_source, inference=inference)

0 commit comments

Comments
 (0)