11
11
12
12
namespace vkcompute {
13
13
14
+ //
15
+ // sym_size
16
+ //
17
+
18
+ void sym_size_impl (ComputeGraph* graph, const std::vector<ValueRef>& args) {
19
+ const ValueRef in_tensor = args.at (0 );
20
+ const ValueRef dim = args.at (1 );
21
+ const ValueRef out_symint = args.at (2 );
22
+
23
+ const int64_t dim_val = graph->extract_scalar <int64_t >(dim);
24
+ const int64_t size_at_dim = graph->size_at <int64_t >(dim_val, in_tensor);
25
+
26
+ graph->set_symint (out_symint, static_cast <int32_t >(size_at_dim));
27
+ }
28
+
14
29
void resize_sym_size_node (
15
30
ComputeGraph* graph,
16
31
const std::vector<ArgGroup>& args,
17
- const std::vector<ValueRef>& extra_args ) {
32
+ const std::vector<ValueRef>& resize_args ) {
18
33
(void )args; // Unused parameter
19
-
20
- ValueRef out_symint_ref = extra_args[0 ];
21
- ValueRef in_tensor_ref = extra_args[1 ];
22
-
23
- int64_t dim = graph->extract_scalar <int64_t >(extra_args[2 ]);
24
- int64_t size_at_dim = graph->size_at <int64_t >(dim, in_tensor_ref);
25
-
26
- graph->set_symint (out_symint_ref, static_cast <int32_t >(size_at_dim));
34
+ sym_size_impl (graph, resize_args);
27
35
}
28
36
29
37
/*
@@ -32,21 +40,50 @@ void resize_sym_size_node(
32
40
* specified dimension.
33
41
*/
34
42
void sym_size_int (ComputeGraph& graph, const std::vector<ValueRef>& args) {
35
- ValueRef in_tensor = args[0 ];
36
- ValueRef dim = args[1 ];
37
- ValueRef out_symint = args[2 ];
43
+ sym_size_impl (&graph, args);
44
+
45
+ graph.execute_nodes ().emplace_back (
46
+ new ExecuteNode (resize_sym_size_node, args));
47
+ }
38
48
39
- int64_t dim_val = graph.extract_scalar <int64_t >(dim);
49
+ //
50
+ // binary operators
51
+ //
40
52
41
- int64_t size_at_dim = graph.size_at <int64_t >(dim_val, in_tensor);
42
- graph.set_symint (out_symint, static_cast <int32_t >(size_at_dim));
53
+ void sym_add_impl (ComputeGraph* graph, const std::vector<ValueRef>& args) {
54
+ const ValueRef a = args.at (0 );
55
+ const ValueRef b = args.at (1 );
56
+ const ValueRef out = args.at (2 );
57
+
58
+ const int32_t a_val = graph->read_symint (a);
59
+ const int32_t b_val = graph->read_symint (b);
60
+ const int32_t result = a_val + b_val;
61
+
62
+ graph->set_symint (out, result);
63
+ }
64
+
65
+ void resize_sym_add_node (
66
+ ComputeGraph* graph,
67
+ const std::vector<ArgGroup>& args,
68
+ const std::vector<ValueRef>& resize_args) {
69
+ (void )args; // Unused parameter
70
+ sym_add_impl (graph, resize_args);
71
+ }
72
+
73
+ /*
74
+ * This operator takes two symints as inputs and produces a symint as output.
75
+ * The output symint's value is the sum of the two input symints.
76
+ */
77
+ void sym_add (ComputeGraph& graph, const std::vector<ValueRef>& args) {
78
+ sym_add_impl (&graph, args);
43
79
44
80
graph.execute_nodes ().emplace_back (
45
- new ExecuteNode (resize_sym_size_node, {out_symint, in_tensor, dim} ));
81
+ new ExecuteNode (resize_sym_add_node, args ));
46
82
}
47
83
48
84
REGISTER_OPERATORS {
49
85
VK_REGISTER_OP (sym_size.int , sym_size_int);
86
+ VK_REGISTER_OP (add, sym_add);
50
87
}
51
88
52
89
} // namespace vkcompute
0 commit comments