1
1
import asyncio
2
2
import inspect
3
3
from concurrent .futures import Executor
4
+ from contextlib import asynccontextmanager
4
5
from logging import getLogger
5
6
from time import time
6
- from typing import Any , Callable , Dict , List , Optional , Set , Union , get_type_hints
7
+ from typing import (
8
+ Any ,
9
+ AsyncIterator ,
10
+ Callable ,
11
+ Dict ,
12
+ List ,
13
+ Optional ,
14
+ Set ,
15
+ Union ,
16
+ get_type_hints ,
17
+ )
7
18
8
19
import anyio
9
20
from taskiq_dependencies import DependencyGraph
21
32
22
33
logger = getLogger (__name__ )
23
34
QUEUE_DONE = b"-1"
35
+ QUEUE_SKIP = b"-2"
24
36
25
37
26
38
def _run_sync (
@@ -83,6 +95,11 @@ def __init__(
83
95
"can result in undefined behavior" ,
84
96
)
85
97
self .sem_prefetch = asyncio .Semaphore (max_prefetch )
98
+ self .idle_tasks : "Set[asyncio.Task[Any]]" = set ()
99
+ self .sem_lock : asyncio .Lock = asyncio .Lock ()
100
+ self .listen_queue : "asyncio.Queue[Union[AckableMessage, bytes]]" = (
101
+ asyncio .Queue ()
102
+ )
86
103
87
104
async def callback ( # noqa: C901, PLR0912
88
105
self ,
@@ -227,7 +244,7 @@ async def run_task( # noqa: C901, PLR0912, PLR0915
227
244
broker_ctx = self .broker .custom_dependency_context
228
245
broker_ctx .update (
229
246
{
230
- Context : Context (message , self .broker ),
247
+ Context : Context (message , self .broker , self . idle ),
231
248
TaskiqState : self .broker .state ,
232
249
},
233
250
)
@@ -329,6 +346,7 @@ async def listen(self) -> None: # pragma: no cover
329
346
await self .broker .startup ()
330
347
logger .info ("Listening started." )
331
348
queue : "asyncio.Queue[Union[bytes, AckableMessage]]" = asyncio .Queue ()
349
+ self .listen_queue = queue
332
350
333
351
async with anyio .create_task_group () as gr :
334
352
gr .start_soon (self .prefetcher , queue )
@@ -396,7 +414,8 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
396
414
while True :
397
415
# Waits for semaphore to be released.
398
416
if self .sem is not None :
399
- await self .sem .acquire ()
417
+ async with self .sem_lock :
418
+ await self .sem .acquire ()
400
419
401
420
self .sem_prefetch .release ()
402
421
message = await queue .get ()
@@ -407,6 +426,11 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
407
426
await asyncio .wait (tasks , timeout = self .wait_tasks_timeout )
408
427
break
409
428
429
+ if message is QUEUE_SKIP :
430
+ if self .sem is not None :
431
+ self .sem .release ()
432
+ continue
433
+
410
434
task = asyncio .create_task (
411
435
self .callback (message = message , raise_err = False ),
412
436
)
@@ -420,6 +444,49 @@ def task_cb(task: "asyncio.Task[Any]") -> None:
420
444
# https://textual.textualize.io/blog/2023/02/11/the-heisenbug-lurking-in-your-async-code/
421
445
task .add_done_callback (task_cb )
422
446
447
+ @asynccontextmanager
448
+ async def idle (self , timeout : Optional [int ] = None ) -> AsyncIterator [None ]:
449
+ """Idle task.
450
+
451
+ :param timeout: idle time
452
+ """
453
+ if self .sem is not None :
454
+ self .sem .release ()
455
+
456
+ def acquire () -> "asyncio.Task[Any]" :
457
+ if self .sem is None :
458
+ raise ValueError (self .sem )
459
+
460
+ task = asyncio .create_task (self .sem .acquire ())
461
+ task .add_done_callback (self .idle_tasks .discard )
462
+ self .idle_tasks .add (task )
463
+ return task
464
+
465
+ cancelled = False
466
+ try :
467
+ with anyio .fail_after (timeout ):
468
+ yield
469
+ except asyncio .CancelledError :
470
+ if self .sem :
471
+ acquire ()
472
+
473
+ cancelled = True
474
+ raise
475
+
476
+ finally :
477
+ if not cancelled and self .sem is not None :
478
+ try :
479
+ await self .sem_lock .acquire ()
480
+ except asyncio .CancelledError :
481
+ acquire ()
482
+ raise
483
+
484
+ try :
485
+ self .listen_queue .put_nowait (QUEUE_SKIP )
486
+ await acquire ()
487
+ finally :
488
+ self .sem_lock .release ()
489
+
423
490
def _prepare_task (self , name : str , handler : Callable [..., Any ]) -> None :
424
491
"""
425
492
Prepare task for execution.
0 commit comments