Skip to content

Commit 44ce58a

Browse files
committed
Add regex support in ServerProtocol(origins=...)
1 parent 031ec31 commit 44ce58a

File tree

5 files changed

+47
-6
lines changed

5 files changed

+47
-6
lines changed

src/websockets/asyncio/server.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ def handler(websocket):
599599
See :meth:`~asyncio.loop.create_server` for details.
600600
port: TCP port the server listens on.
601601
See :meth:`~asyncio.loop.create_server` for details.
602-
origins: Acceptable values of the ``Origin`` header, for defending
602+
origins: Acceptable values of the ``Origin`` header, including regular expressions, for defending
603603
against Cross-Site WebSocket Hijacking attacks. Include :obj:`None`
604604
in the list if the lack of an origin is acceptable.
605605
extensions: List of supported extensions, in order in which they

src/websockets/server.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import email.utils
66
import http
77
import warnings
8+
import re
89
from collections.abc import Generator, Sequence
910
from typing import Any, Callable, cast
1011

@@ -49,7 +50,7 @@ class ServerProtocol(Protocol):
4950
Sans-I/O implementation of a WebSocket server connection.
5051
5152
Args:
52-
origins: Acceptable values of the ``Origin`` header; include
53+
origins: Acceptable values of the ``Origin`` header, including regular expressions; include
5354
:obj:`None` in the list if the lack of an origin is acceptable.
5455
This is useful for defending against Cross-Site WebSocket
5556
Hijacking attacks.
@@ -309,7 +310,15 @@ def process_origin(self, headers: Headers) -> Origin | None:
309310
if origin is not None:
310311
origin = cast(Origin, origin)
311312
if self.origins is not None:
312-
if origin not in self.origins:
313+
valid = False
314+
for acceptable_origin in self.origins:
315+
if isinstance(acceptable_origin, re.Pattern):
316+
valid = acceptable_origin.match(origin)
317+
else:
318+
valid = acceptable_origin == origin
319+
if valid:
320+
break
321+
if not valid:
313322
raise InvalidOrigin(origin)
314323
return origin
315324

src/websockets/sync/server.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def handler(websocket):
399399
You may call :func:`socket.create_server` to create a suitable TCP
400400
socket.
401401
ssl: Configuration for enabling TLS on the connection.
402-
origins: Acceptable values of the ``Origin`` header, for defending
402+
origins: Acceptable values of the ``Origin`` header, including regular expressions, for defending
403403
against Cross-Site WebSocket Hijacking attacks. Include :obj:`None`
404404
in the list if the lack of an origin is acceptable.
405405
extensions: List of supported extensions, in order in which they

src/websockets/typing.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import http
44
import logging
5+
import re
56
import typing
67
from typing import Any, NewType, Optional, Union
78

@@ -45,8 +46,8 @@
4546
Types accepted where an :class:`~http.HTTPStatus` is expected."""
4647

4748

48-
Origin = NewType("Origin", str)
49-
"""Value of a ``Origin`` header."""
49+
Origin = Union[str, re.Pattern]
50+
"""Value of a ``Origin`` header or a regular expression."""
5051

5152

5253
Subprotocol = NewType("Subprotocol", str)

tests/test_server.py

+31
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import http
22
import logging
3+
import re
34
import sys
45
import unittest
56
from unittest.mock import patch
@@ -623,6 +624,36 @@ def test_unsupported_origin(self):
623624
"invalid Origin header: https://original.example.com",
624625
)
625626

627+
def test_supported_origin_by_regex(self):
628+
"""Handshake succeeds when checking origins and the origin is supported by a regular expression."""
629+
server = ServerProtocol(
630+
origins=["https://example.com", re.compile(r"https://other.*")]
631+
)
632+
request = make_request()
633+
request.headers["Origin"] = "https://other.example.com"
634+
response = server.accept(request)
635+
server.send_response(response)
636+
637+
self.assertHandshakeSuccess(server)
638+
self.assertEqual(server.origin, "https://other.example.com")
639+
640+
def test_unsupported_origin_by_regex(self):
641+
"""Handshake succeeds when checking origins and the origin is unsupported by a regular expression."""
642+
server = ServerProtocol(
643+
origins=["https://example.com", re.compile(r"https://other.*")]
644+
)
645+
request = make_request()
646+
request.headers["Origin"] = "https://original.example.com"
647+
response = server.accept(request)
648+
server.send_response(response)
649+
650+
self.assertEqual(response.status_code, 403)
651+
self.assertHandshakeError(
652+
server,
653+
InvalidOrigin,
654+
"invalid Origin header: https://original.example.com",
655+
)
656+
626657
def test_no_origin_accepted(self):
627658
"""Handshake succeeds when the lack of an origin is accepted."""
628659
server = ServerProtocol(origins=[None])

0 commit comments

Comments
 (0)