This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 0a59580aa4 [Enhancement](function) fix compatibility issues of
sum/count during upgrade process (#20890)
0a59580aa4 is described below
commit 0a59580aa47baee2236b3d26eb82cd764a882b75
Author: zhangstar333 <[email protected]>
AuthorDate: Sat Jun 17 12:51:01 2023 +0800
[Enhancement](function) fix compatibility issues of sum/count during
upgrade process (#20890)
in order to solve agg of sum/count is not compatibility during the upgrade
process.
in PR [refactor](agg_state) refactor agg_state type to support fixed length
object type #20370 have changed the serialize type and serialize column of
sum/count
before is ColumnVector, now sum/count change to use ColumnFixedLengthObject
so during the upgrade process, will be not compatible if exist Old BE and
Newer BE
---
be/src/agent/be_exec_version_manager.h | 2 +
.../aggregate_function_count_old.cpp | 51 +++++
.../aggregate_function_count_old.h | 248 +++++++++++++++++++++
.../aggregate_function_simple_factory.cpp | 4 +
.../aggregate_function_simple_factory.h | 28 ++-
.../aggregate_function_sum_old.cpp | 33 +++
.../aggregate_function_sum_old.h | 185 +++++++++++++++
be/src/vec/exprs/vectorized_agg_fn.cpp | 3 +-
8 files changed, 552 insertions(+), 2 deletions(-)
diff --git a/be/src/agent/be_exec_version_manager.h
b/be/src/agent/be_exec_version_manager.h
index 23c692bb86..0491a038c8 100644
--- a/be/src/agent/be_exec_version_manager.h
+++ b/be/src/agent/be_exec_version_manager.h
@@ -54,6 +54,8 @@ private:
* b. runtime filter use new hash method.
* 2: start from doris 2.0
* a. function month/day/hour/minute/second's return type is changed to
smaller type.
+ * b. in order to solve agg of sum/count is not compatibility during the
upgrade process
+ *
*/
inline const int BeExecVersionManager::max_be_exec_version = 2;
inline const int BeExecVersionManager::min_be_exec_version = 0;
diff --git a/be/src/vec/aggregate_functions/aggregate_function_count_old.cpp
b/be/src/vec/aggregate_functions/aggregate_function_count_old.cpp
new file mode 100644
index 0000000000..c480c09a5a
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_count_old.cpp
@@ -0,0 +1,51 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+// This file is copied from
+//
https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionCount.cpp
+// and modified by Doris
+
+#include "vec/aggregate_functions/aggregate_function_count_old.h"
+
+#include <string>
+
+#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/aggregate_functions/factory_helpers.h"
+
+namespace doris::vectorized {
+
+AggregateFunctionPtr create_aggregate_function_count_old(const std::string&
name,
+ const DataTypes&
argument_types,
+ const bool
result_is_nullable) {
+ assert_arity_at_most<1>(name, argument_types);
+
+ return std::make_shared<AggregateFunctionCountOld>(argument_types);
+}
+
+AggregateFunctionPtr create_aggregate_function_count_not_null_unary_old(
+ const std::string& name, const DataTypes& argument_types, const bool
result_is_nullable) {
+ assert_arity_at_most<1>(name, argument_types);
+
+ return
std::make_shared<AggregateFunctionCountNotNullUnaryOld>(argument_types);
+}
+
+void register_aggregate_function_count_old(AggregateFunctionSimpleFactory&
factory) {
+ factory.register_alternative_function("count",
create_aggregate_function_count_old, false);
+ factory.register_alternative_function("count",
+
create_aggregate_function_count_not_null_unary_old, true);
+}
+
+} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_count_old.h
b/be/src/vec/aggregate_functions/aggregate_function_count_old.h
new file mode 100644
index 0000000000..ff935b6a67
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_count_old.h
@@ -0,0 +1,248 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+// This file is copied from
+//
https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionCountOld.h
+// and modified by Doris
+
+#pragma once
+
+#include <stddef.h>
+
+#include <algorithm>
+#include <boost/iterator/iterator_facade.hpp>
+#include <memory>
+#include <vector>
+
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/columns/column.h"
+#include "vec/columns/column_nullable.h"
+#include "vec/columns/column_vector.h"
+#include "vec/columns/columns_number.h"
+#include "vec/common/assert_cast.h"
+#include "vec/core/types.h"
+#include "vec/data_types/data_type.h"
+#include "vec/data_types/data_type_number.h"
+#include "vec/io/var_int.h"
+
+namespace doris {
+namespace vectorized {
+class Arena;
+class BufferReadable;
+class BufferWritable;
+} // namespace vectorized
+} // namespace doris
+
+/*
+ * this function is used to solve agg of sum/count is not compatibility during
the upgrade process.
+ * in PR #20370 have changed the serialize type and serialize column.
+ * before is ColumnVector, now sum/count change to use ColumnFixedLengthObject.
+ * so during the upgrade process, will be not compatible if exist old BE and
Newer BE.
+ */
+
+namespace doris::vectorized {
+struct AggregateFunctionCountDataOld {
+ UInt64 count = 0;
+};
+/// Simply count number of calls.
+class AggregateFunctionCountOld final
+ : public IAggregateFunctionDataHelper<AggregateFunctionCountDataOld,
+ AggregateFunctionCountOld> {
+public:
+ AggregateFunctionCountOld(const DataTypes& argument_types_)
+ : IAggregateFunctionDataHelper(argument_types_) {}
+
+ String get_name() const override { return "count"; }
+
+ DataTypePtr get_return_type() const override { return
std::make_shared<DataTypeInt64>(); }
+
+ void add(AggregateDataPtr __restrict place, const IColumn**, size_t,
Arena*) const override {
+ ++data(place).count;
+ }
+
+ void reset(AggregateDataPtr place) const override {
+ AggregateFunctionCountOld::data(place).count = 0;
+ }
+
+ void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+ Arena*) const override {
+ data(place).count += data(rhs).count;
+ }
+
+ void serialize(ConstAggregateDataPtr __restrict place, BufferWritable&
buf) const override {
+ write_var_uint(data(place).count, buf);
+ }
+
+ void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+ Arena*) const override {
+ read_var_uint(data(place).count, buf);
+ }
+
+ void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn&
to) const override {
+ assert_cast<ColumnInt64&>(to).get_data().push_back(data(place).count);
+ }
+
+ void deserialize_from_column(AggregateDataPtr places, const IColumn&
column, Arena* arena,
+ size_t num_rows) const override {
+ auto data = assert_cast<const ColumnUInt64&>(column).get_data().data();
+ auto* dst_data = reinterpret_cast<Data*>(places);
+ for (size_t i = 0; i != num_rows; ++i) {
+ dst_data[i].count = data[i];
+ }
+ }
+
+ void serialize_to_column(const std::vector<AggregateDataPtr>& places,
size_t offset,
+ MutableColumnPtr& dst, const size_t num_rows)
const override {
+ auto& col = assert_cast<ColumnUInt64&>(*dst);
+ col.resize(num_rows);
+ auto* data = col.get_data().data();
+ for (size_t i = 0; i != num_rows; ++i) {
+ data[i] = AggregateFunctionCountOld::data(places[i] +
offset).count;
+ }
+ }
+
+ void streaming_agg_serialize_to_column(const IColumn** columns,
MutableColumnPtr& dst,
+ const size_t num_rows, Arena*
arena) const override {
+ auto& col = assert_cast<ColumnUInt64&>(*dst);
+ col.resize(num_rows);
+ col.get_data().assign(num_rows, assert_cast<UInt64>(1UL));
+ }
+
+ void deserialize_and_merge_from_column(AggregateDataPtr __restrict place,
const IColumn& column,
+ Arena* arena) const override {
+ auto data = assert_cast<const ColumnUInt64&>(column).get_data().data();
+ const size_t num_rows = column.size();
+ for (size_t i = 0; i != num_rows; ++i) {
+ AggregateFunctionCountOld::data(place).count += data[i];
+ }
+ }
+
+ void serialize_without_key_to_column(ConstAggregateDataPtr __restrict
place,
+ IColumn& to) const override {
+ auto& col = assert_cast<ColumnUInt64&>(to);
+ col.resize(1);
+ reinterpret_cast<Data*>(col.get_data().data())->count =
+ AggregateFunctionCountOld::data(place).count;
+ }
+
+ MutableColumnPtr create_serialize_column() const override {
+ return ColumnVector<UInt64>::create();
+ }
+
+ DataTypePtr get_serialized_type() const override { return
std::make_shared<DataTypeUInt64>(); }
+};
+
+// TODO: Maybe AggregateFunctionCountNotNullUnaryOld should be a subclass of
AggregateFunctionCountOld
+// Simply count number of not-NULL values.
+class AggregateFunctionCountNotNullUnaryOld final
+ : public IAggregateFunctionDataHelper<AggregateFunctionCountDataOld,
+
AggregateFunctionCountNotNullUnaryOld> {
+public:
+ AggregateFunctionCountNotNullUnaryOld(const DataTypes& argument_types_)
+ : IAggregateFunctionDataHelper(argument_types_) {}
+
+ String get_name() const override { return "count"; }
+
+ DataTypePtr get_return_type() const override { return
std::make_shared<DataTypeInt64>(); }
+
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ Arena*) const override {
+ data(place).count += !assert_cast<const
ColumnNullable&>(*columns[0]).is_null_at(row_num);
+ }
+
+ void reset(AggregateDataPtr place) const override { data(place).count = 0;
}
+
+ void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+ Arena*) const override {
+ data(place).count += data(rhs).count;
+ }
+
+ void serialize(ConstAggregateDataPtr __restrict place, BufferWritable&
buf) const override {
+ write_var_uint(data(place).count, buf);
+ }
+
+ void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+ Arena*) const override {
+ read_var_uint(data(place).count, buf);
+ }
+
+ void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn&
to) const override {
+ if (to.is_nullable()) {
+ auto& null_column = assert_cast<ColumnNullable&>(to);
+ null_column.get_null_map_data().push_back(0);
+ assert_cast<ColumnInt64&>(null_column.get_nested_column())
+ .get_data()
+ .push_back(data(place).count);
+ } else {
+
assert_cast<ColumnInt64&>(to).get_data().push_back(data(place).count);
+ }
+ }
+
+ void deserialize_from_column(AggregateDataPtr places, const IColumn&
column, Arena* arena,
+ size_t num_rows) const override {
+ auto data = assert_cast<const ColumnUInt64&>(column).get_data().data();
+ auto* dst_data = reinterpret_cast<Data*>(places);
+ for (size_t i = 0; i != num_rows; ++i) {
+ dst_data[i].count = data[i];
+ }
+ }
+
+ void serialize_to_column(const std::vector<AggregateDataPtr>& places,
size_t offset,
+ MutableColumnPtr& dst, const size_t num_rows)
const override {
+ auto& col = assert_cast<ColumnUInt64&>(*dst);
+ col.resize(num_rows);
+ auto* data = col.get_data().data();
+ for (size_t i = 0; i != num_rows; ++i) {
+ data[i] = AggregateFunctionCountNotNullUnaryOld::data(places[i] +
offset).count;
+ }
+ }
+
+ void streaming_agg_serialize_to_column(const IColumn** columns,
MutableColumnPtr& dst,
+ const size_t num_rows, Arena*
arena) const override {
+ auto& col = assert_cast<ColumnUInt64&>(*dst);
+ col.resize(num_rows);
+ auto& data = col.get_data();
+ const ColumnNullable& input_col = assert_cast<const
ColumnNullable&>(*columns[0]);
+ for (size_t i = 0; i < num_rows; i++) {
+ data[i] = !input_col.is_null_at(i);
+ }
+ }
+
+ void deserialize_and_merge_from_column(AggregateDataPtr __restrict place,
const IColumn& column,
+ Arena* arena) const override {
+ auto data = assert_cast<const ColumnUInt64&>(column).get_data().data();
+ const size_t num_rows = column.size();
+ for (size_t i = 0; i != num_rows; ++i) {
+ AggregateFunctionCountNotNullUnaryOld::data(place).count +=
data[i];
+ }
+ }
+
+ void serialize_without_key_to_column(ConstAggregateDataPtr __restrict
place,
+ IColumn& to) const override {
+ auto& col = assert_cast<ColumnUInt64&>(to);
+ col.resize(1);
+ reinterpret_cast<Data*>(col.get_data().data())->count =
+ AggregateFunctionCountNotNullUnaryOld::data(place).count;
+ }
+
+ MutableColumnPtr create_serialize_column() const override {
+ return ColumnVector<UInt64>::create();
+ }
+
+ DataTypePtr get_serialized_type() const override { return
std::make_shared<DataTypeUInt64>(); }
+};
+
+} // namespace doris::vectorized
\ No newline at end of file
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
index 8c0ae92c07..5ab8eb874a 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
@@ -57,6 +57,8 @@ void
register_aggregate_function_collect_list(AggregateFunctionSimpleFactory& fa
void
register_aggregate_function_sequence_match(AggregateFunctionSimpleFactory&
factory);
void register_aggregate_function_avg_weighted(AggregateFunctionSimpleFactory&
factory);
void register_aggregate_function_histogram(AggregateFunctionSimpleFactory&
factory);
+void register_aggregate_function_count_old(AggregateFunctionSimpleFactory&
factory);
+void register_aggregate_function_sum_old(AggregateFunctionSimpleFactory&
factory);
AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
static std::once_flag oc;
@@ -68,6 +70,8 @@ AggregateFunctionSimpleFactory&
AggregateFunctionSimpleFactory::instance() {
register_aggregate_function_max_by(instance);
register_aggregate_function_avg(instance);
register_aggregate_function_count(instance);
+ register_aggregate_function_count_old(instance);
+ register_aggregate_function_sum_old(instance);
register_aggregate_function_uniq(instance);
register_aggregate_function_bit(instance);
register_aggregate_function_bitmap(instance);
diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
index bff49e9d9e..618340dd88 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
@@ -27,6 +27,7 @@
#include <utility>
#include <vector>
+#include "agent/be_exec_version_manager.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/data_types/data_type.h"
@@ -46,6 +47,11 @@ private:
AggregateFunctions aggregate_functions;
AggregateFunctions nullable_aggregate_functions;
std::unordered_map<std::string, std::string> function_alias;
+ /// @TEMPORARY: for be_exec_version=2
+ /// in order to solve agg of sum/count is not compatibility during the
upgrade process
+ constexpr static int AGG_FUNCTION_NEW = 2;
+ /// @TEMPORARY: for be_exec_version < AGG_FUNCTION_NEW. replace function
to old version.
+ std::unordered_map<std::string, std::string> function_to_replace;
public:
void register_nullable_function_combinator(const Creator& creator) {
@@ -73,7 +79,8 @@ public:
}
AggregateFunctionPtr get(const std::string& name, const DataTypes&
argument_types,
- const bool result_is_nullable = false) {
+ const bool result_is_nullable = false,
+ int be_version =
BeExecVersionManager::get_newest_version()) {
bool nullable = false;
for (const auto& type : argument_types) {
if (type->is_nullable()) {
@@ -82,6 +89,8 @@ public:
}
std::string name_str = name;
+ temporary_function_update(be_version, name_str);
+
if (function_alias.count(name)) {
name_str = function_alias[name];
}
@@ -116,6 +125,23 @@ public:
function_alias[alias] = name;
}
+ /// @TEMPORARY: for be_exec_version < AGG_FUNCTION_NEW
+ void register_alternative_function(const std::string& name, const Creator&
creator,
+ bool nullable = false) {
+ static std::string suffix {"_old_for_version_before_2_0"};
+ register_function(name + suffix, creator, nullable);
+ function_to_replace[name] = name + suffix;
+ }
+
+ /// @TEMPORARY: for be_exec_version < AGG_FUNCTION_NEW
+ void temporary_function_update(int fe_version_now, std::string& name) {
+ // replace if fe is old version.
+ if (fe_version_now < AGG_FUNCTION_NEW &&
+ function_to_replace.find(name) != function_to_replace.end()) {
+ name = function_to_replace[name];
+ }
+ }
+
public:
static AggregateFunctionSimpleFactory& instance();
};
diff --git a/be/src/vec/aggregate_functions/aggregate_function_sum_old.cpp
b/be/src/vec/aggregate_functions/aggregate_function_sum_old.cpp
new file mode 100644
index 0000000000..9b5d3cf85a
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_sum_old.cpp
@@ -0,0 +1,33 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+// This file is copied from
+//
https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionSum.cpp
+// and modified by Doris
+
+#include "vec/aggregate_functions/aggregate_function_sum_old.h"
+
+#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/aggregate_functions/helpers.h"
+
+namespace doris::vectorized {
+void register_aggregate_function_sum_old(AggregateFunctionSimpleFactory&
factory) {
+ factory.register_alternative_function(
+ "sum", creator_with_type::creator<AggregateFunctionSumSimpleOld>,
false);
+ factory.register_alternative_function(
+ "sum", creator_with_type::creator<AggregateFunctionSumSimpleOld>,
true);
+}
+} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_sum_old.h
b/be/src/vec/aggregate_functions/aggregate_function_sum_old.h
new file mode 100644
index 0000000000..a3b7280fd2
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_sum_old.h
@@ -0,0 +1,185 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+// This file is copied from
+//
https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionOldSum.h
+// and modified by Doris
+
+#pragma once
+
+#include <stddef.h>
+
+#include <memory>
+#include <type_traits>
+#include <vector>
+
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/aggregate_functions/aggregate_function_sum.h"
+#include "vec/columns/column.h"
+#include "vec/common/assert_cast.h"
+#include "vec/core/field.h"
+#include "vec/core/types.h"
+#include "vec/data_types/data_type.h"
+#include "vec/data_types/data_type_decimal.h"
+#include "vec/io/io_helper.h"
+namespace doris {
+namespace vectorized {
+class Arena;
+class BufferReadable;
+class BufferWritable;
+template <typename T>
+class ColumnDecimal;
+template <typename T>
+class DataTypeNumber;
+template <typename>
+class ColumnVector;
+} // namespace vectorized
+} // namespace doris
+
+namespace doris::vectorized {
+
+/*
+ * this function is used to solve agg of sum/count is not compatibility during
the upgrade process
+ * in PR #20370 have changed the serialize type and serialize column
+ * before is ColumnVector, now sum/count change to use ColumnFixedLengthObject
+ * so during the upgrade process, will be not compatible if exist old BE and
Newer BE
+ */
+
+template <typename T, typename TResult, typename Data>
+class AggregateFunctionOldSum final
+ : public IAggregateFunctionDataHelper<Data, AggregateFunctionOldSum<T,
TResult, Data>> {
+public:
+ using ResultDataType = std::conditional_t<IsDecimalNumber<T>,
DataTypeDecimal<TResult>,
+ DataTypeNumber<TResult>>;
+ using ColVecType = std::conditional_t<IsDecimalNumber<T>,
ColumnDecimal<T>, ColumnVector<T>>;
+ using ColVecResult =
+ std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<TResult>,
ColumnVector<TResult>>;
+
+ String get_name() const override { return "sum"; }
+
+ AggregateFunctionOldSum(const DataTypes& argument_types_)
+ : IAggregateFunctionDataHelper<Data, AggregateFunctionOldSum<T,
TResult, Data>>(
+ argument_types_),
+ scale(get_decimal_scale(*argument_types_[0])) {}
+
+ DataTypePtr get_return_type() const override {
+ if constexpr (IsDecimalNumber<T>) {
+ return
std::make_shared<ResultDataType>(ResultDataType::max_precision(), scale);
+ } else {
+ return std::make_shared<ResultDataType>();
+ }
+ }
+
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ Arena*) const override {
+ const auto& column = assert_cast<const ColVecType&>(*columns[0]);
+ this->data(place).add(column.get_data()[row_num]);
+ }
+
+ void reset(AggregateDataPtr place) const override { this->data(place).sum
= {}; }
+
+ void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+ Arena*) const override {
+ this->data(place).merge(this->data(rhs));
+ }
+
+ void serialize(ConstAggregateDataPtr __restrict place, BufferWritable&
buf) const override {
+ this->data(place).write(buf);
+ }
+
+ void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+ Arena*) const override {
+ this->data(place).read(buf);
+ }
+
+ void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn&
to) const override {
+ auto& column = assert_cast<ColVecResult&>(to);
+ column.get_data().push_back(this->data(place).get());
+ }
+
+ void deserialize_from_column(AggregateDataPtr places, const IColumn&
column, Arena* arena,
+ size_t num_rows) const override {
+ auto data = assert_cast<const ColVecResult&>(column).get_data().data();
+ auto dst_data = reinterpret_cast<Data*>(places);
+ for (size_t i = 0; i != num_rows; ++i) {
+ dst_data[i].sum = data[i];
+ }
+ }
+
+ void serialize_to_column(const std::vector<AggregateDataPtr>& places,
size_t offset,
+ MutableColumnPtr& dst, const size_t num_rows)
const override {
+ auto& col = assert_cast<ColVecResult&>(*dst);
+ col.resize(num_rows);
+ auto* data = col.get_data().data();
+ for (size_t i = 0; i != num_rows; ++i) {
+ data[i] = this->data(places[i] + offset).sum;
+ }
+ }
+
+ void streaming_agg_serialize_to_column(const IColumn** columns,
MutableColumnPtr& dst,
+ const size_t num_rows, Arena*
arena) const override {
+ auto& col = assert_cast<ColVecResult&>(*dst);
+ auto& src = assert_cast<const ColVecType&>(*columns[0]);
+ col.resize(num_rows);
+ auto* src_data = src.get_data().data();
+ auto* dst_data = col.get_data().data();
+ for (size_t i = 0; i != num_rows; ++i) {
+ dst_data[i] = src_data[i];
+ }
+ }
+
+ void deserialize_and_merge_from_column(AggregateDataPtr __restrict place,
const IColumn& column,
+ Arena* arena) const override {
+ auto data = assert_cast<const ColVecResult&>(column).get_data().data();
+ const size_t num_rows = column.size();
+ for (size_t i = 0; i != num_rows; ++i) {
+ this->data(place).sum += data[i];
+ }
+ }
+
+ void serialize_without_key_to_column(ConstAggregateDataPtr __restrict
place,
+ IColumn& to) const override {
+ auto& col = assert_cast<ColVecResult&>(to);
+ col.resize(1);
+ reinterpret_cast<Data*>(col.get_data().data())->sum =
this->data(place).sum;
+ }
+
+ MutableColumnPtr create_serialize_column() const override {
+ return get_return_type()->create_column();
+ }
+
+ DataTypePtr get_serialized_type() const override { return
get_return_type(); }
+
+private:
+ UInt32 scale;
+};
+
+template <typename T, bool level_up>
+struct OldSumSimple {
+ /// @note It uses slow Decimal128 (cause we need such a variant).
sumWithOverflow is faster for Decimal32/64
+ using ResultType = std::conditional_t<level_up, DisposeDecimal<T,
NearestFieldType<T>>, T>;
+ using AggregateDataType = AggregateFunctionSumData<ResultType>;
+ using Function = AggregateFunctionOldSum<T, ResultType, AggregateDataType>;
+};
+
+template <typename T>
+using AggregateFunctionSumSimpleOld = typename OldSumSimple<T, true>::Function;
+
+// do not level up return type for agg reader
+template <typename T>
+using AggregateFunctionSumSimpleReaderOld = typename OldSumSimple<T,
false>::Function;
+
+} // namespace doris::vectorized
diff --git a/be/src/vec/exprs/vectorized_agg_fn.cpp
b/be/src/vec/exprs/vectorized_agg_fn.cpp
index 40db922312..a6895ff728 100644
--- a/be/src/vec/exprs/vectorized_agg_fn.cpp
+++ b/be/src/vec/exprs/vectorized_agg_fn.cpp
@@ -188,7 +188,8 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const
RowDescriptor& desc,
}
} else {
_function = AggregateFunctionSimpleFactory::instance().get(
- _fn.name.function_name, argument_types,
_data_type->is_nullable());
+ _fn.name.function_name, argument_types,
_data_type->is_nullable(),
+ state->be_exec_version());
}
if (_function == nullptr) {
return Status::InternalError("Agg Function {} is not implemented",
_fn.signature);
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]