Skip to content

Commit ce86e3f

Browse files
committed
Implement GLU using internal views to avoid copying
GLU requires slicing the input Tensor into two halves. Currently, we accomplish this by copying; ExecuTorch does not support views in general because it requires Tensors to be contiguous. However, nothing stops us from implementing [the ATen that uses views](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/GatedLinearUnit.cpp#L35) entirely internally to the op. To support this, I added `support_noncontiguous_tensors` as an optional template argument to BroadcastIndexesRange and plumbed it through to the elementwise_util functions as an optional SupportNonContiguousTensors parameter. ghstack-source-id: fac946b ghstack-comment-id: 2932190540 Pull-Request-resolved: #11295
1 parent 4c12f6f commit ce86e3f

File tree

4 files changed

+193
-112
lines changed

4 files changed

+193
-112
lines changed

kernels/portable/cpu/op_glu.cpp

Lines changed: 55 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <c10/util/irange.h>
1010
#include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
11+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1112
#include <executorch/runtime/kernel/kernel_includes.h>
1213
#include <executorch/runtime/platform/assert.h>
1314
#include <cinttypes>
@@ -23,93 +24,6 @@ using ScalarType = executorch::aten::ScalarType;
2324

2425
namespace {
2526

26-
double exp_overload(double d) {
27-
return exp(d);
28-
}
29-
30-
float exp_overload(float f) {
31-
return expf(f);
32-
}
33-
34-
/**
35-
* In-place element-wise sigmoid function , i.e., f(x) = 1 / (1 + e^{-x})
36-
*/
37-
// TODO: T146333648, refactor this as a common helper function
38-
template <typename CTYPE_OUT>
39-
void sigmoid_tensor(Tensor& out) {
40-
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
41-
for (const auto i : c10::irange(out.numel())) {
42-
out_data[i] = 1.0 / (1.0 + exp_overload(-out_data[i]));
43-
}
44-
}
45-
46-
/**
47-
* Element-wise multiplication of the first half of `in` along the specified
48-
* dimension and `out`, overwriting `out`.
49-
*/
50-
template <typename CTYPE_IN, typename CTYPE_OUT>
51-
void mul_tensors(const Tensor& in, int64_t dim, Tensor& out) {
52-
size_t num_values = static_cast<size_t>(in.size(dim)) / 2;
53-
size_t dim_length_in = static_cast<size_t>(in.size(dim));
54-
size_t dim_length_out = static_cast<size_t>(out.size(dim));
55-
size_t leading_dims = getLeadingDims(in, dim);
56-
size_t trailing_dims = getTrailingDims(in, dim);
57-
58-
const CTYPE_IN* input_data_base = in.const_data_ptr<CTYPE_IN>();
59-
CTYPE_OUT* output_data_base = out.mutable_data_ptr<CTYPE_OUT>();
60-
61-
for (const auto i : c10::irange(leading_dims)) {
62-
const CTYPE_IN* input_data =
63-
input_data_base + i * dim_length_in * trailing_dims;
64-
CTYPE_OUT* output_data =
65-
output_data_base + i * dim_length_out * trailing_dims;
66-
for ([[maybe_unused]] const auto j : c10::irange(num_values)) {
67-
for (const auto k : c10::irange(trailing_dims)) {
68-
output_data[k] = static_cast<CTYPE_OUT>(input_data[k]) * output_data[k];
69-
}
70-
input_data += trailing_dims;
71-
output_data += trailing_dims;
72-
}
73-
}
74-
}
75-
76-
/**
77-
* Slice the tensor in the given dim, from start to end, assume tensor in and
78-
* out have same shape and dtype, the dim is a non-negative number and start,
79-
* end are valid non-negative number
80-
*/
81-
template <typename CTYPE_IN, typename CTYPE_OUT>
82-
void slice_tensor(
83-
const Tensor& in,
84-
int64_t dim,
85-
int64_t start,
86-
int64_t end,
87-
Tensor& out) {
88-
size_t num_values = static_cast<size_t>(end - start);
89-
size_t dim_length_in = static_cast<size_t>(in.size(dim));
90-
size_t dim_length_out = static_cast<size_t>(out.size(dim));
91-
size_t non_negative_start = static_cast<size_t>(start);
92-
size_t leading_dims = getLeadingDims(in, dim);
93-
size_t trailing_dims = getTrailingDims(in, dim);
94-
95-
const CTYPE_IN* input_data_base = in.const_data_ptr<CTYPE_IN>();
96-
CTYPE_OUT* output_data_base = out.mutable_data_ptr<CTYPE_OUT>();
97-
98-
for (const auto i : c10::irange(leading_dims)) {
99-
const CTYPE_IN* input_data = input_data_base +
100-
(i * dim_length_in + non_negative_start) * trailing_dims;
101-
CTYPE_OUT* output_data =
102-
output_data_base + i * dim_length_out * trailing_dims;
103-
for ([[maybe_unused]] const auto j : c10::irange(num_values)) {
104-
for (const auto k : c10::irange(trailing_dims)) {
105-
output_data[k] = static_cast<CTYPE_OUT>(input_data[k]);
106-
}
107-
input_data += trailing_dims;
108-
output_data += trailing_dims;
109-
}
110-
}
111-
}
112-
11327
/**
11428
* Applies the gated linear unit function
11529
*
@@ -120,11 +34,60 @@ void slice_tensor(
12034
* 2. The output shall be in float types (Float, Double)
12135
*/
12236
template <typename CTYPE_IN, typename CTYPE_OUT>
123-
Tensor& glu_out_tensor(const Tensor& self, int64_t dim, Tensor& out) {
37+
Tensor& glu_out_tensor(
38+
KernelRuntimeContext& ctx,
39+
const Tensor& self,
40+
int64_t dim,
41+
Tensor& out) {
12442
const auto self_size = self.size(dim);
125-
slice_tensor<CTYPE_IN, CTYPE_OUT>(self, dim, self_size / 2, self_size, out);
126-
sigmoid_tensor<CTYPE_OUT>(out);
127-
mul_tensors<CTYPE_IN, CTYPE_OUT>(self, dim, out);
43+
ET_KERNEL_CHECK(
44+
ctx, self.dim() <= kTensorDimensionLimit, InvalidArgument, out);
45+
std::array<executorch::aten::SizesType, kTensorDimensionLimit> half_sizes;
46+
std::copy(self.sizes().begin(), self.sizes().end(), half_sizes.begin());
47+
half_sizes[dim] /= 2;
48+
TensorImpl first_half_impl(
49+
self.scalar_type(),
50+
self.dim(),
51+
half_sizes.data(),
52+
self.mutable_data_ptr(),
53+
const_cast<executorch::aten::DimOrderType*>(self.dim_order().data()),
54+
const_cast<executorch::aten::StridesType*>(self.strides().data()),
55+
self.shape_dynamism());
56+
TensorImpl second_half_impl(
57+
self.scalar_type(),
58+
self.dim(),
59+
half_sizes.data(),
60+
reinterpret_cast<char*>(self.mutable_data_ptr()) +
61+
self.strides()[dim] * self_size / 2 * self.element_size(),
62+
const_cast<executorch::aten::DimOrderType*>(self.dim_order().data()),
63+
const_cast<executorch::aten::StridesType*>(self.strides().data()),
64+
self.shape_dynamism());
65+
Tensor first_half(&first_half_impl);
66+
Tensor second_half(&second_half_impl);
67+
ScalarType compute_type =
68+
executorch::runtime::isFloatingType(self.scalar_type())
69+
? self.scalar_type()
70+
: ScalarType::Float;
71+
// @lint-ignore CLANGTIDY facebook-hte-CArray
72+
static constexpr const char op_name[] = "glu.out";
73+
ET_SWITCH_FLOATHBF16_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
74+
utils::apply_bitensor_elementwise_fn<
75+
CTYPE_COMPUTE,
76+
op_name,
77+
utils::SupportedTensorDtypes::FLOATHBF16>(
78+
[](const auto val_a, const auto val_b) -> CTYPE_COMPUTE {
79+
// TODO: rewrite this to be vectorization-capable.
80+
const auto one = static_cast<decltype(val_a)>(1.0);
81+
return val_a * (one / (one + std::exp(-val_b)));
82+
},
83+
ctx,
84+
first_half,
85+
utils::SupportedTensorDtypes::FLOATHBF16,
86+
second_half,
87+
utils::SupportedTensorDtypes::FLOATHBF16,
88+
out,
89+
utils::internal::SupportNoncontiguousTensors());
90+
});
12891
return out;
12992
}
13093
} // namespace
@@ -158,7 +121,7 @@ Tensor& glu_out(
158121

159122
ET_SWITCH_FLOATHBF16_TYPES(in_dtype, ctx, "glu", CTYPE_IN, [&]() {
160123
ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, "glu", CTYPE_OUT, [&]() {
161-
glu_out_tensor<CTYPE_IN, CTYPE_OUT>(self, non_negative_dim, out);
124+
glu_out_tensor<CTYPE_IN, CTYPE_OUT>(ctx, self, non_negative_dim, out);
162125
});
163126
});
164127

kernels/portable/cpu/util/broadcast_indexes_range.h

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ inline bool sizes_match_ignoring_leading_1s(
4343
std::equal(lhs_begin, lhs_end, rhs_begin);
4444
}
4545

46-
template <std::size_t kNumInputs>
46+
template <std::size_t kNumInputs, bool support_noncontiguous_tensors = false>
4747
class BroadcastIndexesIterator {
4848
public:
4949
using difference_type = ssize_t;
@@ -57,16 +57,19 @@ class BroadcastIndexesIterator {
5757
template <typename... Args>
5858
explicit BroadcastIndexesIterator(const Tensor& output, const Args&... args)
5959
: output_dim_or_zero_if_no_broadcasting_(
60-
(sizes_match_ignoring_leading_1s(args.sizes(), output.sizes()) &&
61-
...)
60+
!support_noncontiguous_tensors &&
61+
(sizes_match_ignoring_leading_1s(
62+
args.sizes(),
63+
output.sizes()) &&
64+
...)
6265
? 0
6366
: output.dim()),
6467
output_shape_(output.sizes()) {
6568
static_assert(
6669
sizeof...(args) == kNumInputs && (std::is_same_v<Args, Tensor> && ...),
6770
"BroadcastIndexesIterator constructor requires kNumInputs input tensor"
6871
"arguments!");
69-
if (output_dim_or_zero_if_no_broadcasting_ != 0) {
72+
if (support_noncontiguous_tensors || output_dim_or_zero_if_no_broadcasting_ != 0) {
7073
effective_input_broadcast_strides_ = {
7174
effective_input_broadcast_stride(output, args)...};
7275
}
@@ -249,11 +252,17 @@ class BroadcastIndexesIterator {
249252
* Unlike looping using delinearize_index() and
250253
* linearize_access_indexes(), BroadcastIndexesRange avoids expensive
251254
* division and modulo operations on each iteration.
255+
*
256+
* The support_noncontiguous_tensors argument disables an optimization
257+
* that causes the iterators not to respect strides in some
258+
* cases. This optimization is normally safe because ExecuTorch
259+
* tensors are contiguous.
252260
*/
253-
template <std::size_t kNumInputs>
261+
template <std::size_t kNumInputs, bool support_noncontiguous_tensors = false>
254262
class BroadcastIndexesRange {
255263
public:
256-
using iterator = internal::BroadcastIndexesIterator<kNumInputs>;
264+
using iterator =
265+
internal::BroadcastIndexesIterator<kNumInputs, support_noncontiguous_tensors>;
257266

258267
template <typename... Args>
259268
BroadcastIndexesRange(const Tensor& output, const Args&... args)

0 commit comments

Comments
 (0)