1
1
from __future__ import annotations
2
2
3
3
import asyncio
4
- import functools
5
4
import logging
5
+ import os
6
+ import urllib .parse
6
7
from types import TracebackType
7
8
from typing import Any , AsyncIterator , Callable , Generator , Sequence
8
9
9
10
from ..client import ClientProtocol , backoff
10
11
from ..datastructures import HeadersLike
11
- from ..exceptions import InvalidStatus
12
+ from ..exceptions import InvalidStatus , SecurityError
12
13
from ..extensions .base import ClientExtensionFactory
13
14
from ..extensions .permessage_deflate import enable_client_permessage_deflate
14
15
from ..headers import validate_subprotocols
15
16
from ..http11 import USER_AGENT , Response
16
17
from ..protocol import CONNECTING , Event
17
18
from ..typing import LoggerLike , Origin , Subprotocol
18
- from ..uri import parse_uri
19
+ from ..uri import WebSocketURI , parse_uri
19
20
from .compatibility import TimeoutError , asyncio_timeout
20
21
from .connection import Connection
21
22
22
23
23
24
__all__ = ["connect" , "unix_connect" , "ClientConnection" ]
24
25
26
+ MAX_REDIRECTS = int (os .environ .get ("WEBSOCKETS_MAX_REDIRECTS" , "10" ))
27
+
25
28
26
29
class ClientConnection (Connection ):
27
30
"""
@@ -126,7 +129,7 @@ def connection_lost(self, exc: Exception | None) -> None:
126
129
127
130
def process_exception (exc : Exception ) -> Exception | None :
128
131
"""
129
- Determine whether an error is retryable or fatal.
132
+ Determine whether a connection error is retryable or fatal.
130
133
131
134
When reconnecting automatically with ``async for ... in connect(...)``, if a
132
135
connection attempt fails, :func:`process_exception` is called to determine
@@ -297,16 +300,7 @@ def __init__(
297
300
# Other keyword arguments are passed to loop.create_connection
298
301
** kwargs : Any ,
299
302
) -> None :
300
- wsuri = parse_uri (uri )
301
-
302
- if wsuri .secure :
303
- kwargs .setdefault ("ssl" , True )
304
- kwargs .setdefault ("server_hostname" , wsuri .host )
305
- if kwargs .get ("ssl" ) is None :
306
- raise TypeError ("ssl=None is incompatible with a wss:// URI" )
307
- else :
308
- if kwargs .get ("ssl" ) is not None :
309
- raise TypeError ("ssl argument is incompatible with a ws:// URI" )
303
+ self .uri = uri
310
304
311
305
if subprotocols is not None :
312
306
validate_subprotocols (subprotocols )
@@ -316,10 +310,13 @@ def __init__(
316
310
elif compression is not None :
317
311
raise ValueError (f"unsupported compression: { compression } " )
318
312
313
+ if logger is None :
314
+ logger = logging .getLogger ("websockets.client" )
315
+
319
316
if create_connection is None :
320
317
create_connection = ClientConnection
321
318
322
- def factory ( ) -> ClientConnection :
319
+ def protocol_factory ( wsuri : WebSocketURI ) -> ClientConnection :
323
320
# This is a protocol in the Sans-I/O implementation of websockets.
324
321
protocol = ClientProtocol (
325
322
wsuri ,
@@ -340,28 +337,104 @@ def factory() -> ClientConnection:
340
337
)
341
338
return connection
342
339
340
+ self .protocol_factory = protocol_factory
341
+ self .handshake_args = (
342
+ additional_headers ,
343
+ user_agent_header ,
344
+ )
345
+ self .process_exception = process_exception
346
+ self .open_timeout = open_timeout
347
+ self .logger = logger
348
+ self .connection_kwargs = kwargs
349
+
350
+ async def create_connection (self ) -> ClientConnection :
351
+ """Create TCP or Unix connection."""
343
352
loop = asyncio .get_running_loop ()
353
+
354
+ wsuri = parse_uri (self .uri )
355
+ kwargs = self .connection_kwargs .copy ()
356
+
357
+ def factory () -> ClientConnection :
358
+ return self .protocol_factory (wsuri )
359
+
360
+ if wsuri .secure :
361
+ kwargs .setdefault ("ssl" , True )
362
+ kwargs .setdefault ("server_hostname" , wsuri .host )
363
+ if kwargs .get ("ssl" ) is None :
364
+ raise TypeError ("ssl=None is incompatible with a wss:// URI" )
365
+ else :
366
+ if kwargs .get ("ssl" ) is not None :
367
+ raise TypeError ("ssl argument is incompatible with a ws:// URI" )
368
+
344
369
if kwargs .pop ("unix" , False ):
345
- self .create_connection = functools .partial (
346
- loop .create_unix_connection , factory , ** kwargs
347
- )
370
+ _ , connection = await loop .create_unix_connection (factory , ** kwargs )
348
371
else :
349
372
if kwargs .get ("sock" ) is None :
350
373
kwargs .setdefault ("host" , wsuri .host )
351
374
kwargs .setdefault ("port" , wsuri .port )
352
- self .create_connection = functools .partial (
353
- loop .create_connection , factory , ** kwargs
375
+ _ , connection = await loop .create_connection (factory , ** kwargs )
376
+ return connection
377
+
378
+ def process_redirect (self , exc : Exception ) -> Exception | str :
379
+ """
380
+ Determine whether a connection error is a redirect that can be followed.
381
+
382
+ Return the new URI if it's a valid redirect. Else, return an exception.
383
+
384
+ """
385
+ if not (
386
+ isinstance (exc , InvalidStatus )
387
+ and exc .response .status_code
388
+ in [
389
+ 300 , # Multiple Choices
390
+ 301 , # Moved Permanently
391
+ 302 , # Found
392
+ 303 , # See Other
393
+ 307 , # Temporary Redirect
394
+ 308 , # Permanent Redirect
395
+ ]
396
+ and "Location" in exc .response .headers
397
+ ):
398
+ return exc
399
+
400
+ old_wsuri = parse_uri (self .uri )
401
+ new_uri = urllib .parse .urljoin (self .uri , exc .response .headers ["Location" ])
402
+ new_wsuri = parse_uri (new_uri )
403
+
404
+ # If connect() received a socket, it is closed and cannot be reused.
405
+ if self .connection_kwargs .get ("sock" ) is not None :
406
+ return ValueError (
407
+ f"cannot follow redirect to { new_uri } with a preexisting socket"
354
408
)
355
409
356
- self .handshake_args = (
357
- additional_headers ,
358
- user_agent_header ,
359
- )
360
- self .process_exception = process_exception
361
- self .open_timeout = open_timeout
362
- if logger is None :
363
- logger = logging .getLogger ("websockets.client" )
364
- self .logger = logger
410
+ # TLS downgrade is forbidden.
411
+ if old_wsuri .secure and not new_wsuri .secure :
412
+ return SecurityError (f"cannot follow redirect to non-secure URI { new_uri } " )
413
+
414
+ # Apply restrictions to cross-origin redirects.
415
+ if (
416
+ old_wsuri .secure != new_wsuri .secure
417
+ or old_wsuri .host != new_wsuri .host
418
+ or old_wsuri .port != new_wsuri .port
419
+ ):
420
+ # Cross-origin redirects on Unix sockets don't quite make sense.
421
+ if self .connection_kwargs .get ("unix" , False ):
422
+ return ValueError (
423
+ f"cannot follow cross-origin redirect to { new_uri } "
424
+ f"with a Unix socket"
425
+ )
426
+
427
+ # Cross-origin redirects when host and port are overridden are ill-defined.
428
+ if (
429
+ self .connection_kwargs .get ("host" ) is not None
430
+ or self .connection_kwargs .get ("port" ) is not None
431
+ ):
432
+ return ValueError (
433
+ f"cannot follow cross-origin redirect to { new_uri } "
434
+ f"with an explicit host or port"
435
+ )
436
+
437
+ return new_uri
365
438
366
439
# ... = await connect(...)
367
440
@@ -372,14 +445,38 @@ def __await__(self) -> Generator[Any, None, ClientConnection]:
372
445
async def __await_impl__ (self ) -> ClientConnection :
373
446
try :
374
447
async with asyncio_timeout (self .open_timeout ):
375
- _transport , self .connection = await self .create_connection ()
376
- try :
377
- await self .connection .handshake (* self .handshake_args )
378
- except (Exception , asyncio .CancelledError ):
379
- self .connection .transport .close ()
380
- raise
448
+ for _ in range (MAX_REDIRECTS ):
449
+ self .connection = await self .create_connection ()
450
+ try :
451
+ await self .connection .handshake (* self .handshake_args )
452
+ except asyncio .CancelledError :
453
+ self .connection .transport .close ()
454
+ raise
455
+ except Exception as exc :
456
+ # Always close the connection even though keep-alive is
457
+ # the default in HTTP/1.1 because create_connection ties
458
+ # opening the network connection with initializing the
459
+ # protocol. In the current design of connect(), there is
460
+ # no easy way to reuse the network connection that works
461
+ # in every case nor to reinitialize the protocol.
462
+ self .connection .transport .close ()
463
+
464
+ uri_or_exc = self .process_redirect (exc )
465
+ # Response is a valid redirect; follow it.
466
+ if isinstance (uri_or_exc , str ):
467
+ self .uri = uri_or_exc
468
+ continue
469
+ # Response isn't a valid redirect; raise the exception.
470
+ if uri_or_exc is exc :
471
+ raise
472
+ else :
473
+ raise uri_or_exc from exc
474
+
475
+ else :
476
+ return self .connection
381
477
else :
382
- return self .connection
478
+ raise SecurityError (f"more than { MAX_REDIRECTS } redirects" )
479
+
383
480
except TimeoutError :
384
481
# Re-raise exception with an informative error message.
385
482
raise TimeoutError ("timed out during handshake" ) from None
0 commit comments