Skip to content

Commit 1a90f1e

Browse files
committed
Set User-Agent header in CONNECT requests.
1 parent 2b9a90a commit 1a90f1e

File tree

5 files changed

+99
-24
lines changed

5 files changed

+99
-24
lines changed

src/websockets/asyncio/client.py

+27-10
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ async def handshake(
9797
self.request = self.protocol.connect()
9898
if additional_headers is not None:
9999
self.request.headers.update(additional_headers)
100-
if user_agent_header:
100+
if user_agent_header is not None:
101101
self.request.headers.setdefault("User-Agent", user_agent_header)
102102
self.protocol.send_request(self.request)
103103

@@ -363,10 +363,8 @@ def protocol_factory(uri: WebSocketURI) -> ClientConnection:
363363

364364
self.proxy = proxy
365365
self.protocol_factory = protocol_factory
366-
self.handshake_args = (
367-
additional_headers,
368-
user_agent_header,
369-
)
366+
self.additional_headers = additional_headers
367+
self.user_agent_header = user_agent_header
370368
self.process_exception = process_exception
371369
self.open_timeout = open_timeout
372370
self.logger = logger
@@ -442,6 +440,7 @@ def factory() -> ClientConnection:
442440
transport = await connect_http_proxy(
443441
proxy_parsed,
444442
ws_uri,
443+
user_agent_header=self.user_agent_header,
445444
**proxy_kwargs,
446445
)
447446
# Initialize WebSocket connection via the proxy.
@@ -541,7 +540,10 @@ async def __await_impl__(self) -> ClientConnection:
541540
for _ in range(MAX_REDIRECTS):
542541
self.connection = await self.create_connection()
543542
try:
544-
await self.connection.handshake(*self.handshake_args)
543+
await self.connection.handshake(
544+
self.additional_headers,
545+
self.user_agent_header,
546+
)
545547
except asyncio.CancelledError:
546548
self.connection.transport.abort()
547549
raise
@@ -717,10 +719,16 @@ async def connect_socks_proxy(
717719
raise ImportError("python-socks is required to use a SOCKS proxy")
718720

719721

720-
def prepare_connect_request(proxy: Proxy, ws_uri: WebSocketURI) -> bytes:
722+
def prepare_connect_request(
723+
proxy: Proxy,
724+
ws_uri: WebSocketURI,
725+
user_agent_header: str | None = None,
726+
) -> bytes:
721727
host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True)
722728
headers = Headers()
723729
headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure)
730+
if user_agent_header is not None:
731+
headers["User-Agent"] = user_agent_header
724732
if proxy.username is not None:
725733
assert proxy.password is not None # enforced by parse_proxy()
726734
headers["Proxy-Authorization"] = build_authorization_basic(
@@ -731,9 +739,15 @@ def prepare_connect_request(proxy: Proxy, ws_uri: WebSocketURI) -> bytes:
731739

732740

733741
class HTTPProxyConnection(asyncio.Protocol):
734-
def __init__(self, ws_uri: WebSocketURI, proxy: Proxy):
742+
def __init__(
743+
self,
744+
ws_uri: WebSocketURI,
745+
proxy: Proxy,
746+
user_agent_header: str | None = None,
747+
):
735748
self.ws_uri = ws_uri
736749
self.proxy = proxy
750+
self.user_agent_header = user_agent_header
737751

738752
self.reader = StreamReader()
739753
self.parser = Response.parse(
@@ -765,7 +779,9 @@ def run_parser(self) -> None:
765779
def connection_made(self, transport: asyncio.BaseTransport) -> None:
766780
transport = cast(asyncio.Transport, transport)
767781
self.transport = transport
768-
self.transport.write(prepare_connect_request(self.proxy, self.ws_uri))
782+
self.transport.write(
783+
prepare_connect_request(self.proxy, self.ws_uri, self.user_agent_header)
784+
)
769785

770786
def data_received(self, data: bytes) -> None:
771787
self.reader.feed_data(data)
@@ -784,10 +800,11 @@ def connection_lost(self, exc: Exception | None) -> None:
784800
async def connect_http_proxy(
785801
proxy: Proxy,
786802
ws_uri: WebSocketURI,
803+
user_agent_header: str | None = None,
787804
**kwargs: Any,
788805
) -> asyncio.Transport:
789806
transport, protocol = await asyncio.get_running_loop().create_connection(
790-
lambda: HTTPProxyConnection(ws_uri, proxy),
807+
lambda: HTTPProxyConnection(ws_uri, proxy, user_agent_header),
791808
proxy.host,
792809
proxy.port,
793810
**kwargs,

src/websockets/sync/client.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ def connect(
312312
proxy_parsed,
313313
ws_uri,
314314
deadline,
315+
user_agent_header=user_agent_header,
315316
ssl=proxy_ssl,
316317
server_hostname=proxy_server_hostname,
317318
**kwargs,
@@ -472,10 +473,16 @@ def connect_socks_proxy(
472473
raise ImportError("python-socks is required to use a SOCKS proxy")
473474

474475

475-
def prepare_connect_request(proxy: Proxy, ws_uri: WebSocketURI) -> bytes:
476+
def prepare_connect_request(
477+
proxy: Proxy,
478+
ws_uri: WebSocketURI,
479+
user_agent_header: str | None = None,
480+
) -> bytes:
476481
host = build_host(ws_uri.host, ws_uri.port, ws_uri.secure, always_include_port=True)
477482
headers = Headers()
478483
headers["Host"] = build_host(ws_uri.host, ws_uri.port, ws_uri.secure)
484+
if user_agent_header is not None:
485+
headers["User-Agent"] = user_agent_header
479486
if proxy.username is not None:
480487
assert proxy.password is not None # enforced by parse_proxy()
481488
headers["Proxy-Authorization"] = build_authorization_basic(
@@ -524,6 +531,7 @@ def connect_http_proxy(
524531
ws_uri: WebSocketURI,
525532
deadline: Deadline,
526533
*,
534+
user_agent_header: str | None = None,
527535
ssl: ssl_module.SSLContext | None = None,
528536
server_hostname: str | None = None,
529537
**kwargs: Any,
@@ -546,7 +554,7 @@ def connect_http_proxy(
546554

547555
# Send CONNECT request to the proxy and read response.
548556

549-
sock.sendall(prepare_connect_request(proxy, ws_uri))
557+
sock.sendall(prepare_connect_request(proxy, ws_uri, user_agent_header))
550558
try:
551559
read_connect_response(sock, deadline)
552560
except Exception:

tests/asyncio/test_client.py

+18
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,24 @@ async def test_authenticated_http_proxy_error(self):
712712
)
713713
self.assertNumFlows(0)
714714

715+
@patch.dict(os.environ, {"https_proxy": "http://localhost:58080"})
716+
async def test_http_proxy_override_user_agent(self):
717+
"""Client can override User-Agent header with user_agent_header."""
718+
async with serve(*args) as server:
719+
async with connect(get_uri(server), user_agent_header="Smith") as client:
720+
self.assertEqual(client.protocol.state.name, "OPEN")
721+
[http_connect] = self.get_http_connects()
722+
self.assertEqual(http_connect.request.headers[b"User-Agent"], "Smith")
723+
724+
@patch.dict(os.environ, {"https_proxy": "http://localhost:58080"})
725+
async def test_http_proxy_remove_user_agent(self):
726+
"""Client can remove User-Agent header with user_agent_header."""
727+
async with serve(*args) as server:
728+
async with connect(get_uri(server), user_agent_header=None) as client:
729+
self.assertEqual(client.protocol.state.name, "OPEN")
730+
[http_connect] = self.get_http_connects()
731+
self.assertNotIn(b"User-Agent", http_connect.request.headers)
732+
715733
@patch.dict(os.environ, {"https_proxy": "http://localhost:58080"})
716734
async def test_http_proxy_protocol_error(self):
717735
"""Client receives invalid data when connecting to the HTTP proxy."""

tests/proxy.py

+26-12
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,26 @@
2222
class RecordFlows:
2323
def __init__(self, on_running):
2424
self.running = on_running
25-
self.flows = []
25+
self.http_connects = []
26+
self.tcp_flows = []
27+
28+
def http_connect(self, flow):
29+
self.http_connects.append(flow)
2630

2731
def tcp_start(self, flow):
28-
self.flows.append(flow)
32+
self.tcp_flows.append(flow)
33+
34+
def get_http_connects(self):
35+
http_connects, self.http_connects[:] = self.http_connects[:], []
36+
return http_connects
2937

30-
def get_flows(self):
31-
flows, self.flows[:] = self.flows[:], []
32-
return flows
38+
def get_tcp_flows(self):
39+
tcp_flows, self.tcp_flows[:] = self.tcp_flows[:], []
40+
return tcp_flows
3341

34-
def reset_flows(self):
35-
self.flows = []
42+
def reset(self):
43+
self.http_connects = []
44+
self.tcp_flows = []
3645

3746

3847
class AlterRequest:
@@ -121,13 +130,18 @@ def setUpClass(cls):
121130
cls.proxy_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
122131
cls.proxy_context.load_verify_locations(bytes(certificate))
123132

124-
def assertNumFlows(self, num_flows):
125-
record_flows = self.proxy_master.addons.get("recordflows")
126-
self.assertEqual(len(record_flows.get_flows()), num_flows)
133+
def get_http_connects(self):
134+
return self.proxy_master.addons.get("recordflows").get_http_connects()
135+
136+
def get_tcp_flows(self):
137+
return self.proxy_master.addons.get("recordflows").get_tcp_flows()
138+
139+
def assertNumFlows(self, num_tcp_flows):
140+
self.assertEqual(len(self.get_tcp_flows()), num_tcp_flows)
127141

128142
def tearDown(self):
129-
record_flows = self.proxy_master.addons.get("recordflows")
130-
record_flows.reset_flows()
143+
record_tcp_flows = self.proxy_master.addons.get("recordflows")
144+
record_tcp_flows.reset()
131145
super().tearDown()
132146

133147
@classmethod

tests/sync/test_client.py

+18
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,24 @@ def test_authenticated_http_proxy_error(self):
456456
)
457457
self.assertNumFlows(0)
458458

459+
@patch.dict(os.environ, {"https_proxy": "http://localhost:58080"})
460+
def test_http_proxy_override_user_agent(self):
461+
"""Client can override User-Agent header with user_agent_header."""
462+
with run_server() as server:
463+
with connect(get_uri(server), user_agent_header="Smith") as client:
464+
self.assertEqual(client.protocol.state.name, "OPEN")
465+
[http_connect] = self.get_http_connects()
466+
self.assertEqual(http_connect.request.headers[b"User-Agent"], "Smith")
467+
468+
@patch.dict(os.environ, {"https_proxy": "http://localhost:58080"})
469+
def test_http_proxy_remove_user_agent(self):
470+
"""Client can remove User-Agent header with user_agent_header."""
471+
with run_server() as server:
472+
with connect(get_uri(server), user_agent_header=None) as client:
473+
self.assertEqual(client.protocol.state.name, "OPEN")
474+
[http_connect] = self.get_http_connects()
475+
self.assertNotIn(b"User-Agent", http_connect.request.headers)
476+
459477
@patch.dict(os.environ, {"https_proxy": "http://localhost:58080"})
460478
def test_http_proxy_protocol_error(self):
461479
"""Client receives invalid data when connecting to the HTTP proxy."""

0 commit comments

Comments
 (0)