Skip to content

Commit 1a43fe2

Browse files
committed
Add regex support in ServerProtocol(origins=...).
1 parent e7a098e commit 1a43fe2

File tree

4 files changed

+65
-13
lines changed

4 files changed

+65
-13
lines changed

src/websockets/asyncio/server.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import hmac
55
import http
66
import logging
7+
import re
78
import socket
89
import sys
910
from collections.abc import Awaitable, Generator, Iterable, Sequence
@@ -599,9 +600,10 @@ def handler(websocket):
599600
See :meth:`~asyncio.loop.create_server` for details.
600601
port: TCP port the server listens on.
601602
See :meth:`~asyncio.loop.create_server` for details.
602-
origins: Acceptable values of the ``Origin`` header, for defending
603-
against Cross-Site WebSocket Hijacking attacks. Include :obj:`None`
604-
in the list if the lack of an origin is acceptable.
603+
origins: Acceptable values of the ``Origin`` header, including regular
604+
expressions, for defending against Cross-Site WebSocket Hijacking
605+
attacks. Include :obj:`None` in the list if the lack of an origin
606+
is acceptable.
605607
extensions: List of supported extensions, in order in which they
606608
should be negotiated and run.
607609
subprotocols: List of supported subprotocols, in order of decreasing
@@ -681,7 +683,7 @@ def __init__(
681683
port: int | None = None,
682684
*,
683685
# WebSocket
684-
origins: Sequence[Origin | None] | None = None,
686+
origins: Sequence[Origin | re.Pattern[str] | None] | None = None,
685687
extensions: Sequence[ServerExtensionFactory] | None = None,
686688
subprotocols: Sequence[Subprotocol] | None = None,
687689
select_subprotocol: (

src/websockets/server.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import binascii
55
import email.utils
66
import http
7+
import re
78
import warnings
89
from collections.abc import Generator, Sequence
910
from typing import Any, Callable, cast
@@ -49,9 +50,9 @@ class ServerProtocol(Protocol):
4950
Sans-I/O implementation of a WebSocket server connection.
5051
5152
Args:
52-
origins: Acceptable values of the ``Origin`` header; include
53-
:obj:`None` in the list if the lack of an origin is acceptable.
54-
This is useful for defending against Cross-Site WebSocket
53+
origins: Acceptable values of the ``Origin`` header, including regular
54+
expressions; include :obj:`None` in the list if the lack of an origin
55+
is acceptable. This is useful for defending against Cross-Site WebSocket
5556
Hijacking attacks.
5657
extensions: List of supported extensions, in order in which they
5758
should be tried.
@@ -73,7 +74,7 @@ class ServerProtocol(Protocol):
7374
def __init__(
7475
self,
7576
*,
76-
origins: Sequence[Origin | None] | None = None,
77+
origins: Sequence[Origin | re.Pattern[str] | None] | None = None,
7778
extensions: Sequence[ServerExtensionFactory] | None = None,
7879
subprotocols: Sequence[Subprotocol] | None = None,
7980
select_subprotocol: (
@@ -309,7 +310,17 @@ def process_origin(self, headers: Headers) -> Origin | None:
309310
if origin is not None:
310311
origin = cast(Origin, origin)
311312
if self.origins is not None:
312-
if origin not in self.origins:
313+
valid = False
314+
for acceptable_origin_or_regex in self.origins:
315+
if isinstance(acceptable_origin_or_regex, re.Pattern):
316+
# `str(origin)` is needed for compatibility
317+
# between `Pattern.match(string=...)` and `origin`.
318+
valid = acceptable_origin_or_regex.match(str(origin)) is not None
319+
else:
320+
valid = acceptable_origin_or_regex == origin
321+
if valid:
322+
break
323+
if not valid:
313324
raise InvalidOrigin(origin)
314325
return origin
315326

src/websockets/sync/server.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import http
55
import logging
66
import os
7+
import re
78
import selectors
89
import socket
910
import ssl as ssl_module
@@ -325,7 +326,7 @@ def serve(
325326
sock: socket.socket | None = None,
326327
ssl: ssl_module.SSLContext | None = None,
327328
# WebSocket
328-
origins: Sequence[Origin | None] | None = None,
329+
origins: Sequence[Origin | re.Pattern[str] | None] | None = None,
329330
extensions: Sequence[ServerExtensionFactory] | None = None,
330331
subprotocols: Sequence[Subprotocol] | None = None,
331332
select_subprotocol: (
@@ -399,9 +400,10 @@ def handler(websocket):
399400
You may call :func:`socket.create_server` to create a suitable TCP
400401
socket.
401402
ssl: Configuration for enabling TLS on the connection.
402-
origins: Acceptable values of the ``Origin`` header, for defending
403-
against Cross-Site WebSocket Hijacking attacks. Include :obj:`None`
404-
in the list if the lack of an origin is acceptable.
403+
origins: Acceptable values of the ``Origin`` header, including regular
404+
expressions, for defending against Cross-Site WebSocket Hijacking
405+
attacks. Include :obj:`None` in the list if the lack of an origin
406+
is acceptable.
405407
extensions: List of supported extensions, in order in which they
406408
should be negotiated and run.
407409
subprotocols: List of supported subprotocols, in order of decreasing

tests/test_server.py

+37
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import http
22
import logging
3+
import re
34
import sys
45
import unittest
56
from unittest.mock import patch
@@ -623,6 +624,42 @@ def test_unsupported_origin(self):
623624
"invalid Origin header: https://original.example.com",
624625
)
625626

627+
def test_supported_origin_by_regex(self):
628+
"""
629+
Handshake succeeds when checking origins and the origin is supported
630+
by a regular expression.
631+
"""
632+
server = ServerProtocol(
633+
origins=["https://example.com", re.compile(r"https://other.*")]
634+
)
635+
request = make_request()
636+
request.headers["Origin"] = "https://other.example.com"
637+
response = server.accept(request)
638+
server.send_response(response)
639+
640+
self.assertHandshakeSuccess(server)
641+
self.assertEqual(server.origin, "https://other.example.com")
642+
643+
def test_unsupported_origin_by_regex(self):
644+
"""
645+
Handshake succeeds when checking origins and the origin is unsupported
646+
by a regular expression.
647+
"""
648+
server = ServerProtocol(
649+
origins=["https://example.com", re.compile(r"https://other.*")]
650+
)
651+
request = make_request()
652+
request.headers["Origin"] = "https://original.example.com"
653+
response = server.accept(request)
654+
server.send_response(response)
655+
656+
self.assertEqual(response.status_code, 403)
657+
self.assertHandshakeError(
658+
server,
659+
InvalidOrigin,
660+
"invalid Origin header: https://original.example.com",
661+
)
662+
626663
def test_no_origin_accepted(self):
627664
"""Handshake succeeds when the lack of an origin is accepted."""
628665
server = ServerProtocol(origins=[None])

0 commit comments

Comments
 (0)