Skip to content

Commit 04710d4

Browse files
pytorchbotswolchok
andauthored
Reapply #11294 and #11295 (improve GLU test and implement using internal views to avoid copying) (#11539)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #11509 by @swolchok ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/swolchok/451/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/swolchok/451/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/swolchok/451/orig @diff-train-skip-merge Co-authored-by: Scott Wolchok <swolchok@meta.com>
1 parent c6c3616 commit 04710d4

File tree

5 files changed

+221
-113
lines changed

5 files changed

+221
-113
lines changed

kernels/portable/cpu/op_glu.cpp

Lines changed: 78 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <c10/util/irange.h>
1010
#include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
11+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1112
#include <executorch/runtime/kernel/kernel_includes.h>
1213
#include <executorch/runtime/platform/assert.h>
1314
#include <cinttypes>
@@ -23,92 +24,46 @@ using ScalarType = executorch::aten::ScalarType;
2324

2425
namespace {
2526

26-
double exp_overload(double d) {
27-
return exp(d);
28-
}
29-
30-
float exp_overload(float f) {
31-
return expf(f);
32-
}
33-
34-
/**
35-
* In-place element-wise sigmoid function , i.e., f(x) = 1 / (1 + e^{-x})
36-
*/
37-
// TODO: T146333648, refactor this as a common helper function
38-
template <typename CTYPE_OUT>
39-
void sigmoid_tensor(Tensor& out) {
40-
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
41-
for (const auto i : c10::irange(out.numel())) {
42-
out_data[i] = 1.0 / (1.0 + exp_overload(-out_data[i]));
43-
}
44-
}
45-
46-
/**
47-
* Element-wise multiplication of the first half of `in` along the specified
48-
* dimension and `out`, overwriting `out`.
49-
*/
50-
template <typename CTYPE_IN, typename CTYPE_OUT>
51-
void mul_tensors(const Tensor& in, int64_t dim, Tensor& out) {
52-
size_t num_values = static_cast<size_t>(in.size(dim)) / 2;
53-
size_t dim_length_in = static_cast<size_t>(in.size(dim));
54-
size_t dim_length_out = static_cast<size_t>(out.size(dim));
55-
size_t leading_dims = getLeadingDims(in, dim);
56-
size_t trailing_dims = getTrailingDims(in, dim);
57-
58-
const CTYPE_IN* input_data_base = in.const_data_ptr<CTYPE_IN>();
59-
CTYPE_OUT* output_data_base = out.mutable_data_ptr<CTYPE_OUT>();
60-
61-
for (const auto i : c10::irange(leading_dims)) {
62-
const CTYPE_IN* input_data =
63-
input_data_base + i * dim_length_in * trailing_dims;
64-
CTYPE_OUT* output_data =
65-
output_data_base + i * dim_length_out * trailing_dims;
66-
for ([[maybe_unused]] const auto j : c10::irange(num_values)) {
67-
for (const auto k : c10::irange(trailing_dims)) {
68-
output_data[k] = static_cast<CTYPE_OUT>(input_data[k]) * output_data[k];
69-
}
70-
input_data += trailing_dims;
71-
output_data += trailing_dims;
72-
}
27+
struct SplitGLUInputTensor {
28+
explicit SplitGLUInputTensor(const Tensor& self, int64_t dim);
29+
using SizesArray =
30+
std::array<executorch::aten::SizesType, kTensorDimensionLimit>;
31+
SizesArray half_sizes;
32+
TensorImpl first_half_impl;
33+
TensorImpl second_half_impl;
34+
Tensor first_half;
35+
Tensor second_half;
36+
37+
private:
38+
static SizesArray get_half_sizes(const Tensor& self, int64_t dim) {
39+
SizesArray half_sizes;
40+
std::copy(self.sizes().begin(), self.sizes().end(), half_sizes.begin());
41+
half_sizes[dim] /= 2;
42+
return half_sizes;
7343
}
74-
}
75-
76-
/**
77-
* Slice the tensor in the given dim, from start to end, assume tensor in and
78-
* out have same shape and dtype, the dim is a non-negative number and start,
79-
* end are valid non-negative number
80-
*/
81-
template <typename CTYPE_IN, typename CTYPE_OUT>
82-
void slice_tensor(
83-
const Tensor& in,
84-
int64_t dim,
85-
int64_t start,
86-
int64_t end,
87-
Tensor& out) {
88-
size_t num_values = static_cast<size_t>(end - start);
89-
size_t dim_length_in = static_cast<size_t>(in.size(dim));
90-
size_t dim_length_out = static_cast<size_t>(out.size(dim));
91-
size_t non_negative_start = static_cast<size_t>(start);
92-
size_t leading_dims = getLeadingDims(in, dim);
93-
size_t trailing_dims = getTrailingDims(in, dim);
94-
95-
const CTYPE_IN* input_data_base = in.const_data_ptr<CTYPE_IN>();
96-
CTYPE_OUT* output_data_base = out.mutable_data_ptr<CTYPE_OUT>();
97-
98-
for (const auto i : c10::irange(leading_dims)) {
99-
const CTYPE_IN* input_data = input_data_base +
100-
(i * dim_length_in + non_negative_start) * trailing_dims;
101-
CTYPE_OUT* output_data =
102-
output_data_base + i * dim_length_out * trailing_dims;
103-
for ([[maybe_unused]] const auto j : c10::irange(num_values)) {
104-
for (const auto k : c10::irange(trailing_dims)) {
105-
output_data[k] = static_cast<CTYPE_OUT>(input_data[k]);
106-
}
107-
input_data += trailing_dims;
108-
output_data += trailing_dims;
109-
}
110-
}
111-
}
44+
};
45+
46+
SplitGLUInputTensor::SplitGLUInputTensor(const Tensor& self, int64_t dim)
47+
: half_sizes(get_half_sizes(self, dim)),
48+
first_half_impl(
49+
self.scalar_type(),
50+
self.dim(),
51+
half_sizes.data(),
52+
self.mutable_data_ptr(),
53+
const_cast<executorch::aten::DimOrderType*>(self.dim_order().data()),
54+
const_cast<executorch::aten::StridesType*>(self.strides().data()),
55+
self.shape_dynamism()),
56+
second_half_impl(
57+
self.scalar_type(),
58+
self.dim(),
59+
half_sizes.data(),
60+
reinterpret_cast<char*>(self.mutable_data_ptr()) +
61+
self.strides()[dim] * self.size(dim) / 2 * self.element_size(),
62+
const_cast<executorch::aten::DimOrderType*>(self.dim_order().data()),
63+
const_cast<executorch::aten::StridesType*>(self.strides().data()),
64+
self.shape_dynamism()),
65+
first_half(&first_half_impl),
66+
second_half(&second_half_impl) {}
11267

11368
/**
11469
* Applies the gated linear unit function
@@ -120,11 +75,43 @@ void slice_tensor(
12075
* 2. The output shall be in float types (Float, Double)
12176
*/
12277
template <typename CTYPE_IN, typename CTYPE_OUT>
123-
Tensor& glu_out_tensor(const Tensor& self, int64_t dim, Tensor& out) {
124-
const auto self_size = self.size(dim);
125-
slice_tensor<CTYPE_IN, CTYPE_OUT>(self, dim, self_size / 2, self_size, out);
126-
sigmoid_tensor<CTYPE_OUT>(out);
127-
mul_tensors<CTYPE_IN, CTYPE_OUT>(self, dim, out);
78+
Tensor& glu_out_tensor(
79+
KernelRuntimeContext& ctx,
80+
const Tensor& self,
81+
int64_t dim,
82+
Tensor& out) {
83+
ET_KERNEL_CHECK(
84+
ctx,
85+
self.dim() <= static_cast<ssize_t>(kTensorDimensionLimit),
86+
InvalidArgument,
87+
out);
88+
SplitGLUInputTensor split_input(self, dim);
89+
ScalarType compute_type =
90+
executorch::runtime::isFloatingType(self.scalar_type())
91+
? self.scalar_type()
92+
: ScalarType::Float;
93+
// @lint-ignore CLANGTIDY facebook-hte-CArray
94+
static constexpr const char op_name[] = "glu.out";
95+
ET_SWITCH_FLOATHBF16_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
96+
utils::apply_bitensor_elementwise_fn<
97+
CTYPE_COMPUTE,
98+
op_name,
99+
utils::SupportedTensorDtypes::FLOATHBF16>(
100+
[](const auto val_a, const auto val_b) -> CTYPE_COMPUTE {
101+
// TODO: rewrite this to be vectorization-capable? the
102+
// tensors might not be contiguous; need to have
103+
// apply_bitensor_elementwise_fn check that.
104+
const auto one = static_cast<decltype(val_a)>(1.0);
105+
return val_a * (one / (one + std::exp(-val_b)));
106+
},
107+
ctx,
108+
split_input.first_half,
109+
utils::SupportedTensorDtypes::FLOATHBF16,
110+
split_input.second_half,
111+
utils::SupportedTensorDtypes::FLOATHBF16,
112+
out,
113+
utils::internal::SupportNoncontiguousTensors());
114+
});
128115
return out;
129116
}
130117
} // namespace
@@ -158,7 +145,7 @@ Tensor& glu_out(
158145

159146
ET_SWITCH_FLOATHBF16_TYPES(in_dtype, ctx, "glu", CTYPE_IN, [&]() {
160147
ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, "glu", CTYPE_OUT, [&]() {
161-
glu_out_tensor<CTYPE_IN, CTYPE_OUT>(self, non_negative_dim, out);
148+
glu_out_tensor<CTYPE_IN, CTYPE_OUT>(ctx, self, non_negative_dim, out);
162149
});
163150
});
164151

kernels/portable/cpu/util/broadcast_indexes_range.h

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ inline bool sizes_match_ignoring_leading_1s(
4343
std::equal(lhs_begin, lhs_end, rhs_begin);
4444
}
4545

46-
template <std::size_t kNumInputs>
46+
template <std::size_t kNumInputs, bool support_noncontiguous_tensors = false>
4747
class BroadcastIndexesIterator {
4848
public:
4949
using difference_type = ssize_t;
@@ -57,16 +57,20 @@ class BroadcastIndexesIterator {
5757
template <typename... Args>
5858
explicit BroadcastIndexesIterator(const Tensor& output, const Args&... args)
5959
: output_dim_or_zero_if_no_broadcasting_(
60-
(sizes_match_ignoring_leading_1s(args.sizes(), output.sizes()) &&
61-
...)
60+
!support_noncontiguous_tensors &&
61+
(sizes_match_ignoring_leading_1s(
62+
args.sizes(),
63+
output.sizes()) &&
64+
...)
6265
? 0
6366
: output.dim()),
6467
output_shape_(output.sizes()) {
6568
static_assert(
6669
sizeof...(args) == kNumInputs && (std::is_same_v<Args, Tensor> && ...),
6770
"BroadcastIndexesIterator constructor requires kNumInputs input tensor"
6871
"arguments!");
69-
if (output_dim_or_zero_if_no_broadcasting_ != 0) {
72+
if (support_noncontiguous_tensors ||
73+
output_dim_or_zero_if_no_broadcasting_ != 0) {
7074
effective_input_broadcast_strides_ = {
7175
effective_input_broadcast_stride(output, args)...};
7276
}
@@ -249,11 +253,17 @@ class BroadcastIndexesIterator {
249253
* Unlike looping using delinearize_index() and
250254
* linearize_access_indexes(), BroadcastIndexesRange avoids expensive
251255
* division and modulo operations on each iteration.
256+
*
257+
* The support_noncontiguous_tensors argument disables an optimization
258+
* that causes the iterators not to respect strides in some
259+
* cases. This optimization is normally safe because ExecuTorch
260+
* tensors are contiguous.
252261
*/
253-
template <std::size_t kNumInputs>
262+
template <std::size_t kNumInputs, bool support_noncontiguous_tensors = false>
254263
class BroadcastIndexesRange {
255264
public:
256-
using iterator = internal::BroadcastIndexesIterator<kNumInputs>;
265+
using iterator = internal::
266+
BroadcastIndexesIterator<kNumInputs, support_noncontiguous_tensors>;
257267

258268
template <typename... Args>
259269
BroadcastIndexesRange(const Tensor& output, const Args&... args)

0 commit comments

Comments
 (0)