Skip to content

Commit 72b76e2

Browse files
author
Anton
committed
fix: receiver
1 parent 25ffb83 commit 72b76e2

File tree

1 file changed

+70
-3
lines changed

1 file changed

+70
-3
lines changed

taskiq/receiver/receiver.py

+70-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,20 @@
11
import asyncio
22
import inspect
33
from concurrent.futures import Executor
4+
from contextlib import asynccontextmanager
45
from logging import getLogger
56
from time import time
6-
from typing import Any, Callable, Dict, List, Optional, Set, Union, get_type_hints
7+
from typing import (
8+
Any,
9+
AsyncIterator,
10+
Callable,
11+
Dict,
12+
List,
13+
Optional,
14+
Set,
15+
Union,
16+
get_type_hints,
17+
)
718

819
import anyio
920
from taskiq_dependencies import DependencyGraph
@@ -21,6 +32,7 @@
2132

2233
logger = getLogger(__name__)
2334
QUEUE_DONE = b"-1"
35+
QUEUE_SKIP = b"-2"
2436

2537

2638
def _run_sync(
@@ -83,6 +95,11 @@ def __init__(
8395
"can result in undefined behavior",
8496
)
8597
self.sem_prefetch = asyncio.Semaphore(max_prefetch)
98+
self.idle_tasks: "Set[asyncio.Task[Any]]" = set()
99+
self.sem_lock: asyncio.Lock = asyncio.Lock()
100+
self.listen_queue: "asyncio.Queue[Union[AckableMessage, bytes]]" = (
101+
asyncio.Queue()
102+
)
86103

87104
async def callback( # noqa: C901, PLR0912
88105
self,
@@ -227,7 +244,7 @@ async def run_task( # noqa: C901, PLR0912, PLR0915
227244
broker_ctx = self.broker.custom_dependency_context
228245
broker_ctx.update(
229246
{
230-
Context: Context(message, self.broker),
247+
Context: Context(message, self.broker, self.idle),
231248
TaskiqState: self.broker.state,
232249
},
233250
)
@@ -329,6 +346,7 @@ async def listen(self) -> None: # pragma: no cover
329346
await self.broker.startup()
330347
logger.info("Listening started.")
331348
queue: "asyncio.Queue[Union[bytes, AckableMessage]]" = asyncio.Queue()
349+
self.listen_queue = queue
332350

333351
async with anyio.create_task_group() as gr:
334352
gr.start_soon(self.prefetcher, queue)
@@ -396,7 +414,8 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
396414
while True:
397415
# Waits for semaphore to be released.
398416
if self.sem is not None:
399-
await self.sem.acquire()
417+
async with self.sem_lock:
418+
await self.sem.acquire()
400419

401420
self.sem_prefetch.release()
402421
message = await queue.get()
@@ -407,6 +426,11 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
407426
await asyncio.wait(tasks, timeout=self.wait_tasks_timeout)
408427
break
409428

429+
if message is QUEUE_SKIP:
430+
if self.sem is not None:
431+
self.sem.release()
432+
continue
433+
410434
task = asyncio.create_task(
411435
self.callback(message=message, raise_err=False),
412436
)
@@ -420,6 +444,49 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
420444
# https://textual.textualize.io/blog/2023/02/11/the-heisenbug-lurking-in-your-async-code/
421445
task.add_done_callback(task_cb)
422446

447+
@asynccontextmanager
448+
async def idle(self, timeout: Optional[int] = None) -> AsyncIterator[None]:
449+
"""Idle task.
450+
451+
:param timeout: idle time
452+
"""
453+
if self.sem is not None:
454+
self.sem.release()
455+
456+
def acquire() -> "asyncio.Task[Any]":
457+
if self.sem is None:
458+
raise ValueError(self.sem)
459+
460+
task = asyncio.create_task(self.sem.acquire())
461+
task.add_done_callback(self.idle_tasks.discard)
462+
self.idle_tasks.add(task)
463+
return task
464+
465+
cancelled = False
466+
try:
467+
with anyio.fail_after(timeout):
468+
yield
469+
except asyncio.CancelledError:
470+
if self.sem:
471+
acquire()
472+
473+
cancelled = True
474+
raise
475+
476+
finally:
477+
if not cancelled and self.sem is not None:
478+
try:
479+
await self.sem_lock.acquire()
480+
except asyncio.CancelledError:
481+
acquire()
482+
raise
483+
484+
try:
485+
self.listen_queue.put_nowait(QUEUE_SKIP)
486+
await acquire()
487+
finally:
488+
self.sem_lock.release()
489+
423490
def _prepare_task(self, name: str, handler: Callable[..., Any]) -> None:
424491
"""
425492
Prepare task for execution.

0 commit comments

Comments
 (0)