From 7a8b757c56777509bf183bf973407de943d0d973 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Mon, 20 Jan 2025 19:50:41 +0100 Subject: [PATCH 1/4] Rename and reorder for consistency. --- src/websockets/asyncio/connection.py | 16 ++++++++-------- src/websockets/sync/connection.py | 24 ++++++++++++------------ 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index e2e587e7c..91bc0dda5 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -101,6 +101,14 @@ def __init__( # Protect sending fragmented messages. self.fragmented_send_waiter: asyncio.Future[None] | None = None + # Exception raised while reading from the connection, to be chained to + # ConnectionClosed in order to show why the TCP connection dropped. + self.recv_exc: BaseException | None = None + + # Completed when the TCP connection is closed and the WebSocket + # connection state becomes CLOSED. + self.connection_lost_waiter: asyncio.Future[None] = self.loop.create_future() + # Mapping of ping IDs to pong waiters, in chronological order. self.pong_waiters: dict[bytes, tuple[asyncio.Future[float], float]] = {} @@ -120,14 +128,6 @@ def __init__( # Task that sends keepalive pings. None when ping_interval is None. self.keepalive_task: asyncio.Task[None] | None = None - # Exception raised while reading from the connection, to be chained to - # ConnectionClosed in order to show why the TCP connection dropped. - self.recv_exc: BaseException | None = None - - # Completed when the TCP connection is closed and the WebSocket - # connection state becomes CLOSED. - self.connection_lost_waiter: asyncio.Future[None] = self.loop.create_future() - # Adapted from asyncio.FlowControlMixin self.paused: bool = False self.drain_waiters: collections.deque[asyncio.Future[None]] = ( diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 06ea00efc..653310c35 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -90,14 +90,11 @@ def __init__( resume=self.recv_flow_control.release, ) - # Whether we are busy sending a fragmented message. - self.send_in_progress = False - # Deadline for the closing handshake. self.close_deadline: Deadline | None = None - # Mapping of ping IDs to pong waiters, in chronological order. - self.ping_waiters: dict[bytes, threading.Event] = {} + # Whether we are busy sending a fragmented message. + self.send_in_progress = False # Exception raised in recv_events, to be chained to ConnectionClosed # in the user thread in order to show why the TCP connection dropped. @@ -112,6 +109,9 @@ def __init__( ) self.recv_events_thread.start() + # Mapping of ping IDs to pong waiters, in chronological order. + self.pong_waiters: dict[bytes, threading.Event] = {} + # Public attributes @property @@ -581,15 +581,15 @@ def ping(self, data: Data | None = None) -> threading.Event: with self.send_context(): # Protect against duplicates if a payload is explicitly set. - if data in self.ping_waiters: + if data in self.pong_waiters: raise ConcurrencyError("already waiting for a pong with the same data") # Generate a unique random payload otherwise. - while data is None or data in self.ping_waiters: + while data is None or data in self.pong_waiters: data = struct.pack("!I", random.getrandbits(32)) pong_waiter = threading.Event() - self.ping_waiters[data] = pong_waiter + self.pong_waiters[data] = pong_waiter self.protocol.send_ping(data) return pong_waiter @@ -641,22 +641,22 @@ def acknowledge_pings(self, data: bytes) -> None: """ with self.protocol_mutex: # Ignore unsolicited pong. - if data not in self.ping_waiters: + if data not in self.pong_waiters: return # Sending a pong for only the most recent ping is legal. # Acknowledge all previous pings too in that case. ping_id = None ping_ids = [] - for ping_id, ping in self.ping_waiters.items(): + for ping_id, ping in self.pong_waiters.items(): ping_ids.append(ping_id) ping.set() if ping_id == data: break else: raise AssertionError("solicited pong not found in pings") - # Remove acknowledged pings from self.ping_waiters. + # Remove acknowledged pings from self.pong_waiters. for ping_id in ping_ids: - del self.ping_waiters[ping_id] + del self.pong_waiters[ping_id] def recv_events(self) -> None: """ From 0fac3829353905c2079e05c36958a294b2b49cf1 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 21 Jan 2025 22:35:08 +0100 Subject: [PATCH 2/4] Add latency measurement to the threading implementation. --- docs/project/changelog.rst | 5 +++++ docs/reference/features.rst | 3 +-- docs/reference/sync/client.rst | 2 ++ docs/reference/sync/common.rst | 2 ++ docs/reference/sync/server.rst | 2 ++ src/websockets/sync/connection.py | 23 +++++++++++++++++++---- 6 files changed, 31 insertions(+), 6 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 4ad8f5532..867231241 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -32,6 +32,11 @@ notice. *In development* +New features +............ + +* Added latency measurement to the :mod:`threading` implementation. + .. _14.2: 14.2 diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 9187fa505..1135bf829 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -65,10 +65,9 @@ Both sides +------------------------------------+--------+--------+--------+--------+ | Heartbeat | ✅ | ❌ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Measure latency | ✅ | ❌ | — | ✅ | + | Measure latency | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Perform the closing handshake | ✅ | ✅ | ✅ | ✅ | - +------------------------------------+--------+--------+--------+--------+ | Enforce closing timeout | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Report close codes and reasons | ✅ | ✅ | ✅ | ❌ | diff --git a/docs/reference/sync/client.rst b/docs/reference/sync/client.rst index 2aa491f6a..89316c997 100644 --- a/docs/reference/sync/client.rst +++ b/docs/reference/sync/client.rst @@ -39,6 +39,8 @@ Using a connection .. autoproperty:: remote_address + .. autoproperty:: latency + .. autoproperty:: state The following attributes are available after the opening handshake, diff --git a/docs/reference/sync/common.rst b/docs/reference/sync/common.rst index 3c03b25b6..d44ff55b6 100644 --- a/docs/reference/sync/common.rst +++ b/docs/reference/sync/common.rst @@ -31,6 +31,8 @@ Both sides (:mod:`threading`) .. autoproperty:: remote_address + .. autoattribute:: latency + .. autoproperty:: state The following attributes are available after the opening handshake, diff --git a/docs/reference/sync/server.rst b/docs/reference/sync/server.rst index 1d80450f9..c3d0e8f25 100644 --- a/docs/reference/sync/server.rst +++ b/docs/reference/sync/server.rst @@ -52,6 +52,8 @@ Using a connection .. autoproperty:: remote_address + .. autoproperty:: latency + .. autoproperty:: state The following attributes are available after the opening handshake, diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 653310c35..b0fbf45b6 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -6,6 +6,7 @@ import socket import struct import threading +import time import uuid from collections.abc import Iterable, Iterator, Mapping from types import TracebackType @@ -110,7 +111,16 @@ def __init__( self.recv_events_thread.start() # Mapping of ping IDs to pong waiters, in chronological order. - self.pong_waiters: dict[bytes, threading.Event] = {} + self.pong_waiters: dict[bytes, tuple[threading.Event, float]] = {} + + self.latency: float = 0 + """ + Latency of the connection, in seconds. + + Latency is defined as the round-trip time of the connection. It is + measured by sending a Ping frame and waiting for a matching Pong frame. + Before the first measurement, :attr:`latency` is ``0``. + """ # Public attributes @@ -589,7 +599,7 @@ def ping(self, data: Data | None = None) -> threading.Event: data = struct.pack("!I", random.getrandbits(32)) pong_waiter = threading.Event() - self.pong_waiters[data] = pong_waiter + self.pong_waiters[data] = (pong_waiter, time.monotonic()) self.protocol.send_ping(data) return pong_waiter @@ -643,17 +653,22 @@ def acknowledge_pings(self, data: bytes) -> None: # Ignore unsolicited pong. if data not in self.pong_waiters: return + + pong_timestamp = time.monotonic() + # Sending a pong for only the most recent ping is legal. # Acknowledge all previous pings too in that case. ping_id = None ping_ids = [] - for ping_id, ping in self.pong_waiters.items(): + for ping_id, (pong_waiter, ping_timestamp) in self.pong_waiters.items(): ping_ids.append(ping_id) - ping.set() + pong_waiter.set() if ping_id == data: + self.latency = pong_timestamp - ping_timestamp break else: raise AssertionError("solicited pong not found in pings") + # Remove acknowledged pings from self.pong_waiters. for ping_id in ping_ids: del self.pong_waiters[ping_id] From fc7b151fdfbc092a8d2062ef522d374074153cfe Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 21 Jan 2025 22:32:30 +0100 Subject: [PATCH 3/4] Add option to set pong waiters on connection close. --- src/websockets/sync/connection.py | 38 +++++++++++++++++++++++++++---- tests/sync/test_connection.py | 9 ++++++++ 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index b0fbf45b6..5270c1fab 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -111,7 +111,7 @@ def __init__( self.recv_events_thread.start() # Mapping of ping IDs to pong waiters, in chronological order. - self.pong_waiters: dict[bytes, tuple[threading.Event, float]] = {} + self.pong_waiters: dict[bytes, tuple[threading.Event, float, bool]] = {} self.latency: float = 0 """ @@ -554,7 +554,11 @@ def close(self, code: int = CloseCode.NORMAL_CLOSURE, reason: str = "") -> None: # They mean that the connection is closed, which was the goal. pass - def ping(self, data: Data | None = None) -> threading.Event: + def ping( + self, + data: Data | None = None, + ack_on_close: bool = False, + ) -> threading.Event: """ Send a Ping_. @@ -566,6 +570,12 @@ def ping(self, data: Data | None = None) -> threading.Event: Args: data: Payload of the ping. A :class:`str` will be encoded to UTF-8. If ``data`` is :obj:`None`, the payload is four random bytes. + ack_on_close: when this option is :obj:`True`, the event will also + be set when the connection is closed. While this avoids getting + stuck waiting for a pong that will never arrive, it requires + checking that the state of the connection is still ``OPEN`` to + confirm that a pong was received, rather than the connection + being closed. Returns: An event that will be set when the corresponding pong is received. @@ -599,7 +609,7 @@ def ping(self, data: Data | None = None) -> threading.Event: data = struct.pack("!I", random.getrandbits(32)) pong_waiter = threading.Event() - self.pong_waiters[data] = (pong_waiter, time.monotonic()) + self.pong_waiters[data] = (pong_waiter, time.monotonic(), ack_on_close) self.protocol.send_ping(data) return pong_waiter @@ -660,7 +670,11 @@ def acknowledge_pings(self, data: bytes) -> None: # Acknowledge all previous pings too in that case. ping_id = None ping_ids = [] - for ping_id, (pong_waiter, ping_timestamp) in self.pong_waiters.items(): + for ping_id, ( + pong_waiter, + ping_timestamp, + _ack_on_close, + ) in self.pong_waiters.items(): ping_ids.append(ping_id) pong_waiter.set() if ping_id == data: @@ -673,6 +687,19 @@ def acknowledge_pings(self, data: bytes) -> None: for ping_id in ping_ids: del self.pong_waiters[ping_id] + def acknowledge_pending_pings(self) -> None: + """ + Acknowledge pending pings when the connection is closed. + + """ + assert self.protocol.state is CLOSED + + for pong_waiter, _ping_timestamp, ack_on_close in self.pong_waiters.values(): + if ack_on_close: + pong_waiter.set() + + self.pong_waiters.clear() + def recv_events(self) -> None: """ Read incoming data from the socket and process events. @@ -944,3 +971,6 @@ def close_socket(self) -> None: # Abort recv() with a ConnectionClosed exception. self.recv_messages.close() + + # Acknowledge pings sent with the ack_on_close option. + self.acknowledge_pending_pings() diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index aa445498c..ee4aec5f6 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -685,6 +685,15 @@ def test_acknowledge_previous_ping(self): self.remote_connection.pong("that") self.assertTrue(pong_waiter.wait(MS)) + def test_acknowledge_ping_on_close(self): + """ping with ack_on_close is acknowledged when the connection is closed.""" + with self.drop_frames_rcvd(): # drop automatic response to ping + pong_waiter_ack_on_close = self.connection.ping("this", ack_on_close=True) + pong_waiter = self.connection.ping("that") + self.connection.close() + self.assertTrue(pong_waiter_ack_on_close.wait(MS)) + self.assertFalse(pong_waiter.wait(MS)) + def test_ping_duplicate_payload(self): """ping rejects the same payload until receiving the pong.""" with self.drop_frames_rcvd(): # drop automatic response to ping From 8f12d8fd16d8945cbb7ad2d37962c47c16cb804f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Tue, 21 Jan 2025 22:21:31 +0100 Subject: [PATCH 4/4] Add keepalive to the threading implementation. --- docs/project/changelog.rst | 3 +- docs/reference/features.rst | 4 +- docs/topics/keepalive.rst | 5 -- src/websockets/asyncio/connection.py | 17 +++-- src/websockets/sync/client.py | 17 ++++- src/websockets/sync/connection.py | 63 +++++++++++++++ src/websockets/sync/server.py | 17 ++++- tests/asyncio/test_connection.py | 26 +++---- tests/sync/test_client.py | 15 ++++ tests/sync/test_connection.py | 110 +++++++++++++++++++++++++++ tests/sync/test_server.py | 21 +++++ 11 files changed, 267 insertions(+), 31 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 867231241..67c16ba9e 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -35,7 +35,8 @@ notice. New features ............ -* Added latency measurement to the :mod:`threading` implementation. +* Added :doc:`keepalive and latency measurement <../topics/keepalive>` to the + :mod:`threading` implementation. .. _14.2: diff --git a/docs/reference/features.rst b/docs/reference/features.rst index 1135bf829..6ba42f66b 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -61,9 +61,9 @@ Both sides +------------------------------------+--------+--------+--------+--------+ | Send a pong | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Keepalive | ✅ | ❌ | — | ✅ | + | Keepalive | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Heartbeat | ✅ | ❌ | — | ✅ | + | Heartbeat | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Measure latency | ✅ | ✅ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ diff --git a/docs/topics/keepalive.rst b/docs/topics/keepalive.rst index a0467ced2..e63c2f8f5 100644 --- a/docs/topics/keepalive.rst +++ b/docs/topics/keepalive.rst @@ -1,11 +1,6 @@ Keepalive and latency ===================== -.. admonition:: This guide applies only to the :mod:`asyncio` implementation. - :class: tip - - The :mod:`threading` implementation doesn't provide keepalive yet. - .. currentmodule:: websockets Long-lived connections diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 91bc0dda5..75c43fa8a 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -686,8 +686,7 @@ async def ping(self, data: Data | None = None) -> Awaitable[float]: pong_waiter = self.loop.create_future() # The event loop's default clock is time.monotonic(). Its resolution # is a bit low on Windows (~16ms). This is improved in Python 3.13. - ping_timestamp = self.loop.time() - self.pong_waiters[data] = (pong_waiter, ping_timestamp) + self.pong_waiters[data] = (pong_waiter, self.loop.time()) self.protocol.send_ping(data) return pong_waiter @@ -792,13 +791,19 @@ async def keepalive(self) -> None: latency = 0.0 try: while True: - # If self.ping_timeout > latency > self.ping_interval, pings - # will be sent immediately after receiving pongs. The period - # will be longer than self.ping_interval. + # If self.ping_timeout > latency > self.ping_interval, + # pings will be sent immediately after receiving pongs. + # The period will be longer than self.ping_interval. await asyncio.sleep(self.ping_interval - latency) - self.logger.debug("% sending keepalive ping") + # This cannot raise ConnectionClosed when the connection is + # closing because ping(), via send_context(), waits for the + # connection to be closed before raising ConnectionClosed. + # However, connection_lost() cancels keepalive_task before + # it gets a chance to resume excuting. pong_waiter = await self.ping() + if self.debug: + self.logger.debug("% sent keepalive ping") if self.ping_timeout is not None: try: diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 9e6da7caf..8325721b7 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -40,8 +40,8 @@ class ClientConnection(Connection): :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. - The ``close_timeout`` and ``max_queue`` arguments have the same meaning as - in :func:`connect`. + The ``ping_interval``, ``ping_timeout``, ``close_timeout``, and + ``max_queue`` arguments have the same meaning as in :func:`connect`. Args: socket: Socket connected to a WebSocket server. @@ -54,6 +54,8 @@ def __init__( socket: socket.socket, protocol: ClientProtocol, *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: @@ -62,6 +64,8 @@ def __init__( super().__init__( socket, protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, ) @@ -136,6 +140,8 @@ def connect( compression: str | None = "deflate", # Timeouts open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, @@ -184,6 +190,10 @@ def connect( :doc:`compression guide <../../topics/compression>` for details. open_timeout: Timeout for opening the connection in seconds. :obj:`None` disables the timeout. + ping_interval: Interval between keepalive pings in seconds. + :obj:`None` disables keepalive. + ping_timeout: Timeout for keepalive pings in seconds. + :obj:`None` disables timeouts. close_timeout: Timeout for closing the connection in seconds. :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. @@ -296,6 +306,8 @@ def connect( connection = create_connection( sock, protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, ) @@ -315,6 +327,7 @@ def connect( connection.recv_events_thread.join() raise + connection.start_keepalive() return connection diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 5270c1fab..07f0543e4 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -49,11 +49,15 @@ def __init__( socket: socket.socket, protocol: Protocol, *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: self.socket = socket self.protocol = protocol + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout self.close_timeout = close_timeout if isinstance(max_queue, int) or max_queue is None: max_queue = (max_queue, None) @@ -120,8 +124,15 @@ def __init__( Latency is defined as the round-trip time of the connection. It is measured by sending a Ping frame and waiting for a matching Pong frame. Before the first measurement, :attr:`latency` is ``0``. + + By default, websockets enables a :ref:`keepalive ` mechanism + that sends Ping frames automatically at regular intervals. You can also + send Ping frames and measure latency with :meth:`ping`. """ + # Thread that sends keepalive pings. None when ping_interval is None. + self.keepalive_thread: threading.Thread | None = None + # Public attributes @property @@ -700,6 +711,58 @@ def acknowledge_pending_pings(self) -> None: self.pong_waiters.clear() + def keepalive(self) -> None: + """ + Send a Ping frame and wait for a Pong frame at regular intervals. + + """ + assert self.ping_interval is not None + try: + while True: + # If self.ping_timeout > self.latency > self.ping_interval, + # pings will be sent immediately after receiving pongs. + # The period will be longer than self.ping_interval. + self.recv_events_thread.join(self.ping_interval - self.latency) + if not self.recv_events_thread.is_alive(): + break + + try: + pong_waiter = self.ping(ack_on_close=True) + except ConnectionClosed: + break + if self.debug: + self.logger.debug("% sent keepalive ping") + + if self.ping_timeout is not None: + # + if pong_waiter.wait(self.ping_timeout): + if self.debug: + self.logger.debug("% received keepalive pong") + else: + if self.debug: + self.logger.debug("- timed out waiting for keepalive pong") + with self.send_context(): + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "keepalive ping timeout", + ) + break + except Exception: + self.logger.error("keepalive ping failed", exc_info=True) + + def start_keepalive(self) -> None: + """ + Run :meth:`keepalive` in a thread, unless keepalive is disabled. + + """ + if self.ping_interval is not None: + # This thread is marked as daemon like self.recv_events_thread. + self.keepalive_thread = threading.Thread( + target=self.keepalive, + daemon=True, + ) + self.keepalive_thread.start() + def recv_events(self) -> None: """ Read incoming data from the socket and process events. diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 50a2f3c06..643f9b44b 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -52,8 +52,8 @@ class ServerConnection(Connection): :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. - The ``close_timeout`` and ``max_queue`` arguments have the same meaning as - in :func:`serve`. + The ``ping_interval``, ``ping_timeout``, ``close_timeout``, and + ``max_queue`` arguments have the same meaning as in :func:`serve`. Args: socket: Socket connected to a WebSocket client. @@ -66,6 +66,8 @@ def __init__( socket: socket.socket, protocol: ServerProtocol, *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, max_queue: int | None | tuple[int | None, int | None] = 16, ) -> None: @@ -74,6 +76,8 @@ def __init__( super().__init__( socket, protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, ) @@ -354,6 +358,8 @@ def serve( compression: str | None = "deflate", # Timeouts open_timeout: float | None = 10, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, @@ -434,6 +440,10 @@ def handler(websocket): :doc:`compression guide <../../topics/compression>` for details. open_timeout: Timeout for opening connections in seconds. :obj:`None` disables the timeout. + ping_interval: Interval between keepalive pings in seconds. + :obj:`None` disables keepalive. + ping_timeout: Timeout for keepalive pings in seconds. + :obj:`None` disables timeouts. close_timeout: Timeout for closing connections in seconds. :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. @@ -563,6 +573,8 @@ def protocol_select_subprotocol( connection = create_connection( sock, protocol, + ping_interval=ping_interval, + ping_timeout=ping_timeout, close_timeout=close_timeout, max_queue=max_queue, ) @@ -590,6 +602,7 @@ def protocol_select_subprotocol( assert connection.protocol.state is OPEN try: + connection.start_keepalive() handler(connection) except Exception: connection.logger.error("connection handler failed", exc_info=True) diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 788a457ed..b53c97030 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -1010,7 +1010,7 @@ async def test_keepalive_times_out(self, getrandbits): self.connection.start_keepalive() # 4 ms: keepalive() sends a ping frame. await asyncio.sleep(4 * MS) - # Exiting the context manager sleeps for MS. + # Exiting the context manager sleeps for 1 ms. # 4.x ms: a pong frame is dropped. # 6 ms: no pong frame is received; the connection is closed. await asyncio.sleep(2 * MS) @@ -1026,9 +1026,9 @@ async def test_keepalive_ignores_timeout(self, getrandbits): getrandbits.return_value = 1918987876 self.connection.start_keepalive() # 4 ms: keepalive() sends a ping frame. - await asyncio.sleep(4 * MS) - # Exiting the context manager sleeps for MS. # 4.x ms: a pong frame is dropped. + await asyncio.sleep(4 * MS) + # Exiting the context manager sleeps for 1 ms. # 6 ms: no pong frame is received; the connection remains open. await asyncio.sleep(2 * MS) # 7 ms: check that the connection is still open. @@ -1036,7 +1036,7 @@ async def test_keepalive_ignores_timeout(self, getrandbits): async def test_keepalive_terminates_while_sleeping(self): """keepalive task terminates while waiting to send a ping.""" - self.connection.ping_interval = 2 * MS + self.connection.ping_interval = 3 * MS self.connection.start_keepalive() await asyncio.sleep(MS) await self.connection.close() @@ -1044,15 +1044,15 @@ async def test_keepalive_terminates_while_sleeping(self): async def test_keepalive_terminates_while_waiting_for_pong(self): """keepalive task terminates while waiting to receive a pong.""" - self.connection.ping_interval = 2 * MS - self.connection.ping_timeout = 2 * MS + self.connection.ping_interval = MS + self.connection.ping_timeout = 3 * MS async with self.drop_frames_rcvd(): self.connection.start_keepalive() - # 2 ms: keepalive() sends a ping frame. - await asyncio.sleep(2 * MS) - # Exiting the context manager sleeps for MS. - # 2.x ms: a pong frame is dropped. - # 3 ms: close the connection before ping_timeout elapses. + # 1 ms: keepalive() sends a ping frame. + # 1.x ms: a pong frame is dropped. + await asyncio.sleep(MS) + # Exiting the context manager sleeps for 1 ms. + # 2 ms: close the connection before ping_timeout elapses. await self.connection.close() self.assertTrue(self.connection.keepalive_task.done()) @@ -1062,9 +1062,9 @@ async def test_keepalive_reports_errors(self): async with self.drop_frames_rcvd(): self.connection.start_keepalive() # 2 ms: keepalive() sends a ping frame. - await asyncio.sleep(2 * MS) - # Exiting the context manager sleeps for MS. # 2.x ms: a pong frame is dropped. + await asyncio.sleep(2 * MS) + # Exiting the context manager sleeps for 1 ms. # 3 ms: inject a fault: raise an exception in the pending pong waiter. pong_waiter = next(iter(self.connection.pong_waiters.values()))[0] with self.assertLogs("websockets", logging.ERROR) as logs: diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 7ab8f4dd4..1669a0e84 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -76,6 +76,21 @@ def test_disable_compression(self): with connect(get_uri(server), compression=None) as client: self.assertEqual(client.protocol.extensions, []) + def test_keepalive_is_enabled(self): + """Client enables keepalive and measures latency by default.""" + with run_server() as server: + with connect(get_uri(server), ping_interval=MS) as client: + self.assertEqual(client.latency, 0) + time.sleep(2 * MS) + self.assertGreater(client.latency, 0) + + def test_disable_keepalive(self): + """Client disables keepalive.""" + with run_server() as server: + with connect(get_uri(server), ping_interval=None) as client: + time.sleep(2 * MS) + self.assertEqual(client.latency, 0) + def test_logger(self): """Client accepts a logger argument.""" logger = logging.getLogger("test") diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index ee4aec5f6..f191d8944 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -738,6 +738,116 @@ def test_pong_unsupported_type(self): with self.assertRaises(TypeError): self.connection.pong([]) + # Test keepalive. + + @patch("random.getrandbits") + def test_keepalive(self, getrandbits): + """keepalive sends pings at ping_interval and measures latency.""" + self.connection.ping_interval = 2 * MS + getrandbits.return_value = 1918987876 + self.connection.start_keepalive() + self.assertEqual(self.connection.latency, 0) + # 2 ms: keepalive() sends a ping frame. + # 2.x ms: a pong frame is received. + time.sleep(3 * MS) + # 3 ms: check that the ping frame was sent. + self.assertFrameSent(Frame(Opcode.PING, b"rand")) + self.assertGreater(self.connection.latency, 0) + self.assertLess(self.connection.latency, MS) + + def test_disable_keepalive(self): + """keepalive is disabled when ping_interval is None.""" + self.connection.ping_interval = None + self.connection.start_keepalive() + time.sleep(3 * MS) + self.assertNoFrameSent() + + @patch("random.getrandbits") + def test_keepalive_times_out(self, getrandbits): + """keepalive closes the connection if ping_timeout elapses.""" + self.connection.ping_interval = 4 * MS + self.connection.ping_timeout = 2 * MS + with self.drop_frames_rcvd(): + getrandbits.return_value = 1918987876 + self.connection.start_keepalive() + # 4 ms: keepalive() sends a ping frame. + time.sleep(4 * MS) + # Exiting the context manager sleeps for 1 ms. + # 4.x ms: a pong frame is dropped. + # 6 ms: no pong frame is received; the connection is closed. + time.sleep(2 * MS) + # 7 ms: check that the connection is closed. + self.assertEqual(self.connection.state, State.CLOSED) + + @patch("random.getrandbits") + def test_keepalive_ignores_timeout(self, getrandbits): + """keepalive ignores timeouts if ping_timeout isn't set.""" + self.connection.ping_interval = 4 * MS + self.connection.ping_timeout = None + with self.drop_frames_rcvd(): + getrandbits.return_value = 1918987876 + self.connection.start_keepalive() + # 4 ms: keepalive() sends a ping frame. + time.sleep(4 * MS) + # Exiting the context manager sleeps for 1 ms. + # 4.x ms: a pong frame is dropped. + # 6 ms: no pong frame is received; the connection remains open. + time.sleep(2 * MS) + # 7 ms: check that the connection is still open. + self.assertEqual(self.connection.state, State.OPEN) + + def test_keepalive_terminates_while_sleeping(self): + """keepalive task terminates while waiting to send a ping.""" + self.connection.ping_interval = 3 * MS + self.connection.start_keepalive() + time.sleep(MS) + self.connection.close() + self.connection.keepalive_thread.join(MS) + self.assertFalse(self.connection.keepalive_thread.is_alive()) + + def test_keepalive_terminates_when_sending_ping_fails(self): + """keepalive task terminates when sending a ping fails.""" + self.connection.ping_interval = 1 * MS + self.connection.start_keepalive() + with self.drop_eof_rcvd(), self.drop_frames_rcvd(): + self.connection.close() + self.assertFalse(self.connection.keepalive_thread.is_alive()) + + def test_keepalive_terminates_while_waiting_for_pong(self): + """keepalive task terminates while waiting to receive a pong.""" + self.connection.ping_interval = MS + self.connection.ping_timeout = 4 * MS + with self.drop_frames_rcvd(): + self.connection.start_keepalive() + # 1 ms: keepalive() sends a ping frame. + # 1.x ms: a pong frame is dropped. + time.sleep(MS) + # Exiting the context manager sleeps for 1 ms. + # 2 ms: close the connection before ping_timeout elapses. + self.connection.close() + self.connection.keepalive_thread.join(MS) + self.assertFalse(self.connection.keepalive_thread.is_alive()) + + def test_keepalive_reports_errors(self): + """keepalive reports unexpected errors in logs.""" + self.connection.ping_interval = 2 * MS + with self.drop_frames_rcvd(): + self.connection.start_keepalive() + # 2 ms: keepalive() sends a ping frame. + # 2.x ms: a pong frame is dropped. + with self.assertLogs("websockets", logging.ERROR) as logs: + with patch("threading.Event.wait", side_effect=Exception("BOOM")): + time.sleep(3 * MS) + # Exiting the context manager sleeps for 1 ms. + self.assertEqual( + [record.getMessage() for record in logs.records], + ["keepalive ping failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) + # Test parameters. def test_close_timeout(self): diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index bb2ebae14..8e83f2a81 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -236,6 +236,27 @@ def test_disable_compression(self): with connect(get_uri(server)) as client: self.assertEval(client, "ws.protocol.extensions", "[]") + def test_keepalive_is_enabled(self): + """Server enables keepalive and measures latency.""" + with run_server(ping_interval=MS) as server: + with connect(get_uri(server)) as client: + client.send("ws.latency") + latency = eval(client.recv()) + self.assertEqual(latency, 0) + time.sleep(2 * MS) + client.send("ws.latency") + latency = eval(client.recv()) + self.assertGreater(latency, 0) + + def test_disable_keepalive(self): + """Server disables keepalive.""" + with run_server(ping_interval=None) as server: + with connect(get_uri(server)) as client: + time.sleep(2 * MS) + client.send("ws.latency") + latency = eval(client.recv()) + self.assertEqual(latency, 0) + def test_logger(self): """Server accepts a logger argument.""" logger = logging.getLogger("test")