Skip to content

Commit 1f52982

Browse files
authored
Implement GLU using internal views to avoid copying (#11295)
GLU requires slicing the input Tensor into two halves. Currently, we accomplish this by copying; ExecuTorch does not support views in general because it requires Tensors to be contiguous. However, nothing stops us from implementing [the ATen that uses views](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/GatedLinearUnit.cpp#L35) entirely internally to the op. To support this, I added `support_noncontiguous_tensors` as an optional template argument to BroadcastIndexesRange and plumbed it through to the elementwise_util functions as an optional SupportNonContiguousTensors parameter.
1 parent 0e35c30 commit 1f52982

File tree

5 files changed

+198
-111
lines changed

5 files changed

+198
-111
lines changed

kernels/portable/cpu/op_glu.cpp

Lines changed: 58 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <c10/util/irange.h>
1010
#include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
11+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1112
#include <executorch/runtime/kernel/kernel_includes.h>
1213
#include <executorch/runtime/platform/assert.h>
1314
#include <cinttypes>
@@ -23,93 +24,6 @@ using ScalarType = executorch::aten::ScalarType;
2324

2425
namespace {
2526

26-
double exp_overload(double d) {
27-
return exp(d);
28-
}
29-
30-
float exp_overload(float f) {
31-
return expf(f);
32-
}
33-
34-
/**
35-
* In-place element-wise sigmoid function , i.e., f(x) = 1 / (1 + e^{-x})
36-
*/
37-
// TODO: T146333648, refactor this as a common helper function
38-
template <typename CTYPE_OUT>
39-
void sigmoid_tensor(Tensor& out) {
40-
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
41-
for (const auto i : c10::irange(out.numel())) {
42-
out_data[i] = 1.0 / (1.0 + exp_overload(-out_data[i]));
43-
}
44-
}
45-
46-
/**
47-
* Element-wise multiplication of the first half of `in` along the specified
48-
* dimension and `out`, overwriting `out`.
49-
*/
50-
template <typename CTYPE_IN, typename CTYPE_OUT>
51-
void mul_tensors(const Tensor& in, int64_t dim, Tensor& out) {
52-
size_t num_values = static_cast<size_t>(in.size(dim)) / 2;
53-
size_t dim_length_in = static_cast<size_t>(in.size(dim));
54-
size_t dim_length_out = static_cast<size_t>(out.size(dim));
55-
size_t leading_dims = getLeadingDims(in, dim);
56-
size_t trailing_dims = getTrailingDims(in, dim);
57-
58-
const CTYPE_IN* input_data_base = in.const_data_ptr<CTYPE_IN>();
59-
CTYPE_OUT* output_data_base = out.mutable_data_ptr<CTYPE_OUT>();
60-
61-
for (const auto i : c10::irange(leading_dims)) {
62-
const CTYPE_IN* input_data =
63-
input_data_base + i * dim_length_in * trailing_dims;
64-
CTYPE_OUT* output_data =
65-
output_data_base + i * dim_length_out * trailing_dims;
66-
for ([[maybe_unused]] const auto j : c10::irange(num_values)) {
67-
for (const auto k : c10::irange(trailing_dims)) {
68-
output_data[k] = static_cast<CTYPE_OUT>(input_data[k]) * output_data[k];
69-
}
70-
input_data += trailing_dims;
71-
output_data += trailing_dims;
72-
}
73-
}
74-
}
75-
76-
/**
77-
* Slice the tensor in the given dim, from start to end, assume tensor in and
78-
* out have same shape and dtype, the dim is a non-negative number and start,
79-
* end are valid non-negative number
80-
*/
81-
template <typename CTYPE_IN, typename CTYPE_OUT>
82-
void slice_tensor(
83-
const Tensor& in,
84-
int64_t dim,
85-
int64_t start,
86-
int64_t end,
87-
Tensor& out) {
88-
size_t num_values = static_cast<size_t>(end - start);
89-
size_t dim_length_in = static_cast<size_t>(in.size(dim));
90-
size_t dim_length_out = static_cast<size_t>(out.size(dim));
91-
size_t non_negative_start = static_cast<size_t>(start);
92-
size_t leading_dims = getLeadingDims(in, dim);
93-
size_t trailing_dims = getTrailingDims(in, dim);
94-
95-
const CTYPE_IN* input_data_base = in.const_data_ptr<CTYPE_IN>();
96-
CTYPE_OUT* output_data_base = out.mutable_data_ptr<CTYPE_OUT>();
97-
98-
for (const auto i : c10::irange(leading_dims)) {
99-
const CTYPE_IN* input_data = input_data_base +
100-
(i * dim_length_in + non_negative_start) * trailing_dims;
101-
CTYPE_OUT* output_data =
102-
output_data_base + i * dim_length_out * trailing_dims;
103-
for ([[maybe_unused]] const auto j : c10::irange(num_values)) {
104-
for (const auto k : c10::irange(trailing_dims)) {
105-
output_data[k] = static_cast<CTYPE_OUT>(input_data[k]);
106-
}
107-
input_data += trailing_dims;
108-
output_data += trailing_dims;
109-
}
110-
}
111-
}
112-
11327
/**
11428
* Applies the gated linear unit function
11529
*
@@ -120,11 +34,63 @@ void slice_tensor(
12034
* 2. The output shall be in float types (Float, Double)
12135
*/
12236
template <typename CTYPE_IN, typename CTYPE_OUT>
123-
Tensor& glu_out_tensor(const Tensor& self, int64_t dim, Tensor& out) {
37+
Tensor& glu_out_tensor(
38+
KernelRuntimeContext& ctx,
39+
const Tensor& self,
40+
int64_t dim,
41+
Tensor& out) {
12442
const auto self_size = self.size(dim);
125-
slice_tensor<CTYPE_IN, CTYPE_OUT>(self, dim, self_size / 2, self_size, out);
126-
sigmoid_tensor<CTYPE_OUT>(out);
127-
mul_tensors<CTYPE_IN, CTYPE_OUT>(self, dim, out);
43+
ET_KERNEL_CHECK(
44+
ctx,
45+
self.dim() <= static_cast<ssize_t>(kTensorDimensionLimit),
46+
InvalidArgument,
47+
out);
48+
std::array<executorch::aten::SizesType, kTensorDimensionLimit> half_sizes;
49+
std::copy(self.sizes().begin(), self.sizes().end(), half_sizes.begin());
50+
half_sizes[dim] /= 2;
51+
TensorImpl first_half_impl(
52+
self.scalar_type(),
53+
self.dim(),
54+
half_sizes.data(),
55+
self.mutable_data_ptr(),
56+
const_cast<executorch::aten::DimOrderType*>(self.dim_order().data()),
57+
const_cast<executorch::aten::StridesType*>(self.strides().data()),
58+
self.shape_dynamism());
59+
TensorImpl second_half_impl(
60+
self.scalar_type(),
61+
self.dim(),
62+
half_sizes.data(),
63+
reinterpret_cast<char*>(self.mutable_data_ptr()) +
64+
self.strides()[dim] * self_size / 2 * self.element_size(),
65+
const_cast<executorch::aten::DimOrderType*>(self.dim_order().data()),
66+
const_cast<executorch::aten::StridesType*>(self.strides().data()),
67+
self.shape_dynamism());
68+
Tensor first_half(&first_half_impl);
69+
Tensor second_half(&second_half_impl);
70+
ScalarType compute_type =
71+
executorch::runtime::isFloatingType(self.scalar_type())
72+
? self.scalar_type()
73+
: ScalarType::Float;
74+
// @lint-ignore CLANGTIDY facebook-hte-CArray
75+
static constexpr const char op_name[] = "glu.out";
76+
ET_SWITCH_FLOATHBF16_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
77+
utils::apply_bitensor_elementwise_fn<
78+
CTYPE_COMPUTE,
79+
op_name,
80+
utils::SupportedTensorDtypes::FLOATHBF16>(
81+
[](const auto val_a, const auto val_b) -> CTYPE_COMPUTE {
82+
// TODO: rewrite this to be vectorization-capable.
83+
const auto one = static_cast<decltype(val_a)>(1.0);
84+
return val_a * (one / (one + std::exp(-val_b)));
85+
},
86+
ctx,
87+
first_half,
88+
utils::SupportedTensorDtypes::FLOATHBF16,
89+
second_half,
90+
utils::SupportedTensorDtypes::FLOATHBF16,
91+
out,
92+
utils::internal::SupportNoncontiguousTensors());
93+
});
12894
return out;
12995
}
13096
} // namespace
@@ -158,7 +124,7 @@ Tensor& glu_out(
158124

159125
ET_SWITCH_FLOATHBF16_TYPES(in_dtype, ctx, "glu", CTYPE_IN, [&]() {
160126
ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, "glu", CTYPE_OUT, [&]() {
161-
glu_out_tensor<CTYPE_IN, CTYPE_OUT>(self, non_negative_dim, out);
127+
glu_out_tensor<CTYPE_IN, CTYPE_OUT>(ctx, self, non_negative_dim, out);
162128
});
163129
});
164130

kernels/portable/cpu/util/broadcast_indexes_range.h

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ inline bool sizes_match_ignoring_leading_1s(
4343
std::equal(lhs_begin, lhs_end, rhs_begin);
4444
}
4545

46-
template <std::size_t kNumInputs>
46+
template <std::size_t kNumInputs, bool support_noncontiguous_tensors = false>
4747
class BroadcastIndexesIterator {
4848
public:
4949
using difference_type = ssize_t;
@@ -57,16 +57,20 @@ class BroadcastIndexesIterator {
5757
template <typename... Args>
5858
explicit BroadcastIndexesIterator(const Tensor& output, const Args&... args)
5959
: output_dim_or_zero_if_no_broadcasting_(
60-
(sizes_match_ignoring_leading_1s(args.sizes(), output.sizes()) &&
61-
...)
60+
!support_noncontiguous_tensors &&
61+
(sizes_match_ignoring_leading_1s(
62+
args.sizes(),
63+
output.sizes()) &&
64+
...)
6265
? 0
6366
: output.dim()),
6467
output_shape_(output.sizes()) {
6568
static_assert(
6669
sizeof...(args) == kNumInputs && (std::is_same_v<Args, Tensor> && ...),
6770
"BroadcastIndexesIterator constructor requires kNumInputs input tensor"
6871
"arguments!");
69-
if (output_dim_or_zero_if_no_broadcasting_ != 0) {
72+
if (support_noncontiguous_tensors ||
73+
output_dim_or_zero_if_no_broadcasting_ != 0) {
7074
effective_input_broadcast_strides_ = {
7175
effective_input_broadcast_stride(output, args)...};
7276
}
@@ -249,11 +253,17 @@ class BroadcastIndexesIterator {
249253
* Unlike looping using delinearize_index() and
250254
* linearize_access_indexes(), BroadcastIndexesRange avoids expensive
251255
* division and modulo operations on each iteration.
256+
*
257+
* The support_noncontiguous_tensors argument disables an optimization
258+
* that causes the iterators not to respect strides in some
259+
* cases. This optimization is normally safe because ExecuTorch
260+
* tensors are contiguous.
252261
*/
253-
template <std::size_t kNumInputs>
262+
template <std::size_t kNumInputs, bool support_noncontiguous_tensors = false>
254263
class BroadcastIndexesRange {
255264
public:
256-
using iterator = internal::BroadcastIndexesIterator<kNumInputs>;
265+
using iterator = internal::
266+
BroadcastIndexesIterator<kNumInputs, support_noncontiguous_tensors>;
257267

258268
template <typename... Args>
259269
BroadcastIndexesRange(const Tensor& output, const Args&... args)

0 commit comments

Comments
 (0)