diff --git a/kernels/portable/cpu/op_glu.cpp b/kernels/portable/cpu/op_glu.cpp index edc82c55eb8..be76a158182 100644 --- a/kernels/portable/cpu/op_glu.cpp +++ b/kernels/portable/cpu/op_glu.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -23,93 +24,6 @@ using ScalarType = executorch::aten::ScalarType; namespace { -double exp_overload(double d) { - return exp(d); -} - -float exp_overload(float f) { - return expf(f); -} - -/** - * In-place element-wise sigmoid function , i.e., f(x) = 1 / (1 + e^{-x}) - */ -// TODO: T146333648, refactor this as a common helper function -template -void sigmoid_tensor(Tensor& out) { - CTYPE_OUT* out_data = out.mutable_data_ptr(); - for (const auto i : c10::irange(out.numel())) { - out_data[i] = 1.0 / (1.0 + exp_overload(-out_data[i])); - } -} - -/** - * Element-wise multiplication of the first half of `in` along the specified - * dimension and `out`, overwriting `out`. - */ -template -void mul_tensors(const Tensor& in, int64_t dim, Tensor& out) { - size_t num_values = static_cast(in.size(dim)) / 2; - size_t dim_length_in = static_cast(in.size(dim)); - size_t dim_length_out = static_cast(out.size(dim)); - size_t leading_dims = getLeadingDims(in, dim); - size_t trailing_dims = getTrailingDims(in, dim); - - const CTYPE_IN* input_data_base = in.const_data_ptr(); - CTYPE_OUT* output_data_base = out.mutable_data_ptr(); - - for (const auto i : c10::irange(leading_dims)) { - const CTYPE_IN* input_data = - input_data_base + i * dim_length_in * trailing_dims; - CTYPE_OUT* output_data = - output_data_base + i * dim_length_out * trailing_dims; - for ([[maybe_unused]] const auto j : c10::irange(num_values)) { - for (const auto k : c10::irange(trailing_dims)) { - output_data[k] = static_cast(input_data[k]) * output_data[k]; - } - input_data += trailing_dims; - output_data += trailing_dims; - } - } -} - -/** - * Slice the tensor in the given dim, from start to end, assume tensor in and - * out have same shape and dtype, the dim is a non-negative number and start, - * end are valid non-negative number - */ -template -void slice_tensor( - const Tensor& in, - int64_t dim, - int64_t start, - int64_t end, - Tensor& out) { - size_t num_values = static_cast(end - start); - size_t dim_length_in = static_cast(in.size(dim)); - size_t dim_length_out = static_cast(out.size(dim)); - size_t non_negative_start = static_cast(start); - size_t leading_dims = getLeadingDims(in, dim); - size_t trailing_dims = getTrailingDims(in, dim); - - const CTYPE_IN* input_data_base = in.const_data_ptr(); - CTYPE_OUT* output_data_base = out.mutable_data_ptr(); - - for (const auto i : c10::irange(leading_dims)) { - const CTYPE_IN* input_data = input_data_base + - (i * dim_length_in + non_negative_start) * trailing_dims; - CTYPE_OUT* output_data = - output_data_base + i * dim_length_out * trailing_dims; - for ([[maybe_unused]] const auto j : c10::irange(num_values)) { - for (const auto k : c10::irange(trailing_dims)) { - output_data[k] = static_cast(input_data[k]); - } - input_data += trailing_dims; - output_data += trailing_dims; - } - } -} - /** * Applies the gated linear unit function * @@ -120,11 +34,63 @@ void slice_tensor( * 2. The output shall be in float types (Float, Double) */ template -Tensor& glu_out_tensor(const Tensor& self, int64_t dim, Tensor& out) { +Tensor& glu_out_tensor( + KernelRuntimeContext& ctx, + const Tensor& self, + int64_t dim, + Tensor& out) { const auto self_size = self.size(dim); - slice_tensor(self, dim, self_size / 2, self_size, out); - sigmoid_tensor(out); - mul_tensors(self, dim, out); + ET_KERNEL_CHECK( + ctx, + self.dim() <= static_cast(kTensorDimensionLimit), + InvalidArgument, + out); + std::array half_sizes; + std::copy(self.sizes().begin(), self.sizes().end(), half_sizes.begin()); + half_sizes[dim] /= 2; + TensorImpl first_half_impl( + self.scalar_type(), + self.dim(), + half_sizes.data(), + self.mutable_data_ptr(), + const_cast(self.dim_order().data()), + const_cast(self.strides().data()), + self.shape_dynamism()); + TensorImpl second_half_impl( + self.scalar_type(), + self.dim(), + half_sizes.data(), + reinterpret_cast(self.mutable_data_ptr()) + + self.strides()[dim] * self_size / 2 * self.element_size(), + const_cast(self.dim_order().data()), + const_cast(self.strides().data()), + self.shape_dynamism()); + Tensor first_half(&first_half_impl); + Tensor second_half(&second_half_impl); + ScalarType compute_type = + executorch::runtime::isFloatingType(self.scalar_type()) + ? self.scalar_type() + : ScalarType::Float; + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "glu.out"; + ET_SWITCH_FLOATHBF16_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::FLOATHBF16>( + [](const auto val_a, const auto val_b) -> CTYPE_COMPUTE { + // TODO: rewrite this to be vectorization-capable. + const auto one = static_cast(1.0); + return val_a * (one / (one + std::exp(-val_b))); + }, + ctx, + first_half, + utils::SupportedTensorDtypes::FLOATHBF16, + second_half, + utils::SupportedTensorDtypes::FLOATHBF16, + out, + utils::internal::SupportNoncontiguousTensors()); + }); return out; } } // namespace @@ -158,7 +124,7 @@ Tensor& glu_out( ET_SWITCH_FLOATHBF16_TYPES(in_dtype, ctx, "glu", CTYPE_IN, [&]() { ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, "glu", CTYPE_OUT, [&]() { - glu_out_tensor(self, non_negative_dim, out); + glu_out_tensor(ctx, self, non_negative_dim, out); }); }); diff --git a/kernels/portable/cpu/util/broadcast_indexes_range.h b/kernels/portable/cpu/util/broadcast_indexes_range.h index 4d3ba46b51b..d372767819a 100644 --- a/kernels/portable/cpu/util/broadcast_indexes_range.h +++ b/kernels/portable/cpu/util/broadcast_indexes_range.h @@ -43,7 +43,7 @@ inline bool sizes_match_ignoring_leading_1s( std::equal(lhs_begin, lhs_end, rhs_begin); } -template +template class BroadcastIndexesIterator { public: using difference_type = ssize_t; @@ -57,8 +57,11 @@ class BroadcastIndexesIterator { template explicit BroadcastIndexesIterator(const Tensor& output, const Args&... args) : output_dim_or_zero_if_no_broadcasting_( - (sizes_match_ignoring_leading_1s(args.sizes(), output.sizes()) && - ...) + !support_noncontiguous_tensors && + (sizes_match_ignoring_leading_1s( + args.sizes(), + output.sizes()) && + ...) ? 0 : output.dim()), output_shape_(output.sizes()) { @@ -66,7 +69,8 @@ class BroadcastIndexesIterator { sizeof...(args) == kNumInputs && (std::is_same_v && ...), "BroadcastIndexesIterator constructor requires kNumInputs input tensor" "arguments!"); - if (output_dim_or_zero_if_no_broadcasting_ != 0) { + if (support_noncontiguous_tensors || + output_dim_or_zero_if_no_broadcasting_ != 0) { effective_input_broadcast_strides_ = { effective_input_broadcast_stride(output, args)...}; } @@ -249,11 +253,17 @@ class BroadcastIndexesIterator { * Unlike looping using delinearize_index() and * linearize_access_indexes(), BroadcastIndexesRange avoids expensive * division and modulo operations on each iteration. + * + * The support_noncontiguous_tensors argument disables an optimization + * that causes the iterators not to respect strides in some + * cases. This optimization is normally safe because ExecuTorch + * tensors are contiguous. */ -template +template class BroadcastIndexesRange { public: - using iterator = internal::BroadcastIndexesIterator; + using iterator = internal:: + BroadcastIndexesIterator; template BroadcastIndexesRange(const Tensor& output, const Args&... args) diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index e30b8af7d89..722483ec363 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -51,9 +51,19 @@ inline int64_t scalar_to(const Scalar& s) { } namespace internal { +/** + * Causes these utility functions to make sure to respect Tensor + * strides; normally, this is not strictly necessary because ExecuTorch + * Tensors are contiguous. + */ +struct SupportNoncontiguousTensors { + explicit SupportNoncontiguousTensors() = default; +}; + template < typename CTYPE_COMPUTE, typename CTYPE_OUT, + bool support_noncontiguous_tensors, typename Op, typename... Args> inline void dtype_specialized_elementwise_fn_impl( @@ -75,7 +85,8 @@ inline void dtype_specialized_elementwise_fn_impl( CTYPE_OUT* const data_out = out.mutable_data_ptr(); const auto range = - BroadcastIndexesRange(out, (*inputs.first)...); + BroadcastIndexesRange( + out, (*inputs.first)...); auto begin_it = range.begin(); begin_it += begin; for (; (*begin_it)[0] < end; ++begin_it) { @@ -117,6 +128,7 @@ inline bool validate_elementwise_fn_inputs( template < typename CTYPE_COMPUTE, const char* op_name, + bool support_noncontiguous_tensors, typename Op, typename... Args> inline void apply_elementwise_fn_generic_impl( @@ -151,7 +163,8 @@ inline void apply_elementwise_fn_generic_impl( ::executorch::extension::internal::GRAIN_SIZE, [&](const auto begin, const auto end) { const auto range = - BroadcastIndexesRange(out, (*inputs.first)...); + BroadcastIndexesRange( + out, (*inputs.first)...); auto begin_it = range.begin(); begin_it += begin; for (; (*begin_it)[0] < end; ++begin_it) { @@ -187,7 +200,10 @@ inline void apply_elementwise_fn_runtime_out_dtypes( return; } - apply_elementwise_fn_generic_impl( + apply_elementwise_fn_generic_impl< + CTYPE_COMPUTE, + op_name, + /*support_noncontiguous_tensors*/ false>( compute_fun, ctx, out, out_dtypes, inputs...); } @@ -195,6 +211,7 @@ template < typename CTYPE_COMPUTE, const char* op_name, SupportedTensorDtypes out_dtypes, + bool support_noncontiguous_tensors, typename Op, typename... Args> inline void apply_elementwise_fn( @@ -218,12 +235,17 @@ inline void apply_elementwise_fn( out.scalar_type() == out_specialized_scalar_type) { using CTYPE_OUT = typename ScalarTypeToCppType::type; - dtype_specialized_elementwise_fn_impl( - compute_fun, ctx, out, inputs...); + dtype_specialized_elementwise_fn_impl< + CTYPE_COMPUTE, + CTYPE_OUT, + support_noncontiguous_tensors>(compute_fun, ctx, out, inputs...); return; } - apply_elementwise_fn_generic_impl( + apply_elementwise_fn_generic_impl< + CTYPE_COMPUTE, + op_name, + support_noncontiguous_tensors>( compute_fun, ctx, out, out_dtypes, inputs...); } @@ -251,7 +273,31 @@ inline void apply_unitensor_elementwise_fn( const Tensor& a, SupportedTensorDtypes a_dtypes, const Tensor& out) { - internal::apply_elementwise_fn( + internal::apply_elementwise_fn< + CTYPE_COMPUTE, + op_name, + out_dtypes, + /*support_noncontiguous_tensors*/ false>( + compute_fun, ctx, out, std::make_pair(&a, a_dtypes)); +} + +template < + typename CTYPE_COMPUTE, + const char* op_name, + SupportedTensorDtypes out_dtypes, + typename Op> +inline void apply_unitensor_elementwise_fn( + const Op& compute_fun, + KernelRuntimeContext& ctx, + const Tensor& a, + SupportedTensorDtypes a_dtypes, + const Tensor& out, + SupportNoncontiguousTensors) { + internal::apply_elementwise_fn< + CTYPE_COMPUTE, + op_name, + out_dtypes, + /*support_noncontiguous_tensors*/ true>( compute_fun, ctx, out, std::make_pair(&a, a_dtypes)); } @@ -295,7 +341,37 @@ inline void apply_bitensor_elementwise_fn( const Tensor& b, SupportedTensorDtypes b_dtypes, const Tensor& out) { - internal::apply_elementwise_fn( + internal::apply_elementwise_fn< + CTYPE_COMPUTE, + op_name, + out_dtypes, + /*support_noncontiguous_tensors*/ false>( + compute_fun, + ctx, + out, + std::make_pair(&a, a_dtypes), + std::make_pair(&b, b_dtypes)); +} + +template < + typename CTYPE_COMPUTE, + const char* op_name, + SupportedTensorDtypes out_dtypes, + typename Op> +inline void apply_bitensor_elementwise_fn( + const Op& compute_fun, + KernelRuntimeContext& ctx, + const Tensor& a, + SupportedTensorDtypes a_dtypes, + const Tensor& b, + SupportedTensorDtypes b_dtypes, + const Tensor& out, + SupportNoncontiguousTensors) { + internal::apply_elementwise_fn< + CTYPE_COMPUTE, + op_name, + out_dtypes, + /*support_noncontiguous_tensors*/ true>( compute_fun, ctx, out, @@ -363,7 +439,40 @@ inline void apply_tritensor_elementwise_fn( const Tensor& c, SupportedTensorDtypes c_dtypes, const Tensor& out) { - internal::apply_elementwise_fn( + internal::apply_elementwise_fn< + CTYPE_COMPUTE, + op_name, + out_dtypes, + /*support_noncontiguous_tensors*/ false>( + compute_fun, + ctx, + out, + std::make_pair(&a, a_dtypes), + std::make_pair(&b, b_dtypes), + std::make_pair(&c, c_dtypes)); +} + +template < + typename CTYPE_COMPUTE, + const char* op_name, + SupportedTensorDtypes out_dtypes, + typename Op> +inline void apply_tritensor_elementwise_fn( + const Op& compute_fun, + KernelRuntimeContext& ctx, + const Tensor& a, + SupportedTensorDtypes a_dtypes, + const Tensor& b, + SupportedTensorDtypes b_dtypes, + const Tensor& c, + SupportedTensorDtypes c_dtypes, + const Tensor& out, + SupportNoncontiguousTensors) { + internal::apply_elementwise_fn< + CTYPE_COMPUTE, + op_name, + out_dtypes, + /*support_noncontiguous_tensors*/ true>( compute_fun, ctx, out, diff --git a/kernels/test/op_glu_test.cpp b/kernels/test/op_glu_test.cpp index b18117eaa4e..ac931302f98 100644 --- a/kernels/test/op_glu_test.cpp +++ b/kernels/test/op_glu_test.cpp @@ -28,9 +28,10 @@ class OpGluOutTest : public OperatorTest { return torch::executor::aten::glu_outf(context_, self, dim, out); } - template + template void expect_tensor_close(Tensor actual, Tensor expected) { - if (DTYPE == ScalarType::Half || DTYPE == ScalarType::BFloat16) { + if (DTYPE == ScalarType::Half || DTYPE == ScalarType::BFloat16 || + OUT_DTYPE == ScalarType::Half || OUT_DTYPE == ScalarType::BFloat16) { EXPECT_TENSOR_CLOSE_WITH_TOL( actual, expected, @@ -54,14 +55,14 @@ class OpGluOutTest : public OperatorTest { Tensor in = tf.make(sizes, {0, 1, 2, 3, 4, 5, 6, 7}); Tensor out = tf_out.zeros(out_sizes_1); op_glu_out(in, 0, out); - expect_tensor_close( + expect_tensor_close( out, tf_out.make( out_sizes_1, /*data=*/{0, 0.99330717, 1.99505484, 2.99726701})); const std::vector out_sizes_2 = {4, 1}; out = tf_out.zeros(out_sizes_2); op_glu_out(in, 1, out); - expect_tensor_close( + expect_tensor_close( out, tf_out.make( out_sizes_2, /*data=*/{0, 1.90514827, 3.97322869, 5.99453402})); diff --git a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl index a731ce5c674..96941590dd4 100644 --- a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -618,6 +618,7 @@ ATEN_OPS = ( name = "op_glu", deps = [ "//executorch/kernels/portable/cpu/util:activation_ops_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", "//executorch/runtime/core/exec_aten/util:scalar_type_util", "//executorch/runtime/core/exec_aten/util:tensor_util", ],