1
1
#include " functions.h"
2
2
3
3
#include < contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h>
4
+ #include < contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h>
5
+ #include < contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/codegen_internal.h>
6
+ #include < contrib/libs/apache/arrow/cpp/src/arrow/compute/registry_internal.h>
4
7
#include < contrib/libs/apache/arrow/cpp/src/arrow/table.h>
5
8
6
9
namespace NKikimr ::NArrow::NSSA {
10
+
11
+ namespace internal {
12
+
13
+ // Find the largest compatible primitive type for a primitive type.
14
+ template <typename I, typename Enable = void >
15
+ struct FindAccumulatorType {};
16
+
17
+ template <typename I>
18
+ struct FindAccumulatorType <I, arrow::enable_if_boolean<I>> {
19
+ using Type = arrow::UInt64Type;
20
+ };
21
+
22
+ template <typename I>
23
+ struct FindAccumulatorType <I, arrow::enable_if_signed_integer<I>> {
24
+ using Type = arrow::Int64Type;
25
+ };
26
+
27
+ template <typename I>
28
+ struct FindAccumulatorType <I, arrow::enable_if_unsigned_integer<I>> {
29
+ using Type = arrow::UInt64Type;
30
+ };
31
+
32
+ template <typename I>
33
+ struct FindAccumulatorType <I, arrow::enable_if_floating_point<I>> {
34
+ using Type = arrow::DoubleType;
35
+ };
36
+
37
+ template <>
38
+ struct FindAccumulatorType <arrow::FloatType, void > {
39
+ using Type = arrow::FloatType;
40
+ };
41
+
42
+ template <typename ArrowType, arrow::compute::SimdLevel::type SimdLevel>
43
+ struct SumImpl : public arrow ::compute::ScalarAggregator {
44
+ using ThisType = SumImpl<ArrowType, SimdLevel>;
45
+ using CType = typename ArrowType::c_type;
46
+ using SumType = typename FindAccumulatorType<ArrowType>::Type;
47
+ using OutputType = typename arrow::TypeTraits<SumType>::ScalarType;
48
+
49
+ arrow::Status Consume (arrow::compute::KernelContext*, const arrow::compute::ExecBatch& batch) override {
50
+ if (batch[0 ].is_array ()) {
51
+ const auto & data = batch[0 ].array ();
52
+ this ->Count += data->length - data->GetNullCount ();
53
+ if (arrow::is_boolean_type<ArrowType>::value) {
54
+ this ->Sum +=
55
+ static_cast <typename SumType::c_type>(arrow::BooleanArray (data).true_count ());
56
+ } else {
57
+ this ->Sum +=
58
+ arrow::compute::detail::SumArray<CType, typename SumType::c_type, SimdLevel>(
59
+ *data);
60
+ }
61
+ } else {
62
+ const auto & data = *batch[0 ].scalar ();
63
+ this ->Count += data.is_valid * batch.length ;
64
+ if (data.is_valid ) {
65
+ this ->Sum += arrow::compute::internal::UnboxScalar<ArrowType>::Unbox (data) * batch.length ;
66
+ }
67
+ }
68
+ return arrow::Status::OK ();
69
+ }
70
+
71
+ arrow::Status MergeFrom (arrow::compute::KernelContext*, arrow::compute::KernelState&& src) override {
72
+ const auto & other = arrow::checked_cast<const ThisType&>(src);
73
+ this ->Count += other.Count ;
74
+ this ->Sum += other.Sum ;
75
+ return arrow::Status::OK ();
76
+ }
77
+
78
+ arrow::Status Finalize (arrow::compute::KernelContext*, arrow::Datum* out) override {
79
+ if (this ->Count < Options.min_count ) {
80
+ out->value = std::make_shared<OutputType>();
81
+ } else {
82
+ out->value = arrow::MakeScalar (this ->Sum );
83
+ }
84
+ return arrow::Status::OK ();
85
+ }
86
+
87
+ size_t Count = 0 ;
88
+ typename SumType::c_type Sum = 0 ;
89
+ arrow::compute::ScalarAggregateOptions Options;
90
+ };
91
+
92
+ template <typename ArrowType>
93
+ struct SumImplDefault : public SumImpl <ArrowType, arrow::compute::SimdLevel::NONE> {
94
+ explicit SumImplDefault (const arrow::compute::ScalarAggregateOptions& options) {
95
+ this ->Options = options;
96
+ }
97
+ };
98
+
99
+ void AddScalarAggKernels (arrow::compute::KernelInit init,
100
+ const std::vector<std::shared_ptr<arrow::DataType>>& types,
101
+ std::shared_ptr<arrow::DataType> out_ty,
102
+ arrow::compute::ScalarAggregateFunction* func) {
103
+ for (const auto & ty : types) {
104
+ // scalar[InT] -> scalar[OutT]
105
+ auto sig = arrow::compute::KernelSignature::Make ({arrow::compute::InputType::Scalar (ty)}, arrow::ValueDescr::Scalar (out_ty));
106
+ AddAggKernel (std::move (sig), init, func, arrow::compute::SimdLevel::NONE);
107
+ }
108
+ }
109
+
110
+ void AddArrayScalarAggKernels (arrow::compute::KernelInit init,
111
+ const std::vector<std::shared_ptr<arrow::DataType>>& types,
112
+ std::shared_ptr<arrow::DataType> out_ty,
113
+ arrow::compute::ScalarAggregateFunction* func,
114
+ arrow::compute::SimdLevel::type simd_level = arrow::compute::SimdLevel::NONE) {
115
+ arrow::compute::aggregate::AddBasicAggKernels (init, types, out_ty, func, simd_level);
116
+ AddScalarAggKernels (init, types, out_ty, func);
117
+ }
118
+
119
+ arrow::Result<std::unique_ptr<arrow::compute::KernelState>> SumInit (arrow::compute::KernelContext* ctx,
120
+ const arrow::compute::KernelInitArgs& args) {
121
+ arrow::compute::aggregate::SumLikeInit<SumImplDefault> visitor (
122
+ ctx, *args.inputs [0 ].type ,
123
+ static_cast <const arrow::compute::ScalarAggregateOptions&>(*args.options ));
124
+ return visitor.Create ();
125
+ }
126
+
127
+ static std::unique_ptr<arrow::compute::FunctionRegistry> CreateCustomRegistry () {
128
+ arrow::compute::FunctionRegistry* defaultRegistry = arrow::compute::GetFunctionRegistry ();
129
+ auto registry = arrow::compute::FunctionRegistry::Make ();
130
+ for (const auto & func : defaultRegistry->GetFunctionNames ()) {
131
+ if (func == " sum" ) {
132
+ auto aggregateFunc = dynamic_cast <arrow::compute::ScalarAggregateFunction*>(defaultRegistry->GetFunction (func)->get ());
133
+ if (!aggregateFunc) {
134
+ DCHECK_OK (registry->AddFunction (*defaultRegistry->GetFunction (func)));
135
+ continue ;
136
+ }
137
+ arrow::compute::ScalarAggregateFunction newFunc (func, aggregateFunc->arity (), &aggregateFunc->doc (), aggregateFunc->default_options ());
138
+ for (const arrow::compute::ScalarAggregateKernel* kernel : aggregateFunc->kernels ()) {
139
+ auto shouldReplaceKernel = [](const arrow::compute::ScalarAggregateKernel& kernel) {
140
+ const auto & params = kernel.signature ->in_types ();
141
+ if (params.empty ()) {
142
+ return false ;
143
+ }
144
+
145
+ if (params[0 ].kind () == arrow::compute::InputType::Kind::EXACT_TYPE) {
146
+ auto type = params[0 ].type ();
147
+ return type->id () == arrow::Type::FLOAT;
148
+ }
149
+
150
+ return false ;
151
+ };
152
+
153
+ if (shouldReplaceKernel (*kernel)) {
154
+ AddArrayScalarAggKernels (SumInit, {arrow::float32 ()}, arrow::float32 (), &newFunc);
155
+ } else {
156
+ DCHECK_OK (newFunc.AddKernel (*kernel));
157
+ }
158
+ }
159
+ DCHECK_OK (registry->AddFunction (std::make_shared<arrow::compute::ScalarAggregateFunction>(std::move (newFunc))));
160
+ } else {
161
+ DCHECK_OK (registry->AddFunction (*defaultRegistry->GetFunction (func)));
162
+ }
163
+ }
164
+
165
+ return registry;
166
+ }
167
+ arrow::compute::FunctionRegistry* GetCustomFunctionRegistry () {
168
+ static auto registry = internal::CreateCustomRegistry ();
169
+ return registry.get ();
170
+ }
171
+
172
+ } // namespace internal
173
+
7
174
TConclusion<arrow::Datum> TInternalFunction::Call (
8
175
const TExecFunctionContext& context, const std::shared_ptr<TAccessorsCollection>& resources) const {
9
176
auto funcNames = GetRegistryFunctionNames ();
@@ -16,7 +183,8 @@ TConclusion<arrow::Datum> TInternalFunction::Call(
16
183
if (GetContext () && GetContext ()->func_registry ()->GetFunction (funcName).ok ()) {
17
184
result = arrow::compute::CallFunction (funcName, *arguments, FunctionOptions.get (), GetContext ());
18
185
} else {
19
- result = arrow::compute::CallFunction (funcName, *arguments, FunctionOptions.get ());
186
+ arrow::compute::ExecContext defaultContext (arrow::default_memory_pool (), nullptr , internal::GetCustomFunctionRegistry ());
187
+ result = arrow::compute::CallFunction (funcName, *arguments, FunctionOptions.get (), &defaultContext);
20
188
}
21
189
22
190
if (result.ok () && funcName == " count" sv) {
0 commit comments