Skip to content

add notify_closing #896

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Apr 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
1ca5d18
add notify_closing
graingert Mar 17, 2025
abc6c87
invert the condition for a ClosedResourceError
graingert Mar 17, 2025
af37214
fix notify_closing on windows and uvloop
graingert Mar 18, 2025
0a07e76
checkpoint after registration
graingert Mar 18, 2025
8b95665
Merge branch 'master' into notify-closing
graingert Mar 18, 2025
23b4792
add tests
graingert Mar 18, 2025
3458b0d
Merge branch 'master' into notify-closing
graingert Mar 18, 2025
2a7295d
add version
graingert Mar 18, 2025
d9f7c3b
doc notify_closing
graingert Mar 18, 2025
45bee93
Update tests/test_sockets.py
graingert Mar 18, 2025
e34a775
Update src/anyio/_core/_sockets.py
graingert Mar 18, 2025
850eb9a
Apply suggestions from code review
agronholm Mar 18, 2025
aa2ebc0
replace HasFileno | int with FileDescriptorLike
graingert Mar 19, 2025
6b442ab
use Future[bool]s instead of ClosableEvents in socket wait api
graingert Mar 20, 2025
b07dfd8
Merge branch 'master' into notify-closing
agronholm Apr 7, 2025
7e3601d
Merge branch 'master' into notify-closing
agronholm Apr 8, 2025
917dd0d
pop futs immedeately from read_events/write_events when possible
graingert Apr 8, 2025
4103157
also remove reader/writers as soon as the callback is called
graingert Apr 8, 2025
cebbfff
Removed a blank line
agronholm Apr 12, 2025
b52d65e
Added a blank line
agronholm Apr 12, 2025
4736920
Adjusted the code style
agronholm Apr 12, 2025
85a15ea
A couple more style adjustments
agronholm Apr 12, 2025
756682b
Moved fill_socket()
agronholm Apr 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.

- Added the ability to specify the thread name in ``start_blocking_portal()``
(`#818 <https://github.com/agronholm/anyio/issues/818>`_; PR by @davidbrochart)
- Added ``anyio.notify_closing`` to allow waking ``anyio.wait_readable``
and ``anyio.wait_writable`` before closing a socket. Among other things,
this prevents an OSError on the ``ProactorEventLoop``.
(`#123 <https://github.com/agronholm/anyio/pull/896>`_; PR by @graingert)

**4.9.0**

Expand Down
1 change: 1 addition & 0 deletions src/anyio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ._core._sockets import create_unix_listener as create_unix_listener
from ._core._sockets import getaddrinfo as getaddrinfo
from ._core._sockets import getnameinfo as getnameinfo
from ._core._sockets import notify_closing as notify_closing
from ._core._sockets import wait_readable as wait_readable
from ._core._sockets import wait_socket_readable as wait_socket_readable
from ._core._sockets import wait_socket_writable as wait_socket_writable
Expand Down
137 changes: 111 additions & 26 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -1745,8 +1745,8 @@ async def send(self, item: bytes) -> None:
return


_read_events: RunVar[dict[int, asyncio.Event]] = RunVar("read_events")
_write_events: RunVar[dict[int, asyncio.Event]] = RunVar("write_events")
_read_events: RunVar[dict[int, asyncio.Future[bool]]] = RunVar("read_events")
_write_events: RunVar[dict[int, asyncio.Future[bool]]] = RunVar("write_events")


#
Expand Down Expand Up @@ -2701,73 +2701,158 @@ async def getnameinfo(

@classmethod
async def wait_readable(cls, obj: FileDescriptorLike) -> None:
await cls.checkpoint()
try:
read_events = _read_events.get()
except LookupError:
read_events = {}
_read_events.set(read_events)

if not isinstance(obj, int):
obj = obj.fileno()

if read_events.get(obj):
fd = obj if isinstance(obj, int) else obj.fileno()
if read_events.get(fd):
raise BusyResourceError("reading from")

loop = get_running_loop()
event = asyncio.Event()
fut: asyncio.Future[bool] = loop.create_future()

def cb() -> None:
try:
del read_events[fd]
except KeyError:
pass
else:
remove_reader(fd)

try:
fut.set_result(True)
except asyncio.InvalidStateError:
pass

try:
loop.add_reader(obj, event.set)
loop.add_reader(fd, cb)
except NotImplementedError:
from anyio._core._asyncio_selector_thread import get_selector

selector = get_selector()
selector.add_reader(obj, event.set)
selector.add_reader(fd, cb)
remove_reader = selector.remove_reader
else:
remove_reader = loop.remove_reader

read_events[obj] = event
read_events[fd] = fut
try:
await event.wait()
success = await fut
finally:
remove_reader(obj)
del read_events[obj]
try:
del read_events[fd]
except KeyError:
pass
else:
remove_reader(fd)

if not success:
raise ClosedResourceError

@classmethod
async def wait_writable(cls, obj: FileDescriptorLike) -> None:
await cls.checkpoint()
try:
write_events = _write_events.get()
except LookupError:
write_events = {}
_write_events.set(write_events)

if not isinstance(obj, int):
obj = obj.fileno()

if write_events.get(obj):
fd = obj if isinstance(obj, int) else obj.fileno()
if write_events.get(fd):
raise BusyResourceError("writing to")

loop = get_running_loop()
event = asyncio.Event()
fut: asyncio.Future[bool] = loop.create_future()

def cb() -> None:
try:
del write_events[fd]
except KeyError:
pass
else:
remove_writer(fd)

try:
fut.set_result(True)
except asyncio.InvalidStateError:
pass

try:
loop.add_writer(obj, event.set)
loop.add_writer(fd, cb)
except NotImplementedError:
from anyio._core._asyncio_selector_thread import get_selector

selector = get_selector()
selector.add_writer(obj, event.set)
selector.add_writer(fd, cb)
remove_writer = selector.remove_writer
else:
remove_writer = loop.remove_writer

write_events[obj] = event
write_events[fd] = fut
try:
await event.wait()
success = await fut
finally:
del write_events[obj]
remove_writer(obj)
try:
del write_events[fd]
except KeyError:
pass
else:
remove_writer(fd)

if not success:
raise ClosedResourceError

@classmethod
def notify_closing(cls, obj: FileDescriptorLike) -> None:
fd = obj if isinstance(obj, int) else obj.fileno()
loop = get_running_loop()

try:
write_events = _write_events.get()
except LookupError:
pass
else:
try:
fut = write_events.pop(fd)
except KeyError:
pass
else:
try:
fut.set_result(False)
except asyncio.InvalidStateError:
pass

try:
loop.remove_writer(fd)
except NotImplementedError:
from anyio._core._asyncio_selector_thread import get_selector

get_selector().remove_writer(fd)

try:
read_events = _read_events.get()
except LookupError:
pass
else:
try:
fut = read_events.pop(fd)
except KeyError:
pass
else:
try:
fut.set_result(False)
except asyncio.InvalidStateError:
pass

try:
loop.remove_reader(fd)
except NotImplementedError:
from anyio._core._asyncio_selector_thread import get_selector

get_selector().remove_reader(fd)

@classmethod
def current_default_thread_limiter(cls) -> CapacityLimiter:
Expand Down
11 changes: 8 additions & 3 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from trio.lowlevel import (
current_root_task,
current_task,
notify_closing,
wait_readable,
wait_writable,
)
Expand Down Expand Up @@ -82,7 +83,7 @@
from ..streams.memory import MemoryObjectSendStream

if TYPE_CHECKING:
from _typeshed import HasFileno
from _typeshed import FileDescriptorLike

if sys.version_info >= (3, 10):
from typing import ParamSpec
Expand Down Expand Up @@ -1264,7 +1265,7 @@ async def getnameinfo(
return await trio.socket.getnameinfo(sockaddr, flags)

@classmethod
async def wait_readable(cls, obj: HasFileno | int) -> None:
async def wait_readable(cls, obj: FileDescriptorLike) -> None:
try:
await wait_readable(obj)
except trio.ClosedResourceError as exc:
Expand All @@ -1273,14 +1274,18 @@ async def wait_readable(cls, obj: HasFileno | int) -> None:
raise BusyResourceError("reading from") from None

@classmethod
async def wait_writable(cls, obj: HasFileno | int) -> None:
async def wait_writable(cls, obj: FileDescriptorLike) -> None:
try:
await wait_writable(obj)
except trio.ClosedResourceError as exc:
raise ClosedResourceError().with_traceback(exc.__traceback__) from None
except trio.BusyResourceError:
raise BusyResourceError("writing to") from None

@classmethod
def notify_closing(cls, obj: FileDescriptorLike) -> None:
notify_closing(obj)

@classmethod
def current_default_thread_limiter(cls) -> CapacityLimiter:
try:
Expand Down
30 changes: 30 additions & 0 deletions src/anyio/_core/_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,36 @@ def wait_writable(obj: FileDescriptorLike) -> Awaitable[None]:
return get_async_backend().wait_writable(obj)


def notify_closing(obj: FileDescriptorLike) -> None:
"""
Call this before closing a file descriptor (on Unix) or socket (on
Windows). This will cause any `wait_readable` or `wait_writable`
calls on the given object to immediately wake up and raise
`~anyio.ClosedResourceError`.

This doesn't actually close the object – you still have to do that
yourself afterwards. Also, you want to be careful to make sure no
new tasks start waiting on the object in between when you call this
and when it's actually closed. So to close something properly, you
usually want to do these steps in order:

1. Explicitly mark the object as closed, so that any new attempts
to use it will abort before they start.
2. Call `notify_closing` to wake up any already-existing users.
3. Actually close the object.

It's also possible to do them in a different order if that's more
convenient, *but only if* you make sure not to have any checkpoints in
between the steps. This way they all happen in a single atomic
step, so other tasks won't be able to tell what order they happened
in anyway.

:param obj: an object with a ``.fileno()`` method or an integer handle

"""
get_async_backend().notify_closing(obj)


#
# Private API
#
Expand Down
11 changes: 8 additions & 3 deletions src/anyio/abc/_eventloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from typing_extensions import TypeAlias

if TYPE_CHECKING:
from _typeshed import HasFileno
from _typeshed import FileDescriptorLike

from .._core._synchronization import CapacityLimiter, Event, Lock, Semaphore
from .._core._tasks import CancelScope
Expand Down Expand Up @@ -335,12 +335,17 @@ async def getnameinfo(

@classmethod
@abstractmethod
async def wait_readable(cls, obj: HasFileno | int) -> None:
async def wait_readable(cls, obj: FileDescriptorLike) -> None:
pass

@classmethod
@abstractmethod
async def wait_writable(cls, obj: HasFileno | int) -> None:
async def wait_writable(cls, obj: FileDescriptorLike) -> None:
pass

@classmethod
@abstractmethod
def notify_closing(cls, obj: FileDescriptorLike) -> None:
pass

@classmethod
Expand Down
36 changes: 36 additions & 0 deletions tests/test_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
getaddrinfo,
getnameinfo,
move_on_after,
notify_closing,
wait_all_tasks_blocked,
wait_readable,
wait_socket_readable,
Expand Down Expand Up @@ -142,6 +143,14 @@ def _identity(v: _T) -> _T:
return v


def fill_socket(sock: socket.socket) -> None:
try:
while True:
sock.send(b"x" * 65536)
except BlockingIOError:
pass


# _ProactorBasePipeTransport.abort() after _ProactorBasePipeTransport.close()
# does not cancel writes: https://bugs.python.org/issue44428
_ignore_win32_resource_warnings = (
Expand Down Expand Up @@ -1932,3 +1941,30 @@ async def test_deprecated_wait_socket(anyio_backend_name: str) -> None:
):
with move_on_after(0.1):
await wait_socket_writable(sock)


@pytest.mark.parametrize("socket_type", ["socket", "fd"])
async def test_interrupted_by_close(socket_type: str) -> None:
a_sock, b = socket.socketpair()
with a_sock, b:
a_sock.setblocking(False)
b.setblocking(False)

a: FileDescriptorLike = a_sock.fileno() if socket_type == "fd" else a_sock

async def reader() -> None:
with pytest.raises(ClosedResourceError):
await wait_readable(a)

async def writer() -> None:
with pytest.raises(ClosedResourceError):
await wait_writable(a)

fill_socket(a_sock)

async with create_task_group() as tg:
tg.start_soon(reader)
tg.start_soon(writer)
await wait_all_tasks_blocked()
notify_closing(a_sock)
a_sock.close()
Loading