Skip to content

Commit 44ca5ea

Browse files
Fixed cancellation propagation when task group host is in a shielded scope (#648)
Co-authored-by: Ganden Schaffner <gschaffner@pm.me>
1 parent 3ea17f9 commit 44ca5ea

File tree

3 files changed

+110
-47
lines changed

3 files changed

+110
-47
lines changed

docs/versionhistory.rst

+3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
1313
from Egor Blagov)
1414
- Fixed ``loop_factory`` and ``use_uvloop`` options not being used on the asyncio
1515
backend (`#643 <https://github.com/agronholm/anyio/issues/643>`_)
16+
- Fixed cancellation propagating on asyncio from a task group to child tasks if the task
17+
hosting the task group is in a shielded cancel scope
18+
(`#642 <https://github.com/agronholm/anyio/issues/642>`_)
1619

1720
**4.1.0**
1821

src/anyio/_backends/_asyncio.py

+63-46
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ def __init__(self, deadline: float = math.inf, shield: bool = False):
343343
self._deadline = deadline
344344
self._shield = shield
345345
self._parent_scope: CancelScope | None = None
346+
self._child_scopes: set[CancelScope] = set()
346347
self._cancel_called = False
347348
self._cancelled_caught = False
348349
self._active = False
@@ -369,6 +370,9 @@ def __enter__(self) -> CancelScope:
369370
else:
370371
self._parent_scope = task_state.cancel_scope
371372
task_state.cancel_scope = self
373+
if self._parent_scope is not None:
374+
self._parent_scope._child_scopes.add(self)
375+
self._parent_scope._tasks.remove(host_task)
372376

373377
self._timeout()
374378
self._active = True
@@ -377,7 +381,7 @@ def __enter__(self) -> CancelScope:
377381

378382
# Start cancelling the host task if the scope was cancelled before entering
379383
if self._cancel_called:
380-
self._deliver_cancellation()
384+
self._deliver_cancellation(self)
381385

382386
return self
383387

@@ -409,13 +413,15 @@ def __exit__(
409413
self._timeout_handle = None
410414

411415
self._tasks.remove(self._host_task)
416+
if self._parent_scope is not None:
417+
self._parent_scope._child_scopes.remove(self)
418+
self._parent_scope._tasks.add(self._host_task)
412419

413420
host_task_state.cancel_scope = self._parent_scope
414421

415-
# Restart the cancellation effort in the farthest directly cancelled parent
422+
# Restart the cancellation effort in the closest directly cancelled parent
416423
# scope if this one was shielded
417-
if self._shield:
418-
self._deliver_cancellation_to_parent()
424+
self._restart_cancellation_in_parent()
419425

420426
if self._cancel_called and exc_val is not None:
421427
for exc in iterate_exceptions(exc_val):
@@ -451,65 +457,70 @@ def _timeout(self) -> None:
451457
else:
452458
self._timeout_handle = loop.call_at(self._deadline, self._timeout)
453459

454-
def _deliver_cancellation(self) -> None:
460+
def _deliver_cancellation(self, origin: CancelScope) -> bool:
455461
"""
456462
Deliver cancellation to directly contained tasks and nested cancel scopes.
457463
458464
Schedule another run at the end if we still have tasks eligible for
459465
cancellation.
466+
467+
:param origin: the cancel scope that originated the cancellation
468+
:return: ``True`` if the delivery needs to be retried on the next cycle
469+
460470
"""
461471
should_retry = False
462472
current = current_task()
463473
for task in self._tasks:
464474
if task._must_cancel: # type: ignore[attr-defined]
465475
continue
466476

467-
# The task is eligible for cancellation if it has started and is not in a
468-
# cancel scope shielded from this one
469-
cancel_scope = _task_states[task].cancel_scope
470-
while cancel_scope is not self:
471-
if cancel_scope is None or cancel_scope._shield:
472-
break
473-
else:
474-
cancel_scope = cancel_scope._parent_scope
475-
else:
476-
should_retry = True
477-
if task is not current and (
478-
task is self._host_task or _task_started(task)
479-
):
480-
waiter = task._fut_waiter # type: ignore[attr-defined]
481-
if not isinstance(waiter, asyncio.Future) or not waiter.done():
482-
self._cancel_calls += 1
483-
if sys.version_info >= (3, 9):
484-
task.cancel(f"Cancelled by cancel scope {id(self):x}")
485-
else:
486-
task.cancel()
477+
# The task is eligible for cancellation if it has started
478+
should_retry = True
479+
if task is not current and (task is self._host_task or _task_started(task)):
480+
waiter = task._fut_waiter # type: ignore[attr-defined]
481+
if not isinstance(waiter, asyncio.Future) or not waiter.done():
482+
self._cancel_calls += 1
483+
if sys.version_info >= (3, 9):
484+
task.cancel(f"Cancelled by cancel scope {id(origin):x}")
485+
else:
486+
task.cancel()
487+
488+
# Deliver cancellation to child scopes that aren't shielded or running their own
489+
# cancellation callbacks
490+
for scope in self._child_scopes:
491+
if not scope._shield and not scope.cancel_called:
492+
should_retry = scope._deliver_cancellation(origin) or should_retry
487493

488494
# Schedule another callback if there are still tasks left
489-
if should_retry:
490-
self._cancel_handle = get_running_loop().call_soon(
491-
self._deliver_cancellation
492-
)
493-
else:
494-
self._cancel_handle = None
495+
if origin is self:
496+
if should_retry:
497+
self._cancel_handle = get_running_loop().call_soon(
498+
self._deliver_cancellation, origin
499+
)
500+
else:
501+
self._cancel_handle = None
502+
503+
return should_retry
504+
505+
def _restart_cancellation_in_parent(self) -> None:
506+
"""
507+
Restart the cancellation effort in the closest directly cancelled parent scope.
495508
496-
def _deliver_cancellation_to_parent(self) -> None:
497-
"""Start cancellation effort in the farthest directly cancelled parent scope"""
509+
"""
498510
scope = self._parent_scope
499-
scope_to_cancel: CancelScope | None = None
500511
while scope is not None:
501-
if scope._cancel_called and scope._cancel_handle is None:
502-
scope_to_cancel = scope
512+
if scope._cancel_called:
513+
if scope._cancel_handle is None:
514+
scope._deliver_cancellation(scope)
515+
516+
break
503517

504518
# No point in looking beyond any shielded scope
505519
if scope._shield:
506520
break
507521

508522
scope = scope._parent_scope
509523

510-
if scope_to_cancel is not None:
511-
scope_to_cancel._deliver_cancellation()
512-
513524
def _parent_cancelled(self) -> bool:
514525
# Check whether any parent has been cancelled
515526
cancel_scope = self._parent_scope
@@ -529,7 +540,7 @@ def cancel(self) -> None:
529540

530541
self._cancel_called = True
531542
if self._host_task is not None:
532-
self._deliver_cancellation()
543+
self._deliver_cancellation(self)
533544

534545
@property
535546
def deadline(self) -> float:
@@ -562,7 +573,7 @@ def shield(self, value: bool) -> None:
562573
if self._shield != value:
563574
self._shield = value
564575
if not value:
565-
self._deliver_cancellation_to_parent()
576+
self._restart_cancellation_in_parent()
566577

567578

568579
#
@@ -623,6 +634,7 @@ def __init__(self) -> None:
623634
self.cancel_scope: CancelScope = CancelScope()
624635
self._active = False
625636
self._exceptions: list[BaseException] = []
637+
self._tasks: set[asyncio.Task] = set()
626638

627639
async def __aenter__(self) -> TaskGroup:
628640
self.cancel_scope.__enter__()
@@ -642,9 +654,9 @@ async def __aexit__(
642654
self._exceptions.append(exc_val)
643655

644656
cancelled_exc_while_waiting_tasks: CancelledError | None = None
645-
while self.cancel_scope._tasks:
657+
while self._tasks:
646658
try:
647-
await asyncio.wait(self.cancel_scope._tasks)
659+
await asyncio.wait(self._tasks)
648660
except CancelledError as exc:
649661
# This task was cancelled natively; reraise the CancelledError later
650662
# unless this task was already interrupted by another exception
@@ -676,8 +688,11 @@ def _spawn(
676688
task_status_future: asyncio.Future | None = None,
677689
) -> asyncio.Task:
678690
def task_done(_task: asyncio.Task) -> None:
679-
assert _task in self.cancel_scope._tasks
680-
self.cancel_scope._tasks.remove(_task)
691+
task_state = _task_states[_task]
692+
assert task_state.cancel_scope is not None
693+
assert _task in task_state.cancel_scope._tasks
694+
task_state.cancel_scope._tasks.remove(_task)
695+
self._tasks.remove(task)
681696
del _task_states[_task]
682697

683698
try:
@@ -693,7 +708,8 @@ def task_done(_task: asyncio.Task) -> None:
693708
if not isinstance(exc, CancelledError):
694709
self._exceptions.append(exc)
695710

696-
self.cancel_scope.cancel()
711+
if not self.cancel_scope._parent_cancelled():
712+
self.cancel_scope.cancel()
697713
else:
698714
task_status_future.set_exception(exc)
699715
elif task_status_future is not None and not task_status_future.done():
@@ -732,6 +748,7 @@ def task_done(_task: asyncio.Task) -> None:
732748
parent_id=parent_id, cancel_scope=self.cancel_scope
733749
)
734750
self.cancel_scope._tasks.add(task)
751+
self._tasks.add(task)
735752
return task
736753

737754
def start_soon(

tests/test_taskgroups.py

+44-1
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,7 @@ async def killer(scope: CancelScope) -> None:
844844
assert isinstance(exc.value.exceptions[0], TimeoutError)
845845

846846

847-
async def test_triple_nested_shield() -> None:
847+
async def test_triple_nested_shield_checkpoint_in_outer() -> None:
848848
"""Regression test for #370."""
849849

850850
got_past_checkpoint = False
@@ -867,6 +867,26 @@ async def taskfunc() -> None:
867867
assert not got_past_checkpoint
868868

869869

870+
async def test_triple_nested_shield_checkpoint_in_middle() -> None:
871+
got_past_checkpoint = False
872+
873+
async def taskfunc() -> None:
874+
nonlocal got_past_checkpoint
875+
876+
with CancelScope() as scope1:
877+
with CancelScope():
878+
with CancelScope(shield=True):
879+
scope1.cancel()
880+
881+
await checkpoint()
882+
got_past_checkpoint = True
883+
884+
async with create_task_group() as tg:
885+
tg.start_soon(taskfunc)
886+
887+
assert not got_past_checkpoint
888+
889+
870890
def test_task_group_in_generator(
871891
anyio_backend_name: str, anyio_backend_options: dict[str, Any]
872892
) -> None:
@@ -1293,6 +1313,29 @@ def handler(excgrp: BaseExceptionGroup) -> None:
12931313
await anyio.sleep_forever()
12941314

12951315

1316+
async def test_cancel_child_task_when_host_is_shielded() -> None:
1317+
# Regression test for #642
1318+
# Tests that cancellation propagates to a child task even if the host task is within
1319+
# a shielded cancel scope.
1320+
cancelled = anyio.Event()
1321+
1322+
async def wait_cancel() -> None:
1323+
try:
1324+
await anyio.sleep_forever()
1325+
except anyio.get_cancelled_exc_class():
1326+
cancelled.set()
1327+
raise
1328+
1329+
with CancelScope() as parent_scope:
1330+
async with anyio.create_task_group() as task_group:
1331+
task_group.start_soon(wait_cancel)
1332+
await wait_all_tasks_blocked()
1333+
1334+
with CancelScope(shield=True), fail_after(1):
1335+
parent_scope.cancel()
1336+
await cancelled.wait()
1337+
1338+
12961339
class TestTaskStatusTyping:
12971340
"""
12981341
These tests do not do anything at run time, but since the test suite is also checked

0 commit comments

Comments
 (0)