Skip to content

Commit 5107a8c

Browse files
committed
Run handler only when opening handshake succeeds.
When process_request or process_response returned a HTTP response with a status code other than 101, the connection handler used to start, which was incorrect. Fix #1419.
1 parent 070ff1a commit 5107a8c

File tree

6 files changed

+61
-25
lines changed

6 files changed

+61
-25
lines changed

docs/project/changelog.rst

+7
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@ Improvements
6868
Previously, :exc:`RuntimeError` was raised. For backwards compatibility,
6969
:exc:`~exceptions.ConcurrencyError` is a subclass of :exc:`RuntimeError`.
7070

71+
Bug fixes
72+
.........
73+
74+
* The new :mod:`asyncio` and :mod:`threading` implementations of servers don't
75+
start the connection handler anymore when ``process_request`` or
76+
``process_response`` returns a HTTP response.
77+
7178
13.0.1
7279
------
7380

src/websockets/asyncio/client.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,7 @@ async def handshake(
9898
# before receiving a response, when the response cannot be parsed, or
9999
# when the response fails the handshake.
100100

101-
if self.protocol.handshake_exc is None:
102-
self.start_keepalive()
103-
else:
101+
if self.protocol.handshake_exc is not None:
104102
raise self.protocol.handshake_exc
105103

106104
def process_event(self, event: Event) -> None:
@@ -465,6 +463,7 @@ async def __await_impl__(self) -> ClientConnection:
465463
raise uri_or_exc from exc
466464

467465
else:
466+
self.connection.start_keepalive()
468467
return self.connection
469468
else:
470469
raise SecurityError(f"more than {MAX_REDIRECTS} redirects")

src/websockets/asyncio/server.py

+15-10
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,10 @@ async def handshake(
202202

203203
# self.protocol.handshake_exc is always set when the connection is lost
204204
# before receiving a request, when the request cannot be parsed, or when
205-
# the response fails the handshake.
205+
# the handshake encounters an error. It isn't set when process_request
206+
# or process_response sends a HTTP response that rejects the handshake.
206207

207-
if self.protocol.handshake_exc is None:
208-
self.start_keepalive()
209-
else:
208+
if self.protocol.handshake_exc is not None:
210209
raise self.protocol.handshake_exc
211210

212211
def process_event(self, event: Event) -> None:
@@ -369,13 +368,19 @@ async def conn_handler(self, connection: ServerConnection) -> None:
369368
connection.close_transport()
370369
return
371370

372-
try:
373-
await self.handler(connection)
374-
except Exception:
375-
connection.logger.error("connection handler failed", exc_info=True)
376-
await connection.close(CloseCode.INTERNAL_ERROR)
371+
if connection.protocol.state is OPEN:
372+
try:
373+
connection.start_keepalive()
374+
await self.handler(connection)
375+
except Exception:
376+
connection.logger.error("connection handler failed", exc_info=True)
377+
await connection.close(CloseCode.INTERNAL_ERROR)
378+
else:
379+
await connection.close()
377380
else:
378-
await connection.close()
381+
# process_request or process_response sent a non-101 response.
382+
connection.close_transport()
383+
return
379384

380385
except TimeoutError:
381386
# When the opening handshake times out, there's nothing to log.

src/websockets/sync/server.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
validate_subprotocols,
2424
)
2525
from ..http11 import SERVER, Request, Response
26-
from ..protocol import CONNECTING, Event
26+
from ..protocol import CONNECTING, OPEN, Event
2727
from ..server import ServerProtocol
2828
from ..typing import LoggerLike, Origin, StatusLike, Subprotocol
2929
from .connection import Connection
@@ -167,7 +167,8 @@ def handshake(
167167

168168
# self.protocol.handshake_exc is always set when the connection is lost
169169
# before receiving a request, when the request cannot be parsed, or when
170-
# the response fails the handshake.
170+
# the handshake encounters an error. It isn't set when process_request
171+
# or process_response sends a HTTP response that rejects the handshake.
171172

172173
if self.protocol.handshake_exc is not None:
173174
raise self.protocol.handshake_exc
@@ -562,20 +563,23 @@ def protocol_select_subprotocol(
562563
except TimeoutError:
563564
connection.close_socket()
564565
connection.recv_events_thread.join()
565-
return
566566
except Exception:
567567
connection.logger.error("opening handshake failed", exc_info=True)
568568
connection.close_socket()
569569
connection.recv_events_thread.join()
570-
return
571570

572-
try:
573-
handler(connection)
574-
except Exception:
575-
connection.logger.error("connection handler failed", exc_info=True)
576-
connection.close(CloseCode.INTERNAL_ERROR)
571+
if connection.protocol.state is OPEN:
572+
try:
573+
handler(connection)
574+
except Exception:
575+
connection.logger.error("connection handler failed", exc_info=True)
576+
connection.close(CloseCode.INTERNAL_ERROR)
577+
else:
578+
connection.close()
577579
else:
578-
connection.close()
580+
# process_request or process_response sent a non-101 response.
581+
connection.close_socket()
582+
connection.recv_events_thread.join()
579583

580584
except Exception: # pragma: no cover
581585
# Don't leak sockets on unexpected errors.

tests/asyncio/test_server.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -142,32 +142,46 @@ async def process_request(ws, request):
142142
async def test_process_request_returns_response(self):
143143
"""Server aborts handshake if process_request returns a response."""
144144

145+
handler_ran = True
146+
145147
def process_request(ws, request):
146148
return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden")
147149

148-
async with serve(*args, process_request=process_request) as server:
150+
async def handler(ws):
151+
nonlocal handler_ran
152+
handler_ran = True
153+
154+
async with serve(handler, *args[1:], process_request=process_request) as server:
149155
with self.assertRaises(InvalidStatus) as raised:
150156
async with connect(get_uri(server)):
151157
self.fail("did not raise")
152158
self.assertEqual(
153159
str(raised.exception),
154160
"server rejected WebSocket connection: HTTP 403",
155161
)
162+
self.assertFalse(handler_ran)
156163

157164
async def test_async_process_request_returns_response(self):
158165
"""Server aborts handshake if async process_request returns a response."""
159166

167+
handler_ran = True
168+
160169
async def process_request(ws, request):
161170
return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden")
162171

163-
async with serve(*args, process_request=process_request) as server:
172+
async def handler(ws):
173+
nonlocal handler_ran
174+
handler_ran = True
175+
176+
async with serve(handler, *args[1:], process_request=process_request) as server:
164177
with self.assertRaises(InvalidStatus) as raised:
165178
async with connect(get_uri(server)):
166179
self.fail("did not raise")
167180
self.assertEqual(
168181
str(raised.exception),
169182
"server rejected WebSocket connection: HTTP 403",
170183
)
184+
self.assertFalse(handler_ran)
171185

172186
async def test_process_request_raises_exception(self):
173187
"""Server returns an error if process_request raises an exception."""

tests/sync/test_server.py

+7
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,15 @@ def process_request(ws, request):
130130
def test_process_request_returns_response(self):
131131
"""Server aborts handshake if process_request returns a response."""
132132

133+
handler_ran = True
134+
133135
def process_request(ws, request):
134136
return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden")
135137

138+
async def handler(ws):
139+
nonlocal handler_ran
140+
handler_ran = True
141+
136142
with run_server(process_request=process_request) as server:
137143
with self.assertRaises(InvalidStatus) as raised:
138144
with connect(get_uri(server)):
@@ -141,6 +147,7 @@ def process_request(ws, request):
141147
str(raised.exception),
142148
"server rejected WebSocket connection: HTTP 403",
143149
)
150+
self.assertFalse(handler_ran)
144151

145152
def test_process_request_raises_exception(self):
146153
"""Server returns an error if process_request raises an exception."""

0 commit comments

Comments
 (0)