|
| 1 | +import asyncio |
| 2 | +from asyncio.exceptions import CancelledError |
| 3 | + |
| 4 | +import anyio |
| 5 | +import pytest |
| 6 | +from taskiq_dependencies import Depends |
| 7 | + |
| 8 | +from taskiq.api.receiver import run_receiver_task |
| 9 | +from taskiq.brokers.inmemory_broker import InmemoryResultBackend |
| 10 | +from taskiq.depends.task_idler import TaskIdler |
| 11 | +from tests.utils import AsyncQueueBroker |
| 12 | + |
| 13 | + |
| 14 | +@pytest.mark.anyio |
| 15 | +async def test_task_idler() -> None: |
| 16 | + broker = AsyncQueueBroker().with_result_backend(InmemoryResultBackend()) |
| 17 | + kicked = 0 |
| 18 | + desired_kicked = 20 |
| 19 | + |
| 20 | + @broker.task(timeout=1) |
| 21 | + async def test_func(idle: TaskIdler = Depends()) -> None: |
| 22 | + nonlocal kicked |
| 23 | + async with idle(): |
| 24 | + await asyncio.sleep(0.5) |
| 25 | + kicked += 1 |
| 26 | + |
| 27 | + receiver_task = asyncio.create_task(run_receiver_task(broker, max_async_tasks=1)) |
| 28 | + |
| 29 | + tasks = [] |
| 30 | + for _ in range(desired_kicked): |
| 31 | + tasks.append(await test_func.kiq()) |
| 32 | + |
| 33 | + with anyio.fail_after(1): |
| 34 | + for task in tasks: |
| 35 | + await task.wait_result(check_interval=0.01) |
| 36 | + |
| 37 | + receiver_task.cancel() |
| 38 | + assert kicked == desired_kicked |
| 39 | + |
| 40 | + |
| 41 | +@pytest.mark.anyio |
| 42 | +async def test_task_idler_task_cancelled() -> None: |
| 43 | + broker = AsyncQueueBroker().with_result_backend(InmemoryResultBackend()) |
| 44 | + kicked = 0 |
| 45 | + desired_kicked = 20 |
| 46 | + |
| 47 | + @broker.task(timeout=0.2) |
| 48 | + async def test_func_timeout(idle: TaskIdler = Depends()) -> None: |
| 49 | + nonlocal kicked |
| 50 | + try: |
| 51 | + async with idle(): |
| 52 | + await asyncio.sleep(2) |
| 53 | + except CancelledError: |
| 54 | + kicked += 1 |
| 55 | + raise |
| 56 | + |
| 57 | + @broker.task(timeout=2) |
| 58 | + async def test_func(idle: TaskIdler = Depends()) -> None: |
| 59 | + nonlocal kicked |
| 60 | + async with idle(): |
| 61 | + await asyncio.sleep(0.5) |
| 62 | + kicked += 1 |
| 63 | + |
| 64 | + receiver_task = asyncio.create_task(run_receiver_task(broker, max_async_tasks=1)) |
| 65 | + |
| 66 | + tasks = [] |
| 67 | + tasks.append(await test_func_timeout.kiq()) |
| 68 | + for _ in range(desired_kicked): |
| 69 | + tasks.append(await test_func.kiq()) |
| 70 | + |
| 71 | + with anyio.fail_after(1): |
| 72 | + for task in tasks: |
| 73 | + await task.wait_result(check_interval=0.01) |
| 74 | + |
| 75 | + receiver_task.cancel() |
| 76 | + assert kicked == desired_kicked + 1 |
0 commit comments