This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new 33e7941b6f1 branch-3.1: [behavior change](agg)  The array type 
returned by foreach is always array<nullable<T>> #52679 (#52704)
33e7941b6f1 is described below

commit 33e7941b6f1535379a004eafbfd58864c4abc768
Author: github-actions[bot] 
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Tue Jul 8 14:12:56 2025 +0800

    branch-3.1: [behavior change](agg)  The array type returned by foreach is 
always array<nullable<T>> #52679 (#52704)
    
    Cherry-picked from #52679
    
    Co-authored-by: Mryange <[email protected]>
---
 .../aggregate_function_foreachv2.cpp               | 108 +++++++++++++++++++++
 .../aggregate_function_simple_factory.cpp          |   5 +-
 .../aggregate_function_simple_factory.h            |  10 +-
 be/src/vec/exprs/vectorized_agg_fn.cpp             |   6 +-
 .../java/org/apache/doris/catalog/Function.java    |   5 +-
 .../functions/combinator/ForEachCombinator.java    |   2 +-
 6 files changed, 130 insertions(+), 6 deletions(-)

diff --git a/be/src/vec/aggregate_functions/aggregate_function_foreachv2.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_foreachv2.cpp
new file mode 100644
index 00000000000..fb7f48bf38e
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_foreachv2.cpp
@@ -0,0 +1,108 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+// This file is copied from
+// 
https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/Combinators/AggregateFunctionForEach.cpp
+// and modified by Doris
+
+#include <memory>
+
+#include "common/logging.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/aggregate_functions/aggregate_function_foreach.h"
+#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/aggregate_functions/helpers.h"
+#include "vec/data_types/data_type_array.h"
+#include "vec/data_types/data_type_nullable.h"
+
+namespace doris::vectorized {
+#include "common/compile_check_begin.h"
+
+// The difference between AggregateFunctionForEachV2 and 
AggregateFunctionForEach is that its return value array is always an 
Array<Nullable<T>>.
+// For example, AggregateFunctionForEach's count_foreach([1,2,3]) returns 
Array<Int64>, which is not ideal
+// because we may have already assumed that the array's elements are always 
nullable types, and many places have such checks.
+// V1 code is kept to ensure compatibility during upgrades and downgrades.
+// V2 code differs from V1 only in the return type and insert_into logic; all 
other logic is exactly the same.
+class AggregateFunctionForEachV2 : public AggregateFunctionForEach {
+public:
+    constexpr static auto AGG_FOREACH_SUFFIX = "_foreachv2";
+    AggregateFunctionForEachV2(AggregateFunctionPtr nested_function_, const 
DataTypes& arguments)
+            : AggregateFunctionForEach(nested_function_, arguments) {}
+
+    DataTypePtr get_return_type() const override {
+        return 
std::make_shared<DataTypeArray>(make_nullable(nested_function->get_return_type()));
+    }
+
+    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& 
to) const override {
+        const AggregateFunctionForEachData& state = data(place);
+
+        auto& arr_to = assert_cast<ColumnArray&>(to);
+        auto& offsets_to = arr_to.get_offsets();
+        IColumn& elems_nullable = arr_to.get_data();
+
+        DCHECK(elems_nullable.is_nullable());
+        auto& elems_to = 
assert_cast<ColumnNullable&>(elems_nullable).get_nested_column();
+        auto& elements_null_map =
+                
assert_cast<ColumnNullable&>(elems_nullable).get_null_map_column();
+
+        if (nested_function->get_return_type()->is_nullable()) {
+            char* nested_state = state.array_of_aggregate_datas;
+            for (size_t i = 0; i < state.dynamic_array_size; ++i) {
+                nested_function->insert_result_into(nested_state, 
elems_nullable);
+                nested_state += nested_size_of_data;
+            }
+        } else {
+            char* nested_state = state.array_of_aggregate_datas;
+            for (size_t i = 0; i < state.dynamic_array_size; ++i) {
+                nested_function->insert_result_into(nested_state, elems_to);
+                elements_null_map.insert_default(); // not null
+                nested_state += nested_size_of_data;
+            }
+        }
+        offsets_to.push_back(offsets_to.back() + state.dynamic_array_size);
+    }
+};
+
+void 
register_aggregate_function_combinator_foreachv2(AggregateFunctionSimpleFactory&
 factory) {
+    AggregateFunctionCreator creator =
+            [&](const std::string& name, const DataTypes& types, const bool 
result_is_nullable,
+                const AggregateFunctionAttr& attr) -> AggregateFunctionPtr {
+        const std::string& suffix = 
AggregateFunctionForEachV2::AGG_FOREACH_SUFFIX;
+        DataTypes transform_arguments;
+        for (const auto& t : types) {
+            auto item_type =
+                    assert_cast<const 
DataTypeArray*>(remove_nullable(t).get())->get_nested_type();
+            transform_arguments.push_back((item_type));
+        }
+        auto nested_function_name = name.substr(0, name.size() - 
suffix.size());
+        auto nested_function = factory.get(nested_function_name, 
transform_arguments, true,
+                                           
BeExecVersionManager::get_newest_version(), attr);
+        if (!nested_function) {
+            throw Exception(
+                    ErrorCode::INTERNAL_ERROR,
+                    "The combiner did not find a foreach combiner function. 
nested function "
+                    "name {} , args {}",
+                    nested_function_name, types_name(types));
+        }
+        return creator_without_type::create<AggregateFunctionForEachV2>(types, 
result_is_nullable,
+                                                                        
nested_function);
+    };
+    factory.register_foreach_function_combinator(
+            creator, AggregateFunctionForEachV2::AGG_FOREACH_SUFFIX, true);
+    factory.register_foreach_function_combinator(
+            creator, AggregateFunctionForEachV2::AGG_FOREACH_SUFFIX, false);
+}
+} // namespace doris::vectorized
diff --git 
a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
index 7a1981263fc..c41fbfd5cab 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
@@ -28,6 +28,7 @@ namespace doris::vectorized {
 
 void 
register_aggregate_function_combinator_distinct(AggregateFunctionSimpleFactory& 
factory);
 void 
register_aggregate_function_combinator_foreach(AggregateFunctionSimpleFactory& 
factory);
+void 
register_aggregate_function_combinator_foreachv2(AggregateFunctionSimpleFactory&
 factory);
 
 void register_aggregate_function_sum(AggregateFunctionSimpleFactory& factory);
 void register_aggregate_function_sum0(AggregateFunctionSimpleFactory& factory);
@@ -121,11 +122,13 @@ AggregateFunctionSimpleFactory& 
AggregateFunctionSimpleFactory::instance() {
         register_aggregate_functions_corr_welford(instance);
         register_aggregate_function_covar_pop(instance);
         register_aggregate_function_covar_samp(instance);
-        register_aggregate_function_combinator_foreach(instance);
         register_aggregate_function_skewness(instance);
         register_aggregate_function_kurtosis(instance);
         register_aggregate_function_approx_top_k(instance);
         register_aggregate_function_approx_top_sum(instance);
+        // Register foreach and foreachv2 functions
+        register_aggregate_function_combinator_foreach(instance);
+        register_aggregate_function_combinator_foreachv2(instance);
     });
     return instance;
 }
diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h 
b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
index 0a0ec6abe16..f32d60fddde 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
@@ -55,7 +55,7 @@ public:
 
 private:
     using AggregateFunctions = std::unordered_map<std::string, Creator>;
-    constexpr static std::string_view combiner_names[] = {"_foreach"};
+    constexpr static std::string_view combiner_names[] = {"_foreach", 
"_foreachv2"};
     AggregateFunctions aggregate_functions;
     AggregateFunctions nullable_aggregate_functions;
     std::unordered_map<std::string, std::string> function_alias;
@@ -69,6 +69,14 @@ public:
         return name.substr(name.length() - suffix.length()) == suffix;
     }
 
+    static bool is_foreachv2(const std::string& name) {
+        constexpr std::string_view suffix = "_foreachv2";
+        if (name.length() < suffix.length()) {
+            return false;
+        }
+        return name.substr(name.length() - suffix.length()) == suffix;
+    }
+
     static bool result_nullable_by_foreach(DataTypePtr& data_type) {
         // The return value of the 'foreach' function is 'null' or 
'array<type>'.
         // The internal function's nullable should depend on whether 'type' is 
nullable
diff --git a/be/src/vec/exprs/vectorized_agg_fn.cpp 
b/be/src/vec/exprs/vectorized_agg_fn.cpp
index 8a3c8558abe..067ba3a7125 100644
--- a/be/src/vec/exprs/vectorized_agg_fn.cpp
+++ b/be/src/vec/exprs/vectorized_agg_fn.cpp
@@ -205,6 +205,7 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const 
RowDescriptor& desc,
                                          _fn.name.function_name);
         }
     } else {
+        // Here, only foreachv1 needs special treatment, and v2 can follow the 
normal code logic.
         if 
(AggregateFunctionSimpleFactory::is_foreach(_fn.name.function_name)) {
             _function = AggregateFunctionSimpleFactory::instance().get(
                     _fn.name.function_name, argument_types,
@@ -229,7 +230,10 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const 
RowDescriptor& desc,
                                                    _sort_description, state);
     }
 
-    if (!AggregateFunctionSimpleFactory::is_foreach(_fn.name.function_name)) {
+    // Foreachv2, like foreachv1, does not check the return type,
+    // because its return type is related to the internal agg.
+    if (!AggregateFunctionSimpleFactory::is_foreach(_fn.name.function_name) &&
+        !AggregateFunctionSimpleFactory::is_foreachv2(_fn.name.function_name)) 
{
         if (state->be_exec_version() >= 
BE_VERSION_THAT_SUPPORT_NULLABLE_CHECK) {
             RETURN_IF_ERROR(
                     _function->verify_result_type(_without_key, 
argument_types, _data_type));
diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java 
b/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java
index 513ca12495d..b3b45655243 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java
@@ -922,13 +922,14 @@ public class Function implements Writable {
 
     public static FunctionCallExpr convertForEachCombinator(FunctionCallExpr 
fnCall) {
         Function aggFunction = fnCall.getFn();
-        aggFunction.setName(new 
FunctionName(aggFunction.getFunctionName().getFunction() + 
Expr.AGG_FOREACH_SUFFIX));
+        aggFunction.setName(new 
FunctionName(aggFunction.getFunctionName().getFunction()
+                + Expr.AGG_FOREACH_SUFFIX + "v2"));
         List<Type> argTypes = new ArrayList();
         for (Type type : aggFunction.argTypes) {
             argTypes.add(new ArrayType(type));
         }
         aggFunction.setArgs(argTypes);
-        aggFunction.setReturnType(new ArrayType(aggFunction.getReturnType(), 
fnCall.isNullable()));
+        aggFunction.setReturnType(new ArrayType(aggFunction.getReturnType(), 
true));
         aggFunction.setNullableMode(NullableMode.ALWAYS_NULLABLE);
         return fnCall;
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/ForEachCombinator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/ForEachCombinator.java
index ddd92f894e1..a4e50e14110 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/ForEachCombinator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/ForEachCombinator.java
@@ -109,7 +109,7 @@ public class ForEachCombinator extends 
NullableAggregateFunction
 
     @Override
     public DataType getDataType() {
-        return ArrayType.of(nested.getDataType(), nested.nullable());
+        return ArrayType.of(nested.getDataType(), true);
     }
 
     @Override


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to