Skip to content

Commit 227d4a1

Browse files
committed
Add regex support in ServerProtocol(origins=...)
1 parent 031ec31 commit 227d4a1

File tree

3 files changed

+32
-4
lines changed

3 files changed

+32
-4
lines changed

src/websockets/asyncio/server.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ def handler(websocket):
599599
See :meth:`~asyncio.loop.create_server` for details.
600600
port: TCP port the server listens on.
601601
See :meth:`~asyncio.loop.create_server` for details.
602-
origins: Acceptable values of the ``Origin`` header, for defending
602+
origins: Acceptable values of the ``Origin`` header, including regular expressions, for defending
603603
against Cross-Site WebSocket Hijacking attacks. Include :obj:`None`
604604
in the list if the lack of an origin is acceptable.
605605
extensions: List of supported extensions, in order in which they

src/websockets/server.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import email.utils
66
import http
77
import warnings
8+
import re
89
from collections.abc import Generator, Sequence
910
from typing import Any, Callable, cast
1011

@@ -49,7 +50,7 @@ class ServerProtocol(Protocol):
4950
Sans-I/O implementation of a WebSocket server connection.
5051
5152
Args:
52-
origins: Acceptable values of the ``Origin`` header; include
53+
origins: Acceptable values of the ``Origin`` header, including regular expressions; include
5354
:obj:`None` in the list if the lack of an origin is acceptable.
5455
This is useful for defending against Cross-Site WebSocket
5556
Hijacking attacks.
@@ -309,10 +310,37 @@ def process_origin(self, headers: Headers) -> Origin | None:
309310
if origin is not None:
310311
origin = cast(Origin, origin)
311312
if self.origins is not None:
312-
if origin not in self.origins:
313+
valid = False
314+
for origin_maybe_regex in self.origins:
315+
if origin_maybe_regex is None:
316+
continue
317+
if self.probably_regex(origin_maybe_regex):
318+
valid = re.match(origin_maybe_regex, origin)
319+
else:
320+
valid = origin_maybe_regex == origin
321+
if valid:
322+
break
323+
if not valid:
313324
raise InvalidOrigin(origin)
314325
return origin
315326

327+
@staticmethod
328+
def probably_regex(maybe_regex: str) -> bool:
329+
"""
330+
Determine if the given string is a regex.
331+
332+
Args:
333+
maybe_regex: A string that may be a regex.
334+
335+
Returns:
336+
True if the string is a regex, False otherwise.
337+
338+
"""
339+
common_regex_chars = ['*', '\\', ']', '?', '$', '^', '[', ']', '(', ')']
340+
# Use common characters used in regular expressions as a proxy
341+
# for if this string is in fact a regex.
342+
return any((c in maybe_regex for c in common_regex_chars))
343+
316344
def process_extensions(
317345
self,
318346
headers: Headers,

src/websockets/sync/server.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def handler(websocket):
399399
You may call :func:`socket.create_server` to create a suitable TCP
400400
socket.
401401
ssl: Configuration for enabling TLS on the connection.
402-
origins: Acceptable values of the ``Origin`` header, for defending
402+
origins: Acceptable values of the ``Origin`` header, including regular expressions, for defending
403403
against Cross-Site WebSocket Hijacking attacks. Include :obj:`None`
404404
in the list if the lack of an origin is acceptable.
405405
extensions: List of supported extensions, in order in which they

0 commit comments

Comments
 (0)