Skip to content

add notify_closing #896

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Apr 13, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
1ca5d18
add notify_closing
graingert Mar 17, 2025
abc6c87
invert the condition for a ClosedResourceError
graingert Mar 17, 2025
af37214
fix notify_closing on windows and uvloop
graingert Mar 18, 2025
0a07e76
checkpoint after registration
graingert Mar 18, 2025
8b95665
Merge branch 'master' into notify-closing
graingert Mar 18, 2025
23b4792
add tests
graingert Mar 18, 2025
3458b0d
Merge branch 'master' into notify-closing
graingert Mar 18, 2025
2a7295d
add version
graingert Mar 18, 2025
d9f7c3b
doc notify_closing
graingert Mar 18, 2025
45bee93
Update tests/test_sockets.py
graingert Mar 18, 2025
e34a775
Update src/anyio/_core/_sockets.py
graingert Mar 18, 2025
850eb9a
Apply suggestions from code review
agronholm Mar 18, 2025
aa2ebc0
replace HasFileno | int with FileDescriptorLike
graingert Mar 19, 2025
6b442ab
use Future[bool]s instead of ClosableEvents in socket wait api
graingert Mar 20, 2025
b07dfd8
Merge branch 'master' into notify-closing
agronholm Apr 7, 2025
7e3601d
Merge branch 'master' into notify-closing
agronholm Apr 8, 2025
917dd0d
pop futs immedeately from read_events/write_events when possible
graingert Apr 8, 2025
4103157
also remove reader/writers as soon as the callback is called
graingert Apr 8, 2025
cebbfff
Removed a blank line
agronholm Apr 12, 2025
b52d65e
Added a blank line
agronholm Apr 12, 2025
4736920
Adjusted the code style
agronholm Apr 12, 2025
85a15ea
A couple more style adjustments
agronholm Apr 12, 2025
756682b
Moved fill_socket()
agronholm Apr 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@ Version history

This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.

**UNRELEASED**

- Added ``anyio.notify_closing`` to allow waking ``anyio.wait_readable``
and ``anyio.wait_writable`` before closing a socket. Among other things
this prevents an OSError on the ``ProactorEventLoop``.
(`#123 <https://github.com/agronholm/anyio/pull/896>`_; PR by @graingert)

**4.9.0**

- Added async support for temporary file handling
Expand Down
1 change: 1 addition & 0 deletions src/anyio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ._core._sockets import create_unix_listener as create_unix_listener
from ._core._sockets import getaddrinfo as getaddrinfo
from ._core._sockets import getnameinfo as getnameinfo
from ._core._sockets import notify_closing as notify_closing
from ._core._sockets import wait_readable as wait_readable
from ._core._sockets import wait_socket_readable as wait_socket_readable
from ._core._sockets import wait_socket_writable as wait_socket_writable
Expand Down
77 changes: 70 additions & 7 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from concurrent.futures import Future
from contextlib import AbstractContextManager, suppress
from contextvars import Context, copy_context
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import partial, wraps
from inspect import (
CORO_RUNNING,
Expand Down Expand Up @@ -1745,8 +1745,26 @@ async def send(self, item: bytes) -> None:
return


_read_events: RunVar[dict[int, asyncio.Event]] = RunVar("read_events")
_write_events: RunVar[dict[int, asyncio.Event]] = RunVar("write_events")
@dataclass(eq=False)
class ClosableEvent:
_closed: bool = field(default=False, init=False)
_event: asyncio.Event = field(default_factory=asyncio.Event, init=False)

def set(self) -> None:
self._event.set()

async def wait(self) -> None:
await self._event.wait()
if self._closed:
raise ClosedResourceError

def close(self) -> None:
self._closed = True
self._event.set()


_read_events: RunVar[dict[int, ClosableEvent]] = RunVar("read_events")
_write_events: RunVar[dict[int, ClosableEvent]] = RunVar("write_events")


#
Expand Down Expand Up @@ -2701,7 +2719,6 @@ async def getnameinfo(

@classmethod
async def wait_readable(cls, obj: FileDescriptorLike) -> None:
await cls.checkpoint()
try:
read_events = _read_events.get()
except LookupError:
Expand All @@ -2715,7 +2732,7 @@ async def wait_readable(cls, obj: FileDescriptorLike) -> None:
raise BusyResourceError("reading from")

loop = get_running_loop()
event = asyncio.Event()
event = ClosableEvent()
try:
loop.add_reader(obj, event.set)
except NotImplementedError:
Expand All @@ -2729,14 +2746,14 @@ async def wait_readable(cls, obj: FileDescriptorLike) -> None:

read_events[obj] = event
try:
await cls.checkpoint()
await event.wait()
finally:
remove_reader(obj)
del read_events[obj]

@classmethod
async def wait_writable(cls, obj: FileDescriptorLike) -> None:
await cls.checkpoint()
try:
write_events = _write_events.get()
except LookupError:
Expand All @@ -2750,7 +2767,7 @@ async def wait_writable(cls, obj: FileDescriptorLike) -> None:
raise BusyResourceError("writing to")

loop = get_running_loop()
event = asyncio.Event()
event = ClosableEvent()
try:
loop.add_writer(obj, event.set)
except NotImplementedError:
Expand All @@ -2764,11 +2781,57 @@ async def wait_writable(cls, obj: FileDescriptorLike) -> None:

write_events[obj] = event
try:
await cls.checkpoint()
await event.wait()
finally:
del write_events[obj]
remove_writer(obj)

@classmethod
def notify_closing(cls, obj: FileDescriptorLike) -> None:
if not isinstance(obj, int):
obj = obj.fileno()

loop = get_running_loop()

try:
write_events = _write_events.get()
except LookupError:
pass
else:
try:
event = write_events[obj]
except KeyError:
pass
else:
event.close()
try:
loop.remove_writer(obj)
except NotImplementedError:
from anyio._core._asyncio_selector_thread import get_selector

selector = get_selector()
selector.remove_writer(obj)

try:
read_events = _read_events.get()
except LookupError:
pass
else:
try:
event = read_events[obj]
except KeyError:
pass
else:
event.close()
try:
loop.remove_reader(obj)
except NotImplementedError:
from anyio._core._asyncio_selector_thread import get_selector

selector = get_selector()
selector.remove_reader(obj)

@classmethod
def current_default_thread_limiter(cls) -> CapacityLimiter:
try:
Expand Down
5 changes: 5 additions & 0 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from trio.lowlevel import (
current_root_task,
current_task,
notify_closing,
wait_readable,
wait_writable,
)
Expand Down Expand Up @@ -1281,6 +1282,10 @@ async def wait_writable(cls, obj: HasFileno | int) -> None:
except trio.BusyResourceError:
raise BusyResourceError("writing to") from None

@classmethod
def notify_closing(cls, obj: HasFileno | int) -> None:
notify_closing(obj)

@classmethod
def current_default_thread_limiter(cls) -> CapacityLimiter:
try:
Expand Down
29 changes: 29 additions & 0 deletions src/anyio/_core/_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,35 @@ def wait_writable(obj: FileDescriptorLike) -> Awaitable[None]:
return get_async_backend().wait_writable(obj)


def notify_closing(obj: FileDescriptorLike) -> None:
"""
Call this before closing a file descriptor (on Unix) or socket (on
Windows). This will cause any `wait_readable` or `wait_writable`
calls on the given object to immediately wake up and raise
`~anyio.ClosedResourceError`.

This doesn't actually close the object – you still have to do that
yourself afterwards. Also, you want to be careful to make sure no
new tasks start waiting on the object in between when you call this
and when it's actually closed. So to close something properly, you
usually want to do these steps in order:

1. Explicitly mark the object as closed, so that any new attempts
to use it will abort before they start.
2. Call `notify_closing` to wake up any already-existing users.
3. Actually close the object.

It's also possible to do them in a different order if that's more
convenient, *but only if* you make sure not to have any checkpoints in
between the steps. This way they all happen in a single atomic
step, so other tasks won't be able to tell what order they happened
in anyway.

:param obj: an object with a ``.fileno()`` method or an integer handle
"""
get_async_backend().notify_closing(obj)


#
# Private API
#
Expand Down
5 changes: 5 additions & 0 deletions src/anyio/abc/_eventloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,11 @@ async def wait_readable(cls, obj: HasFileno | int) -> None:
async def wait_writable(cls, obj: HasFileno | int) -> None:
pass

@classmethod
@abstractmethod
def notify_closing(cls, obj: HasFileno | int) -> None:
pass

@classmethod
@abstractmethod
def current_default_thread_limiter(cls) -> CapacityLimiter:
Expand Down
40 changes: 40 additions & 0 deletions tests/test_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
getaddrinfo,
getnameinfo,
move_on_after,
notify_closing,
wait_all_tasks_blocked,
wait_readable,
wait_socket_readable,
Expand Down Expand Up @@ -1932,3 +1933,42 @@ async def test_deprecated_wait_socket(anyio_backend_name: str) -> None:
):
with move_on_after(0.1):
await wait_socket_writable(sock)


def fill_socket(sock: socket.socket) -> None:
try:
while True:
sock.send(b"x" * 65536)
except BlockingIOError:
pass


@pytest.mark.parametrize("socket_type", ["socket", "fd"])
async def test_interrupted_by_close(socket_type: str) -> None:
a_sock, b = socket.socketpair()
with a_sock, b:
a_sock.setblocking(False)
b.setblocking(False)

a: FileDescriptorLike = a_sock.fileno() if socket_type == "fd" else a_sock

async def reader() -> None:
with pytest.raises(ClosedResourceError):
await wait_readable(a)

async def writer() -> None:
with pytest.raises(ClosedResourceError):
await wait_writable(a)

try:
while True:
a_sock.send(b"x" * 65536)
except BlockingIOError:
pass

async with create_task_group() as tg:
tg.start_soon(reader)
tg.start_soon(writer)
await wait_all_tasks_blocked()
notify_closing(a_sock)
a_sock.close()