Skip to content

Commit 5fffd54

Browse files
committed
Add regex support in ServerProtocol(origins=...).
1 parent 031ec31 commit 5fffd54

File tree

5 files changed

+61
-12
lines changed

5 files changed

+61
-12
lines changed

src/websockets/asyncio/server.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -599,9 +599,10 @@ def handler(websocket):
599599
See :meth:`~asyncio.loop.create_server` for details.
600600
port: TCP port the server listens on.
601601
See :meth:`~asyncio.loop.create_server` for details.
602-
origins: Acceptable values of the ``Origin`` header, for defending
603-
against Cross-Site WebSocket Hijacking attacks. Include :obj:`None`
604-
in the list if the lack of an origin is acceptable.
602+
origins: Acceptable values of the ``Origin`` header, including regular
603+
expressions, for defending against Cross-Site WebSocket Hijacking
604+
attacks. Include :obj:`None` in the list if the lack of an origin
605+
is acceptable.
605606
extensions: List of supported extensions, in order in which they
606607
should be negotiated and run.
607608
subprotocols: List of supported subprotocols, in order of decreasing

src/websockets/server.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import binascii
55
import email.utils
66
import http
7+
import re
78
import warnings
89
from collections.abc import Generator, Sequence
910
from typing import Any, Callable, cast
@@ -49,9 +50,9 @@ class ServerProtocol(Protocol):
4950
Sans-I/O implementation of a WebSocket server connection.
5051
5152
Args:
52-
origins: Acceptable values of the ``Origin`` header; include
53-
:obj:`None` in the list if the lack of an origin is acceptable.
54-
This is useful for defending against Cross-Site WebSocket
53+
origins: Acceptable values of the ``Origin`` header, including regular
54+
expressions; include :obj:`None` in the list if the lack of an origin
55+
is acceptable. This is useful for defending against Cross-Site WebSocket
5556
Hijacking attacks.
5657
extensions: List of supported extensions, in order in which they
5758
should be tried.
@@ -309,7 +310,15 @@ def process_origin(self, headers: Headers) -> Origin | None:
309310
if origin is not None:
310311
origin = cast(Origin, origin)
311312
if self.origins is not None:
312-
if origin not in self.origins:
313+
valid = False
314+
for acceptable_origin in self.origins:
315+
if isinstance(acceptable_origin, re.Pattern):
316+
valid = acceptable_origin.match(origin)
317+
else:
318+
valid = acceptable_origin == origin
319+
if valid:
320+
break
321+
if not valid:
313322
raise InvalidOrigin(origin)
314323
return origin
315324

src/websockets/sync/server.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -399,9 +399,10 @@ def handler(websocket):
399399
You may call :func:`socket.create_server` to create a suitable TCP
400400
socket.
401401
ssl: Configuration for enabling TLS on the connection.
402-
origins: Acceptable values of the ``Origin`` header, for defending
403-
against Cross-Site WebSocket Hijacking attacks. Include :obj:`None`
404-
in the list if the lack of an origin is acceptable.
402+
origins: Acceptable values of the ``Origin`` header, including regular
403+
expressions, for defending against Cross-Site WebSocket Hijacking
404+
attacks. Include :obj:`None` in the list if the lack of an origin
405+
is acceptable.
405406
extensions: List of supported extensions, in order in which they
406407
should be negotiated and run.
407408
subprotocols: List of supported subprotocols, in order of decreasing

src/websockets/typing.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import http
44
import logging
5+
import re
56
import typing
67
from typing import Any, NewType, Optional, Union
78

@@ -45,8 +46,8 @@
4546
Types accepted where an :class:`~http.HTTPStatus` is expected."""
4647

4748

48-
Origin = NewType("Origin", str)
49-
"""Value of a ``Origin`` header."""
49+
Origin = Union[str, re.Pattern]
50+
"""Value of a ``Origin`` header or a regular expression."""
5051

5152

5253
Subprotocol = NewType("Subprotocol", str)

tests/test_server.py

+37
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import http
22
import logging
3+
import re
34
import sys
45
import unittest
56
from unittest.mock import patch
@@ -623,6 +624,42 @@ def test_unsupported_origin(self):
623624
"invalid Origin header: https://original.example.com",
624625
)
625626

627+
def test_supported_origin_by_regex(self):
628+
"""
629+
Handshake succeeds when checking origins and the origin is supported
630+
by a regular expression.
631+
"""
632+
server = ServerProtocol(
633+
origins=["https://example.com", re.compile(r"https://other.*")]
634+
)
635+
request = make_request()
636+
request.headers["Origin"] = "https://other.example.com"
637+
response = server.accept(request)
638+
server.send_response(response)
639+
640+
self.assertHandshakeSuccess(server)
641+
self.assertEqual(server.origin, "https://other.example.com")
642+
643+
def test_unsupported_origin_by_regex(self):
644+
"""
645+
Handshake succeeds when checking origins and the origin is unsupported
646+
by a regular expression.
647+
"""
648+
server = ServerProtocol(
649+
origins=["https://example.com", re.compile(r"https://other.*")]
650+
)
651+
request = make_request()
652+
request.headers["Origin"] = "https://original.example.com"
653+
response = server.accept(request)
654+
server.send_response(response)
655+
656+
self.assertEqual(response.status_code, 403)
657+
self.assertHandshakeError(
658+
server,
659+
InvalidOrigin,
660+
"invalid Origin header: https://original.example.com",
661+
)
662+
626663
def test_no_origin_accepted(self):
627664
"""Handshake succeeds when the lack of an origin is accepted."""
628665
server = ServerProtocol(origins=[None])

0 commit comments

Comments
 (0)