Skip to content

Commit 4037a91

Browse files
committed
Add asyncio message assembler.
1 parent dbada9d commit 4037a91

File tree

3 files changed

+656
-0
lines changed

3 files changed

+656
-0
lines changed

src/websockets/asyncio/messages.py

+254
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import codecs
5+
import collections
6+
from typing import (
7+
Any,
8+
AsyncIterator,
9+
Callable,
10+
Generic,
11+
TypeVar,
12+
)
13+
14+
from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
15+
from ..typing import Data
16+
17+
18+
__all__ = ["Assembler"]
19+
20+
UTF8Decoder = codecs.getincrementaldecoder("utf-8")
21+
22+
T = TypeVar("T")
23+
24+
25+
class SimpleQueue(Generic[T]):
26+
"""
27+
Simplified version of asyncio.Queue.
28+
29+
Doesn't support maxsize nor concurrent calls to get().
30+
31+
"""
32+
33+
def __init__(self) -> None:
34+
self.loop = asyncio.get_running_loop()
35+
self.get_waiter: asyncio.Future[None] | None = None
36+
self.queue: collections.deque[T] = collections.deque()
37+
38+
def __len__(self) -> int:
39+
return len(self.queue)
40+
41+
def put(self, item: T) -> None:
42+
"""Put an item into the queue without waiting."""
43+
self.queue.append(item)
44+
if self.get_waiter is not None and not self.get_waiter.done():
45+
self.get_waiter.set_result(None)
46+
47+
async def get(self) -> T:
48+
"""Remove and return an item from the queue, waiting if necessary."""
49+
if not self.queue:
50+
if self.get_waiter is not None:
51+
raise RuntimeError("get is already running")
52+
self.get_waiter = self.loop.create_future()
53+
try:
54+
await self.get_waiter
55+
finally:
56+
self.get_waiter.cancel()
57+
self.get_waiter = None
58+
return self.queue.popleft()
59+
60+
def abort(self) -> None:
61+
if self.get_waiter is not None and not self.get_waiter.done():
62+
self.get_waiter.set_exception(EOFError("stream of frames ended"))
63+
self.queue.clear() # TODO add test for this
64+
65+
66+
class Assembler:
67+
"""
68+
Assemble messages from frames.
69+
70+
:class:`Assembler` expects only data frames. The stream of frames must
71+
respect the protocol; if it doesn't, the behavior is undefined.
72+
73+
Args:
74+
pause: Called when the buffer of frames goes above the high water mark;
75+
should pause reading from the network.
76+
resume: Called when the buffer of frames goes below the low water mark;
77+
should resume reading from the network.
78+
79+
"""
80+
81+
# coverage reports incorrectly: "line NN didn't jump to the function exit"
82+
def __init__( # pragma: no cover
83+
self,
84+
pause: Callable[[], Any] = lambda: None,
85+
resume: Callable[[], Any] = lambda: None,
86+
) -> None:
87+
# Queue of incoming messages. Each item is a queue of frames.
88+
self.frames: SimpleQueue[Frame] = SimpleQueue()
89+
90+
# We cannot put a hard limit on the size of the queues because a single
91+
# call to Protocol.data_received() could produce thousands of frames,
92+
# which must be buffered. Instead, we pause reading when the buffer goes
93+
# above the high limit and we resume when it goes under the low limit.
94+
self.high = 16
95+
self.low = 4
96+
self.paused = False
97+
self.pause = pause
98+
self.resume = resume
99+
100+
# This flag prevents concurrent calls to get() by user code.
101+
self.get_in_progress = False
102+
103+
# This flag marks the end of the connection.
104+
self.closed = False
105+
106+
async def get(self, decode: bool | None = None) -> Data:
107+
"""
108+
Read the next message.
109+
110+
:meth:`get` returns a single :class:`str` or :class:`bytes`.
111+
112+
If the message is fragmented, :meth:`get` waits until the last frame is
113+
received, then it reassembles the message and returns it. To receive
114+
messages frame by frame, use :meth:`get_iter` instead.
115+
116+
Raises:
117+
EOFError: If the stream of frames has ended.
118+
RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter`
119+
concurrently.
120+
121+
"""
122+
if self.closed:
123+
raise EOFError("stream of frames ended")
124+
125+
if self.get_in_progress:
126+
raise RuntimeError("get or get_iter is already running")
127+
128+
# Locking with get_in_progress ensures only one coroutine can get here.
129+
self.get_in_progress = True
130+
try:
131+
# First frame
132+
frame = await self.frames.get()
133+
self.maybe_resume()
134+
assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
135+
if decode is None:
136+
decode = frame.opcode is OP_TEXT
137+
frames = [frame]
138+
# Following frames, for fragmented messages
139+
while not frame.fin:
140+
frame = await self.frames.get()
141+
self.maybe_resume()
142+
assert frame.opcode is OP_CONT
143+
frames.append(frame)
144+
finally:
145+
self.get_in_progress = False
146+
147+
data = b"".join(frame.data for frame in frames)
148+
if decode:
149+
return data.decode()
150+
else:
151+
return data
152+
153+
async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
154+
"""
155+
Stream the next message.
156+
157+
Iterating the return value of :meth:`get_iter` asynchronously yields a
158+
:class:`str` or :class:`bytes` for each frame in the message.
159+
160+
The iterator must be fully consumed before calling :meth:`get_iter` or
161+
:meth:`get` again. Else, :exc:`RuntimeError` is raised.
162+
163+
This method only makes sense for fragmented messages. If messages aren't
164+
fragmented, use :meth:`get` instead.
165+
166+
Raises:
167+
EOFError: If the stream of frames has ended.
168+
RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter`
169+
concurrently.
170+
171+
"""
172+
if self.closed:
173+
raise EOFError("stream of frames ended")
174+
175+
if self.get_in_progress:
176+
raise RuntimeError("get or get_iter is already running")
177+
178+
# Locking with get_in_progress ensures only one coroutine can get here.
179+
self.get_in_progress = True
180+
try:
181+
# First frame
182+
frame = await self.frames.get()
183+
self.maybe_resume()
184+
assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
185+
if decode is None:
186+
decode = frame.opcode is OP_TEXT
187+
if decode:
188+
decoder = UTF8Decoder()
189+
yield decoder.decode(frame.data, frame.fin)
190+
else:
191+
yield frame.data
192+
# Following frames, for fragmented messages
193+
while not frame.fin:
194+
frame = await self.frames.get()
195+
self.maybe_resume()
196+
assert frame.opcode is OP_CONT
197+
if decode:
198+
yield decoder.decode(frame.data, frame.fin)
199+
else:
200+
yield frame.data
201+
finally:
202+
self.get_in_progress = False
203+
204+
def put(self, frame: Frame) -> None:
205+
"""
206+
Add ``frame`` to the next message.
207+
208+
Raises:
209+
EOFError: If the stream of frames has ended.
210+
211+
"""
212+
if self.closed:
213+
raise EOFError("stream of frames ended")
214+
215+
self.frames.put(frame)
216+
self.maybe_pause()
217+
218+
def get_limits(self) -> tuple[int, int]:
219+
"""Return low and high water marks for flow control."""
220+
return self.low, self.high
221+
222+
def set_limits(self, low: int = 4, high: int = 16) -> None:
223+
"""Configure low and high water marks for flow control."""
224+
self.low, self.high = low, high
225+
226+
def maybe_pause(self) -> None:
227+
"""Pause the writer if queue is above the high water mark."""
228+
# Check for "> high" to support high = 0
229+
if len(self.frames) > self.high and not self.paused:
230+
self.paused = True
231+
self.pause()
232+
233+
def maybe_resume(self) -> None:
234+
"""Resume the writer if queue is below the low water mark."""
235+
# Check for "<= low" to support low = 0
236+
if len(self.frames) <= self.low and self.paused:
237+
self.paused = False
238+
self.resume()
239+
240+
def close(self) -> None:
241+
"""
242+
End the stream of frames.
243+
244+
Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`,
245+
or :meth:`put` is safe. They will raise :exc:`EOFError`.
246+
247+
"""
248+
if self.closed:
249+
return
250+
251+
self.closed = True
252+
253+
# Unblock get or get_iter.
254+
self.frames.abort()

tests/asyncio/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)