@@ -1649,8 +1649,8 @@ def save_output_code(code: str) -> None:
1649
1649
1650
1650
1651
1651
def run_and_get_kernels (
1652
- fn : Callable [..., Any ], * args : Any , ** kwargs : Any
1653
- ) -> tuple [Any , list [str ]]:
1652
+ fn : Callable [P , _T ], * args : P . args , ** kwargs : P . kwargs
1653
+ ) -> tuple [_T , list [str ]]:
1654
1654
result , source_codes = run_and_get_code (fn , * args , ** kwargs )
1655
1655
kernels = []
1656
1656
for code in source_codes :
@@ -1667,7 +1667,7 @@ def run_with_backward() -> Any:
1667
1667
return run_and_get_code (run_with_backward )
1668
1668
1669
1669
1670
- def get_code (fn : Callable [..., Any ], * args : Any , ** kwargs : Any ) -> list [str ]:
1670
+ def get_code (fn : Callable [P , _T ], * args : P . args , ** kwargs : P . kwargs ) -> list [str ]:
1671
1671
"""Get the inductor-generated code, but skip any actual compilation or running."""
1672
1672
from .graph import GraphLowering
1673
1673
@@ -1711,7 +1711,7 @@ def call(self, *args: Any, **kwargs: Any) -> None:
1711
1711
return source_codes
1712
1712
1713
1713
1714
- def get_triton_code (fn : Callable [..., Any ], * args : Any , ** kwargs : Any ) -> str :
1714
+ def get_triton_code (fn : Callable [P , _T ], * args : P . args , ** kwargs : P . kwargs ) -> str :
1715
1715
source_codes = get_code (fn , * args , ** kwargs )
1716
1716
# Can have two outputs if backwards was eagerly compiled
1717
1717
assert 1 <= len (source_codes ) <= 2 , (
@@ -1720,7 +1720,9 @@ def get_triton_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> str:
1720
1720
return source_codes [0 ]
1721
1721
1722
1722
1723
- def run_and_get_triton_code (fn : Callable [..., Any ], * args : Any , ** kwargs : Any ) -> str :
1723
+ def run_and_get_triton_code (
1724
+ fn : Callable [P , _T ], * args : P .args , ** kwargs : P .kwargs
1725
+ ) -> str :
1724
1726
_ , source_codes = run_and_get_code (fn , * args , ** kwargs )
1725
1727
# Can have two outputs if backwards was eagerly compiled
1726
1728
assert 1 <= len (source_codes ) <= 2 , (
@@ -1730,7 +1732,7 @@ def run_and_get_triton_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -
1730
1732
1731
1733
1732
1734
def run_and_get_graph_lowering (
1733
- fn : Callable [..., Any ], * args : Any , ** kwargs : Any
1735
+ fn : Callable [P , _T ], * args : P . args , ** kwargs : P . kwargs
1734
1736
) -> tuple [Any , list [GraphLowering ]]:
1735
1737
from torch ._inductor .graph import GraphLowering
1736
1738
from torch ._inductor .output_code import CompiledFxGraph
@@ -2386,8 +2388,8 @@ def maybe_get_suppress_shape_guards_ctx() -> contextlib.AbstractContextManager[N
2386
2388
2387
2389
2388
2390
def run_and_get_cpp_code (
2389
- fn : Callable [..., Any ], * args : Any , ** kwargs : Any
2390
- ) -> tuple [Any , str ]:
2391
+ fn : Callable [P , _T ], * args : P . args , ** kwargs : P . kwargs
2392
+ ) -> tuple [_T , str ]:
2391
2393
# We use the patch context manager instead of using it as a decorator.
2392
2394
# In this way, we can ensure that the attribute is patched and unpatched correctly
2393
2395
# even if this run_and_get_cpp_code function is called multiple times.
@@ -2431,9 +2433,9 @@ def shape_env_from_inputs(inputs: Sequence[InputType]) -> Optional[ShapeEnv]:
2431
2433
2432
2434
2433
2435
def align_inputs_from_check_idxs (
2434
- model : Callable [[list [InputType ]], Any ],
2436
+ model : Callable [[list [InputType ]], _T ],
2435
2437
inputs_to_check : Sequence [int ],
2436
- ) -> Callable [[list [InputType ]], Any ]:
2438
+ ) -> Callable [[list [InputType ]], _T ]:
2437
2439
if len (inputs_to_check ) == 0 :
2438
2440
return model
2439
2441
0 commit comments