|
1 | 1 | import http
|
2 | 2 | import logging
|
| 3 | +import re |
3 | 4 | import sys
|
4 | 5 | import unittest
|
5 | 6 | from unittest.mock import patch
|
@@ -623,6 +624,36 @@ def test_unsupported_origin(self):
|
623 | 624 | "invalid Origin header: https://original.example.com",
|
624 | 625 | )
|
625 | 626 |
|
| 627 | + def test_supported_origin_by_regex(self): |
| 628 | + """Handshake succeeds when checking origins and the origin is supported by a regular expression.""" |
| 629 | + server = ServerProtocol( |
| 630 | + origins=["https://example.com", re.compile(r"https://other.*")] |
| 631 | + ) |
| 632 | + request = make_request() |
| 633 | + request.headers["Origin"] = "https://other.example.com" |
| 634 | + response = server.accept(request) |
| 635 | + server.send_response(response) |
| 636 | + |
| 637 | + self.assertHandshakeSuccess(server) |
| 638 | + self.assertEqual(server.origin, "https://other.example.com") |
| 639 | + |
| 640 | + def test_unsupported_origin_by_regex(self): |
| 641 | + """Handshake succeeds when checking origins and the origin is unsupported by a regular expression.""" |
| 642 | + server = ServerProtocol( |
| 643 | + origins=["https://example.com", re.compile(r"https://other.*")] |
| 644 | + ) |
| 645 | + request = make_request() |
| 646 | + request.headers["Origin"] = "https://original.example.com" |
| 647 | + response = server.accept(request) |
| 648 | + server.send_response(response) |
| 649 | + |
| 650 | + self.assertEqual(response.status_code, 403) |
| 651 | + self.assertHandshakeError( |
| 652 | + server, |
| 653 | + InvalidOrigin, |
| 654 | + "invalid Origin header: https://original.example.com", |
| 655 | + ) |
| 656 | + |
626 | 657 | def test_no_origin_accepted(self):
|
627 | 658 | """Handshake succeeds when the lack of an origin is accepted."""
|
628 | 659 | server = ServerProtocol(origins=[None])
|
|
0 commit comments