Skip to content

Commit b04d272

Browse files
committed
Add asyncio message assembler.
1 parent e29f3b3 commit b04d272

File tree

5 files changed

+774
-4
lines changed

5 files changed

+774
-4
lines changed

src/websockets/asyncio/compatibility.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,22 @@
33
import sys
44

55

6-
__all__ = ["asyncio_timeout"]
6+
__all__ = ["TimeoutError", "aiter", "anext", "asyncio_timeout"]
77

88

99
if sys.version_info[:2] >= (3, 11):
10-
from asyncio import timeout as asyncio_timeout # noqa: F401
11-
else:
12-
from .async_timeout import timeout as asyncio_timeout # noqa: F401
10+
TimeoutError = TimeoutError
11+
aiter = aiter
12+
anext = anext
13+
from asyncio import timeout as asyncio_timeout
14+
15+
else: # Python < 3.11
16+
from asyncio import TimeoutError
17+
18+
def aiter(async_iterable):
19+
return type(async_iterable).__aiter__(async_iterable)
20+
21+
async def anext(async_iterator):
22+
return await type(async_iterator).__anext__(async_iterator)
23+
24+
from .async_timeout import timeout as asyncio_timeout

src/websockets/asyncio/messages.py

+282
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import codecs
5+
import collections
6+
from typing import (
7+
Any,
8+
AsyncIterator,
9+
Callable,
10+
Generic,
11+
Iterable,
12+
TypeVar,
13+
)
14+
15+
from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
16+
from ..typing import Data
17+
18+
19+
__all__ = ["Assembler"]
20+
21+
UTF8Decoder = codecs.getincrementaldecoder("utf-8")
22+
23+
T = TypeVar("T")
24+
25+
26+
class SimpleQueue(Generic[T]):
27+
"""
28+
Simplified version of asyncio.Queue.
29+
30+
Doesn't support maxsize nor concurrent calls to get().
31+
32+
"""
33+
34+
def __init__(self) -> None:
35+
self.loop = asyncio.get_running_loop()
36+
self.get_waiter: asyncio.Future[None] | None = None
37+
self.queue: collections.deque[T] = collections.deque()
38+
39+
def __len__(self) -> int:
40+
return len(self.queue)
41+
42+
def put(self, item: T) -> None:
43+
"""Put an item into the queue without waiting."""
44+
self.queue.append(item)
45+
if self.get_waiter is not None and not self.get_waiter.done():
46+
self.get_waiter.set_result(None)
47+
48+
async def get(self) -> T:
49+
"""Remove and return an item from the queue, waiting if necessary."""
50+
if not self.queue:
51+
if self.get_waiter is not None:
52+
raise RuntimeError("get is already running")
53+
self.get_waiter = self.loop.create_future()
54+
try:
55+
await self.get_waiter
56+
finally:
57+
self.get_waiter.cancel()
58+
self.get_waiter = None
59+
return self.queue.popleft()
60+
61+
def reset(self, items: Iterable[T]) -> None:
62+
"""Put back items into an empty queue."""
63+
assert self.get_waiter is None
64+
assert not self.queue
65+
self.queue.extend(items)
66+
67+
def abort(self) -> None:
68+
if self.get_waiter is not None and not self.get_waiter.done():
69+
self.get_waiter.set_exception(EOFError("stream of frames ended"))
70+
# Clear the queue to avoid storing unnecessary data in memory.
71+
self.queue.clear()
72+
73+
74+
class Assembler:
75+
"""
76+
Assemble messages from frames.
77+
78+
:class:`Assembler` expects only data frames. The stream of frames must
79+
respect the protocol; if it doesn't, the behavior is undefined.
80+
81+
Args:
82+
pause: Called when the buffer of frames goes above the high water mark;
83+
should pause reading from the network.
84+
resume: Called when the buffer of frames goes below the low water mark;
85+
should resume reading from the network.
86+
87+
"""
88+
89+
# coverage reports incorrectly: "line NN didn't jump to the function exit"
90+
def __init__( # pragma: no cover
91+
self,
92+
pause: Callable[[], Any] = lambda: None,
93+
resume: Callable[[], Any] = lambda: None,
94+
) -> None:
95+
# Queue of incoming messages. Each item is a queue of frames.
96+
self.frames: SimpleQueue[Frame] = SimpleQueue()
97+
98+
# We cannot put a hard limit on the size of the queues because a single
99+
# call to Protocol.data_received() could produce thousands of frames,
100+
# which must be buffered. Instead, we pause reading when the buffer goes
101+
# above the high limit and we resume when it goes under the low limit.
102+
self.high = 16
103+
self.low = 4
104+
self.paused = False
105+
self.pause = pause
106+
self.resume = resume
107+
108+
# This flag prevents concurrent calls to get() by user code.
109+
self.get_in_progress = False
110+
111+
# This flag marks the end of the connection.
112+
self.closed = False
113+
114+
async def get(self, decode: bool | None = None) -> Data:
115+
"""
116+
Read the next message.
117+
118+
:meth:`get` returns a single :class:`str` or :class:`bytes`.
119+
120+
If the message is fragmented, :meth:`get` waits until the last frame is
121+
received, then it reassembles the message and returns it. To receive
122+
messages frame by frame, use :meth:`get_iter` instead.
123+
124+
Raises:
125+
EOFError: If the stream of frames has ended.
126+
RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter`
127+
concurrently.
128+
129+
"""
130+
if self.closed:
131+
raise EOFError("stream of frames ended")
132+
133+
if self.get_in_progress:
134+
raise RuntimeError("get or get_iter is already running")
135+
136+
# Locking with get_in_progress ensures only one coroutine can get here.
137+
self.get_in_progress = True
138+
139+
# First frame
140+
try:
141+
frame = await self.frames.get()
142+
except asyncio.CancelledError:
143+
self.get_in_progress = False
144+
raise
145+
self.maybe_resume()
146+
assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
147+
if decode is None:
148+
decode = frame.opcode is OP_TEXT
149+
frames = [frame]
150+
151+
# Following frames, for fragmented messages
152+
while not frame.fin:
153+
try:
154+
frame = await self.frames.get()
155+
except asyncio.CancelledError:
156+
# Put frames already received back into the queue.
157+
self.frames.reset(frames)
158+
self.get_in_progress = False
159+
raise
160+
self.maybe_resume()
161+
assert frame.opcode is OP_CONT
162+
frames.append(frame)
163+
164+
self.get_in_progress = False
165+
166+
data = b"".join(frame.data for frame in frames)
167+
if decode:
168+
return data.decode()
169+
else:
170+
return data
171+
172+
async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
173+
"""
174+
Stream the next message.
175+
176+
Iterating the return value of :meth:`get_iter` asynchronously yields a
177+
:class:`str` or :class:`bytes` for each frame in the message.
178+
179+
The iterator must be fully consumed before calling :meth:`get_iter` or
180+
:meth:`get` again. Else, :exc:`RuntimeError` is raised.
181+
182+
This method only makes sense for fragmented messages. If messages aren't
183+
fragmented, use :meth:`get` instead.
184+
185+
Raises:
186+
EOFError: If the stream of frames has ended.
187+
RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter`
188+
concurrently.
189+
190+
"""
191+
if self.closed:
192+
raise EOFError("stream of frames ended")
193+
194+
if self.get_in_progress:
195+
raise RuntimeError("get or get_iter is already running")
196+
197+
# Locking with get_in_progress ensures only one coroutine can get here.
198+
self.get_in_progress = True
199+
200+
# First frame
201+
try:
202+
frame = await self.frames.get()
203+
except asyncio.CancelledError:
204+
self.get_in_progress = False
205+
raise
206+
self.maybe_resume()
207+
assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
208+
if decode is None:
209+
decode = frame.opcode is OP_TEXT
210+
if decode:
211+
decoder = UTF8Decoder()
212+
yield decoder.decode(frame.data, frame.fin)
213+
else:
214+
yield frame.data
215+
216+
# Following frames, for fragmented messages
217+
while not frame.fin:
218+
# We cannot handle asyncio.CancelledError because we don't buffer
219+
# previous fragments — we're streaming them. Canceling get_iter()
220+
# here will leave the assembler in a stuck state. Future calls to
221+
# get() or get_iter() will raise RuntimeError.
222+
frame = await self.frames.get()
223+
self.maybe_resume()
224+
assert frame.opcode is OP_CONT
225+
if decode:
226+
yield decoder.decode(frame.data, frame.fin)
227+
else:
228+
yield frame.data
229+
230+
self.get_in_progress = False
231+
232+
def put(self, frame: Frame) -> None:
233+
"""
234+
Add ``frame`` to the next message.
235+
236+
Raises:
237+
EOFError: If the stream of frames has ended.
238+
239+
"""
240+
if self.closed:
241+
raise EOFError("stream of frames ended")
242+
243+
self.frames.put(frame)
244+
self.maybe_pause()
245+
246+
def get_limits(self) -> tuple[int, int]:
247+
"""Return low and high water marks for flow control."""
248+
return self.low, self.high
249+
250+
def set_limits(self, low: int = 4, high: int = 16) -> None:
251+
"""Configure low and high water marks for flow control."""
252+
self.low, self.high = low, high
253+
254+
def maybe_pause(self) -> None:
255+
"""Pause the writer if queue is above the high water mark."""
256+
# Check for "> high" to support high = 0
257+
if len(self.frames) > self.high and not self.paused:
258+
self.paused = True
259+
self.pause()
260+
261+
def maybe_resume(self) -> None:
262+
"""Resume the writer if queue is below the low water mark."""
263+
# Check for "<= low" to support low = 0
264+
if len(self.frames) <= self.low and self.paused:
265+
self.paused = False
266+
self.resume()
267+
268+
def close(self) -> None:
269+
"""
270+
End the stream of frames.
271+
272+
Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`,
273+
or :meth:`put` is safe. They will raise :exc:`EOFError`.
274+
275+
"""
276+
if self.closed:
277+
return
278+
279+
self.closed = True
280+
281+
# Unblock get or get_iter.
282+
self.frames.abort()

tests/asyncio/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)