Skip to content

Commit 7b424eb

Browse files
authored
chore: port route chaining, fallback, async, times (#1376)
This is part 2/n of the 1.23 port. Relates #1308, #1374 Ports: - [x] microsoft/playwright@a1324bd (fix(route): support route w/ async handler & times (#14317)) - [x] microsoft/playwright@7a568a2 (feat(route): chain routes (#14771)) - [x] microsoft/playwright@dcdd3c3 (feat(route): explicitly fall back to the next handler (#14834)) - [x] microsoft/playwright@9cf068a (feat(fallback): allow falling back w/ overrides (#14849)) - [x] microsoft/playwright@48f9867 (chore: remove stray fallback overrides check) - [x] microsoft/playwright@ae6f48c (fix(route): match against updated url while chaining (#15112))
1 parent 84d94a3 commit 7b424eb

14 files changed

+2035
-70
lines changed

playwright/_impl/_browser_context.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,11 @@ def __init__(
9696
)
9797
self._channel.on(
9898
"route",
99-
lambda params: self._on_route(
100-
from_channel(params.get("route")), from_channel(params.get("request"))
99+
lambda params: asyncio.create_task(
100+
self._on_route(
101+
from_channel(params.get("route")),
102+
from_channel(params.get("request")),
103+
)
101104
),
102105
)
103106

@@ -156,18 +159,21 @@ def _on_page(self, page: Page) -> None:
156159
if page._opener and not page._opener.is_closed():
157160
page._opener.emit(Page.Events.Popup, page)
158161

159-
def _on_route(self, route: Route, request: Request) -> None:
160-
for handler_entry in self._routes:
161-
if handler_entry.matches(request.url):
162-
try:
163-
handler_entry.handle(route, request)
164-
finally:
165-
if not handler_entry.is_active:
166-
self._routes.remove(handler_entry)
167-
if not len(self._routes) == 0:
168-
asyncio.create_task(self._disable_interception())
169-
break
170-
route._internal_continue()
162+
async def _on_route(self, route: Route, request: Request) -> None:
163+
route_handlers = self._routes.copy()
164+
for route_handler in route_handlers:
165+
if not route_handler.matches(request.url):
166+
continue
167+
if route_handler.will_expire:
168+
self._routes.remove(route_handler)
169+
try:
170+
handled = await route_handler.handle(route, request)
171+
finally:
172+
if len(self._routes) == 0:
173+
asyncio.create_task(self._disable_interception())
174+
if handled:
175+
return
176+
await route._internal_continue(is_internal=True)
171177

172178
def _on_binding(self, binding_call: BindingCall) -> None:
173179
func = self._bindings.get(binding_call._initializer["name"])

playwright/_impl/_helper.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,11 @@ class ErrorPayload(TypedDict, total=False):
7676
value: Optional[Any]
7777

7878

79-
class ContinueParameters(TypedDict, total=False):
79+
class FallbackOverrideParameters(TypedDict, total=False):
8080
url: Optional[str]
8181
method: Optional[str]
82-
headers: Optional[List[NameValue]]
83-
postData: Optional[str]
82+
headers: Optional[Dict[str, str]]
83+
postData: Optional[Union[str, bytes]]
8484

8585

8686
class ParsedMessageParams(TypedDict):
@@ -225,14 +225,17 @@ def __init__(
225225
def matches(self, request_url: str) -> bool:
226226
return self.matcher.matches(request_url)
227227

228-
def handle(self, route: "Route", request: "Request") -> None:
228+
async def handle(self, route: "Route", request: "Request") -> bool:
229+
handled_future = route._start_handling()
230+
handler_task = []
231+
229232
def impl() -> None:
230233
self._handled_count += 1
231234
result = cast(
232235
Callable[["Route", "Request"], Union[Coroutine, Any]], self.handler
233236
)(route, request)
234237
if inspect.iscoroutine(result):
235-
asyncio.create_task(result)
238+
handler_task.append(asyncio.create_task(result))
236239

237240
# As with event handlers, each route handler is a potentially blocking context
238241
# so it needs a fiber.
@@ -242,9 +245,12 @@ def impl() -> None:
242245
else:
243246
impl()
244247

248+
[handled, *_] = await asyncio.gather(handled_future, *handler_task)
249+
return handled
250+
245251
@property
246-
def is_active(self) -> bool:
247-
return self._handled_count < self._times
252+
def will_expire(self) -> bool:
253+
return self._handled_count + 1 >= self._times
248254

249255

250256
def is_safe_close_error(error: Exception) -> bool:

playwright/_impl/_network.py

Lines changed: 114 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,22 @@
4747
from_nullable_channel,
4848
)
4949
from playwright._impl._event_context_manager import EventContextManagerImpl
50-
from playwright._impl._helper import ContinueParameters, locals_to_params
50+
from playwright._impl._helper import FallbackOverrideParameters, locals_to_params
5151
from playwright._impl._wait_helper import WaitHelper
5252

5353
if TYPE_CHECKING: # pragma: no cover
5454
from playwright._impl._fetch import APIResponse
5555
from playwright._impl._frame import Frame
5656

5757

58+
def serialize_headers(headers: Dict[str, str]) -> HeadersArray:
59+
return [
60+
{"name": name, "value": value}
61+
for name, value in headers.items()
62+
if value is not None
63+
]
64+
65+
5866
class Request(ChannelOwner):
5967
def __init__(
6068
self, parent: ChannelOwner, type: str, guid: str, initializer: Dict
@@ -80,21 +88,31 @@ def __init__(
8088
}
8189
self._provisional_headers = RawHeaders(self._initializer["headers"])
8290
self._all_headers_future: Optional[asyncio.Future[RawHeaders]] = None
91+
self._fallback_overrides: FallbackOverrideParameters = (
92+
FallbackOverrideParameters()
93+
)
8394

8495
def __repr__(self) -> str:
8596
return f"<Request url={self.url!r} method={self.method!r}>"
8697

98+
def _apply_fallback_overrides(self, overrides: FallbackOverrideParameters) -> None:
99+
self._fallback_overrides = cast(
100+
FallbackOverrideParameters, {**self._fallback_overrides, **overrides}
101+
)
102+
87103
@property
88104
def url(self) -> str:
89-
return self._initializer["url"]
105+
return cast(str, self._fallback_overrides.get("url", self._initializer["url"]))
90106

91107
@property
92108
def resource_type(self) -> str:
93109
return self._initializer["resourceType"]
94110

95111
@property
96112
def method(self) -> str:
97-
return self._initializer["method"]
113+
return cast(
114+
str, self._fallback_overrides.get("method", self._initializer["method"])
115+
)
98116

99117
async def sizes(self) -> RequestSizes:
100118
response = await self.response()
@@ -104,10 +122,10 @@ async def sizes(self) -> RequestSizes:
104122

105123
@property
106124
def post_data(self) -> Optional[str]:
107-
data = self.post_data_buffer
125+
data = self._fallback_overrides.get("postData", self.post_data_buffer)
108126
if not data:
109127
return None
110-
return data.decode()
128+
return data.decode() if isinstance(data, bytes) else data
111129

112130
@property
113131
def post_data_json(self) -> Optional[Any]:
@@ -124,6 +142,13 @@ def post_data_json(self) -> Optional[Any]:
124142

125143
@property
126144
def post_data_buffer(self) -> Optional[bytes]:
145+
override = self._fallback_overrides.get("post_data")
146+
if override:
147+
return (
148+
override.encode()
149+
if isinstance(override, str)
150+
else cast(bytes, override)
151+
)
127152
b64_content = self._initializer.get("postData")
128153
if b64_content is None:
129154
return None
@@ -157,6 +182,9 @@ def timing(self) -> ResourceTiming:
157182

158183
@property
159184
def headers(self) -> Headers:
185+
override = self._fallback_overrides.get("headers")
186+
if override:
187+
return RawHeaders._from_headers_dict_lossy(override).headers()
160188
return self._provisional_headers.headers()
161189

162190
async def all_headers(self) -> Headers:
@@ -169,6 +197,9 @@ async def header_value(self, name: str) -> Optional[str]:
169197
return (await self._actual_headers()).get(name)
170198

171199
async def _actual_headers(self) -> "RawHeaders":
200+
override = self._fallback_overrides.get("headers")
201+
if override:
202+
return RawHeaders(serialize_headers(override))
172203
if not self._all_headers_future:
173204
self._all_headers_future = asyncio.Future()
174205
headers = await self._channel.send("rawRequestHeaders")
@@ -181,6 +212,21 @@ def __init__(
181212
self, parent: ChannelOwner, type: str, guid: str, initializer: Dict
182213
) -> None:
183214
super().__init__(parent, type, guid, initializer)
215+
self._handling_future: Optional[asyncio.Future["bool"]] = None
216+
217+
def _start_handling(self) -> "asyncio.Future[bool]":
218+
self._handling_future = asyncio.Future()
219+
return self._handling_future
220+
221+
def _report_handled(self, done: bool) -> None:
222+
chain = self._handling_future
223+
assert chain
224+
self._handling_future = None
225+
chain.set_result(done)
226+
227+
def _check_not_handled(self) -> None:
228+
if not self._handling_future:
229+
raise Error("Route is already handled!")
184230

185231
def __repr__(self) -> str:
186232
return f"<Route request={self.request}>"
@@ -203,6 +249,7 @@ async def fulfill(
203249
contentType: str = None,
204250
response: "APIResponse" = None,
205251
) -> None:
252+
self._check_not_handled()
206253
params = locals_to_params(locals())
207254
if response:
208255
del params["response"]
@@ -247,37 +294,74 @@ async def fulfill(
247294
headers["content-length"] = str(length)
248295
params["headers"] = serialize_headers(headers)
249296
await self._race_with_page_close(self._channel.send("fulfill", params))
297+
self._report_handled(True)
250298

251-
async def continue_(
299+
async def fallback(
252300
self,
253301
url: str = None,
254302
method: str = None,
255303
headers: Dict[str, str] = None,
256304
postData: Union[str, bytes] = None,
257305
) -> None:
258-
overrides: ContinueParameters = {}
259-
if url:
260-
overrides["url"] = url
261-
if method:
262-
overrides["method"] = method
263-
if headers:
264-
overrides["headers"] = serialize_headers(headers)
265-
if isinstance(postData, str):
266-
overrides["postData"] = base64.b64encode(postData.encode()).decode()
267-
elif isinstance(postData, bytes):
268-
overrides["postData"] = base64.b64encode(postData).decode()
269-
await self._race_with_page_close(
270-
self._channel.send("continue", cast(Any, overrides))
271-
)
306+
overrides = cast(FallbackOverrideParameters, locals_to_params(locals()))
307+
self._check_not_handled()
308+
self.request._apply_fallback_overrides(overrides)
309+
self._report_handled(False)
272310

273-
def _internal_continue(self) -> None:
311+
async def continue_(
312+
self,
313+
url: str = None,
314+
method: str = None,
315+
headers: Dict[str, str] = None,
316+
postData: Union[str, bytes] = None,
317+
) -> None:
318+
overrides = cast(FallbackOverrideParameters, locals_to_params(locals()))
319+
self._check_not_handled()
320+
self.request._apply_fallback_overrides(overrides)
321+
await self._internal_continue()
322+
self._report_handled(True)
323+
324+
def _internal_continue(
325+
self, is_internal: bool = False
326+
) -> Coroutine[Any, Any, None]:
274327
async def continue_route() -> None:
275328
try:
276-
await self.continue_()
277-
except Exception:
278-
pass
279-
280-
asyncio.create_task(continue_route())
329+
post_data_for_wire: Optional[str] = None
330+
post_data_from_overrides = self.request._fallback_overrides.get(
331+
"postData"
332+
)
333+
if post_data_from_overrides is not None:
334+
post_data_for_wire = (
335+
base64.b64encode(post_data_from_overrides.encode()).decode()
336+
if isinstance(post_data_from_overrides, str)
337+
else base64.b64encode(post_data_from_overrides).decode()
338+
)
339+
params = locals_to_params(
340+
cast(Dict[str, str], self.request._fallback_overrides)
341+
)
342+
if "headers" in params:
343+
params["headers"] = serialize_headers(params["headers"])
344+
if post_data_for_wire is not None:
345+
params["postData"] = post_data_for_wire
346+
await self._race_with_page_close(
347+
self._channel.send(
348+
"continue",
349+
params,
350+
)
351+
)
352+
except Exception as e:
353+
if not is_internal:
354+
raise e
355+
356+
return continue_route()
357+
358+
# FIXME: Port corresponding tests, and call this method
359+
async def _redirected_navigation_request(self, url: str) -> None:
360+
self._check_not_handled()
361+
await self._race_with_page_close(
362+
self._channel.send("redirectNavigationRequest", {"url": url})
363+
)
364+
self._report_handled(True)
281365

282366
async def _race_with_page_close(self, future: Coroutine) -> None:
283367
if hasattr(self.request.frame, "_page"):
@@ -484,17 +568,17 @@ def _on_close(self) -> None:
484568
self.emit(WebSocket.Events.Close, self)
485569

486570

487-
def serialize_headers(headers: Dict[str, str]) -> HeadersArray:
488-
return [{"name": name, "value": value} for name, value in headers.items()]
489-
490-
491571
class RawHeaders:
492572
def __init__(self, headers: HeadersArray) -> None:
493573
self._headers_array = headers
494574
self._headers_map: Dict[str, Dict[str, bool]] = defaultdict(dict)
495575
for header in headers:
496576
self._headers_map[header["name"].lower()][header["value"]] = True
497577

578+
@staticmethod
579+
def _from_headers_dict_lossy(headers: Dict[str, str]) -> "RawHeaders":
580+
return RawHeaders(serialize_headers(headers))
581+
498582
def get(self, name: str) -> Optional[str]:
499583
values = self.get_all(name)
500584
if not values:

0 commit comments

Comments
 (0)