Skip to content

Commit 28516e2

Browse files
authored
Enabled Event and CapacityLimiter to be instantiated outside an event loop (#651)
1 parent 44ca5ea commit 28516e2

File tree

4 files changed

+232
-3
lines changed

4 files changed

+232
-3
lines changed

Diff for: docs/versionhistory.rst

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
88
- Add support for ``byte``-based paths in ``connect_unix``, ``create_unix_listeners``,
99
``create_unix_datagram_socket``, and ``create_connected_unix_datagram_socket``. (PR by
1010
Lura Skye)
11+
- Enabled the ``Event`` and ``CapacityLimiter`` classes to be instantiated outside an
12+
event loop thread
1113
- Fixed adjusting the total number of tokens in a ``CapacityLimiter`` on asyncio failing
1214
to wake up tasks waiting to acquire the limiter in certain edge cases (fixed with help
1315
from Egor Blagov)

Diff for: src/anyio/_backends/_trio.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,9 @@ def set(self) -> None:
615615

616616

617617
class CapacityLimiter(BaseCapacityLimiter):
618-
def __new__(cls, *args: object, **kwargs: object) -> CapacityLimiter:
618+
def __new__(
619+
cls, *args: Any, original: trio.CapacityLimiter | None = None
620+
) -> CapacityLimiter:
619621
return object.__new__(cls)
620622

621623
def __init__(

Diff for: src/anyio/_core/_synchronization.py

+133-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from __future__ import annotations
22

3+
import math
34
from collections import deque
45
from dataclasses import dataclass
56
from types import TracebackType
67

8+
from sniffio import AsyncLibraryNotFoundError
9+
710
from ..lowlevel import cancel_shielded_checkpoint, checkpoint, checkpoint_if_cancelled
811
from ._eventloop import get_async_backend
912
from ._exceptions import BusyResourceError, WouldBlock
@@ -76,7 +79,10 @@ class SemaphoreStatistics:
7679

7780
class Event:
7881
def __new__(cls) -> Event:
79-
return get_async_backend().create_event()
82+
try:
83+
return get_async_backend().create_event()
84+
except AsyncLibraryNotFoundError:
85+
return EventAdapter()
8086

8187
def set(self) -> None:
8288
"""Set the flag, notifying all listeners."""
@@ -101,6 +107,35 @@ def statistics(self) -> EventStatistics:
101107
raise NotImplementedError
102108

103109

110+
class EventAdapter(Event):
111+
_internal_event: Event | None = None
112+
113+
def __new__(cls) -> EventAdapter:
114+
return object.__new__(cls)
115+
116+
@property
117+
def _event(self) -> Event:
118+
if self._internal_event is None:
119+
self._internal_event = get_async_backend().create_event()
120+
121+
return self._internal_event
122+
123+
def set(self) -> None:
124+
self._event.set()
125+
126+
def is_set(self) -> bool:
127+
return self._internal_event is not None and self._internal_event.is_set()
128+
129+
async def wait(self) -> None:
130+
await self._event.wait()
131+
132+
def statistics(self) -> EventStatistics:
133+
if self._internal_event is None:
134+
return EventStatistics(tasks_waiting=0)
135+
136+
return self._internal_event.statistics()
137+
138+
104139
class Lock:
105140
_owner_task: TaskInfo | None = None
106141

@@ -373,7 +408,10 @@ def statistics(self) -> SemaphoreStatistics:
373408

374409
class CapacityLimiter:
375410
def __new__(cls, total_tokens: float) -> CapacityLimiter:
376-
return get_async_backend().create_capacity_limiter(total_tokens)
411+
try:
412+
return get_async_backend().create_capacity_limiter(total_tokens)
413+
except AsyncLibraryNotFoundError:
414+
return CapacityLimiterAdapter(total_tokens)
377415

378416
async def __aenter__(self) -> None:
379417
raise NotImplementedError
@@ -482,6 +520,99 @@ def statistics(self) -> CapacityLimiterStatistics:
482520
raise NotImplementedError
483521

484522

523+
class CapacityLimiterAdapter(CapacityLimiter):
524+
_internal_limiter: CapacityLimiter | None = None
525+
526+
def __new__(cls, total_tokens: float) -> CapacityLimiterAdapter:
527+
return object.__new__(cls)
528+
529+
def __init__(self, total_tokens: float) -> None:
530+
self.total_tokens = total_tokens
531+
532+
@property
533+
def _limiter(self) -> CapacityLimiter:
534+
if self._internal_limiter is None:
535+
self._internal_limiter = get_async_backend().create_capacity_limiter(
536+
self._total_tokens
537+
)
538+
539+
return self._internal_limiter
540+
541+
async def __aenter__(self) -> None:
542+
await self._limiter.__aenter__()
543+
544+
async def __aexit__(
545+
self,
546+
exc_type: type[BaseException] | None,
547+
exc_val: BaseException | None,
548+
exc_tb: TracebackType | None,
549+
) -> bool | None:
550+
return await self._limiter.__aexit__(exc_type, exc_val, exc_tb)
551+
552+
@property
553+
def total_tokens(self) -> float:
554+
if self._internal_limiter is None:
555+
return self._total_tokens
556+
557+
return self._internal_limiter.total_tokens
558+
559+
@total_tokens.setter
560+
def total_tokens(self, value: float) -> None:
561+
if not isinstance(value, int) and value is not math.inf:
562+
raise TypeError("total_tokens must be an int or math.inf")
563+
elif value < 1:
564+
raise ValueError("total_tokens must be >= 1")
565+
566+
if self._internal_limiter is None:
567+
self._total_tokens = value
568+
return
569+
570+
self._limiter.total_tokens = value
571+
572+
@property
573+
def borrowed_tokens(self) -> int:
574+
if self._internal_limiter is None:
575+
return 0
576+
577+
return self._internal_limiter.borrowed_tokens
578+
579+
@property
580+
def available_tokens(self) -> float:
581+
if self._internal_limiter is None:
582+
return self._total_tokens
583+
584+
return self._internal_limiter.available_tokens
585+
586+
def acquire_nowait(self) -> None:
587+
self._limiter.acquire_nowait()
588+
589+
def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
590+
self._limiter.acquire_on_behalf_of_nowait(borrower)
591+
592+
async def acquire(self) -> None:
593+
await self._limiter.acquire()
594+
595+
async def acquire_on_behalf_of(self, borrower: object) -> None:
596+
await self._limiter.acquire_on_behalf_of(borrower)
597+
598+
def release(self) -> None:
599+
self._limiter.release()
600+
601+
def release_on_behalf_of(self, borrower: object) -> None:
602+
self._limiter.release_on_behalf_of(borrower)
603+
604+
def statistics(self) -> CapacityLimiterStatistics:
605+
if self._internal_limiter is None:
606+
return CapacityLimiterStatistics(
607+
borrowed_tokens=0,
608+
total_tokens=self.total_tokens,
609+
borrowers=(),
610+
tasks_waiting=0,
611+
)
612+
613+
return self._internal_limiter.statistics()
614+
615+
485616
class ResourceGuard:
486617
"""
487618
A context manager for ensuring that a resource is only used by a single task at a

Diff for: tests/test_synchronization.py

+94
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import asyncio
4+
from typing import Any
45

56
import pytest
67

@@ -13,6 +14,7 @@
1314
WouldBlock,
1415
create_task_group,
1516
fail_after,
17+
run,
1618
to_thread,
1719
wait_all_tasks_blocked,
1820
)
@@ -141,6 +143,21 @@ async def acquire() -> None:
141143
task1.cancel()
142144
await asyncio.wait_for(task2, 1)
143145

146+
def test_instantiate_outside_event_loop(
147+
self, anyio_backend_name: str, anyio_backend_options: dict[str, Any]
148+
) -> None:
149+
async def use_lock() -> None:
150+
async with lock:
151+
pass
152+
153+
lock = Lock()
154+
statistics = lock.statistics()
155+
assert not statistics.locked
156+
assert statistics.owner is None
157+
assert statistics.tasks_waiting == 0
158+
159+
run(use_lock, backend=anyio_backend_name, backend_options=anyio_backend_options)
160+
144161

145162
class TestEvent:
146163
async def test_event(self) -> None:
@@ -208,6 +225,21 @@ async def waiter() -> None:
208225

209226
assert event.statistics().tasks_waiting == 0
210227

228+
def test_instantiate_outside_event_loop(
229+
self, anyio_backend_name: str, anyio_backend_options: dict[str, Any]
230+
) -> None:
231+
async def use_event() -> None:
232+
event.set()
233+
await event.wait()
234+
235+
event = Event()
236+
assert not event.is_set()
237+
assert event.statistics().tasks_waiting == 0
238+
239+
run(
240+
use_event, backend=anyio_backend_name, backend_options=anyio_backend_options
241+
)
242+
211243

212244
class TestCondition:
213245
async def test_contextmanager(self) -> None:
@@ -304,6 +336,22 @@ async def waiter() -> None:
304336
assert not condition.statistics().lock_statistics.locked
305337
assert condition.statistics().tasks_waiting == 0
306338

339+
def test_instantiate_outside_event_loop(
340+
self, anyio_backend_name: str, anyio_backend_options: dict[str, Any]
341+
) -> None:
342+
async def use_condition() -> None:
343+
async with condition:
344+
pass
345+
346+
condition = Condition()
347+
assert condition.statistics().tasks_waiting == 0
348+
349+
run(
350+
use_condition,
351+
backend=anyio_backend_name,
352+
backend_options=anyio_backend_options,
353+
)
354+
307355

308356
class TestSemaphore:
309357
async def test_contextmanager(self) -> None:
@@ -426,6 +474,22 @@ async def acquire() -> None:
426474
task1.cancel()
427475
await asyncio.wait_for(task2, 1)
428476

477+
def test_instantiate_outside_event_loop(
478+
self, anyio_backend_name: str, anyio_backend_options: dict[str, Any]
479+
) -> None:
480+
async def use_semaphore() -> None:
481+
async with semaphore:
482+
pass
483+
484+
semaphore = Semaphore(1)
485+
assert semaphore.statistics().tasks_waiting == 0
486+
487+
run(
488+
use_semaphore,
489+
backend=anyio_backend_name,
490+
backend_options=anyio_backend_options,
491+
)
492+
429493

430494
class TestCapacityLimiter:
431495
async def test_bad_init_type(self) -> None:
@@ -595,3 +659,33 @@ async def worker(entered_event: Event) -> None:
595659

596660
# Allow all tasks to exit
597661
continue_event.set()
662+
663+
def test_instantiate_outside_event_loop(
664+
self, anyio_backend_name: str, anyio_backend_options: dict[str, Any]
665+
) -> None:
666+
async def use_limiter() -> None:
667+
async with limiter:
668+
pass
669+
670+
limiter = CapacityLimiter(1)
671+
limiter.total_tokens = 2
672+
673+
with pytest.raises(TypeError):
674+
limiter.total_tokens = "2" # type: ignore[assignment]
675+
676+
with pytest.raises(TypeError):
677+
limiter.total_tokens = 3.0
678+
679+
assert limiter.total_tokens == 2
680+
assert limiter.borrowed_tokens == 0
681+
statistics = limiter.statistics()
682+
assert statistics.total_tokens == 2
683+
assert statistics.borrowed_tokens == 0
684+
assert statistics.borrowers == ()
685+
assert statistics.tasks_waiting == 0
686+
687+
run(
688+
use_limiter,
689+
backend=anyio_backend_name,
690+
backend_options=anyio_backend_options,
691+
)

0 commit comments

Comments
 (0)