Skip to content

Commit 9f345ee

Browse files
committed
Add option to set pong waiters on connection close.
1 parent 0fac382 commit 9f345ee

File tree

2 files changed

+43
-4
lines changed

2 files changed

+43
-4
lines changed

src/websockets/sync/connection.py

+34-4
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def __init__(
111111
self.recv_events_thread.start()
112112

113113
# Mapping of ping IDs to pong waiters, in chronological order.
114-
self.pong_waiters: dict[bytes, tuple[threading.Event, float]] = {}
114+
self.pong_waiters: dict[bytes, tuple[threading.Event, float, bool]] = {}
115115

116116
self.latency: float = 0
117117
"""
@@ -554,7 +554,11 @@ def close(self, code: int = CloseCode.NORMAL_CLOSURE, reason: str = "") -> None:
554554
# They mean that the connection is closed, which was the goal.
555555
pass
556556

557-
def ping(self, data: Data | None = None) -> threading.Event:
557+
def ping(
558+
self,
559+
data: Data | None = None,
560+
ack_on_close: bool = False,
561+
) -> threading.Event:
558562
"""
559563
Send a Ping_.
560564
@@ -566,6 +570,12 @@ def ping(self, data: Data | None = None) -> threading.Event:
566570
Args:
567571
data: Payload of the ping. A :class:`str` will be encoded to UTF-8.
568572
If ``data`` is :obj:`None`, the payload is four random bytes.
573+
ack_on_close: when this option is :obj:`True`, the event will also
574+
be set when the connection is closed. While this avoids getting
575+
stuck waiting for a pong that will never arrive, it requires
576+
checking that the state of the connection is still ``OPEN`` to
577+
confirm that a pong was received, rather than the connection
578+
being closed.
569579
570580
Returns:
571581
An event that will be set when the corresponding pong is received.
@@ -599,7 +609,7 @@ def ping(self, data: Data | None = None) -> threading.Event:
599609
data = struct.pack("!I", random.getrandbits(32))
600610

601611
pong_waiter = threading.Event()
602-
self.pong_waiters[data] = (pong_waiter, time.monotonic())
612+
self.pong_waiters[data] = (pong_waiter, time.monotonic(), ack_on_close)
603613
self.protocol.send_ping(data)
604614
return pong_waiter
605615

@@ -660,7 +670,11 @@ def acknowledge_pings(self, data: bytes) -> None:
660670
# Acknowledge all previous pings too in that case.
661671
ping_id = None
662672
ping_ids = []
663-
for ping_id, (pong_waiter, ping_timestamp) in self.pong_waiters.items():
673+
for ping_id, (
674+
pong_waiter,
675+
ping_timestamp,
676+
ack_on_close,
677+
) in self.pong_waiters.items():
664678
ping_ids.append(ping_id)
665679
pong_waiter.set()
666680
if ping_id == data:
@@ -673,6 +687,19 @@ def acknowledge_pings(self, data: bytes) -> None:
673687
for ping_id in ping_ids:
674688
del self.pong_waiters[ping_id]
675689

690+
def acknowledge_pending_pings(self) -> None:
691+
"""
692+
Acknowledge pending pings when the connection is closed.
693+
694+
"""
695+
assert self.protocol.state is CLOSED
696+
697+
for pong_waiter, _ping_timestamp, ack_on_close in self.pong_waiters.values():
698+
if ack_on_close:
699+
pong_waiter.set()
700+
701+
self.pong_waiters.clear()
702+
676703
def recv_events(self) -> None:
677704
"""
678705
Read incoming data from the socket and process events.
@@ -944,3 +971,6 @@ def close_socket(self) -> None:
944971

945972
# Abort recv() with a ConnectionClosed exception.
946973
self.recv_messages.close()
974+
975+
# Acknowledge pings sent with the ack_on_close option.
976+
self.acknowledge_pending_pings()

tests/sync/test_connection.py

+9
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,15 @@ def test_acknowledge_previous_ping(self):
685685
self.remote_connection.pong("that")
686686
self.assertTrue(pong_waiter.wait(MS))
687687

688+
def test_acknowledge_ping_on_close(self):
689+
"""ping with ack_on_close is acknowledged when the connection is closed."""
690+
with self.drop_frames_rcvd(): # drop automatic response to ping
691+
pong_waiter_ack_on_close = self.connection.ping("this", ack_on_close=True)
692+
pong_waiter = self.connection.ping("that")
693+
self.connection.close()
694+
self.assertTrue(pong_waiter_ack_on_close.wait(MS))
695+
self.assertFalse(pong_waiter.wait(MS))
696+
688697
def test_ping_duplicate_payload(self):
689698
"""ping rejects the same payload until receiving the pong."""
690699
with self.drop_frames_rcvd(): # drop automatic response to ping

0 commit comments

Comments
 (0)