Skip to content

Commit de768cf

Browse files
committed
Improve tests for sync implementation.
1 parent 9b5273c commit de768cf

File tree

3 files changed

+53
-50
lines changed

3 files changed

+53
-50
lines changed

tests/sync/server.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def assertEval(self, client, expr, value):
2525

2626

2727
@contextlib.contextmanager
28-
def run_server(ws_handler=eval_shell, host="localhost", port=0, **kwargs):
29-
with serve(ws_handler, host, port, **kwargs) as server:
28+
def run_server(handler=eval_shell, host="localhost", port=0, **kwargs):
29+
with serve(handler, host, port, **kwargs) as server:
3030
thread = threading.Thread(target=server.serve_forever)
3131
thread.start()
3232
try:
@@ -37,8 +37,8 @@ def run_server(ws_handler=eval_shell, host="localhost", port=0, **kwargs):
3737

3838

3939
@contextlib.contextmanager
40-
def run_unix_server(path, ws_handler=eval_shell, **kwargs):
41-
with unix_serve(ws_handler, path, **kwargs) as server:
40+
def run_unix_server(path, handler=eval_shell, **kwargs):
41+
with unix_serve(handler, path, **kwargs) as server:
4242
thread = threading.Thread(target=server.serve_forever)
4343
thread.start()
4444
try:

tests/sync/test_client.py

+31-28
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import threading
44
import unittest
55

6-
from websockets.exceptions import InvalidHandshake
6+
from websockets.exceptions import InvalidHandshake, InvalidURI
77
from websockets.extensions.permessage_deflate import PerMessageDeflate
88
from websockets.sync.client import *
99

@@ -25,29 +25,6 @@ def test_connection(self):
2525
with run_client(server) as client:
2626
self.assertEqual(client.protocol.state.name, "OPEN")
2727

28-
def test_connection_fails(self):
29-
"""Client connects to server but the handshake fails."""
30-
31-
def remove_accept_header(self, request, response):
32-
del response.headers["Sec-WebSocket-Accept"]
33-
34-
# The connection will be open for the server but failed for the client.
35-
# Use a connection handler that exits immediately to avoid an exception.
36-
with run_server(do_nothing, process_response=remove_accept_header) as server:
37-
with self.assertRaises(InvalidHandshake) as raised:
38-
with run_client(server, close_timeout=MS):
39-
self.fail("did not raise")
40-
self.assertEqual(
41-
str(raised.exception),
42-
"missing Sec-WebSocket-Accept header",
43-
)
44-
45-
def test_tcp_connection_fails(self):
46-
"""Client fails to connect to server."""
47-
with self.assertRaises(OSError):
48-
with run_client("ws://localhost:54321"): # invalid port
49-
self.fail("did not raise")
50-
5128
def test_existing_socket(self):
5229
"""Client connects using a pre-existing socket."""
5330
with run_server() as server:
@@ -103,6 +80,35 @@ def create_connection(*args, **kwargs):
10380
with run_client(server, create_connection=create_connection) as client:
10481
self.assertTrue(client.create_connection_ran)
10582

83+
def test_invalid_uri(self):
84+
"""Client receives an invalid URI."""
85+
with self.assertRaises(InvalidURI):
86+
with run_client("http://localhost"): # invalid scheme
87+
self.fail("did not raise")
88+
89+
def test_tcp_connection_fails(self):
90+
"""Client fails to connect to server."""
91+
with self.assertRaises(OSError):
92+
with run_client("ws://localhost:54321"): # invalid port
93+
self.fail("did not raise")
94+
95+
def test_handshake_fails(self):
96+
"""Client connects to server but the handshake fails."""
97+
98+
def remove_accept_header(self, request, response):
99+
del response.headers["Sec-WebSocket-Accept"]
100+
101+
# The connection will be open for the server but failed for the client.
102+
# Use a connection handler that exits immediately to avoid an exception.
103+
with run_server(do_nothing, process_response=remove_accept_header) as server:
104+
with self.assertRaises(InvalidHandshake) as raised:
105+
with run_client(server, close_timeout=MS):
106+
self.fail("did not raise")
107+
self.assertEqual(
108+
str(raised.exception),
109+
"missing Sec-WebSocket-Accept header",
110+
)
111+
106112
def test_timeout_during_handshake(self):
107113
"""Client times out before receiving handshake response from server."""
108114
gate = threading.Event()
@@ -115,10 +121,7 @@ def stall_connection(self, request):
115121
with run_server(do_nothing, process_request=stall_connection) as server:
116122
try:
117123
with self.assertRaises(TimeoutError) as raised:
118-
# While it shouldn't take 50ms to open a connection, this
119-
# test becomes flaky in CI when setting a smaller timeout,
120-
# even after increasing WEBSOCKETS_TESTS_TIMEOUT_FACTOR.
121-
with run_client(server, open_timeout=5 * MS):
124+
with run_client(server, open_timeout=2 * MS):
122125
self.fail("did not raise")
123126
self.assertEqual(
124127
str(raised.exception),

tests/sync/test_server.py

+18-18
Original file line numberDiff line numberDiff line change
@@ -39,21 +39,6 @@ def test_connection(self):
3939
with run_client(server) as client:
4040
self.assertEval(client, "ws.protocol.state.name", "OPEN")
4141

42-
def test_connection_fails(self):
43-
"""Server receives connection from client but the handshake fails."""
44-
45-
def remove_key_header(self, request):
46-
del request.headers["Sec-WebSocket-Key"]
47-
48-
with run_server(process_request=remove_key_header) as server:
49-
with self.assertRaises(InvalidStatus) as raised:
50-
with run_client(server):
51-
self.fail("did not raise")
52-
self.assertEqual(
53-
str(raised.exception),
54-
"server rejected WebSocket connection: HTTP 400",
55-
)
56-
5742
def test_connection_handler_returns(self):
5843
"""Connection handler returns."""
5944
with run_server(do_nothing) as server:
@@ -81,8 +66,8 @@ def test_existing_socket(self):
8166
"""Server receives connection using a pre-existing socket."""
8267
with socket.create_server(("localhost", 0)) as sock:
8368
with run_server(sock=sock):
84-
# Build WebSocket URI to ensure we connect to the right socket.
85-
with run_client("ws://{}:{}/".format(*sock.getsockname())) as client:
69+
uri = "ws://{}:{}/".format(*sock.getsockname())
70+
with run_client(uri) as client:
8671
self.assertEval(client, "ws.protocol.state.name", "OPEN")
8772

8873
def test_select_subprotocol(self):
@@ -185,7 +170,7 @@ def process_response(ws, request, response):
185170
self.assertEval(client, "ws.process_response_ran", "True")
186171

187172
def test_process_response_override_response(self):
188-
"""Server runs process_response after processing the handshake."""
173+
"""Server runs process_response and overrides the handshake response."""
189174

190175
def process_response(ws, request, response):
191176
headers = response.headers.copy()
@@ -253,6 +238,21 @@ def create_connection(*args, **kwargs):
253238
with run_client(server) as client:
254239
self.assertEval(client, "ws.create_connection_ran", "True")
255240

241+
def test_handshake_fails(self):
242+
"""Server receives connection from client but the handshake fails."""
243+
244+
def remove_key_header(self, request):
245+
del request.headers["Sec-WebSocket-Key"]
246+
247+
with run_server(process_request=remove_key_header) as server:
248+
with self.assertRaises(InvalidStatus) as raised:
249+
with run_client(server):
250+
self.fail("did not raise")
251+
self.assertEqual(
252+
str(raised.exception),
253+
"server rejected WebSocket connection: HTTP 400",
254+
)
255+
256256
def test_timeout_during_handshake(self):
257257
"""Server times out before receiving handshake request from client."""
258258
with run_server(open_timeout=MS) as server:

0 commit comments

Comments
 (0)