From ee0d7441940131cc30b0f5a9a2afc905bb490564 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 24 Jan 2025 21:15:08 +0100 Subject: [PATCH 1/2] Avoid shadowing an import with another import. --- pyproject.toml | 2 +- src/websockets/__init__.py | 5 +++-- src/websockets/typing.py | 5 ++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d3128f8ec..4044de0f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,7 @@ exclude_lines = [ "except ImportError:", "if self.debug:", "if sys.platform != \"win32\":", - "if typing.TYPE_CHECKING:", + "if TYPE_CHECKING:", "raise AssertionError", "self.fail\\(\".*\"\\)", "@unittest.skip", diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index c8df54e0b..1d0abe5cd 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -1,6 +1,7 @@ from __future__ import annotations -import typing +# Importing the typing module would conflict with websockets.typing. +from typing import TYPE_CHECKING from .imports import lazy_import from .version import version as __version__ # noqa: F401 @@ -72,7 +73,7 @@ ] # When type checking, import non-deprecated aliases eagerly. Else, import on demand. -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from .asyncio.client import ClientConnection, connect, unix_connect from .asyncio.server import ( Server, diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 0a37141c6..f10481b8b 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -2,8 +2,7 @@ import http import logging -import typing -from typing import Any, NewType, Optional, Union +from typing import TYPE_CHECKING, Any, NewType, Optional, Union __all__ = [ @@ -31,7 +30,7 @@ # Change to logging.Logger | ... when dropping Python < 3.10. -if typing.TYPE_CHECKING: +if TYPE_CHECKING: LoggerLike = Union[logging.Logger, logging.LoggerAdapter[Any]] """Types accepted where a :class:`~logging.Logger` is expected.""" else: # remove this branch when dropping support for Python < 3.11 From bba423e510cc422e27dfd77a95771d208e3e766a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Thu, 23 Jan 2025 21:22:14 +0100 Subject: [PATCH 2/2] Add type overloads for recv and recv_streaming. Fix #1578. --- docs/project/changelog.rst | 6 +++++ pyproject.toml | 1 + src/websockets/asyncio/connection.py | 20 ++++++++++++++++- src/websockets/asyncio/messages.py | 20 ++++++++++++++++- src/websockets/sync/connection.py | 33 +++++++++++++++++++++++++++- src/websockets/sync/messages.py | 29 +++++++++++++++++++++++- 6 files changed, 105 insertions(+), 4 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 67c16ba9e..7f341d942 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -38,6 +38,12 @@ New features * Added :doc:`keepalive and latency measurement <../topics/keepalive>` to the :mod:`threading` implementation. +Improvements +............ + +* Added type overloads for the ``decode`` argument of + :meth:`~asyncio.connection.Connection.recv`. This may simplify static typing. + .. _14.2: 14.2 diff --git a/pyproject.toml b/pyproject.toml index 4044de0f2..c0d9fcfd9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,7 @@ exclude_lines = [ "if TYPE_CHECKING:", "raise AssertionError", "self.fail\\(\".*\"\\)", + "@overload", "@unittest.skip", ] partial_branches = [ diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 75c43fa8a..79429923e 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -11,7 +11,7 @@ import uuid from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping from types import TracebackType -from typing import Any, cast +from typing import Any, Literal, cast, overload from ..exceptions import ( ConcurrencyError, @@ -243,6 +243,15 @@ async def __aiter__(self) -> AsyncIterator[Data]: except ConnectionClosedOK: return + @overload + async def recv(self, decode: Literal[True]) -> str: ... + + @overload + async def recv(self, decode: Literal[False]) -> bytes: ... + + @overload + async def recv(self, decode: bool | None = None) -> Data: ... + async def recv(self, decode: bool | None = None) -> Data: """ Receive the next message. @@ -312,6 +321,15 @@ async def recv(self, decode: bool | None = None) -> Data: await asyncio.shield(self.connection_lost_waiter) raise self.protocol.close_exc from self.recv_exc + @overload + def recv_streaming(self, decode: Literal[True]) -> AsyncIterator[str]: ... + + @overload + def recv_streaming(self, decode: Literal[False]) -> AsyncIterator[bytes]: ... + + @overload + def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: ... + async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: """ Receive the next message frame by frame. diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index c10072467..581870037 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -4,7 +4,7 @@ import codecs import collections from collections.abc import AsyncIterator, Iterable -from typing import Any, Callable, Generic, TypeVar +from typing import Any, Callable, Generic, Literal, TypeVar, overload from ..exceptions import ConcurrencyError from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame @@ -116,6 +116,15 @@ def __init__( # pragma: no cover # This flag marks the end of the connection. self.closed = False + @overload + async def get(self, decode: Literal[True]) -> str: ... + + @overload + async def get(self, decode: Literal[False]) -> bytes: ... + + @overload + async def get(self, decode: bool | None = None) -> Data: ... + async def get(self, decode: bool | None = None) -> Data: """ Read the next message. @@ -176,6 +185,15 @@ async def get(self, decode: bool | None = None) -> Data: else: return data + @overload + def get_iter(self, decode: Literal[True]) -> AsyncIterator[str]: ... + + @overload + def get_iter(self, decode: Literal[False]) -> AsyncIterator[bytes]: ... + + @overload + def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: ... + async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: """ Stream the next message. diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 07f0543e4..0c517cc64 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -10,7 +10,7 @@ import uuid from collections.abc import Iterable, Iterator, Mapping from types import TracebackType -from typing import Any +from typing import Any, Literal, overload from ..exceptions import ( ConcurrencyError, @@ -241,6 +241,28 @@ def __iter__(self) -> Iterator[Data]: except ConnectionClosedOK: return + # This overload structure is required to avoid the error: + # "parameter without a default follows parameter with a default" + + @overload + def recv(self, timeout: float | None, decode: Literal[True]) -> str: ... + + @overload + def recv(self, timeout: float | None, decode: Literal[False]) -> bytes: ... + + @overload + def recv(self, timeout: float | None = None, *, decode: Literal[True]) -> str: ... + + @overload + def recv( + self, timeout: float | None = None, *, decode: Literal[False] + ) -> bytes: ... + + @overload + def recv( + self, timeout: float | None = None, decode: bool | None = None + ) -> Data: ... + def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data: """ Receive the next message. @@ -311,6 +333,15 @@ def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data self.recv_events_thread.join() raise self.protocol.close_exc from self.recv_exc + @overload + def recv_streaming(self, decode: Literal[True]) -> Iterator[str]: ... + + @overload + def recv_streaming(self, decode: Literal[False]) -> Iterator[bytes]: ... + + @overload + def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: ... + def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: """ Receive the next message frame by frame. diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index dfabedd65..c619e78a1 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -3,7 +3,7 @@ import codecs import queue import threading -from typing import Any, Callable, Iterable, Iterator +from typing import Any, Callable, Iterable, Iterator, Literal, overload from ..exceptions import ConcurrencyError from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame @@ -110,6 +110,24 @@ def reset_queue(self, frames: Iterable[Frame]) -> None: for frame in queued: # pragma: no cover self.frames.put(frame) + # This overload structure is required to avoid the error: + # "parameter without a default follows parameter with a default" + + @overload + def get(self, timeout: float | None, decode: Literal[True]) -> str: ... + + @overload + def get(self, timeout: float | None, decode: Literal[False]) -> bytes: ... + + @overload + def get(self, timeout: float | None = None, *, decode: Literal[True]) -> str: ... + + @overload + def get(self, timeout: float | None = None, *, decode: Literal[False]) -> bytes: ... + + @overload + def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: ... + def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: """ Read the next message. @@ -181,6 +199,15 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: else: return data + @overload + def get_iter(self, decode: Literal[True]) -> Iterator[str]: ... + + @overload + def get_iter(self, decode: Literal[False]) -> Iterator[bytes]: ... + + @overload + def get_iter(self, decode: bool | None = None) -> Iterator[Data]: ... + def get_iter(self, decode: bool | None = None) -> Iterator[Data]: """ Stream the next message.