9
9
from . import blocks
10
10
from . import sources as src
11
11
from .inference import StreamingInference
12
- from .progress import ProgressBar , RichProgressBar
13
12
14
13
# Configure logging
15
14
logging .basicConfig (
18
17
logger = logging .getLogger (__name__ )
19
18
20
19
21
- @dataclass
22
- class StreamingHandlerConfig :
23
- """Configuration for streaming inference.
24
-
25
- Parameters
26
- ----------
27
- pipeline_class : type
28
- Pipeline class
29
- pipeline_config : blocks.PipelineConfig
30
- Pipeline configuration
31
- batch_size : int
32
- Number of inputs to process at once
33
- do_profile : bool
34
- Enable processing time profiling
35
- do_plot : bool
36
- Enable real-time prediction plotting
37
- show_progress : bool
38
- Display progress bar
39
- progress_bar : Optional[ProgressBar]
40
- Custom progress bar implementation
41
- """
42
-
43
- pipeline_class : type
44
- pipeline_config : blocks .PipelineConfig
45
- batch_size : int = 1
46
- do_profile : bool = True
47
- do_plot : bool = False
48
- show_progress : bool = True
49
- progress_bar : Optional [ProgressBar ] = None
50
-
51
-
52
20
@dataclass
53
21
class ClientState :
54
22
"""Represents the state of a connected client."""
@@ -57,16 +25,18 @@ class ClientState:
57
25
inference : StreamingInference
58
26
59
27
60
- class StreamingHandler :
28
+ class WebSocketStreamingServer :
61
29
"""Handles real-time speaker diarization inference for multiple audio sources over WebSocket.
62
30
63
31
This handler manages WebSocket connections from multiple clients, processing
64
32
audio streams and performing speaker diarization in real-time.
65
33
66
34
Parameters
67
35
----------
68
- config : StreamingHandlerConfig
69
- Streaming inference configuration
36
+ pipeline_class : type
37
+ Pipeline class
38
+ pipeline_config : blocks.PipelineConfig
39
+ Pipeline configuration
70
40
host : str, optional
71
41
WebSocket server host, by default "127.0.0.1"
72
42
port : int, optional
@@ -79,17 +49,20 @@ class StreamingHandler:
79
49
80
50
def __init__ (
81
51
self ,
82
- config : StreamingHandlerConfig ,
52
+ pipeline_class : type ,
53
+ pipeline_config : blocks .PipelineConfig ,
83
54
host : Text = "127.0.0.1" ,
84
55
port : int = 7007 ,
85
56
key : Optional [Union [Text , Path ]] = None ,
86
57
certificate : Optional [Union [Text , Path ]] = None ,
87
58
):
88
- self . config = config
89
- self .host = host
90
- self .port = port
59
+ # Pipeline configuration
60
+ self .pipeline_class = pipeline_class
61
+ self .pipeline_config = pipeline_config
91
62
92
63
# Server configuration
64
+ self .host = host
65
+ self .port = port
93
66
self .uri = f"{ host } :{ port } "
94
67
self ._clients : Dict [Text , ClientState ] = {}
95
68
@@ -118,21 +91,22 @@ def _create_client_state(self, client_id: Text) -> ClientState:
118
91
"""
119
92
# Create a new pipeline instance with the same config
120
93
# This ensures each client has its own state while sharing model weights
121
- pipeline = self .config . pipeline_class (self . config .pipeline_config )
94
+ pipeline = self .pipeline_class (self .pipeline_config )
122
95
123
96
audio_source = src .WebSocketAudioSource (
124
97
uri = f"{ self .uri } :{ client_id } " ,
125
- sample_rate = self .config . pipeline_config .sample_rate ,
98
+ sample_rate = self .pipeline_config .sample_rate ,
126
99
)
127
100
128
101
inference = StreamingInference (
129
102
pipeline = pipeline ,
130
103
source = audio_source ,
131
- batch_size = self .config .batch_size ,
132
- do_profile = self .config .do_profile ,
133
- do_plot = self .config .do_plot ,
134
- show_progress = self .config .show_progress ,
135
- progress_bar = self .config .progress_bar ,
104
+ # The following variables are fixed for a client
105
+ batch_size = 1 ,
106
+ do_profile = False , # for minimal latency
107
+ do_plot = False ,
108
+ show_progress = False ,
109
+ progress_bar = None ,
136
110
)
137
111
138
112
return ClientState (audio_source = audio_source , inference = inference )
0 commit comments