Skip to content

Commit 3034834

Browse files
committed
Support recv() after the connection is closed.
Fix #1538.
1 parent bdfc8cf commit 3034834

File tree

7 files changed

+142
-43
lines changed

7 files changed

+142
-43
lines changed

docs/project/changelog.rst

+7
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ notice.
3434

3535
.. _14.0:
3636

37+
Bug fixes
38+
.........
39+
40+
* Once the connection is closed, messages previously received and buffered can
41+
be read in the :mod:`asyncio` and :mod:`threading` implementations, just like
42+
in the legacy implementation.
43+
3744
14.0
3845
----
3946

src/websockets/asyncio/messages.py

+7-13
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@ def put(self, item: T) -> None:
4040
if self.get_waiter is not None and not self.get_waiter.done():
4141
self.get_waiter.set_result(None)
4242

43-
async def get(self) -> T:
43+
async def get(self, block: bool = True) -> T:
4444
"""Remove and return an item from the queue, waiting if necessary."""
4545
if not self.queue:
46+
if not block:
47+
raise EOFError("stream of frames ended")
4648
assert self.get_waiter is None, "cannot call get() concurrently"
4749
self.get_waiter = self.loop.create_future()
4850
try:
@@ -133,20 +135,16 @@ async def get(self, decode: bool | None = None) -> Data:
133135
:meth:`get_iter` concurrently.
134136
135137
"""
136-
if self.closed:
137-
raise EOFError("stream of frames ended")
138-
139138
if self.get_in_progress:
140139
raise ConcurrencyError("get() or get_iter() is already running")
141-
142140
self.get_in_progress = True
143141

144142
# Locking with get_in_progress prevents concurrent execution
145143
# until get() fetches a complete message or is cancelled.
146144

147145
try:
148146
# First frame
149-
frame = await self.frames.get()
147+
frame = await self.frames.get(not self.closed)
150148
self.maybe_resume()
151149
assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
152150
if decode is None:
@@ -156,7 +154,7 @@ async def get(self, decode: bool | None = None) -> Data:
156154
# Following frames, for fragmented messages
157155
while not frame.fin:
158156
try:
159-
frame = await self.frames.get()
157+
frame = await self.frames.get(not self.closed)
160158
except asyncio.CancelledError:
161159
# Put frames already received back into the queue
162160
# so that future calls to get() can return them.
@@ -200,12 +198,8 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
200198
:meth:`get_iter` concurrently.
201199
202200
"""
203-
if self.closed:
204-
raise EOFError("stream of frames ended")
205-
206201
if self.get_in_progress:
207202
raise ConcurrencyError("get() or get_iter() is already running")
208-
209203
self.get_in_progress = True
210204

211205
# Locking with get_in_progress prevents concurrent execution
@@ -216,7 +210,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
216210

217211
# First frame
218212
try:
219-
frame = await self.frames.get()
213+
frame = await self.frames.get(not self.closed)
220214
except asyncio.CancelledError:
221215
self.get_in_progress = False
222216
raise
@@ -236,7 +230,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
236230
# previous fragments — we're streaming them. Canceling get_iter()
237231
# here will leave the assembler in a stuck state. Future calls to
238232
# get() or get_iter() will raise ConcurrencyError.
239-
frame = await self.frames.get()
233+
frame = await self.frames.get(not self.closed)
240234
self.maybe_resume()
241235
assert frame.opcode is OP_CONT
242236
if decode:

src/websockets/sync/messages.py

+14-13
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,16 @@ def __init__(
6969
def get_next_frame(self, timeout: float | None = None) -> Frame:
7070
# Helper to factor out the logic for getting the next frame from the
7171
# queue, while handling timeouts and reaching the end of the stream.
72-
try:
73-
frame = self.frames.get(timeout=timeout)
74-
except queue.Empty:
75-
raise TimeoutError(f"timed out in {timeout:.1f}s") from None
72+
if self.closed:
73+
try:
74+
frame = self.frames.get(block=False)
75+
except queue.Empty:
76+
raise EOFError("stream of frames ended") from None
77+
else:
78+
try:
79+
frame = self.frames.get(block=True, timeout=timeout)
80+
except queue.Empty:
81+
raise TimeoutError(f"timed out in {timeout:.1f}s") from None
7682
if frame is None:
7783
raise EOFError("stream of frames ended")
7884
return frame
@@ -87,7 +93,7 @@ def reset_queue(self, frames: Iterable[Frame]) -> None:
8793
queued = []
8894
try:
8995
while True:
90-
queued.append(self.frames.get_nowait())
96+
queued.append(self.frames.get(block=False))
9197
except queue.Empty:
9298
pass
9399
for frame in frames:
@@ -123,9 +129,6 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data:
123129
124130
"""
125131
with self.mutex:
126-
if self.closed:
127-
raise EOFError("stream of frames ended")
128-
129132
if self.get_in_progress:
130133
raise ConcurrencyError("get() or get_iter() is already running")
131134
self.get_in_progress = True
@@ -194,9 +197,6 @@ def get_iter(self, decode: bool | None = None) -> Iterator[Data]:
194197
195198
"""
196199
with self.mutex:
197-
if self.closed:
198-
raise EOFError("stream of frames ended")
199-
200200
if self.get_in_progress:
201201
raise ConcurrencyError("get() or get_iter() is already running")
202202
self.get_in_progress = True
@@ -288,5 +288,6 @@ def close(self) -> None:
288288

289289
self.closed = True
290290

291-
# Unblock get() or get_iter().
292-
self.frames.put(None)
291+
if self.get_in_progress:
292+
# Unblock get() or get_iter().
293+
self.frames.put(None)

tests/asyncio/test_connection.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -793,14 +793,12 @@ async def test_close_timeout_waiting_for_connection_closed(self):
793793
# Remove socket.timeout when dropping Python < 3.10.
794794
self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError))
795795

796-
async def test_close_does_not_wait_for_recv(self):
797-
# Closing the connection discards messages buffered in the assembler.
798-
# This is allowed by the RFC:
799-
# > However, there is no guarantee that the endpoint that has already
800-
# > sent a Close frame will continue to process data.
796+
async def test_close_preserves_queued_messages(self):
797+
"""close preserves messages buffered in the assembler."""
801798
await self.remote_connection.send("😀")
802799
await self.connection.close()
803800

801+
self.assertEqual(await self.connection.recv(), "😀")
804802
with self.assertRaises(ConnectionClosedOK) as raised:
805803
await self.connection.recv()
806804

tests/asyncio/test_messages.py

+52
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,58 @@ async def test_get_iter_fails_after_close(self):
395395
async for _ in self.assembler.get_iter():
396396
self.fail("no fragment expected")
397397

398+
async def test_get_queued_message_after_close(self):
399+
"""get returns a message after close is called."""
400+
self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9"))
401+
self.assembler.close()
402+
message = await self.assembler.get()
403+
self.assertEqual(message, "café")
404+
405+
async def test_get_iter_queued_message_after_close(self):
406+
"""get_iter yields a message after close is called."""
407+
self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9"))
408+
self.assembler.close()
409+
fragments = await alist(self.assembler.get_iter())
410+
self.assertEqual(fragments, ["café"])
411+
412+
async def test_get_queued_fragmented_message_after_close(self):
413+
"""get reassembles a fragmented message after close is called."""
414+
self.assembler.put(Frame(OP_BINARY, b"t", fin=False))
415+
self.assembler.put(Frame(OP_CONT, b"e", fin=False))
416+
self.assembler.put(Frame(OP_CONT, b"a"))
417+
self.assembler.close()
418+
self.assembler.close()
419+
message = await self.assembler.get()
420+
self.assertEqual(message, b"tea")
421+
422+
async def test_get_iter_queued_fragmented_message_after_close(self):
423+
"""get_iter yields a fragmented message after close is called."""
424+
self.assembler.put(Frame(OP_BINARY, b"t", fin=False))
425+
self.assembler.put(Frame(OP_CONT, b"e", fin=False))
426+
self.assembler.put(Frame(OP_CONT, b"a"))
427+
self.assembler.close()
428+
fragments = await alist(self.assembler.get_iter())
429+
self.assertEqual(fragments, [b"t", b"e", b"a"])
430+
431+
async def test_get_partially_queued_fragmented_message_after_close(self):
432+
"""get raises EOF on a partial fragmented message after close is called."""
433+
self.assembler.put(Frame(OP_BINARY, b"t", fin=False))
434+
self.assembler.put(Frame(OP_CONT, b"e", fin=False))
435+
self.assembler.close()
436+
with self.assertRaises(EOFError):
437+
await self.assembler.get()
438+
439+
async def test_get_iter_partially_queued_fragmented_message_after_close(self):
440+
"""get_iter yields a partial fragmented message after close is called."""
441+
self.assembler.put(Frame(OP_BINARY, b"t", fin=False))
442+
self.assembler.put(Frame(OP_CONT, b"e", fin=False))
443+
self.assembler.close()
444+
fragments = []
445+
with self.assertRaises(EOFError):
446+
async for fragment in self.assembler.get_iter():
447+
fragments.append(fragment)
448+
self.assertEqual(fragments, [b"t", b"e"])
449+
398450
async def test_put_fails_after_close(self):
399451
"""put raises EOFError after close is called."""
400452
self.assembler.close()

tests/sync/test_connection.py

+7-12
Original file line numberDiff line numberDiff line change
@@ -543,17 +543,12 @@ def test_close_timeout_waiting_for_connection_closed(self):
543543
# Remove socket.timeout when dropping Python < 3.10.
544544
self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError))
545545

546-
def test_close_does_not_wait_for_recv(self):
547-
# Closing the connection discards messages buffered in the assembler.
548-
# This is allowed by the RFC:
549-
# > However, there is no guarantee that the endpoint that has already
550-
# > sent a Close frame will continue to process data.
546+
def test_close_preserves_queued_messages(self):
547+
"""close preserves messages buffered in the assembler."""
551548
self.remote_connection.send("😀")
552549
self.connection.close()
553550

554-
close_thread = threading.Thread(target=self.connection.close)
555-
close_thread.start()
556-
551+
self.assertEqual(self.connection.recv(), "😀")
557552
with self.assertRaises(ConnectionClosedOK) as raised:
558553
self.connection.recv()
559554

@@ -576,10 +571,10 @@ def test_close_idempotency(self):
576571
def test_close_idempotency_race_condition(self):
577572
"""close waits if the connection is already closing."""
578573

579-
self.connection.close_timeout = 5 * MS
574+
self.connection.close_timeout = 6 * MS
580575

581576
def closer():
582-
with self.delay_frames_rcvd(3 * MS):
577+
with self.delay_frames_rcvd(4 * MS):
583578
self.connection.close()
584579

585580
close_thread = threading.Thread(target=closer)
@@ -591,14 +586,14 @@ def closer():
591586

592587
# Connection isn't closed yet.
593588
with self.assertRaises(TimeoutError):
594-
self.connection.recv(timeout=0)
589+
self.connection.recv(timeout=MS)
595590

596591
self.connection.close()
597592
self.assertNoFrameSent()
598593

599594
# Connection is closed now.
600595
with self.assertRaises(ConnectionClosedOK):
601-
self.connection.recv(timeout=0)
596+
self.connection.recv(timeout=MS)
602597

603598
close_thread.join()
604599

tests/sync/test_messages.py

+52
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,58 @@ def test_get_iter_fails_after_close(self):
374374
for _ in self.assembler.get_iter():
375375
self.fail("no fragment expected")
376376

377+
def test_get_queued_message_after_close(self):
378+
"""get returns a message after close is called."""
379+
self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9"))
380+
self.assembler.close()
381+
message = self.assembler.get()
382+
self.assertEqual(message, "café")
383+
384+
def test_get_iter_queued_message_after_close(self):
385+
"""get_iter yields a message after close is called."""
386+
self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9"))
387+
self.assembler.close()
388+
fragments = list(self.assembler.get_iter())
389+
self.assertEqual(fragments, ["café"])
390+
391+
def test_get_queued_fragmented_message_after_close(self):
392+
"""get reassembles a fragmented message after close is called."""
393+
self.assembler.put(Frame(OP_BINARY, b"t", fin=False))
394+
self.assembler.put(Frame(OP_CONT, b"e", fin=False))
395+
self.assembler.put(Frame(OP_CONT, b"a"))
396+
self.assembler.close()
397+
self.assembler.close()
398+
message = self.assembler.get()
399+
self.assertEqual(message, b"tea")
400+
401+
def test_get_iter_queued_fragmented_message_after_close(self):
402+
"""get_iter yields a fragmented message after close is called."""
403+
self.assembler.put(Frame(OP_BINARY, b"t", fin=False))
404+
self.assembler.put(Frame(OP_CONT, b"e", fin=False))
405+
self.assembler.put(Frame(OP_CONT, b"a"))
406+
self.assembler.close()
407+
fragments = list(self.assembler.get_iter())
408+
self.assertEqual(fragments, [b"t", b"e", b"a"])
409+
410+
def test_get_partially_queued_fragmented_message_after_close(self):
411+
"""get raises EOF on a partial fragmented message after close is called."""
412+
self.assembler.put(Frame(OP_BINARY, b"t", fin=False))
413+
self.assembler.put(Frame(OP_CONT, b"e", fin=False))
414+
self.assembler.close()
415+
with self.assertRaises(EOFError):
416+
self.assembler.get()
417+
418+
def test_get_iter_partially_queued_fragmented_message_after_close(self):
419+
"""get_iter yields a partial fragmented message after close is called."""
420+
self.assembler.put(Frame(OP_BINARY, b"t", fin=False))
421+
self.assembler.put(Frame(OP_CONT, b"e", fin=False))
422+
self.assembler.close()
423+
fragments = []
424+
with self.assertRaises(EOFError):
425+
for fragment in self.assembler.get_iter():
426+
fragments.append(fragment)
427+
self.assertEqual(fragments, [b"t", b"e"])
428+
377429
def test_put_fails_after_close(self):
378430
"""put raises EOFError after close is called."""
379431
self.assembler.close()

0 commit comments

Comments
 (0)