Skip to content

Commit 875c718

Browse files
authored
[ET-VK] Add support for binary symint ops (#11348)
## Changes * Add an implementation for binary operators which add symbolic integers. ## Motivation Support executing llama models with dynamic shapes. This operator shows up when exporting with dynamic shapes. Differential Revision: [D75238029](https://our.internmc.facebook.com/intern/diff/D75238029/)
1 parent bacc2d6 commit 875c718

File tree

4 files changed

+71
-29
lines changed

4 files changed

+71
-29
lines changed

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import logging
8-
from copy import deepcopy
98
from typing import Any, Optional, Set
109

1110
import executorch.backends.vulkan.utils as utils
@@ -22,6 +21,7 @@
2221
from executorch.exir.dialects._ops import ops as exir_ops
2322

2423
from executorch.exir.pass_base import ExportPass, PassResult
24+
from executorch.exir.tensor import TensorSpec
2525

2626
logger: logging.Logger = logging.getLogger("")
2727
logger.setLevel(logging.INFO)
@@ -52,7 +52,7 @@ def insert_transition_node(
5252
(arg,),
5353
)
5454
clone_node.meta["val"] = arg.meta["val"]
55-
clone_node.meta["spec"] = deepcopy(arg.meta["spec"])
55+
clone_node.meta["spec"] = TensorSpec.from_tensor(clone_node.meta["val"])
5656
clone_node.meta["spec"].const = False
5757
set_memory_metadata(clone_node, storage, layout)
5858
arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y)

backends/vulkan/op_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def update_features_impl(op: OpKey):
230230
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
231231
# Symbolic integer ops
232232
torch.ops.aten.sym_size.int,
233+
operator.add,
233234
]
234235
)
235236
def register_ephemeral_op(features: OpFeatures):

backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,27 @@
1111

1212
namespace vkcompute {
1313

14+
//
15+
// sym_size
16+
//
17+
18+
void sym_size_impl(ComputeGraph* graph, const std::vector<ValueRef>& args) {
19+
const ValueRef in_tensor = args.at(0);
20+
const ValueRef dim = args.at(1);
21+
const ValueRef out_symint = args.at(2);
22+
23+
const int64_t dim_val = graph->extract_scalar<int64_t>(dim);
24+
const int64_t size_at_dim = graph->size_at<int64_t>(dim_val, in_tensor);
25+
26+
graph->set_symint(out_symint, static_cast<int32_t>(size_at_dim));
27+
}
28+
1429
void resize_sym_size_node(
1530
ComputeGraph* graph,
1631
const std::vector<ArgGroup>& args,
17-
const std::vector<ValueRef>& extra_args) {
32+
const std::vector<ValueRef>& resize_args) {
1833
(void)args; // Unused parameter
19-
20-
ValueRef out_symint_ref = extra_args[0];
21-
ValueRef in_tensor_ref = extra_args[1];
22-
23-
int64_t dim = graph->extract_scalar<int64_t>(extra_args[2]);
24-
int64_t size_at_dim = graph->size_at<int64_t>(dim, in_tensor_ref);
25-
26-
graph->set_symint(out_symint_ref, static_cast<int32_t>(size_at_dim));
34+
sym_size_impl(graph, resize_args);
2735
}
2836

2937
/*
@@ -32,21 +40,50 @@ void resize_sym_size_node(
3240
* specified dimension.
3341
*/
3442
void sym_size_int(ComputeGraph& graph, const std::vector<ValueRef>& args) {
35-
ValueRef in_tensor = args[0];
36-
ValueRef dim = args[1];
37-
ValueRef out_symint = args[2];
43+
sym_size_impl(&graph, args);
44+
45+
graph.execute_nodes().emplace_back(
46+
new ExecuteNode(resize_sym_size_node, args));
47+
}
3848

39-
int64_t dim_val = graph.extract_scalar<int64_t>(dim);
49+
//
50+
// binary operators
51+
//
4052

41-
int64_t size_at_dim = graph.size_at<int64_t>(dim_val, in_tensor);
42-
graph.set_symint(out_symint, static_cast<int32_t>(size_at_dim));
53+
void sym_add_impl(ComputeGraph* graph, const std::vector<ValueRef>& args) {
54+
const ValueRef a = args.at(0);
55+
const ValueRef b = args.at(1);
56+
const ValueRef out = args.at(2);
57+
58+
const int32_t a_val = graph->read_symint(a);
59+
const int32_t b_val = graph->read_symint(b);
60+
const int32_t result = a_val + b_val;
61+
62+
graph->set_symint(out, result);
63+
}
64+
65+
void resize_sym_add_node(
66+
ComputeGraph* graph,
67+
const std::vector<ArgGroup>& args,
68+
const std::vector<ValueRef>& resize_args) {
69+
(void)args; // Unused parameter
70+
sym_add_impl(graph, resize_args);
71+
}
72+
73+
/*
74+
* This operator takes two symints as inputs and produces a symint as output.
75+
* The output symint's value is the sum of the two input symints.
76+
*/
77+
void sym_add(ComputeGraph& graph, const std::vector<ValueRef>& args) {
78+
sym_add_impl(&graph, args);
4379

4480
graph.execute_nodes().emplace_back(
45-
new ExecuteNode(resize_sym_size_node, {out_symint, in_tensor, dim}));
81+
new ExecuteNode(resize_sym_add_node, args));
4682
}
4783

4884
REGISTER_OPERATORS {
4985
VK_REGISTER_OP(sym_size.int, sym_size_int);
86+
VK_REGISTER_OP(add, sym_add);
5087
}
5188

5289
} // namespace vkcompute

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -340,17 +340,21 @@ def process_call_function_node(self, node) -> None:
340340

341341
self.seen_ops.add(node.target)
342342

343-
for i, schema_arg in enumerate(node.target._schema.arguments):
344-
if not schema_arg.kwarg_only and i < len(node.args):
345-
function_arg = node.args[i]
346-
elif schema_arg.name in node.kwargs:
347-
function_arg = node.kwargs[schema_arg.name]
348-
else:
349-
function_arg = schema_arg.default_value
350-
351-
# Create a Value for each function argument. If the argument has been
352-
# previously encountered, then use the existing Value id.
353-
operator_call_args.append(self.get_or_create_value_for(function_arg))
343+
if hasattr(node.target, "_schema"):
344+
for i, schema_arg in enumerate(node.target._schema.arguments):
345+
if not schema_arg.kwarg_only and i < len(node.args):
346+
function_arg = node.args[i]
347+
elif schema_arg.name in node.kwargs:
348+
function_arg = node.kwargs[schema_arg.name]
349+
else:
350+
function_arg = schema_arg.default_value
351+
352+
# Create a Value for each function argument. If the argument has been
353+
# previously encountered, then use the existing Value id.
354+
operator_call_args.append(self.get_or_create_value_for(function_arg))
355+
else:
356+
for _, arg_node in enumerate(node.args):
357+
operator_call_args.append(self.get_or_create_value_for(arg_node))
354358

355359
# Add output node
356360
operator_call_args.append(self.create_node_value(node))

0 commit comments

Comments
 (0)