From 1a43fe22694686046e2f397c29a6f025e02d8a16 Mon Sep 17 00:00:00 2001 From: dan005 Date: Mon, 13 Jan 2025 23:22:20 +0500 Subject: [PATCH] Add regex support in `ServerProtocol(origins=...)`. --- src/websockets/asyncio/server.py | 10 +++++---- src/websockets/server.py | 21 +++++++++++++----- src/websockets/sync/server.py | 10 +++++---- tests/test_server.py | 37 ++++++++++++++++++++++++++++++++ 4 files changed, 65 insertions(+), 13 deletions(-) diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index fdb928004..4aad7122c 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -4,6 +4,7 @@ import hmac import http import logging +import re import socket import sys from collections.abc import Awaitable, Generator, Iterable, Sequence @@ -599,9 +600,10 @@ def handler(websocket): See :meth:`~asyncio.loop.create_server` for details. port: TCP port the server listens on. See :meth:`~asyncio.loop.create_server` for details. - origins: Acceptable values of the ``Origin`` header, for defending - against Cross-Site WebSocket Hijacking attacks. Include :obj:`None` - in the list if the lack of an origin is acceptable. + origins: Acceptable values of the ``Origin`` header, including regular + expressions, for defending against Cross-Site WebSocket Hijacking + attacks. Include :obj:`None` in the list if the lack of an origin + is acceptable. extensions: List of supported extensions, in order in which they should be negotiated and run. subprotocols: List of supported subprotocols, in order of decreasing @@ -681,7 +683,7 @@ def __init__( port: int | None = None, *, # WebSocket - origins: Sequence[Origin | None] | None = None, + origins: Sequence[Origin | re.Pattern[str] | None] | None = None, extensions: Sequence[ServerExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, select_subprotocol: ( diff --git a/src/websockets/server.py b/src/websockets/server.py index fe3c65a7d..67082ed72 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -4,6 +4,7 @@ import binascii import email.utils import http +import re import warnings from collections.abc import Generator, Sequence from typing import Any, Callable, cast @@ -49,9 +50,9 @@ class ServerProtocol(Protocol): Sans-I/O implementation of a WebSocket server connection. Args: - origins: Acceptable values of the ``Origin`` header; include - :obj:`None` in the list if the lack of an origin is acceptable. - This is useful for defending against Cross-Site WebSocket + origins: Acceptable values of the ``Origin`` header, including regular + expressions; include :obj:`None` in the list if the lack of an origin + is acceptable. This is useful for defending against Cross-Site WebSocket Hijacking attacks. extensions: List of supported extensions, in order in which they should be tried. @@ -73,7 +74,7 @@ class ServerProtocol(Protocol): def __init__( self, *, - origins: Sequence[Origin | None] | None = None, + origins: Sequence[Origin | re.Pattern[str] | None] | None = None, extensions: Sequence[ServerExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, select_subprotocol: ( @@ -309,7 +310,17 @@ def process_origin(self, headers: Headers) -> Origin | None: if origin is not None: origin = cast(Origin, origin) if self.origins is not None: - if origin not in self.origins: + valid = False + for acceptable_origin_or_regex in self.origins: + if isinstance(acceptable_origin_or_regex, re.Pattern): + # `str(origin)` is needed for compatibility + # between `Pattern.match(string=...)` and `origin`. + valid = acceptable_origin_or_regex.match(str(origin)) is not None + else: + valid = acceptable_origin_or_regex == origin + if valid: + break + if not valid: raise InvalidOrigin(origin) return origin diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 9506d6830..c14e558ac 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -4,6 +4,7 @@ import http import logging import os +import re import selectors import socket import ssl as ssl_module @@ -325,7 +326,7 @@ def serve( sock: socket.socket | None = None, ssl: ssl_module.SSLContext | None = None, # WebSocket - origins: Sequence[Origin | None] | None = None, + origins: Sequence[Origin | re.Pattern[str] | None] | None = None, extensions: Sequence[ServerExtensionFactory] | None = None, subprotocols: Sequence[Subprotocol] | None = None, select_subprotocol: ( @@ -399,9 +400,10 @@ def handler(websocket): You may call :func:`socket.create_server` to create a suitable TCP socket. ssl: Configuration for enabling TLS on the connection. - origins: Acceptable values of the ``Origin`` header, for defending - against Cross-Site WebSocket Hijacking attacks. Include :obj:`None` - in the list if the lack of an origin is acceptable. + origins: Acceptable values of the ``Origin`` header, including regular + expressions, for defending against Cross-Site WebSocket Hijacking + attacks. Include :obj:`None` in the list if the lack of an origin + is acceptable. extensions: List of supported extensions, in order in which they should be negotiated and run. subprotocols: List of supported subprotocols, in order of decreasing diff --git a/tests/test_server.py b/tests/test_server.py index 69f555689..dd5e0d09a 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,5 +1,6 @@ import http import logging +import re import sys import unittest from unittest.mock import patch @@ -623,6 +624,42 @@ def test_unsupported_origin(self): "invalid Origin header: https://original.example.com", ) + def test_supported_origin_by_regex(self): + """ + Handshake succeeds when checking origins and the origin is supported + by a regular expression. + """ + server = ServerProtocol( + origins=["https://example.com", re.compile(r"https://other.*")] + ) + request = make_request() + request.headers["Origin"] = "https://other.example.com" + response = server.accept(request) + server.send_response(response) + + self.assertHandshakeSuccess(server) + self.assertEqual(server.origin, "https://other.example.com") + + def test_unsupported_origin_by_regex(self): + """ + Handshake succeeds when checking origins and the origin is unsupported + by a regular expression. + """ + server = ServerProtocol( + origins=["https://example.com", re.compile(r"https://other.*")] + ) + request = make_request() + request.headers["Origin"] = "https://original.example.com" + response = server.accept(request) + server.send_response(response) + + self.assertEqual(response.status_code, 403) + self.assertHandshakeError( + server, + InvalidOrigin, + "invalid Origin header: https://original.example.com", + ) + def test_no_origin_accepted(self): """Handshake succeeds when the lack of an origin is accepted.""" server = ServerProtocol(origins=[None])