Skip to content

Commit d4380c4

Browse files
committed
refactor StreamingHandler to use LazyModel for resource mgmt
1 parent 1975f75 commit d4380c4

File tree

2 files changed

+33
-57
lines changed

2 files changed

+33
-57
lines changed

src/diart/console/serve.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from diart import argdoc
77
from diart import models as m
88
from diart import utils
9-
from diart.handler import StreamingInferenceConfig, StreamingInferenceHandler
9+
from diart.handler import StreamingHandlerConfig, StreamingHandler
1010

1111

1212
def run():
@@ -94,24 +94,23 @@ def run():
9494
args.segmentation = m.SegmentationModel.from_pretrained(args.segmentation, hf_token)
9595
args.embedding = m.EmbeddingModel.from_pretrained(args.embedding, hf_token)
9696

97-
# Resolve pipeline
97+
# Resolve pipeline configuration
9898
pipeline_class = utils.get_pipeline_class(args.pipeline)
99-
config = pipeline_class.get_config_class()(**vars(args))
100-
pipeline = pipeline_class(config)
99+
pipeline_config = pipeline_class.get_config_class()(**vars(args))
101100

102-
# Create inference configuration
103-
inference_config = StreamingInferenceConfig(
104-
pipeline=pipeline,
101+
# Create handler configuration for inference
102+
config = StreamingHandlerConfig(
103+
pipeline_class=pipeline_class,
104+
pipeline_config=pipeline_config,
105105
batch_size=1,
106106
do_profile=False,
107107
do_plot=False,
108108
show_progress=False,
109109
)
110110

111-
# Initialize handler with new configuration
112-
handler = StreamingInferenceHandler(
113-
inference_config=inference_config,
114-
sample_rate=config.sample_rate,
111+
# Initialize handler
112+
handler = StreamingHandler(
113+
config=config,
115114
host=args.host,
116115
port=args.port,
117116
)
@@ -120,4 +119,4 @@ def run():
120119

121120

122121
if __name__ == "__main__":
123-
run()
122+
run()

src/diart/handler.py

+22-45
Original file line numberDiff line numberDiff line change
@@ -19,29 +19,15 @@
1919

2020

2121
@dataclass
22-
class WebSocketAudioSourceConfig:
23-
"""Configuration for WebSocket audio source.
24-
25-
Parameters
26-
----------
27-
uri : str
28-
WebSocket URI for the audio source
29-
sample_rate : int
30-
Audio sample rate in Hz
31-
"""
32-
33-
uri: str
34-
sample_rate: int = 16000
35-
36-
37-
@dataclass
38-
class StreamingInferenceConfig:
22+
class StreamingHandlerConfig:
3923
"""Configuration for streaming inference.
4024
4125
Parameters
4226
----------
43-
pipeline : blocks.Pipeline
44-
Diarization pipeline configuration
27+
pipeline_class : type
28+
Pipeline class
29+
pipeline_config : blocks.PipelineConfig
30+
Pipeline configuration
4531
batch_size : int
4632
Number of inputs to process at once
4733
do_profile : bool
@@ -54,7 +40,8 @@ class StreamingInferenceConfig:
5440
Custom progress bar implementation
5541
"""
5642

57-
pipeline: blocks.Pipeline
43+
pipeline_class: type
44+
pipeline_config: blocks.PipelineConfig
5845
batch_size: int = 1
5946
do_profile: bool = True
6047
do_plot: bool = False
@@ -70,18 +57,16 @@ class ClientState:
7057
inference: StreamingInference
7158

7259

73-
class StreamingInferenceHandler:
60+
class StreamingHandler:
7461
"""Handles real-time speaker diarization inference for multiple audio sources over WebSocket.
7562
7663
This handler manages WebSocket connections from multiple clients, processing
7764
audio streams and performing speaker diarization in real-time.
7865
7966
Parameters
8067
----------
81-
inference_config : StreamingInferenceConfig
68+
config : StreamingHandlerConfig
8269
Streaming inference configuration
83-
sample_rate : int, optional
84-
Audio sample rate in Hz, by default 16000
8570
host : str, optional
8671
WebSocket server host, by default "127.0.0.1"
8772
port : int, optional
@@ -94,15 +79,13 @@ class StreamingInferenceHandler:
9479

9580
def __init__(
9681
self,
97-
inference_config: StreamingInferenceConfig,
98-
sample_rate: int = 16000,
82+
config: StreamingHandlerConfig,
9983
host: Text = "127.0.0.1",
10084
port: int = 7007,
10185
key: Optional[Union[Text, Path]] = None,
10286
certificate: Optional[Union[Text, Path]] = None,
10387
):
104-
self.inference_config = inference_config
105-
self.sample_rate = sample_rate
88+
self.config = config
10689
self.host = host
10790
self.port = port
10891

@@ -135,26 +118,21 @@ def _create_client_state(self, client_id: Text) -> ClientState:
135118
"""
136119
# Create a new pipeline instance with the same config
137120
# This ensures each client has its own state while sharing model weights
138-
pipeline = self.inference_config.pipeline.__class__(
139-
self.inference_config.pipeline.config
140-
)
141-
142-
audio_config = WebSocketAudioSourceConfig(
143-
uri=f"{self.uri}:{client_id}", sample_rate=self.sample_rate
144-
)
121+
pipeline = self.config.pipeline_class(self.config.pipeline_config)
145122

146123
audio_source = src.WebSocketAudioSource(
147-
uri=audio_config.uri, sample_rate=audio_config.sample_rate
124+
uri=f"{self.uri}:{client_id}",
125+
sample_rate=self.config.pipeline_config.sample_rate,
148126
)
149127

150128
inference = StreamingInference(
151129
pipeline=pipeline,
152130
source=audio_source,
153-
batch_size=self.inference_config.batch_size,
154-
do_profile=self.inference_config.do_profile,
155-
do_plot=self.inference_config.do_plot,
156-
show_progress=self.inference_config.show_progress,
157-
progress_bar=self.inference_config.progress_bar,
131+
batch_size=self.config.batch_size,
132+
do_profile=self.config.do_profile,
133+
do_plot=self.config.do_plot,
134+
show_progress=self.config.show_progress,
135+
progress_bar=self.config.progress_bar,
158136
)
159137

160138
return ClientState(audio_source=audio_source, inference=inference)
@@ -174,16 +152,15 @@ def _on_connect(self, client: Dict[Text, Any], server: WebsocketServer) -> None:
174152

175153
if client_id not in self._clients:
176154
try:
177-
client_state = self._create_client_state(client_id)
178-
self._clients[client_id] = client_state
155+
self._clients[client_id] = self._create_client_state(client_id)
179156

180157
# Setup RTTM response hook
181-
client_state.inference.attach_hooks(
158+
self._clients[client_id].inference.attach_hooks(
182159
lambda ann_wav: self.send(client_id, ann_wav[0].to_rttm())
183160
)
184161

185162
# Start inference
186-
client_state.inference()
163+
self._clients[client_id].inference()
187164
logger.info(f"Started inference for client: {client_id}")
188165

189166
# Send ready notification to client

0 commit comments

Comments
 (0)