Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add regex support in ServerProtocol(origins=...) #1575

Merged
merged 1 commit into from
Jan 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/websockets/asyncio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import hmac
import http
import logging
import re
import socket
import sys
from collections.abc import Awaitable, Generator, Iterable, Sequence
Expand Down Expand Up @@ -599,9 +600,10 @@ def handler(websocket):
See :meth:`~asyncio.loop.create_server` for details.
port: TCP port the server listens on.
See :meth:`~asyncio.loop.create_server` for details.
origins: Acceptable values of the ``Origin`` header, for defending
against Cross-Site WebSocket Hijacking attacks. Include :obj:`None`
in the list if the lack of an origin is acceptable.
origins: Acceptable values of the ``Origin`` header, including regular
expressions, for defending against Cross-Site WebSocket Hijacking
attacks. Include :obj:`None` in the list if the lack of an origin
is acceptable.
extensions: List of supported extensions, in order in which they
should be negotiated and run.
subprotocols: List of supported subprotocols, in order of decreasing
Expand Down Expand Up @@ -681,7 +683,7 @@ def __init__(
port: int | None = None,
*,
# WebSocket
origins: Sequence[Origin | None] | None = None,
origins: Sequence[Origin | re.Pattern[str] | None] | None = None,
extensions: Sequence[ServerExtensionFactory] | None = None,
subprotocols: Sequence[Subprotocol] | None = None,
select_subprotocol: (
Expand Down
21 changes: 16 additions & 5 deletions src/websockets/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import binascii
import email.utils
import http
import re
import warnings
from collections.abc import Generator, Sequence
from typing import Any, Callable, cast
Expand Down Expand Up @@ -49,9 +50,9 @@ class ServerProtocol(Protocol):
Sans-I/O implementation of a WebSocket server connection.

Args:
origins: Acceptable values of the ``Origin`` header; include
:obj:`None` in the list if the lack of an origin is acceptable.
This is useful for defending against Cross-Site WebSocket
origins: Acceptable values of the ``Origin`` header, including regular
expressions; include :obj:`None` in the list if the lack of an origin
is acceptable. This is useful for defending against Cross-Site WebSocket
Hijacking attacks.
extensions: List of supported extensions, in order in which they
should be tried.
Expand All @@ -73,7 +74,7 @@ class ServerProtocol(Protocol):
def __init__(
self,
*,
origins: Sequence[Origin | None] | None = None,
origins: Sequence[Origin | re.Pattern[str] | None] | None = None,
extensions: Sequence[ServerExtensionFactory] | None = None,
subprotocols: Sequence[Subprotocol] | None = None,
select_subprotocol: (
Expand Down Expand Up @@ -309,7 +310,17 @@ def process_origin(self, headers: Headers) -> Origin | None:
if origin is not None:
origin = cast(Origin, origin)
if self.origins is not None:
if origin not in self.origins:
valid = False
for acceptable_origin_or_regex in self.origins:
if isinstance(acceptable_origin_or_regex, re.Pattern):
# `str(origin)` is needed for compatibility
# between `Pattern.match(string=...)` and `origin`.
valid = acceptable_origin_or_regex.match(str(origin)) is not None
else:
valid = acceptable_origin_or_regex == origin
if valid:
break
if not valid:
raise InvalidOrigin(origin)
return origin

Expand Down
10 changes: 6 additions & 4 deletions src/websockets/sync/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import http
import logging
import os
import re
import selectors
import socket
import ssl as ssl_module
Expand Down Expand Up @@ -325,7 +326,7 @@ def serve(
sock: socket.socket | None = None,
ssl: ssl_module.SSLContext | None = None,
# WebSocket
origins: Sequence[Origin | None] | None = None,
origins: Sequence[Origin | re.Pattern[str] | None] | None = None,
extensions: Sequence[ServerExtensionFactory] | None = None,
subprotocols: Sequence[Subprotocol] | None = None,
select_subprotocol: (
Expand Down Expand Up @@ -399,9 +400,10 @@ def handler(websocket):
You may call :func:`socket.create_server` to create a suitable TCP
socket.
ssl: Configuration for enabling TLS on the connection.
origins: Acceptable values of the ``Origin`` header, for defending
against Cross-Site WebSocket Hijacking attacks. Include :obj:`None`
in the list if the lack of an origin is acceptable.
origins: Acceptable values of the ``Origin`` header, including regular
expressions, for defending against Cross-Site WebSocket Hijacking
attacks. Include :obj:`None` in the list if the lack of an origin
is acceptable.
extensions: List of supported extensions, in order in which they
should be negotiated and run.
subprotocols: List of supported subprotocols, in order of decreasing
Expand Down
37 changes: 37 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import http
import logging
import re
import sys
import unittest
from unittest.mock import patch
Expand Down Expand Up @@ -623,6 +624,42 @@ def test_unsupported_origin(self):
"invalid Origin header: https://original.example.com",
)

def test_supported_origin_by_regex(self):
"""
Handshake succeeds when checking origins and the origin is supported
by a regular expression.
"""
server = ServerProtocol(
origins=["https://example.com", re.compile(r"https://other.*")]
)
request = make_request()
request.headers["Origin"] = "https://other.example.com"
response = server.accept(request)
server.send_response(response)

self.assertHandshakeSuccess(server)
self.assertEqual(server.origin, "https://other.example.com")

def test_unsupported_origin_by_regex(self):
"""
Handshake succeeds when checking origins and the origin is unsupported
by a regular expression.
"""
server = ServerProtocol(
origins=["https://example.com", re.compile(r"https://other.*")]
)
request = make_request()
request.headers["Origin"] = "https://original.example.com"
response = server.accept(request)
server.send_response(response)

self.assertEqual(response.status_code, 403)
self.assertHandshakeError(
server,
InvalidOrigin,
"invalid Origin header: https://original.example.com",
)

def test_no_origin_accepted(self):
"""Handshake succeeds when the lack of an origin is accepted."""
server = ServerProtocol(origins=[None])
Expand Down
Loading