Skip to content

Commit c9fc059

Browse files
committed
Add asyncio message reassembler.
1 parent 570da0a commit c9fc059

File tree

3 files changed

+652
-0
lines changed

3 files changed

+652
-0
lines changed

src/websockets/asyncio/messages.py

+261
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import codecs
5+
from typing import Iterator
6+
7+
from ..frames import Frame, Opcode
8+
from ..typing import Data, List, Optional
9+
from .compatibility import asyncio_timeout
10+
11+
12+
__all__ = ["Assembler"]
13+
14+
UTF8Decoder = codecs.getincrementaldecoder("utf-8")
15+
16+
17+
class Assembler:
18+
"""
19+
Assemble messages from frames.
20+
21+
"""
22+
23+
def __init__(self) -> None:
24+
self.loop = asyncio.get_event_loop()
25+
26+
# We create a latch with two futures to ensure proper interleaving of
27+
# writing and reading messages.
28+
# put() sets this future to tell get() that a message can be fetched.
29+
self.message_complete: asyncio.Future[None] = self.loop.create_future()
30+
# get() sets this future to let put() that the message was fetched.
31+
self.message_fetched: asyncio.Future[None] = self.loop.create_future()
32+
33+
# This flag prevents concurrent calls to get() by user code.
34+
self.get_in_progress = False
35+
# This flag prevents concurrent calls to put() by library code.
36+
self.put_in_progress = False
37+
38+
# Decoder for text frames, None for binary frames.
39+
self.decoder: Optional[codecs.IncrementalDecoder] = None
40+
41+
# Buffer of frames belonging to the same message.
42+
self.chunks: List[Data] = []
43+
44+
# When switching from "buffering" to "streaming", we use a queue for
45+
# transferring frames from the writing coroutine (library code) to the
46+
# reading coroutine (user code). We're buffering when chunks_queue is
47+
# None and streaming when it's a Queue. None is a sentinel value marking
48+
# the end of the stream, superseding message_complete.
49+
50+
# Stream data from frames belonging to the same message.
51+
self.chunks_queue: Optional[asyncio.Queue[Optional[Data]]] = None
52+
53+
# This flag marks the end of the stream.
54+
self.closed = False
55+
56+
async def get(self, timeout: Optional[float] = None) -> Data:
57+
"""
58+
Read the next message.
59+
60+
:meth:`get` returns a single :class:`str` or :class:`bytes`.
61+
62+
If the message is fragmented, :meth:`get` waits until the last frame is
63+
received, then it reassembles the message and returns it. To receive
64+
messages frame by frame, use :meth:`get_iter` instead.
65+
66+
Args:
67+
timeout: If a timeout is provided and elapses before a complete
68+
message is received, :meth:`get` raises :exc:`TimeoutError`.
69+
70+
Raises:
71+
EOFError: If the stream of frames has ended.
72+
RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter`
73+
concurrently.
74+
TimeoutError: If a timeout is provided and elapses before a
75+
complete message is received.
76+
77+
"""
78+
if self.closed:
79+
raise EOFError("stream of frames ended")
80+
81+
if self.get_in_progress:
82+
raise RuntimeError("get or get_iter is already running")
83+
84+
# If the message_complete future isn't set yet, yield control to allow
85+
# put() to run and eventually set it.
86+
# Locking with get_in_progress ensures only one coroutine can get here.
87+
self.get_in_progress = True
88+
try:
89+
async with asyncio_timeout(timeout):
90+
await self.message_complete
91+
finally:
92+
self.get_in_progress = False
93+
94+
# get() was unblocked by close() rather than put().
95+
if self.closed:
96+
raise EOFError("stream of frames ended")
97+
98+
assert self.message_complete.done()
99+
self.message_complete = self.loop.create_future()
100+
101+
joiner: Data = b"" if self.decoder is None else ""
102+
# mypy cannot figure out that chunks have the proper type.
103+
message: Data = joiner.join(self.chunks) # type: ignore
104+
105+
self.message_fetched.set_result(None)
106+
107+
self.chunks = []
108+
assert self.chunks_queue is None
109+
110+
return message
111+
112+
async def get_iter(self) -> Iterator[Data]:
113+
"""
114+
Stream the next message.
115+
116+
Iterating the return value of :meth:`get_iter` asynchronously yields a
117+
:class:`str` or :class:`bytes` for each frame in the message.
118+
119+
The iterator must be fully consumed before calling :meth:`get_iter` or
120+
:meth:`get` again. Else, :exc:`RuntimeError` is raised.
121+
122+
This method only makes sense for fragmented messages. If messages aren't
123+
fragmented, use :meth:`get` instead.
124+
125+
Raises:
126+
EOFError: If the stream of frames has ended.
127+
RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter`
128+
concurrently.
129+
130+
"""
131+
if self.closed:
132+
raise EOFError("stream of frames ended")
133+
134+
if self.get_in_progress:
135+
raise RuntimeError("get or get_iter is already running")
136+
137+
chunks = self.chunks
138+
self.chunks = []
139+
self.chunks_queue: asyncio.Queue[Optional[Data]] = asyncio.Queue()
140+
141+
# Sending None in chunk_queue supersedes setting message_complete
142+
# when switching to "streaming". If message is already complete
143+
# when the switch happens, put() didn't send None, so we have to.
144+
if self.message_complete.done():
145+
self.chunks_queue.put_nowait(None)
146+
147+
# Locking with get_in_progress ensures only one coroutine can get here.
148+
self.get_in_progress = True
149+
try:
150+
for chunk in chunks:
151+
yield chunk
152+
while (chunk := await self.chunks_queue.get()) is not None:
153+
yield chunk
154+
finally:
155+
self.get_in_progress = False
156+
157+
assert self.message_complete.done()
158+
self.message_complete = self.loop.create_future()
159+
160+
# get_iter() was unblocked by close() rather than put().
161+
if self.closed:
162+
raise EOFError("stream of frames ended")
163+
164+
self.message_fetched.set_result(None)
165+
166+
assert self.chunks == []
167+
self.chunks_queue = None
168+
169+
async def put(self, frame: Frame) -> None:
170+
"""
171+
Add ``frame`` to the next message.
172+
173+
When ``frame`` is the final frame in a message, :meth:`put` waits until
174+
the message is fetched, either by calling :meth:`get` or by fully
175+
consuming the return value of :meth:`get_iter`.
176+
177+
:meth:`put` assumes that the stream of frames respects the protocol. If
178+
it doesn't, the behavior is undefined.
179+
180+
Raises:
181+
EOFError: If the stream of frames has ended.
182+
RuntimeError: If two coroutines run :meth:`put` concurrently.
183+
184+
"""
185+
if self.closed:
186+
raise EOFError("stream of frames ended")
187+
188+
if self.put_in_progress:
189+
raise RuntimeError("put is already running")
190+
191+
if frame.opcode is Opcode.TEXT:
192+
self.decoder = UTF8Decoder(errors="strict")
193+
elif frame.opcode is Opcode.BINARY:
194+
self.decoder = None
195+
elif frame.opcode is Opcode.CONT:
196+
pass
197+
else:
198+
# Ignore control frames.
199+
return
200+
201+
data: Data
202+
if self.decoder is not None:
203+
data = self.decoder.decode(frame.data, frame.fin)
204+
else:
205+
data = frame.data
206+
207+
if self.chunks_queue is None:
208+
self.chunks.append(data)
209+
else:
210+
self.chunks_queue.put_nowait(data)
211+
212+
if not frame.fin:
213+
return
214+
215+
# Message is complete. Wait until it's fetched to return.
216+
217+
self.message_complete.set_result(None)
218+
219+
if self.chunks_queue is not None:
220+
self.chunks_queue.put_nowait(None)
221+
222+
# Yield control to allow get() to run and eventually set the future.
223+
# Locking with put_in_progress ensures only one coroutine can get here.
224+
self.put_in_progress = True
225+
try:
226+
assert not self.message_fetched.done()
227+
await self.message_fetched
228+
finally:
229+
self.put_in_progress = False
230+
231+
assert self.message_fetched.done()
232+
self.message_fetched = self.loop.create_future()
233+
234+
# put() was unblocked by close() rather than get() or get_iter().
235+
if self.closed:
236+
raise EOFError("stream of frames ended")
237+
238+
self.decoder = None
239+
240+
def close(self) -> None:
241+
"""
242+
End the stream of frames.
243+
244+
Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`,
245+
or :meth:`put` is safe. They will raise :exc:`EOFError`.
246+
247+
"""
248+
if self.closed:
249+
return
250+
251+
self.closed = True
252+
253+
# Unblock get or get_iter.
254+
if self.get_in_progress:
255+
self.message_complete.set_result(None)
256+
if self.chunks_queue is not None:
257+
self.chunks_queue.put_nowait(None)
258+
259+
# Unblock put().
260+
if self.put_in_progress:
261+
self.message_fetched.set_result(None)

tests/asyncio/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)