62
62
TimeoutSettings ,
63
63
URLMatch ,
64
64
URLMatcher ,
65
+ WebSocketRouteHandlerCallback ,
65
66
async_readfile ,
66
67
async_writefile ,
67
68
locals_to_params ,
68
69
parse_error ,
69
70
prepare_record_har_options ,
70
71
to_impl ,
71
72
)
72
- from playwright ._impl ._network import Request , Response , Route , serialize_headers
73
+ from playwright ._impl ._network import (
74
+ Request ,
75
+ Response ,
76
+ Route ,
77
+ WebSocketRoute ,
78
+ WebSocketRouteHandler ,
79
+ serialize_headers ,
80
+ )
73
81
from playwright ._impl ._page import BindingCall , Page , Worker
74
82
from playwright ._impl ._str_utils import escape_regex_flags
75
83
from playwright ._impl ._tracing import Tracing
@@ -106,6 +114,7 @@ def __init__(
106
114
self ._browser ._contexts .append (self )
107
115
self ._pages : List [Page ] = []
108
116
self ._routes : List [RouteHandler ] = []
117
+ self ._web_socket_routes : List [WebSocketRouteHandler ] = []
109
118
self ._bindings : Dict [str , Any ] = {}
110
119
self ._timeout_settings = TimeoutSettings (None )
111
120
self ._owner_page : Optional [Page ] = None
@@ -132,7 +141,14 @@ def __init__(
132
141
)
133
142
),
134
143
)
135
-
144
+ self ._channel .on (
145
+ "webSocketRoute" ,
146
+ lambda params : self ._loop .create_task (
147
+ self ._on_web_socket_route (
148
+ from_channel (params ["webSocketRoute" ]),
149
+ )
150
+ ),
151
+ )
136
152
self ._channel .on (
137
153
"backgroundPage" ,
138
154
lambda params : self ._on_background_page (from_channel (params ["page" ])),
@@ -244,10 +260,24 @@ async def _on_route(self, route: Route) -> None:
244
260
try :
245
261
# If the page is closed or unrouteAll() was called without waiting and interception disabled,
246
262
# the method will throw an error - silence it.
247
- await route ._internal_continue ( is_internal = True )
263
+ await route ._inner_continue ( True )
248
264
except Exception :
249
265
pass
250
266
267
+ async def _on_web_socket_route (self , web_socket_route : WebSocketRoute ) -> None :
268
+ route_handler = next (
269
+ (
270
+ route_handler
271
+ for route_handler in self ._web_socket_routes
272
+ if route_handler .matches (web_socket_route .url )
273
+ ),
274
+ None ,
275
+ )
276
+ if route_handler :
277
+ await route_handler .handle (web_socket_route )
278
+ else :
279
+ web_socket_route .connect_to_server ()
280
+
251
281
def _on_binding (self , binding_call : BindingCall ) -> None :
252
282
func = self ._bindings .get (binding_call ._initializer ["name" ])
253
283
if func is None :
@@ -418,6 +448,17 @@ async def _unroute_internal(
418
448
return
419
449
await asyncio .gather (* map (lambda router : router .stop (behavior ), removed )) # type: ignore
420
450
451
+ async def route_web_socket (
452
+ self , url : URLMatch , handler : WebSocketRouteHandlerCallback
453
+ ) -> None :
454
+ self ._web_socket_routes .insert (
455
+ 0 ,
456
+ WebSocketRouteHandler (
457
+ URLMatcher (self ._options .get ("baseURL" ), url ), handler
458
+ ),
459
+ )
460
+ await self ._update_web_socket_interception_patterns ()
461
+
421
462
def _dispose_har_routers (self ) -> None :
422
463
for router in self ._har_routers :
423
464
router .dispose ()
@@ -488,6 +529,14 @@ async def _update_interception_patterns(self) -> None:
488
529
"setNetworkInterceptionPatterns" , {"patterns" : patterns }
489
530
)
490
531
532
+ async def _update_web_socket_interception_patterns (self ) -> None :
533
+ patterns = WebSocketRouteHandler .prepare_interception_patterns (
534
+ self ._web_socket_routes
535
+ )
536
+ await self ._channel .send (
537
+ "setWebSocketInterceptionPatterns" , {"patterns" : patterns }
538
+ )
539
+
491
540
def expect_event (
492
541
self ,
493
542
event : str ,
0 commit comments