8
8
9
9
#include < c10/util/irange.h>
10
10
#include < executorch/kernels/portable/cpu/util/activation_ops_util.h>
11
+ #include < executorch/kernels/portable/cpu/util/elementwise_util.h>
11
12
#include < executorch/runtime/kernel/kernel_includes.h>
12
13
#include < executorch/runtime/platform/assert.h>
13
14
#include < cinttypes>
@@ -23,92 +24,46 @@ using ScalarType = executorch::aten::ScalarType;
23
24
24
25
namespace {
25
26
26
- double exp_overload (double d) {
27
- return exp (d);
28
- }
29
-
30
- float exp_overload (float f) {
31
- return expf (f);
32
- }
33
-
34
- /* *
35
- * In-place element-wise sigmoid function , i.e., f(x) = 1 / (1 + e^{-x})
36
- */
37
- // TODO: T146333648, refactor this as a common helper function
38
- template <typename CTYPE_OUT>
39
- void sigmoid_tensor (Tensor& out) {
40
- CTYPE_OUT* out_data = out.mutable_data_ptr <CTYPE_OUT>();
41
- for (const auto i : c10::irange (out.numel ())) {
42
- out_data[i] = 1.0 / (1.0 + exp_overload (-out_data[i]));
43
- }
44
- }
45
-
46
- /* *
47
- * Element-wise multiplication of the first half of `in` along the specified
48
- * dimension and `out`, overwriting `out`.
49
- */
50
- template <typename CTYPE_IN, typename CTYPE_OUT>
51
- void mul_tensors (const Tensor& in, int64_t dim, Tensor& out) {
52
- size_t num_values = static_cast <size_t >(in.size (dim)) / 2 ;
53
- size_t dim_length_in = static_cast <size_t >(in.size (dim));
54
- size_t dim_length_out = static_cast <size_t >(out.size (dim));
55
- size_t leading_dims = getLeadingDims (in, dim);
56
- size_t trailing_dims = getTrailingDims (in, dim);
57
-
58
- const CTYPE_IN* input_data_base = in.const_data_ptr <CTYPE_IN>();
59
- CTYPE_OUT* output_data_base = out.mutable_data_ptr <CTYPE_OUT>();
60
-
61
- for (const auto i : c10::irange (leading_dims)) {
62
- const CTYPE_IN* input_data =
63
- input_data_base + i * dim_length_in * trailing_dims;
64
- CTYPE_OUT* output_data =
65
- output_data_base + i * dim_length_out * trailing_dims;
66
- for ([[maybe_unused]] const auto j : c10::irange (num_values)) {
67
- for (const auto k : c10::irange (trailing_dims)) {
68
- output_data[k] = static_cast <CTYPE_OUT>(input_data[k]) * output_data[k];
69
- }
70
- input_data += trailing_dims;
71
- output_data += trailing_dims;
72
- }
27
+ struct SplitGLUInputTensor {
28
+ explicit SplitGLUInputTensor (const Tensor& self, int64_t dim);
29
+ using SizesArray =
30
+ std::array<executorch::aten::SizesType, kTensorDimensionLimit >;
31
+ SizesArray half_sizes;
32
+ TensorImpl first_half_impl;
33
+ TensorImpl second_half_impl;
34
+ Tensor first_half;
35
+ Tensor second_half;
36
+
37
+ private:
38
+ static SizesArray get_half_sizes (const Tensor& self, int64_t dim) {
39
+ SizesArray half_sizes;
40
+ std::copy (self.sizes ().begin (), self.sizes ().end (), half_sizes.begin ());
41
+ half_sizes[dim] /= 2 ;
42
+ return half_sizes;
73
43
}
74
- }
75
-
76
- /* *
77
- * Slice the tensor in the given dim, from start to end, assume tensor in and
78
- * out have same shape and dtype, the dim is a non-negative number and start,
79
- * end are valid non-negative number
80
- */
81
- template <typename CTYPE_IN, typename CTYPE_OUT>
82
- void slice_tensor (
83
- const Tensor& in,
84
- int64_t dim,
85
- int64_t start,
86
- int64_t end,
87
- Tensor& out) {
88
- size_t num_values = static_cast <size_t >(end - start);
89
- size_t dim_length_in = static_cast <size_t >(in.size (dim));
90
- size_t dim_length_out = static_cast <size_t >(out.size (dim));
91
- size_t non_negative_start = static_cast <size_t >(start);
92
- size_t leading_dims = getLeadingDims (in, dim);
93
- size_t trailing_dims = getTrailingDims (in, dim);
94
-
95
- const CTYPE_IN* input_data_base = in.const_data_ptr <CTYPE_IN>();
96
- CTYPE_OUT* output_data_base = out.mutable_data_ptr <CTYPE_OUT>();
97
-
98
- for (const auto i : c10::irange (leading_dims)) {
99
- const CTYPE_IN* input_data = input_data_base +
100
- (i * dim_length_in + non_negative_start) * trailing_dims;
101
- CTYPE_OUT* output_data =
102
- output_data_base + i * dim_length_out * trailing_dims;
103
- for ([[maybe_unused]] const auto j : c10::irange (num_values)) {
104
- for (const auto k : c10::irange (trailing_dims)) {
105
- output_data[k] = static_cast <CTYPE_OUT>(input_data[k]);
106
- }
107
- input_data += trailing_dims;
108
- output_data += trailing_dims;
109
- }
110
- }
111
- }
44
+ };
45
+
46
+ SplitGLUInputTensor::SplitGLUInputTensor (const Tensor& self, int64_t dim)
47
+ : half_sizes(get_half_sizes(self, dim)),
48
+ first_half_impl (
49
+ self.scalar_type(),
50
+ self.dim(),
51
+ half_sizes.data(),
52
+ self.mutable_data_ptr(),
53
+ const_cast<executorch::aten::DimOrderType*>(self.dim_order().data()),
54
+ const_cast<executorch::aten::StridesType*>(self.strides().data()),
55
+ self.shape_dynamism()),
56
+ second_half_impl(
57
+ self.scalar_type(),
58
+ self.dim(),
59
+ half_sizes.data(),
60
+ reinterpret_cast<char*>(self.mutable_data_ptr()) +
61
+ self.strides()[dim] * self.size(dim) / 2 * self.element_size(),
62
+ const_cast<executorch::aten::DimOrderType*>(self.dim_order().data()),
63
+ const_cast<executorch::aten::StridesType*>(self.strides().data()),
64
+ self.shape_dynamism()),
65
+ first_half(&first_half_impl),
66
+ second_half(&second_half_impl) {}
112
67
113
68
/* *
114
69
* Applies the gated linear unit function
@@ -120,11 +75,43 @@ void slice_tensor(
120
75
* 2. The output shall be in float types (Float, Double)
121
76
*/
122
77
template <typename CTYPE_IN, typename CTYPE_OUT>
123
- Tensor& glu_out_tensor (const Tensor& self, int64_t dim, Tensor& out) {
124
- const auto self_size = self.size (dim);
125
- slice_tensor<CTYPE_IN, CTYPE_OUT>(self, dim, self_size / 2 , self_size, out);
126
- sigmoid_tensor<CTYPE_OUT>(out);
127
- mul_tensors<CTYPE_IN, CTYPE_OUT>(self, dim, out);
78
+ Tensor& glu_out_tensor (
79
+ KernelRuntimeContext& ctx,
80
+ const Tensor& self,
81
+ int64_t dim,
82
+ Tensor& out) {
83
+ ET_KERNEL_CHECK (
84
+ ctx,
85
+ self.dim () <= static_cast <ssize_t >(kTensorDimensionLimit ),
86
+ InvalidArgument,
87
+ out);
88
+ SplitGLUInputTensor split_input (self, dim);
89
+ ScalarType compute_type =
90
+ executorch::runtime::isFloatingType (self.scalar_type ())
91
+ ? self.scalar_type ()
92
+ : ScalarType::Float;
93
+ // @lint-ignore CLANGTIDY facebook-hte-CArray
94
+ static constexpr const char op_name[] = " glu.out" ;
95
+ ET_SWITCH_FLOATHBF16_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
96
+ utils::apply_bitensor_elementwise_fn<
97
+ CTYPE_COMPUTE,
98
+ op_name,
99
+ utils::SupportedTensorDtypes::FLOATHBF16>(
100
+ [](const auto val_a, const auto val_b) -> CTYPE_COMPUTE {
101
+ // TODO: rewrite this to be vectorization-capable? the
102
+ // tensors might not be contiguous; need to have
103
+ // apply_bitensor_elementwise_fn check that.
104
+ const auto one = static_cast <decltype (val_a)>(1.0 );
105
+ return val_a * (one / (one + std::exp (-val_b)));
106
+ },
107
+ ctx,
108
+ split_input.first_half ,
109
+ utils::SupportedTensorDtypes::FLOATHBF16,
110
+ split_input.second_half ,
111
+ utils::SupportedTensorDtypes::FLOATHBF16,
112
+ out,
113
+ utils::internal::SupportNoncontiguousTensors ());
114
+ });
128
115
return out;
129
116
}
130
117
} // namespace
@@ -158,7 +145,7 @@ Tensor& glu_out(
158
145
159
146
ET_SWITCH_FLOATHBF16_TYPES (in_dtype, ctx, " glu" , CTYPE_IN, [&]() {
160
147
ET_SWITCH_FLOATHBF16_TYPES (out.scalar_type (), ctx, " glu" , CTYPE_OUT, [&]() {
161
- glu_out_tensor<CTYPE_IN, CTYPE_OUT>(self, non_negative_dim, out);
148
+ glu_out_tensor<CTYPE_IN, CTYPE_OUT>(ctx, self, non_negative_dim, out);
162
149
});
163
150
});
164
151
0 commit comments