Skip to content

Commit b6d6bc6

Browse files
committed
fix(client): manage stop events and handle errors correctly
1 parent f2c3144 commit b6d6bc6

File tree

1 file changed

+77
-38
lines changed

1 file changed

+77
-38
lines changed

src/diart/console/client.py

+77-38
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,60 @@
11
import argparse
22
from pathlib import Path
3-
from threading import Thread
3+
from threading import Event, Thread
44
from typing import Optional, Text
55

66
import rx.operators as ops
7-
from websocket import WebSocket
7+
from websocket import WebSocket, WebSocketException
88

99
from diart import argdoc
1010
from diart import sources as src
1111
from diart import utils
1212

1313

14-
def send_audio(ws: WebSocket, source: Text, step: float, sample_rate: int):
15-
# Create audio source
16-
source_components = source.split(":")
17-
if source_components[0] != "microphone":
18-
audio_source = src.FileAudioSource(source, sample_rate, block_duration=step)
19-
else:
20-
device = int(source_components[1]) if len(source_components) > 1 else None
21-
audio_source = src.MicrophoneAudioSource(step, device)
14+
def send_audio(
15+
ws: WebSocket, source: Text, step: float, sample_rate: int, stop_event: Event
16+
):
17+
try:
18+
# Create audio source
19+
source_components = source.split(":")
20+
if source_components[0] != "microphone":
21+
audio_source = src.FileAudioSource(source, sample_rate, block_duration=step)
22+
else:
23+
device = int(source_components[1]) if len(source_components) > 1 else None
24+
audio_source = src.MicrophoneAudioSource(step, device)
2225

23-
# Encode audio, then send through websocket
24-
audio_source.stream.pipe(ops.map(utils.encode_audio)).subscribe_(ws.send)
26+
# Encode audio, then send through websocket
27+
def on_next(data):
28+
if not stop_event.is_set():
29+
try:
30+
ws.send(utils.encode_audio(data))
31+
except WebSocketException:
32+
stop_event.set()
2533

26-
# Start reading audio
27-
audio_source.read()
34+
audio_source.stream.subscribe_(on_next)
2835

36+
# Start reading audio
37+
audio_source.read()
38+
except Exception as e:
39+
print(f"Error in send_audio: {e}")
40+
stop_event.set()
2941

30-
def receive_audio(ws: WebSocket, output: Optional[Path]):
31-
while True:
32-
message = ws.recv()
33-
print(f"Received: {message}", end="")
34-
if output is not None:
35-
with open(output, "a") as file:
36-
file.write(message)
42+
43+
def receive_audio(ws: WebSocket, output: Optional[Path], stop_event: Event):
44+
try:
45+
while not stop_event.is_set():
46+
try:
47+
message = ws.recv()
48+
print(f"Received: {message}", end="")
49+
if output is not None:
50+
with open(output, "a") as file:
51+
file.write(message)
52+
except WebSocketException:
53+
break
54+
except Exception as e:
55+
print(f"Error in receive_audio: {e}")
56+
finally:
57+
stop_event.set()
3758

3859

3960
def run():
@@ -65,23 +86,41 @@ def run():
6586

6687
# Run websocket client
6788
ws = WebSocket()
68-
ws.connect(f"ws://{args.host}:{args.port}")
69-
70-
# Wait for READY signal from server
71-
print("Waiting for server to be ready...", end="", flush=True)
72-
while True:
73-
message = ws.recv()
74-
if message.strip() == "READY":
75-
print(" OK")
76-
break
77-
print(f"\nUnexpected message while waiting for READY: {message}")
78-
79-
sender = Thread(
80-
target=send_audio, args=[ws, args.source, args.step, args.sample_rate]
81-
)
82-
receiver = Thread(target=receive_audio, args=[ws, args.output_file])
83-
sender.start()
84-
receiver.start()
89+
stop_event = Event()
90+
91+
try:
92+
ws.connect(f"ws://{args.host}:{args.port}")
93+
94+
# Wait for READY signal from server
95+
print("Waiting for server to be ready...", end="", flush=True)
96+
while True:
97+
try:
98+
message = ws.recv()
99+
if message.strip() == "READY":
100+
print(" OK")
101+
break
102+
print(f"\nUnexpected message while waiting for READY: {message}")
103+
except WebSocketException as e:
104+
print(f"\nError while waiting for server: {e}")
105+
return
106+
107+
sender = Thread(
108+
target=send_audio,
109+
args=[ws, args.source, args.step, args.sample_rate, stop_event],
110+
)
111+
receiver = Thread(target=receive_audio, args=[ws, args.output_file, stop_event])
112+
113+
sender.start()
114+
receiver.start()
115+
116+
except Exception as e:
117+
print(f"Error: {e}")
118+
stop_event.set()
119+
finally:
120+
try:
121+
ws.close()
122+
except:
123+
pass
85124

86125

87126
if __name__ == "__main__":

0 commit comments

Comments
 (0)