This is an automated email from the ASF dual-hosted git repository.
lihaopeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 3c3973433e7 [feature](function) Support for aggregate function
foreach combiner (#31526)
3c3973433e7 is described below
commit 3c3973433e736b74023d849e20449b31e6f1881f
Author: Mryange <[email protected]>
AuthorDate: Wed Mar 6 10:22:05 2024 +0800
[feature](function) Support for aggregate function foreach combiner
(#31526)
---
.../vec/aggregate_functions/aggregate_function.h | 2 +-
.../aggregate_function_approx_count_distinct.h | 2 +-
.../aggregate_functions/aggregate_function_avg.h | 2 +-
.../aggregate_function_avg_weighted.h | 2 +-
.../aggregate_function_binary.h | 2 +-
.../aggregate_functions/aggregate_function_bit.h | 2 +-
.../aggregate_function_bitmap.h | 4 +-
.../aggregate_function_bitmap_agg.h | 2 +-
.../aggregate_function_collect.h | 2 +-
.../aggregate_functions/aggregate_function_count.h | 4 +-
.../aggregate_function_count_by_enum.h | 2 +-
.../aggregate_functions/aggregate_function_covar.h | 2 +-
.../aggregate_function_distinct.h | 2 +-
.../aggregate_function_foreach.cpp | 64 +++++
.../aggregate_function_foreach.h | 264 +++++++++++++++++++++
.../aggregate_function_group_concat.h | 2 +-
.../aggregate_function_histogram.h | 2 +-
.../aggregate_function_hll_union_agg.h | 2 +-
.../aggregate_function_java_udaf.h | 2 +-
.../aggregate_functions/aggregate_function_map.h | 2 +-
.../aggregate_function_min_max.h | 2 +-
.../aggregate_function_min_max_by.h | 2 +-
.../aggregate_functions/aggregate_function_null.h | 4 +-
.../aggregate_function_orthogonal_bitmap.h | 2 +-
.../aggregate_function_percentile_approx.h | 10 +-
.../aggregate_function_product.h | 2 +-
.../aggregate_function_quantile_state.h | 2 +-
.../aggregate_function_reader_first_last.h | 2 +-
.../aggregate_function_retention.h | 2 +-
.../aggregate_functions/aggregate_function_rpc.h | 2 +-
.../aggregate_function_sequence_match.h | 2 +-
.../aggregate_function_simple_factory.cpp | 3 +
.../aggregate_function_simple_factory.h | 26 +-
.../aggregate_functions/aggregate_function_sort.h | 2 +-
.../aggregate_function_state_union.h | 2 +-
.../aggregate_function_stddev.h | 2 +-
.../aggregate_functions/aggregate_function_sum.h | 2 +-
.../aggregate_functions/aggregate_function_topn.h | 2 +-
.../aggregate_functions/aggregate_function_uniq.h | 2 +-
.../aggregate_function_window.h | 14 +-
.../aggregate_function_window_funnel.h | 2 +-
.../data_types/serde/data_type_nullable_serde.cpp | 2 +-
.../main/java/org/apache/doris/analysis/Expr.java | 1 +
.../java/org/apache/doris/catalog/Function.java | 14 ++
.../org/apache/doris/catalog/FunctionRegistry.java | 10 +-
.../glue/translator/ExpressionTranslator.java | 11 +
...uilder.java => AggCombinerFunctionBuilder.java} | 34 ++-
...UnionCombinator.java => ForEachCombinator.java} | 38 +--
.../functions/combinator/MergeCombinator.java | 4 +-
.../functions/combinator/StateCombinator.java | 4 +-
.../functions/combinator/UnionCombinator.java | 4 +-
.../visitor/AggregateFunctionVisitor.java | 5 +
.../data/function_p0/test_agg_foreach.out | 28 +++
.../suites/function_p0/test_agg_foreach.groovy | 95 ++++++++
54 files changed, 624 insertions(+), 83 deletions(-)
diff --git a/be/src/vec/aggregate_functions/aggregate_function.h
b/be/src/vec/aggregate_functions/aggregate_function.h
index cc1b7d88f58..c24cd70ebea 100644
--- a/be/src/vec/aggregate_functions/aggregate_function.h
+++ b/be/src/vec/aggregate_functions/aggregate_function.h
@@ -106,7 +106,7 @@ public:
* row_num is number of row which should be added.
* Additional parameter arena should be used instead of standard memory
allocator if the addition requires memory allocation.
*/
- virtual void add(AggregateDataPtr __restrict place, const IColumn**
columns, size_t row_num,
+ virtual void add(AggregateDataPtr __restrict place, const IColumn**
columns, ssize_t row_num,
Arena* arena) const = 0;
virtual void add_many(AggregateDataPtr __restrict place, const IColumn**
columns,
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.h
b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.h
index 03e1cc3df13..d0f5bce81a0 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.h
@@ -95,7 +95,7 @@ public:
DataTypePtr get_return_type() const override { return
std::make_shared<DataTypeInt64>(); }
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
if constexpr (IsFixLenColumnType<ColumnDataType>::value) {
auto column = assert_cast<const ColumnDataType*>(columns[0]);
diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.h
b/be/src/vec/aggregate_functions/aggregate_function_avg.h
index 61eb04bb13b..ca155f9d72c 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_avg.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_avg.h
@@ -140,7 +140,7 @@ public:
}
}
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
#ifdef __clang__
#pragma clang fp reassociate(on)
diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.h
b/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.h
index fe6f50481ba..498ee20ccb8 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.h
@@ -106,7 +106,7 @@ public:
DataTypePtr get_return_type() const override { return
std::make_shared<DataTypeFloat64>(); }
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
const auto& column = assert_cast<const ColVecType&>(*columns[0]);
const auto& weight = assert_cast<const
ColumnVector<Float64>&>(*columns[1]);
diff --git a/be/src/vec/aggregate_functions/aggregate_function_binary.h
b/be/src/vec/aggregate_functions/aggregate_function_binary.h
index 422919c52af..ca06cc1bb81 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_binary.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_binary.h
@@ -69,7 +69,7 @@ struct AggregateFunctionBinary
bool allocates_memory_in_arena() const override { return false; }
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
this->data(place).add(
static_cast<ResultType>(
diff --git a/be/src/vec/aggregate_functions/aggregate_function_bit.h
b/be/src/vec/aggregate_functions/aggregate_function_bit.h
index 6d2e67b14e7..c0b2df85ba2 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_bit.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_bit.h
@@ -112,7 +112,7 @@ public:
DataTypePtr get_return_type() const override { return
std::make_shared<DataTypeNumber<T>>(); }
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
const auto& column = assert_cast<const ColumnVector<T>&>(*columns[0]);
this->data(place).add(column.get_data()[row_num]);
diff --git a/be/src/vec/aggregate_functions/aggregate_function_bitmap.h
b/be/src/vec/aggregate_functions/aggregate_function_bitmap.h
index aa167b6571c..e9973377697 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_bitmap.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_bitmap.h
@@ -301,7 +301,7 @@ public:
DataTypePtr get_return_type() const override { return
std::make_shared<DataTypeBitMap>(); }
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
const auto& column = assert_cast<const ColVecType&>(*columns[0]);
this->data(place).add(column.get_data()[row_num]);
@@ -361,7 +361,7 @@ public:
String get_name() const override { return "count"; }
DataTypePtr get_return_type() const override { return
std::make_shared<DataTypeInt64>(); }
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
if constexpr (arg_is_nullable) {
auto& nullable_column = assert_cast<const
ColumnNullable&>(*columns[0]);
diff --git a/be/src/vec/aggregate_functions/aggregate_function_bitmap_agg.h
b/be/src/vec/aggregate_functions/aggregate_function_bitmap_agg.h
index a4c08aefe2a..000a6dab36b 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_bitmap_agg.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_bitmap_agg.h
@@ -70,7 +70,7 @@ public:
std::string get_name() const override { return "bitmap_agg"; }
DataTypePtr get_return_type() const override { return
std::make_shared<DataTypeBitMap>(); }
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena* arena) const override {
DCHECK_LT(row_num, columns[0]->size());
if constexpr (arg_nullable) {
diff --git a/be/src/vec/aggregate_functions/aggregate_function_collect.h
b/be/src/vec/aggregate_functions/aggregate_function_collect.h
index 2188fe9b242..7e3c7207a7d 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_collect.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_collect.h
@@ -469,7 +469,7 @@ public:
bool allocates_memory_in_arena() const override { return ENABLE_ARENA; }
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena* arena) const override {
auto& data = this->data(place);
if constexpr (HasLimit::value) {
diff --git a/be/src/vec/aggregate_functions/aggregate_function_count.h
b/be/src/vec/aggregate_functions/aggregate_function_count.h
index 92d0f644d33..bf44b944bda 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_count.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_count.h
@@ -65,7 +65,7 @@ public:
DataTypePtr get_return_type() const override { return
std::make_shared<DataTypeInt64>(); }
- void add(AggregateDataPtr __restrict place, const IColumn**, size_t,
Arena*) const override {
+ void add(AggregateDataPtr __restrict place, const IColumn**, ssize_t,
Arena*) const override {
++data(place).count;
}
@@ -194,7 +194,7 @@ public:
DataTypePtr get_return_type() const override { return
std::make_shared<DataTypeInt64>(); }
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
data(place).count += !assert_cast<const
ColumnNullable&>(*columns[0]).is_null_at(row_num);
}
diff --git a/be/src/vec/aggregate_functions/aggregate_function_count_by_enum.h
b/be/src/vec/aggregate_functions/aggregate_function_count_by_enum.h
index 273fa2a1e4c..93a5103ef59 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_count_by_enum.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_count_by_enum.h
@@ -159,7 +159,7 @@ public:
DataTypePtr get_return_type() const override { return
std::make_shared<DataTypeString>(); }
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena* arena) const override {
for (int i = 0; i < arg_count; i++) {
const auto* nullable_column =
check_and_get_column<ColumnNullable>(columns[i]);
diff --git a/be/src/vec/aggregate_functions/aggregate_function_covar.h
b/be/src/vec/aggregate_functions/aggregate_function_covar.h
index 0c5dfd3f037..31f0d7d2830 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_covar.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_covar.h
@@ -273,7 +273,7 @@ public:
}
}
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
if constexpr (is_pop) {
this->data(place).add(columns[0], columns[1], row_num);
diff --git a/be/src/vec/aggregate_functions/aggregate_function_distinct.h
b/be/src/vec/aggregate_functions/aggregate_function_distinct.h
index 3b4968050ae..c0c7a5b66dd 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_distinct.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_distinct.h
@@ -201,7 +201,7 @@ public:
prefix_size = (sizeof(Data) + nested_size - 1) / nested_size *
nested_size;
}
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena* arena) const override {
this->data(place).add(columns, arguments_num, row_num, arena);
}
diff --git a/be/src/vec/aggregate_functions/aggregate_function_foreach.cpp
b/be/src/vec/aggregate_functions/aggregate_function_foreach.cpp
new file mode 100644
index 00000000000..e64e5900d01
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_foreach.cpp
@@ -0,0 +1,64 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+// This file is copied from
+//
https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/Combinators/AggregateFunctionForEach.cpp
+// and modified by Doris
+
+#include "vec/aggregate_functions/aggregate_function_foreach.h"
+
+#include <memory>
+#include <ostream>
+
+#include "common/logging.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/aggregate_functions/helpers.h"
+#include "vec/common/typeid_cast.h"
+#include "vec/data_types/data_type_array.h"
+#include "vec/data_types/data_type_nullable.h"
+
+namespace doris::vectorized {
+
+void
register_aggregate_function_combinator_foreach(AggregateFunctionSimpleFactory&
factory) {
+ AggregateFunctionCreator creator = [&](const std::string& name, const
DataTypes& types,
+ const bool result_is_nullable) ->
AggregateFunctionPtr {
+ const std::string& suffix =
AggregateFunctionForEach::AGG_FOREACH_SUFFIX;
+ DataTypes transform_arguments;
+ for (const auto& t : types) {
+ auto item_type =
+ assert_cast<const
DataTypeArray*>(remove_nullable(t).get())->get_nested_type();
+ transform_arguments.push_back((item_type));
+ }
+ auto nested_function_name = name.substr(0, name.size() -
suffix.size());
+ auto nested_function =
+ factory.get(nested_function_name, transform_arguments,
result_is_nullable);
+ if (!nested_function) {
+ throw Exception(
+ ErrorCode::INTERNAL_ERROR,
+ "The combiner did not find a foreach combiner function.
nested function "
+ "name {} , args {}",
+ nested_function_name, types_name(types));
+ }
+ return
creator_without_type::create<AggregateFunctionForEach>(transform_arguments,
true,
+
nested_function);
+ };
+ factory.register_foreach_function_combinator(
+ creator, AggregateFunctionForEach::AGG_FOREACH_SUFFIX, true);
+ factory.register_foreach_function_combinator(
+ creator, AggregateFunctionForEach::AGG_FOREACH_SUFFIX, false);
+}
+} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_foreach.h
b/be/src/vec/aggregate_functions/aggregate_function_foreach.h
new file mode 100644
index 00000000000..039c2d507b8
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_foreach.h
@@ -0,0 +1,264 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+// This file is copied from
+//
https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/Combinators/AggregateFunctionForEach.h
+// and modified by Doris
+
+#pragma once
+
+#include "common/logging.h"
+#include "common/status.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/columns/column_nullable.h"
+#include "vec/common/assert_cast.h"
+#include "vec/data_types/data_type_array.h"
+#include "vec/data_types/data_type_nullable.h"
+#include "vec/functions/array/function_array_utils.h"
+#include "vec/io/io_helper.h"
+
+namespace doris::vectorized {
+
+struct AggregateFunctionForEachData {
+ size_t dynamic_array_size = 0;
+ char* array_of_aggregate_datas = nullptr;
+};
+
+/** Adaptor for aggregate functions.
+ * Adding -ForEach suffix to aggregate function
+ * will convert that aggregate function to a function, accepting arrays,
+ * and applies aggregation for each corresponding elements of arrays
independently,
+ * returning arrays of aggregated values on corresponding positions.
+ *
+ * Example: sumForEach of:
+ * [1, 2],
+ * [3, 4, 5],
+ * [6, 7]
+ * will return:
+ * [10, 13, 5]
+ *
+ * TODO Allow variable number of arguments.
+ */
+class AggregateFunctionForEach : public
IAggregateFunctionDataHelper<AggregateFunctionForEachData,
+
AggregateFunctionForEach> {
+protected:
+ using Base =
+ IAggregateFunctionDataHelper<AggregateFunctionForEachData,
AggregateFunctionForEach>;
+
+ AggregateFunctionPtr nested_function;
+ const size_t nested_size_of_data;
+ const size_t num_arguments;
+
+ AggregateFunctionForEachData& ensure_aggregate_data(AggregateDataPtr
__restrict place,
+ size_t new_size,
Arena& arena) const {
+ AggregateFunctionForEachData& state = data(place);
+
+ /// Ensure we have aggregate states for new_size elements, allocate
+ /// from arena if needed. When reallocating, we can't copy the
+ /// states to new buffer with memcpy, because they may contain pointers
+ /// to themselves. In particular, this happens when a state contains
+ /// a PODArrayWithStackMemory, which stores small number of elements
+ /// inline. This is why we create new empty states in the new buffer,
+ /// and merge the old states to them.
+ size_t old_size = state.dynamic_array_size;
+ if (old_size < new_size) {
+ static constexpr size_t MAX_ARRAY_SIZE = 100 * 1000000000ULL;
+ if (new_size > MAX_ARRAY_SIZE) {
+ throw Exception(ErrorCode::INTERNAL_ERROR,
+ "Suspiciously large array size ({}) in
-ForEach aggregate function",
+ new_size);
+ }
+
+ size_t allocation_size = 0;
+ if (common::mul_overflow(new_size, nested_size_of_data,
allocation_size)) {
+ throw Exception(ErrorCode::INTERNAL_ERROR,
+ "Allocation size ({} * {}) overflows in
-ForEach aggregate "
+ "function, but it should've been prevented by
previous checks",
+ new_size, nested_size_of_data);
+ }
+
+ char* old_state = state.array_of_aggregate_datas;
+
+ char* new_state =
+ arena.aligned_alloc(allocation_size,
nested_function->align_of_data());
+
+ size_t i;
+ try {
+ for (i = 0; i < new_size; ++i) {
+ nested_function->create(&new_state[i *
nested_size_of_data]);
+ }
+ } catch (...) {
+ size_t cleanup_size = i;
+
+ for (i = 0; i < cleanup_size; ++i) {
+ nested_function->destroy(&new_state[i *
nested_size_of_data]);
+ }
+
+ throw;
+ }
+
+ for (i = 0; i < old_size; ++i) {
+ nested_function->merge(&new_state[i * nested_size_of_data],
+ &old_state[i * nested_size_of_data],
&arena);
+ nested_function->destroy(&old_state[i * nested_size_of_data]);
+ }
+
+ state.array_of_aggregate_datas = new_state;
+ state.dynamic_array_size = new_size;
+ }
+
+ return state;
+ }
+
+public:
+ constexpr static auto AGG_FOREACH_SUFFIX = "_foreach";
+ AggregateFunctionForEach(AggregateFunctionPtr nested_function_, const
DataTypes& arguments)
+ : Base(arguments),
+ nested_function {std::move(nested_function_)},
+ nested_size_of_data(nested_function->size_of_data()),
+ num_arguments(arguments.size()) {
+ if (arguments.empty()) {
+ throw Exception(ErrorCode::INTERNAL_ERROR,
+ "Aggregate function {} require at least one
argument", get_name());
+ }
+ }
+ void set_version(const int version_) override {
+ Base::set_version(version_);
+ nested_function->set_version(version_);
+ }
+
+ String get_name() const override { return nested_function->get_name() +
AGG_FOREACH_SUFFIX; }
+
+ DataTypePtr get_return_type() const override {
+ return
std::make_shared<DataTypeArray>(nested_function->get_return_type());
+ }
+
+ void destroy(AggregateDataPtr __restrict place) const noexcept override {
+ AggregateFunctionForEachData& state = data(place);
+
+ char* nested_state = state.array_of_aggregate_datas;
+ for (size_t i = 0; i < state.dynamic_array_size; ++i) {
+ nested_function->destroy(nested_state);
+ nested_state += nested_size_of_data;
+ }
+ }
+
+ bool has_trivial_destructor() const override {
+ return std::is_trivially_destructible_v<Data> &&
nested_function->has_trivial_destructor();
+ }
+
+ void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+ Arena* arena) const override {
+ const AggregateFunctionForEachData& rhs_state = data(rhs);
+ AggregateFunctionForEachData& state =
+ ensure_aggregate_data(place, rhs_state.dynamic_array_size,
*arena);
+
+ const char* rhs_nested_state = rhs_state.array_of_aggregate_datas;
+ char* nested_state = state.array_of_aggregate_datas;
+
+ for (size_t i = 0; i < state.dynamic_array_size && i <
rhs_state.dynamic_array_size; ++i) {
+ nested_function->merge(nested_state, rhs_nested_state, arena);
+
+ rhs_nested_state += nested_size_of_data;
+ nested_state += nested_size_of_data;
+ }
+ }
+
+ void serialize(ConstAggregateDataPtr __restrict place, BufferWritable&
buf) const override {
+ const AggregateFunctionForEachData& state = data(place);
+ write_binary(state.dynamic_array_size, buf);
+ const char* nested_state = state.array_of_aggregate_datas;
+ for (size_t i = 0; i < state.dynamic_array_size; ++i) {
+ nested_function->serialize(nested_state, buf);
+ nested_state += nested_size_of_data;
+ }
+ }
+
+ void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+ Arena* arena) const override {
+ AggregateFunctionForEachData& state = data(place);
+
+ size_t new_size = 0;
+ read_binary(new_size, buf);
+
+ ensure_aggregate_data(place, new_size, *arena);
+
+ char* nested_state = state.array_of_aggregate_datas;
+ for (size_t i = 0; i < new_size; ++i) {
+ nested_function->deserialize(nested_state, buf, arena);
+ nested_state += nested_size_of_data;
+ }
+ }
+
+ void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn&
to) const override {
+ const AggregateFunctionForEachData& state = data(place);
+
+ auto& arr_to = assert_cast<ColumnArray&>(to);
+ auto& offsets_to = arr_to.get_offsets();
+ IColumn& elems_to = arr_to.get_data();
+
+ char* nested_state = state.array_of_aggregate_datas;
+ for (size_t i = 0; i < state.dynamic_array_size; ++i) {
+ nested_function->insert_result_into(nested_state, elems_to);
+ nested_state += nested_size_of_data;
+ }
+
+ offsets_to.push_back(offsets_to.back() + state.dynamic_array_size);
+ }
+
+ bool allocates_memory_in_arena() const override {
+ return nested_function->allocates_memory_in_arena();
+ }
+
+ bool is_state() const override { return nested_function->is_state(); }
+
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
+ Arena* arena) const override {
+ const IColumn* nested[num_arguments];
+
+ for (size_t i = 0; i < num_arguments; ++i) {
+ nested[i] = &assert_cast<const
ColumnArray&>(*columns[i]).get_data();
+ }
+
+ const auto& first_array_column = assert_cast<const
ColumnArray&>(*columns[0]);
+ const auto& offsets = first_array_column.get_offsets();
+
+ size_t begin = offsets[row_num - 1];
+ size_t end = offsets[row_num];
+
+ /// Sanity check. NOTE We can implement specialization for a case with
single argument, if the check will hurt performance.
+ for (size_t i = 1; i < num_arguments; ++i) {
+ const auto& ith_column = assert_cast<const
ColumnArray&>(*columns[i]);
+ const auto& ith_offsets = ith_column.get_offsets();
+
+ if (ith_offsets[row_num] != end ||
+ (row_num != 0 && ith_offsets[row_num - 1] != begin)) {
+ throw Exception(ErrorCode::INTERNAL_ERROR,
+ "Arrays passed to {} aggregate function have
different sizes",
+ get_name());
+ }
+ }
+
+ AggregateFunctionForEachData& state = ensure_aggregate_data(place, end
- begin, *arena);
+
+ char* nested_state = state.array_of_aggregate_datas;
+ for (size_t i = begin; i < end; ++i) {
+ nested_function->add(nested_state, nested, i, arena);
+ nested_state += nested_size_of_data;
+ }
+ }
+};
+} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_concat.h
b/be/src/vec/aggregate_functions/aggregate_function_group_concat.h
index 6438e65a20b..87ed907377e 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_group_concat.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_group_concat.h
@@ -124,7 +124,7 @@ public:
DataTypePtr get_return_type() const override { return
std::make_shared<DataTypeString>(); }
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
Impl::add(this->data(place), columns, row_num);
}
diff --git a/be/src/vec/aggregate_functions/aggregate_function_histogram.h
b/be/src/vec/aggregate_functions/aggregate_function_histogram.h
index 295a063bc30..cae2a88daf0 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_histogram.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_histogram.h
@@ -184,7 +184,7 @@ public:
DataTypePtr get_return_type() const override { return
std::make_shared<DataTypeString>(); }
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena* arena) const override {
if (columns[0]->is_null_at(row_num)) {
return;
diff --git a/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.h
b/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.h
index bb2ab75d6c5..f976e959f85 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.h
@@ -121,7 +121,7 @@ public:
this->data(place).insert_result_into(to);
}
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena* arena) const override {
this->data(place).add(columns[0], row_num);
}
diff --git a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h
b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h
index d79154a004c..4ef64aae558 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h
@@ -295,7 +295,7 @@ public:
DataTypePtr get_return_type() const override { return _return_type; }
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
int64_t places_address = reinterpret_cast<int64_t>(place);
Status st = this->data(_exec_place)
diff --git a/be/src/vec/aggregate_functions/aggregate_function_map.h
b/be/src/vec/aggregate_functions/aggregate_function_map.h
index a0617305830..e0a19a34207 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_map.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_map.h
@@ -172,7 +172,7 @@ public:
make_nullable(argument_types[1]));
}
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena* arena) const override {
if (columns[0]->is_nullable()) {
auto& nullable_col = assert_cast<const
ColumnNullable&>(*columns[0]);
diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max.h
b/be/src/vec/aggregate_functions/aggregate_function_min_max.h
index 56714c9ee80..dfc0cbae7f4 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_min_max.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_min_max.h
@@ -526,7 +526,7 @@ public:
DataTypePtr get_return_type() const override { return type; }
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena* arena) const override {
this->data(place).change_if_better(*columns[0], row_num, arena);
}
diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h
b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h
index b7a2f5c159d..634dc171f59 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h
@@ -167,7 +167,7 @@ public:
DataTypePtr get_return_type() const override { return value_type; }
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena* arena) const override {
this->data(place).change_if_better(*columns[0], *columns[1], row_num,
arena);
}
diff --git a/be/src/vec/aggregate_functions/aggregate_function_null.h
b/be/src/vec/aggregate_functions/aggregate_function_null.h
index 93939607382..a91a172fc05 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_null.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_null.h
@@ -200,7 +200,7 @@ public:
AggregateFunctionNullUnaryInline<NestFuction,
result_is_nullable>>(
nested_function_, arguments) {}
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena* arena) const override {
const ColumnNullable* column = assert_cast<const
ColumnNullable*>(columns[0]);
if (!column->is_null_at(row_num)) {
@@ -301,7 +301,7 @@ public:
}
}
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena* arena) const override {
/// This container stores the columns we really pass to the nested
function.
const IColumn* nested_columns[number_of_arguments];
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.h
b/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.h
index 0204c08e020..f0fd67f4a85 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.h
@@ -344,7 +344,7 @@ public:
DataTypePtr get_return_type() const override { return
Impl::get_return_type(); }
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
this->data(place).init_add_key(columns, row_num, _argument_size);
this->data(place).add(columns, row_num);
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h
b/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h
index 1a1285a9dc8..2eb7cc33098 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.h
@@ -200,7 +200,7 @@ class AggregateFunctionPercentileApproxMerge : public
AggregateFunctionPercentil
public:
AggregateFunctionPercentileApproxMerge(const DataTypes& argument_types_)
: AggregateFunctionPercentileApprox(argument_types_) {}
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
LOG(FATAL) << "AggregateFunctionPercentileApproxMerge do not support
add()";
}
@@ -211,7 +211,7 @@ class AggregateFunctionPercentileApproxTwoParams : public
AggregateFunctionPerce
public:
AggregateFunctionPercentileApproxTwoParams(const DataTypes&
argument_types_)
: AggregateFunctionPercentileApprox(argument_types_) {}
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
if constexpr (is_nullable) {
double column_data[2] = {0, 0};
@@ -251,7 +251,7 @@ class AggregateFunctionPercentileApproxThreeParams : public
AggregateFunctionPer
public:
AggregateFunctionPercentileApproxThreeParams(const DataTypes&
argument_types_)
: AggregateFunctionPercentileApprox(argument_types_) {}
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
if constexpr (is_nullable) {
double column_data[3] = {0, 0, 0};
@@ -386,7 +386,7 @@ public:
DataTypePtr get_return_type() const override { return
std::make_shared<DataTypeFloat64>(); }
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
const auto& sources = assert_cast<const
ColumnVector<Int64>&>(*columns[0]);
const auto& quantile = assert_cast<const
ColumnVector<Float64>&>(*columns[1]);
@@ -431,7 +431,7 @@ public:
return
std::make_shared<DataTypeArray>(make_nullable(std::make_shared<DataTypeFloat64>()));
}
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
const auto& sources = assert_cast<const
ColumnVector<Int64>&>(*columns[0]);
const auto& quantile_array = assert_cast<const
ColumnArray&>(*columns[1]);
diff --git a/be/src/vec/aggregate_functions/aggregate_function_product.h
b/be/src/vec/aggregate_functions/aggregate_function_product.h
index 4b0365a1b6d..22a217263b2 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_product.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_product.h
@@ -131,7 +131,7 @@ public:
}
}
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
const auto& column = assert_cast<const ColVecType&>(*columns[0]);
this->data(place).add(TResult(column.get_data()[row_num]), multiplier);
diff --git a/be/src/vec/aggregate_functions/aggregate_function_quantile_state.h
b/be/src/vec/aggregate_functions/aggregate_function_quantile_state.h
index 7c6a9cb9da0..14250087d2b 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_quantile_state.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_quantile_state.h
@@ -111,7 +111,7 @@ public:
return std::make_shared<DataTypeQuantileState>();
}
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
if constexpr (arg_is_nullable) {
auto& nullable_column = assert_cast<const
ColumnNullable&>(*columns[0]);
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_reader_first_last.h
b/be/src/vec/aggregate_functions/aggregate_function_reader_first_last.h
index 9077a009a7a..bbf62b09222 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_reader_first_last.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_reader_first_last.h
@@ -215,7 +215,7 @@ public:
this->data(place).insert_result_into(to);
}
- void add(AggregateDataPtr place, const IColumn** columns, size_t row_num,
+ void add(AggregateDataPtr place, const IColumn** columns, ssize_t row_num,
Arena* arena) const override {
this->data(place).add(row_num, columns);
}
diff --git a/be/src/vec/aggregate_functions/aggregate_function_retention.h
b/be/src/vec/aggregate_functions/aggregate_function_retention.h
index f595a1ad726..f38f1cf45a0 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_retention.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_retention.h
@@ -124,7 +124,7 @@ public:
}
void reset(AggregateDataPtr __restrict place) const override {
this->data(place).reset(); }
- void add(AggregateDataPtr __restrict place, const IColumn** columns, const
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns, const
ssize_t row_num,
Arena*) const override {
for (int i = 0; i < get_argument_types().size(); i++) {
auto event = assert_cast<const
ColumnVector<UInt8>*>(columns[i])->get_data()[row_num];
diff --git a/be/src/vec/aggregate_functions/aggregate_function_rpc.h
b/be/src/vec/aggregate_functions/aggregate_function_rpc.h
index 21e5aa290d0..c92e96aaf9d 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_rpc.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_rpc.h
@@ -357,7 +357,7 @@ public:
DataTypePtr get_return_type() const override { return _return_type; }
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
static_cast<void>(
this->data(place).buffer_add(columns, row_num, row_num + 1,
argument_types));
diff --git a/be/src/vec/aggregate_functions/aggregate_function_sequence_match.h
b/be/src/vec/aggregate_functions/aggregate_function_sequence_match.h
index 064c1e9979a..101c2c16fd0 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_sequence_match.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_sequence_match.h
@@ -599,7 +599,7 @@ public:
void reset(AggregateDataPtr __restrict place) const override {
this->data(place).reset(); }
- void add(AggregateDataPtr __restrict place, const IColumn** columns, const
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns, const
ssize_t row_num,
Arena*) const override {
std::string pattern =
assert_cast<const
ColumnString*>(columns[0])->get_data_at(0).to_string();
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
index abb84491989..c33b8b50609 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
@@ -28,6 +28,7 @@ namespace doris::vectorized {
void
register_aggregate_function_combinator_sort(AggregateFunctionSimpleFactory&
factory);
void
register_aggregate_function_combinator_distinct(AggregateFunctionSimpleFactory&
factory);
+void
register_aggregate_function_combinator_foreach(AggregateFunctionSimpleFactory&
factory);
void register_aggregate_function_sum(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_minmax(AggregateFunctionSimpleFactory&
factory);
@@ -107,6 +108,8 @@ AggregateFunctionSimpleFactory&
AggregateFunctionSimpleFactory::instance() {
register_aggregate_functions_corr(instance);
register_aggregate_function_covar_pop(instance);
register_aggregate_function_covar_samp(instance);
+
+ register_aggregate_function_combinator_foreach(instance);
});
return instance;
}
diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
index dccbd9a4d57..635709f3594 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
@@ -37,6 +37,14 @@ using DataTypes = std::vector<DataTypePtr>;
using AggregateFunctionCreator =
std::function<AggregateFunctionPtr(const std::string&, const
DataTypes&, const bool)>;
+inline std::string types_name(const DataTypes& types) {
+ std::string name;
+ for (auto&& type : types) {
+ name += type->get_name();
+ }
+ return name;
+}
+
class AggregateFunctionSimpleFactory {
public:
using Creator = AggregateFunctionCreator;
@@ -78,6 +86,21 @@ public:
}
}
+ void register_foreach_function_combinator(const Creator& creator, const
std::string& suffix,
+ bool nullable = false) {
+ auto& functions = nullable ? nullable_aggregate_functions :
aggregate_functions;
+ std::vector<std::string> need_insert;
+ for (const auto& entity : aggregate_functions) {
+ std::string target_value = entity.first + suffix;
+ if (functions.find(target_value) == functions.end()) {
+ need_insert.emplace_back(std::move(target_value));
+ }
+ }
+ for (const auto& function_name : need_insert) {
+ register_function(function_name, creator, nullable);
+ }
+ }
+
AggregateFunctionPtr get(const std::string& name, const DataTypes&
argument_types,
const bool result_is_nullable = false,
int be_version =
BeExecVersionManager::get_newest_version(),
@@ -97,7 +120,7 @@ public:
}
temporary_function_update(be_version, name_str);
- if (function_alias.count(name)) {
+ if (function_alias.contains(name)) {
name_str = function_alias[name];
}
@@ -148,7 +171,6 @@ public:
}
}
-public:
static AggregateFunctionSimpleFactory& instance();
};
}; // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_sort.h
b/be/src/vec/aggregate_functions/aggregate_function_sort.h
index 02106b75e60..07b57e41359 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_sort.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_sort.h
@@ -138,7 +138,7 @@ public:
}
}
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena* arena) const override {
this->data(place).add(columns, _arguments.size(), row_num);
}
diff --git a/be/src/vec/aggregate_functions/aggregate_function_state_union.h
b/be/src/vec/aggregate_functions/aggregate_function_state_union.h
index 4134b7f79d1..3c9e2ed3767 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_state_union.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_state_union.h
@@ -53,7 +53,7 @@ public:
DataTypePtr get_return_type() const override { return _return_type; }
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena* arena) const override {
//the range is [begin, end]
_function->deserialize_and_merge_from_column_range(place, *columns[0],
row_num, row_num,
diff --git a/be/src/vec/aggregate_functions/aggregate_function_stddev.h
b/be/src/vec/aggregate_functions/aggregate_function_stddev.h
index c84e67a7d6d..456e91c3f6a 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_stddev.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_stddev.h
@@ -296,7 +296,7 @@ public:
}
}
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
if constexpr (is_pop) {
this->data(place).add(columns[0], row_num);
diff --git a/be/src/vec/aggregate_functions/aggregate_function_sum.h
b/be/src/vec/aggregate_functions/aggregate_function_sum.h
index 41677dd419b..b53d011e5f1 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_sum.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_sum.h
@@ -98,7 +98,7 @@ public:
}
}
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
const auto& column = assert_cast<const ColVecType&>(*columns[0]);
this->data(place).add(TResult(column.get_data()[row_num]));
diff --git a/be/src/vec/aggregate_functions/aggregate_function_topn.h
b/be/src/vec/aggregate_functions/aggregate_function_topn.h
index 633a36231a7..6c7502c99a3 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_topn.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_topn.h
@@ -287,7 +287,7 @@ public:
: IAggregateFunctionDataHelper<AggregateFunctionTopNData<T>,
AggregateFunctionTopNBase<Impl,
T>>(argument_types_) {}
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
Impl::add(this->data(place), columns, row_num);
}
diff --git a/be/src/vec/aggregate_functions/aggregate_function_uniq.h
b/be/src/vec/aggregate_functions/aggregate_function_uniq.h
index 3ef0359461b..727a145c45a 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_uniq.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_uniq.h
@@ -115,7 +115,7 @@ public:
DataTypePtr get_return_type() const override { return
std::make_shared<DataTypeInt64>(); }
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
detail::OneAdder<T, Data>::add(this->data(place), *columns[0],
row_num);
}
diff --git a/be/src/vec/aggregate_functions/aggregate_function_window.h
b/be/src/vec/aggregate_functions/aggregate_function_window.h
index 5ce46495464..3b0748d519f 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_window.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_window.h
@@ -65,7 +65,7 @@ public:
DataTypePtr get_return_type() const override { return
std::make_shared<DataTypeInt64>(); }
- void add(AggregateDataPtr place, const IColumn**, size_t, Arena*) const
override {
+ void add(AggregateDataPtr place, const IColumn**, ssize_t, Arena*) const
override {
++data(place).count;
}
@@ -103,7 +103,7 @@ public:
DataTypePtr get_return_type() const override { return
std::make_shared<DataTypeInt64>(); }
- void add(AggregateDataPtr place, const IColumn**, size_t, Arena*) const
override {
+ void add(AggregateDataPtr place, const IColumn**, ssize_t, Arena*) const
override {
++data(place).rank;
}
@@ -148,7 +148,7 @@ public:
DataTypePtr get_return_type() const override { return
std::make_shared<DataTypeInt64>(); }
- void add(AggregateDataPtr place, const IColumn**, size_t, Arena*) const
override {
+ void add(AggregateDataPtr place, const IColumn**, ssize_t, Arena*) const
override {
++data(place).rank;
}
@@ -197,7 +197,7 @@ public:
DataTypePtr get_return_type() const override { return
std::make_shared<DataTypeFloat64>(); }
- void add(AggregateDataPtr place, const IColumn**, size_t, Arena*) const
override {}
+ void add(AggregateDataPtr place, const IColumn**, ssize_t, Arena*) const
override {}
void add_range_single_place(int64_t partition_start, int64_t
partition_end, int64_t frame_start,
int64_t frame_end, AggregateDataPtr place,
const IColumn** columns,
@@ -255,7 +255,7 @@ public:
DataTypePtr get_return_type() const override { return
std::make_shared<DataTypeFloat64>(); }
- void add(AggregateDataPtr place, const IColumn**, size_t, Arena*) const
override {}
+ void add(AggregateDataPtr place, const IColumn**, ssize_t, Arena*) const
override {}
void add_range_single_place(int64_t partition_start, int64_t
partition_end, int64_t frame_start,
int64_t frame_end, AggregateDataPtr place,
const IColumn** columns,
@@ -299,7 +299,7 @@ public:
DataTypePtr get_return_type() const override { return
std::make_shared<DataTypeInt64>(); }
- void add(AggregateDataPtr place, const IColumn**, size_t, Arena*) const
override {}
+ void add(AggregateDataPtr place, const IColumn**, ssize_t, Arena*) const
override {}
void add_range_single_place(int64_t partition_start, int64_t
partition_end, int64_t frame_start,
int64_t frame_end, AggregateDataPtr place,
const IColumn** columns,
@@ -556,7 +556,7 @@ public:
this->data(place).insert_result_into(to);
}
- void add(AggregateDataPtr place, const IColumn** columns, size_t row_num,
+ void add(AggregateDataPtr place, const IColumn** columns, ssize_t row_num,
Arena* arena) const override {
LOG(FATAL) << "WindowFunctionLeadLagData do not support add";
}
diff --git a/be/src/vec/aggregate_functions/aggregate_function_window_funnel.h
b/be/src/vec/aggregate_functions/aggregate_function_window_funnel.h
index 253677bbc3c..d11b45caef6 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_window_funnel.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_window_funnel.h
@@ -279,7 +279,7 @@ public:
void reset(AggregateDataPtr __restrict place) const override {
this->data(place).reset(); }
- void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
const auto& window =
assert_cast<const
ColumnVector<Int64>&>(*columns[0]).get_data()[row_num];
diff --git a/be/src/vec/data_types/serde/data_type_nullable_serde.cpp
b/be/src/vec/data_types/serde/data_type_nullable_serde.cpp
index 6bdadfe23a7..07f3c5edbd4 100644
--- a/be/src/vec/data_types/serde/data_type_nullable_serde.cpp
+++ b/be/src/vec/data_types/serde/data_type_nullable_serde.cpp
@@ -286,7 +286,7 @@ template <bool is_binary_format>
Status DataTypeNullableSerDe::_write_column_to_mysql(const IColumn& column,
MysqlRowBuffer<is_binary_format>& result,
int row_idx, bool
col_const) const {
- auto& col = static_cast<const ColumnNullable&>(column);
+ auto& col = assert_cast<const ColumnNullable&>(column);
auto& nested_col = col.get_nested_column();
col_const = col_const || is_column_const(nested_col);
const auto col_index = index_check_const(row_idx, col_const);
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
index d89ae3ffffa..e38b7143608 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
@@ -88,6 +88,7 @@ public abstract class Expr extends TreeNode<Expr> implements
ParseNode, Cloneabl
public static final String AGG_STATE_SUFFIX = "_state";
public static final String AGG_UNION_SUFFIX = "_union";
public static final String AGG_MERGE_SUFFIX = "_merge";
+ public static final String AGG_FOREACH_SUFFIX = "_foreach";
public static final String DEFAULT_EXPR_NAME = "expr";
protected boolean disableTableName = false;
diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java
b/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java
index 2ee5485d7f8..7dbf3a0ec0a 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java
@@ -43,6 +43,7 @@ import java.io.DataInput;
import java.io.DataOutput;
import java.io.DataOutputStream;
import java.io.IOException;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
@@ -878,4 +879,17 @@ public class Function implements Writable {
fnCall.setType(fnCall.getChildren().get(0).getType());
return fnCall;
}
+
+ public static FunctionCallExpr convertForEachCombinator(FunctionCallExpr
fnCall) {
+ Function aggFunction = fnCall.getFn();
+ aggFunction.setName(new
FunctionName(aggFunction.getFunctionName().getFunction() +
Expr.AGG_FOREACH_SUFFIX));
+ List<Type> argTypes = new ArrayList();
+ for (Type type : aggFunction.argTypes) {
+ argTypes.add(new ArrayType(type));
+ }
+ aggFunction.setArgs(argTypes);
+ aggFunction.setReturnType(new ArrayType(aggFunction.getReturnType(),
fnCall.isNullable()));
+ aggFunction.setNullableMode(NullableMode.ALWAYS_NULLABLE);
+ return fnCall;
+ }
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionRegistry.java
b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionRegistry.java
index fc8efef8d3e..82a09d7e04b 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionRegistry.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionRegistry.java
@@ -20,7 +20,7 @@ package org.apache.doris.catalog;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.exceptions.AnalysisException;
-import
org.apache.doris.nereids.trees.expressions.functions.AggStateFunctionBuilder;
+import
org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.udf.UdfBuilder;
import org.apache.doris.nereids.types.DataType;
@@ -93,13 +93,13 @@ public class FunctionRegistry {
if (StringUtils.isEmpty(dbName)) {
// search internal function only if dbName is empty
functionBuilders =
name2InternalBuiltinBuilders.get(name.toLowerCase());
- if (CollectionUtils.isEmpty(functionBuilders) &&
AggStateFunctionBuilder.isAggStateCombinator(name)) {
- String nestedName =
AggStateFunctionBuilder.getNestedName(name);
- String combinatorSuffix =
AggStateFunctionBuilder.getCombinatorSuffix(name);
+ if (CollectionUtils.isEmpty(functionBuilders) &&
AggCombinerFunctionBuilder.isAggStateCombinator(name)) {
+ String nestedName =
AggCombinerFunctionBuilder.getNestedName(name);
+ String combinatorSuffix =
AggCombinerFunctionBuilder.getCombinatorSuffix(name);
functionBuilders =
name2InternalBuiltinBuilders.get(nestedName.toLowerCase());
if (functionBuilders != null) {
functionBuilders = functionBuilders.stream()
- .map(builder -> new
AggStateFunctionBuilder(combinatorSuffix, builder))
+ .map(builder -> new
AggCombinerFunctionBuilder(combinatorSuffix, builder))
.filter(functionBuilder ->
functionBuilder.canApply(arguments))
.collect(Collectors.toList());
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
index 2d77124fa58..cc2034aa746 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
@@ -83,6 +83,7 @@ import
org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import
org.apache.doris.nereids.trees.expressions.functions.combinator.ForEachCombinator;
import
org.apache.doris.nereids.trees.expressions.functions.combinator.MergeCombinator;
import
org.apache.doris.nereids.trees.expressions.functions.combinator.StateCombinator;
import
org.apache.doris.nereids.trees.expressions.functions.combinator.UnionCombinator;
@@ -612,6 +613,16 @@ public class ExpressionTranslator extends
DefaultExpressionVisitor<Expr, PlanTra
new FunctionParams(false, arguments)));
}
+ @Override
+ public Expr visitForEachCombinator(ForEachCombinator combinator,
PlanTranslatorContext context) {
+ List<Expr> arguments = combinator.children().stream()
+ .map(arg -> new SlotRef(arg.getDataType().toCatalogDataType(),
arg.nullable()))
+ .collect(ImmutableList.toImmutableList());
+ return Function.convertForEachCombinator(
+ new
FunctionCallExpr(visitAggregateFunction(combinator.getNestedFunction(),
context).getFn(),
+ new FunctionParams(false, arguments)));
+ }
+
@Override
public Expr visitAggregateFunction(AggregateFunction function,
PlanTranslatorContext context) {
List<Expr> arguments = function.children().stream()
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggStateFunctionBuilder.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggCombinerFunctionBuilder.java
similarity index 74%
rename from
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggStateFunctionBuilder.java
rename to
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggCombinerFunctionBuilder.java
index 054aa4767b6..3c514475eed 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggStateFunctionBuilder.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggCombinerFunctionBuilder.java
@@ -18,11 +18,15 @@
package org.apache.doris.nereids.trees.expressions.functions;
import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import
org.apache.doris.nereids.trees.expressions.functions.combinator.ForEachCombinator;
import
org.apache.doris.nereids.trees.expressions.functions.combinator.MergeCombinator;
import
org.apache.doris.nereids.trees.expressions.functions.combinator.StateCombinator;
import
org.apache.doris.nereids.trees.expressions.functions.combinator.UnionCombinator;
import org.apache.doris.nereids.types.AggStateType;
+import org.apache.doris.nereids.types.ArrayType;
+import org.apache.doris.nereids.types.DataType;
import java.util.List;
import java.util.Objects;
@@ -31,28 +35,30 @@ import java.util.stream.Collectors;
/**
* This class used to resolve AggState's combinators
*/
-public class AggStateFunctionBuilder extends FunctionBuilder {
+public class AggCombinerFunctionBuilder extends FunctionBuilder {
public static final String COMBINATOR_LINKER = "_";
public static final String STATE = "state";
public static final String MERGE = "merge";
public static final String UNION = "union";
+ public static final String FOREACH = "foreach";
public static final String STATE_SUFFIX = COMBINATOR_LINKER + STATE;
public static final String MERGE_SUFFIX = COMBINATOR_LINKER + MERGE;
public static final String UNION_SUFFIX = COMBINATOR_LINKER + UNION;
+ public static final String FOREACH_SUFFIX = COMBINATOR_LINKER + FOREACH;
private final FunctionBuilder nestedBuilder;
private final String combinatorSuffix;
- public AggStateFunctionBuilder(String combinatorSuffix, FunctionBuilder
nestedBuilder) {
+ public AggCombinerFunctionBuilder(String combinatorSuffix, FunctionBuilder
nestedBuilder) {
this.combinatorSuffix = Objects.requireNonNull(combinatorSuffix,
"combinatorSuffix can not be null");
this.nestedBuilder = Objects.requireNonNull(nestedBuilder,
"nestedBuilder can not be null");
}
@Override
public boolean canApply(List<? extends Object> arguments) {
- if (combinatorSuffix.equals(STATE)) {
+ if (combinatorSuffix.equals(STATE) ||
combinatorSuffix.equals(FOREACH)) {
return nestedBuilder.canApply(arguments);
} else {
if (arguments.size() != 1) {
@@ -71,6 +77,23 @@ public class AggStateFunctionBuilder extends FunctionBuilder
{
return (AggregateFunction) nestedBuilder.build(nestedName, arguments);
}
+ private AggregateFunction buildForEach(String nestedName, List<? extends
Object> arguments) {
+ List<Expression> forEachargs = arguments.stream().map(expr -> {
+ if (!(expr instanceof SlotReference)) {
+ throw new IllegalStateException(
+ "Can not build foreach nested function: '" +
nestedName);
+ }
+ DataType arrayType = (((Expression) expr).getDataType());
+ if (!(arrayType instanceof ArrayType)) {
+ throw new IllegalStateException(
+ "foreach must be input array type: '" + nestedName);
+ }
+ DataType itemType = ((ArrayType) arrayType).getItemType();
+ return new SlotReference("mocked", itemType, (((ArrayType)
arrayType).containsNull()));
+ }).collect(Collectors.toList());
+ return (AggregateFunction) nestedBuilder.build(nestedName,
forEachargs);
+ }
+
private AggregateFunction buildMergeOrUnion(String nestedName, List<?
extends Object> arguments) {
if (arguments.size() != 1 || !(arguments.get(0) instanceof Expression)
|| !((Expression)
arguments.get(0)).getDataType().isAggStateType()) {
@@ -105,13 +128,16 @@ public class AggStateFunctionBuilder extends
FunctionBuilder {
} else if (combinatorSuffix.equals(UNION)) {
AggregateFunction nestedFunction = buildMergeOrUnion(nestedName,
arguments);
return new UnionCombinator((List<Expression>) arguments,
nestedFunction);
+ } else if (combinatorSuffix.equals(FOREACH)) {
+ AggregateFunction nestedFunction = buildForEach(nestedName,
arguments);
+ return new ForEachCombinator((List<Expression>) arguments,
nestedFunction);
}
return null;
}
public static boolean isAggStateCombinator(String name) {
return name.toLowerCase().endsWith(STATE_SUFFIX) ||
name.toLowerCase().endsWith(MERGE_SUFFIX)
- || name.toLowerCase().endsWith(UNION_SUFFIX);
+ || name.toLowerCase().endsWith(UNION_SUFFIX) ||
name.toLowerCase().endsWith(FOREACH_SUFFIX);
}
public static String getNestedName(String name) {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/UnionCombinator.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/ForEachCombinator.java
similarity index 67%
copy from
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/UnionCombinator.java
copy to
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/ForEachCombinator.java
index 67f09a50ebf..fbbf51eb909 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/UnionCombinator.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/ForEachCombinator.java
@@ -19,13 +19,13 @@ package
org.apache.doris.nereids.trees.expressions.functions.combinator;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
-import
org.apache.doris.nereids.trees.expressions.functions.AggStateFunctionBuilder;
-import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
+import
org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
+import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
import
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
-import org.apache.doris.nereids.types.AggStateType;
+import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import com.google.common.collect.ImmutableList;
@@ -34,41 +34,49 @@ import java.util.List;
import java.util.Objects;
/**
- * AggState combinator union
+ * combinator foreach
*/
-public class UnionCombinator extends AggregateFunction
- implements UnaryExpression, ExplicitlyCastableSignature,
AlwaysNotNullable {
+public class ForEachCombinator extends AggregateFunction
+ implements UnaryExpression, ExplicitlyCastableSignature,
AlwaysNullable {
private final AggregateFunction nested;
- private final AggStateType inputType;
- public UnionCombinator(List<Expression> arguments, AggregateFunction
nested) {
- super(nested.getName() + AggStateFunctionBuilder.UNION_SUFFIX,
arguments);
+ /**
+ * constructor of ForEachCombinator
+ */
+ public ForEachCombinator(List<Expression> arguments, AggregateFunction
nested) {
+ super(nested.getName() + AggCombinerFunctionBuilder.FOREACH_SUFFIX,
arguments);
this.nested = Objects.requireNonNull(nested, "nested can not be null");
- inputType = (AggStateType) arguments.get(0).getDataType();
+ }
+
+ public static ForEachCombinator create(AggregateFunction nested) {
+ return new ForEachCombinator(nested.getArguments(), nested);
}
@Override
- public UnionCombinator withChildren(List<Expression> children) {
- return new UnionCombinator(children, nested);
+ public ForEachCombinator withChildren(List<Expression> children) {
+ return new ForEachCombinator(children, nested);
}
@Override
public List<FunctionSignature> getSignatures() {
return nested.getSignatures().stream().map(sig -> {
- return sig.withArgumentTypes(false,
ImmutableList.of(inputType)).withReturnType(inputType);
+ return
sig.withReturnType(ArrayType.of(sig.returnType)).withArgumentTypes(false,
+ sig.argumentsTypes.stream().map(arg -> {
+ return ArrayType.of(arg);
+ }).collect(ImmutableList.toImmutableList()));
}).collect(ImmutableList.toImmutableList());
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
- return visitor.visitUnionCombinator(this, context);
+ return visitor.visitForEachCombinator(this, context);
}
@Override
public DataType getDataType() {
- return inputType;
+ return ArrayType.of(nested.getDataType(), nested.nullable());
}
public AggregateFunction getNestedFunction() {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/MergeCombinator.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/MergeCombinator.java
index b529ae2de3f..a9b2d13d0d5 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/MergeCombinator.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/MergeCombinator.java
@@ -19,7 +19,7 @@ package
org.apache.doris.nereids.trees.expressions.functions.combinator;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
-import
org.apache.doris.nereids.trees.expressions.functions.AggStateFunctionBuilder;
+import
org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.ComputeNullable;
import
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
@@ -43,7 +43,7 @@ public class MergeCombinator extends AggregateFunction
private final AggStateType inputType;
public MergeCombinator(List<Expression> arguments, AggregateFunction
nested) {
- super(nested.getName() + AggStateFunctionBuilder.MERGE_SUFFIX,
arguments);
+ super(nested.getName() + AggCombinerFunctionBuilder.MERGE_SUFFIX,
arguments);
this.nested = Objects.requireNonNull(nested, "nested can not be null");
inputType = (AggStateType) arguments.get(0).getDataType();
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/StateCombinator.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/StateCombinator.java
index db001a6793c..877824822c5 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/StateCombinator.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/StateCombinator.java
@@ -19,7 +19,7 @@ package
org.apache.doris.nereids.trees.expressions.functions.combinator;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
-import
org.apache.doris.nereids.trees.expressions.functions.AggStateFunctionBuilder;
+import
org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
@@ -47,7 +47,7 @@ public class StateCombinator extends ScalarFunction
* constructor of StateCombinator
*/
public StateCombinator(List<Expression> arguments, AggregateFunction
nested) {
- super(nested.getName() + AggStateFunctionBuilder.STATE_SUFFIX,
arguments);
+ super(nested.getName() + AggCombinerFunctionBuilder.STATE_SUFFIX,
arguments);
this.nested = Objects.requireNonNull(nested, "nested can not be null");
this.returnType = new AggStateType(nested.getName(),
arguments.stream().map(arg -> {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/UnionCombinator.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/UnionCombinator.java
index 67f09a50ebf..e1138dd4851 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/UnionCombinator.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/combinator/UnionCombinator.java
@@ -19,7 +19,7 @@ package
org.apache.doris.nereids.trees.expressions.functions.combinator;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
-import
org.apache.doris.nereids.trees.expressions.functions.AggStateFunctionBuilder;
+import
org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
@@ -43,7 +43,7 @@ public class UnionCombinator extends AggregateFunction
private final AggStateType inputType;
public UnionCombinator(List<Expression> arguments, AggregateFunction
nested) {
- super(nested.getName() + AggStateFunctionBuilder.UNION_SUFFIX,
arguments);
+ super(nested.getName() + AggCombinerFunctionBuilder.UNION_SUFFIX,
arguments);
this.nested = Objects.requireNonNull(nested, "nested can not be null");
inputType = (AggStateType) arguments.get(0).getDataType();
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java
index 4a1830341b9..594f9c75433 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java
@@ -72,6 +72,7 @@ import
org.apache.doris.nereids.trees.expressions.functions.agg.TopNWeighted;
import org.apache.doris.nereids.trees.expressions.functions.agg.Variance;
import org.apache.doris.nereids.trees.expressions.functions.agg.VarianceSamp;
import org.apache.doris.nereids.trees.expressions.functions.agg.WindowFunnel;
+import
org.apache.doris.nereids.trees.expressions.functions.combinator.ForEachCombinator;
import
org.apache.doris.nereids.trees.expressions.functions.combinator.MergeCombinator;
import
org.apache.doris.nereids.trees.expressions.functions.combinator.UnionCombinator;
import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdaf;
@@ -305,6 +306,10 @@ public interface AggregateFunctionVisitor<R, C> {
return visitAggregateFunction(combinator, context);
}
+ default R visitForEachCombinator(ForEachCombinator combinator, C context) {
+ return visitAggregateFunction(combinator, context);
+ }
+
default R visitJavaUdaf(JavaUdaf javaUdaf, C context) {
return visitAggregateFunction(javaUdaf, context);
}
diff --git a/regression-test/data/function_p0/test_agg_foreach.out
b/regression-test/data/function_p0/test_agg_foreach.out
new file mode 100644
index 00000000000..849d49bc9df
--- /dev/null
+++ b/regression-test/data/function_p0/test_agg_foreach.out
@@ -0,0 +1,28 @@
+-- This file is automatically generated. You should know what you did if you
want to edit this
+-- !sql --
+[1, 2, 3] [1, 2, 3] [100, 2, 3] [100, 2, 3]
[40.333333333333336, 2, 3] [85.95867768595042, 2, 3]
+
+-- !sql --
+[121, 4, 3] [42.897811391983879, 0, 0] [52.538874496255943, 0, null]
[1840.2222222222219, 0, 0] [2760.333333333333, 0, null]
+
+-- !sql --
+[1840.2222222222222, 0, 0] [2760.3333333333335, 0, null] [1, 0, 0]
+
+-- !sql --
+["{"20":1,"100":1,"1":1}", "{"2":2}", "{"3":1}"]
["{"20":1,"100":1,"1":1}", "{"2":2}", "{"3":1}"] [[100, 20, 1], [2],
[3]] [[100, 20, 1], [2], [3]]
+
+-- !sql --
+[3, 2, 1]
["[{"cbe":{"100":1,"1":1,"20":1},"notnull":3,"null":1,"all":4}]",
"[{"cbe":{"2":2},"notnull":2,"null":0,"all":2}]",
"[{"cbe":{"3":1},"notnull":1,"null":0,"all":1}]"]
+
+-- !sql --
+[100, 2, 3]
+
+-- !sql --
+[[1], [2, 2, 2], [3]]
+
+-- !sql --
+[null, null, null]
+
+-- !sql --
+[0, 2, 3] [117, 2, 3] [113, 0, 3]
+
diff --git a/regression-test/suites/function_p0/test_agg_foreach.groovy
b/regression-test/suites/function_p0/test_agg_foreach.groovy
new file mode 100644
index 00000000000..eec05fcde9e
--- /dev/null
+++ b/regression-test/suites/function_p0/test_agg_foreach.groovy
@@ -0,0 +1,95 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_agg_foreach") {
+ // for nereids_planner
+ // now support min min_by maxmax_by avg avg_weighted sum stddev
stddev_samp_foreach variance var_samp
+ // covar covar_samp corr
+ // topn topn_array topn_weighted
+ // count count_by_enum
+ // PERCENTILE PERCENTILE_ARRAY PERCENTILE_APPROX
+ // histogram
+ // GROUP_BIT_AND GROUP_BIT_OR GROUP_BIT_XOR
+
+ sql """ set enable_nereids_planner=true;"""
+ sql """ set enable_fallback_to_original_planner=false;"""
+
+ sql """
+ drop table if exists foreach_table;
+ """
+
+ sql """
+ CREATE TABLE IF NOT EXISTS foreach_table (
+ `id` INT(11) null COMMENT "",
+ `a` array<INT> null COMMENT "",
+ `b` array<array<INT>> null COMMENT "",
+ `s` array<String> null COMMENT ""
+ ) ENGINE=OLAP
+ DUPLICATE KEY(`id`)
+ DISTRIBUTED BY HASH(`id`) BUCKETS 1
+ PROPERTIES (
+ "replication_allocation" = "tag.location.default: 1",
+ "storage_format" = "V2"
+ );
+ """
+ sql """
+ insert into foreach_table values
+ (1,[1,2,3],[[1],[1,2,3],[2]],["ab","123"]),
+ (2,[20],[[2]],["cd"]),
+ (3,[100],[[1]],["efg"]) ,
+ (4,null,[null],null),
+ (5,[null,2],[[2],null],[null,'c']);
+ """
+
+
+ qt_sql """
+ select min_foreach(a),
min_by_foreach(a,a),max_foreach(a),max_by_foreach(a,a) ,
avg_foreach(a),avg_weighted_foreach(a,a) from foreach_table ;
+ """
+
+ qt_sql """
+ select sum_foreach(a) , stddev_foreach(a) ,stddev_samp_foreach(a) ,
variance_foreach(a) , var_samp_foreach(a) from foreach_table ;
+ """
+
+ qt_sql """
+ select covar_foreach(a,a) , covar_samp_foreach(a,a) , corr_foreach(a,a)
from foreach_table ;
+ """
+ qt_sql """
+ select topn_foreach(a,a) ,topn_foreach(a,a,a) , topn_array_foreach(a,a)
,topn_array_foreach(a,a,a)from foreach_table ;
+ """
+
+
+ qt_sql """
+ select count_foreach(a) , count_by_enum_foreach(a) from foreach_table;
+ """
+
+ qt_sql """
+ select PERCENTILE_foreach(a,a) from foreach_table;
+ """
+
+ qt_sql """
+ select PERCENTILE_ARRAY_foreach(a,b) from foreach_table where id = 1;
+ """
+
+ qt_sql """
+
+ select PERCENTILE_APPROX_foreach(a,a) from foreach_table;
+ """
+
+ qt_sql """
+ select GROUP_BIT_AND_foreach(a), GROUP_BIT_OR_foreach(a),
GROUP_BIT_XOR_foreach(a) from foreach_table;
+ """
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]