Skip to content

simplify wait_readable loop + notify_closing #31

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
4 changes: 3 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: pip install -e ".[test]"
run: |
pip install -e ".[test]"
pip install git+https://github.com/agronholm/anyio.git@notify-closing#egg=anyio --ignore-installed
- name: Check with mypy and ruff
if: ${{ (matrix.python-version == '3.13') && (matrix.os == 'ubuntu-latest') }}
run: |
Expand Down
63 changes: 33 additions & 30 deletions src/zmq_anyio/_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
get_cancelled_exc_class,
sleep,
wait_readable,
ClosedResourceError,
notify_closing,
)
from anyio.abc import TaskGroup, TaskStatus
from anyioutils import FIRST_COMPLETED, Future, create_task, wait
from anyioutils import Future, create_task

import zmq
from zmq import EVENTS, POLLIN, POLLOUT
Expand Down Expand Up @@ -890,36 +892,35 @@ async def _start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
task_status.started()
self.started.set()
self._thread = get_ident()

async def wait_or_cancel() -> None:
assert self.stopped is not None
await self.stopped.wait()
tg.cancel_scope.cancel()

def fileno() -> int:
if self.closed:
return -1
try:
return self._shadow_sock.fileno()
except zmq.ZMQError:
return -1

try:
while True:
wait_stopped_task = create_task(
self.stopped.wait(),
self._task_group,
exception_handler=ignore_exceptions,
)
tasks = [
create_task(
wait_readable(self._shadow_sock), # type: ignore[arg-type]
self._task_group,
exception_handler=ignore_exceptions,
),
wait_stopped_task,
]
done, pending = await wait(
tasks, self._task_group, return_when=FIRST_COMPLETED
)
for task in pending:
task.cancel()
if wait_stopped_task in done:
while (fd := fileno()) > 0:
async with create_task_group() as tg:
tg.start_soon(wait_or_cancel)
try:
await wait_readable(fd)
except ClosedResourceError:
break
tg.cancel_scope.cancel()
if self.stopped.is_set():
break
await self._handle_events()
except BaseException:
pass
finally:
self._exited.set()

assert self.stopped is not None
self.stopped.set()
self.stopped.set()

async def stop(self):
assert self._exited is not None
Expand All @@ -933,11 +934,13 @@ async def stop(self):
self.close()

def close(self, linger: int | None = None) -> None:
try:
if not self.closed and self._fd is not None:
fd = self._fd
if not self.closed and fd is not None:
notify_closing(fd)
try:
super().close(linger=linger)
except BaseException:
pass
except BaseException:
pass

assert self.stopped is not None
self.stopped.set()
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def context(contexts):


@pytest.fixture
def sockets(contexts):
async def sockets(contexts):
sockets = []
yield sockets
# ensure any tracked sockets get their contexts cleaned up
Expand Down