19
19
20
20
21
21
@dataclass
22
- class WebSocketAudioSourceConfig :
23
- """Configuration for WebSocket audio source.
24
-
25
- Parameters
26
- ----------
27
- uri : str
28
- WebSocket URI for the audio source
29
- sample_rate : int
30
- Audio sample rate in Hz
31
- """
32
-
33
- uri : str
34
- sample_rate : int = 16000
35
-
36
-
37
- @dataclass
38
- class StreamingInferenceConfig :
22
+ class StreamingHandlerConfig :
39
23
"""Configuration for streaming inference.
40
24
41
25
Parameters
42
26
----------
43
- pipeline : blocks.Pipeline
44
- Diarization pipeline configuration
27
+ pipeline_class : type
28
+ Pipeline class
29
+ pipeline_config : blocks.PipelineConfig
30
+ Pipeline configuration
45
31
batch_size : int
46
32
Number of inputs to process at once
47
33
do_profile : bool
@@ -54,7 +40,8 @@ class StreamingInferenceConfig:
54
40
Custom progress bar implementation
55
41
"""
56
42
57
- pipeline : blocks .Pipeline
43
+ pipeline_class : type
44
+ pipeline_config : blocks .PipelineConfig
58
45
batch_size : int = 1
59
46
do_profile : bool = True
60
47
do_plot : bool = False
@@ -70,18 +57,16 @@ class ClientState:
70
57
inference : StreamingInference
71
58
72
59
73
- class StreamingInferenceHandler :
60
+ class StreamingHandler :
74
61
"""Handles real-time speaker diarization inference for multiple audio sources over WebSocket.
75
62
76
63
This handler manages WebSocket connections from multiple clients, processing
77
64
audio streams and performing speaker diarization in real-time.
78
65
79
66
Parameters
80
67
----------
81
- inference_config : StreamingInferenceConfig
68
+ config : StreamingHandlerConfig
82
69
Streaming inference configuration
83
- sample_rate : int, optional
84
- Audio sample rate in Hz, by default 16000
85
70
host : str, optional
86
71
WebSocket server host, by default "127.0.0.1"
87
72
port : int, optional
@@ -94,15 +79,13 @@ class StreamingInferenceHandler:
94
79
95
80
def __init__ (
96
81
self ,
97
- inference_config : StreamingInferenceConfig ,
98
- sample_rate : int = 16000 ,
82
+ config : StreamingHandlerConfig ,
99
83
host : Text = "127.0.0.1" ,
100
84
port : int = 7007 ,
101
85
key : Optional [Union [Text , Path ]] = None ,
102
86
certificate : Optional [Union [Text , Path ]] = None ,
103
87
):
104
- self .inference_config = inference_config
105
- self .sample_rate = sample_rate
88
+ self .config = config
106
89
self .host = host
107
90
self .port = port
108
91
@@ -135,26 +118,21 @@ def _create_client_state(self, client_id: Text) -> ClientState:
135
118
"""
136
119
# Create a new pipeline instance with the same config
137
120
# This ensures each client has its own state while sharing model weights
138
- pipeline = self .inference_config .pipeline .__class__ (
139
- self .inference_config .pipeline .config
140
- )
141
-
142
- audio_config = WebSocketAudioSourceConfig (
143
- uri = f"{ self .uri } :{ client_id } " , sample_rate = self .sample_rate
144
- )
121
+ pipeline = self .config .pipeline_class (self .config .pipeline_config )
145
122
146
123
audio_source = src .WebSocketAudioSource (
147
- uri = audio_config .uri , sample_rate = audio_config .sample_rate
124
+ uri = f"{ self .uri } :{ client_id } " ,
125
+ sample_rate = self .config .pipeline_config .sample_rate ,
148
126
)
149
127
150
128
inference = StreamingInference (
151
129
pipeline = pipeline ,
152
130
source = audio_source ,
153
- batch_size = self .inference_config .batch_size ,
154
- do_profile = self .inference_config .do_profile ,
155
- do_plot = self .inference_config .do_plot ,
156
- show_progress = self .inference_config .show_progress ,
157
- progress_bar = self .inference_config .progress_bar ,
131
+ batch_size = self .config .batch_size ,
132
+ do_profile = self .config .do_profile ,
133
+ do_plot = self .config .do_plot ,
134
+ show_progress = self .config .show_progress ,
135
+ progress_bar = self .config .progress_bar ,
158
136
)
159
137
160
138
return ClientState (audio_source = audio_source , inference = inference )
@@ -174,16 +152,15 @@ def _on_connect(self, client: Dict[Text, Any], server: WebsocketServer) -> None:
174
152
175
153
if client_id not in self ._clients :
176
154
try :
177
- client_state = self ._create_client_state (client_id )
178
- self ._clients [client_id ] = client_state
155
+ self ._clients [client_id ] = self ._create_client_state (client_id )
179
156
180
157
# Setup RTTM response hook
181
- client_state .inference .attach_hooks (
158
+ self . _clients [ client_id ] .inference .attach_hooks (
182
159
lambda ann_wav : self .send (client_id , ann_wav [0 ].to_rttm ())
183
160
)
184
161
185
162
# Start inference
186
- client_state .inference ()
163
+ self . _clients [ client_id ] .inference ()
187
164
logger .info (f"Started inference for client: { client_id } " )
188
165
189
166
# Send ready notification to client
0 commit comments