Skip to content

Commit 1a1e267

Browse files
committed
Improve error handling for SOCKS proxy.
1 parent c074802 commit 1a1e267

File tree

11 files changed

+203
-79
lines changed

11 files changed

+203
-79
lines changed

docs/reference/exceptions.rst

+5-3
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,16 @@ also reported by :func:`~websockets.asyncio.server.serve` in logs.
3434

3535
.. autoexception:: SecurityError
3636

37-
.. autoexception:: InvalidMessage
38-
39-
.. autoexception:: InvalidStatus
37+
.. autoexception:: ProxyError
4038

4139
.. autoexception:: InvalidProxyMessage
4240

4341
.. autoexception:: InvalidProxyStatus
4442

43+
.. autoexception:: InvalidMessage
44+
45+
.. autoexception:: InvalidStatus
46+
4547
.. autoexception:: InvalidHeader
4648

4749
.. autoexception:: InvalidHeaderFormat

example/asyncio/client.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,17 @@
66

77
from websockets.asyncio.client import connect
88

9+
import logging; logging.basicConfig(level=logging.INFO)
10+
911

1012
async def hello():
11-
async with connect("ws://localhost:8765") as websocket:
13+
async for websocket in connect("ws://localhost:8765", proxy="socks5h://localhost:1080"):
1214
name = input("What's your name? ")
1315

1416
await websocket.send(name)
1517
print(f">>> {name}")
1618

19+
1720
greeting = await websocket.recv()
1821
print(f"<<< {greeting}")
1922

src/websockets/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
"NegotiationError",
5050
"PayloadTooBig",
5151
"ProtocolError",
52+
"ProxyError",
5253
"SecurityError",
5354
"WebSocketException",
5455
# .frames
@@ -112,6 +113,7 @@
112113
NegotiationError,
113114
PayloadTooBig,
114115
ProtocolError,
116+
ProxyError,
115117
SecurityError,
116118
WebSocketException,
117119
)
@@ -173,6 +175,7 @@
173175
"NegotiationError": ".exceptions",
174176
"PayloadTooBig": ".exceptions",
175177
"ProtocolError": ".exceptions",
178+
"ProxyError": ".exceptions",
176179
"SecurityError": ".exceptions",
177180
"WebSocketException": ".exceptions",
178181
# .frames

src/websockets/asyncio/client.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from ..client import ClientProtocol, backoff
1414
from ..datastructures import HeadersLike
15-
from ..exceptions import InvalidMessage, InvalidStatus, SecurityError
15+
from ..exceptions import InvalidMessage, InvalidStatus, ProxyError, SecurityError
1616
from ..extensions.base import ClientExtensionFactory
1717
from ..extensions.permessage_deflate import enable_client_permessage_deflate
1818
from ..headers import validate_subprotocols
@@ -148,7 +148,9 @@ def process_exception(exc: Exception) -> Exception | None:
148148
That exception will be raised, breaking out of the retry loop.
149149
150150
"""
151-
if isinstance(exc, (OSError, asyncio.TimeoutError)):
151+
# This catches socks_proxy's ProxyConnectionError and ProxyTimeoutError.
152+
# Remove asyncio.TimeoutError when dropping Python < 3.11.
153+
if isinstance(exc, (OSError, TimeoutError, asyncio.TimeoutError)):
152154
return None
153155
if isinstance(exc, InvalidMessage) and isinstance(exc.__cause__, EOFError):
154156
return None
@@ -266,6 +268,7 @@ class connect:
266268
267269
Raises:
268270
InvalidURI: If ``uri`` isn't a valid WebSocket URI.
271+
InvalidProxy: If ``proxy`` isn't a valid proxy.
269272
OSError: If the TCP connection fails.
270273
InvalidHandshake: If the opening handshake fails.
271274
TimeoutError: If the opening handshake times out.
@@ -622,7 +625,15 @@ async def connect_socks_proxy(
622625
proxy.password,
623626
SOCKS_PROXY_RDNS[proxy.scheme],
624627
)
625-
return await socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs)
628+
# connect() is documented to raise OSError.
629+
# socks_proxy.connect() doesn't raise TimeoutError; it gets canceled.
630+
# Wrap other exceptions in ProxyError, a subclass of InvalidHandshake.
631+
try:
632+
return await socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs)
633+
except OSError:
634+
raise
635+
except Exception as exc:
636+
raise ProxyError("failed to connect to SOCKS proxy") from exc
626637

627638
except ImportError:
628639

src/websockets/exceptions.py

+25-16
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
* :exc:`InvalidProxy`
1010
* :exc:`InvalidHandshake`
1111
* :exc:`SecurityError`
12+
* :exc:`ProxyError`
13+
* :exc:`InvalidProxyMessage`
14+
* :exc:`InvalidProxyStatus`
1215
* :exc:`InvalidMessage`
1316
* :exc:`InvalidStatus`
1417
* :exc:`InvalidStatusCode` (legacy)
15-
* :exc:`InvalidProxyMessage`
16-
* :exc:`InvalidProxyStatus`
1718
* :exc:`InvalidHeader`
1819
* :exc:`InvalidHeaderFormat`
1920
* :exc:`InvalidHeaderValue`
@@ -48,10 +49,11 @@
4849
"InvalidProxy",
4950
"InvalidHandshake",
5051
"SecurityError",
51-
"InvalidMessage",
52-
"InvalidStatus",
52+
"ProxyError",
5353
"InvalidProxyMessage",
5454
"InvalidProxyStatus",
55+
"InvalidMessage",
56+
"InvalidStatus",
5557
"InvalidHeader",
5658
"InvalidHeaderFormat",
5759
"InvalidHeaderValue",
@@ -206,46 +208,53 @@ class SecurityError(InvalidHandshake):
206208
"""
207209

208210

209-
class InvalidMessage(InvalidHandshake):
211+
class ProxyError(InvalidHandshake):
210212
"""
211-
Raised when a handshake request or response is malformed.
213+
Raised when failing to connect to a proxy.
212214
213215
"""
214216

215217

216-
class InvalidStatus(InvalidHandshake):
218+
class InvalidProxyMessage(ProxyError):
217219
"""
218-
Raised when a handshake response rejects the WebSocket upgrade.
220+
Raised when an HTTP proxy response is malformed.
221+
222+
"""
223+
224+
225+
class InvalidProxyStatus(ProxyError):
226+
"""
227+
Raised when an HTTP proxy rejects the connection.
219228
220229
"""
221230

222231
def __init__(self, response: http11.Response) -> None:
223232
self.response = response
224233

225234
def __str__(self) -> str:
226-
return (
227-
f"server rejected WebSocket connection: HTTP {self.response.status_code:d}"
228-
)
235+
return f"proxy rejected connection: HTTP {self.response.status_code:d}"
229236

230237

231-
class InvalidProxyMessage(InvalidHandshake):
238+
class InvalidMessage(InvalidHandshake):
232239
"""
233-
Raised when a proxy response is malformed.
240+
Raised when a handshake request or response is malformed.
234241
235242
"""
236243

237244

238-
class InvalidProxyStatus(InvalidHandshake):
245+
class InvalidStatus(InvalidHandshake):
239246
"""
240-
Raised when a proxy rejects the connection.
247+
Raised when a handshake response rejects the WebSocket upgrade.
241248
242249
"""
243250

244251
def __init__(self, response: http11.Response) -> None:
245252
self.response = response
246253

247254
def __str__(self) -> str:
248-
return f"proxy rejected connection: HTTP {self.response.status_code:d}"
255+
return (
256+
f"server rejected WebSocket connection: HTTP {self.response.status_code:d}"
257+
)
249258

250259

251260
class InvalidHeader(InvalidHandshake):

src/websockets/sync/client.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from ..client import ClientProtocol
1111
from ..datastructures import HeadersLike
12+
from ..exceptions import ProxyError
1213
from ..extensions.base import ClientExtensionFactory
1314
from ..extensions.permessage_deflate import enable_client_permessage_deflate
1415
from ..headers import validate_subprotocols
@@ -418,7 +419,14 @@ def connect_socks_proxy(
418419
SOCKS_PROXY_RDNS[proxy.scheme],
419420
)
420421
kwargs.setdefault("timeout", deadline.timeout())
421-
return socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs)
422+
# connect() is documented to raise OSError and TimeoutError.
423+
# Wrap other exceptions in ProxyError, a subclass of InvalidHandshake.
424+
try:
425+
return socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs)
426+
except (OSError, TimeoutError, socket.timeout):
427+
raise
428+
except Exception as exc:
429+
raise ProxyError("failed to connect to SOCKS proxy") from exc
422430

423431
except ImportError:
424432

tests/asyncio/test_client.py

+68-25
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
InvalidProxy,
1818
InvalidStatus,
1919
InvalidURI,
20+
ProxyError,
2021
SecurityError,
2122
)
2223
from websockets.extensions.permessage_deflate import PerMessageDeflate
@@ -379,24 +380,16 @@ def remove_accept_header(self, request, response):
379380

380381
async def test_timeout_during_handshake(self):
381382
"""Client times out before receiving handshake response from server."""
382-
gate = asyncio.get_running_loop().create_future()
383-
384-
async def stall_connection(self, request):
385-
await gate
386-
387-
# The connection will be open for the server but failed for the client.
388-
# Use a connection handler that exits immediately to avoid an exception.
389-
async with serve(*args, process_request=stall_connection) as server:
390-
try:
391-
with self.assertRaises(TimeoutError) as raised:
392-
async with connect(get_uri(server) + "/no-op", open_timeout=2 * MS):
393-
self.fail("did not raise")
394-
self.assertEqual(
395-
str(raised.exception),
396-
"timed out during handshake",
397-
)
398-
finally:
399-
gate.set_result(None)
383+
# Replace the WebSocket server with a TCP server that does't respond.
384+
with socket.create_server(("localhost", 0)) as sock:
385+
host, port = sock.getsockname()
386+
with self.assertRaises(TimeoutError) as raised:
387+
async with connect(f"ws://{host}:{port}", open_timeout=MS):
388+
self.fail("did not raise")
389+
self.assertEqual(
390+
str(raised.exception),
391+
"timed out during handshake",
392+
)
400393

401394
async def test_connection_closed_during_handshake(self):
402395
"""Client reads EOF before receiving handshake response from server."""
@@ -570,11 +563,13 @@ class ProxyClientTests(unittest.IsolatedAsyncioTestCase):
570563
async def socks_proxy(self, auth=None):
571564
if auth:
572565
proxyauth = "hello:iloveyou"
573-
proxy_uri = "http://hello:iloveyou@localhost:1080"
566+
proxy_uri = "http://hello:iloveyou@localhost:51080"
574567
else:
575568
proxyauth = None
576-
proxy_uri = "http://localhost:1080"
577-
async with async_proxy(mode=["socks5"], proxyauth=proxyauth) as record_flows:
569+
proxy_uri = "http://localhost:51080"
570+
async with async_proxy(
571+
mode=["socks5@51080"], proxyauth=proxyauth
572+
) as record_flows:
578573
with patch_environ({"socks_proxy": proxy_uri}):
579574
yield record_flows
580575

@@ -602,14 +597,62 @@ async def test_authenticated_socks_proxy(self):
602597
self.assertEqual(client.protocol.state.name, "OPEN")
603598
self.assertEqual(len(proxy.get_flows()), 1)
604599

600+
async def test_socks_proxy_connection_error(self):
601+
"""Client receives an error when connecting to the SOCKS5 proxy."""
602+
from python_socks import ProxyError as SocksProxyError
603+
604+
async with self.socks_proxy(auth=True) as proxy:
605+
with self.assertRaises(ProxyError) as raised:
606+
async with connect(
607+
"ws://example.com/",
608+
proxy="socks5h://localhost:51080", # remove credentials
609+
):
610+
self.fail("did not raise")
611+
self.assertEqual(
612+
str(raised.exception),
613+
"failed to connect to SOCKS proxy",
614+
)
615+
self.assertIsInstance(raised.exception.__cause__, SocksProxyError)
616+
self.assertEqual(len(proxy.get_flows()), 0)
617+
618+
async def test_socks_proxy_connection_fails(self):
619+
"""Client fails to connect to the SOCKS5 proxy."""
620+
from python_socks import ProxyConnectionError as SocksProxyConnectionError
621+
622+
with self.assertRaises(OSError) as raised:
623+
async with connect(
624+
"ws://example.com/",
625+
proxy="socks5h://localhost:51080", # nothing at this address
626+
):
627+
self.fail("did not raise")
628+
# Don't test str(raised.exception) because we don't control it.
629+
self.assertIsInstance(raised.exception, SocksProxyConnectionError)
630+
631+
async def test_socks_proxy_connection_timeout(self):
632+
"""Client times out while connecting to the SOCKS5 proxy."""
633+
# Replace the proxy with a TCP server that does't respond.
634+
with socket.create_server(("localhost", 0)) as sock:
635+
host, port = sock.getsockname()
636+
with self.assertRaises(TimeoutError) as raised:
637+
async with connect(
638+
"ws://example.com/",
639+
proxy=f"socks5h://{host}:{port}/",
640+
open_timeout=MS,
641+
):
642+
self.fail("did not raise")
643+
self.assertEqual(
644+
str(raised.exception),
645+
"timed out during handshake",
646+
)
647+
605648
async def test_explicit_proxy(self):
606649
"""Client connects to server through a proxy set explicitly."""
607-
async with async_proxy(mode=["socks5"]) as proxy:
650+
async with async_proxy(mode=["socks5@51080"]) as proxy:
608651
async with serve(*args) as server:
609652
async with connect(
610653
get_uri(server),
611654
# Take this opportunity to test socks5 instead of socks5h.
612-
proxy="socks5://localhost:1080",
655+
proxy="socks5://localhost:51080",
613656
) as client:
614657
self.assertEqual(client.protocol.state.name, "OPEN")
615658
self.assertEqual(len(proxy.get_flows()), 1)
@@ -626,13 +669,13 @@ async def test_ignore_proxy_with_existing_socket(self):
626669

627670
async def test_unsupported_proxy(self):
628671
"""Client connects to server through an unsupported proxy."""
629-
with patch_environ({"ws_proxy": "other://localhost:1080"}):
672+
with patch_environ({"ws_proxy": "other://localhost:51080"}):
630673
with self.assertRaises(InvalidProxy) as raised:
631674
async with connect("ws://example.com/"):
632675
self.fail("did not raise")
633676
self.assertEqual(
634677
str(raised.exception),
635-
"other://localhost:1080 isn't a valid proxy: scheme other isn't supported",
678+
"other://localhost:51080 isn't a valid proxy: scheme other isn't supported",
636679
)
637680

638681

tests/asyncio/test_server.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ async def test_connection_handler_raises_exception(self):
6565
async def test_existing_socket(self):
6666
"""Server receives connection using a pre-existing socket."""
6767
with socket.create_server(("localhost", 0)) as sock:
68-
async with serve(handler, sock=sock, host=None, port=None):
69-
uri = "ws://{}:{}/".format(*sock.getsockname())
70-
async with connect(uri) as client:
68+
host, port = sock.getsockname()
69+
async with serve(handler, sock=sock):
70+
async with connect(f"ws://{host}:{port}/") as client:
7171
await self.assertEval(client, "ws.protocol.state.name", "OPEN")
7272

7373
async def test_select_subprotocol(self):

0 commit comments

Comments
 (0)