Skip to content

Commit 2ff7653

Browse files
authored
chore: improve type hints (#2867)
1 parent bc4b937 commit 2ff7653

35 files changed

+459
-427
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ source_pkgs = ["starlette", "tests"]
9191
[tool.coverage.report]
9292
exclude_lines = [
9393
"pragma: no cover",
94-
"if typing.TYPE_CHECKING:",
95-
"@typing.overload",
94+
"if TYPE_CHECKING:",
95+
"@overload",
9696
"raise NotImplementedError",
9797
]

starlette/_exception_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
import typing
3+
from typing import Any
44

55
from starlette._utils import is_async_callable
66
from starlette.concurrency import run_in_threadpool
@@ -9,7 +9,7 @@
99
from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
1010
from starlette.websockets import WebSocket
1111

12-
ExceptionHandlers = dict[typing.Any, ExceptionHandler]
12+
ExceptionHandlers = dict[Any, ExceptionHandler]
1313
StatusHandlers = dict[int, ExceptionHandler]
1414

1515

starlette/_utils.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import functools
44
import inspect
55
import sys
6-
import typing
7-
from contextlib import contextmanager
6+
from collections.abc import Awaitable, Generator
7+
from contextlib import AbstractAsyncContextManager, contextmanager
8+
from typing import Any, Callable, Generic, Protocol, TypeVar, overload
89

910
from starlette.types import Scope
1011

@@ -20,58 +21,58 @@
2021
except ImportError:
2122
has_exceptiongroups = False
2223

23-
T = typing.TypeVar("T")
24-
AwaitableCallable = typing.Callable[..., typing.Awaitable[T]]
24+
T = TypeVar("T")
25+
AwaitableCallable = Callable[..., Awaitable[T]]
2526

2627

27-
@typing.overload
28+
@overload
2829
def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]: ...
2930

3031

31-
@typing.overload
32-
def is_async_callable(obj: typing.Any) -> TypeGuard[AwaitableCallable[typing.Any]]: ...
32+
@overload
33+
def is_async_callable(obj: Any) -> TypeGuard[AwaitableCallable[Any]]: ...
3334

3435

35-
def is_async_callable(obj: typing.Any) -> typing.Any:
36+
def is_async_callable(obj: Any) -> Any:
3637
while isinstance(obj, functools.partial):
3738
obj = obj.func
3839

3940
return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__))
4041

4142

42-
T_co = typing.TypeVar("T_co", covariant=True)
43+
T_co = TypeVar("T_co", covariant=True)
4344

4445

45-
class AwaitableOrContextManager(typing.Awaitable[T_co], typing.AsyncContextManager[T_co], typing.Protocol[T_co]): ...
46+
class AwaitableOrContextManager(Awaitable[T_co], AbstractAsyncContextManager[T_co], Protocol[T_co]): ...
4647

4748

48-
class SupportsAsyncClose(typing.Protocol):
49+
class SupportsAsyncClose(Protocol):
4950
async def close(self) -> None: ... # pragma: no cover
5051

5152

52-
SupportsAsyncCloseType = typing.TypeVar("SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False)
53+
SupportsAsyncCloseType = TypeVar("SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False)
5354

5455

55-
class AwaitableOrContextManagerWrapper(typing.Generic[SupportsAsyncCloseType]):
56+
class AwaitableOrContextManagerWrapper(Generic[SupportsAsyncCloseType]):
5657
__slots__ = ("aw", "entered")
5758

58-
def __init__(self, aw: typing.Awaitable[SupportsAsyncCloseType]) -> None:
59+
def __init__(self, aw: Awaitable[SupportsAsyncCloseType]) -> None:
5960
self.aw = aw
6061

61-
def __await__(self) -> typing.Generator[typing.Any, None, SupportsAsyncCloseType]:
62+
def __await__(self) -> Generator[Any, None, SupportsAsyncCloseType]:
6263
return self.aw.__await__()
6364

6465
async def __aenter__(self) -> SupportsAsyncCloseType:
6566
self.entered = await self.aw
6667
return self.entered
6768

68-
async def __aexit__(self, *args: typing.Any) -> None | bool:
69+
async def __aexit__(self, *args: Any) -> None | bool:
6970
await self.entered.close()
7071
return None
7172

7273

7374
@contextmanager
74-
def collapse_excgroups() -> typing.Generator[None, None, None]:
75+
def collapse_excgroups() -> Generator[None, None, None]:
7576
try:
7677
yield
7778
except BaseException as exc:

starlette/applications.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

33
import sys
4-
import typing
54
import warnings
5+
from collections.abc import Awaitable, Mapping, Sequence
6+
from typing import Any, Callable, TypeVar
67

78
if sys.version_info >= (3, 10): # pragma: no cover
89
from typing import ParamSpec
@@ -20,7 +21,7 @@
2021
from starlette.types import ASGIApp, ExceptionHandler, Lifespan, Receive, Scope, Send
2122
from starlette.websockets import WebSocket
2223

23-
AppType = typing.TypeVar("AppType", bound="Starlette")
24+
AppType = TypeVar("AppType", bound="Starlette")
2425
P = ParamSpec("P")
2526

2627

@@ -30,11 +31,11 @@ class Starlette:
3031
def __init__(
3132
self: AppType,
3233
debug: bool = False,
33-
routes: typing.Sequence[BaseRoute] | None = None,
34-
middleware: typing.Sequence[Middleware] | None = None,
35-
exception_handlers: typing.Mapping[typing.Any, ExceptionHandler] | None = None,
36-
on_startup: typing.Sequence[typing.Callable[[], typing.Any]] | None = None,
37-
on_shutdown: typing.Sequence[typing.Callable[[], typing.Any]] | None = None,
34+
routes: Sequence[BaseRoute] | None = None,
35+
middleware: Sequence[Middleware] | None = None,
36+
exception_handlers: Mapping[Any, ExceptionHandler] | None = None,
37+
on_startup: Sequence[Callable[[], Any]] | None = None,
38+
on_shutdown: Sequence[Callable[[], Any]] | None = None,
3839
lifespan: Lifespan[AppType] | None = None,
3940
) -> None:
4041
"""Initializes the application.
@@ -79,7 +80,7 @@ def __init__(
7980
def build_middleware_stack(self) -> ASGIApp:
8081
debug = self.debug
8182
error_handler = None
82-
exception_handlers: dict[typing.Any, typing.Callable[[Request, Exception], Response]] = {}
83+
exception_handlers: dict[Any, Callable[[Request, Exception], Response]] = {}
8384

8485
for key, value in self.exception_handlers.items():
8586
if key in (500, Exception):
@@ -102,7 +103,7 @@ def build_middleware_stack(self) -> ASGIApp:
102103
def routes(self) -> list[BaseRoute]:
103104
return self.router.routes
104105

105-
def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
106+
def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
106107
return self.router.url_path_for(name, **path_params)
107108

108109
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
@@ -111,7 +112,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
111112
self.middleware_stack = self.build_middleware_stack()
112113
await self.middleware_stack(scope, receive, send)
113114

114-
def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg]
115+
def on_event(self, event_type: str) -> Callable: # type: ignore[type-arg]
115116
return self.router.on_event(event_type) # pragma: no cover
116117

117118
def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None:
@@ -140,14 +141,14 @@ def add_exception_handler(
140141
def add_event_handler(
141142
self,
142143
event_type: str,
143-
func: typing.Callable, # type: ignore[type-arg]
144+
func: Callable, # type: ignore[type-arg]
144145
) -> None: # pragma: no cover
145146
self.router.add_event_handler(event_type, func)
146147

147148
def add_route(
148149
self,
149150
path: str,
150-
route: typing.Callable[[Request], typing.Awaitable[Response] | Response],
151+
route: Callable[[Request], Awaitable[Response] | Response],
151152
methods: list[str] | None = None,
152153
name: str | None = None,
153154
include_in_schema: bool = True,
@@ -157,19 +158,19 @@ def add_route(
157158
def add_websocket_route(
158159
self,
159160
path: str,
160-
route: typing.Callable[[WebSocket], typing.Awaitable[None]],
161+
route: Callable[[WebSocket], Awaitable[None]],
161162
name: str | None = None,
162163
) -> None: # pragma: no cover
163164
self.router.add_websocket_route(path, route, name=name)
164165

165-
def exception_handler(self, exc_class_or_status_code: int | type[Exception]) -> typing.Callable: # type: ignore[type-arg]
166+
def exception_handler(self, exc_class_or_status_code: int | type[Exception]) -> Callable: # type: ignore[type-arg]
166167
warnings.warn(
167168
"The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. "
168169
"Refer to https://www.starlette.io/exceptions/ for the recommended approach.",
169170
DeprecationWarning,
170171
)
171172

172-
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
173+
def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
173174
self.add_exception_handler(exc_class_or_status_code, func)
174175
return func
175176

@@ -181,7 +182,7 @@ def route(
181182
methods: list[str] | None = None,
182183
name: str | None = None,
183184
include_in_schema: bool = True,
184-
) -> typing.Callable: # type: ignore[type-arg]
185+
) -> Callable: # type: ignore[type-arg]
185186
"""
186187
We no longer document this decorator style API, and its usage is discouraged.
187188
Instead you should use the following approach:
@@ -195,7 +196,7 @@ def route(
195196
DeprecationWarning,
196197
)
197198

198-
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
199+
def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
199200
self.router.add_route(
200201
path,
201202
func,
@@ -207,7 +208,7 @@ def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-ar
207208

208209
return decorator
209210

210-
def websocket_route(self, path: str, name: str | None = None) -> typing.Callable: # type: ignore[type-arg]
211+
def websocket_route(self, path: str, name: str | None = None) -> Callable: # type: ignore[type-arg]
211212
"""
212213
We no longer document this decorator style API, and its usage is discouraged.
213214
Instead you should use the following approach:
@@ -221,13 +222,13 @@ def websocket_route(self, path: str, name: str | None = None) -> typing.Callable
221222
DeprecationWarning,
222223
)
223224

224-
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
225+
def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
225226
self.router.add_websocket_route(path, func, name=name)
226227
return func
227228

228229
return decorator
229230

230-
def middleware(self, middleware_type: str) -> typing.Callable: # type: ignore[type-arg]
231+
def middleware(self, middleware_type: str) -> Callable: # type: ignore[type-arg]
231232
"""
232233
We no longer document this decorator style API, and its usage is discouraged.
233234
Instead you should use the following approach:
@@ -242,7 +243,7 @@ def middleware(self, middleware_type: str) -> typing.Callable: # type: ignore[t
242243
)
243244
assert middleware_type == "http", 'Currently only middleware("http") is supported.'
244245

245-
def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
246+
def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
246247
self.add_middleware(BaseHTTPMiddleware, dispatch=func)
247248
return func
248249

starlette/authentication.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import functools
44
import inspect
55
import sys
6-
import typing
6+
from collections.abc import Sequence
7+
from typing import Any, Callable
78
from urllib.parse import urlencode
89

910
if sys.version_info >= (3, 10): # pragma: no cover
@@ -20,23 +21,23 @@
2021
_P = ParamSpec("_P")
2122

2223

23-
def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool:
24+
def has_required_scope(conn: HTTPConnection, scopes: Sequence[str]) -> bool:
2425
for scope in scopes:
2526
if scope not in conn.auth.scopes:
2627
return False
2728
return True
2829

2930

3031
def requires(
31-
scopes: str | typing.Sequence[str],
32+
scopes: str | Sequence[str],
3233
status_code: int = 403,
3334
redirect: str | None = None,
34-
) -> typing.Callable[[typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]]:
35+
) -> Callable[[Callable[_P, Any]], Callable[_P, Any]]:
3536
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
3637

3738
def decorator(
38-
func: typing.Callable[_P, typing.Any],
39-
) -> typing.Callable[_P, typing.Any]:
39+
func: Callable[_P, Any],
40+
) -> Callable[_P, Any]:
4041
sig = inspect.signature(func)
4142
for idx, parameter in enumerate(sig.parameters.values()):
4243
if parameter.name == "request" or parameter.name == "websocket":
@@ -62,7 +63,7 @@ async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
6263
elif is_async_callable(func):
6364
# Handle async request/response functions.
6465
@functools.wraps(func)
65-
async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any:
66+
async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
6667
request = kwargs.get("request", args[idx] if idx < len(args) else None)
6768
assert isinstance(request, Request)
6869

@@ -79,7 +80,7 @@ async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any:
7980
else:
8081
# Handle sync request/response functions.
8182
@functools.wraps(func)
82-
def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any:
83+
def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
8384
request = kwargs.get("request", args[idx] if idx < len(args) else None)
8485
assert isinstance(request, Request)
8586

@@ -106,7 +107,7 @@ async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, Bas
106107

107108

108109
class AuthCredentials:
109-
def __init__(self, scopes: typing.Sequence[str] | None = None):
110+
def __init__(self, scopes: Sequence[str] | None = None):
110111
self.scopes = [] if scopes is None else list(scopes)
111112

112113

starlette/background.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

33
import sys
4-
import typing
4+
from collections.abc import Sequence
5+
from typing import Any, Callable
56

67
if sys.version_info >= (3, 10): # pragma: no cover
78
from typing import ParamSpec
@@ -15,7 +16,7 @@
1516

1617

1718
class BackgroundTask:
18-
def __init__(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None:
19+
def __init__(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None:
1920
self.func = func
2021
self.args = args
2122
self.kwargs = kwargs
@@ -29,10 +30,10 @@ async def __call__(self) -> None:
2930

3031

3132
class BackgroundTasks(BackgroundTask):
32-
def __init__(self, tasks: typing.Sequence[BackgroundTask] | None = None):
33+
def __init__(self, tasks: Sequence[BackgroundTask] | None = None):
3334
self.tasks = list(tasks) if tasks else []
3435

35-
def add_task(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None:
36+
def add_task(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None:
3637
task = BackgroundTask(func, *args, **kwargs)
3738
self.tasks.append(task)
3839

0 commit comments

Comments
 (0)