Skip to content

Commit 1eb9655

Browse files
committed
fix: Use anyio.SpooledTemporaryFile in UploadFile for proper async handling
1 parent 4a81176 commit 1eb9655

File tree

3 files changed

+109
-96
lines changed

3 files changed

+109
-96
lines changed

starlette/datastructures.py

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from shlex import shlex
55
from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit
66

7-
from starlette.concurrency import run_in_threadpool
7+
import anyio
8+
89
from starlette.types import Scope
910

1011

@@ -413,7 +414,7 @@ class UploadFile:
413414

414415
def __init__(
415416
self,
416-
file: typing.BinaryIO,
417+
file: anyio.SpooledTemporaryFile[bytes],
417418
*,
418419
size: int | None = None,
419420
filename: str | None = None,
@@ -428,37 +429,19 @@ def __init__(
428429
def content_type(self) -> str | None:
429430
return self.headers.get("content-type", None)
430431

431-
@property
432-
def _in_memory(self) -> bool:
433-
# check for SpooledTemporaryFile._rolled
434-
rolled_to_disk = getattr(self.file, "_rolled", True)
435-
return not rolled_to_disk
436-
437432
async def write(self, data: bytes) -> None:
438433
if self.size is not None:
439434
self.size += len(data)
440-
441-
if self._in_memory:
442-
self.file.write(data)
443-
else:
444-
await run_in_threadpool(self.file.write, data)
435+
await self.file.write(data)
445436

446437
async def read(self, size: int = -1) -> bytes:
447-
if self._in_memory:
448-
return self.file.read(size)
449-
return await run_in_threadpool(self.file.read, size)
438+
return await self.file.read(size)
450439

451440
async def seek(self, offset: int) -> None:
452-
if self._in_memory:
453-
self.file.seek(offset)
454-
else:
455-
await run_in_threadpool(self.file.seek, offset)
441+
await self.file.seek(offset)
456442

457443
async def close(self) -> None:
458-
if self._in_memory:
459-
self.file.close()
460-
else:
461-
await run_in_threadpool(self.file.close)
444+
await self.file.aclose()
462445

463446
def __repr__(self) -> str:
464447
return f"{self.__class__.__name__}(filename={self.filename!r}, size={self.size!r}, headers={self.headers!r})"

starlette/formparsers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import typing
44
from dataclasses import dataclass, field
55
from enum import Enum
6-
from tempfile import SpooledTemporaryFile
76
from urllib.parse import unquote_plus
87

8+
from anyio import SpooledTemporaryFile
9+
910
from starlette.datastructures import FormData, Headers, UploadFile
1011

1112
if typing.TYPE_CHECKING:
@@ -208,7 +209,7 @@ def on_headers_finished(self) -> None:
208209
tempfile = SpooledTemporaryFile(max_size=self.spool_max_size)
209210
self._files_to_close_on_error.append(tempfile)
210211
self._current_part.file = UploadFile(
211-
file=tempfile, # type: ignore[arg-type]
212+
file=tempfile,
212213
size=0,
213214
filename=filename,
214215
headers=Headers(raw=self._current_part.item_headers),
@@ -268,7 +269,7 @@ async def parse(self) -> FormData:
268269
except MultiPartException as exc:
269270
# Close all the files if there was an error.
270271
for file in self._files_to_close_on_error:
271-
file.close()
272+
await file.aclose()
272273
raise exc
273274

274275
parser.finalize()

tests/test_datastructures.py

Lines changed: 98 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
import io
2-
from tempfile import SpooledTemporaryFile
3-
from typing import BinaryIO
4-
1+
import anyio
52
import pytest
63

74
from starlette.datastructures import (
@@ -308,29 +305,41 @@ def test_queryparams() -> None:
308305
@pytest.mark.anyio
309306
async def test_upload_file_file_input() -> None:
310307
"""Test passing file/stream into the UploadFile constructor"""
311-
stream = io.BytesIO(b"data")
312-
file = UploadFile(filename="file", file=stream, size=4)
313-
assert await file.read() == b"data"
314-
assert file.size == 4
315-
await file.write(b" and more data!")
316-
assert await file.read() == b""
317-
assert file.size == 19
318-
await file.seek(0)
319-
assert await file.read() == b"data and more data!"
308+
async with anyio.SpooledTemporaryFile(max_size=1024 * 1024) as stream:
309+
await stream.write(b"data")
310+
await stream.seek(0)
311+
312+
file = UploadFile(filename="file", file=stream, size=4)
313+
try:
314+
assert await file.read() == b"data"
315+
assert file.size == 4
316+
await file.write(b" and more data!")
317+
assert await file.read() == b""
318+
assert file.size == 19
319+
await file.seek(0)
320+
assert await file.read() == b"data and more data!"
321+
finally:
322+
await file.close()
320323

321324

322325
@pytest.mark.anyio
323326
async def test_upload_file_without_size() -> None:
324327
"""Test passing file/stream into the UploadFile constructor without size"""
325-
stream = io.BytesIO(b"data")
326-
file = UploadFile(filename="file", file=stream)
327-
assert await file.read() == b"data"
328-
assert file.size is None
329-
await file.write(b" and more data!")
330-
assert await file.read() == b""
331-
assert file.size is None
332-
await file.seek(0)
333-
assert await file.read() == b"data and more data!"
328+
async with anyio.SpooledTemporaryFile(max_size=1024 * 1024) as stream:
329+
await stream.write(b"data")
330+
await stream.seek(0)
331+
332+
file = UploadFile(filename="file", file=stream)
333+
try:
334+
assert await file.read() == b"data"
335+
assert file.size is None
336+
await file.write(b" and more data!")
337+
assert await file.read() == b""
338+
assert file.size is None
339+
await file.seek(0)
340+
assert await file.read() == b"data and more data!"
341+
finally:
342+
await file.close()
334343

335344

336345
@pytest.mark.anyio
@@ -339,61 +348,81 @@ async def test_uploadfile_rolling(max_size: int) -> None:
339348
"""Test that we can r/w to a SpooledTemporaryFile
340349
managed by UploadFile before and after it rolls to disk
341350
"""
342-
stream: BinaryIO = SpooledTemporaryFile( # type: ignore[assignment]
343-
max_size=max_size
344-
)
345-
file = UploadFile(filename="file", file=stream, size=0)
346-
assert await file.read() == b""
347-
assert file.size == 0
348-
await file.write(b"data")
349-
assert await file.read() == b""
350-
assert file.size == 4
351-
await file.seek(0)
352-
assert await file.read() == b"data"
353-
await file.write(b" more")
354-
assert await file.read() == b""
355-
assert file.size == 9
356-
await file.seek(0)
357-
assert await file.read() == b"data more"
358-
assert file.size == 9
359-
await file.close()
360-
361-
362-
def test_formdata() -> None:
363-
stream = io.BytesIO(b"data")
364-
upload = UploadFile(filename="file", file=stream, size=4)
365-
form = FormData([("a", "123"), ("a", "456"), ("b", upload)])
366-
assert "a" in form
367-
assert "A" not in form
368-
assert "c" not in form
369-
assert form["a"] == "456"
370-
assert form.get("a") == "456"
371-
assert form.get("nope", default=None) is None
372-
assert form.getlist("a") == ["123", "456"]
373-
assert list(form.keys()) == ["a", "b"]
374-
assert list(form.values()) == ["456", upload]
375-
assert list(form.items()) == [("a", "456"), ("b", upload)]
376-
assert len(form) == 2
377-
assert list(form) == ["a", "b"]
378-
assert dict(form) == {"a": "456", "b": upload}
379-
assert repr(form) == "FormData([('a', '123'), ('a', '456'), ('b', " + repr(upload) + ")])"
380-
assert FormData(form) == form
381-
assert FormData({"a": "123", "b": "789"}) == FormData([("a", "123"), ("b", "789")])
382-
assert FormData({"a": "123", "b": "789"}) != {"a": "123", "b": "789"}
351+
async with anyio.SpooledTemporaryFile(max_size=max_size) as stream:
352+
file = UploadFile(filename="file", file=stream, size=0)
353+
try:
354+
assert await file.read() == b""
355+
assert file.size == 0
356+
await file.write(b"data")
357+
assert await file.read() == b""
358+
assert file.size == 4
359+
await file.seek(0)
360+
assert await file.read() == b"data"
361+
await file.write(b" more")
362+
assert await file.read() == b""
363+
assert file.size == 9
364+
await file.seek(0)
365+
assert await file.read() == b"data more"
366+
assert file.size == 9
367+
finally:
368+
await file.close()
369+
370+
371+
@pytest.mark.anyio
372+
async def test_formdata() -> None:
373+
async with anyio.SpooledTemporaryFile(max_size=1024) as stream:
374+
await stream.write(b"data")
375+
await stream.seek(0)
376+
377+
upload = UploadFile(filename="file", file=stream, size=4)
378+
379+
form = FormData([("a", "123"), ("a", "456"), ("b", upload)])
380+
381+
assert "a" in form
382+
assert "A" not in form
383+
assert "c" not in form
384+
assert form["a"] == "456"
385+
assert form.get("a") == "456"
386+
assert form.get("nope", default=None) is None
387+
assert form.getlist("a") == ["123", "456"]
388+
assert list(form.keys()) == ["a", "b"]
389+
assert list(form.values()) == ["456", upload]
390+
assert list(form.items()) == [("a", "456"), ("b", upload)]
391+
assert len(form) == 2
392+
assert list(form) == ["a", "b"]
393+
assert dict(form) == {"a": "456", "b": upload}
394+
assert repr(form) == "FormData([('a', '123'), ('a', '456'), ('b', " + repr(upload) + ")])"
395+
assert FormData(form) == form
396+
assert FormData({"a": "123", "b": "789"}) == FormData([("a", "123"), ("b", "789")])
397+
assert FormData({"a": "123", "b": "789"}) != {"a": "123", "b": "789"}
383398

384399

385400
@pytest.mark.anyio
386401
async def test_upload_file_repr() -> None:
387-
stream = io.BytesIO(b"data")
388-
file = UploadFile(filename="file", file=stream, size=4)
389-
assert repr(file) == "UploadFile(filename='file', size=4, headers=Headers({}))"
402+
"""Test the string representation of UploadFile"""
403+
async with anyio.SpooledTemporaryFile(max_size=1024 * 1024) as stream:
404+
await stream.write(b"data")
405+
await stream.seek(0)
406+
407+
file = UploadFile(filename="file", file=stream, size=4)
408+
try:
409+
assert repr(file) == "UploadFile(filename='file', size=4, headers=Headers({}))"
410+
finally:
411+
await file.close()
390412

391413

392414
@pytest.mark.anyio
393415
async def test_upload_file_repr_headers() -> None:
394-
stream = io.BytesIO(b"data")
395-
file = UploadFile(filename="file", file=stream, headers=Headers({"foo": "bar"}))
396-
assert repr(file) == "UploadFile(filename='file', size=None, headers=Headers({'foo': 'bar'}))"
416+
"""Test the string representation of UploadFile with custom headers"""
417+
async with anyio.SpooledTemporaryFile(max_size=1024 * 1024) as stream:
418+
await stream.write(b"data")
419+
await stream.seek(0)
420+
421+
file = UploadFile(filename="file", file=stream, headers=Headers({"foo": "bar"}))
422+
try:
423+
assert repr(file) == "UploadFile(filename='file', size=None, headers=Headers({'foo': 'bar'}))"
424+
finally:
425+
await file.close()
397426

398427

399428
def test_multidict() -> None:

0 commit comments

Comments
 (0)