diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index bfbfa793f..7bb94b349 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -35,12 +35,12 @@ notice. Backwards-incompatible changes .............................. -.. admonition:: Client connections use SOCKS proxies automatically. +.. admonition:: Client connections use SOCKS and HTTP proxies automatically. :class: important If a proxy is configured in the operating system or with an environment variable, websockets uses it automatically when connecting to a server. - This feature requires installing the third-party library `python-socks`_. + SOCKS proxies require installing the third-party library `python-socks`_. If you want to disable the proxy, add ``proxy=None`` when calling :func:`~asyncio.client.connect`. See :doc:`../topics/proxies` for details. diff --git a/docs/reference/features.rst b/docs/reference/features.rst index eaecd02a9..93b083d20 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -166,12 +166,11 @@ Client | Perform HTTP Digest Authentication | ❌ | ❌ | ❌ | ❌ | | (`#784`_) | | | | | +------------------------------------+--------+--------+--------+--------+ - | Connect via HTTP proxy (`#364`_) | ❌ | ❌ | — | ❌ | + | Connect via HTTP proxy | ✅ | ✅ | — | ❌ | +------------------------------------+--------+--------+--------+--------+ | Connect via SOCKS5 proxy | ✅ | ✅ | — | ❌ | +------------------------------------+--------+--------+--------+--------+ -.. _#364: https://github.com/python-websockets/websockets/issues/364 .. _#784: https://github.com/python-websockets/websockets/issues/784 Known limitations diff --git a/docs/topics/proxies.rst b/docs/topics/proxies.rst index fd3ae78b6..14fc68c0c 100644 --- a/docs/topics/proxies.rst +++ b/docs/topics/proxies.rst @@ -30,6 +30,9 @@ most common, for `historical reasons`_, and recommended. .. _historical reasons: https://unix.stackexchange.com/questions/212894/ +websockets authenticates automatically when the address of the proxy includes +credentials e.g. ``http://user:password@proxy:8080/``. + .. admonition:: Any environment variable can configure a SOCKS proxy or an HTTP proxy. :class: tip @@ -64,3 +67,19 @@ SOCKS proxy is configured in the operating system, python-socks uses SOCKS5h. python-socks supports username/password authentication for SOCKS5 (:rfc:`1929`) but does not support other authentication methods such as GSSAPI (:rfc:`1961`). + +HTTP proxies +------------ + +When the address of the proxy starts with ``https://``, websockets secures the +connection to the proxy with TLS. + +When the address of the server starts with ``wss://``, websockets secures the +connection from the proxy to the server with TLS. + +These two options are compatible. TLS-in-TLS is supported. + +The documentation of :func:`~asyncio.client.connect` describes how to configure +TLS from websockets to the proxy and from the proxy to the server. + +websockets supports proxy authentication with Basic Auth. diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 1e560fe0c..c19a53f8c 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -4,20 +4,29 @@ import logging import os import socket +import ssl as ssl_module import traceback import urllib.parse from collections.abc import AsyncIterator, Generator, Sequence from types import TracebackType -from typing import Any, Callable, Literal +from typing import Any, Callable, Literal, cast from ..client import ClientProtocol, backoff -from ..datastructures import HeadersLike -from ..exceptions import InvalidMessage, InvalidStatus, ProxyError, SecurityError +from ..datastructures import Headers, HeadersLike +from ..exceptions import ( + InvalidMessage, + InvalidProxyMessage, + InvalidProxyStatus, + InvalidStatus, + ProxyError, + SecurityError, +) from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate -from ..headers import validate_subprotocols +from ..headers import build_authorization_basic, build_host, validate_subprotocols from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event +from ..streams import StreamReader from ..typing import LoggerLike, Origin, Subprotocol from ..uri import Proxy, WebSocketURI, get_proxy, parse_proxy, parse_uri from .compatibility import TimeoutError, asyncio_timeout @@ -266,6 +275,16 @@ class connect: :meth:`~asyncio.loop.create_connection` method) to create a suitable client socket and customize it. + When using a proxy: + + * Prefix keyword arguments with ``proxy_`` for configuring TLS between the + client and an HTTPS proxy: ``proxy_ssl``, ``proxy_server_hostname``, + ``proxy_ssl_handshake_timeout``, and ``proxy_ssl_shutdown_timeout``. + * Use the standard keyword arguments for configuring TLS between the proxy + and the WebSocket server: ``ssl``, ``server_hostname``, + ``ssl_handshake_timeout``, and ``ssl_shutdown_timeout``. + * Other keyword arguments are used only for connecting to the proxy. + Raises: InvalidURI: If ``uri`` isn't a valid WebSocket URI. InvalidProxy: If ``proxy`` isn't a valid proxy. @@ -383,16 +402,69 @@ def factory() -> ClientConnection: if kwargs.pop("unix", False): _, connection = await loop.create_unix_connection(factory, **kwargs) elif proxy is not None: - kwargs["sock"] = await connect_proxy( - parse_proxy(proxy), - ws_uri, - local_addr=kwargs.pop("local_addr", None), - ) - _, connection = await loop.create_connection(factory, **kwargs) + proxy_parsed = parse_proxy(proxy) + if proxy_parsed.scheme[:5] == "socks": + # Connect to the server through the proxy. + sock = await connect_socks_proxy( + proxy_parsed, + ws_uri, + local_addr=kwargs.pop("local_addr", None), + ) + # Initialize WebSocket connection via the proxy. + _, connection = await loop.create_connection( + factory, + sock=sock, + **kwargs, + ) + elif proxy_parsed.scheme[:4] == "http": + # Split keyword arguments between the proxy and the server. + all_kwargs, proxy_kwargs, kwargs = kwargs, {}, {} + for key, value in all_kwargs.items(): + if key.startswith("ssl") or key == "server_hostname": + kwargs[key] = value + elif key.startswith("proxy_"): + proxy_kwargs[key[6:]] = value + else: + proxy_kwargs[key] = value + # Validate the proxy_ssl argument. + if proxy_parsed.scheme == "https": + proxy_kwargs.setdefault("ssl", True) + if proxy_kwargs.get("ssl") is None: + raise ValueError( + "proxy_ssl=None is incompatible with an https:// proxy" + ) + else: + if proxy_kwargs.get("ssl") is not None: + raise ValueError( + "proxy_ssl argument is incompatible with an http:// proxy" + ) + # Connect to the server through the proxy. + transport = await connect_http_proxy( + proxy_parsed, + ws_uri, + **proxy_kwargs, + ) + # Initialize WebSocket connection via the proxy. + connection = factory() + transport.set_protocol(connection) + ssl = kwargs.pop("ssl", None) + if ssl is True: + ssl = ssl_module.create_default_context() + if ssl is not None: + new_transport = await loop.start_tls( + transport, connection, ssl, **kwargs + ) + assert new_transport is not None # help mypy + transport = new_transport + connection.connection_made(transport) + else: + raise AssertionError("unsupported proxy") else: + # Connect to the server directly. if kwargs.get("sock") is None: kwargs.setdefault("host", ws_uri.host) kwargs.setdefault("port", ws_uri.port) + # Initialize WebSocket connection. _, connection = await loop.create_connection(factory, **kwargs) return connection @@ -499,9 +571,9 @@ async def __await_impl__(self) -> ClientConnection: else: raise SecurityError(f"more than {MAX_REDIRECTS} redirects") - except TimeoutError: + except TimeoutError as exc: # Re-raise exception with an informative error message. - raise TimeoutError("timed out during handshake") from None + raise TimeoutError("timed out during opening handshake") from exc # ... = yield from connect(...) - remove when dropping Python < 3.10 @@ -645,14 +717,87 @@ async def connect_socks_proxy( raise ImportError("python-socks is required to use a SOCKS proxy") -async def connect_proxy( +def prepare_connect_request(proxy: Proxy, ws_uri: WebSocketURI) -> bytes: + host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True) + headers = Headers() + headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure) + if proxy.username is not None: + assert proxy.password is not None # enforced by parse_proxy() + headers["Proxy-Authorization"] = build_authorization_basic( + proxy.username, proxy.password + ) + # We cannot use the Request class because it supports only GET requests. + return f"CONNECT {host} HTTP/1.1\r\n".encode() + headers.serialize() + + +class HTTPProxyConnection(asyncio.Protocol): + def __init__(self, ws_uri: WebSocketURI, proxy: Proxy): + self.ws_uri = ws_uri + self.proxy = proxy + + self.reader = StreamReader() + self.parser = Response.parse( + self.reader.read_line, + self.reader.read_exact, + self.reader.read_to_eof, + include_body=False, + ) + + loop = asyncio.get_running_loop() + self.response: asyncio.Future[Response] = loop.create_future() + + def run_parser(self) -> None: + try: + next(self.parser) + except StopIteration as exc: + response = exc.value + if 200 <= response.status_code < 300: + self.response.set_result(response) + else: + self.response.set_exception(InvalidProxyStatus(response)) + except Exception as exc: + proxy_exc = InvalidProxyMessage( + "did not receive a valid HTTP response from proxy" + ) + proxy_exc.__cause__ = exc + self.response.set_exception(proxy_exc) + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + transport = cast(asyncio.Transport, transport) + self.transport = transport + self.transport.write(prepare_connect_request(self.proxy, self.ws_uri)) + + def data_received(self, data: bytes) -> None: + self.reader.feed_data(data) + self.run_parser() + + def eof_received(self) -> None: + self.reader.feed_eof() + self.run_parser() + + def connection_lost(self, exc: Exception | None) -> None: + self.reader.feed_eof() + if exc is not None: + self.response.set_exception(exc) + + +async def connect_http_proxy( proxy: Proxy, ws_uri: WebSocketURI, **kwargs: Any, -) -> socket.socket: - """Connect via a proxy and return the socket.""" - # parse_proxy() validates proxy.scheme. - if proxy.scheme[:5] == "socks": - return await connect_socks_proxy(proxy, ws_uri, **kwargs) - else: - raise AssertionError("unsupported proxy") +) -> asyncio.Transport: + transport, protocol = await asyncio.get_running_loop().create_connection( + lambda: HTTPProxyConnection(ws_uri, proxy), + proxy.host, + proxy.port, + **kwargs, + ) + + try: + # This raises exceptions if the connection to the proxy fails. + await protocol.response + except Exception: + transport.close() + raise + + return transport diff --git a/src/websockets/headers.py b/src/websockets/headers.py index e05948a1f..c42abd976 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -36,7 +36,13 @@ T = TypeVar("T") -def build_host(host: str, port: int, secure: bool) -> str: +def build_host( + host: str, + port: int, + secure: bool, + *, + always_include_port: bool = False, +) -> str: """ Build a ``Host`` header. @@ -53,7 +59,7 @@ def build_host(host: str, port: int, secure: bool) -> str: if address.version == 6: host = f"[{host}]" - if port != (443 if secure else 80): + if always_include_port or port != (443 if secure else 80): host = f"{host}:{port}" return host diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 49d7b9a41..530ac3d09 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -210,6 +210,7 @@ def parse( read_line: Callable[[int], Generator[None, None, bytes]], read_exact: Callable[[int], Generator[None, None, bytes]], read_to_eof: Callable[[int], Generator[None, None, bytes]], + include_body: bool = True, ) -> Generator[None, None, Response]: """ Parse a WebSocket handshake response. @@ -265,9 +266,12 @@ def parse( headers = yield from parse_headers(read_line) - body = yield from read_body( - status_code, headers, read_line, read_exact, read_to_eof - ) + if include_body: + body = yield from read_body( + status_code, headers, read_line, read_exact, read_to_eof + ) + else: + body = b"" return cls(status_code, reason, headers, body) diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index b7ab83664..c0fe6901a 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -5,16 +5,17 @@ import threading import warnings from collections.abc import Sequence -from typing import Any, Literal +from typing import Any, Callable, Literal, TypeVar, cast from ..client import ClientProtocol -from ..datastructures import HeadersLike -from ..exceptions import ProxyError +from ..datastructures import Headers, HeadersLike +from ..exceptions import InvalidProxyMessage, InvalidProxyStatus, ProxyError from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate -from ..headers import validate_subprotocols +from ..headers import build_authorization_basic, build_host, validate_subprotocols from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event +from ..streams import StreamReader from ..typing import LoggerLike, Origin, Subprotocol from ..uri import Proxy, WebSocketURI, get_proxy, parse_proxy, parse_uri from .connection import Connection @@ -90,7 +91,7 @@ def handshake( self.protocol.send_request(self.request) if not self.response_rcvd.wait(timeout): - raise TimeoutError("timed out during handshake") + raise TimeoutError("timed out while waiting for handshake response") # self.protocol.handshake_exc is set when the connection is lost before # receiving a response, when the response cannot be parsed, or when the @@ -141,6 +142,8 @@ def connect( additional_headers: HeadersLike | None = None, user_agent_header: str | None = USER_AGENT, proxy: str | Literal[True] | None = True, + proxy_ssl: ssl_module.SSLContext | None = None, + proxy_server_hostname: str | None = None, # Timeouts open_timeout: float | None = 10, ping_interval: float | None = 20, @@ -195,6 +198,9 @@ def connect( to :obj:`None` to disable the proxy or to the address of a proxy to override the system configuration. See the :doc:`proxy docs <../../topics/proxies>` for details. + proxy_ssl: Configuration for enabling TLS on the proxy connection. + proxy_server_hostname: Host name for the TLS handshake with the proxy. + ``proxy_server_hostname`` overrides the host name from ``proxy``. open_timeout: Timeout for opening the connection in seconds. :obj:`None` disables the timeout. ping_interval: Interval between keepalive pings in seconds. @@ -284,14 +290,34 @@ def connect( assert path is not None # mypy cannot figure this out sock.connect(path) elif proxy is not None: - sock = connect_proxy( - parse_proxy(proxy), - ws_uri, - deadline, - # websockets is consistent with the socket module while - # python_socks is consistent across implementations. - local_addr=kwargs.pop("source_address", None), - ) + proxy_parsed = parse_proxy(proxy) + if proxy_parsed.scheme[:5] == "socks": + # Connect to the server through the proxy. + sock = connect_socks_proxy( + proxy_parsed, + ws_uri, + deadline, + # websockets is consistent with the socket module while + # python_socks is consistent across implementations. + local_addr=kwargs.pop("source_address", None), + ) + elif proxy_parsed.scheme[:4] == "http": + # Validate the proxy_ssl argument. + if proxy_parsed.scheme != "https" and proxy_ssl is not None: + raise ValueError( + "proxy_ssl argument is incompatible with an http:// proxy" + ) + # Connect to the server through the proxy. + sock = connect_http_proxy( + proxy_parsed, + ws_uri, + deadline, + ssl=proxy_ssl, + server_hostname=proxy_server_hostname, + **kwargs, + ) + else: + raise AssertionError("unsupported proxy") else: kwargs.setdefault("timeout", deadline.timeout()) sock = socket.create_connection( @@ -313,7 +339,12 @@ def connect( if server_hostname is None: server_hostname = ws_uri.host sock.settimeout(deadline.timeout()) - sock = ssl.wrap_socket(sock, server_hostname=server_hostname) + if proxy_ssl is None: + sock = ssl.wrap_socket(sock, server_hostname=server_hostname) + else: + sock_2 = SSLSSLSocket(sock, ssl, server_hostname=server_hostname) + # Let's pretend that sock is a socket, even though it isn't. + sock = cast(socket.socket, sock_2) sock.settimeout(None) # Initialize WebSocket protocol @@ -441,15 +472,169 @@ def connect_socks_proxy( raise ImportError("python-socks is required to use a SOCKS proxy") -def connect_proxy( +def prepare_connect_request(proxy: Proxy, ws_uri: WebSocketURI) -> bytes: + host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True) + headers = Headers() + headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure) + if proxy.username is not None: + assert proxy.password is not None # enforced by parse_proxy() + headers["Proxy-Authorization"] = build_authorization_basic( + proxy.username, proxy.password + ) + # We cannot use the Request class because it supports only GET requests. + return f"CONNECT {host} HTTP/1.1\r\n".encode() + headers.serialize() + + +def read_connect_response(sock: socket.socket, deadline: Deadline) -> Response: + reader = StreamReader() + parser = Response.parse( + reader.read_line, + reader.read_exact, + reader.read_to_eof, + include_body=False, + ) + try: + while True: + sock.settimeout(deadline.timeout()) + data = sock.recv(4096) + if data: + reader.feed_data(data) + else: + reader.feed_eof() + next(parser) + except StopIteration as exc: + assert isinstance(exc.value, Response) # help mypy + response = exc.value + if 200 <= response.status_code < 300: + return response + else: + raise InvalidProxyStatus(response) + except socket.timeout: + raise TimeoutError("timed out while connecting to HTTP proxy") + except Exception as exc: + raise InvalidProxyMessage( + "did not receive a valid HTTP response from proxy" + ) from exc + finally: + sock.settimeout(None) + + +def connect_http_proxy( proxy: Proxy, ws_uri: WebSocketURI, deadline: Deadline, + *, + ssl: ssl_module.SSLContext | None = None, + server_hostname: str | None = None, **kwargs: Any, ) -> socket.socket: - """Connect via a proxy and return the socket.""" - # parse_proxy() validates proxy.scheme. - if proxy.scheme[:5] == "socks": - return connect_socks_proxy(proxy, ws_uri, deadline, **kwargs) - else: - raise AssertionError("unsupported proxy") + # Connect socket + + kwargs.setdefault("timeout", deadline.timeout()) + sock = socket.create_connection((proxy.host, proxy.port), **kwargs) + + # Initialize TLS wrapper and perform TLS handshake + + if proxy.scheme == "https": + if ssl is None: + ssl = ssl_module.create_default_context() + if server_hostname is None: + server_hostname = proxy.host + sock.settimeout(deadline.timeout()) + sock = ssl.wrap_socket(sock, server_hostname=server_hostname) + sock.settimeout(None) + + # Send CONNECT request to the proxy and read response. + + sock.sendall(prepare_connect_request(proxy, ws_uri)) + try: + read_connect_response(sock, deadline) + except Exception: + sock.close() + raise + + return sock + + +T = TypeVar("T") +F = TypeVar("F", bound=Callable[..., T]) + + +class SSLSSLSocket: + """ + Socket-like object providing TLS-in-TLS. + + Only methods that are used by websockets are implemented. + + """ + + recv_bufsize = 65536 + + def __init__( + self, + sock: socket.socket, + ssl_context: ssl_module.SSLContext, + server_hostname: str | None = None, + ) -> None: + self.incoming = ssl_module.MemoryBIO() + self.outgoing = ssl_module.MemoryBIO() + self.ssl_socket = sock + self.ssl_object = ssl_context.wrap_bio( + self.incoming, + self.outgoing, + server_hostname=server_hostname, + ) + self.run_io(self.ssl_object.do_handshake) + + def run_io(self, func: Callable[..., T], *args: Any) -> T: + while True: + want_read = False + want_write = False + try: + result = func(*args) + except ssl_module.SSLWantReadError: + want_read = True + except ssl_module.SSLWantWriteError: # pragma: no cover + want_write = True + + # Write outgoing data in all cases. + data = self.outgoing.read() + if data: + self.ssl_socket.sendall(data) + + # Read incoming data and retry on SSLWantReadError. + if want_read: + data = self.ssl_socket.recv(self.recv_bufsize) + if data: + self.incoming.write(data) + else: + self.incoming.write_eof() + continue + # Retry after writing outgoing data on SSLWantWriteError. + if want_write: # pragma: no cover + continue + # Return result if no error happened. + return result + + def recv(self, buflen: int) -> bytes: + try: + return self.run_io(self.ssl_object.read, buflen) + except ssl_module.SSLEOFError: + return b"" # always ignore ragged EOFs + + def send(self, data: bytes) -> int: + return self.run_io(self.ssl_object.write, data) + + def sendall(self, data: bytes) -> None: + # adapted from ssl_module.SSLSocket.sendall() + count = 0 + with memoryview(data) as view, view.cast("B") as byte_view: + amount = len(byte_view) + while count < amount: + count += self.send(byte_view[count:]) + + # recv_into(), recvfrom(), recvfrom_into(), sendto(), unwrap(), and the + # flags argument aren't implemented because websockets doesn't need them. + + def __getattr__(self, name: str) -> Any: + return getattr(self.ssl_socket, name) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 2b753b2c5..10e3b6816 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -128,7 +128,7 @@ def handshake( """ if not self.request_rcvd.wait(timeout): - raise TimeoutError("timed out during handshake") + raise TimeoutError("timed out while waiting for handshake request") if self.request is not None: with self.send_context(expected_state=CONNECTING): diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 9c7ee46ad..c2a96f3ec 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -2,10 +2,12 @@ import contextlib import http import logging +import os import socket import ssl import sys import unittest +from unittest.mock import patch from websockets.asyncio.client import * from websockets.asyncio.compatibility import TimeoutError @@ -15,6 +17,7 @@ InvalidHandshake, InvalidMessage, InvalidProxy, + InvalidProxyMessage, InvalidStatus, InvalidURI, ProxyError, @@ -22,14 +25,8 @@ ) from websockets.extensions.permessage_deflate import PerMessageDeflate -from ..proxy import async_proxy -from ..utils import ( - CLIENT_CONTEXT, - MS, - SERVER_CONTEXT, - patch_environ, - temp_unix_socket_path, -) +from ..proxy import ProxyMixin +from ..utils import CLIENT_CONTEXT, MS, SERVER_CONTEXT, temp_unix_socket_path from .server import args, get_host_port, get_uri, handler @@ -388,7 +385,7 @@ def remove_accept_header(self, request, response): async def test_timeout_during_handshake(self): """Client times out before receiving handshake response from server.""" - # Replace the WebSocket server with a TCP server that does't respond. + # Replace the WebSocket server with a TCP server that doesn't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() with self.assertRaises(TimeoutError) as raised: @@ -396,7 +393,7 @@ async def test_timeout_during_handshake(self): self.fail("did not raise") self.assertEqual( str(raised.exception), - "timed out during handshake", + "timed out during opening handshake", ) async def test_connection_closed_during_handshake(self): @@ -508,7 +505,7 @@ async def test_reject_invalid_server_certificate(self): """Client rejects certificate where server certificate isn't trusted.""" async with serve(*args, ssl=SERVER_CONTEXT) as server: with self.assertRaises(ssl.SSLCertVerificationError) as raised: - # The test certificate isn't trusted system-wide. + # The test certificate is self-signed. async with connect(get_uri(server)): self.fail("did not raise") self.assertIn( @@ -566,126 +563,276 @@ def redirect(connection, request): @unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") -class ProxyClientTests(unittest.IsolatedAsyncioTestCase): - @contextlib.asynccontextmanager - async def socks_proxy(self, auth=None): - if auth: - proxyauth = "hello:iloveyou" - proxy_uri = "http://hello:iloveyou@localhost:51080" - else: - proxyauth = None - proxy_uri = "http://localhost:51080" - async with async_proxy( - mode=["socks5@51080"], - proxyauth=proxyauth, - ) as record_flows: - with patch_environ({"socks_proxy": proxy_uri}): - yield record_flows +class SocksProxyClientTests(ProxyMixin, unittest.IsolatedAsyncioTestCase): + proxy_mode = "socks5@51080" + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) async def test_socks_proxy(self): """Client connects to server through a SOCKS5 proxy.""" - async with self.socks_proxy() as proxy: - async with serve(*args) as server: - async with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 1) + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) async def test_secure_socks_proxy(self): """Client connects to server securely through a SOCKS5 proxy.""" - async with self.socks_proxy() as proxy: - async with serve(*args, ssl=SERVER_CONTEXT) as server: - async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 1) + async with serve(*args, ssl=SERVER_CONTEXT) as server: + async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + @patch.dict(os.environ, {"socks_proxy": "http://hello:iloveyou@localhost:51080"}) async def test_authenticated_socks_proxy(self): """Client connects to server through an authenticated SOCKS5 proxy.""" - async with self.socks_proxy(auth=True) as proxy: + try: + self.proxy_options.update(proxyauth="hello:iloveyou") async with serve(*args) as server: async with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 1) + finally: + self.proxy_options.update(proxyauth=None) + self.assertNumFlows(1) - async def test_socks_proxy_connection_error(self): - """Client receives an error when connecting to the SOCKS5 proxy.""" + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) + async def test_authenticated_socks_proxy_error(self): + """Client fails to authenticate to the SOCKS5 proxy.""" from python_socks import ProxyError as SocksProxyError - async with self.socks_proxy(auth=True) as proxy: + try: + self.proxy_options.update(proxyauth="any") with self.assertRaises(ProxyError) as raised: - async with connect( - "ws://example.com/", - proxy="socks5h://localhost:51080", # remove credentials - ): + async with connect("ws://example.com/"): self.fail("did not raise") + finally: + self.proxy_options.update(proxyauth=None) self.assertEqual( str(raised.exception), "failed to connect to SOCKS proxy", ) self.assertIsInstance(raised.exception.__cause__, SocksProxyError) - self.assertEqual(len(proxy.get_flows()), 0) + self.assertNumFlows(0) - async def test_socks_proxy_connection_fails(self): + @patch.dict(os.environ, {"socks_proxy": "http://localhost:61080"}) # bad port + async def test_socks_proxy_connection_failure(self): """Client fails to connect to the SOCKS5 proxy.""" from python_socks import ProxyConnectionError as SocksProxyConnectionError with self.assertRaises(OSError) as raised: - async with connect( - "ws://example.com/", - proxy="socks5h://localhost:51080", # nothing at this address - ): + async with connect("ws://example.com/"): self.fail("did not raise") # Don't test str(raised.exception) because we don't control it. self.assertIsInstance(raised.exception, SocksProxyConnectionError) + self.assertNumFlows(0) async def test_socks_proxy_connection_timeout(self): """Client times out while connecting to the SOCKS5 proxy.""" - # Replace the proxy with a TCP server that does't respond. + # Replace the proxy with a TCP server that doesn't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() - with self.assertRaises(TimeoutError) as raised: - async with connect( - "ws://example.com/", - proxy=f"socks5h://{host}:{port}/", - open_timeout=MS, - ): - self.fail("did not raise") + with patch.dict(os.environ, {"socks_proxy": f"http://{host}:{port}"}): + with self.assertRaises(TimeoutError) as raised: + async with connect("ws://example.com/", open_timeout=MS): + self.fail("did not raise") self.assertEqual( str(raised.exception), - "timed out during handshake", + "timed out during opening handshake", ) + self.assertNumFlows(0) - async def test_explicit_proxy(self): - """Client connects to server through a proxy set explicitly.""" - async with async_proxy(mode=["socks5@51080"]) as proxy: - async with serve(*args) as server: - async with connect( - get_uri(server), - # Take this opportunity to test socks5 instead of socks5h. - proxy="socks5://localhost:51080", - ) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 1) + async def test_explicit_socks_proxy(self): + """Client connects to server through a SOCKS5 proxy set explicitly.""" + async with serve(*args) as server: + async with connect( + get_uri(server), + # Take this opportunity to test socks5 instead of socks5h. + proxy="socks5://localhost:51080", + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) async def test_ignore_proxy_with_existing_socket(self): """Client connects using a pre-existing socket.""" - async with self.socks_proxy() as proxy: - async with serve(*args) as server: - with socket.create_connection(get_host_port(server)) as sock: - # Use a non-existing domain to ensure we connect to sock. - async with connect("ws://invalid/", sock=sock) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 0) + async with serve(*args) as server: + with socket.create_connection(get_host_port(server)) as sock: + # Use a non-existing domain to ensure we connect to sock. + async with connect("ws://invalid/", sock=sock) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(0) - async def test_unsupported_proxy(self): - """Client connects to server through an unsupported proxy.""" - with patch_environ({"ws_proxy": "other://localhost:51080"}): - with self.assertRaises(InvalidProxy) as raised: + +@unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") +class HTTPProxyClientTests(ProxyMixin, unittest.IsolatedAsyncioTestCase): + proxy_mode = "regular@58080" + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + async def test_http_proxy(self): + """Client connects to server through an HTTP proxy.""" + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + async def test_secure_http_proxy(self): + """Client connects to server securely through an HTTP proxy.""" + async with serve(*args, ssl=SERVER_CONTEXT) as server: + async with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.version()[:3], "TLS") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "http://hello:iloveyou@localhost:58080"}) + async def test_authenticated_http_proxy(self): + """Client connects to server through an authenticated HTTP proxy.""" + try: + self.proxy_options.update(proxyauth="hello:iloveyou") + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + finally: + self.proxy_options.update(proxyauth=None) + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + async def test_authenticated_http_proxy_error(self): + """Client fails to authenticate to the HTTP proxy.""" + try: + self.proxy_options.update(proxyauth="any") + with self.assertRaises(ProxyError) as raised: async with connect("ws://example.com/"): self.fail("did not raise") + finally: + self.proxy_options.update(proxyauth=None) self.assertEqual( str(raised.exception), - "other://localhost:51080 isn't a valid proxy: scheme other isn't supported", + "proxy rejected connection: HTTP 407", + ) + self.assertNumFlows(0) + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + async def test_http_proxy_protocol_error(self): + """Client receives invalid data when connecting to the HTTP proxy.""" + try: + self.proxy_options.update(break_http_connect=True) + with self.assertRaises(InvalidProxyMessage) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(break_http_connect=False) + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response from proxy", + ) + self.assertNumFlows(0) + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + async def test_http_proxy_connection_error(self): + """Client receives no response when connecting to the HTTP proxy.""" + try: + self.proxy_options.update(close_http_connect=True) + with self.assertRaises(InvalidProxyMessage) as raised: + async with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(close_http_connect=False) + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response from proxy", + ) + self.assertNumFlows(0) + + @patch.dict(os.environ, {"https_proxy": "http://localhost:48080"}) # bad port + async def test_http_proxy_connection_failure(self): + """Client fails to connect to the HTTP proxy.""" + with self.assertRaises(OSError): + async with connect("ws://example.com/"): + self.fail("did not raise") + # Don't test str(raised.exception) because we don't control it. + self.assertNumFlows(0) + + async def test_http_proxy_connection_timeout(self): + """Client times out while connecting to the HTTP proxy.""" + # Replace the proxy with a TCP server that doesn't respond. + with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() + with patch.dict(os.environ, {"https_proxy": f"http://{host}:{port}"}): + with self.assertRaises(TimeoutError) as raised: + async with connect("ws://example.com/", open_timeout=MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out during opening handshake", + ) + + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) + async def test_https_proxy(self): + """Client connects to server through an HTTPS proxy.""" + async with serve(*args) as server: + async with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) + async def test_secure_https_proxy(self): + """Client connects to server securely through an HTTPS proxy.""" + async with serve(*args, ssl=SERVER_CONTEXT) as server: + async with connect( + get_uri(server), + ssl=CLIENT_CONTEXT, + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.version()[:3], "TLS") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) + async def test_https_server_hostname(self): + """Client sets server_hostname to the value of proxy_server_hostname.""" + async with serve(*args) as server: + # Pass an argument not prefixed with proxy_ for coverage. + kwargs = {"all_errors": True} if sys.version_info >= (3, 12) else {} + async with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + proxy_server_hostname="overridden", + **kwargs, + ) as client: + ssl_object = client.transport.get_extra_info("ssl_object") + self.assertEqual(ssl_object.server_hostname, "overridden") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) + async def test_https_proxy_invalid_proxy_certificate(self): + """Client rejects certificate when proxy certificate isn't trusted.""" + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The proxy certificate isn't trusted. + async with connect("wss://example.com/"): + self.fail("did not raise") + self.assertIn( + "certificate verify failed: unable to get local issuer certificate", + str(raised.exception), + ) + + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) + async def test_https_proxy_invalid_server_certificate(self): + """Client rejects certificate when proxy certificate isn't trusted.""" + async with serve(*args, ssl=SERVER_CONTEXT) as server: + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The test certificate is self-signed. + async with connect(get_uri(server), proxy_ssl=self.proxy_context): + self.fail("did not raise") + self.assertIn( + "certificate verify failed: self signed certificate", + str(raised.exception).replace("-", " "), ) + self.assertNumFlows(1) @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") @@ -724,10 +871,7 @@ def redirect(connection, request): "cannot follow cross-origin redirect to ws://other/ with a Unix socket", ) - -@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") -class SecureUnixClientTests(unittest.IsolatedAsyncioTestCase): - async def test_connection(self): + async def test_secure_connection(self): """Client connects to server securely over a Unix socket.""" with temp_unix_socket_path() as path: async with unix_serve(handler, path, ssl=SERVER_CONTEXT): @@ -761,7 +905,7 @@ async def test_ssl_without_secure_uri(self): ) async def test_secure_uri_without_ssl(self): - """Client rejects no ssl when URI is secure.""" + """Client rejects ssl=None when URI is secure.""" with self.assertRaises(ValueError) as raised: await connect("wss://localhost/", ssl=None) self.assertEqual( @@ -769,6 +913,42 @@ async def test_secure_uri_without_ssl(self): "ssl=None is incompatible with a wss:// URI", ) + async def test_proxy_ssl_without_https_proxy(self): + """Client rejects proxy_ssl when proxy isn't HTTPS.""" + with self.assertRaises(ValueError) as raised: + await connect( + "ws://localhost/", + proxy="http://localhost:8080", + proxy_ssl=True, + ) + self.assertEqual( + str(raised.exception), + "proxy_ssl argument is incompatible with an http:// proxy", + ) + + async def test_https_proxy_without_ssl(self): + """Client rejects proxy_ssl=None when proxy is HTTPS.""" + with self.assertRaises(ValueError) as raised: + await connect( + "ws://localhost/", + proxy="https://localhost:8080", + proxy_ssl=None, + ) + self.assertEqual( + str(raised.exception), + "proxy_ssl=None is incompatible with an https:// proxy", + ) + + async def test_unsupported_proxy(self): + """Client rejects unsupported proxy.""" + with self.assertRaises(InvalidProxy) as raised: + async with connect("ws://example.com/", proxy="other://localhost:51080"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "other://localhost:51080 isn't a valid proxy: scheme other isn't supported", + ) + async def test_unix_without_path_or_sock(self): """Unix client requires path when sock isn't provided.""" with self.assertRaises(ValueError) as raised: diff --git a/tests/proxy.py b/tests/proxy.py index 95525a360..9746e3382 100644 --- a/tests/proxy.py +++ b/tests/proxy.py @@ -1,89 +1,137 @@ import asyncio -import contextlib import pathlib +import ssl import threading import warnings -warnings.filterwarnings("ignore", category=DeprecationWarning, module="mitmproxy") -warnings.filterwarnings("ignore", category=DeprecationWarning, module="passlib") -warnings.filterwarnings("ignore", category=DeprecationWarning, module="pyasn1") - try: + # Ignore deprecation warnings raised by mitmproxy dependencies at import time. + warnings.filterwarnings("ignore", category=DeprecationWarning, module="passlib") + warnings.filterwarnings("ignore", category=DeprecationWarning, module="pyasn1") + + from mitmproxy import ctx from mitmproxy.addons import core, next_layer, proxyauth, proxyserver, tlsconfig + from mitmproxy.http import Response from mitmproxy.master import Master - from mitmproxy.options import Options + from mitmproxy.options import CONF_BASENAME, CONF_DIR, Options except ImportError: pass class RecordFlows: - def __init__(self): - self.ready = asyncio.get_running_loop().create_future() + def __init__(self, on_running): + self.running = on_running self.flows = [] - def running(self): - self.ready.set_result(None) - - def websocket_start(self, flow): + def tcp_start(self, flow): self.flows.append(flow) def get_flows(self): flows, self.flows[:] = self.flows[:], [] return flows + def reset_flows(self): + self.flows = [] + + +class AlterRequest: + def load(self, loader): + loader.add_option( + name="break_http_connect", + typespec=bool, + default=False, + help="Respond to HTTP CONNECT requests with a 999 status code.", + ) + loader.add_option( + name="close_http_connect", + typespec=bool, + default=False, + help="Do not respond to HTTP CONNECT requests.", + ) + + def http_connect(self, flow): + if ctx.options.break_http_connect: + # mitmproxy can send a response with a status code not between 100 + # and 599, while websockets treats it as a protocol error. + # This is used for testing HTTP parsing errors. + flow.response = Response.make(999, "not a valid HTTP response") + if ctx.options.close_http_connect: + flow.kill() + + +class ProxyMixin: + """ + Run mitmproxy in a background thread. + + While it's uncommon to run two event loops in two threads, tests for the + asyncio implementation rely on this class too because it starts an event + loop for mitm proxy once, then a new event loop for each test. + """ + + proxy_mode = None + + @classmethod + async def run_proxy(cls): + cls.proxy_loop = loop = asyncio.get_event_loop() + cls.proxy_stop = stop = loop.create_future() + + cls.proxy_options = options = Options( + mode=[cls.proxy_mode], + # Don't intercept connections, but record them. + ignore_hosts=["^localhost:", "^127.0.0.1:", "^::1:"], + # This option requires mitmproxy 11.0.0, which requires Python 3.11. + show_ignored_hosts=True, + ) + cls.proxy_master = master = Master(options) + master.addons.add( + core.Core(), + proxyauth.ProxyAuth(), + proxyserver.Proxyserver(), + next_layer.NextLayer(), + tlsconfig.TlsConfig(), + RecordFlows(on_running=cls.proxy_ready.set), + AlterRequest(), + ) + + task = loop.create_task(cls.proxy_master.run()) + await stop -@contextlib.asynccontextmanager -async def async_proxy(mode, **config): - options = Options(mode=mode) - master = Master(options) - record_flows = RecordFlows() - master.addons.add( - core.Core(), - proxyauth.ProxyAuth(), - proxyserver.Proxyserver(), - next_layer.NextLayer(), - tlsconfig.TlsConfig(), - record_flows, - ) - config.update( - # Use our test certificate for TLS between client and proxy - # and disable TLS verification between proxy and upstream. - certs=[str(pathlib.Path(__file__).with_name("test_localhost.pem"))], - ssl_insecure=True, - ) - options.update(**config) - - asyncio.create_task(master.run()) - try: - await record_flows.ready - yield record_flows - finally: for server in master.addons.get("proxyserver").servers: await server.stop() master.shutdown() + await task + + @classmethod + def setUpClass(cls): + super().setUpClass() + + # Ignore deprecation warnings raised by mitmproxy at run time. + warnings.filterwarnings( + "ignore", category=DeprecationWarning, module="mitmproxy" + ) + + cls.proxy_ready = threading.Event() + cls.proxy_thread = threading.Thread(target=asyncio.run, args=(cls.run_proxy(),)) + cls.proxy_thread.start() + cls.proxy_ready.wait() + + certificate = pathlib.Path(CONF_DIR) / f"{CONF_BASENAME}-ca-cert.pem" + certificate = certificate.expanduser() + cls.proxy_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + cls.proxy_context.load_verify_locations(bytes(certificate)) + + def assertNumFlows(self, num_flows): + record_flows = self.proxy_master.addons.get("recordflows") + self.assertEqual(len(record_flows.get_flows()), num_flows) + def tearDown(self): + record_flows = self.proxy_master.addons.get("recordflows") + record_flows.reset_flows() + super().tearDown() -@contextlib.contextmanager -def sync_proxy(mode, **config): - loop = None - test_done = None - proxy_ready = threading.Event() - record_flows = None - - async def proxy_coroutine(): - nonlocal loop, test_done, proxy_ready, record_flows - loop = asyncio.get_running_loop() - test_done = loop.create_future() - async with async_proxy(mode, **config) as record_flows: - proxy_ready.set() - await test_done - - proxy_thread = threading.Thread(target=asyncio.run, args=(proxy_coroutine(),)) - proxy_thread.start() - try: - proxy_ready.wait() - yield record_flows - finally: - loop.call_soon_threadsafe(test_done.set_result, None) - proxy_thread.join() + @classmethod + def tearDownClass(cls): + cls.proxy_loop.call_soon_threadsafe(cls.proxy_stop.set_result, None) + cls.proxy_thread.join() + super().tearDownClass() diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 4844d3b5e..e4927bb32 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -1,6 +1,6 @@ -import contextlib import http import logging +import os import socket import socketserver import ssl @@ -8,11 +8,13 @@ import threading import time import unittest +from unittest.mock import patch from websockets.exceptions import ( InvalidHandshake, InvalidMessage, InvalidProxy, + InvalidProxyMessage, InvalidStatus, InvalidURI, ProxyError, @@ -20,13 +22,12 @@ from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.sync.client import * -from ..proxy import sync_proxy +from ..proxy import ProxyMixin from ..utils import ( CLIENT_CONTEXT, MS, SERVER_CONTEXT, DeprecationTestCase, - patch_environ, temp_unix_socket_path, ) from .server import get_uri, run_server, run_unix_server @@ -157,7 +158,7 @@ def remove_accept_header(self, request, response): def test_timeout_during_handshake(self): """Client times out before receiving handshake response from server.""" - # Replace the WebSocket server with a TCP server that does't respond. + # Replace the WebSocket server with a TCP server that doesn't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() with self.assertRaises(TimeoutError) as raised: @@ -165,7 +166,7 @@ def test_timeout_during_handshake(self): self.fail("did not raise") self.assertEqual( str(raised.exception), - "timed out during handshake", + "timed out while waiting for handshake response", ) def test_connection_closed_during_handshake(self): @@ -283,7 +284,7 @@ def test_reject_invalid_server_certificate(self): """Client rejects certificate where server certificate isn't trusted.""" with run_server(ssl=SERVER_CONTEXT) as server: with self.assertRaises(ssl.SSLCertVerificationError) as raised: - # The test certificate isn't trusted system-wide. + # The test certificate is self-signed. with connect(get_uri(server)): self.fail("did not raise") self.assertIn( @@ -307,127 +308,274 @@ def test_reject_invalid_server_hostname(self): @unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") -class ProxyClientTests(unittest.TestCase): - @contextlib.contextmanager - def socks_proxy(self, auth=None): - if auth: - proxyauth = "hello:iloveyou" - proxy_uri = "http://hello:iloveyou@localhost:51080" - else: - proxyauth = None - proxy_uri = "http://localhost:51080" - - with sync_proxy( - mode=["socks5@51080"], - proxyauth=proxyauth, - ) as record_flows: - with patch_environ({"socks_proxy": proxy_uri}): - yield record_flows +class SocksProxyClientTests(ProxyMixin, unittest.TestCase): + proxy_mode = "socks5@51080" + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) def test_socks_proxy(self): """Client connects to server through a SOCKS5 proxy.""" - with self.socks_proxy() as proxy: - with run_server() as server: - with connect(get_uri(server)) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 1) + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) def test_secure_socks_proxy(self): """Client connects to server securely through a SOCKS5 proxy.""" - with self.socks_proxy() as proxy: - with run_server(ssl=SERVER_CONTEXT) as server: - with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 1) + with run_server(ssl=SERVER_CONTEXT) as server: + with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + @patch.dict(os.environ, {"socks_proxy": "http://hello:iloveyou@localhost:51080"}) def test_authenticated_socks_proxy(self): """Client connects to server through an authenticated SOCKS5 proxy.""" - with self.socks_proxy(auth=True) as proxy: + try: + self.proxy_options.update(proxyauth="hello:iloveyou") with run_server() as server: with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 1) + finally: + self.proxy_options.update(proxyauth=None) + self.assertNumFlows(1) - def test_socks_proxy_connection_error(self): - """Client receives an error when connecting to the SOCKS5 proxy.""" + @patch.dict(os.environ, {"socks_proxy": "http://localhost:51080"}) + def test_authenticated_socks_proxy_error(self): + """Client fails to authenticate to the SOCKS5 proxy.""" from python_socks import ProxyError as SocksProxyError - with self.socks_proxy(auth=True) as proxy: + try: + self.proxy_options.update(proxyauth="any") with self.assertRaises(ProxyError) as raised: - with connect( - "ws://example.com/", - proxy="socks5h://localhost:51080", # remove credentials - ): + with connect("ws://example.com/"): self.fail("did not raise") + finally: + self.proxy_options.update(proxyauth=None) self.assertEqual( str(raised.exception), "failed to connect to SOCKS proxy", ) self.assertIsInstance(raised.exception.__cause__, SocksProxyError) - self.assertEqual(len(proxy.get_flows()), 0) + self.assertNumFlows(0) - def test_socks_proxy_connection_fails(self): + @patch.dict(os.environ, {"socks_proxy": "http://localhost:61080"}) # bad port + def test_socks_proxy_connection_failure(self): """Client fails to connect to the SOCKS5 proxy.""" from python_socks import ProxyConnectionError as SocksProxyConnectionError with self.assertRaises(OSError) as raised: - with connect( - "ws://example.com/", - proxy="socks5h://localhost:51080", # nothing at this address - ): + with connect("ws://example.com/"): self.fail("did not raise") # Don't test str(raised.exception) because we don't control it. self.assertIsInstance(raised.exception, SocksProxyConnectionError) + self.assertNumFlows(0) - def test_socks_proxy_timeout(self): - """Client times out before connecting to the SOCKS5 proxy.""" + def test_socks_proxy_connection_timeout(self): + """Client times out while connecting to the SOCKS5 proxy.""" from python_socks import ProxyTimeoutError as SocksProxyTimeoutError - # Replace the proxy with a TCP server that does't respond. + # Replace the proxy with a TCP server that doesn't respond. with socket.create_server(("localhost", 0)) as sock: host, port = sock.getsockname() - with self.assertRaises(TimeoutError) as raised: - with connect( - "ws://example.com/", - proxy=f"socks5h://{host}:{port}/", - open_timeout=MS, - ): - self.fail("did not raise") + with patch.dict(os.environ, {"socks_proxy": f"http://{host}:{port}"}): + with self.assertRaises(TimeoutError) as raised: + with connect("ws://example.com/", open_timeout=MS): + self.fail("did not raise") # Don't test str(raised.exception) because we don't control it. self.assertIsInstance(raised.exception, SocksProxyTimeoutError) + self.assertNumFlows(0) - def test_explicit_proxy(self): - """Client connects to server through a proxy set explicitly.""" - with sync_proxy(mode=["socks5@51080"]) as proxy: - with run_server() as server: - with connect( - get_uri(server), - # Take this opportunity to test socks5 instead of socks5h. - proxy="socks5://localhost:51080", - ) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 1) + def test_explicit_socks_proxy(self): + """Client connects to server through a SOCKS5 proxy set explicitly.""" + with run_server() as server: + with connect( + get_uri(server), + # Take this opportunity to test socks5 instead of socks5h. + proxy="socks5://localhost:51080", + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + @patch.dict(os.environ, {"ws_proxy": "http://localhost:58080"}) def test_ignore_proxy_with_existing_socket(self): """Client connects using a pre-existing socket.""" - with self.socks_proxy() as proxy: - with run_server() as server: - with socket.create_connection(server.socket.getsockname()) as sock: - # Use a non-existing domain to ensure we connect to sock. - with connect("ws://invalid/", sock=sock) as client: - self.assertEqual(client.protocol.state.name, "OPEN") - self.assertEqual(len(proxy.get_flows()), 0) + with run_server() as server: + with socket.create_connection(server.socket.getsockname()) as sock: + # Use a non-existing domain to ensure we connect to sock. + with connect("ws://invalid/", sock=sock) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(0) - def test_unsupported_proxy(self): - """Client connects to server through an unsupported proxy.""" - with patch_environ({"ws_proxy": "other://localhost:51080"}): - with self.assertRaises(InvalidProxy) as raised: + +@unittest.skipUnless("mitmproxy" in sys.modules, "mitmproxy not installed") +class HTTPProxyClientTests(ProxyMixin, unittest.IsolatedAsyncioTestCase): + proxy_mode = "regular@58080" + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + def test_http_proxy(self): + """Client connects to server through an HTTP proxy.""" + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + def test_secure_http_proxy(self): + """Client connects to server securely through an HTTP proxy.""" + with run_server(ssl=SERVER_CONTEXT) as server: + with connect(get_uri(server), ssl=CLIENT_CONTEXT) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(client.socket.version()[:3], "TLS") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "http://hello:iloveyou@localhost:58080"}) + def test_authenticated_http_proxy(self): + """Client connects to server through an authenticated HTTP proxy.""" + try: + self.proxy_options.update(proxyauth="hello:iloveyou") + with run_server() as server: + with connect(get_uri(server)) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + finally: + self.proxy_options.update(proxyauth=None) + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + def test_authenticated_http_proxy_error(self): + """Client fails to authenticate to the HTTP proxy.""" + try: + self.proxy_options.update(proxyauth="any") + with self.assertRaises(ProxyError) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(proxyauth=None) + self.assertEqual( + str(raised.exception), + "proxy rejected connection: HTTP 407", + ) + self.assertNumFlows(0) + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + def test_http_proxy_protocol_error(self): + """Client receives invalid data when connecting to the HTTP proxy.""" + try: + self.proxy_options.update(break_http_connect=True) + with self.assertRaises(InvalidProxyMessage) as raised: with connect("ws://example.com/"): self.fail("did not raise") + finally: + self.proxy_options.update(break_http_connect=False) self.assertEqual( str(raised.exception), - "other://localhost:51080 isn't a valid proxy: scheme other isn't supported", + "did not receive a valid HTTP response from proxy", ) + self.assertNumFlows(0) + + @patch.dict(os.environ, {"https_proxy": "http://localhost:58080"}) + def test_http_proxy_connection_error(self): + """Client receives no response when connecting to the HTTP proxy.""" + try: + self.proxy_options.update(close_http_connect=True) + with self.assertRaises(InvalidProxyMessage) as raised: + with connect("ws://example.com/"): + self.fail("did not raise") + finally: + self.proxy_options.update(close_http_connect=False) + self.assertEqual( + str(raised.exception), + "did not receive a valid HTTP response from proxy", + ) + self.assertNumFlows(0) + + @patch.dict(os.environ, {"https_proxy": "http://localhost:48080"}) # bad port + def test_http_proxy_connection_failure(self): + """Client fails to connect to the HTTP proxy.""" + with self.assertRaises(OSError): + with connect("ws://example.com/"): + self.fail("did not raise") + # Don't test str(raised.exception) because we don't control it. + self.assertNumFlows(0) + + def test_http_proxy_connection_timeout(self): + """Client times out while connecting to the HTTP proxy.""" + # Replace the proxy with a TCP server that does't respond. + with socket.create_server(("localhost", 0)) as sock: + host, port = sock.getsockname() + with patch.dict(os.environ, {"https_proxy": f"http://{host}:{port}"}): + with self.assertRaises(TimeoutError) as raised: + with connect("ws://example.com/", open_timeout=MS): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "timed out while connecting to HTTP proxy", + ) + + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) + def test_https_proxy(self): + """Client connects to server through an HTTPS proxy.""" + with run_server() as server: + with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) + def test_secure_https_proxy(self): + """Client connects to server securely through an HTTPS proxy.""" + with run_server(ssl=SERVER_CONTEXT) as server: + with connect( + get_uri(server), + ssl=CLIENT_CONTEXT, + proxy_ssl=self.proxy_context, + ) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + self.assertEqual(client.socket.version()[:3], "TLS") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) + def test_https_proxy_server_hostname(self): + """Client sets server_hostname to the value of proxy_server_hostname.""" + with run_server() as server: + # Pass an argument not prefixed with proxy_ for coverage. + kwargs = {"all_errors": True} if sys.version_info >= (3, 11) else {} + with connect( + get_uri(server), + proxy_ssl=self.proxy_context, + proxy_server_hostname="overridden", + **kwargs, + ) as client: + self.assertEqual(client.socket.server_hostname, "overridden") + self.assertNumFlows(1) + + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) + def test_https_proxy_invalid_proxy_certificate(self): + """Client rejects certificate when proxy certificate isn't trusted.""" + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The proxy certificate isn't trusted. + with connect("wss://example.com/"): + self.fail("did not raise") + self.assertIn( + "certificate verify failed: unable to get local issuer certificate", + str(raised.exception), + ) + self.assertNumFlows(0) + + @patch.dict(os.environ, {"https_proxy": "https://localhost:58080"}) + def test_https_proxy_invalid_server_certificate(self): + """Client rejects certificate when server certificate isn't trusted.""" + with run_server(ssl=SERVER_CONTEXT) as server: + with self.assertRaises(ssl.SSLCertVerificationError) as raised: + # The test certificate is self-signed. + with connect(get_uri(server), proxy_ssl=self.proxy_context): + self.fail("did not raise") + self.assertIn( + "certificate verify failed: self signed certificate", + str(raised.exception).replace("-", " "), + ) + self.assertNumFlows(1) @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") @@ -447,10 +595,7 @@ def test_set_host_header(self): with unix_connect(path, uri="ws://overridden/") as client: self.assertEqual(client.request.headers["Host"], "overridden") - -@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") -class SecureUnixClientTests(unittest.TestCase): - def test_connection(self): + def test_secure_connection(self): """Client connects to server securely over a Unix socket.""" with temp_unix_socket_path() as path: with run_unix_server(path, ssl=SERVER_CONTEXT): @@ -479,6 +624,19 @@ def test_ssl_without_secure_uri(self): "ssl argument is incompatible with a ws:// URI", ) + def test_proxy_ssl_without_https_proxy(self): + """Client rejects proxy_ssl when proxy isn't HTTPS.""" + with self.assertRaises(ValueError) as raised: + connect( + "ws://localhost/", + proxy="http://localhost:8080", + proxy_ssl=True, + ) + self.assertEqual( + str(raised.exception), + "proxy_ssl argument is incompatible with an http:// proxy", + ) + def test_unix_without_path_or_sock(self): """Unix client requires path when sock isn't provided.""" with self.assertRaises(ValueError) as raised: @@ -488,6 +646,16 @@ def test_unix_without_path_or_sock(self): "missing path argument", ) + def test_unsupported_proxy(self): + """Client rejects unsupported proxy.""" + with self.assertRaises(InvalidProxy) as raised: + with connect("ws://example.com/", proxy="other://localhost:58080"): + self.fail("did not raise") + self.assertEqual( + str(raised.exception), + "other://localhost:58080 isn't a valid proxy: scheme other isn't supported", + ) + def test_unix_with_path_and_sock(self): """Unix client rejects path when sock is provided.""" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) diff --git a/tests/test_headers.py b/tests/test_headers.py index 4ebd8b90c..816afc541 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -6,26 +6,33 @@ class HeadersTests(unittest.TestCase): def test_build_host(self): - for (host, port, secure), result in [ - (("localhost", 80, False), "localhost"), - (("localhost", 8000, False), "localhost:8000"), - (("localhost", 443, True), "localhost"), - (("localhost", 8443, True), "localhost:8443"), - (("example.com", 80, False), "example.com"), - (("example.com", 8000, False), "example.com:8000"), - (("example.com", 443, True), "example.com"), - (("example.com", 8443, True), "example.com:8443"), - (("127.0.0.1", 80, False), "127.0.0.1"), - (("127.0.0.1", 8000, False), "127.0.0.1:8000"), - (("127.0.0.1", 443, True), "127.0.0.1"), - (("127.0.0.1", 8443, True), "127.0.0.1:8443"), - (("::1", 80, False), "[::1]"), - (("::1", 8000, False), "[::1]:8000"), - (("::1", 443, True), "[::1]"), - (("::1", 8443, True), "[::1]:8443"), + for (host, port, secure), (result, result_with_port) in [ + (("localhost", 80, False), ("localhost", "localhost:80")), + (("localhost", 8000, False), ("localhost:8000", "localhost:8000")), + (("localhost", 443, True), ("localhost", "localhost:443")), + (("localhost", 8443, True), ("localhost:8443", "localhost:8443")), + (("example.com", 80, False), ("example.com", "example.com:80")), + (("example.com", 8000, False), ("example.com:8000", "example.com:8000")), + (("example.com", 443, True), ("example.com", "example.com:443")), + (("example.com", 8443, True), ("example.com:8443", "example.com:8443")), + (("127.0.0.1", 80, False), ("127.0.0.1", "127.0.0.1:80")), + (("127.0.0.1", 8000, False), ("127.0.0.1:8000", "127.0.0.1:8000")), + (("127.0.0.1", 443, True), ("127.0.0.1", "127.0.0.1:443")), + (("127.0.0.1", 8443, True), ("127.0.0.1:8443", "127.0.0.1:8443")), + (("::1", 80, False), ("[::1]", "[::1]:80")), + (("::1", 8000, False), ("[::1]:8000", "[::1]:8000")), + (("::1", 443, True), ("[::1]", "[::1]:443")), + (("::1", 8443, True), ("[::1]:8443", "[::1]:8443")), ]: with self.subTest(host=host, port=port, secure=secure): - self.assertEqual(build_host(host, port, secure), result) + self.assertEqual( + build_host(host, port, secure), + result, + ) + self.assertEqual( + build_host(host, port, secure, always_include_port=True), + result_with_port, + ) def test_parse_connection(self): for header, parsed in [ diff --git a/tests/test_http11.py b/tests/test_http11.py index bb0d27b95..3afb6d02c 100644 --- a/tests/test_http11.py +++ b/tests/test_http11.py @@ -130,11 +130,12 @@ def setUp(self): super().setUp() self.reader = StreamReader() - def parse(self): + def parse(self, **kwargs): return Response.parse( self.reader.read_line, self.reader.read_exact, self.reader.read_to_eof, + **kwargs, ) def test_parse(self): @@ -322,6 +323,11 @@ def test_parse_body_not_modified(self): response = self.assertGeneratorReturns(self.parse()) self.assertEqual(response.body, b"") + def test_parse_without_body(self): + self.reader.feed_data(b"HTTP/1.1 200 Connection Established\r\n\r\n") + response = self.assertGeneratorReturns(self.parse(include_body=False)) + self.assertEqual(response.body, b"") + def test_serialize(self): # Example from the protocol overview in RFC 6455 response = Response( diff --git a/tests/test_uri.py b/tests/test_uri.py index 35b51fa58..3ccf21158 100644 --- a/tests/test_uri.py +++ b/tests/test_uri.py @@ -1,11 +1,11 @@ +import os import unittest +from unittest.mock import patch from websockets.exceptions import InvalidProxy, InvalidURI from websockets.uri import * from websockets.uri import Proxy, get_proxy, parse_proxy -from .utils import patch_environ - VALID_URIS = [ ( @@ -255,6 +255,6 @@ def test_parse_proxy_user_info(self): def test_get_proxy(self): for environ, uri, proxy in PROXY_ENVS: - with patch_environ(environ): + with patch.dict(os.environ, environ): with self.subTest(environ=environ, uri=uri): self.assertEqual(get_proxy(parse_uri(uri)), proxy) diff --git a/tests/utils.py b/tests/utils.py index f68a447b1..7932aae60 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,14 +20,13 @@ # $ cat test_localhost.key test_localhost.crt > test_localhost.pem # $ rm test_localhost.key test_localhost.crt -CERTIFICATE = bytes(pathlib.Path(__file__).with_name("test_localhost.pem")) +CERTIFICATE = pathlib.Path(__file__).with_name("test_localhost.pem") CLIENT_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) -CLIENT_CONTEXT.load_verify_locations(CERTIFICATE) - +CLIENT_CONTEXT.load_verify_locations(bytes(CERTIFICATE)) SERVER_CONTEXT = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) -SERVER_CONTEXT.load_cert_chain(CERTIFICATE) +SERVER_CONTEXT.load_cert_chain(bytes(CERTIFICATE)) # Work around https://github.com/openssl/openssl/issues/7967 @@ -139,22 +138,6 @@ def assertNoLogs(self, logger=None, level=None): self.assertEqual(logs.output, [f"{level_name}:{logger}:dummy"]) -@contextlib.contextmanager -def patch_environ(environ): - backup = {} - for key, value in environ.items(): - backup[key] = os.environ.get(key) - os.environ[key] = value - try: - yield - finally: - for key, value in backup.items(): - if value is None: - del os.environ[key] - else: # pragma: no cover - os.environ[key] = value - - @contextlib.contextmanager def temp_unix_socket_path(): with tempfile.TemporaryDirectory() as temp_dir: diff --git a/tox.ini b/tox.ini index f5a2f5d3c..918aeaaec 100644 --- a/tox.ini +++ b/tox.ini @@ -15,8 +15,8 @@ commands = pass_env = WEBSOCKETS_* deps = - mitmproxy - python-socks[asyncio] + py311,py312,py313,coverage,maxi_cov: mitmproxy + py311,py312,py313,coverage,maxi_cov: python-socks[asyncio] [testenv:coverage] commands =