|
4 | 4 | import logging
|
5 | 5 | import os
|
6 | 6 | import socket
|
| 7 | +import ssl as ssl_module |
7 | 8 | import traceback
|
8 | 9 | import urllib.parse
|
9 | 10 | from collections.abc import AsyncIterator, Generator, Sequence
|
10 | 11 | from types import TracebackType
|
11 |
| -from typing import Any, Callable, Literal |
| 12 | +from typing import Any, Callable, Literal, cast |
12 | 13 |
|
13 | 14 | from ..client import ClientProtocol, backoff
|
14 |
| -from ..datastructures import HeadersLike |
15 |
| -from ..exceptions import InvalidMessage, InvalidStatus, ProxyError, SecurityError |
| 15 | +from ..datastructures import Headers, HeadersLike |
| 16 | +from ..exceptions import ( |
| 17 | + InvalidMessage, |
| 18 | + InvalidProxyMessage, |
| 19 | + InvalidProxyStatus, |
| 20 | + InvalidStatus, |
| 21 | + ProxyError, |
| 22 | + SecurityError, |
| 23 | +) |
16 | 24 | from ..extensions.base import ClientExtensionFactory
|
17 | 25 | from ..extensions.permessage_deflate import enable_client_permessage_deflate
|
18 |
| -from ..headers import validate_subprotocols |
| 26 | +from ..headers import build_authorization_basic, build_host, validate_subprotocols |
19 | 27 | from ..http11 import USER_AGENT, Response
|
20 | 28 | from ..protocol import CONNECTING, Event
|
| 29 | +from ..streams import StreamReader |
21 | 30 | from ..typing import LoggerLike, Origin, Subprotocol
|
22 | 31 | from ..uri import Proxy, WebSocketURI, get_proxy, parse_proxy, parse_uri
|
23 | 32 | from .compatibility import TimeoutError, asyncio_timeout
|
@@ -266,6 +275,16 @@ class connect:
|
266 | 275 | :meth:`~asyncio.loop.create_connection` method) to create a suitable
|
267 | 276 | client socket and customize it.
|
268 | 277 |
|
| 278 | + When using a proxy: |
| 279 | +
|
| 280 | + * Prefix keyword arguments with ``proxy_`` for configuring TLS between the |
| 281 | + client and an HTTPS proxy: ``proxy_ssl``, ``proxy_server_hostname``, |
| 282 | + ``proxy_ssl_handshake_timeout``, and ``proxy_ssl_shutdown_timeout``. |
| 283 | + * Use the standard keyword arguments for configuring TLS between the proxy |
| 284 | + and the WebSocket server: ``ssl``, ``server_hostname``, |
| 285 | + ``ssl_handshake_timeout``, and ``ssl_shutdown_timeout``. |
| 286 | + * Other keyword arguments are used only for connecting to the proxy. |
| 287 | +
|
269 | 288 | Raises:
|
270 | 289 | InvalidURI: If ``uri`` isn't a valid WebSocket URI.
|
271 | 290 | InvalidProxy: If ``proxy`` isn't a valid proxy.
|
@@ -397,6 +416,47 @@ def factory() -> ClientConnection:
|
397 | 416 | sock=sock,
|
398 | 417 | **kwargs,
|
399 | 418 | )
|
| 419 | + elif proxy_parsed.scheme[:4] == "http": |
| 420 | + # Split keyword arguments between the proxy and the server. |
| 421 | + all_kwargs, proxy_kwargs, kwargs = kwargs, {}, {} |
| 422 | + for key, value in all_kwargs.items(): |
| 423 | + if key.startswith("ssl") or key == "server_hostname": |
| 424 | + kwargs[key] = value |
| 425 | + elif key.startswith("proxy_"): |
| 426 | + proxy_kwargs[key[6:]] = value |
| 427 | + else: |
| 428 | + proxy_kwargs[key] = value |
| 429 | + # Validate the proxy_ssl argument. |
| 430 | + if proxy_parsed.scheme == "https": |
| 431 | + proxy_kwargs.setdefault("ssl", True) |
| 432 | + if proxy_kwargs.get("ssl") is None: |
| 433 | + raise ValueError( |
| 434 | + "proxy_ssl=None is incompatible with an https:// proxy" |
| 435 | + ) |
| 436 | + else: |
| 437 | + if proxy_kwargs.get("ssl") is not None: |
| 438 | + raise ValueError( |
| 439 | + "proxy_ssl argument is incompatible with an http:// proxy" |
| 440 | + ) |
| 441 | + # Connect to the server through the proxy. |
| 442 | + transport = await connect_http_proxy( |
| 443 | + proxy_parsed, |
| 444 | + ws_uri, |
| 445 | + **proxy_kwargs, |
| 446 | + ) |
| 447 | + # Initialize WebSocket connection via the proxy. |
| 448 | + connection = factory() |
| 449 | + transport.set_protocol(connection) |
| 450 | + ssl = kwargs.pop("ssl", None) |
| 451 | + if ssl is True: |
| 452 | + ssl = ssl_module.create_default_context() |
| 453 | + if ssl is not None: |
| 454 | + new_transport = await loop.start_tls( |
| 455 | + transport, connection, ssl, **kwargs |
| 456 | + ) |
| 457 | + assert new_transport is not None # help mypy |
| 458 | + transport = new_transport |
| 459 | + connection.connection_made(transport) |
400 | 460 | else:
|
401 | 461 | raise AssertionError("unsupported proxy")
|
402 | 462 | else:
|
@@ -655,3 +715,89 @@ async def connect_socks_proxy(
|
655 | 715 | **kwargs: Any,
|
656 | 716 | ) -> socket.socket:
|
657 | 717 | raise ImportError("python-socks is required to use a SOCKS proxy")
|
| 718 | + |
| 719 | + |
| 720 | +def prepare_connect_request(proxy: Proxy, ws_uri: WebSocketURI) -> bytes: |
| 721 | + host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True) |
| 722 | + headers = Headers() |
| 723 | + headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure) |
| 724 | + if proxy.username is not None: |
| 725 | + assert proxy.password is not None # enforced by parse_proxy() |
| 726 | + headers["Proxy-Authorization"] = build_authorization_basic( |
| 727 | + proxy.username, proxy.password |
| 728 | + ) |
| 729 | + # We cannot use the Request class because it supports only GET requests. |
| 730 | + return f"CONNECT {host} HTTP/1.1\r\n".encode() + headers.serialize() |
| 731 | + |
| 732 | + |
| 733 | +class HTTPProxyConnection(asyncio.Protocol): |
| 734 | + def __init__(self, ws_uri: WebSocketURI, proxy: Proxy): |
| 735 | + self.ws_uri = ws_uri |
| 736 | + self.proxy = proxy |
| 737 | + |
| 738 | + self.reader = StreamReader() |
| 739 | + self.parser = Response.parse( |
| 740 | + self.reader.read_line, |
| 741 | + self.reader.read_exact, |
| 742 | + self.reader.read_to_eof, |
| 743 | + include_body=False, |
| 744 | + ) |
| 745 | + |
| 746 | + loop = asyncio.get_running_loop() |
| 747 | + self.response: asyncio.Future[Response] = loop.create_future() |
| 748 | + |
| 749 | + def run_parser(self) -> None: |
| 750 | + try: |
| 751 | + next(self.parser) |
| 752 | + except StopIteration as exc: |
| 753 | + response = exc.value |
| 754 | + if 200 <= response.status_code < 300: |
| 755 | + self.response.set_result(response) |
| 756 | + else: |
| 757 | + self.response.set_exception(InvalidProxyStatus(response)) |
| 758 | + except Exception as exc: |
| 759 | + proxy_exc = InvalidProxyMessage( |
| 760 | + "did not receive a valid HTTP response from proxy" |
| 761 | + ) |
| 762 | + proxy_exc.__cause__ = exc |
| 763 | + self.response.set_exception(proxy_exc) |
| 764 | + |
| 765 | + def connection_made(self, transport: asyncio.BaseTransport) -> None: |
| 766 | + transport = cast(asyncio.Transport, transport) |
| 767 | + self.transport = transport |
| 768 | + self.transport.write(prepare_connect_request(self.proxy, self.ws_uri)) |
| 769 | + |
| 770 | + def data_received(self, data: bytes) -> None: |
| 771 | + self.reader.feed_data(data) |
| 772 | + self.run_parser() |
| 773 | + |
| 774 | + def eof_received(self) -> None: |
| 775 | + self.reader.feed_eof() |
| 776 | + self.run_parser() |
| 777 | + |
| 778 | + def connection_lost(self, exc: Exception | None) -> None: |
| 779 | + self.reader.feed_eof() |
| 780 | + if exc is not None: |
| 781 | + self.response.set_exception(exc) |
| 782 | + |
| 783 | + |
| 784 | +async def connect_http_proxy( |
| 785 | + proxy: Proxy, |
| 786 | + ws_uri: WebSocketURI, |
| 787 | + **kwargs: Any, |
| 788 | +) -> asyncio.Transport: |
| 789 | + transport, protocol = await asyncio.get_running_loop().create_connection( |
| 790 | + lambda: HTTPProxyConnection(ws_uri, proxy), |
| 791 | + proxy.host, |
| 792 | + proxy.port, |
| 793 | + **kwargs, |
| 794 | + ) |
| 795 | + |
| 796 | + try: |
| 797 | + # This raises exceptions if the connection to the proxy fails. |
| 798 | + await protocol.response |
| 799 | + except Exception: |
| 800 | + transport.close() |
| 801 | + raise |
| 802 | + |
| 803 | + return transport |
0 commit comments