@@ -97,7 +97,7 @@ async def handshake(
97
97
self .request = self .protocol .connect ()
98
98
if additional_headers is not None :
99
99
self .request .headers .update (additional_headers )
100
- if user_agent_header :
100
+ if user_agent_header is not None :
101
101
self .request .headers .setdefault ("User-Agent" , user_agent_header )
102
102
self .protocol .send_request (self .request )
103
103
@@ -363,10 +363,8 @@ def protocol_factory(uri: WebSocketURI) -> ClientConnection:
363
363
364
364
self .proxy = proxy
365
365
self .protocol_factory = protocol_factory
366
- self .handshake_args = (
367
- additional_headers ,
368
- user_agent_header ,
369
- )
366
+ self .additional_headers = additional_headers
367
+ self .user_agent_header = user_agent_header
370
368
self .process_exception = process_exception
371
369
self .open_timeout = open_timeout
372
370
self .logger = logger
@@ -442,6 +440,7 @@ def factory() -> ClientConnection:
442
440
transport = await connect_http_proxy (
443
441
proxy_parsed ,
444
442
ws_uri ,
443
+ user_agent_header = self .user_agent_header ,
445
444
** proxy_kwargs ,
446
445
)
447
446
# Initialize WebSocket connection via the proxy.
@@ -541,7 +540,10 @@ async def __await_impl__(self) -> ClientConnection:
541
540
for _ in range (MAX_REDIRECTS ):
542
541
self .connection = await self .create_connection ()
543
542
try :
544
- await self .connection .handshake (* self .handshake_args )
543
+ await self .connection .handshake (
544
+ self .additional_headers ,
545
+ self .user_agent_header ,
546
+ )
545
547
except asyncio .CancelledError :
546
548
self .connection .transport .abort ()
547
549
raise
@@ -717,10 +719,16 @@ async def connect_socks_proxy(
717
719
raise ImportError ("python-socks is required to use a SOCKS proxy" )
718
720
719
721
720
- def prepare_connect_request (proxy : Proxy , ws_uri : WebSocketURI ) -> bytes :
722
+ def prepare_connect_request (
723
+ proxy : Proxy ,
724
+ ws_uri : WebSocketURI ,
725
+ user_agent_header : str | None = None ,
726
+ ) -> bytes :
721
727
host = build_host (ws_uri .host , ws_uri .port , ws_uri .secure , always_include_port = True )
722
728
headers = Headers ()
723
729
headers ["Host" ] = build_host (ws_uri .host , ws_uri .port , ws_uri .secure )
730
+ if user_agent_header is not None :
731
+ headers ["User-Agent" ] = user_agent_header
724
732
if proxy .username is not None :
725
733
assert proxy .password is not None # enforced by parse_proxy()
726
734
headers ["Proxy-Authorization" ] = build_authorization_basic (
@@ -731,9 +739,15 @@ def prepare_connect_request(proxy: Proxy, ws_uri: WebSocketURI) -> bytes:
731
739
732
740
733
741
class HTTPProxyConnection (asyncio .Protocol ):
734
- def __init__ (self , ws_uri : WebSocketURI , proxy : Proxy ):
742
+ def __init__ (
743
+ self ,
744
+ ws_uri : WebSocketURI ,
745
+ proxy : Proxy ,
746
+ user_agent_header : str | None = None ,
747
+ ):
735
748
self .ws_uri = ws_uri
736
749
self .proxy = proxy
750
+ self .user_agent_header = user_agent_header
737
751
738
752
self .reader = StreamReader ()
739
753
self .parser = Response .parse (
@@ -765,7 +779,9 @@ def run_parser(self) -> None:
765
779
def connection_made (self , transport : asyncio .BaseTransport ) -> None :
766
780
transport = cast (asyncio .Transport , transport )
767
781
self .transport = transport
768
- self .transport .write (prepare_connect_request (self .proxy , self .ws_uri ))
782
+ self .transport .write (
783
+ prepare_connect_request (self .proxy , self .ws_uri , self .user_agent_header )
784
+ )
769
785
770
786
def data_received (self , data : bytes ) -> None :
771
787
self .reader .feed_data (data )
@@ -784,10 +800,11 @@ def connection_lost(self, exc: Exception | None) -> None:
784
800
async def connect_http_proxy (
785
801
proxy : Proxy ,
786
802
ws_uri : WebSocketURI ,
803
+ user_agent_header : str | None = None ,
787
804
** kwargs : Any ,
788
805
) -> asyncio .Transport :
789
806
transport , protocol = await asyncio .get_running_loop ().create_connection (
790
- lambda : HTTPProxyConnection (ws_uri , proxy ),
807
+ lambda : HTTPProxyConnection (ws_uri , proxy , user_agent_header ),
791
808
proxy .host ,
792
809
proxy .port ,
793
810
** kwargs ,
0 commit comments