|
4 | 4 | # SPDX-License-Identifier: BSD 2-Clause License
|
5 | 5 | #
|
6 | 6 |
|
| 7 | +import asyncio |
7 | 8 | import base64
|
8 | 9 | import json
|
9 | 10 | import warnings
|
@@ -224,6 +225,7 @@ def __init__(
|
224 | 225 | self._params = params
|
225 | 226 | self._websocket = None
|
226 | 227 | self._receive_task = None
|
| 228 | + self._keepalive_task = None |
227 | 229 |
|
228 | 230 | def language_to_service_language(self, language: Language) -> Optional[str]:
|
229 | 231 | """Convert pipecat Language enum to Gladia's language code."""
|
@@ -287,22 +289,38 @@ async def start(self, frame: StartFrame):
|
287 | 289 | self._websocket = await websockets.connect(response["url"])
|
288 | 290 | if self._websocket and not self._receive_task:
|
289 | 291 | self._receive_task = self.create_task(self._receive_task_handler())
|
| 292 | + if self._websocket and not self._keepalive_task: |
| 293 | + self._keepalive_task = self.create_task(self._keepalive_task_handler()) |
290 | 294 |
|
291 | 295 | async def stop(self, frame: EndFrame):
|
292 | 296 | """Stop the Gladia STT websocket connection."""
|
293 | 297 | await super().stop(frame)
|
294 | 298 | await self._send_stop_recording()
|
| 299 | + |
| 300 | + if self._keepalive_task: |
| 301 | + await self.cancel_task(self._keepalive_task) |
| 302 | + self._keepalive_task = None |
| 303 | + |
295 | 304 | if self._websocket:
|
296 | 305 | await self._websocket.close()
|
297 | 306 | self._websocket = None
|
| 307 | + |
298 | 308 | if self._receive_task:
|
299 | 309 | await self.wait_for_task(self._receive_task)
|
300 | 310 | self._receive_task = None
|
301 | 311 |
|
302 | 312 | async def cancel(self, frame: CancelFrame):
|
303 | 313 | """Cancel the Gladia STT websocket connection."""
|
304 | 314 | await super().cancel(frame)
|
305 |
| - await self._websocket.close() |
| 315 | + |
| 316 | + if self._keepalive_task: |
| 317 | + await self.cancel_task(self._keepalive_task) |
| 318 | + self._keepalive_task = None |
| 319 | + |
| 320 | + if self._websocket: |
| 321 | + await self._websocket.close() |
| 322 | + self._websocket = None |
| 323 | + |
306 | 324 | if self._receive_task:
|
307 | 325 | await self.cancel_task(self._receive_task)
|
308 | 326 | self._receive_task = None
|
@@ -341,6 +359,24 @@ async def _send_stop_recording(self):
|
341 | 359 | if self._websocket and not self._websocket.closed:
|
342 | 360 | await self._websocket.send(json.dumps({"type": "stop_recording"}))
|
343 | 361 |
|
| 362 | + async def _keepalive_task_handler(self): |
| 363 | + """Send periodic empty audio chunks to keep the connection alive.""" |
| 364 | + try: |
| 365 | + while True: |
| 366 | + # Send keepalive every 20 seconds (Gladia times out after 30 seconds) |
| 367 | + await asyncio.sleep(20) |
| 368 | + if self._websocket and not self._websocket.closed: |
| 369 | + # Send an empty audio chunk as keepalive |
| 370 | + empty_audio = b"" |
| 371 | + await self._send_audio(empty_audio) |
| 372 | + else: |
| 373 | + logger.debug("Websocket closed, stopping keepalive") |
| 374 | + break |
| 375 | + except websockets.exceptions.ConnectionClosed: |
| 376 | + logger.debug("Connection closed during keepalive") |
| 377 | + except Exception as e: |
| 378 | + logger.error(f"Error in Gladia keepalive task: {e}") |
| 379 | + |
344 | 380 | async def _receive_task_handler(self):
|
345 | 381 | try:
|
346 | 382 | async for message in self._websocket:
|
|
0 commit comments