Skip to content

Commit af1c709

Browse files
authored
Add support for ASGI pathsend extension (#2671)
1 parent f099743 commit af1c709

File tree

6 files changed

+152
-11
lines changed

6 files changed

+152
-11
lines changed

starlette/middleware/base.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
from __future__ import annotations
22

3-
from collections.abc import AsyncGenerator, Awaitable, Mapping
4-
from typing import Any, Callable, TypeVar
3+
from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Mapping, MutableMapping
4+
from typing import Any, Callable, TypeVar, Union
55

66
import anyio
77

88
from starlette._utils import collapse_excgroups
99
from starlette.requests import ClientDisconnect, Request
10-
from starlette.responses import AsyncContentStream, Response
10+
from starlette.responses import Response
1111
from starlette.types import ASGIApp, Message, Receive, Scope, Send
1212

1313
RequestResponseEndpoint = Callable[[Request], Awaitable[Response]]
1414
DispatchFunction = Callable[[Request, RequestResponseEndpoint], Awaitable[Response]]
15+
BodyStreamGenerator = AsyncGenerator[Union[bytes, MutableMapping[str, Any]], None]
16+
AsyncContentStream = AsyncIterable[Union[str, bytes, memoryview, MutableMapping[str, Any]]]
1517
T = TypeVar("T")
1618

1719

@@ -159,9 +161,12 @@ async def coro() -> None:
159161

160162
assert message["type"] == "http.response.start"
161163

162-
async def body_stream() -> AsyncGenerator[bytes, None]:
164+
async def body_stream() -> BodyStreamGenerator:
163165
async for message in recv_stream:
164-
assert message["type"] == "http.response.body"
166+
if message["type"] == "http.response.pathsend":
167+
yield message
168+
break
169+
assert message["type"] == "http.response.body", f"Unexpected message: {message}"
165170
body = message.get("body", b"")
166171
if body:
167172
yield body
@@ -214,10 +219,17 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
214219
}
215220
)
216221

222+
should_close_body = True
217223
async for chunk in self.body_iterator:
224+
if isinstance(chunk, dict):
225+
# We got an ASGI message which is not response body (eg: pathsend)
226+
should_close_body = False
227+
await send(chunk)
228+
continue
218229
await send({"type": "http.response.body", "body": chunk, "more_body": True})
219230

220-
await send({"type": "http.response.body", "body": b"", "more_body": False})
231+
if should_close_body:
232+
await send({"type": "http.response.body", "body": b"", "more_body": False})
221233

222234
if self.background:
223235
await self.background()

starlette/middleware/gzip.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,18 @@ async def send_with_compression(self, message: Message) -> None:
9393

9494
await self.send(self.initial_message)
9595
await self.send(message)
96-
elif message_type == "http.response.body": # pragma: no branch
96+
elif message_type == "http.response.body":
9797
# Remaining body in streaming response.
9898
body = message.get("body", b"")
9999
more_body = message.get("more_body", False)
100100

101101
message["body"] = self.apply_compression(body, more_body=more_body)
102102

103103
await self.send(message)
104+
elif message_type == "http.response.pathsend": # pragma: no branch
105+
# Don't apply GZip to pathsend responses
106+
await self.send(self.initial_message)
107+
await self.send(message)
104108

105109
def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
106110
"""Apply compression on the response body.

starlette/responses.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,8 @@ def set_stat_headers(self, stat_result: os.stat_result) -> None:
346346

347347
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
348348
send_header_only: bool = scope["method"].upper() == "HEAD"
349+
send_pathsend: bool = "http.response.pathsend" in scope.get("extensions", {})
350+
349351
if self.stat_result is None:
350352
try:
351353
stat_result = await anyio.to_thread.run_sync(os.stat, self.path)
@@ -364,7 +366,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
364366
http_if_range = headers.get("if-range")
365367

366368
if http_range is None or (http_if_range is not None and not self._should_use_range(http_if_range)):
367-
await self._handle_simple(send, send_header_only)
369+
await self._handle_simple(send, send_header_only, send_pathsend)
368370
else:
369371
try:
370372
ranges = self._parse_range_header(http_range, stat_result.st_size)
@@ -383,10 +385,12 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
383385
if self.background is not None:
384386
await self.background()
385387

386-
async def _handle_simple(self, send: Send, send_header_only: bool) -> None:
388+
async def _handle_simple(self, send: Send, send_header_only: bool, send_pathsend: bool) -> None:
387389
await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers})
388390
if send_header_only:
389391
await send({"type": "http.response.body", "body": b"", "more_body": False})
392+
elif send_pathsend:
393+
await send({"type": "http.response.pathsend", "path": str(self.path)})
390394
else:
391395
async with await anyio.open_file(self.path, mode="rb") as file:
392396
more_body = True

tests/middleware/test_base.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import contextvars
44
from collections.abc import AsyncGenerator, AsyncIterator, Generator
55
from contextlib import AsyncExitStack
6+
from pathlib import Path
67
from typing import Any
78

89
import anyio
@@ -14,7 +15,7 @@
1415
from starlette.middleware import Middleware, _MiddlewareFactory
1516
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
1617
from starlette.requests import ClientDisconnect, Request
17-
from starlette.responses import PlainTextResponse, Response, StreamingResponse
18+
from starlette.responses import FileResponse, PlainTextResponse, Response, StreamingResponse
1819
from starlette.routing import Route, WebSocketRoute
1920
from starlette.testclient import TestClient
2021
from starlette.types import ASGIApp, Message, Receive, Scope, Send
@@ -1198,3 +1199,47 @@ async def send(message: Message) -> None:
11981199
{"type": "http.response.body", "body": b"good!", "more_body": True},
11991200
{"type": "http.response.body", "body": b"", "more_body": False},
12001201
]
1202+
1203+
1204+
@pytest.mark.anyio
1205+
async def test_asgi_pathsend_events(tmpdir: Path) -> None:
1206+
path = tmpdir / "example.txt"
1207+
with path.open("w") as file:
1208+
file.write("<file content>")
1209+
1210+
response_complete = anyio.Event()
1211+
events: list[Message] = []
1212+
1213+
async def endpoint_with_pathsend(_: Request) -> FileResponse:
1214+
return FileResponse(path)
1215+
1216+
async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> Response:
1217+
return await call_next(request)
1218+
1219+
app = Starlette(
1220+
middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)],
1221+
routes=[Route("/", endpoint_with_pathsend)],
1222+
)
1223+
1224+
scope = {
1225+
"type": "http",
1226+
"version": "3",
1227+
"method": "GET",
1228+
"path": "/",
1229+
"headers": [],
1230+
"extensions": {"http.response.pathsend": {}},
1231+
}
1232+
1233+
async def receive() -> Message:
1234+
raise NotImplementedError("Should not be called!") # pragma: no cover
1235+
1236+
async def send(message: Message) -> None:
1237+
events.append(message)
1238+
if message["type"] == "http.response.pathsend":
1239+
response_complete.set()
1240+
1241+
await app(scope, receive, send)
1242+
1243+
assert len(events) == 2
1244+
assert events[0]["type"] == "http.response.start"
1245+
assert events[1]["type"] == "http.response.pathsend"

tests/middleware/test_gzip.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
from __future__ import annotations
22

3+
from pathlib import Path
4+
5+
import pytest
6+
37
from starlette.applications import Starlette
48
from starlette.middleware import Middleware
59
from starlette.middleware.gzip import GZipMiddleware
610
from starlette.requests import Request
7-
from starlette.responses import ContentStream, PlainTextResponse, StreamingResponse
11+
from starlette.responses import ContentStream, FileResponse, PlainTextResponse, StreamingResponse
812
from starlette.routing import Route
13+
from starlette.types import Message
914
from tests.types import TestClientFactory
1015

1116

@@ -156,3 +161,42 @@ async def generator(bytes: bytes, count: int) -> ContentStream:
156161
assert response.text == "x" * 4000
157162
assert "Content-Encoding" not in response.headers
158163
assert "Content-Length" not in response.headers
164+
165+
166+
@pytest.mark.anyio
167+
async def test_gzip_ignored_for_pathsend_responses(tmpdir: Path) -> None:
168+
path = tmpdir / "example.txt"
169+
with path.open("w") as file:
170+
file.write("<file content>")
171+
172+
events: list[Message] = []
173+
174+
async def endpoint_with_pathsend(request: Request) -> FileResponse:
175+
_ = await request.body()
176+
return FileResponse(path)
177+
178+
app = Starlette(
179+
routes=[Route("/", endpoint=endpoint_with_pathsend)],
180+
middleware=[Middleware(GZipMiddleware)],
181+
)
182+
183+
scope = {
184+
"type": "http",
185+
"version": "3",
186+
"method": "GET",
187+
"path": "/",
188+
"headers": [(b"accept-encoding", b"gzip, text")],
189+
"extensions": {"http.response.pathsend": {}},
190+
}
191+
192+
async def receive() -> Message:
193+
return {"type": "http.request", "body": b"", "more_body": False}
194+
195+
async def send(message: Message) -> None:
196+
events.append(message)
197+
198+
await app(scope, receive, send)
199+
200+
assert len(events) == 2
201+
assert events[0]["type"] == "http.response.start"
202+
assert events[1]["type"] == "http.response.pathsend"

tests/test_responses.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,38 @@ def test_file_response_with_range_header(tmp_path: Path, test_client_factory: Te
354354
assert response.headers["content-range"] == f"bytes 0-4/{len(content)}"
355355

356356

357+
@pytest.mark.anyio
358+
async def test_file_response_with_pathsend(tmpdir: Path) -> None:
359+
path = tmpdir / "xyz"
360+
content = b"<file content>" * 1000
361+
with open(path, "wb") as file:
362+
file.write(content)
363+
364+
app = FileResponse(path=path, filename="example.png")
365+
366+
async def receive() -> Message: # type: ignore[empty-body]
367+
... # pragma: no cover
368+
369+
async def send(message: Message) -> None:
370+
if message["type"] == "http.response.start":
371+
assert message["status"] == status.HTTP_200_OK
372+
headers = Headers(raw=message["headers"])
373+
assert headers["content-type"] == "image/png"
374+
assert "content-length" in headers
375+
assert "content-disposition" in headers
376+
assert "last-modified" in headers
377+
assert "etag" in headers
378+
elif message["type"] == "http.response.pathsend": # pragma: no branch
379+
assert message["path"] == str(path)
380+
381+
# Since the TestClient doesn't support `pathsend`, we need to test this directly.
382+
await app(
383+
{"type": "http", "method": "get", "headers": [], "extensions": {"http.response.pathsend": {}}},
384+
receive,
385+
send,
386+
)
387+
388+
357389
def test_set_cookie(test_client_factory: TestClientFactory, monkeypatch: pytest.MonkeyPatch) -> None:
358390
# Mock time used as a reference for `Expires` by stdlib `SimpleCookie`.
359391
mocked_now = dt.datetime(2037, 1, 22, 12, 0, 0, tzinfo=dt.timezone.utc)

0 commit comments

Comments
 (0)