Skip to content

Commit 34ac84c

Browse files
committed
Add type overloads for recv and recv_streaming.
Fix #1578.
1 parent 8f12d8f commit 34ac84c

File tree

4 files changed

+88
-4
lines changed

4 files changed

+88
-4
lines changed

src/websockets/asyncio/connection.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import uuid
1212
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping
1313
from types import TracebackType
14-
from typing import Any, cast
14+
from typing import Any, Literal, cast, overload
1515

1616
from ..exceptions import (
1717
ConcurrencyError,
@@ -243,6 +243,15 @@ async def __aiter__(self) -> AsyncIterator[Data]:
243243
except ConnectionClosedOK:
244244
return
245245

246+
@overload
247+
async def recv(self, decode: Literal[True] = True) -> str: ...
248+
249+
@overload
250+
async def recv(self, decode: Literal[False] = False) -> bytes: ...
251+
252+
@overload
253+
async def recv(self, decode: bool | None = None) -> Data: ...
254+
246255
async def recv(self, decode: bool | None = None) -> Data:
247256
"""
248257
Receive the next message.
@@ -312,6 +321,17 @@ async def recv(self, decode: bool | None = None) -> Data:
312321
await asyncio.shield(self.connection_lost_waiter)
313322
raise self.protocol.close_exc from self.recv_exc
314323

324+
@overload
325+
def recv_streaming(self, decode: Literal[True] = True) -> AsyncIterator[str]: ...
326+
327+
@overload
328+
def recv_streaming(
329+
self, decode: Literal[False] = False
330+
) -> AsyncIterator[bytes]: ...
331+
332+
@overload
333+
def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: ...
334+
315335
async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]:
316336
"""
317337
Receive the next message frame by frame.

src/websockets/asyncio/messages.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import codecs
55
import collections
66
from collections.abc import AsyncIterator, Iterable
7-
from typing import Any, Callable, Generic, TypeVar
7+
from typing import Any, Callable, Generic, Literal, TypeVar, overload
88

99
from ..exceptions import ConcurrencyError
1010
from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
@@ -116,6 +116,15 @@ def __init__( # pragma: no cover
116116
# This flag marks the end of the connection.
117117
self.closed = False
118118

119+
@overload
120+
async def get(self, decode: Literal[True] = True) -> str: ...
121+
122+
@overload
123+
async def get(self, decode: Literal[False] = False) -> bytes: ...
124+
125+
@overload
126+
async def get(self, decode: bool | None = None) -> Data: ...
127+
119128
async def get(self, decode: bool | None = None) -> Data:
120129
"""
121130
Read the next message.
@@ -176,6 +185,15 @@ async def get(self, decode: bool | None = None) -> Data:
176185
else:
177186
return data
178187

188+
@overload
189+
def get_iter(self, decode: Literal[True] = True) -> AsyncIterator[str]: ...
190+
191+
@overload
192+
def get_iter(self, decode: Literal[False] = False) -> AsyncIterator[bytes]: ...
193+
194+
@overload
195+
def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: ...
196+
179197
async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
180198
"""
181199
Stream the next message.

src/websockets/sync/connection.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import uuid
1111
from collections.abc import Iterable, Iterator, Mapping
1212
from types import TracebackType
13-
from typing import Any
13+
from typing import Any, Literal, overload
1414

1515
from ..exceptions import (
1616
ConcurrencyError,
@@ -241,6 +241,21 @@ def __iter__(self) -> Iterator[Data]:
241241
except ConnectionClosedOK:
242242
return
243243

244+
@overload
245+
def recv(
246+
self, timeout: float | None = None, decode: Literal[True] = True
247+
) -> str: ...
248+
249+
@overload
250+
def recv(
251+
self, timeout: float | None = None, decode: Literal[False] = False
252+
) -> bytes: ...
253+
254+
@overload
255+
def recv(
256+
self, timeout: float | None = None, decode: bool | None = None
257+
) -> Data: ...
258+
244259
def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data:
245260
"""
246261
Receive the next message.
@@ -311,6 +326,15 @@ def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data
311326
self.recv_events_thread.join()
312327
raise self.protocol.close_exc from self.recv_exc
313328

329+
@overload
330+
def recv_streaming(self, decode: Literal[True] = True) -> Iterator[str]: ...
331+
332+
@overload
333+
def recv_streaming(self, decode: Literal[False] = False) -> Iterator[bytes]: ...
334+
335+
@overload
336+
def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: ...
337+
314338
def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]:
315339
"""
316340
Receive the next message frame by frame.

src/websockets/sync/messages.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import codecs
44
import queue
55
import threading
6-
from typing import Any, Callable, Iterable, Iterator
6+
from typing import Any, Callable, Iterable, Iterator, Literal, overload
77

88
from ..exceptions import ConcurrencyError
99
from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
@@ -110,6 +110,19 @@ def reset_queue(self, frames: Iterable[Frame]) -> None:
110110
for frame in queued: # pragma: no cover
111111
self.frames.put(frame)
112112

113+
@overload
114+
def get(
115+
self, timeout: float | None = None, decode: Literal[True] = True
116+
) -> str: ...
117+
118+
@overload
119+
def get(
120+
self, timeout: float | None = None, decode: Literal[False] = False
121+
) -> bytes: ...
122+
123+
@overload
124+
def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: ...
125+
113126
def get(self, timeout: float | None = None, decode: bool | None = None) -> Data:
114127
"""
115128
Read the next message.
@@ -181,6 +194,15 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data:
181194
else:
182195
return data
183196

197+
@overload
198+
def get_iter(self, decode: Literal[True] = True) -> Iterator[str]: ...
199+
200+
@overload
201+
def get_iter(self, decode: Literal[False] = False) -> Iterator[bytes]: ...
202+
203+
@overload
204+
def get_iter(self, decode: bool | None = None) -> Iterator[Data]: ...
205+
184206
def get_iter(self, decode: bool | None = None) -> Iterator[Data]:
185207
"""
186208
Stream the next message.

0 commit comments

Comments
 (0)