Skip to content

Commit 10175f7

Browse files
committed
Refactor SOCKS proxy implementation.
1 parent 4a89e56 commit 10175f7

File tree

2 files changed

+141
-95
lines changed

2 files changed

+141
-95
lines changed

src/websockets/asyncio/client.py

+65-44
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import asyncio
44
import logging
55
import os
6+
import socket
67
import traceback
78
import urllib.parse
89
from collections.abc import AsyncIterator, Generator, Sequence
@@ -357,15 +358,12 @@ async def create_connection(self) -> ClientConnection:
357358
ws_uri = parse_uri(self.uri)
358359

359360
proxy = self.proxy
360-
proxy_uri: Proxy | None = None
361361
if kwargs.get("unix", False):
362362
proxy = None
363363
if kwargs.get("sock") is not None:
364364
proxy = None
365365
if proxy is True:
366366
proxy = get_proxy(ws_uri)
367-
if proxy is not None:
368-
proxy_uri = parse_proxy(proxy)
369367

370368
def factory() -> ClientConnection:
371369
return self.protocol_factory(ws_uri)
@@ -381,48 +379,14 @@ def factory() -> ClientConnection:
381379

382380
if kwargs.pop("unix", False):
383381
_, connection = await loop.create_unix_connection(factory, **kwargs)
382+
elif proxy is not None:
383+
kwargs["sock"] = await connect_proxy(
384+
parse_proxy(proxy),
385+
ws_uri,
386+
local_addr=kwargs.pop("local_addr", None),
387+
)
388+
_, connection = await loop.create_connection(factory, **kwargs)
384389
else:
385-
if proxy_uri is not None:
386-
if proxy_uri.scheme[:5] == "socks":
387-
try:
388-
from python_socks import ProxyType
389-
from python_socks.async_.asyncio import Proxy
390-
except ImportError:
391-
raise ImportError(
392-
"python-socks is required to use a SOCKS proxy"
393-
)
394-
if proxy_uri.scheme == "socks5h":
395-
proxy_type = ProxyType.SOCKS5
396-
rdns = True
397-
elif proxy_uri.scheme == "socks5":
398-
proxy_type = ProxyType.SOCKS5
399-
rdns = False
400-
# We use mitmproxy for testing and it doesn't support SOCKS4.
401-
elif proxy_uri.scheme == "socks4a": # pragma: no cover
402-
proxy_type = ProxyType.SOCKS4
403-
rdns = True
404-
elif proxy_uri.scheme == "socks4": # pragma: no cover
405-
proxy_type = ProxyType.SOCKS4
406-
rdns = False
407-
# Proxy types are enforced in parse_proxy().
408-
else:
409-
raise AssertionError("unsupported SOCKS proxy")
410-
socks_proxy = Proxy(
411-
proxy_type,
412-
proxy_uri.host,
413-
proxy_uri.port,
414-
proxy_uri.username,
415-
proxy_uri.password,
416-
rdns,
417-
)
418-
kwargs["sock"] = await socks_proxy.connect(
419-
ws_uri.host,
420-
ws_uri.port,
421-
local_addr=kwargs.pop("local_addr", None),
422-
)
423-
# Proxy types are enforced in parse_proxy().
424-
else:
425-
raise AssertionError("unsupported proxy")
426390
if kwargs.get("sock") is None:
427391
kwargs.setdefault("host", ws_uri.host)
428392
kwargs.setdefault("port", ws_uri.port)
@@ -624,3 +588,60 @@ def unix_connect(
624588
else:
625589
uri = "wss://localhost/"
626590
return connect(uri=uri, unix=True, path=path, **kwargs)
591+
592+
593+
try:
594+
from python_socks import ProxyType
595+
from python_socks.async_.asyncio import Proxy as SocksProxy
596+
597+
SOCKS_PROXY_TYPES = {
598+
"socks5h": ProxyType.SOCKS5,
599+
"socks5": ProxyType.SOCKS5,
600+
"socks4a": ProxyType.SOCKS4,
601+
"socks4": ProxyType.SOCKS4,
602+
}
603+
604+
SOCKS_PROXY_RDNS = {
605+
"socks5h": True,
606+
"socks5": False,
607+
"socks4a": True,
608+
"socks4": False,
609+
}
610+
611+
async def connect_socks_proxy(
612+
proxy: Proxy,
613+
ws_uri: WebSocketURI,
614+
**kwargs: Any,
615+
) -> socket.socket:
616+
"""Connect via a SOCKS proxy and return the socket."""
617+
socks_proxy = SocksProxy(
618+
SOCKS_PROXY_TYPES[proxy.scheme],
619+
proxy.host,
620+
proxy.port,
621+
proxy.username,
622+
proxy.password,
623+
SOCKS_PROXY_RDNS[proxy.scheme],
624+
)
625+
return await socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs)
626+
627+
except ImportError:
628+
629+
async def connect_socks_proxy(
630+
proxy: Proxy,
631+
ws_uri: WebSocketURI,
632+
**kwargs: Any,
633+
) -> socket.socket:
634+
raise ImportError("python-socks is required to use a SOCKS proxy")
635+
636+
637+
async def connect_proxy(
638+
proxy: Proxy,
639+
ws_uri: WebSocketURI,
640+
**kwargs: Any,
641+
) -> socket.socket:
642+
"""Connect via a proxy and return the socket."""
643+
# parse_proxy() validates proxy.scheme.
644+
if proxy.scheme[:5] == "socks":
645+
return await connect_socks_proxy(proxy, ws_uri, **kwargs)
646+
else:
647+
raise AssertionError("unsupported proxy")

src/websockets/sync/client.py

+76-51
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ..http11 import USER_AGENT, Response
1616
from ..protocol import CONNECTING, Event
1717
from ..typing import LoggerLike, Origin, Subprotocol
18-
from ..uri import Proxy, get_proxy, parse_proxy, parse_uri
18+
from ..uri import Proxy, WebSocketURI, get_proxy, parse_proxy, parse_uri
1919
from .connection import Connection
2020
from .utils import Deadline
2121

@@ -258,15 +258,12 @@ def connect(
258258
elif compression is not None:
259259
raise ValueError(f"unsupported compression: {compression}")
260260

261-
proxy_uri: Proxy | None = None
262261
if unix:
263262
proxy = None
264263
if sock is not None:
265264
proxy = None
266265
if proxy is True:
267266
proxy = get_proxy(ws_uri)
268-
if proxy is not None:
269-
proxy_uri = parse_proxy(proxy)
270267

271268
# Calculate timeouts on the TCP, TLS, and WebSocket handshakes.
272269
# The TCP and TLS timeouts must be set on the socket, then removed
@@ -285,54 +282,21 @@ def connect(
285282
sock.settimeout(deadline.timeout())
286283
assert path is not None # mypy cannot figure this out
287284
sock.connect(path)
285+
elif proxy is not None:
286+
sock = connect_proxy(
287+
parse_proxy(proxy),
288+
ws_uri,
289+
deadline,
290+
# websockets is consistent with the socket module while
291+
# python_socks is consistent across implementations.
292+
local_addr=kwargs.pop("source_address", None),
293+
)
288294
else:
289-
if proxy_uri is not None:
290-
if proxy_uri.scheme[:5] == "socks":
291-
try:
292-
from python_socks import ProxyType
293-
from python_socks.sync import Proxy
294-
except ImportError:
295-
raise ImportError(
296-
"python-socks is required to use a SOCKS proxy"
297-
)
298-
if proxy_uri.scheme == "socks5h":
299-
proxy_type = ProxyType.SOCKS5
300-
rdns = True
301-
elif proxy_uri.scheme == "socks5":
302-
proxy_type = ProxyType.SOCKS5
303-
rdns = False
304-
# We use mitmproxy for testing and it doesn't support SOCKS4.
305-
elif proxy_uri.scheme == "socks4a": # pragma: no cover
306-
proxy_type = ProxyType.SOCKS4
307-
rdns = True
308-
elif proxy_uri.scheme == "socks4": # pragma: no cover
309-
proxy_type = ProxyType.SOCKS4
310-
rdns = False
311-
# Proxy types are enforced in parse_proxy().
312-
else:
313-
raise AssertionError("unsupported SOCKS proxy")
314-
socks_proxy = Proxy(
315-
proxy_type,
316-
proxy_uri.host,
317-
proxy_uri.port,
318-
proxy_uri.username,
319-
proxy_uri.password,
320-
rdns,
321-
)
322-
sock = socks_proxy.connect(
323-
ws_uri.host,
324-
ws_uri.port,
325-
timeout=deadline.timeout(),
326-
local_addr=kwargs.pop("local_addr", None),
327-
)
328-
# Proxy types are enforced in parse_proxy().
329-
else:
330-
raise AssertionError("unsupported proxy")
331-
else:
332-
kwargs.setdefault("timeout", deadline.timeout())
333-
sock = socket.create_connection(
334-
(ws_uri.host, ws_uri.port), **kwargs
335-
)
295+
kwargs.setdefault("timeout", deadline.timeout())
296+
sock = socket.create_connection(
297+
(ws_uri.host, ws_uri.port),
298+
**kwargs,
299+
)
336300
sock.settimeout(None)
337301

338302
# Disable Nagle algorithm
@@ -420,3 +384,64 @@ def unix_connect(
420384
else:
421385
uri = "wss://localhost/"
422386
return connect(uri=uri, unix=True, path=path, **kwargs)
387+
388+
389+
try:
390+
from python_socks import ProxyType
391+
from python_socks.sync import Proxy as SocksProxy
392+
393+
SOCKS_PROXY_TYPES = {
394+
"socks5h": ProxyType.SOCKS5,
395+
"socks5": ProxyType.SOCKS5,
396+
"socks4a": ProxyType.SOCKS4,
397+
"socks4": ProxyType.SOCKS4,
398+
}
399+
400+
SOCKS_PROXY_RDNS = {
401+
"socks5h": True,
402+
"socks5": False,
403+
"socks4a": True,
404+
"socks4": False,
405+
}
406+
407+
def connect_socks_proxy(
408+
proxy: Proxy,
409+
ws_uri: WebSocketURI,
410+
deadline: Deadline,
411+
**kwargs: Any,
412+
) -> socket.socket:
413+
"""Connect via a SOCKS proxy and return the socket."""
414+
socks_proxy = SocksProxy(
415+
SOCKS_PROXY_TYPES[proxy.scheme],
416+
proxy.host,
417+
proxy.port,
418+
proxy.username,
419+
proxy.password,
420+
SOCKS_PROXY_RDNS[proxy.scheme],
421+
)
422+
kwargs.setdefault("timeout", deadline.timeout())
423+
return socks_proxy.connect(ws_uri.host, ws_uri.port, **kwargs)
424+
425+
except ImportError:
426+
427+
def connect_socks_proxy(
428+
proxy: Proxy,
429+
ws_uri: WebSocketURI,
430+
deadline: Deadline,
431+
**kwargs: Any,
432+
) -> socket.socket:
433+
raise ImportError("python-socks is required to use a SOCKS proxy")
434+
435+
436+
def connect_proxy(
437+
proxy: Proxy,
438+
ws_uri: WebSocketURI,
439+
deadline: Deadline,
440+
**kwargs: Any,
441+
) -> socket.socket:
442+
"""Connect via a proxy and return the socket."""
443+
# parse_proxy() validates proxy.scheme.
444+
if proxy.scheme[:5] == "socks":
445+
return connect_socks_proxy(proxy, ws_uri, deadline, **kwargs)
446+
else:
447+
raise AssertionError("unsupported proxy")

0 commit comments

Comments
 (0)