Skip to content

Commit 9286c83

Browse files
committed
Add support for ASGI pathsend extension in BaseHTTPMiddleware
1 parent 8fa5837 commit 9286c83

File tree

3 files changed

+58
-3
lines changed

3 files changed

+58
-3
lines changed

docs/middleware.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,6 @@ around explicitly, rather than mutating the middleware instance.
264264
Currently, the `BaseHTTPMiddleware` has some known limitations:
265265

266266
- Using `BaseHTTPMiddleware` will prevent changes to [`contextlib.ContextVar`](https://docs.python.org/3/library/contextvars.html#contextvars.ContextVar)s from propagating upwards. That is, if you set a value for a `ContextVar` in your endpoint and try to read it from a middleware you will find that the value is not the same value you set in your endpoint (see [this test](https://github.com/encode/starlette/blob/621abc747a6604825190b93467918a0ec6456a24/tests/middleware/test_base.py#L192-L223) for an example of this behavior).
267-
- Using `BaseHTTPMiddleware` will prevent [ASGI pathsend extension](https://asgi.readthedocs.io/en/latest/extensions.html#path-send) to work properly. Thus, if you run your Starlette application with a server implementing this extension, routes returning [FileResponse](responses.md#fileresponse) should avoid the usage of this middleware.
268267

269268
To overcome these limitations, use [pure ASGI middleware](#pure-asgi-middleware), as shown below.
270269

starlette/middleware/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
1414
DispatchFunction = typing.Callable[[Request, RequestResponseEndpoint], typing.Awaitable[Response]]
15+
BodyStreamGenerator = typing.AsyncGenerator[typing.Union[bytes, typing.MutableMapping[str, typing.Any]], None]
1516
T = typing.TypeVar("T")
1617

1718

@@ -165,9 +166,12 @@ async def coro() -> None:
165166

166167
assert message["type"] == "http.response.start"
167168

168-
async def body_stream() -> typing.AsyncGenerator[bytes, None]:
169+
async def body_stream() -> BodyStreamGenerator:
169170
async with recv_stream:
170171
async for message in recv_stream:
172+
if message["type"] == "http.response.pathsend":
173+
yield message
174+
break
171175
assert message["type"] == "http.response.body"
172176
body = message.get("body", b"")
173177
if body:

tests/middleware/test_base.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import contextvars
44
from contextlib import AsyncExitStack
5+
from pathlib import Path
56
from typing import (
67
Any,
78
AsyncGenerator,
@@ -18,7 +19,7 @@
1819
from starlette.middleware import Middleware, _MiddlewareClass
1920
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
2021
from starlette.requests import ClientDisconnect, Request
21-
from starlette.responses import PlainTextResponse, Response, StreamingResponse
22+
from starlette.responses import FileResponse, PlainTextResponse, Response, StreamingResponse
2223
from starlette.routing import Route, WebSocketRoute
2324
from starlette.testclient import TestClient
2425
from starlette.types import ASGIApp, Message, Receive, Scope, Send
@@ -1132,3 +1133,54 @@ async def send(message: Message) -> None:
11321133
{"type": "http.response.body", "body": b"good!", "more_body": True},
11331134
{"type": "http.response.body", "body": b"", "more_body": False},
11341135
]
1136+
1137+
1138+
@pytest.mark.anyio
1139+
async def test_asgi_pathsend_events(tmpdir: Path) -> None:
1140+
path = tmpdir / "example.txt"
1141+
with path.open("w") as file:
1142+
file.write("<file content>")
1143+
1144+
request_body_sent = False
1145+
response_complete = anyio.Event()
1146+
events: list[Message] = []
1147+
1148+
async def endpoint_with_pathsend(_: Request) -> FileResponse:
1149+
return FileResponse(path)
1150+
1151+
async def passthrough(
1152+
request: Request, call_next: RequestResponseEndpoint
1153+
) -> Response:
1154+
return await call_next(request)
1155+
1156+
app = Starlette(
1157+
middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)],
1158+
routes=[Route("/", endpoint_with_pathsend)],
1159+
)
1160+
1161+
scope = {
1162+
"type": "http",
1163+
"version": "3",
1164+
"method": "GET",
1165+
"path": "/",
1166+
"extensions": {"http.response.pathsend": {}},
1167+
}
1168+
1169+
async def receive() -> Message:
1170+
nonlocal request_body_sent
1171+
if not request_body_sent:
1172+
request_body_sent = True
1173+
return {"type": "http.request", "body": b"", "more_body": False}
1174+
await response_complete.wait()
1175+
return {"type": "http.disconnect"}
1176+
1177+
async def send(message: Message) -> None:
1178+
events.append(message)
1179+
if message["type"] == "http.response.pathsend":
1180+
response_complete.set()
1181+
1182+
await app(scope, receive, send)
1183+
1184+
assert len(events) == 2
1185+
assert events[0]["type"] == "http.response.start"
1186+
assert events[1]["type"] == "http.response.pathsend"

0 commit comments

Comments
 (0)