@@ -343,6 +343,7 @@ def __init__(self, deadline: float = math.inf, shield: bool = False):
343
343
self ._deadline = deadline
344
344
self ._shield = shield
345
345
self ._parent_scope : CancelScope | None = None
346
+ self ._child_scopes : set [CancelScope ] = set ()
346
347
self ._cancel_called = False
347
348
self ._cancelled_caught = False
348
349
self ._active = False
@@ -369,6 +370,9 @@ def __enter__(self) -> CancelScope:
369
370
else :
370
371
self ._parent_scope = task_state .cancel_scope
371
372
task_state .cancel_scope = self
373
+ if self ._parent_scope is not None :
374
+ self ._parent_scope ._child_scopes .add (self )
375
+ self ._parent_scope ._tasks .remove (host_task )
372
376
373
377
self ._timeout ()
374
378
self ._active = True
@@ -377,7 +381,7 @@ def __enter__(self) -> CancelScope:
377
381
378
382
# Start cancelling the host task if the scope was cancelled before entering
379
383
if self ._cancel_called :
380
- self ._deliver_cancellation ()
384
+ self ._deliver_cancellation (self )
381
385
382
386
return self
383
387
@@ -409,13 +413,15 @@ def __exit__(
409
413
self ._timeout_handle = None
410
414
411
415
self ._tasks .remove (self ._host_task )
416
+ if self ._parent_scope is not None :
417
+ self ._parent_scope ._child_scopes .remove (self )
418
+ self ._parent_scope ._tasks .add (self ._host_task )
412
419
413
420
host_task_state .cancel_scope = self ._parent_scope
414
421
415
- # Restart the cancellation effort in the farthest directly cancelled parent
422
+ # Restart the cancellation effort in the closest directly cancelled parent
416
423
# scope if this one was shielded
417
- if self ._shield :
418
- self ._deliver_cancellation_to_parent ()
424
+ self ._restart_cancellation_in_parent ()
419
425
420
426
if self ._cancel_called and exc_val is not None :
421
427
for exc in iterate_exceptions (exc_val ):
@@ -451,65 +457,70 @@ def _timeout(self) -> None:
451
457
else :
452
458
self ._timeout_handle = loop .call_at (self ._deadline , self ._timeout )
453
459
454
- def _deliver_cancellation (self ) -> None :
460
+ def _deliver_cancellation (self , origin : CancelScope ) -> bool :
455
461
"""
456
462
Deliver cancellation to directly contained tasks and nested cancel scopes.
457
463
458
464
Schedule another run at the end if we still have tasks eligible for
459
465
cancellation.
466
+
467
+ :param origin: the cancel scope that originated the cancellation
468
+ :return: ``True`` if the delivery needs to be retried on the next cycle
469
+
460
470
"""
461
471
should_retry = False
462
472
current = current_task ()
463
473
for task in self ._tasks :
464
474
if task ._must_cancel : # type: ignore[attr-defined]
465
475
continue
466
476
467
- # The task is eligible for cancellation if it has started and is not in a
468
- # cancel scope shielded from this one
469
- cancel_scope = _task_states [task ].cancel_scope
470
- while cancel_scope is not self :
471
- if cancel_scope is None or cancel_scope ._shield :
472
- break
473
- else :
474
- cancel_scope = cancel_scope ._parent_scope
475
- else :
476
- should_retry = True
477
- if task is not current and (
478
- task is self ._host_task or _task_started (task )
479
- ):
480
- waiter = task ._fut_waiter # type: ignore[attr-defined]
481
- if not isinstance (waiter , asyncio .Future ) or not waiter .done ():
482
- self ._cancel_calls += 1
483
- if sys .version_info >= (3 , 9 ):
484
- task .cancel (f"Cancelled by cancel scope { id (self ):x} " )
485
- else :
486
- task .cancel ()
477
+ # The task is eligible for cancellation if it has started
478
+ should_retry = True
479
+ if task is not current and (task is self ._host_task or _task_started (task )):
480
+ waiter = task ._fut_waiter # type: ignore[attr-defined]
481
+ if not isinstance (waiter , asyncio .Future ) or not waiter .done ():
482
+ self ._cancel_calls += 1
483
+ if sys .version_info >= (3 , 9 ):
484
+ task .cancel (f"Cancelled by cancel scope { id (origin ):x} " )
485
+ else :
486
+ task .cancel ()
487
+
488
+ # Deliver cancellation to child scopes that aren't shielded or running their own
489
+ # cancellation callbacks
490
+ for scope in self ._child_scopes :
491
+ if not scope ._shield and not scope .cancel_called :
492
+ should_retry = scope ._deliver_cancellation (origin ) or should_retry
487
493
488
494
# Schedule another callback if there are still tasks left
489
- if should_retry :
490
- self ._cancel_handle = get_running_loop ().call_soon (
491
- self ._deliver_cancellation
492
- )
493
- else :
494
- self ._cancel_handle = None
495
+ if origin is self :
496
+ if should_retry :
497
+ self ._cancel_handle = get_running_loop ().call_soon (
498
+ self ._deliver_cancellation , origin
499
+ )
500
+ else :
501
+ self ._cancel_handle = None
502
+
503
+ return should_retry
504
+
505
+ def _restart_cancellation_in_parent (self ) -> None :
506
+ """
507
+ Restart the cancellation effort in the closest directly cancelled parent scope.
495
508
496
- def _deliver_cancellation_to_parent (self ) -> None :
497
- """Start cancellation effort in the farthest directly cancelled parent scope"""
509
+ """
498
510
scope = self ._parent_scope
499
- scope_to_cancel : CancelScope | None = None
500
511
while scope is not None :
501
- if scope ._cancel_called and scope ._cancel_handle is None :
502
- scope_to_cancel = scope
512
+ if scope ._cancel_called :
513
+ if scope ._cancel_handle is None :
514
+ scope ._deliver_cancellation (scope )
515
+
516
+ break
503
517
504
518
# No point in looking beyond any shielded scope
505
519
if scope ._shield :
506
520
break
507
521
508
522
scope = scope ._parent_scope
509
523
510
- if scope_to_cancel is not None :
511
- scope_to_cancel ._deliver_cancellation ()
512
-
513
524
def _parent_cancelled (self ) -> bool :
514
525
# Check whether any parent has been cancelled
515
526
cancel_scope = self ._parent_scope
@@ -529,7 +540,7 @@ def cancel(self) -> None:
529
540
530
541
self ._cancel_called = True
531
542
if self ._host_task is not None :
532
- self ._deliver_cancellation ()
543
+ self ._deliver_cancellation (self )
533
544
534
545
@property
535
546
def deadline (self ) -> float :
@@ -562,7 +573,7 @@ def shield(self, value: bool) -> None:
562
573
if self ._shield != value :
563
574
self ._shield = value
564
575
if not value :
565
- self ._deliver_cancellation_to_parent ()
576
+ self ._restart_cancellation_in_parent ()
566
577
567
578
568
579
#
@@ -623,6 +634,7 @@ def __init__(self) -> None:
623
634
self .cancel_scope : CancelScope = CancelScope ()
624
635
self ._active = False
625
636
self ._exceptions : list [BaseException ] = []
637
+ self ._tasks : set [asyncio .Task ] = set ()
626
638
627
639
async def __aenter__ (self ) -> TaskGroup :
628
640
self .cancel_scope .__enter__ ()
@@ -642,9 +654,9 @@ async def __aexit__(
642
654
self ._exceptions .append (exc_val )
643
655
644
656
cancelled_exc_while_waiting_tasks : CancelledError | None = None
645
- while self .cancel_scope . _tasks :
657
+ while self ._tasks :
646
658
try :
647
- await asyncio .wait (self .cancel_scope . _tasks )
659
+ await asyncio .wait (self ._tasks )
648
660
except CancelledError as exc :
649
661
# This task was cancelled natively; reraise the CancelledError later
650
662
# unless this task was already interrupted by another exception
@@ -676,8 +688,11 @@ def _spawn(
676
688
task_status_future : asyncio .Future | None = None ,
677
689
) -> asyncio .Task :
678
690
def task_done (_task : asyncio .Task ) -> None :
679
- assert _task in self .cancel_scope ._tasks
680
- self .cancel_scope ._tasks .remove (_task )
691
+ task_state = _task_states [_task ]
692
+ assert task_state .cancel_scope is not None
693
+ assert _task in task_state .cancel_scope ._tasks
694
+ task_state .cancel_scope ._tasks .remove (_task )
695
+ self ._tasks .remove (task )
681
696
del _task_states [_task ]
682
697
683
698
try :
@@ -693,7 +708,8 @@ def task_done(_task: asyncio.Task) -> None:
693
708
if not isinstance (exc , CancelledError ):
694
709
self ._exceptions .append (exc )
695
710
696
- self .cancel_scope .cancel ()
711
+ if not self .cancel_scope ._parent_cancelled ():
712
+ self .cancel_scope .cancel ()
697
713
else :
698
714
task_status_future .set_exception (exc )
699
715
elif task_status_future is not None and not task_status_future .done ():
@@ -732,6 +748,7 @@ def task_done(_task: asyncio.Task) -> None:
732
748
parent_id = parent_id , cancel_scope = self .cancel_scope
733
749
)
734
750
self .cancel_scope ._tasks .add (task )
751
+ self ._tasks .add (task )
735
752
return task
736
753
737
754
def start_soon (
0 commit comments