Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type overloads for recv and recv_streaming. #1579

Merged
merged 2 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/project/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ New features
* Added :doc:`keepalive and latency measurement <../topics/keepalive>` to the
:mod:`threading` implementation.

Improvements
............

* Added type overloads for the ``decode`` argument of
:meth:`~asyncio.connection.Connection.recv`. This may simplify static typing.

.. _14.2:

14.2
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@ exclude_lines = [
"except ImportError:",
"if self.debug:",
"if sys.platform != \"win32\":",
"if typing.TYPE_CHECKING:",
"if TYPE_CHECKING:",
"raise AssertionError",
"self.fail\\(\".*\"\\)",
"@overload",
"@unittest.skip",
]
partial_branches = [
Expand Down
5 changes: 3 additions & 2 deletions src/websockets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import typing
# Importing the typing module would conflict with websockets.typing.
from typing import TYPE_CHECKING

from .imports import lazy_import
from .version import version as __version__ # noqa: F401
Expand Down Expand Up @@ -72,7 +73,7 @@
]

# When type checking, import non-deprecated aliases eagerly. Else, import on demand.
if typing.TYPE_CHECKING:
if TYPE_CHECKING:
from .asyncio.client import ClientConnection, connect, unix_connect
from .asyncio.server import (
Server,
Expand Down
20 changes: 19 additions & 1 deletion src/websockets/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import uuid
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping
from types import TracebackType
from typing import Any, cast
from typing import Any, Literal, cast, overload

from ..exceptions import (
ConcurrencyError,
Expand Down Expand Up @@ -243,6 +243,15 @@ async def __aiter__(self) -> AsyncIterator[Data]:
except ConnectionClosedOK:
return

@overload
async def recv(self, decode: Literal[True]) -> str: ...

@overload
async def recv(self, decode: Literal[False]) -> bytes: ...

@overload
async def recv(self, decode: bool | None = None) -> Data: ...

async def recv(self, decode: bool | None = None) -> Data:
"""
Receive the next message.
Expand Down Expand Up @@ -312,6 +321,15 @@ async def recv(self, decode: bool | None = None) -> Data:
await asyncio.shield(self.connection_lost_waiter)
raise self.protocol.close_exc from self.recv_exc

@overload
def recv_streaming(self, decode: Literal[True]) -> AsyncIterator[str]: ...

@overload
def recv_streaming(self, decode: Literal[False]) -> AsyncIterator[bytes]: ...

@overload
def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: ...

async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]:
"""
Receive the next message frame by frame.
Expand Down
20 changes: 19 additions & 1 deletion src/websockets/asyncio/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import codecs
import collections
from collections.abc import AsyncIterator, Iterable
from typing import Any, Callable, Generic, TypeVar
from typing import Any, Callable, Generic, Literal, TypeVar, overload

from ..exceptions import ConcurrencyError
from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
Expand Down Expand Up @@ -116,6 +116,15 @@ def __init__( # pragma: no cover
# This flag marks the end of the connection.
self.closed = False

@overload
async def get(self, decode: Literal[True]) -> str: ...

@overload
async def get(self, decode: Literal[False]) -> bytes: ...

@overload
async def get(self, decode: bool | None = None) -> Data: ...

async def get(self, decode: bool | None = None) -> Data:
"""
Read the next message.
Expand Down Expand Up @@ -176,6 +185,15 @@ async def get(self, decode: bool | None = None) -> Data:
else:
return data

@overload
def get_iter(self, decode: Literal[True]) -> AsyncIterator[str]: ...

@overload
def get_iter(self, decode: Literal[False]) -> AsyncIterator[bytes]: ...

@overload
def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: ...

async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
"""
Stream the next message.
Expand Down
33 changes: 32 additions & 1 deletion src/websockets/sync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import uuid
from collections.abc import Iterable, Iterator, Mapping
from types import TracebackType
from typing import Any
from typing import Any, Literal, overload

from ..exceptions import (
ConcurrencyError,
Expand Down Expand Up @@ -241,6 +241,28 @@ def __iter__(self) -> Iterator[Data]:
except ConnectionClosedOK:
return

# This overload structure is required to avoid the error:
# "parameter without a default follows parameter with a default"

@overload
def recv(self, timeout: float | None, decode: Literal[True]) -> str: ...

@overload
def recv(self, timeout: float | None, decode: Literal[False]) -> bytes: ...

@overload
def recv(self, timeout: float | None = None, *, decode: Literal[True]) -> str: ...

@overload
def recv(
self, timeout: float | None = None, *, decode: Literal[False]
) -> bytes: ...

@overload
def recv(
self, timeout: float | None = None, decode: bool | None = None
) -> Data: ...

def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data:
"""
Receive the next message.
Expand Down Expand Up @@ -311,6 +333,15 @@ def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data
self.recv_events_thread.join()
raise self.protocol.close_exc from self.recv_exc

@overload
def recv_streaming(self, decode: Literal[True]) -> Iterator[str]: ...

@overload
def recv_streaming(self, decode: Literal[False]) -> Iterator[bytes]: ...

@overload
def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: ...

def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]:
"""
Receive the next message frame by frame.
Expand Down
29 changes: 28 additions & 1 deletion src/websockets/sync/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import codecs
import queue
import threading
from typing import Any, Callable, Iterable, Iterator
from typing import Any, Callable, Iterable, Iterator, Literal, overload

from ..exceptions import ConcurrencyError
from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
Expand Down Expand Up @@ -110,6 +110,24 @@ def reset_queue(self, frames: Iterable[Frame]) -> None:
for frame in queued: # pragma: no cover
self.frames.put(frame)

# This overload structure is required to avoid the error:
# "parameter without a default follows parameter with a default"

@overload
def get(self, timeout: float | None, decode: Literal[True]) -> str: ...

@overload
def get(self, timeout: float | None, decode: Literal[False]) -> bytes: ...

@overload
def get(self, timeout: float | None = None, *, decode: Literal[True]) -> str: ...

@overload
def get(self, timeout: float | None = None, *, decode: Literal[False]) -> bytes: ...

@overload
def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: ...

def get(self, timeout: float | None = None, decode: bool | None = None) -> Data:
"""
Read the next message.
Expand Down Expand Up @@ -181,6 +199,15 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data:
else:
return data

@overload
def get_iter(self, decode: Literal[True]) -> Iterator[str]: ...

@overload
def get_iter(self, decode: Literal[False]) -> Iterator[bytes]: ...

@overload
def get_iter(self, decode: bool | None = None) -> Iterator[Data]: ...

def get_iter(self, decode: bool | None = None) -> Iterator[Data]:
"""
Stream the next message.
Expand Down
5 changes: 2 additions & 3 deletions src/websockets/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

import http
import logging
import typing
from typing import Any, NewType, Optional, Union
from typing import TYPE_CHECKING, Any, NewType, Optional, Union


__all__ = [
Expand Down Expand Up @@ -31,7 +30,7 @@


# Change to logging.Logger | ... when dropping Python < 3.10.
if typing.TYPE_CHECKING:
if TYPE_CHECKING:
LoggerLike = Union[logging.Logger, logging.LoggerAdapter[Any]]
"""Types accepted where a :class:`~logging.Logger` is expected."""
else: # remove this branch when dropping support for Python < 3.11
Expand Down
Loading