Skip to content

Commit 3640923

Browse files
committed
Wait until state is CLOSED to acces close_exc.
Fix #1449.
1 parent 20739e0 commit 3640923

File tree

3 files changed

+27
-5
lines changed

3 files changed

+27
-5
lines changed

docs/project/changelog.rst

+4
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ Bug fixes
7575
start the connection handler anymore when ``process_request`` or
7676
``process_response`` returns an HTTP response.
7777

78+
* Fixed a bug in the :mod:`threading` implementation that could lead to
79+
incorrect error reporting when closing a connection while
80+
:meth:`~sync.connection.Connection.recv` is running.
81+
7882
13.0.1
7983
------
8084

src/websockets/asyncio/connection.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,8 @@ async def recv(self, decode: bool | None = None) -> Data:
274274
try:
275275
return await self.recv_messages.get(decode)
276276
except EOFError:
277+
# Wait for the protocol state to be CLOSED before accessing close_exc.
278+
await asyncio.shield(self.connection_lost_waiter)
277279
raise self.protocol.close_exc from self.recv_exc
278280
except ConcurrencyError:
279281
raise ConcurrencyError(
@@ -329,6 +331,8 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data
329331
async for frame in self.recv_messages.get_iter(decode):
330332
yield frame
331333
except EOFError:
334+
# Wait for the protocol state to be CLOSED before accessing close_exc.
335+
await asyncio.shield(self.connection_lost_waiter)
332336
raise self.protocol.close_exc from self.recv_exc
333337
except ConcurrencyError:
334338
raise ConcurrencyError(
@@ -864,6 +868,7 @@ async def send_context(
864868
# raise an exception.
865869
if raise_close_exc:
866870
self.close_transport()
871+
# Wait for the protocol state to be CLOSED before accessing close_exc.
867872
await asyncio.shield(self.connection_lost_waiter)
868873
raise self.protocol.close_exc from original_exc
869874

@@ -926,11 +931,14 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None:
926931
self.transport = transport
927932

928933
def connection_lost(self, exc: Exception | None) -> None:
929-
self.protocol.receive_eof() # receive_eof is idempotent
934+
# Calling protocol.receive_eof() is safe because it's idempotent.
935+
# This guarantees that the protocol state becomes CLOSED.
936+
self.protocol.receive_eof()
937+
assert self.protocol.state is CLOSED
930938

931-
# Abort recv() and pending pings with a ConnectionClosed exception.
932-
# Set recv_exc first to get proper exception reporting.
933939
self.set_recv_exc(exc)
940+
941+
# Abort recv() and pending pings with a ConnectionClosed exception.
934942
self.recv_messages.close()
935943
self.abort_pings()
936944

src/websockets/sync/connection.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,8 @@ def recv(self, timeout: float | None = None) -> Data:
206206
try:
207207
return self.recv_messages.get(timeout)
208208
except EOFError:
209+
# Wait for the protocol state to be CLOSED before accessing close_exc.
210+
self.recv_events_thread.join()
209211
raise self.protocol.close_exc from self.recv_exc
210212
except ConcurrencyError:
211213
raise ConcurrencyError(
@@ -240,6 +242,8 @@ def recv_streaming(self) -> Iterator[Data]:
240242
for frame in self.recv_messages.get_iter():
241243
yield frame
242244
except EOFError:
245+
# Wait for the protocol state to be CLOSED before accessing close_exc.
246+
self.recv_events_thread.join()
243247
raise self.protocol.close_exc from self.recv_exc
244248
except ConcurrencyError:
245249
raise ConcurrencyError(
@@ -629,8 +633,6 @@ def recv_events(self) -> None:
629633
self.logger.error("unexpected internal error", exc_info=True)
630634
with self.protocol_mutex:
631635
self.set_recv_exc(exc)
632-
# We don't know where we crashed. Force protocol state to CLOSED.
633-
self.protocol.state = CLOSED
634636
finally:
635637
# This isn't expected to raise an exception.
636638
self.close_socket()
@@ -738,6 +740,7 @@ def send_context(
738740
# raise an exception.
739741
if raise_close_exc:
740742
self.close_socket()
743+
# Wait for the protocol state to be CLOSED before accessing close_exc.
741744
self.recv_events_thread.join()
742745
raise self.protocol.close_exc from original_exc
743746

@@ -788,4 +791,11 @@ def close_socket(self) -> None:
788791
except OSError:
789792
pass # socket is already closed
790793
self.socket.close()
794+
795+
# Calling protocol.receive_eof() is safe because it's idempotent.
796+
# This guarantees that the protocol state becomes CLOSED.
797+
self.protocol.receive_eof()
798+
assert self.protocol.state is CLOSED
799+
800+
# Abort recv() with a ConnectionClosed exception.
791801
self.recv_messages.close()

0 commit comments

Comments
 (0)