Skip to content

Commit f8218d1

Browse files
authored
Copy executorch codegen from pytorch torchgen to executorch repo
Differential Revision: D74865579 Pull Request resolved: #10939
1 parent d9fcea1 commit f8218d1

File tree

12 files changed

+2353
-4
lines changed

12 files changed

+2353
-4
lines changed

codegen/api/__init__.py

Whitespace-only changes.

codegen/api/custom_ops.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
from __future__ import annotations
2+
3+
from collections import defaultdict
4+
from dataclasses import dataclass
5+
from typing import TYPE_CHECKING
6+
7+
from torchgen import dest
8+
9+
10+
# disable import sorting to avoid circular dependency.
11+
from torchgen.api.types import DispatcherSignature # usort: skip
12+
from torchgen.context import method_with_native_function
13+
from torchgen.model import BaseTy, BaseType, DispatchKey, NativeFunction, Variant
14+
from torchgen.utils import concatMap, Target
15+
16+
17+
if TYPE_CHECKING:
18+
from collections.abc import Sequence
19+
20+
from executorch.codegen.model import ETKernelIndex
21+
from torchgen.selective_build.selector import SelectiveBuilder
22+
23+
24+
# Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at
25+
# model authoring side.
26+
@dataclass(frozen=True)
27+
class ComputeNativeFunctionStub:
28+
@method_with_native_function
29+
def __call__(self, f: NativeFunction) -> str | None:
30+
if Variant.function not in f.variants:
31+
return None
32+
33+
sig = DispatcherSignature.from_schema(
34+
f.func, prefix=f"wrapper_CPU_{f.func.name.overload_name}_", symint=False
35+
)
36+
assert sig is not None
37+
if len(f.func.returns) == 0:
38+
ret_name = ""
39+
elif len(f.func.returns) == 1:
40+
if f.func.arguments.out:
41+
ret_name = f.func.arguments.out[0].name
42+
else:
43+
ret_name = next(
44+
(
45+
a.name
46+
for a in f.func.arguments.flat_non_out
47+
if a.type == f.func.returns[0].type
48+
),
49+
"",
50+
)
51+
if not ret_name:
52+
# if return type is tensor
53+
if f.func.returns[0].type == BaseType(BaseTy.Tensor):
54+
# Returns an empty tensor
55+
ret_name = "at::Tensor()"
56+
else:
57+
raise Exception( # noqa: TRY002
58+
f"Can't handle this return type {f.func}"
59+
) # noqa: TRY002
60+
elif len(f.func.arguments.out) == len(f.func.returns):
61+
# Returns a tuple of out arguments
62+
tensor_type = "at::Tensor &"
63+
comma = ", "
64+
ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>(
65+
{comma.join([r.name for r in f.func.arguments.out])}
66+
)"""
67+
else:
68+
assert all(
69+
a.type == BaseType(BaseTy.Tensor) for a in f.func.returns
70+
), f"Only support tensor returns but got {f.func.returns}"
71+
# Returns a tuple of empty tensors
72+
tensor_type = "at::Tensor"
73+
comma = ", "
74+
ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>(
75+
{comma.join(["at::Tensor()" for _ in f.func.returns])}
76+
)"""
77+
ret_str = f"return {ret_name};" if len(f.func.returns) > 0 else ""
78+
return f"""
79+
{sig.defn()} {{
80+
{ret_str}
81+
}}
82+
"""
83+
84+
85+
def gen_custom_ops_registration(
86+
*,
87+
native_functions: Sequence[NativeFunction],
88+
selector: SelectiveBuilder,
89+
kernel_index: ETKernelIndex,
90+
rocm: bool,
91+
) -> tuple[str, str]:
92+
"""
93+
Generate custom ops registration code for dest.RegisterDispatchKey.
94+
95+
:param native_functions: a sequence of `NativeFunction`
96+
:param selector: for selective build.
97+
:param kernel_index: kernels for all the ops.
98+
:param rocm: bool for dest.RegisterDispatchKey.
99+
:return: generated C++ code to register custom operators into PyTorch
100+
"""
101+
102+
# convert kernel index to BackendIndex. This is because we can't handle ETKernelIndex yet.
103+
# TODO larryliu: evaluate if this code is still needed. If yes let it handle ETKernelIndex.
104+
105+
dispatch_key = DispatchKey.CPU
106+
backend_index = kernel_index._to_backend_index()
107+
static_init_dispatch_registrations = ""
108+
ns_grouped_native_functions: dict[str, list[NativeFunction]] = defaultdict(list)
109+
for native_function in native_functions:
110+
ns_grouped_native_functions[native_function.namespace].append(native_function)
111+
112+
for namespace, functions in ns_grouped_native_functions.items():
113+
if len(functions) == 0:
114+
continue
115+
dispatch_registrations_body = "\n".join(
116+
list(
117+
concatMap(
118+
dest.RegisterDispatchKey(
119+
backend_index,
120+
Target.REGISTRATION,
121+
selector,
122+
rocm=rocm,
123+
symint=False,
124+
class_method_name=None,
125+
skip_dispatcher_op_registration=False,
126+
),
127+
functions,
128+
)
129+
)
130+
)
131+
static_init_dispatch_registrations += f"""
132+
TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
133+
{dispatch_registrations_body}
134+
}}"""
135+
anonymous_definition = "\n".join(
136+
list(
137+
concatMap(
138+
dest.RegisterDispatchKey(
139+
backend_index,
140+
Target.ANONYMOUS_DEFINITION,
141+
selector,
142+
rocm=rocm,
143+
symint=False,
144+
class_method_name=None,
145+
skip_dispatcher_op_registration=False,
146+
),
147+
native_functions,
148+
)
149+
)
150+
)
151+
return anonymous_definition, static_init_dispatch_registrations

0 commit comments

Comments
 (0)