Skip to content

Commit 98f236f

Browse files
committed
Run handler only when opening handshake succeeds.
When process_request() or process_response() returned a HTTP response without calling accept() or reject() and with a status code other than 101, the connection handler used to start, which was incorrect. Fix #1419. Also move start_keepalive() outside of handshake() and bring it together with starting the connection handler, which is more logical.
1 parent d19ed26 commit 98f236f

File tree

8 files changed

+88
-26
lines changed

8 files changed

+88
-26
lines changed

docs/project/changelog.rst

+7
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@ Improvements
6868
Previously, :exc:`RuntimeError` was raised. For backwards compatibility,
6969
:exc:`~exceptions.ConcurrencyError` is a subclass of :exc:`RuntimeError`.
7070

71+
Bug fixes
72+
.........
73+
74+
* The new :mod:`asyncio` and :mod:`threading` implementations of servers don't
75+
start the connection handler anymore when ``process_request`` or
76+
``process_response`` returns a HTTP response.
77+
7178
13.0.1
7279
------
7380

src/websockets/asyncio/client.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,7 @@ async def handshake(
9898
# before receiving a response, when the response cannot be parsed, or
9999
# when the response fails the handshake.
100100

101-
if self.protocol.handshake_exc is None:
102-
self.start_keepalive()
103-
else:
101+
if self.protocol.handshake_exc is not None:
104102
raise self.protocol.handshake_exc
105103

106104
def process_event(self, event: Event) -> None:
@@ -465,6 +463,7 @@ async def __await_impl__(self) -> ClientConnection:
465463
raise uri_or_exc from exc
466464

467465
else:
466+
self.connection.start_keepalive()
468467
return self.connection
469468
else:
470469
raise SecurityError(f"more than {MAX_REDIRECTS} redirects")

src/websockets/asyncio/server.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,11 @@ async def handshake(
201201
self.protocol.send_response(self.response)
202202

203203
# self.protocol.handshake_exc is always set when the connection is lost
204-
# before receiving a request, when the request cannot be parsed, or when
205-
# the response fails the handshake.
204+
# before receiving a request, when the request cannot be parsed, when
205+
# the handshake encounters an error, or when process_request or
206+
# process_response sends a HTTP response that rejects the handshake.
206207

207-
if self.protocol.handshake_exc is None:
208-
self.start_keepalive()
209-
else:
208+
if self.protocol.handshake_exc is not None:
210209
raise self.protocol.handshake_exc
211210

212211
def process_event(self, event: Event) -> None:
@@ -369,7 +368,9 @@ async def conn_handler(self, connection: ServerConnection) -> None:
369368
connection.close_transport()
370369
return
371370

371+
assert connection.protocol.state is OPEN
372372
try:
373+
connection.start_keepalive()
373374
await self.handler(connection)
374375
except Exception:
375376
connection.logger.error("connection handler failed", exc_info=True)

src/websockets/server.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,6 @@ def accept(self, request: Request) -> Response:
204204
if protocol_header is not None:
205205
headers["Sec-WebSocket-Protocol"] = protocol_header
206206

207-
self.logger.info("connection open")
208207
return Response(101, "Switching Protocols", headers)
209208

210209
def process_request(
@@ -515,14 +514,7 @@ def reject(self, status: StatusLike, text: str) -> Response:
515514
("Content-Type", "text/plain; charset=utf-8"),
516515
]
517516
)
518-
response = Response(status.value, status.phrase, headers, body)
519-
# When reject() is called from accept(), handshake_exc is already set.
520-
# If a user calls reject(), set handshake_exc to guarantee invariant:
521-
# "handshake_exc is None if and only if opening handshake succeeded."
522-
if self.handshake_exc is None:
523-
self.handshake_exc = InvalidStatus(response)
524-
self.logger.info("connection rejected (%d %s)", status.value, status.phrase)
525-
return response
517+
return Response(status.value, status.phrase, headers, body)
526518

527519
def send_response(self, response: Response) -> None:
528520
"""
@@ -545,7 +537,20 @@ def send_response(self, response: Response) -> None:
545537
if response.status_code == 101:
546538
assert self.state is CONNECTING
547539
self.state = OPEN
540+
self.logger.info("connection open")
541+
548542
else:
543+
# handshake_exc may be already set if accept() encountered an error.
544+
# If the connection isn't open, set handshake_exc to guarantee that
545+
# handshake_exc is None if and only if opening handshake succeeded.
546+
if self.handshake_exc is None:
547+
self.handshake_exc = InvalidStatus(response)
548+
self.logger.info(
549+
"connection rejected (%d %s)",
550+
response.status_code,
551+
response.reason_phrase,
552+
)
553+
549554
self.send_eof()
550555
self.parser = self.discard()
551556
next(self.parser) # start coroutine

src/websockets/sync/server.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
validate_subprotocols,
2424
)
2525
from ..http11 import SERVER, Request, Response
26-
from ..protocol import CONNECTING, Event
26+
from ..protocol import CONNECTING, OPEN, Event
2727
from ..server import ServerProtocol
2828
from ..typing import LoggerLike, Origin, StatusLike, Subprotocol
2929
from .connection import Connection
@@ -166,8 +166,9 @@ def handshake(
166166
self.protocol.send_response(self.response)
167167

168168
# self.protocol.handshake_exc is always set when the connection is lost
169-
# before receiving a request, when the request cannot be parsed, or when
170-
# the response fails the handshake.
169+
# before receiving a request, when the request cannot be parsed, when
170+
# the handshake encounters an error, or when process_request or
171+
# process_response sends a HTTP response that rejects the handshake.
171172

172173
if self.protocol.handshake_exc is not None:
173174
raise self.protocol.handshake_exc
@@ -569,6 +570,7 @@ def protocol_select_subprotocol(
569570
connection.recv_events_thread.join()
570571
return
571572

573+
assert connection.protocol.state is OPEN
572574
try:
573575
handler(connection)
574576
except Exception:

tests/asyncio/test_server.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,10 @@ async def test_process_request_returns_response(self):
145145
def process_request(ws, request):
146146
return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden")
147147

148-
async with serve(*args, process_request=process_request) as server:
148+
async def handler(ws):
149+
self.fail("handler must not run")
150+
151+
async with serve(handler, *args[1:], process_request=process_request) as server:
149152
with self.assertRaises(InvalidStatus) as raised:
150153
async with connect(get_uri(server)):
151154
self.fail("did not raise")
@@ -160,7 +163,10 @@ async def test_async_process_request_returns_response(self):
160163
async def process_request(ws, request):
161164
return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden")
162165

163-
async with serve(*args, process_request=process_request) as server:
166+
async def handler(ws):
167+
self.fail("handler must not run")
168+
169+
async with serve(handler, *args[1:], process_request=process_request) as server:
164170
with self.assertRaises(InvalidStatus) as raised:
165171
async with connect(get_uri(server)):
166172
self.fail("did not raise")

tests/sync/test_server.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,10 @@ def test_process_request_returns_response(self):
133133
def process_request(ws, request):
134134
return ws.respond(http.HTTPStatus.FORBIDDEN, "Forbidden")
135135

136-
with run_server(process_request=process_request) as server:
136+
def handler(ws):
137+
self.fail("handler must not run")
138+
139+
with run_server(handler, process_request=process_request) as server:
137140
with self.assertRaises(InvalidStatus) as raised:
138141
with connect(get_uri(server)):
139142
self.fail("did not raise")

tests/test_server.py

+42-3
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,11 @@ def make_request(self):
106106
),
107107
)
108108

109-
def test_send_accept(self):
109+
def test_send_response_after_successful_accept(self):
110110
server = ServerProtocol()
111+
request = self.make_request()
111112
with unittest.mock.patch("email.utils.formatdate", return_value=DATE):
112-
response = server.accept(self.make_request())
113+
response = server.accept(request)
113114
self.assertIsInstance(response, Response)
114115
server.send_response(response)
115116
self.assertEqual(
@@ -126,7 +127,32 @@ def test_send_accept(self):
126127
self.assertFalse(server.close_expected())
127128
self.assertEqual(server.state, OPEN)
128129

129-
def test_send_reject(self):
130+
def test_send_response_after_failed_accept(self):
131+
server = ServerProtocol()
132+
request = self.make_request()
133+
del request.headers["Sec-WebSocket-Key"]
134+
with unittest.mock.patch("email.utils.formatdate", return_value=DATE):
135+
response = server.accept(request)
136+
self.assertIsInstance(response, Response)
137+
server.send_response(response)
138+
self.assertEqual(
139+
server.data_to_send(),
140+
[
141+
f"HTTP/1.1 400 Bad Request\r\n"
142+
f"Date: {DATE}\r\n"
143+
f"Connection: close\r\n"
144+
f"Content-Length: 94\r\n"
145+
f"Content-Type: text/plain; charset=utf-8\r\n"
146+
f"\r\n"
147+
f"Failed to open a WebSocket connection: "
148+
f"missing Sec-WebSocket-Key header; 'sec-websocket-key'.\n".encode(),
149+
b"",
150+
],
151+
)
152+
self.assertTrue(server.close_expected())
153+
self.assertEqual(server.state, CONNECTING)
154+
155+
def test_send_response_after_reject(self):
130156
server = ServerProtocol()
131157
with unittest.mock.patch("email.utils.formatdate", return_value=DATE):
132158
response = server.reject(http.HTTPStatus.NOT_FOUND, "Sorry folks.\n")
@@ -148,6 +174,19 @@ def test_send_reject(self):
148174
self.assertTrue(server.close_expected())
149175
self.assertEqual(server.state, CONNECTING)
150176

177+
def test_send_response_without_accept_or_reject(self):
178+
server = ServerProtocol()
179+
server.send_response(Response(410, "Gone", Headers(), b"AWOL.\n"))
180+
self.assertEqual(
181+
server.data_to_send(),
182+
[
183+
"HTTP/1.1 410 Gone\r\n\r\nAWOL.\n".encode(),
184+
b"",
185+
],
186+
)
187+
self.assertTrue(server.close_expected())
188+
self.assertEqual(server.state, CONNECTING)
189+
151190
def test_accept_response(self):
152191
server = ServerProtocol()
153192
with unittest.mock.patch("email.utils.formatdate", return_value=DATE):

0 commit comments

Comments
 (0)