Skip to content

Commit e97aabe

Browse files
committed
Run handler only when opening handshake succeeds.
When process_request() or process_response() returned a HTTP response without calling accept() or reject() and with a status code other than 101, the connection handler used to start, which was incorrect. Fix #1419. Also move start_keepalive() outside of handshake() and bring it together with starting the connection handler, which is more logical.
1 parent d19ed26 commit e97aabe

File tree

7 files changed

+57
-22
lines changed

7 files changed

+57
-22
lines changed

docs/project/changelog.rst

+7
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@ Improvements
6868
Previously, :exc:`RuntimeError` was raised. For backwards compatibility,
6969
:exc:`~exceptions.ConcurrencyError` is a subclass of :exc:`RuntimeError`.
7070

71+
Bug fixes
72+
.........
73+
74+
* The new :mod:`asyncio` and :mod:`threading` implementations of servers don't
75+
start the connection handler anymore when ``process_request`` or
76+
``process_response`` returns a HTTP response.
77+
7178
13.0.1
7279
------
7380

src/websockets/asyncio/client.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,7 @@ async def handshake(
9898
# before receiving a response, when the response cannot be parsed, or
9999
# when the response fails the handshake.
100100

101-
if self.protocol.handshake_exc is None:
102-
self.start_keepalive()
103-
else:
101+
if self.protocol.handshake_exc is not None:
104102
raise self.protocol.handshake_exc
105103

106104
def process_event(self, event: Event) -> None:
@@ -465,6 +463,7 @@ async def __await_impl__(self) -> ClientConnection:
465463
raise uri_or_exc from exc
466464

467465
else:
466+
self.connection.start_keepalive()
468467
return self.connection
469468
else:
470469
raise SecurityError(f"more than {MAX_REDIRECTS} redirects")

src/websockets/asyncio/server.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,11 @@ async def handshake(
201201
self.protocol.send_response(self.response)
202202

203203
# self.protocol.handshake_exc is always set when the connection is lost
204-
# before receiving a request, when the request cannot be parsed, or when
205-
# the response fails the handshake.
204+
# before receiving a request, when the request cannot be parsed, when
205+
# the handshake encounters an error, or when process_request or
206+
# process_response sends a HTTP response that rejects the handshake.
206207

207-
if self.protocol.handshake_exc is None:
208-
self.start_keepalive()
209-
else:
208+
if self.protocol.handshake_exc is not None:
210209
raise self.protocol.handshake_exc
211210

212211
def process_event(self, event: Event) -> None:
@@ -369,7 +368,9 @@ async def conn_handler(self, connection: ServerConnection) -> None:
369368
connection.close_transport()
370369
return
371370

371+
assert connection.protocol.state is OPEN
372372
try:
373+
connection.start_keepalive()
373374
await self.handler(connection)
374375
except Exception:
375376
connection.logger.error("connection handler failed", exc_info=True)

src/websockets/server.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,6 @@ def accept(self, request: Request) -> Response:
204204
if protocol_header is not None:
205205
headers["Sec-WebSocket-Protocol"] = protocol_header
206206

207-
self.logger.info("connection open")
208207
return Response(101, "Switching Protocols", headers)
209208

210209
def process_request(
@@ -515,14 +514,7 @@ def reject(self, status: StatusLike, text: str) -> Response:
515514
("Content-Type", "text/plain; charset=utf-8"),
516515
]
517516
)
518-
response = Response(status.value, status.phrase, headers, body)
519-
# When reject() is called from accept(), handshake_exc is already set.
520-
# If a user calls reject(), set handshake_exc to guarantee invariant:
521-
# "handshake_exc is None if and only if opening handshake succeeded."
522-
if self.handshake_exc is None:
523-
self.handshake_exc = InvalidStatus(response)
524-
self.logger.info("connection rejected (%d %s)", status.value, status.phrase)
525-
return response
517+
return Response(status.value, status.phrase, headers, body)
526518

527519
def send_response(self, response: Response) -> None:
528520
"""
@@ -545,7 +537,20 @@ def send_response(self, response: Response) -> None:
545537
if response.status_code == 101:
546538
assert self.state is CONNECTING
547539
self.state = OPEN
540+
self.logger.info("connection open")
541+
548542
else:
543+
# handshake_exc may be already set if accept() encountered an error.
544+
# If the connection isn't open, set handshake_exc to guarantee that
545+
# handshake_exc is None if and only if opening handshake succeeded.
546+
if self.handshake_exc is None:
547+
self.handshake_exc = InvalidStatus(response)
548+
self.logger.info(
549+
"connection rejected (%d %s)",
550+
response.status_code,
551+
response.reason_phrase,
552+
)
553+
549554
self.send_eof()
550555
self.parser = self.discard()
551556
next(self.parser) # start coroutine

src/websockets/sync/server.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
validate_subprotocols,
2424
)
2525
from ..http11 import SERVER, Request, Response
26-
from ..protocol import CONNECTING, Event
26+
from ..protocol import CONNECTING, OPEN, Event
2727
from ..server import ServerProtocol
2828
from ..typing import LoggerLike, Origin, StatusLike, Subprotocol
2929
from .connection import Connection
@@ -166,8 +166,9 @@ def handshake(
166166
self.protocol.send_response(self.response)
167167

168168
# self.protocol.handshake_exc is always set when the connection is lost
169-
# before receiving a request, when the request cannot be parsed, or when
170-
# the response fails the handshake.
169+
# before receiving a request, when the request cannot be parsed, when
170+
# the handshake encounters an error, or when process_request or
171+
# process_response sends a HTTP response that rejects the handshake.
171172

172173
if self.protocol.handshake_exc is not None:
173174
raise self.protocol.handshake_exc
@@ -569,6 +570,7 @@ def protocol_select_subprotocol(
569570
connection.recv_events_thread.join()
570571
return
571572

573+
assert connection.protocol.state is OPEN
572574
try:
573575
handler(connection)
574576
except Exception:

tests/asyncio/test_server.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -142,32 +142,46 @@ async def process_request(ws, request):
142142
async def test_process_request_returns_response(self):
143143
"""Server aborts handshake if process_request returns a response."""
144144

145+
handler_ran = False
146+
145147
def process_request(ws, request):
146148
return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden")
147149

148-
async with serve(*args, process_request=process_request) as server:
150+
async def handler(ws):
151+
nonlocal handler_ran
152+
handler_ran = True
153+
154+
async with serve(handler, *args[1:], process_request=process_request) as server:
149155
with self.assertRaises(InvalidStatus) as raised:
150156
async with connect(get_uri(server)):
151157
self.fail("did not raise")
152158
self.assertEqual(
153159
str(raised.exception),
154160
"server rejected WebSocket connection: HTTP 403",
155161
)
162+
self.assertFalse(handler_ran)
156163

157164
async def test_async_process_request_returns_response(self):
158165
"""Server aborts handshake if async process_request returns a response."""
159166

167+
handler_ran = False
168+
160169
async def process_request(ws, request):
161170
return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden")
162171

163-
async with serve(*args, process_request=process_request) as server:
172+
async def handler(ws):
173+
nonlocal handler_ran
174+
handler_ran = True
175+
176+
async with serve(handler, *args[1:], process_request=process_request) as server:
164177
with self.assertRaises(InvalidStatus) as raised:
165178
async with connect(get_uri(server)):
166179
self.fail("did not raise")
167180
self.assertEqual(
168181
str(raised.exception),
169182
"server rejected WebSocket connection: HTTP 403",
170183
)
184+
self.assertFalse(handler_ran)
171185

172186
async def test_process_request_raises_exception(self):
173187
"""Server returns an error if process_request raises an exception."""

tests/sync/test_server.py

+7
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,15 @@ def process_request(ws, request):
130130
def test_process_request_returns_response(self):
131131
"""Server aborts handshake if process_request returns a response."""
132132

133+
handler_ran = False
134+
133135
def process_request(ws, request):
134136
return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden")
135137

138+
async def handler(ws):
139+
nonlocal handler_ran
140+
handler_ran = True
141+
136142
with run_server(process_request=process_request) as server:
137143
with self.assertRaises(InvalidStatus) as raised:
138144
with connect(get_uri(server)):
@@ -141,6 +147,7 @@ def process_request(ws, request):
141147
str(raised.exception),
142148
"server rejected WebSocket connection: HTTP 403",
143149
)
150+
self.assertFalse(handler_ran)
144151

145152
def test_process_request_raises_exception(self):
146153
"""Server returns an error if process_request raises an exception."""

0 commit comments

Comments
 (0)