Skip to content

Commit f58355e

Browse files
committed
Add asyncio message reassembler.
1 parent 6818280 commit f58355e

File tree

4 files changed

+636
-7
lines changed

4 files changed

+636
-7
lines changed

src/websockets/asyncio/messages.py

+254
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import codecs
5+
from typing import AsyncIterator, List, Optional
6+
7+
from ..frames import Frame, Opcode
8+
from ..typing import Data
9+
10+
11+
__all__ = ["Assembler"]
12+
13+
UTF8Decoder = codecs.getincrementaldecoder("utf-8")
14+
15+
16+
class Assembler:
17+
"""
18+
Assemble messages from frames.
19+
20+
"""
21+
22+
def __init__(self) -> None:
23+
self.loop = asyncio.get_event_loop()
24+
25+
# We create a latch with two futures to ensure proper interleaving of
26+
# writing and reading messages.
27+
# put() sets this future to tell get() that a message can be fetched.
28+
self.message_complete: asyncio.Future[None] = self.loop.create_future()
29+
# get() sets this future to let put() that the message was fetched.
30+
self.message_fetched: asyncio.Future[None] = self.loop.create_future()
31+
32+
# This flag prevents concurrent calls to get() by user code.
33+
self.get_in_progress = False
34+
# This flag prevents concurrent calls to put() by library code.
35+
self.put_in_progress = False
36+
37+
# Decoder for text frames, None for binary frames.
38+
self.decoder: Optional[codecs.IncrementalDecoder] = None
39+
40+
# Buffer of frames belonging to the same message.
41+
self.chunks: List[Data] = []
42+
43+
# When switching from "buffering" to "streaming", we use a queue for
44+
# transferring frames from the writing coroutine (library code) to the
45+
# reading coroutine (user code). We're buffering when chunks_queue is
46+
# None and streaming when it's a Queue. None is a sentinel value marking
47+
# the end of the stream, superseding message_complete.
48+
49+
# Stream data from frames belonging to the same message.
50+
self.chunks_queue: Optional[asyncio.Queue[Optional[Data]]] = None
51+
52+
# This flag marks the end of the stream.
53+
self.closed = False
54+
55+
async def get(self) -> Data:
56+
"""
57+
Read the next message.
58+
59+
:meth:`get` returns a single :class:`str` or :class:`bytes`.
60+
61+
If the message is fragmented, :meth:`get` waits until the last frame is
62+
received, then it reassembles the message and returns it. To receive
63+
messages frame by frame, use :meth:`get_iter` instead.
64+
65+
Raises:
66+
EOFError: If the stream of frames has ended.
67+
RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter`
68+
concurrently.
69+
70+
"""
71+
if self.closed:
72+
raise EOFError("stream of frames ended")
73+
74+
if self.get_in_progress:
75+
raise RuntimeError("get or get_iter is already running")
76+
77+
# If the message_complete future isn't set yet, yield control to allow
78+
# put() to run and eventually set it.
79+
# Locking with get_in_progress ensures only one coroutine can get here.
80+
self.get_in_progress = True
81+
try:
82+
await self.message_complete
83+
finally:
84+
self.get_in_progress = False
85+
86+
# get() was unblocked by close() rather than put().
87+
if self.closed:
88+
raise EOFError("stream of frames ended")
89+
90+
assert self.message_complete.done()
91+
self.message_complete = self.loop.create_future()
92+
93+
joiner: Data = b"" if self.decoder is None else ""
94+
# mypy cannot figure out that chunks have the proper type.
95+
message: Data = joiner.join(self.chunks) # type: ignore
96+
97+
self.message_fetched.set_result(None)
98+
99+
self.chunks = []
100+
assert self.chunks_queue is None
101+
102+
return message
103+
104+
async def get_iter(self) -> AsyncIterator[Data]:
105+
"""
106+
Stream the next message.
107+
108+
Iterating the return value of :meth:`get_iter` asynchronously yields a
109+
:class:`str` or :class:`bytes` for each frame in the message.
110+
111+
The iterator must be fully consumed before calling :meth:`get_iter` or
112+
:meth:`get` again. Else, :exc:`RuntimeError` is raised.
113+
114+
This method only makes sense for fragmented messages. If messages aren't
115+
fragmented, use :meth:`get` instead.
116+
117+
Raises:
118+
EOFError: If the stream of frames has ended.
119+
RuntimeError: If two coroutines run :meth:`get` or :meth:`get_iter`
120+
concurrently.
121+
122+
"""
123+
if self.closed:
124+
raise EOFError("stream of frames ended")
125+
126+
if self.get_in_progress:
127+
raise RuntimeError("get or get_iter is already running")
128+
129+
chunks = self.chunks
130+
self.chunks = []
131+
self.chunks_queue = asyncio.Queue()
132+
133+
# Sending None in chunk_queue supersedes setting message_complete
134+
# when switching to "streaming". If message is already complete
135+
# when the switch happens, put() didn't send None, so we have to.
136+
if self.message_complete.done():
137+
self.chunks_queue.put_nowait(None)
138+
139+
# Locking with get_in_progress ensures only one coroutine can get here.
140+
self.get_in_progress = True
141+
try:
142+
chunk: Optional[Data]
143+
for chunk in chunks:
144+
yield chunk
145+
while (chunk := await self.chunks_queue.get()) is not None:
146+
yield chunk
147+
finally:
148+
self.get_in_progress = False
149+
150+
assert self.message_complete.done()
151+
self.message_complete = self.loop.create_future()
152+
153+
# get_iter() was unblocked by close() rather than put().
154+
if self.closed:
155+
raise EOFError("stream of frames ended")
156+
157+
self.message_fetched.set_result(None)
158+
159+
assert self.chunks == []
160+
self.chunks_queue = None
161+
162+
async def put(self, frame: Frame) -> None:
163+
"""
164+
Add ``frame`` to the next message.
165+
166+
When ``frame`` is the final frame in a message, :meth:`put` waits until
167+
the message is fetched, either by calling :meth:`get` or by fully
168+
consuming the return value of :meth:`get_iter`.
169+
170+
:meth:`put` assumes that the stream of frames respects the protocol. If
171+
it doesn't, the behavior is undefined.
172+
173+
Raises:
174+
EOFError: If the stream of frames has ended.
175+
RuntimeError: If two coroutines run :meth:`put` concurrently.
176+
177+
"""
178+
if self.closed:
179+
raise EOFError("stream of frames ended")
180+
181+
if self.put_in_progress:
182+
raise RuntimeError("put is already running")
183+
184+
if frame.opcode is Opcode.TEXT:
185+
self.decoder = UTF8Decoder(errors="strict")
186+
elif frame.opcode is Opcode.BINARY:
187+
self.decoder = None
188+
elif frame.opcode is Opcode.CONT:
189+
pass
190+
else:
191+
# Ignore control frames.
192+
return
193+
194+
data: Data
195+
if self.decoder is not None:
196+
data = self.decoder.decode(frame.data, frame.fin)
197+
else:
198+
data = frame.data
199+
200+
if self.chunks_queue is None:
201+
self.chunks.append(data)
202+
else:
203+
self.chunks_queue.put_nowait(data)
204+
205+
if not frame.fin:
206+
return
207+
208+
# Message is complete. Wait until it's fetched to return.
209+
210+
self.message_complete.set_result(None)
211+
212+
if self.chunks_queue is not None:
213+
self.chunks_queue.put_nowait(None)
214+
215+
# Yield control to allow get() to run and eventually set the future.
216+
# Locking with put_in_progress ensures only one coroutine can get here.
217+
self.put_in_progress = True
218+
try:
219+
assert not self.message_fetched.done()
220+
await self.message_fetched
221+
finally:
222+
self.put_in_progress = False
223+
224+
assert self.message_fetched.done()
225+
self.message_fetched = self.loop.create_future()
226+
227+
# put() was unblocked by close() rather than get() or get_iter().
228+
if self.closed:
229+
raise EOFError("stream of frames ended")
230+
231+
self.decoder = None
232+
233+
def close(self) -> None:
234+
"""
235+
End the stream of frames.
236+
237+
Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`,
238+
or :meth:`put` is safe. They will raise :exc:`EOFError`.
239+
240+
"""
241+
if self.closed:
242+
return
243+
244+
self.closed = True
245+
246+
# Unblock get or get_iter.
247+
if self.get_in_progress:
248+
self.message_complete.set_result(None)
249+
if self.chunks_queue is not None:
250+
self.chunks_queue.put_nowait(None)
251+
252+
# Unblock put().
253+
if self.put_in_progress:
254+
self.message_fetched.set_result(None)

src/websockets/sync/messages.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ def get(self, timeout: Optional[float] = None) -> Data:
7272
7373
Raises:
7474
EOFError: If the stream of frames has ended.
75-
RuntimeError: If two threads run :meth:`get` or :meth:``get_iter`
75+
RuntimeError: If two threads run :meth:`get` or :meth:`get_iter`
7676
concurrently.
77+
TimeoutError: If a timeout is provided and elapses before a
78+
complete message is received.
7779
7880
"""
7981
with self.mutex:
@@ -131,7 +133,7 @@ def get_iter(self) -> Iterator[Data]:
131133
132134
Raises:
133135
EOFError: If the stream of frames has ended.
134-
RuntimeError: If two threads run :meth:`get` or :meth:``get_iter`
136+
RuntimeError: If two threads run :meth:`get` or :meth:`get_iter`
135137
concurrently.
136138
137139
"""
@@ -159,11 +161,10 @@ def get_iter(self) -> Iterator[Data]:
159161
self.get_in_progress = True
160162

161163
# Locking with get_in_progress ensures only one thread can get here.
162-
yield from chunks
163-
while True:
164-
chunk = self.chunks_queue.get()
165-
if chunk is None:
166-
break
164+
chunk: Optional[Data]
165+
for chunk in chunks:
166+
yield chunk
167+
while (chunk := self.chunks_queue.get()) is not None:
167168
yield chunk
168169

169170
with self.mutex:
@@ -242,6 +243,7 @@ def put(self, frame: Frame) -> None:
242243
self.put_in_progress = True
243244

244245
# Release the lock to allow get() to run and eventually set the event.
246+
# Locking with get_in_progress ensures only one coroutine can get here.
245247
self.message_fetched.wait()
246248

247249
with self.mutex:

tests/asyncio/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)