Skip to content

Commit bbfb87e

Browse files
committed
Add asyncio message assembler.
1 parent 127d56e commit bbfb87e

File tree

3 files changed

+625
-0
lines changed

3 files changed

+625
-0
lines changed

src/websockets/asyncio/messages.py

+239
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import codecs
5+
import collections
6+
from typing import (
7+
Any,
8+
AsyncIterator,
9+
Callable,
10+
Generic,
11+
Optional,
12+
TypeVar,
13+
)
14+
15+
from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
16+
from ..typing import Data
17+
18+
19+
__all__ = ["Assembler"]
20+
21+
UTF8Decoder = codecs.getincrementaldecoder("utf-8")
22+
23+
T = TypeVar("T")
24+
25+
26+
class SimpleQueue(Generic[T]):
27+
"""
28+
Simplified version of asyncio.Queue.
29+
30+
Doesn't support maxsize nor concurrent calls to get().
31+
32+
"""
33+
34+
def __init__(self) -> None:
35+
self.loop = asyncio.get_running_loop()
36+
self.get_waiter: Optional[asyncio.Future[None]] = None
37+
self.queue: collections.deque[T] = collections.deque()
38+
39+
def __len__(self) -> int:
40+
return len(self.queue)
41+
42+
def put(self, item: T) -> None:
43+
"""Put an item into the queue without waiting."""
44+
self.queue.append(item)
45+
if self.get_waiter is not None and not self.get_waiter.done():
46+
self.get_waiter.set_result(None)
47+
48+
async def get(self) -> T:
49+
"""Remove and return an item from the queue, waiting if necessary."""
50+
if not self.queue:
51+
if self.get_waiter is not None:
52+
raise RuntimeError("get is already running")
53+
self.get_waiter = self.loop.create_future()
54+
try:
55+
await self.get_waiter
56+
finally:
57+
self.get_waiter.cancel()
58+
self.get_waiter = None
59+
return self.queue.popleft()
60+
61+
def abort(self) -> None:
62+
if self.get_waiter is not None and not self.get_waiter.done():
63+
self.get_waiter.set_exception(EOFError("stream of frames ended"))
64+
65+
66+
class Assembler:
67+
"""
68+
Assemble messages from frames.
69+
70+
:class:`Assembler` expects only data frame and that the stream of
71+
frames respects the protocol. If it doesn't, the behavior is undefined.
72+
73+
"""
74+
75+
def __init__(
76+
self,
77+
high: int = 16,
78+
low: int = 4,
79+
pause: Callable[[], Any] = lambda: None,
80+
resume: Callable[[], Any] = lambda: None,
81+
) -> None:
82+
# Queue of incoming messages. Each item is a queue of frames.
83+
self.frames: SimpleQueue[Frame] = SimpleQueue()
84+
85+
# We cannot put a hard limit on the size of the queues because a single
86+
# call to Protocol.data_received() could produce thousands of frames,
87+
# which must be buffered. Instead, we pause reading when the buffer goes
88+
# above the high limit and we resume when it goes under the low limit.
89+
self.paused = False
90+
self.high = high
91+
self.low = low
92+
self.pause = pause
93+
self.resume = resume
94+
95+
# This flag prevents concurrent calls to get() by user code.
96+
self.get_in_progress = False
97+
98+
# This flag marks the end of the connection.
99+
self.closed = False
100+
101+
async def get(self, decode: Optional[bool] = None) -> Data:
102+
"""
103+
Read the next message.
104+
105+
:meth:`get` returns a single :class:`str` or :class:`bytes`.
106+
107+
If the message is fragmented, :meth:`get` waits until the last frame is
108+
received, then it reassembles the message and returns it. To receive
109+
messages frame by frame, use :meth:`get_iter` instead.
110+
111+
Raises:
112+
EOFError: If the stream of frames has ended.
113+
RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter`
114+
concurrently.
115+
116+
"""
117+
if self.closed:
118+
raise EOFError("stream of frames ended")
119+
120+
if self.get_in_progress:
121+
raise RuntimeError("get or get_iter is already running")
122+
123+
# Locking with get_in_progress ensures only one coroutine can get here.
124+
self.get_in_progress = True
125+
try:
126+
# First frame
127+
frame = await self.frames.get()
128+
self.maybe_resume()
129+
assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
130+
if decode is None:
131+
decode = frame.opcode is OP_TEXT
132+
frames = [frame]
133+
# Following frames, for fragmented messages
134+
while not frame.fin:
135+
frame = await self.frames.get()
136+
self.maybe_resume()
137+
assert frame.opcode is OP_CONT
138+
frames.append(frame)
139+
finally:
140+
self.get_in_progress = False
141+
142+
data = b"".join(frame.data for frame in frames)
143+
if decode:
144+
return data.decode()
145+
else:
146+
return data
147+
148+
async def get_iter(self, decode: Optional[bool] = None) -> AsyncIterator[Data]:
149+
"""
150+
Stream the next message.
151+
152+
Iterating the return value of :meth:`get_iter` asynchronously yields a
153+
:class:`str` or :class:`bytes` for each frame in the message.
154+
155+
The iterator must be fully consumed before calling :meth:`get_iter` or
156+
:meth:`get` again. Else, :exc:`RuntimeError` is raised.
157+
158+
This method only makes sense for fragmented messages. If messages aren't
159+
fragmented, use :meth:`get` instead.
160+
161+
Raises:
162+
EOFError: If the stream of frames has ended.
163+
RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter`
164+
concurrently.
165+
166+
"""
167+
if self.closed:
168+
raise EOFError("stream of frames ended")
169+
170+
if self.get_in_progress:
171+
raise RuntimeError("get or get_iter is already running")
172+
173+
# Locking with get_in_progress ensures only one coroutine can get here.
174+
self.get_in_progress = True
175+
try:
176+
# First frame
177+
frame = await self.frames.get()
178+
self.maybe_resume()
179+
assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
180+
if decode is None:
181+
decode = frame.opcode is OP_TEXT
182+
if decode:
183+
decoder = UTF8Decoder()
184+
yield decoder.decode(frame.data, frame.fin)
185+
else:
186+
yield frame.data
187+
# Following frames, for fragmented messages
188+
while not frame.fin:
189+
frame = await self.frames.get()
190+
self.maybe_resume()
191+
assert frame.opcode is OP_CONT
192+
if decode:
193+
yield decoder.decode(frame.data, frame.fin)
194+
else:
195+
yield frame.data
196+
finally:
197+
self.get_in_progress = False
198+
199+
def put(self, frame: Frame) -> None:
200+
"""
201+
Add ``frame`` to the next message.
202+
203+
Raises:
204+
EOFError: If the stream of frames has ended.
205+
206+
"""
207+
if self.closed:
208+
raise EOFError("stream of frames ended")
209+
210+
self.frames.put(frame)
211+
self.maybe_pause()
212+
213+
def maybe_resume(self) -> None:
214+
"""Resume the writer if queue is below the low water mark."""
215+
if len(self.frames) < self.low and self.paused:
216+
self.paused = False
217+
self.resume()
218+
219+
def maybe_pause(self) -> None:
220+
"""Pause the writer if queue is above the high water mark."""
221+
if len(self.frames) >= self.high and not self.paused:
222+
self.paused = True
223+
self.pause()
224+
225+
def close(self) -> None:
226+
"""
227+
End the stream of frames.
228+
229+
Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`,
230+
or :meth:`put` is safe. They will raise :exc:`EOFError`.
231+
232+
"""
233+
if self.closed:
234+
return
235+
236+
self.closed = True
237+
238+
# Unblock get or get_iter.
239+
self.frames.abort()

tests/asyncio/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)