Skip to content

Commit 6b2f060

Browse files
committed
Follow redirects in the new asyncio implementation.
Fix #631.
1 parent 1f89db7 commit 6b2f060

File tree

7 files changed

+403
-130
lines changed

7 files changed

+403
-130
lines changed

docs/howto/upgrade.rst

+4-23
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,13 @@ It provides a very similar API. However, there are a few differences.
1010

1111
The recommended upgrade process is:
1212

13-
1. Make sure that your application doesn't use any `deprecated APIs`_. If it
13+
#. Make sure that your application doesn't use any `deprecated APIs`_. If it
1414
doesn't raise any warnings, you can skip this step.
15-
2. Check if your application depends on `missing features`_. If it does, you
16-
should stick to the original implementation until they're added.
17-
3. `Update import paths`_. For straightforward usage of websockets, this could
15+
#. `Update import paths`_. For straightforward usage of websockets, this could
1816
be the only step you need to take. Upgrading could be transparent.
19-
4. Check out `new features and improvements`_ and consider taking advantage of
17+
#. Check out `new features and improvements`_ and consider taking advantage of
2018
them to improve your application.
21-
5. Review `API changes`_ and adapt your application to preserve its current
19+
#. Review `API changes`_ and adapt your application to preserve its current
2220
functionality.
2321

2422
In the interest of brevity, only :func:`~asyncio.client.connect` and
@@ -62,23 +60,6 @@ the release notes of the version in which the feature was deprecated.
6260
* The ``host``, ``port``, and ``secure`` attributes of connections — deprecated
6361
in :ref:`8.0`.
6462

65-
.. _missing features:
66-
67-
Missing features
68-
----------------
69-
70-
.. admonition:: All features listed below will be provided in a future release.
71-
:class: tip
72-
73-
If your application relies on one of them, you should stick to the original
74-
implementation until the new implementation supports it in a future release.
75-
76-
Following redirects
77-
...................
78-
79-
The new implementation of :func:`~asyncio.client.connect` doesn't follow HTTP
80-
redirects yet.
81-
8263
.. _Update import paths:
8364

8465
Import paths

docs/project/changelog.rst

+3
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ New features
5050
:func:`~asyncio.client.connect` as an asynchronous iterator to the new
5151
:mod:`asyncio` implementation.
5252

53+
* :func:`~asyncio.client.connect` now follows redirects in the new
54+
:mod:`asyncio` implementation.
55+
5356
* Added HTTP Basic Auth to the new :mod:`asyncio` and :mod:`threading`
5457
implementations of servers.
5558

docs/reference/features.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ Client
154154
+------------------------------------+--------+--------+--------+--------+
155155
| Connect to non-ASCII IRIs |||||
156156
+------------------------------------+--------+--------+--------+--------+
157-
| Follow HTTP redirects | ||||
157+
| Follow HTTP redirects | ||||
158158
+------------------------------------+--------+--------+--------+--------+
159159
| Perform HTTP Basic Authentication |||||
160160
+------------------------------------+--------+--------+--------+--------+

docs/reference/variables.rst

+11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
Environment variables
22
=====================
33

4+
.. currentmodule:: websockets
5+
46
Logging
57
-------
68

@@ -77,3 +79,12 @@ Reconnection attempts are spaced out with truncated exponential backoff.
7779
The delay between attempts is capped at ``BACKOFF_MAX_DELAY`` seconds.
7880

7981
The default value is ``90.0`` seconds.
82+
83+
Redirections
84+
------------
85+
86+
.. envvar:: WEBSOCKETS_MAX_REDIRECTS
87+
88+
Maximum number of redirects that :func:`~asyncio.client.connect` follows.
89+
90+
The default value is ``10``.

src/websockets/asyncio/client.py

+133-36
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,30 @@
11
from __future__ import annotations
22

33
import asyncio
4-
import functools
54
import logging
5+
import os
6+
import urllib.parse
67
from types import TracebackType
78
from typing import Any, AsyncIterator, Callable, Generator, Sequence
89

910
from ..client import ClientProtocol, backoff
1011
from ..datastructures import HeadersLike
11-
from ..exceptions import InvalidStatus
12+
from ..exceptions import InvalidStatus, SecurityError
1213
from ..extensions.base import ClientExtensionFactory
1314
from ..extensions.permessage_deflate import enable_client_permessage_deflate
1415
from ..headers import validate_subprotocols
1516
from ..http11 import USER_AGENT, Response
1617
from ..protocol import CONNECTING, Event
1718
from ..typing import LoggerLike, Origin, Subprotocol
18-
from ..uri import parse_uri
19+
from ..uri import WebSocketURI, parse_uri
1920
from .compatibility import TimeoutError, asyncio_timeout
2021
from .connection import Connection
2122

2223

2324
__all__ = ["connect", "unix_connect", "ClientConnection"]
2425

26+
MAX_REDIRECTS = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10"))
27+
2528

2629
class ClientConnection(Connection):
2730
"""
@@ -126,7 +129,7 @@ def connection_lost(self, exc: Exception | None) -> None:
126129

127130
def process_exception(exc: Exception) -> Exception | None:
128131
"""
129-
Determine whether an error is retryable or fatal.
132+
Determine whether a connection error is retryable or fatal.
130133
131134
When reconnecting automatically with ``async for ... in connect(...)``, if a
132135
connection attempt fails, :func:`process_exception` is called to determine
@@ -297,16 +300,7 @@ def __init__(
297300
# Other keyword arguments are passed to loop.create_connection
298301
**kwargs: Any,
299302
) -> None:
300-
wsuri = parse_uri(uri)
301-
302-
if wsuri.secure:
303-
kwargs.setdefault("ssl", True)
304-
kwargs.setdefault("server_hostname", wsuri.host)
305-
if kwargs.get("ssl") is None:
306-
raise TypeError("ssl=None is incompatible with a wss:// URI")
307-
else:
308-
if kwargs.get("ssl") is not None:
309-
raise TypeError("ssl argument is incompatible with a ws:// URI")
303+
self.uri = uri
310304

311305
if subprotocols is not None:
312306
validate_subprotocols(subprotocols)
@@ -316,10 +310,13 @@ def __init__(
316310
elif compression is not None:
317311
raise ValueError(f"unsupported compression: {compression}")
318312

313+
if logger is None:
314+
logger = logging.getLogger("websockets.client")
315+
319316
if create_connection is None:
320317
create_connection = ClientConnection
321318

322-
def factory() -> ClientConnection:
319+
def protocol_factory(wsuri: WebSocketURI) -> ClientConnection:
323320
# This is a protocol in the Sans-I/O implementation of websockets.
324321
protocol = ClientProtocol(
325322
wsuri,
@@ -340,28 +337,104 @@ def factory() -> ClientConnection:
340337
)
341338
return connection
342339

340+
self.protocol_factory = protocol_factory
341+
self.handshake_args = (
342+
additional_headers,
343+
user_agent_header,
344+
)
345+
self.process_exception = process_exception
346+
self.open_timeout = open_timeout
347+
self.logger = logger
348+
self.connection_kwargs = kwargs
349+
350+
async def create_connection(self) -> ClientConnection:
351+
"""Create TCP or Unix connection."""
343352
loop = asyncio.get_running_loop()
353+
354+
wsuri = parse_uri(self.uri)
355+
kwargs = self.connection_kwargs.copy()
356+
357+
def factory() -> ClientConnection:
358+
return self.protocol_factory(wsuri)
359+
360+
if wsuri.secure:
361+
kwargs.setdefault("ssl", True)
362+
kwargs.setdefault("server_hostname", wsuri.host)
363+
if kwargs.get("ssl") is None:
364+
raise TypeError("ssl=None is incompatible with a wss:// URI")
365+
else:
366+
if kwargs.get("ssl") is not None:
367+
raise TypeError("ssl argument is incompatible with a ws:// URI")
368+
344369
if kwargs.pop("unix", False):
345-
self.create_connection = functools.partial(
346-
loop.create_unix_connection, factory, **kwargs
347-
)
370+
_, connection = await loop.create_unix_connection(factory, **kwargs)
348371
else:
349372
if kwargs.get("sock") is None:
350373
kwargs.setdefault("host", wsuri.host)
351374
kwargs.setdefault("port", wsuri.port)
352-
self.create_connection = functools.partial(
353-
loop.create_connection, factory, **kwargs
375+
_, connection = await loop.create_connection(factory, **kwargs)
376+
return connection
377+
378+
def process_redirect(self, exc: Exception) -> Exception | str:
379+
"""
380+
Determine whether a connection error is a redirect that can be followed.
381+
382+
Return the new URI if it's a valid redirect. Else, return an exception.
383+
384+
"""
385+
if not (
386+
isinstance(exc, InvalidStatus)
387+
and exc.response.status_code
388+
in [
389+
300, # Multiple Choices
390+
301, # Moved Permanently
391+
302, # Found
392+
303, # See Other
393+
307, # Temporary Redirect
394+
308, # Permanent Redirect
395+
]
396+
and "Location" in exc.response.headers
397+
):
398+
return exc
399+
400+
old_wsuri = parse_uri(self.uri)
401+
new_uri = urllib.parse.urljoin(self.uri, exc.response.headers["Location"])
402+
new_wsuri = parse_uri(new_uri)
403+
404+
# If connect() received a socket, it is closed and cannot be reused.
405+
if self.connection_kwargs.get("sock") is not None:
406+
return ValueError(
407+
f"cannot follow redirect to {new_uri} with a preexisting socket"
354408
)
355409

356-
self.handshake_args = (
357-
additional_headers,
358-
user_agent_header,
359-
)
360-
self.process_exception = process_exception
361-
self.open_timeout = open_timeout
362-
if logger is None:
363-
logger = logging.getLogger("websockets.client")
364-
self.logger = logger
410+
# TLS downgrade is forbidden.
411+
if old_wsuri.secure and not new_wsuri.secure:
412+
return SecurityError(f"cannot follow redirect to non-secure URI {new_uri}")
413+
414+
# Apply restrictions to cross-origin redirects.
415+
if (
416+
old_wsuri.secure != new_wsuri.secure
417+
or old_wsuri.host != new_wsuri.host
418+
or old_wsuri.port != new_wsuri.port
419+
):
420+
# Cross-origin redirects on Unix sockets don't quite make sense.
421+
if self.connection_kwargs.get("unix", False):
422+
return ValueError(
423+
f"cannot follow cross-origin redirect to {new_uri} "
424+
f"with a Unix socket"
425+
)
426+
427+
# Cross-origin redirects when host and port are overridden are ill-defined.
428+
if (
429+
self.connection_kwargs.get("host") is not None
430+
or self.connection_kwargs.get("port") is not None
431+
):
432+
return ValueError(
433+
f"cannot follow cross-origin redirect to {new_uri} "
434+
f"with an explicit host or port"
435+
)
436+
437+
return new_uri
365438

366439
# ... = await connect(...)
367440

@@ -372,14 +445,38 @@ def __await__(self) -> Generator[Any, None, ClientConnection]:
372445
async def __await_impl__(self) -> ClientConnection:
373446
try:
374447
async with asyncio_timeout(self.open_timeout):
375-
_transport, self.connection = await self.create_connection()
376-
try:
377-
await self.connection.handshake(*self.handshake_args)
378-
except (Exception, asyncio.CancelledError):
379-
self.connection.transport.close()
380-
raise
448+
for _ in range(MAX_REDIRECTS):
449+
self.connection = await self.create_connection()
450+
try:
451+
await self.connection.handshake(*self.handshake_args)
452+
except asyncio.CancelledError:
453+
self.connection.transport.close()
454+
raise
455+
except Exception as exc:
456+
# Always close the connection even though keep-alive is
457+
# the default in HTTP/1.1 because create_connection ties
458+
# opening the network connection with initializing the
459+
# protocol. In the current design of connect(), there is
460+
# no easy way to reuse the network connection that works
461+
# in every case nor to reinitialize the protocol.
462+
self.connection.transport.close()
463+
464+
uri_or_exc = self.process_redirect(exc)
465+
# Response is a valid redirect; follow it.
466+
if isinstance(uri_or_exc, str):
467+
self.uri = uri_or_exc
468+
continue
469+
# Response isn't a valid redirect; raise the exception.
470+
if uri_or_exc is exc:
471+
raise
472+
else:
473+
raise uri_or_exc from exc
474+
475+
else:
476+
return self.connection
381477
else:
382-
return self.connection
478+
raise SecurityError(f"more than {MAX_REDIRECTS} redirects")
479+
383480
except TimeoutError:
384481
# Re-raise exception with an informative error message.
385482
raise TimeoutError("timed out during handshake") from None

src/websockets/legacy/client.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ class Connect:
418418
419419
"""
420420

421-
MAX_REDIRECTS_ALLOWED = 10
421+
MAX_REDIRECTS_ALLOWED = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10"))
422422

423423
def __init__(
424424
self,

0 commit comments

Comments
 (0)