Skip to content

Commit f9cea9c

Browse files
committed
Improve isolation of tests of sync implementation.
Before this change, threads handling requests could continue running after the end of the test. This caused spurious failures. Specifically, a test expecting an error log could get an error log from a previous tests. This happened sporadically on PyPy.
1 parent 14d9d40 commit f9cea9c

File tree

2 files changed

+21
-51
lines changed

2 files changed

+21
-51
lines changed

tests/sync/server.py

+18
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,30 @@ def run_server(handler=handler, host="localhost", port=0, **kwargs):
3838
with serve(handler, host, port, **kwargs) as server:
3939
thread = threading.Thread(target=server.serve_forever)
4040
thread.start()
41+
42+
# HACK: since the sync server doesn't track connections (yet), we record
43+
# a reference to the thread handling the most recent connection, then we
44+
# can wait for that thread to terminate when exiting the context.
45+
handler_thread = None
46+
original_handler = server.handler
47+
48+
def handler(sock, addr):
49+
nonlocal handler_thread
50+
handler_thread = threading.current_thread()
51+
original_handler(sock, addr)
52+
53+
server.handler = handler
54+
4155
try:
4256
yield server
4357
finally:
4458
server.shutdown()
4559
thread.join()
4660

61+
# HACK: wait for the thread handling the most recent connection.
62+
if handler_thread is not None:
63+
handler_thread.join()
64+
4765

4866
@contextlib.contextmanager
4967
def run_unix_server(path, handler=handler, **kwargs):

tests/sync/test_server.py

+3-51
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import http
44
import logging
55
import socket
6-
import threading
6+
import time
77
import unittest
88

99
from websockets.exceptions import (
@@ -289,50 +289,19 @@ def test_timeout_during_handshake(self):
289289
def test_connection_closed_during_handshake(self):
290290
"""Server reads EOF before receiving handshake request from client."""
291291
with run_server() as server:
292-
# Patch handler to record a reference to the thread running it.
293-
server_thread = None
294-
conn_received = threading.Event()
295-
original_handler = server.handler
296-
297-
def handler(sock, addr):
298-
nonlocal server_thread
299-
server_thread = threading.current_thread()
300-
nonlocal conn_received
301-
conn_received.set()
302-
original_handler(sock, addr)
303-
304-
server.handler = handler
305-
306292
with socket.create_connection(server.socket.getsockname()):
307293
# Wait for the server to receive the connection, then close it.
308-
conn_received.wait()
309-
310-
# Wait for the server thread to terminate.
311-
server_thread.join()
294+
time.sleep(MS)
312295

313296
def test_junk_handshake(self):
314297
"""Server closes the connection when receiving non-HTTP request from client."""
315298
with self.assertLogs("websockets.server", logging.ERROR) as logs:
316299
with run_server() as server:
317-
# Patch handler to record a reference to the thread running it.
318-
server_thread = None
319-
original_handler = server.handler
320-
321-
def handler(sock, addr):
322-
nonlocal server_thread
323-
server_thread = threading.current_thread()
324-
original_handler(sock, addr)
325-
326-
server.handler = handler
327-
328300
with socket.create_connection(server.socket.getsockname()) as sock:
329301
sock.send(b"HELO relay.invalid\r\n")
330302
# Wait for the server to close the connection.
331303
self.assertEqual(sock.recv(4096), b"")
332304

333-
# Wait for the server thread to terminate.
334-
server_thread.join()
335-
336305
self.assertEqual(
337306
[record.getMessage() for record in logs.records],
338307
["opening handshake failed"],
@@ -360,26 +329,9 @@ def test_timeout_during_tls_handshake(self):
360329
def test_connection_closed_during_tls_handshake(self):
361330
"""Server reads EOF before receiving TLS handshake request from client."""
362331
with run_server(ssl=SERVER_CONTEXT) as server:
363-
# Patch handler to record a reference to the thread running it.
364-
server_thread = None
365-
conn_received = threading.Event()
366-
original_handler = server.handler
367-
368-
def handler(sock, addr):
369-
nonlocal server_thread
370-
server_thread = threading.current_thread()
371-
nonlocal conn_received
372-
conn_received.set()
373-
original_handler(sock, addr)
374-
375-
server.handler = handler
376-
377332
with socket.create_connection(server.socket.getsockname()):
378333
# Wait for the server to receive the connection, then close it.
379-
conn_received.wait()
380-
381-
# Wait for the server thread to terminate.
382-
server_thread.join()
334+
time.sleep(MS)
383335

384336

385337
@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets")

0 commit comments

Comments
 (0)