Skip to content

Commit 8494d55

Browse files
tommyadams5pytorchmergebot
authored andcommitted
Propagate callable parameter types using ParamSpec (pytorch#142306) (pytorch#151014)
Partially addresses pytorch#142306 Pull Request resolved: pytorch#151014 Approved by: https://github.com/Skylion007
1 parent 3f0931b commit 8494d55

File tree

12 files changed

+71
-44
lines changed

12 files changed

+71
-44
lines changed

torch/_dynamo/external_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -190,11 +190,11 @@ def call_module_hooks_from_backward_state(
190190

191191

192192
# used for torch._dynamo.disable(recursive=False)
193-
def get_nonrecursive_disable_wrapper(fn: Callable[..., Any]) -> Callable[..., Any]:
193+
def get_nonrecursive_disable_wrapper(fn: Callable[_P, _R]) -> Callable[_P, _R]:
194194
# wrap function to get the right error message
195195
# this function is in external_utils so that convert_frame doesn't skip it.
196196
@functools.wraps(fn)
197-
def nonrecursive_disable_wrapper(*args, **kwargs): # type: ignore[no-untyped-def]
197+
def nonrecursive_disable_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
198198
return fn(*args, **kwargs)
199199

200200
return nonrecursive_disable_wrapper

torch/_dynamo/testing.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -465,31 +465,31 @@ def make_test_cls_with_patches(
465465

466466

467467
# test Python 3.11+ specific features
468-
def skipIfNotPy311(fn: Callable[..., Any]) -> Callable[..., Any]:
468+
def skipIfNotPy311(fn: Callable[_P, _T]) -> Callable[_P, _T]:
469469
if sys.version_info >= (3, 11):
470470
return fn
471471
return unittest.skip(fn)
472472

473473

474-
def skipIfNotPy312(fn: Callable[..., Any]) -> Callable[..., Any]:
474+
def skipIfNotPy312(fn: Callable[_P, _T]) -> Callable[_P, _T]:
475475
if sys.version_info >= (3, 12):
476476
return fn
477477
return unittest.skip("Requires Python 3.12+")(fn)
478478

479479

480-
def xfailIfPy312(fn: Callable[..., Any]) -> Callable[..., Any]:
480+
def xfailIfPy312(fn: Callable[_P, _T]) -> Callable[_P, _T]:
481481
if sys.version_info >= (3, 12):
482482
return unittest.expectedFailure(fn)
483483
return fn
484484

485485

486-
def skipIfPy312(fn: Callable[..., Any]) -> Callable[..., Any]:
486+
def skipIfPy312(fn: Callable[_P, _T]) -> Callable[_P, _T]:
487487
if sys.version_info >= (3, 12):
488488
return unittest.skip("Not supported in Python 3.12+")(fn)
489489
return fn
490490

491491

492-
def requiresPy310(fn: Callable[..., Any]) -> Callable[..., Any]:
492+
def requiresPy310(fn: Callable[_P, _T]) -> Callable[_P, _T]:
493493
if sys.version_info >= (3, 10):
494494
return fn
495495
else:
@@ -498,19 +498,19 @@ def requiresPy310(fn: Callable[..., Any]) -> Callable[..., Any]:
498498

499499
# Controls tests generated in test/inductor/test_torchinductor_dynamic_shapes.py
500500
# and test/dynamo/test_dynamic_shapes.py
501-
def expectedFailureDynamic(fn: Callable[..., Any]) -> Callable[..., Any]:
501+
def expectedFailureDynamic(fn: Callable[_P, _T]) -> Callable[_P, _T]:
502502
fn._expected_failure_dynamic = True # type: ignore[attr-defined]
503503
return fn
504504

505505

506506
# Controls tests generated in test/inductor/test_torchinductor_codegen_dynamic_shapes.py
507-
def expectedFailureCodegenDynamic(fn: Callable[..., Any]) -> Callable[..., Any]:
507+
def expectedFailureCodegenDynamic(fn: Callable[_P, _T]) -> Callable[_P, _T]:
508508
fn._expected_failure_codegen_dynamic = True # type: ignore[attr-defined]
509509
return fn
510510

511511

512512
# Controls test generated in test/inductor/test_cpp_wrapper.py
513-
def expectedFailureDynamicWrapper(fn: Callable[..., Any]) -> Callable[..., Any]:
513+
def expectedFailureDynamicWrapper(fn: Callable[_P, _T]) -> Callable[_P, _T]:
514514
fn._expected_failure_dynamic_wrapper = True # type: ignore[attr-defined]
515515
return fn
516516

torch/_dynamo/variables/torch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def get_overridable_functions():
203203
from torch.overrides import get_overridable_functions as get_overridable_functions_
204204

205205
funcs = set(chain.from_iterable(get_overridable_functions_().values()))
206-
more = {
206+
more: set[Callable[..., Any]] = {
207207
torch.ones,
208208
torch.ones_like,
209209
torch.zeros,

torch/_inductor/utils.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -1649,8 +1649,8 @@ def save_output_code(code: str) -> None:
16491649

16501650

16511651
def run_and_get_kernels(
1652-
fn: Callable[..., Any], *args: Any, **kwargs: Any
1653-
) -> tuple[Any, list[str]]:
1652+
fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs
1653+
) -> tuple[_T, list[str]]:
16541654
result, source_codes = run_and_get_code(fn, *args, **kwargs)
16551655
kernels = []
16561656
for code in source_codes:
@@ -1667,7 +1667,7 @@ def run_with_backward() -> Any:
16671667
return run_and_get_code(run_with_backward)
16681668

16691669

1670-
def get_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> list[str]:
1670+
def get_code(fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs) -> list[str]:
16711671
"""Get the inductor-generated code, but skip any actual compilation or running."""
16721672
from .graph import GraphLowering
16731673

@@ -1711,7 +1711,7 @@ def call(self, *args: Any, **kwargs: Any) -> None:
17111711
return source_codes
17121712

17131713

1714-
def get_triton_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> str:
1714+
def get_triton_code(fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs) -> str:
17151715
source_codes = get_code(fn, *args, **kwargs)
17161716
# Can have two outputs if backwards was eagerly compiled
17171717
assert 1 <= len(source_codes) <= 2, (
@@ -1720,7 +1720,9 @@ def get_triton_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> str:
17201720
return source_codes[0]
17211721

17221722

1723-
def run_and_get_triton_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> str:
1723+
def run_and_get_triton_code(
1724+
fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs
1725+
) -> str:
17241726
_, source_codes = run_and_get_code(fn, *args, **kwargs)
17251727
# Can have two outputs if backwards was eagerly compiled
17261728
assert 1 <= len(source_codes) <= 2, (
@@ -1730,7 +1732,7 @@ def run_and_get_triton_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -
17301732

17311733

17321734
def run_and_get_graph_lowering(
1733-
fn: Callable[..., Any], *args: Any, **kwargs: Any
1735+
fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs
17341736
) -> tuple[Any, list[GraphLowering]]:
17351737
from torch._inductor.graph import GraphLowering
17361738
from torch._inductor.output_code import CompiledFxGraph
@@ -2386,8 +2388,8 @@ def maybe_get_suppress_shape_guards_ctx() -> contextlib.AbstractContextManager[N
23862388

23872389

23882390
def run_and_get_cpp_code(
2389-
fn: Callable[..., Any], *args: Any, **kwargs: Any
2390-
) -> tuple[Any, str]:
2391+
fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs
2392+
) -> tuple[_T, str]:
23912393
# We use the patch context manager instead of using it as a decorator.
23922394
# In this way, we can ensure that the attribute is patched and unpatched correctly
23932395
# even if this run_and_get_cpp_code function is called multiple times.
@@ -2431,9 +2433,9 @@ def shape_env_from_inputs(inputs: Sequence[InputType]) -> Optional[ShapeEnv]:
24312433

24322434

24332435
def align_inputs_from_check_idxs(
2434-
model: Callable[[list[InputType]], Any],
2436+
model: Callable[[list[InputType]], _T],
24352437
inputs_to_check: Sequence[int],
2436-
) -> Callable[[list[InputType]], Any]:
2438+
) -> Callable[[list[InputType]], _T]:
24372439
if len(inputs_to_check) == 0:
24382440
return model
24392441

torch/_prims_common/wrappers.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import Sequence
66
from functools import wraps
77
from types import GenericAlias
8-
from typing import Callable, NamedTuple, Optional, overload, TypeVar
8+
from typing import Callable, NamedTuple, Optional, overload, TypeVar, Union
99
from typing_extensions import ParamSpec
1010

1111
import torch
@@ -285,7 +285,8 @@ def _out_wrapper(fn: Callable[_P, _T]) -> Callable[_P, _T]:
285285
is_factory_fn = all(p in sig.parameters for p in factory_kwargs)
286286

287287
@wraps(fn)
288-
def _fn(*args: _P.args, out=None, **kwargs: _P.kwargs):
288+
def _fn(*args: _P.args, **kwargs: _P.kwargs):
289+
out = kwargs.pop("out", None)
289290
if is_factory_fn and out is not None:
290291
for k in factory_kwargs:
291292
out_attr = getattr(out, k)
@@ -450,7 +451,9 @@ def _autograd_impl(*args, **kwargs):
450451
# TODO: when tracing this will add torch tensors and not TensorMeta objects
451452
# to the trace -- we should fix this by adding a tracing context and NumberMeta classes
452453
# TODO: this wrapper is currently untested
453-
def elementwise_unary_scalar_wrapper(fn: Callable) -> Callable:
454+
def elementwise_unary_scalar_wrapper(
455+
fn: Callable[_P, _T],
456+
) -> Callable[_P, Union[_T, NumberType]]:
454457
"""
455458
Allows unary operators that accept tensors to work with Python numbers.
456459
"""

torch/distributed/elastic/control_plane.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def _worker_server(socket_path: str) -> Generator[None, None, None]:
2323
server.shutdown()
2424

2525

26-
@contextmanager
2726
@record
27+
@contextmanager
2828
def worker_main() -> Generator[None, None, None]:
2929
"""
3030
This is a context manager that wraps your main entry function. This combines

torch/distributed/elastic/multiprocessing/errors/__init__.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@
5858
from datetime import datetime
5959
from functools import wraps
6060
from string import Template
61-
from typing import Any, Callable, Optional, TypeVar
61+
from typing import Any, Callable, Optional, TypeVar, Union
62+
from typing_extensions import ParamSpec
6263

6364
from torch.distributed.elastic.utils.logging import get_logger
6465

@@ -82,7 +83,8 @@
8283
_EMPTY_ERROR_DATA = {"message": "<NONE>"}
8384
_NOT_AVAILABLE = "<N/A>"
8485

85-
T = TypeVar("T")
86+
_R = TypeVar("_R")
87+
_P = ParamSpec("_P")
8688

8789

8890
@dataclass
@@ -305,8 +307,8 @@ def _format_failure(
305307

306308

307309
def record(
308-
fn: Callable[..., T], error_handler: Optional[ErrorHandler] = None
309-
) -> Callable[..., T]:
310+
fn: Callable[_P, _R], error_handler: Optional[ErrorHandler] = None
311+
) -> Callable[_P, Union[_R, None]]:
310312
"""
311313
Syntactic sugar to record errors/exceptions that happened in the decorated
312314
function using the provided ``error_handler``.
@@ -346,9 +348,9 @@ def main():
346348
if not error_handler:
347349
error_handler = get_error_handler()
348350

349-
def wrap(f):
351+
def wrap(f: Callable[_P, _R]) -> Callable[_P, Union[_R, None]]:
350352
@wraps(f)
351-
def wrapper(*args, **kwargs):
353+
def wrapper(*args: _P.args, **kwargs: _P.kwargs):
352354
assert error_handler is not None # assertion for mypy type checker
353355
error_handler.initialize()
354356
try:

torch/distributed/elastic/timer/file_based_local_timer.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
import sys
1414
import threading
1515
import time
16-
from typing import Callable, Optional
16+
from typing import Callable, Optional, TypeVar
17+
from typing_extensions import ParamSpec
1718

1819
from torch.distributed.elastic.timer.api import TimerClient, TimerRequest
1920
from torch.distributed.elastic.timer.debug_info_logging import (
@@ -22,6 +23,9 @@
2223
from torch.distributed.elastic.utils.logging import get_logger
2324

2425

26+
_P = ParamSpec("_P")
27+
_R = TypeVar("_R")
28+
2529
__all__ = ["FileTimerClient", "FileTimerRequest", "FileTimerServer"]
2630

2731
logger = get_logger(__name__)
@@ -36,8 +40,8 @@ def _retry(max_retries: int, sleep_time: float) -> Callable:
3640
sleep_time: float, the time to sleep between retries.
3741
"""
3842

39-
def wrapper(func: Callable) -> Callable:
40-
def wrapper(*args, **kwargs):
43+
def wrapper(func: Callable[_P, _R]) -> Callable[_P, _R]:
44+
def wrapper(*args: _P.args, **kwargs: _P.kwargs):
4145
for i in range(max_retries):
4246
try:
4347
return func(*args, **kwargs)

torch/fx/node.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import types
77
from collections.abc import Mapping, Sequence
88
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
9+
from typing_extensions import ParamSpec
910

1011
import torch
1112
from torch._C import _fx_map_aggregate, _fx_map_arg, _NodeBase
@@ -58,6 +59,8 @@
5859
]
5960
]
6061
ArgumentT = TypeVar("ArgumentT", bound=Argument)
62+
_P = ParamSpec("_P")
63+
_R = TypeVar("_R")
6164

6265
_legal_ops = dict.fromkeys(
6366
[
@@ -102,7 +105,7 @@
102105

103106

104107
@compatibility(is_backward_compatible=False)
105-
def has_side_effect(fn: Callable) -> Callable:
108+
def has_side_effect(fn: Callable[_P, _R]) -> Callable[_P, _R]:
106109
_side_effectful_functions.add(fn)
107110
return fn
108111

torch/onnx/_internal/exporter/_isolated.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
11
"""Isolated calls to methods that may segfault."""
22

3-
# mypy: allow-untyped-defs
43
from __future__ import annotations
54

65
import multiprocessing
76
import os
87
import warnings
9-
from typing import Callable
8+
from typing import Any, Callable, TypeVar, TypeVarTuple, Union, Unpack
9+
from typing_extensions import ParamSpec
1010

1111

12+
_P = ParamSpec("_P")
13+
_R = TypeVar("_R")
14+
_Ts = TypeVarTuple("_Ts")
15+
1216
_IS_WINDOWS = os.name == "nt"
1317

1418

15-
def _call_function_and_return_exception(func, args, kwargs):
19+
def _call_function_and_return_exception(
20+
func: Callable[[Unpack[_Ts]], _R], args: tuple[Unpack[_Ts]], kwargs: dict[str, Any]
21+
) -> Union[_R, Exception]:
1622
"""Call function and return a exception if there is one."""
1723

1824
try:
@@ -21,7 +27,7 @@ def _call_function_and_return_exception(func, args, kwargs):
2127
return e
2228

2329

24-
def safe_call(func: Callable, *args, **kwargs):
30+
def safe_call(func: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs) -> _R:
2531
"""Call a function in a separate process.
2632
2733
Args:

torch/onnx/_internal/registration.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import warnings
55
from collections.abc import Collection, Sequence
66
from typing import Callable, Generic, Optional, TypeVar, Union
7+
from typing_extensions import ParamSpec
78

89
from torch.onnx import _constants, errors
910

@@ -51,6 +52,8 @@ def _dispatch_opset_version(
5152

5253
_K = TypeVar("_K")
5354
_V = TypeVar("_V")
55+
_R = TypeVar("_R")
56+
_P = ParamSpec("_P")
5457

5558

5659
class OverrideDict(Collection[_K], Generic[_K, _V]):
@@ -287,7 +290,7 @@ def symbolic_b(g: _C.Graph, x: _C.Value, y: _C.Value, arg1: bool) -> _C.Value: .
287290
ValueError: If the separator '::' is not in the name.
288291
"""
289292

290-
def wrapper(func: Callable) -> Callable:
293+
def wrapper(func: Callable[_P, _R]) -> Callable[_P, _R]:
291294
decorated = func
292295
if decorate is not None:
293296
for decorate_func in decorate:

torch/overrides.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
import warnings
3030
from collections.abc import Iterable
3131
from functools import wraps
32-
from typing import Any, Callable, Optional
32+
from typing import Any, Callable, Optional, TypeVar
33+
from typing_extensions import ParamSpec
3334

3435
import torch
3536
from torch._C import (
@@ -58,12 +59,15 @@
5859
"enable_reentrant_dispatch",
5960
]
6061

62+
_P = ParamSpec("_P")
63+
_R = TypeVar("_R")
64+
6165

6266
def _disable_user_warnings(
63-
func: Callable,
67+
func: Callable[_P, _R],
6468
regex: str = ".*is deprecated, please use.*",
6569
module: str = "torch",
66-
) -> Callable:
70+
) -> Callable[_P, _R]:
6771
"""
6872
Decorator that temporarily disables ``UserWarning``s for the given ``module`` if the warning message matches the
6973
given ``regex`` pattern.
@@ -84,7 +88,7 @@ def _disable_user_warnings(
8488
"""
8589

8690
@wraps(func)
87-
def wrapper(*args, **kwargs):
91+
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
8892
with warnings.catch_warnings():
8993
warnings.filterwarnings(
9094
"ignore", category=UserWarning, message=regex, module=module

0 commit comments

Comments
 (0)