Skip to content

Commit 0c186d0

Browse files
authored
float sum aggregation has been fixed (#19466)
1 parent ac19098 commit 0c186d0

File tree

3 files changed

+222
-1
lines changed

3 files changed

+222
-1
lines changed

ydb/core/formats/arrow/program/functions.cpp

Lines changed: 169 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,176 @@
11
#include "functions.h"
22

33
#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/api.h>
4+
#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h>
5+
#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/kernels/codegen_internal.h>
6+
#include <contrib/libs/apache/arrow/cpp/src/arrow/compute/registry_internal.h>
47
#include <contrib/libs/apache/arrow/cpp/src/arrow/table.h>
58

69
namespace NKikimr::NArrow::NSSA {
10+
11+
namespace internal {
12+
13+
// Find the largest compatible primitive type for a primitive type.
14+
template <typename I, typename Enable = void>
15+
struct FindAccumulatorType {};
16+
17+
template <typename I>
18+
struct FindAccumulatorType<I, arrow::enable_if_boolean<I>> {
19+
using Type = arrow::UInt64Type;
20+
};
21+
22+
template <typename I>
23+
struct FindAccumulatorType<I, arrow::enable_if_signed_integer<I>> {
24+
using Type = arrow::Int64Type;
25+
};
26+
27+
template <typename I>
28+
struct FindAccumulatorType<I, arrow::enable_if_unsigned_integer<I>> {
29+
using Type = arrow::UInt64Type;
30+
};
31+
32+
template <typename I>
33+
struct FindAccumulatorType<I, arrow::enable_if_floating_point<I>> {
34+
using Type = arrow::DoubleType;
35+
};
36+
37+
template <>
38+
struct FindAccumulatorType<arrow::FloatType, void> {
39+
using Type = arrow::FloatType;
40+
};
41+
42+
template <typename ArrowType, arrow::compute::SimdLevel::type SimdLevel>
43+
struct SumImpl : public arrow::compute::ScalarAggregator {
44+
using ThisType = SumImpl<ArrowType, SimdLevel>;
45+
using CType = typename ArrowType::c_type;
46+
using SumType = typename FindAccumulatorType<ArrowType>::Type;
47+
using OutputType = typename arrow::TypeTraits<SumType>::ScalarType;
48+
49+
arrow::Status Consume(arrow::compute::KernelContext*, const arrow::compute::ExecBatch& batch) override {
50+
if (batch[0].is_array()) {
51+
const auto& data = batch[0].array();
52+
this->Count += data->length - data->GetNullCount();
53+
if (arrow::is_boolean_type<ArrowType>::value) {
54+
this->Sum +=
55+
static_cast<typename SumType::c_type>(arrow::BooleanArray(data).true_count());
56+
} else {
57+
this->Sum +=
58+
arrow::compute::detail::SumArray<CType, typename SumType::c_type, SimdLevel>(
59+
*data);
60+
}
61+
} else {
62+
const auto& data = *batch[0].scalar();
63+
this->Count += data.is_valid * batch.length;
64+
if (data.is_valid) {
65+
this->Sum += arrow::compute::internal::UnboxScalar<ArrowType>::Unbox(data) * batch.length;
66+
}
67+
}
68+
return arrow::Status::OK();
69+
}
70+
71+
arrow::Status MergeFrom(arrow::compute::KernelContext*, arrow::compute::KernelState&& src) override {
72+
const auto& other = arrow::checked_cast<const ThisType&>(src);
73+
this->Count += other.Count;
74+
this->Sum += other.Sum;
75+
return arrow::Status::OK();
76+
}
77+
78+
arrow::Status Finalize(arrow::compute::KernelContext*, arrow::Datum* out) override {
79+
if (this->Count < Options.min_count) {
80+
out->value = std::make_shared<OutputType>();
81+
} else {
82+
out->value = arrow::MakeScalar(this->Sum);
83+
}
84+
return arrow::Status::OK();
85+
}
86+
87+
size_t Count = 0;
88+
typename SumType::c_type Sum = 0;
89+
arrow::compute::ScalarAggregateOptions Options;
90+
};
91+
92+
template <typename ArrowType>
93+
struct SumImplDefault : public SumImpl<ArrowType, arrow::compute::SimdLevel::NONE> {
94+
explicit SumImplDefault(const arrow::compute::ScalarAggregateOptions& options) {
95+
this->Options = options;
96+
}
97+
};
98+
99+
void AddScalarAggKernels(arrow::compute::KernelInit init,
100+
const std::vector<std::shared_ptr<arrow::DataType>>& types,
101+
std::shared_ptr<arrow::DataType> out_ty,
102+
arrow::compute::ScalarAggregateFunction* func) {
103+
for (const auto& ty : types) {
104+
// scalar[InT] -> scalar[OutT]
105+
auto sig = arrow::compute::KernelSignature::Make({arrow::compute::InputType::Scalar(ty)}, arrow::ValueDescr::Scalar(out_ty));
106+
AddAggKernel(std::move(sig), init, func, arrow::compute::SimdLevel::NONE);
107+
}
108+
}
109+
110+
void AddArrayScalarAggKernels(arrow::compute::KernelInit init,
111+
const std::vector<std::shared_ptr<arrow::DataType>>& types,
112+
std::shared_ptr<arrow::DataType> out_ty,
113+
arrow::compute::ScalarAggregateFunction* func,
114+
arrow::compute::SimdLevel::type simd_level = arrow::compute::SimdLevel::NONE) {
115+
arrow::compute::aggregate::AddBasicAggKernels(init, types, out_ty, func, simd_level);
116+
AddScalarAggKernels(init, types, out_ty, func);
117+
}
118+
119+
arrow::Result<std::unique_ptr<arrow::compute::KernelState>> SumInit(arrow::compute::KernelContext* ctx,
120+
const arrow::compute::KernelInitArgs& args) {
121+
arrow::compute::aggregate::SumLikeInit<SumImplDefault> visitor(
122+
ctx, *args.inputs[0].type,
123+
static_cast<const arrow::compute::ScalarAggregateOptions&>(*args.options));
124+
return visitor.Create();
125+
}
126+
127+
static std::unique_ptr<arrow::compute::FunctionRegistry> CreateCustomRegistry() {
128+
arrow::compute::FunctionRegistry* defaultRegistry = arrow::compute::GetFunctionRegistry();
129+
auto registry = arrow::compute::FunctionRegistry::Make();
130+
for (const auto& func : defaultRegistry->GetFunctionNames()) {
131+
if (func == "sum") {
132+
auto aggregateFunc = dynamic_cast<arrow::compute::ScalarAggregateFunction*>(defaultRegistry->GetFunction(func)->get());
133+
if (!aggregateFunc) {
134+
DCHECK_OK(registry->AddFunction(*defaultRegistry->GetFunction(func)));
135+
continue;
136+
}
137+
arrow::compute::ScalarAggregateFunction newFunc(func, aggregateFunc->arity(), &aggregateFunc->doc(), aggregateFunc->default_options());
138+
for (const arrow::compute::ScalarAggregateKernel* kernel : aggregateFunc->kernels()) {
139+
auto shouldReplaceKernel = [](const arrow::compute::ScalarAggregateKernel& kernel) {
140+
const auto& params = kernel.signature->in_types();
141+
if (params.empty()) {
142+
return false;
143+
}
144+
145+
if (params[0].kind() == arrow::compute::InputType::Kind::EXACT_TYPE) {
146+
auto type = params[0].type();
147+
return type->id() == arrow::Type::FLOAT;
148+
}
149+
150+
return false;
151+
};
152+
153+
if (shouldReplaceKernel(*kernel)) {
154+
AddArrayScalarAggKernels(SumInit, {arrow::float32()}, arrow::float32(), &newFunc);
155+
} else {
156+
DCHECK_OK(newFunc.AddKernel(*kernel));
157+
}
158+
}
159+
DCHECK_OK(registry->AddFunction(std::make_shared<arrow::compute::ScalarAggregateFunction>(std::move(newFunc))));
160+
} else {
161+
DCHECK_OK(registry->AddFunction(*defaultRegistry->GetFunction(func)));
162+
}
163+
}
164+
165+
return registry;
166+
}
167+
arrow::compute::FunctionRegistry* GetCustomFunctionRegistry() {
168+
static auto registry = internal::CreateCustomRegistry();
169+
return registry.get();
170+
}
171+
172+
} // namespace internal
173+
7174
TConclusion<arrow::Datum> TInternalFunction::Call(
8175
const TExecFunctionContext& context, const std::shared_ptr<TAccessorsCollection>& resources) const {
9176
auto funcNames = GetRegistryFunctionNames();
@@ -16,7 +183,8 @@ TConclusion<arrow::Datum> TInternalFunction::Call(
16183
if (GetContext() && GetContext()->func_registry()->GetFunction(funcName).ok()) {
17184
result = arrow::compute::CallFunction(funcName, *arguments, FunctionOptions.get(), GetContext());
18185
} else {
19-
result = arrow::compute::CallFunction(funcName, *arguments, FunctionOptions.get());
186+
arrow::compute::ExecContext defaultContext(arrow::default_memory_pool(), nullptr, internal::GetCustomFunctionRegistry());
187+
result = arrow::compute::CallFunction(funcName, *arguments, FunctionOptions.get(), &defaultContext);
20188
}
21189

22190
if (result.ok() && funcName == "count"sv) {

ydb/core/formats/arrow/program/ya.make

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,8 @@ GENERATE_ENUM_SERIALIZATION(execution.h)
5757

5858
YQL_LAST_ABI_VERSION()
5959

60+
CFLAGS(
61+
-Wno-unused-parameter
62+
)
63+
6064
END()

ydb/core/kqp/ut/olap/aggregations_ut.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1375,6 +1375,55 @@ Y_UNIT_TEST_SUITE(KqpOlapAggregations) {
13751375

13761376
TestTableWithNulls({ testCase }, /* generic */ true);
13771377
}
1378+
1379+
Y_UNIT_TEST(FloatSum) {
1380+
NKikimrConfig::TAppConfig appConfig;
1381+
appConfig.MutableTableServiceConfig()->SetEnableOlapSink(true);
1382+
auto settings = TKikimrSettings()
1383+
.SetAppConfig(appConfig)
1384+
.SetWithSampleTables(false);
1385+
TKikimrRunner kikimr(settings);
1386+
1387+
auto queryClient = kikimr.GetQueryClient();
1388+
{
1389+
auto status = queryClient.ExecuteQuery(
1390+
R"(
1391+
CREATE TABLE `olap_table` (
1392+
id Uint64 NOT NULL,
1393+
value Float,
1394+
PRIMARY KEY (id)
1395+
) WITH (STORE = COLUMN);
1396+
)", NYdb::NQuery::TTxControl::NoTx()
1397+
).GetValueSync();
1398+
UNIT_ASSERT_C(status.IsSuccess(), status.GetIssues().ToString());
1399+
}
1400+
1401+
{
1402+
auto status = queryClient.ExecuteQuery(
1403+
R"(
1404+
INSERT INTO `olap_table` (id, value) VALUES (1u, 0.4f);
1405+
INSERT INTO `olap_table` (id, value) VALUES (2u, 0.85f);
1406+
INSERT INTO `olap_table` (id, value) VALUES (3u, 11.3f);
1407+
INSERT INTO `olap_table` (id, value) VALUES (4u, 7.15f);
1408+
INSERT INTO `olap_table` (id, value) VALUES (5u, 0.3f);
1409+
)", NYdb::NQuery::TTxControl::BeginTx().CommitTx()
1410+
).GetValueSync();
1411+
UNIT_ASSERT_C(status.IsSuccess(), status.GetIssues().ToString());
1412+
}
1413+
1414+
{
1415+
auto status = queryClient.ExecuteQuery(R"(
1416+
--!syntax_v1
1417+
SELECT SUM(value) FROM `olap_table`
1418+
WHERE id = 1
1419+
)", NYdb::NQuery::TTxControl::BeginTx().CommitTx()
1420+
).GetValueSync();
1421+
1422+
UNIT_ASSERT_C(status.IsSuccess(), status.GetIssues().ToString());
1423+
TString result = FormatResultSetYson(status.GetResultSet(0));
1424+
CompareYson(result, R"([[[0.400000006;]]])");
1425+
}
1426+
}
13781427
}
13791428

13801429
}

0 commit comments

Comments
 (0)