This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 0a59580aa4 [Enhancement](function) fix compatibility issues of 
sum/count during upgrade process (#20890)
0a59580aa4 is described below

commit 0a59580aa47baee2236b3d26eb82cd764a882b75
Author: zhangstar333 <[email protected]>
AuthorDate: Sat Jun 17 12:51:01 2023 +0800

    [Enhancement](function) fix compatibility issues of sum/count during 
upgrade process (#20890)
    
    in order to solve agg of sum/count is not compatibility during the upgrade 
process.
    in PR [refactor](agg_state) refactor agg_state type to support fixed length 
object type #20370 have changed the serialize type and serialize column of 
sum/count
    before is ColumnVector, now sum/count change to use ColumnFixedLengthObject
    so during the upgrade process, will be not compatible if exist Old BE and 
Newer BE
---
 be/src/agent/be_exec_version_manager.h             |   2 +
 .../aggregate_function_count_old.cpp               |  51 +++++
 .../aggregate_function_count_old.h                 | 248 +++++++++++++++++++++
 .../aggregate_function_simple_factory.cpp          |   4 +
 .../aggregate_function_simple_factory.h            |  28 ++-
 .../aggregate_function_sum_old.cpp                 |  33 +++
 .../aggregate_function_sum_old.h                   | 185 +++++++++++++++
 be/src/vec/exprs/vectorized_agg_fn.cpp             |   3 +-
 8 files changed, 552 insertions(+), 2 deletions(-)

diff --git a/be/src/agent/be_exec_version_manager.h 
b/be/src/agent/be_exec_version_manager.h
index 23c692bb86..0491a038c8 100644
--- a/be/src/agent/be_exec_version_manager.h
+++ b/be/src/agent/be_exec_version_manager.h
@@ -54,6 +54,8 @@ private:
  *    b. runtime filter use new hash method.
  * 2: start from doris 2.0
  *    a. function month/day/hour/minute/second's return type is changed to 
smaller type.
+ *    b. in order to solve agg of sum/count is not compatibility during the 
upgrade process
+ *
 */
 inline const int BeExecVersionManager::max_be_exec_version = 2;
 inline const int BeExecVersionManager::min_be_exec_version = 0;
diff --git a/be/src/vec/aggregate_functions/aggregate_function_count_old.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_count_old.cpp
new file mode 100644
index 0000000000..c480c09a5a
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_count_old.cpp
@@ -0,0 +1,51 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+// This file is copied from
+// 
https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionCount.cpp
+// and modified by Doris
+
+#include "vec/aggregate_functions/aggregate_function_count_old.h"
+
+#include <string>
+
+#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/aggregate_functions/factory_helpers.h"
+
+namespace doris::vectorized {
+
+AggregateFunctionPtr create_aggregate_function_count_old(const std::string& 
name,
+                                                         const DataTypes& 
argument_types,
+                                                         const bool 
result_is_nullable) {
+    assert_arity_at_most<1>(name, argument_types);
+
+    return std::make_shared<AggregateFunctionCountOld>(argument_types);
+}
+
+AggregateFunctionPtr create_aggregate_function_count_not_null_unary_old(
+        const std::string& name, const DataTypes& argument_types, const bool 
result_is_nullable) {
+    assert_arity_at_most<1>(name, argument_types);
+
+    return 
std::make_shared<AggregateFunctionCountNotNullUnaryOld>(argument_types);
+}
+
+void register_aggregate_function_count_old(AggregateFunctionSimpleFactory& 
factory) {
+    factory.register_alternative_function("count", 
create_aggregate_function_count_old, false);
+    factory.register_alternative_function("count",
+                                          
create_aggregate_function_count_not_null_unary_old, true);
+}
+
+} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_count_old.h 
b/be/src/vec/aggregate_functions/aggregate_function_count_old.h
new file mode 100644
index 0000000000..ff935b6a67
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_count_old.h
@@ -0,0 +1,248 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+// This file is copied from
+// 
https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionCountOld.h
+// and modified by Doris
+
+#pragma once
+
+#include <stddef.h>
+
+#include <algorithm>
+#include <boost/iterator/iterator_facade.hpp>
+#include <memory>
+#include <vector>
+
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/columns/column.h"
+#include "vec/columns/column_nullable.h"
+#include "vec/columns/column_vector.h"
+#include "vec/columns/columns_number.h"
+#include "vec/common/assert_cast.h"
+#include "vec/core/types.h"
+#include "vec/data_types/data_type.h"
+#include "vec/data_types/data_type_number.h"
+#include "vec/io/var_int.h"
+
+namespace doris {
+namespace vectorized {
+class Arena;
+class BufferReadable;
+class BufferWritable;
+} // namespace vectorized
+} // namespace doris
+
+/*
+ * this function is used to solve agg of sum/count is not compatibility during 
the upgrade process.
+ * in PR #20370 have changed the serialize type and serialize column.
+ * before is ColumnVector, now sum/count change to use ColumnFixedLengthObject.
+ * so during the upgrade process, will be not compatible if exist old BE and 
Newer BE.
+ */
+
+namespace doris::vectorized {
+struct AggregateFunctionCountDataOld {
+    UInt64 count = 0;
+};
+/// Simply count number of calls.
+class AggregateFunctionCountOld final
+        : public IAggregateFunctionDataHelper<AggregateFunctionCountDataOld,
+                                              AggregateFunctionCountOld> {
+public:
+    AggregateFunctionCountOld(const DataTypes& argument_types_)
+            : IAggregateFunctionDataHelper(argument_types_) {}
+
+    String get_name() const override { return "count"; }
+
+    DataTypePtr get_return_type() const override { return 
std::make_shared<DataTypeInt64>(); }
+
+    void add(AggregateDataPtr __restrict place, const IColumn**, size_t, 
Arena*) const override {
+        ++data(place).count;
+    }
+
+    void reset(AggregateDataPtr place) const override {
+        AggregateFunctionCountOld::data(place).count = 0;
+    }
+
+    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+               Arena*) const override {
+        data(place).count += data(rhs).count;
+    }
+
+    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& 
buf) const override {
+        write_var_uint(data(place).count, buf);
+    }
+
+    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+                     Arena*) const override {
+        read_var_uint(data(place).count, buf);
+    }
+
+    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& 
to) const override {
+        assert_cast<ColumnInt64&>(to).get_data().push_back(data(place).count);
+    }
+
+    void deserialize_from_column(AggregateDataPtr places, const IColumn& 
column, Arena* arena,
+                                 size_t num_rows) const override {
+        auto data = assert_cast<const ColumnUInt64&>(column).get_data().data();
+        auto* dst_data = reinterpret_cast<Data*>(places);
+        for (size_t i = 0; i != num_rows; ++i) {
+            dst_data[i].count = data[i];
+        }
+    }
+
+    void serialize_to_column(const std::vector<AggregateDataPtr>& places, 
size_t offset,
+                             MutableColumnPtr& dst, const size_t num_rows) 
const override {
+        auto& col = assert_cast<ColumnUInt64&>(*dst);
+        col.resize(num_rows);
+        auto* data = col.get_data().data();
+        for (size_t i = 0; i != num_rows; ++i) {
+            data[i] = AggregateFunctionCountOld::data(places[i] + 
offset).count;
+        }
+    }
+
+    void streaming_agg_serialize_to_column(const IColumn** columns, 
MutableColumnPtr& dst,
+                                           const size_t num_rows, Arena* 
arena) const override {
+        auto& col = assert_cast<ColumnUInt64&>(*dst);
+        col.resize(num_rows);
+        col.get_data().assign(num_rows, assert_cast<UInt64>(1UL));
+    }
+
+    void deserialize_and_merge_from_column(AggregateDataPtr __restrict place, 
const IColumn& column,
+                                           Arena* arena) const override {
+        auto data = assert_cast<const ColumnUInt64&>(column).get_data().data();
+        const size_t num_rows = column.size();
+        for (size_t i = 0; i != num_rows; ++i) {
+            AggregateFunctionCountOld::data(place).count += data[i];
+        }
+    }
+
+    void serialize_without_key_to_column(ConstAggregateDataPtr __restrict 
place,
+                                         IColumn& to) const override {
+        auto& col = assert_cast<ColumnUInt64&>(to);
+        col.resize(1);
+        reinterpret_cast<Data*>(col.get_data().data())->count =
+                AggregateFunctionCountOld::data(place).count;
+    }
+
+    MutableColumnPtr create_serialize_column() const override {
+        return ColumnVector<UInt64>::create();
+    }
+
+    DataTypePtr get_serialized_type() const override { return 
std::make_shared<DataTypeUInt64>(); }
+};
+
+// TODO: Maybe AggregateFunctionCountNotNullUnaryOld should be a subclass of 
AggregateFunctionCountOld
+// Simply count number of not-NULL values.
+class AggregateFunctionCountNotNullUnaryOld final
+        : public IAggregateFunctionDataHelper<AggregateFunctionCountDataOld,
+                                              
AggregateFunctionCountNotNullUnaryOld> {
+public:
+    AggregateFunctionCountNotNullUnaryOld(const DataTypes& argument_types_)
+            : IAggregateFunctionDataHelper(argument_types_) {}
+
+    String get_name() const override { return "count"; }
+
+    DataTypePtr get_return_type() const override { return 
std::make_shared<DataTypeInt64>(); }
+
+    void add(AggregateDataPtr __restrict place, const IColumn** columns, 
size_t row_num,
+             Arena*) const override {
+        data(place).count += !assert_cast<const 
ColumnNullable&>(*columns[0]).is_null_at(row_num);
+    }
+
+    void reset(AggregateDataPtr place) const override { data(place).count = 0; 
}
+
+    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+               Arena*) const override {
+        data(place).count += data(rhs).count;
+    }
+
+    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& 
buf) const override {
+        write_var_uint(data(place).count, buf);
+    }
+
+    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+                     Arena*) const override {
+        read_var_uint(data(place).count, buf);
+    }
+
+    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& 
to) const override {
+        if (to.is_nullable()) {
+            auto& null_column = assert_cast<ColumnNullable&>(to);
+            null_column.get_null_map_data().push_back(0);
+            assert_cast<ColumnInt64&>(null_column.get_nested_column())
+                    .get_data()
+                    .push_back(data(place).count);
+        } else {
+            
assert_cast<ColumnInt64&>(to).get_data().push_back(data(place).count);
+        }
+    }
+
+    void deserialize_from_column(AggregateDataPtr places, const IColumn& 
column, Arena* arena,
+                                 size_t num_rows) const override {
+        auto data = assert_cast<const ColumnUInt64&>(column).get_data().data();
+        auto* dst_data = reinterpret_cast<Data*>(places);
+        for (size_t i = 0; i != num_rows; ++i) {
+            dst_data[i].count = data[i];
+        }
+    }
+
+    void serialize_to_column(const std::vector<AggregateDataPtr>& places, 
size_t offset,
+                             MutableColumnPtr& dst, const size_t num_rows) 
const override {
+        auto& col = assert_cast<ColumnUInt64&>(*dst);
+        col.resize(num_rows);
+        auto* data = col.get_data().data();
+        for (size_t i = 0; i != num_rows; ++i) {
+            data[i] = AggregateFunctionCountNotNullUnaryOld::data(places[i] + 
offset).count;
+        }
+    }
+
+    void streaming_agg_serialize_to_column(const IColumn** columns, 
MutableColumnPtr& dst,
+                                           const size_t num_rows, Arena* 
arena) const override {
+        auto& col = assert_cast<ColumnUInt64&>(*dst);
+        col.resize(num_rows);
+        auto& data = col.get_data();
+        const ColumnNullable& input_col = assert_cast<const 
ColumnNullable&>(*columns[0]);
+        for (size_t i = 0; i < num_rows; i++) {
+            data[i] = !input_col.is_null_at(i);
+        }
+    }
+
+    void deserialize_and_merge_from_column(AggregateDataPtr __restrict place, 
const IColumn& column,
+                                           Arena* arena) const override {
+        auto data = assert_cast<const ColumnUInt64&>(column).get_data().data();
+        const size_t num_rows = column.size();
+        for (size_t i = 0; i != num_rows; ++i) {
+            AggregateFunctionCountNotNullUnaryOld::data(place).count += 
data[i];
+        }
+    }
+
+    void serialize_without_key_to_column(ConstAggregateDataPtr __restrict 
place,
+                                         IColumn& to) const override {
+        auto& col = assert_cast<ColumnUInt64&>(to);
+        col.resize(1);
+        reinterpret_cast<Data*>(col.get_data().data())->count =
+                AggregateFunctionCountNotNullUnaryOld::data(place).count;
+    }
+
+    MutableColumnPtr create_serialize_column() const override {
+        return ColumnVector<UInt64>::create();
+    }
+
+    DataTypePtr get_serialized_type() const override { return 
std::make_shared<DataTypeUInt64>(); }
+};
+
+} // namespace doris::vectorized
\ No newline at end of file
diff --git 
a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
index 8c0ae92c07..5ab8eb874a 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
@@ -57,6 +57,8 @@ void 
register_aggregate_function_collect_list(AggregateFunctionSimpleFactory& fa
 void 
register_aggregate_function_sequence_match(AggregateFunctionSimpleFactory& 
factory);
 void register_aggregate_function_avg_weighted(AggregateFunctionSimpleFactory& 
factory);
 void register_aggregate_function_histogram(AggregateFunctionSimpleFactory& 
factory);
+void register_aggregate_function_count_old(AggregateFunctionSimpleFactory& 
factory);
+void register_aggregate_function_sum_old(AggregateFunctionSimpleFactory& 
factory);
 
 AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
     static std::once_flag oc;
@@ -68,6 +70,8 @@ AggregateFunctionSimpleFactory& 
AggregateFunctionSimpleFactory::instance() {
         register_aggregate_function_max_by(instance);
         register_aggregate_function_avg(instance);
         register_aggregate_function_count(instance);
+        register_aggregate_function_count_old(instance);
+        register_aggregate_function_sum_old(instance);
         register_aggregate_function_uniq(instance);
         register_aggregate_function_bit(instance);
         register_aggregate_function_bitmap(instance);
diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h 
b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
index bff49e9d9e..618340dd88 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
@@ -27,6 +27,7 @@
 #include <utility>
 #include <vector>
 
+#include "agent/be_exec_version_manager.h"
 #include "vec/aggregate_functions/aggregate_function.h"
 #include "vec/data_types/data_type.h"
 
@@ -46,6 +47,11 @@ private:
     AggregateFunctions aggregate_functions;
     AggregateFunctions nullable_aggregate_functions;
     std::unordered_map<std::string, std::string> function_alias;
+    /// @TEMPORARY: for be_exec_version=2
+    /// in order to solve agg of sum/count is not compatibility during the 
upgrade process
+    constexpr static int AGG_FUNCTION_NEW = 2;
+    /// @TEMPORARY: for be_exec_version < AGG_FUNCTION_NEW. replace function 
to old version.
+    std::unordered_map<std::string, std::string> function_to_replace;
 
 public:
     void register_nullable_function_combinator(const Creator& creator) {
@@ -73,7 +79,8 @@ public:
     }
 
     AggregateFunctionPtr get(const std::string& name, const DataTypes& 
argument_types,
-                             const bool result_is_nullable = false) {
+                             const bool result_is_nullable = false,
+                             int be_version = 
BeExecVersionManager::get_newest_version()) {
         bool nullable = false;
         for (const auto& type : argument_types) {
             if (type->is_nullable()) {
@@ -82,6 +89,8 @@ public:
         }
 
         std::string name_str = name;
+        temporary_function_update(be_version, name_str);
+
         if (function_alias.count(name)) {
             name_str = function_alias[name];
         }
@@ -116,6 +125,23 @@ public:
         function_alias[alias] = name;
     }
 
+    /// @TEMPORARY: for be_exec_version < AGG_FUNCTION_NEW
+    void register_alternative_function(const std::string& name, const Creator& 
creator,
+                                       bool nullable = false) {
+        static std::string suffix {"_old_for_version_before_2_0"};
+        register_function(name + suffix, creator, nullable);
+        function_to_replace[name] = name + suffix;
+    }
+
+    /// @TEMPORARY: for be_exec_version < AGG_FUNCTION_NEW
+    void temporary_function_update(int fe_version_now, std::string& name) {
+        // replace if fe is old version.
+        if (fe_version_now < AGG_FUNCTION_NEW &&
+            function_to_replace.find(name) != function_to_replace.end()) {
+            name = function_to_replace[name];
+        }
+    }
+
 public:
     static AggregateFunctionSimpleFactory& instance();
 };
diff --git a/be/src/vec/aggregate_functions/aggregate_function_sum_old.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_sum_old.cpp
new file mode 100644
index 0000000000..9b5d3cf85a
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_sum_old.cpp
@@ -0,0 +1,33 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+// This file is copied from
+// 
https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionSum.cpp
+// and modified by Doris
+
+#include "vec/aggregate_functions/aggregate_function_sum_old.h"
+
+#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/aggregate_functions/helpers.h"
+
+namespace doris::vectorized {
+void register_aggregate_function_sum_old(AggregateFunctionSimpleFactory& 
factory) {
+    factory.register_alternative_function(
+            "sum", creator_with_type::creator<AggregateFunctionSumSimpleOld>, 
false);
+    factory.register_alternative_function(
+            "sum", creator_with_type::creator<AggregateFunctionSumSimpleOld>, 
true);
+}
+} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_sum_old.h 
b/be/src/vec/aggregate_functions/aggregate_function_sum_old.h
new file mode 100644
index 0000000000..a3b7280fd2
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_sum_old.h
@@ -0,0 +1,185 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+// This file is copied from
+// 
https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionOldSum.h
+// and modified by Doris
+
+#pragma once
+
+#include <stddef.h>
+
+#include <memory>
+#include <type_traits>
+#include <vector>
+
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/aggregate_functions/aggregate_function_sum.h"
+#include "vec/columns/column.h"
+#include "vec/common/assert_cast.h"
+#include "vec/core/field.h"
+#include "vec/core/types.h"
+#include "vec/data_types/data_type.h"
+#include "vec/data_types/data_type_decimal.h"
+#include "vec/io/io_helper.h"
+namespace doris {
+namespace vectorized {
+class Arena;
+class BufferReadable;
+class BufferWritable;
+template <typename T>
+class ColumnDecimal;
+template <typename T>
+class DataTypeNumber;
+template <typename>
+class ColumnVector;
+} // namespace vectorized
+} // namespace doris
+
+namespace doris::vectorized {
+
+/*
+ * this function is used to solve agg of sum/count is not compatibility during 
the upgrade process
+ * in PR #20370 have changed the serialize type and serialize column
+ * before is ColumnVector, now sum/count change to use ColumnFixedLengthObject
+ * so during the upgrade process, will be not compatible if exist old BE and 
Newer BE
+ */
+
+template <typename T, typename TResult, typename Data>
+class AggregateFunctionOldSum final
+        : public IAggregateFunctionDataHelper<Data, AggregateFunctionOldSum<T, 
TResult, Data>> {
+public:
+    using ResultDataType = std::conditional_t<IsDecimalNumber<T>, 
DataTypeDecimal<TResult>,
+                                              DataTypeNumber<TResult>>;
+    using ColVecType = std::conditional_t<IsDecimalNumber<T>, 
ColumnDecimal<T>, ColumnVector<T>>;
+    using ColVecResult =
+            std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<TResult>, 
ColumnVector<TResult>>;
+
+    String get_name() const override { return "sum"; }
+
+    AggregateFunctionOldSum(const DataTypes& argument_types_)
+            : IAggregateFunctionDataHelper<Data, AggregateFunctionOldSum<T, 
TResult, Data>>(
+                      argument_types_),
+              scale(get_decimal_scale(*argument_types_[0])) {}
+
+    DataTypePtr get_return_type() const override {
+        if constexpr (IsDecimalNumber<T>) {
+            return 
std::make_shared<ResultDataType>(ResultDataType::max_precision(), scale);
+        } else {
+            return std::make_shared<ResultDataType>();
+        }
+    }
+
+    void add(AggregateDataPtr __restrict place, const IColumn** columns, 
size_t row_num,
+             Arena*) const override {
+        const auto& column = assert_cast<const ColVecType&>(*columns[0]);
+        this->data(place).add(column.get_data()[row_num]);
+    }
+
+    void reset(AggregateDataPtr place) const override { this->data(place).sum 
= {}; }
+
+    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+               Arena*) const override {
+        this->data(place).merge(this->data(rhs));
+    }
+
+    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& 
buf) const override {
+        this->data(place).write(buf);
+    }
+
+    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+                     Arena*) const override {
+        this->data(place).read(buf);
+    }
+
+    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& 
to) const override {
+        auto& column = assert_cast<ColVecResult&>(to);
+        column.get_data().push_back(this->data(place).get());
+    }
+
+    void deserialize_from_column(AggregateDataPtr places, const IColumn& 
column, Arena* arena,
+                                 size_t num_rows) const override {
+        auto data = assert_cast<const ColVecResult&>(column).get_data().data();
+        auto dst_data = reinterpret_cast<Data*>(places);
+        for (size_t i = 0; i != num_rows; ++i) {
+            dst_data[i].sum = data[i];
+        }
+    }
+
+    void serialize_to_column(const std::vector<AggregateDataPtr>& places, 
size_t offset,
+                             MutableColumnPtr& dst, const size_t num_rows) 
const override {
+        auto& col = assert_cast<ColVecResult&>(*dst);
+        col.resize(num_rows);
+        auto* data = col.get_data().data();
+        for (size_t i = 0; i != num_rows; ++i) {
+            data[i] = this->data(places[i] + offset).sum;
+        }
+    }
+
+    void streaming_agg_serialize_to_column(const IColumn** columns, 
MutableColumnPtr& dst,
+                                           const size_t num_rows, Arena* 
arena) const override {
+        auto& col = assert_cast<ColVecResult&>(*dst);
+        auto& src = assert_cast<const ColVecType&>(*columns[0]);
+        col.resize(num_rows);
+        auto* src_data = src.get_data().data();
+        auto* dst_data = col.get_data().data();
+        for (size_t i = 0; i != num_rows; ++i) {
+            dst_data[i] = src_data[i];
+        }
+    }
+
+    void deserialize_and_merge_from_column(AggregateDataPtr __restrict place, 
const IColumn& column,
+                                           Arena* arena) const override {
+        auto data = assert_cast<const ColVecResult&>(column).get_data().data();
+        const size_t num_rows = column.size();
+        for (size_t i = 0; i != num_rows; ++i) {
+            this->data(place).sum += data[i];
+        }
+    }
+
+    void serialize_without_key_to_column(ConstAggregateDataPtr __restrict 
place,
+                                         IColumn& to) const override {
+        auto& col = assert_cast<ColVecResult&>(to);
+        col.resize(1);
+        reinterpret_cast<Data*>(col.get_data().data())->sum = 
this->data(place).sum;
+    }
+
+    MutableColumnPtr create_serialize_column() const override {
+        return get_return_type()->create_column();
+    }
+
+    DataTypePtr get_serialized_type() const override { return 
get_return_type(); }
+
+private:
+    UInt32 scale;
+};
+
+template <typename T, bool level_up>
+struct OldSumSimple {
+    /// @note It uses slow Decimal128 (cause we need such a variant). 
sumWithOverflow is faster for Decimal32/64
+    using ResultType = std::conditional_t<level_up, DisposeDecimal<T, 
NearestFieldType<T>>, T>;
+    using AggregateDataType = AggregateFunctionSumData<ResultType>;
+    using Function = AggregateFunctionOldSum<T, ResultType, AggregateDataType>;
+};
+
+template <typename T>
+using AggregateFunctionSumSimpleOld = typename OldSumSimple<T, true>::Function;
+
+// do not level up return type for agg reader
+template <typename T>
+using AggregateFunctionSumSimpleReaderOld = typename OldSumSimple<T, 
false>::Function;
+
+} // namespace doris::vectorized
diff --git a/be/src/vec/exprs/vectorized_agg_fn.cpp 
b/be/src/vec/exprs/vectorized_agg_fn.cpp
index 40db922312..a6895ff728 100644
--- a/be/src/vec/exprs/vectorized_agg_fn.cpp
+++ b/be/src/vec/exprs/vectorized_agg_fn.cpp
@@ -188,7 +188,8 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const 
RowDescriptor& desc,
         }
     } else {
         _function = AggregateFunctionSimpleFactory::instance().get(
-                _fn.name.function_name, argument_types, 
_data_type->is_nullable());
+                _fn.name.function_name, argument_types, 
_data_type->is_nullable(),
+                state->be_exec_version());
     }
     if (_function == nullptr) {
         return Status::InternalError("Agg Function {} is not implemented", 
_fn.signature);


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to