Skip to content

Commit 7de24bd

Browse files
committed
Improve previous commit.
* Require fullmatch instead of match — this avoids a vulnerability. * Shorten code and tweak to match my preferred style. * Add changelog.
1 parent 7e617b2 commit 7de24bd

File tree

5 files changed

+52
-34
lines changed

5 files changed

+52
-34
lines changed

docs/project/changelog.rst

+6
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ notice.
3232

3333
*In development*
3434

35+
New features
36+
............
37+
38+
* Added support for regular expressions in the ``origins`` argument of
39+
:func:`~asyncio.server.serve`.
40+
3541
Bug fixes
3642
.........
3743

src/websockets/asyncio/server.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -600,10 +600,11 @@ def handler(websocket):
600600
See :meth:`~asyncio.loop.create_server` for details.
601601
port: TCP port the server listens on.
602602
See :meth:`~asyncio.loop.create_server` for details.
603-
origins: Acceptable values of the ``Origin`` header, including regular
604-
expressions, for defending against Cross-Site WebSocket Hijacking
605-
attacks. Include :obj:`None` in the list if the lack of an origin
606-
is acceptable.
603+
origins: Acceptable values of the ``Origin`` header, for defending
604+
against Cross-Site WebSocket Hijacking attacks. Values can be
605+
:class:`str` to test for an exact match or regular expressions
606+
compiled by :func:`re.compile` to test against a pattern. Include
607+
:obj:`None` in the list if the lack of an origin is acceptable.
607608
extensions: List of supported extensions, in order in which they
608609
should be negotiated and run.
609610
subprotocols: List of supported subprotocols, in order of decreasing

src/websockets/server.py

+12-13
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,11 @@ class ServerProtocol(Protocol):
5050
Sans-I/O implementation of a WebSocket server connection.
5151
5252
Args:
53-
origins: Acceptable values of the ``Origin`` header, including regular
54-
expressions; include :obj:`None` in the list if the lack of an origin
55-
is acceptable. This is useful for defending against Cross-Site WebSocket
53+
origins: Acceptable values of the ``Origin`` header. Values can be
54+
:class:`str` to test for an exact match or regular expressions
55+
compiled by :func:`re.compile` to test against a pattern. Include
56+
:obj:`None` in the list if the lack of an origin is acceptable.
57+
This is useful for defending against Cross-Site WebSocket
5658
Hijacking attacks.
5759
extensions: List of supported extensions, in order in which they
5860
should be tried.
@@ -310,17 +312,14 @@ def process_origin(self, headers: Headers) -> Origin | None:
310312
if origin is not None:
311313
origin = cast(Origin, origin)
312314
if self.origins is not None:
313-
valid = False
314-
for acceptable_origin_or_regex in self.origins:
315-
if isinstance(acceptable_origin_or_regex, re.Pattern):
316-
# `str(origin)` is needed for compatibility
317-
# between `Pattern.match(string=...)` and `origin`.
318-
valid = acceptable_origin_or_regex.match(str(origin)) is not None
319-
else:
320-
valid = acceptable_origin_or_regex == origin
321-
if valid:
315+
for origin_or_regex in self.origins:
316+
if origin_or_regex == origin or (
317+
isinstance(origin_or_regex, re.Pattern)
318+
and origin is not None
319+
and origin_or_regex.fullmatch(origin) is not None
320+
):
322321
break
323-
if not valid:
322+
else:
324323
raise InvalidOrigin(origin)
325324
return origin
326325

src/websockets/sync/server.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -400,10 +400,11 @@ def handler(websocket):
400400
You may call :func:`socket.create_server` to create a suitable TCP
401401
socket.
402402
ssl: Configuration for enabling TLS on the connection.
403-
origins: Acceptable values of the ``Origin`` header, including regular
404-
expressions, for defending against Cross-Site WebSocket Hijacking
405-
attacks. Include :obj:`None` in the list if the lack of an origin
406-
is acceptable.
403+
origins: Acceptable values of the ``Origin`` header, for defending
404+
against Cross-Site WebSocket Hijacking attacks. Values can be
405+
:class:`str` to test for an exact match or regular expressions
406+
compiled by :func:`re.compile` to test against a pattern. Include
407+
:obj:`None` in the list if the lack of an origin is acceptable.
407408
extensions: List of supported extensions, in order in which they
408409
should be negotiated and run.
409410
subprotocols: List of supported subprotocols, in order of decreasing

tests/test_server.py

+24-13
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ def test_supported_origin(self):
608608
self.assertEqual(server.origin, "https://other.example.com")
609609

610610
def test_unsupported_origin(self):
611-
"""Handshake succeeds when checking origins and the origin is unsupported."""
611+
"""Handshake fails when checking origins and the origin is unsupported."""
612612
server = ServerProtocol(
613613
origins=["https://example.com", "https://other.example.com"]
614614
)
@@ -624,13 +624,10 @@ def test_unsupported_origin(self):
624624
"invalid Origin header: https://original.example.com",
625625
)
626626

627-
def test_supported_origin_by_regex(self):
628-
"""
629-
Handshake succeeds when checking origins and the origin is supported
630-
by a regular expression.
631-
"""
627+
def test_supported_origin_regex(self):
628+
"""Handshake succeeds when checking origins and the origin is supported."""
632629
server = ServerProtocol(
633-
origins=["https://example.com", re.compile(r"https://other.*")]
630+
origins=[re.compile(r"https://(?!original)[a-z]+\.example\.com")]
634631
)
635632
request = make_request()
636633
request.headers["Origin"] = "https://other.example.com"
@@ -640,13 +637,10 @@ def test_supported_origin_by_regex(self):
640637
self.assertHandshakeSuccess(server)
641638
self.assertEqual(server.origin, "https://other.example.com")
642639

643-
def test_unsupported_origin_by_regex(self):
644-
"""
645-
Handshake succeeds when checking origins and the origin is unsupported
646-
by a regular expression.
647-
"""
640+
def test_unsupported_origin_regex(self):
641+
"""Handshake fails when checking origins and the origin is unsupported."""
648642
server = ServerProtocol(
649-
origins=["https://example.com", re.compile(r"https://other.*")]
643+
origins=[re.compile(r"https://(?!original)[a-z]+\.example\.com")]
650644
)
651645
request = make_request()
652646
request.headers["Origin"] = "https://original.example.com"
@@ -660,6 +654,23 @@ def test_unsupported_origin_by_regex(self):
660654
"invalid Origin header: https://original.example.com",
661655
)
662656

657+
def test_partial_match_origin_regex(self):
658+
"""Handshake fails when checking origins and the origin a partial match."""
659+
server = ServerProtocol(
660+
origins=[re.compile(r"https://(?!original)[a-z]+\.example\.com")]
661+
)
662+
request = make_request()
663+
request.headers["Origin"] = "https://other.example.com.hacked"
664+
response = server.accept(request)
665+
server.send_response(response)
666+
667+
self.assertEqual(response.status_code, 403)
668+
self.assertHandshakeError(
669+
server,
670+
InvalidOrigin,
671+
"invalid Origin header: https://other.example.com.hacked",
672+
)
673+
663674
def test_no_origin_accepted(self):
664675
"""Handshake succeeds when the lack of an origin is accepted."""
665676
server = ServerProtocol(origins=[None])

0 commit comments

Comments
 (0)