@@ -111,7 +111,7 @@ def __init__(
111
111
self .recv_events_thread .start ()
112
112
113
113
# Mapping of ping IDs to pong waiters, in chronological order.
114
- self .pong_waiters : dict [bytes , tuple [threading .Event , float ]] = {}
114
+ self .pong_waiters : dict [bytes , tuple [threading .Event , float , bool ]] = {}
115
115
116
116
self .latency : float = 0
117
117
"""
@@ -554,7 +554,11 @@ def close(self, code: int = CloseCode.NORMAL_CLOSURE, reason: str = "") -> None:
554
554
# They mean that the connection is closed, which was the goal.
555
555
pass
556
556
557
- def ping (self , data : Data | None = None ) -> threading .Event :
557
+ def ping (
558
+ self ,
559
+ data : Data | None = None ,
560
+ ack_on_close : bool = False ,
561
+ ) -> threading .Event :
558
562
"""
559
563
Send a Ping_.
560
564
@@ -566,6 +570,12 @@ def ping(self, data: Data | None = None) -> threading.Event:
566
570
Args:
567
571
data: Payload of the ping. A :class:`str` will be encoded to UTF-8.
568
572
If ``data`` is :obj:`None`, the payload is four random bytes.
573
+ ack_on_close: when this option is :obj:`True`, the event will also
574
+ be set when the connection is closed. While this avoids getting
575
+ stuck waiting for a pong that will never arrive, it requires
576
+ checking that the state of the connection is still ``OPEN`` to
577
+ confirm that a pong was received, rather than the connection
578
+ being closed.
569
579
570
580
Returns:
571
581
An event that will be set when the corresponding pong is received.
@@ -599,7 +609,7 @@ def ping(self, data: Data | None = None) -> threading.Event:
599
609
data = struct .pack ("!I" , random .getrandbits (32 ))
600
610
601
611
pong_waiter = threading .Event ()
602
- self .pong_waiters [data ] = (pong_waiter , time .monotonic ())
612
+ self .pong_waiters [data ] = (pong_waiter , time .monotonic (), ack_on_close )
603
613
self .protocol .send_ping (data )
604
614
return pong_waiter
605
615
@@ -660,7 +670,11 @@ def acknowledge_pings(self, data: bytes) -> None:
660
670
# Acknowledge all previous pings too in that case.
661
671
ping_id = None
662
672
ping_ids = []
663
- for ping_id , (pong_waiter , ping_timestamp ) in self .pong_waiters .items ():
673
+ for ping_id , (
674
+ pong_waiter ,
675
+ ping_timestamp ,
676
+ ack_on_close ,
677
+ ) in self .pong_waiters .items ():
664
678
ping_ids .append (ping_id )
665
679
pong_waiter .set ()
666
680
if ping_id == data :
@@ -673,6 +687,19 @@ def acknowledge_pings(self, data: bytes) -> None:
673
687
for ping_id in ping_ids :
674
688
del self .pong_waiters [ping_id ]
675
689
690
+ def acknowledge_pending_pings (self ) -> None :
691
+ """
692
+ Acknowledge pending pings when the connection is closed.
693
+
694
+ """
695
+ assert self .protocol .state is CLOSED
696
+
697
+ for pong_waiter , _ping_timestamp , ack_on_close in self .pong_waiters .values ():
698
+ if ack_on_close :
699
+ pong_waiter .set ()
700
+
701
+ self .pong_waiters .clear ()
702
+
676
703
def recv_events (self ) -> None :
677
704
"""
678
705
Read incoming data from the socket and process events.
@@ -944,3 +971,6 @@ def close_socket(self) -> None:
944
971
945
972
# Abort recv() with a ConnectionClosed exception.
946
973
self .recv_messages .close ()
974
+
975
+ # Acknowledge pings sent with the ack_on_close option.
976
+ self .acknowledge_pending_pings ()
0 commit comments