Skip to content

Commit cba2981

Browse files
authored
Add regex support in ServerProtocol(origins=...)
1 parent 031ec31 commit cba2981

File tree

1 file changed

+30
-2
lines changed

1 file changed

+30
-2
lines changed

src/websockets/server.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import email.utils
66
import http
77
import warnings
8+
import re
89
from collections.abc import Generator, Sequence
910
from typing import Any, Callable, cast
1011

@@ -49,7 +50,7 @@ class ServerProtocol(Protocol):
4950
Sans-I/O implementation of a WebSocket server connection.
5051
5152
Args:
52-
origins: Acceptable values of the ``Origin`` header; include
53+
origins: Acceptable values of the ``Origin`` header, including regular expressions; include
5354
:obj:`None` in the list if the lack of an origin is acceptable.
5455
This is useful for defending against Cross-Site WebSocket
5556
Hijacking attacks.
@@ -309,10 +310,37 @@ def process_origin(self, headers: Headers) -> Origin | None:
309310
if origin is not None:
310311
origin = cast(Origin, origin)
311312
if self.origins is not None:
312-
if origin not in self.origins:
313+
valid = False
314+
for origin_maybe_regex in self.origins:
315+
if origin_maybe_regex is None:
316+
continue
317+
if self.probably_regex(origin_maybe_regex):
318+
valid = re.match(origin_maybe_regex, origin)
319+
else:
320+
valid = origin_maybe_regex == origin
321+
if valid:
322+
break
323+
if not valid:
313324
raise InvalidOrigin(origin)
314325
return origin
315326

327+
@staticmethod
328+
def probably_regex(maybe_regex: str) -> bool:
329+
"""
330+
Determine if the given string is a regex.
331+
332+
Args:
333+
maybe_regex: A string that may be a regex.
334+
335+
Returns:
336+
True if the string is a regex, False otherwise.
337+
338+
"""
339+
common_regex_chars = ['*', '\\', ']', '?', '$', '^', '[', ']', '(', ')']
340+
# Use common characters used in regular expressions as a proxy
341+
# for if this string is in fact a regex.
342+
return any((c in maybe_regex for c in common_regex_chars))
343+
316344
def process_extensions(
317345
self,
318346
headers: Headers,

0 commit comments

Comments
 (0)