47
47
from_nullable_channel ,
48
48
)
49
49
from playwright ._impl ._event_context_manager import EventContextManagerImpl
50
- from playwright ._impl ._helper import ContinueParameters , locals_to_params
50
+ from playwright ._impl ._helper import FallbackOverrideParameters , locals_to_params
51
51
from playwright ._impl ._wait_helper import WaitHelper
52
52
53
53
if TYPE_CHECKING : # pragma: no cover
54
54
from playwright ._impl ._fetch import APIResponse
55
55
from playwright ._impl ._frame import Frame
56
56
57
57
58
+ def serialize_headers (headers : Dict [str , str ]) -> HeadersArray :
59
+ return [
60
+ {"name" : name , "value" : value }
61
+ for name , value in headers .items ()
62
+ if value is not None
63
+ ]
64
+
65
+
58
66
class Request (ChannelOwner ):
59
67
def __init__ (
60
68
self , parent : ChannelOwner , type : str , guid : str , initializer : Dict
@@ -80,21 +88,31 @@ def __init__(
80
88
}
81
89
self ._provisional_headers = RawHeaders (self ._initializer ["headers" ])
82
90
self ._all_headers_future : Optional [asyncio .Future [RawHeaders ]] = None
91
+ self ._fallback_overrides : FallbackOverrideParameters = (
92
+ FallbackOverrideParameters ()
93
+ )
83
94
84
95
def __repr__ (self ) -> str :
85
96
return f"<Request url={ self .url !r} method={ self .method !r} >"
86
97
98
+ def _apply_fallback_overrides (self , overrides : FallbackOverrideParameters ) -> None :
99
+ self ._fallback_overrides = cast (
100
+ FallbackOverrideParameters , {** self ._fallback_overrides , ** overrides }
101
+ )
102
+
87
103
@property
88
104
def url (self ) -> str :
89
- return self ._initializer ["url" ]
105
+ return cast ( str , self ._fallback_overrides . get ( "url" , self . _initializer ["url" ]))
90
106
91
107
@property
92
108
def resource_type (self ) -> str :
93
109
return self ._initializer ["resourceType" ]
94
110
95
111
@property
96
112
def method (self ) -> str :
97
- return self ._initializer ["method" ]
113
+ return cast (
114
+ str , self ._fallback_overrides .get ("method" , self ._initializer ["method" ])
115
+ )
98
116
99
117
async def sizes (self ) -> RequestSizes :
100
118
response = await self .response ()
@@ -104,10 +122,10 @@ async def sizes(self) -> RequestSizes:
104
122
105
123
@property
106
124
def post_data (self ) -> Optional [str ]:
107
- data = self .post_data_buffer
125
+ data = self ._fallback_overrides . get ( "postData" , self . post_data_buffer )
108
126
if not data :
109
127
return None
110
- return data .decode ()
128
+ return data .decode () if isinstance ( data , bytes ) else data
111
129
112
130
@property
113
131
def post_data_json (self ) -> Optional [Any ]:
@@ -124,6 +142,13 @@ def post_data_json(self) -> Optional[Any]:
124
142
125
143
@property
126
144
def post_data_buffer (self ) -> Optional [bytes ]:
145
+ override = self ._fallback_overrides .get ("post_data" )
146
+ if override :
147
+ return (
148
+ override .encode ()
149
+ if isinstance (override , str )
150
+ else cast (bytes , override )
151
+ )
127
152
b64_content = self ._initializer .get ("postData" )
128
153
if b64_content is None :
129
154
return None
@@ -157,6 +182,9 @@ def timing(self) -> ResourceTiming:
157
182
158
183
@property
159
184
def headers (self ) -> Headers :
185
+ override = self ._fallback_overrides .get ("headers" )
186
+ if override :
187
+ return RawHeaders ._from_headers_dict_lossy (override ).headers ()
160
188
return self ._provisional_headers .headers ()
161
189
162
190
async def all_headers (self ) -> Headers :
@@ -169,6 +197,9 @@ async def header_value(self, name: str) -> Optional[str]:
169
197
return (await self ._actual_headers ()).get (name )
170
198
171
199
async def _actual_headers (self ) -> "RawHeaders" :
200
+ override = self ._fallback_overrides .get ("headers" )
201
+ if override :
202
+ return RawHeaders (serialize_headers (override ))
172
203
if not self ._all_headers_future :
173
204
self ._all_headers_future = asyncio .Future ()
174
205
headers = await self ._channel .send ("rawRequestHeaders" )
@@ -181,6 +212,21 @@ def __init__(
181
212
self , parent : ChannelOwner , type : str , guid : str , initializer : Dict
182
213
) -> None :
183
214
super ().__init__ (parent , type , guid , initializer )
215
+ self ._handling_future : Optional [asyncio .Future ["bool" ]] = None
216
+
217
+ def _start_handling (self ) -> "asyncio.Future[bool]" :
218
+ self ._handling_future = asyncio .Future ()
219
+ return self ._handling_future
220
+
221
+ def _report_handled (self , done : bool ) -> None :
222
+ chain = self ._handling_future
223
+ assert chain
224
+ self ._handling_future = None
225
+ chain .set_result (done )
226
+
227
+ def _check_not_handled (self ) -> None :
228
+ if not self ._handling_future :
229
+ raise Error ("Route is already handled!" )
184
230
185
231
def __repr__ (self ) -> str :
186
232
return f"<Route request={ self .request } >"
@@ -203,6 +249,7 @@ async def fulfill(
203
249
contentType : str = None ,
204
250
response : "APIResponse" = None ,
205
251
) -> None :
252
+ self ._check_not_handled ()
206
253
params = locals_to_params (locals ())
207
254
if response :
208
255
del params ["response" ]
@@ -247,37 +294,74 @@ async def fulfill(
247
294
headers ["content-length" ] = str (length )
248
295
params ["headers" ] = serialize_headers (headers )
249
296
await self ._race_with_page_close (self ._channel .send ("fulfill" , params ))
297
+ self ._report_handled (True )
250
298
251
- async def continue_ (
299
+ async def fallback (
252
300
self ,
253
301
url : str = None ,
254
302
method : str = None ,
255
303
headers : Dict [str , str ] = None ,
256
304
postData : Union [str , bytes ] = None ,
257
305
) -> None :
258
- overrides : ContinueParameters = {}
259
- if url :
260
- overrides ["url" ] = url
261
- if method :
262
- overrides ["method" ] = method
263
- if headers :
264
- overrides ["headers" ] = serialize_headers (headers )
265
- if isinstance (postData , str ):
266
- overrides ["postData" ] = base64 .b64encode (postData .encode ()).decode ()
267
- elif isinstance (postData , bytes ):
268
- overrides ["postData" ] = base64 .b64encode (postData ).decode ()
269
- await self ._race_with_page_close (
270
- self ._channel .send ("continue" , cast (Any , overrides ))
271
- )
306
+ overrides = cast (FallbackOverrideParameters , locals_to_params (locals ()))
307
+ self ._check_not_handled ()
308
+ self .request ._apply_fallback_overrides (overrides )
309
+ self ._report_handled (False )
272
310
273
- def _internal_continue (self ) -> None :
311
+ async def continue_ (
312
+ self ,
313
+ url : str = None ,
314
+ method : str = None ,
315
+ headers : Dict [str , str ] = None ,
316
+ postData : Union [str , bytes ] = None ,
317
+ ) -> None :
318
+ overrides = cast (FallbackOverrideParameters , locals_to_params (locals ()))
319
+ self ._check_not_handled ()
320
+ self .request ._apply_fallback_overrides (overrides )
321
+ await self ._internal_continue ()
322
+ self ._report_handled (True )
323
+
324
+ def _internal_continue (
325
+ self , is_internal : bool = False
326
+ ) -> Coroutine [Any , Any , None ]:
274
327
async def continue_route () -> None :
275
328
try :
276
- await self .continue_ ()
277
- except Exception :
278
- pass
279
-
280
- asyncio .create_task (continue_route ())
329
+ post_data_for_wire : Optional [str ] = None
330
+ post_data_from_overrides = self .request ._fallback_overrides .get (
331
+ "postData"
332
+ )
333
+ if post_data_from_overrides is not None :
334
+ post_data_for_wire = (
335
+ base64 .b64encode (post_data_from_overrides .encode ()).decode ()
336
+ if isinstance (post_data_from_overrides , str )
337
+ else base64 .b64encode (post_data_from_overrides ).decode ()
338
+ )
339
+ params = locals_to_params (
340
+ cast (Dict [str , str ], self .request ._fallback_overrides )
341
+ )
342
+ if "headers" in params :
343
+ params ["headers" ] = serialize_headers (params ["headers" ])
344
+ if post_data_for_wire is not None :
345
+ params ["postData" ] = post_data_for_wire
346
+ await self ._race_with_page_close (
347
+ self ._channel .send (
348
+ "continue" ,
349
+ params ,
350
+ )
351
+ )
352
+ except Exception as e :
353
+ if not is_internal :
354
+ raise e
355
+
356
+ return continue_route ()
357
+
358
+ # FIXME: Port corresponding tests, and call this method
359
+ async def _redirected_navigation_request (self , url : str ) -> None :
360
+ self ._check_not_handled ()
361
+ await self ._race_with_page_close (
362
+ self ._channel .send ("redirectNavigationRequest" , {"url" : url })
363
+ )
364
+ self ._report_handled (True )
281
365
282
366
async def _race_with_page_close (self , future : Coroutine ) -> None :
283
367
if hasattr (self .request .frame , "_page" ):
@@ -484,17 +568,17 @@ def _on_close(self) -> None:
484
568
self .emit (WebSocket .Events .Close , self )
485
569
486
570
487
- def serialize_headers (headers : Dict [str , str ]) -> HeadersArray :
488
- return [{"name" : name , "value" : value } for name , value in headers .items ()]
489
-
490
-
491
571
class RawHeaders :
492
572
def __init__ (self , headers : HeadersArray ) -> None :
493
573
self ._headers_array = headers
494
574
self ._headers_map : Dict [str , Dict [str , bool ]] = defaultdict (dict )
495
575
for header in headers :
496
576
self ._headers_map [header ["name" ].lower ()][header ["value" ]] = True
497
577
578
+ @staticmethod
579
+ def _from_headers_dict_lossy (headers : Dict [str , str ]) -> "RawHeaders" :
580
+ return RawHeaders (serialize_headers (headers ))
581
+
498
582
def get (self , name : str ) -> Optional [str ]:
499
583
values = self .get_all (name )
500
584
if not values :
0 commit comments