Skip to content

Commit 2feed9f

Browse files
committed
Add type overloads for recv and recv_streaming.
Fix #1578.
1 parent 8f12d8f commit 2feed9f

File tree

6 files changed

+105
-4
lines changed

6 files changed

+105
-4
lines changed

docs/project/changelog.rst

+6
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ New features
3838
* Added :doc:`keepalive and latency measurement <../topics/keepalive>` to the
3939
:mod:`threading` implementation.
4040

41+
Improvements
42+
............
43+
44+
* Added type overloads for the ``decode`` argument of
45+
:meth:`~asyncio.connection.Connection.recv`. This may simplify static typing.
46+
4147
.. _14.2:
4248

4349
14.2

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ exclude_lines = [
7070
"if typing.TYPE_CHECKING:",
7171
"raise AssertionError",
7272
"self.fail\\(\".*\"\\)",
73+
"@overload",
7374
"@unittest.skip",
7475
]
7576
partial_branches = [

src/websockets/asyncio/connection.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import uuid
1212
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping
1313
from types import TracebackType
14-
from typing import Any, cast
14+
from typing import Any, Literal, cast, overload
1515

1616
from ..exceptions import (
1717
ConcurrencyError,
@@ -243,6 +243,15 @@ async def __aiter__(self) -> AsyncIterator[Data]:
243243
except ConnectionClosedOK:
244244
return
245245

246+
@overload
247+
async def recv(self, decode: Literal[True]) -> str: ...
248+
249+
@overload
250+
async def recv(self, decode: Literal[False]) -> bytes: ...
251+
252+
@overload
253+
async def recv(self, decode: bool | None = None) -> Data: ...
254+
246255
async def recv(self, decode: bool | None = None) -> Data:
247256
"""
248257
Receive the next message.
@@ -312,6 +321,15 @@ async def recv(self, decode: bool | None = None) -> Data:
312321
await asyncio.shield(self.connection_lost_waiter)
313322
raise self.protocol.close_exc from self.recv_exc
314323

324+
@overload
325+
def recv_streaming(self, decode: Literal[True]) -> AsyncIterator[str]: ...
326+
327+
@overload
328+
def recv_streaming(self, decode: Literal[False]) -> AsyncIterator[bytes]: ...
329+
330+
@overload
331+
def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: ...
332+
315333
async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]:
316334
"""
317335
Receive the next message frame by frame.

src/websockets/asyncio/messages.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import codecs
55
import collections
66
from collections.abc import AsyncIterator, Iterable
7-
from typing import Any, Callable, Generic, TypeVar
7+
from typing import Any, Callable, Generic, Literal, TypeVar, overload
88

99
from ..exceptions import ConcurrencyError
1010
from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
@@ -116,6 +116,15 @@ def __init__( # pragma: no cover
116116
# This flag marks the end of the connection.
117117
self.closed = False
118118

119+
@overload
120+
async def get(self, decode: Literal[True]) -> str: ...
121+
122+
@overload
123+
async def get(self, decode: Literal[False]) -> bytes: ...
124+
125+
@overload
126+
async def get(self, decode: bool | None = None) -> Data: ...
127+
119128
async def get(self, decode: bool | None = None) -> Data:
120129
"""
121130
Read the next message.
@@ -176,6 +185,15 @@ async def get(self, decode: bool | None = None) -> Data:
176185
else:
177186
return data
178187

188+
@overload
189+
def get_iter(self, decode: Literal[True]) -> AsyncIterator[str]: ...
190+
191+
@overload
192+
def get_iter(self, decode: Literal[False]) -> AsyncIterator[bytes]: ...
193+
194+
@overload
195+
def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: ...
196+
179197
async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
180198
"""
181199
Stream the next message.

src/websockets/sync/connection.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import uuid
1111
from collections.abc import Iterable, Iterator, Mapping
1212
from types import TracebackType
13-
from typing import Any
13+
from typing import Any, Literal, overload
1414

1515
from ..exceptions import (
1616
ConcurrencyError,
@@ -241,6 +241,28 @@ def __iter__(self) -> Iterator[Data]:
241241
except ConnectionClosedOK:
242242
return
243243

244+
# This overload structure is required to avoid the error:
245+
# "parameter without a default follows parameter with a default"
246+
247+
@overload
248+
def recv(self, timeout: float | None, decode: Literal[True]) -> str: ...
249+
250+
@overload
251+
def recv(self, timeout: float | None, decode: Literal[False]) -> bytes: ...
252+
253+
@overload
254+
def recv(self, timeout: float | None = None, *, decode: Literal[True]) -> str: ...
255+
256+
@overload
257+
def recv(
258+
self, timeout: float | None = None, *, decode: Literal[False]
259+
) -> bytes: ...
260+
261+
@overload
262+
def recv(
263+
self, timeout: float | None = None, decode: bool | None = None
264+
) -> Data: ...
265+
244266
def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data:
245267
"""
246268
Receive the next message.
@@ -311,6 +333,15 @@ def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data
311333
self.recv_events_thread.join()
312334
raise self.protocol.close_exc from self.recv_exc
313335

336+
@overload
337+
def recv_streaming(self, decode: Literal[True]) -> Iterator[str]: ...
338+
339+
@overload
340+
def recv_streaming(self, decode: Literal[False]) -> Iterator[bytes]: ...
341+
342+
@overload
343+
def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: ...
344+
314345
def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]:
315346
"""
316347
Receive the next message frame by frame.

src/websockets/sync/messages.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import codecs
44
import queue
55
import threading
6-
from typing import Any, Callable, Iterable, Iterator
6+
from typing import Any, Callable, Iterable, Iterator, Literal, overload
77

88
from ..exceptions import ConcurrencyError
99
from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
@@ -110,6 +110,24 @@ def reset_queue(self, frames: Iterable[Frame]) -> None:
110110
for frame in queued: # pragma: no cover
111111
self.frames.put(frame)
112112

113+
# This overload structure is required to avoid the error:
114+
# "parameter without a default follows parameter with a default"
115+
116+
@overload
117+
def get(self, timeout: float | None, decode: Literal[True]) -> str: ...
118+
119+
@overload
120+
def get(self, timeout: float | None, decode: Literal[False]) -> bytes: ...
121+
122+
@overload
123+
def get(self, timeout: float | None = None, *, decode: Literal[True]) -> str: ...
124+
125+
@overload
126+
def get(self, timeout: float | None = None, *, decode: Literal[False]) -> bytes: ...
127+
128+
@overload
129+
def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: ...
130+
113131
def get(self, timeout: float | None = None, decode: bool | None = None) -> Data:
114132
"""
115133
Read the next message.
@@ -181,6 +199,15 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data:
181199
else:
182200
return data
183201

202+
@overload
203+
def get_iter(self, decode: Literal[True]) -> Iterator[str]: ...
204+
205+
@overload
206+
def get_iter(self, decode: Literal[False]) -> Iterator[bytes]: ...
207+
208+
@overload
209+
def get_iter(self, decode: bool | None = None) -> Iterator[Data]: ...
210+
184211
def get_iter(self, decode: bool | None = None) -> Iterator[Data]:
185212
"""
186213
Stream the next message.

0 commit comments

Comments
 (0)