This is an automated email from the ASF dual-hosted git repository.
zclll pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 346e2c4dce8 [enhance](agg) Support max_by/min_by agg functions for
some complex type (#58736)
346e2c4dce8 is described below
commit 346e2c4dce8f04e7275c99ec9e62577f4ca5a404
Author: admiring_xm <[email protected]>
AuthorDate: Wed Dec 17 17:17:31 2025 +0800
[enhance](agg) Support max_by/min_by agg functions for some complex type
(#58736)
Issue Number: https://github.com/apache/doris/issues/58417
Problem Summary:
Support max_by/min_by agg functions for some complex type. like:
```sql
max_by(map, int), max_by(struct, int), max_by(array, int), max_by(int,
array)
min_by(map, int), min_by(struct, int), min_by(array, int), min_by(int,
array)
```
---
.../aggregate_function_max_by.cpp | 32 -------
.../aggregate_function_min_by.cpp | 32 -------
.../aggregate_function_min_max.cpp | 19 +---
.../aggregate_function_min_max_by.cpp | 16 +++-
.../aggregate_function_min_max_by.h | 103 +++++++++++++++------
.../aggregate_function_simple_factory.cpp | 6 +-
.../aggregate_functions/agg_min_max_by_test.cpp | 6 +-
.../trees/expressions/functions/agg/MaxBy.java | 7 +-
.../trees/expressions/functions/agg/MinBy.java | 7 +-
.../test_aggregate_all_functions2.out | 67 +++++++++++++-
.../test_aggregate_all_functions2.groovy | 34 +++++--
11 files changed, 199 insertions(+), 130 deletions(-)
diff --git a/be/src/vec/aggregate_functions/aggregate_function_max_by.cpp
b/be/src/vec/aggregate_functions/aggregate_function_max_by.cpp
deleted file mode 100644
index 50e611b11d8..00000000000
--- a/be/src/vec/aggregate_functions/aggregate_function_max_by.cpp
+++ /dev/null
@@ -1,32 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-#include <memory>
-
-#include "vec/aggregate_functions/aggregate_function_min_max_by.h"
-#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
-
-namespace doris::vectorized {
-#include "common/compile_check_begin.h"
-
-void register_aggregate_function_max_by(AggregateFunctionSimpleFactory&
factory) {
- factory.register_function_both(
- "max_by",
create_aggregate_function_min_max_by<AggregateFunctionsMinMaxBy,
-
AggregateFunctionMaxByData>);
-}
-
-} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_by.cpp
b/be/src/vec/aggregate_functions/aggregate_function_min_by.cpp
deleted file mode 100644
index 0af34292bbb..00000000000
--- a/be/src/vec/aggregate_functions/aggregate_function_min_by.cpp
+++ /dev/null
@@ -1,32 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-#include <memory>
-
-#include "vec/aggregate_functions/aggregate_function_min_max_by.h"
-#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
-
-namespace doris::vectorized {
-#include "common/compile_check_begin.h"
-
-void register_aggregate_function_min_by(AggregateFunctionSimpleFactory&
factory) {
- factory.register_function_both(
- "min_by",
create_aggregate_function_min_max_by<AggregateFunctionsMinMaxBy,
-
AggregateFunctionMinByData>);
-}
-
-} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp
b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp
index 1964ec0adf8..f127cbb4c51 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp
@@ -134,30 +134,13 @@ AggregateFunctionPtr
create_aggregate_function_single_value(const String& name,
return creator_without_type::create_unary_arguments<
AggregateFunctionsSingleValue<Data<SingleValueDataDecimal<TYPE_DECIMAL256>>>>(
argument_types, result_is_nullable, attr);
+ // For Complex type. Currently, only type_array supports min and max.
case PrimitiveType::TYPE_ARRAY:
- return creator_without_type::create_unary_arguments<
-
AggregateFunctionsSingleValue<Data<SingleValueDataComplexType>>>(
- argument_types, result_is_nullable, attr);
case PrimitiveType::TYPE_MAP:
- return creator_without_type::create_unary_arguments<
-
AggregateFunctionsSingleValue<Data<SingleValueDataComplexType>>>(
- argument_types, result_is_nullable, attr);
case PrimitiveType::TYPE_STRUCT:
- return creator_without_type::create_unary_arguments<
-
AggregateFunctionsSingleValue<Data<SingleValueDataComplexType>>>(
- argument_types, result_is_nullable, attr);
case PrimitiveType::TYPE_AGG_STATE:
- return creator_without_type::create_unary_arguments<
-
AggregateFunctionsSingleValue<Data<SingleValueDataComplexType>>>(
- argument_types, result_is_nullable, attr);
case PrimitiveType::TYPE_BITMAP:
- return creator_without_type::create_unary_arguments<
-
AggregateFunctionsSingleValue<Data<SingleValueDataComplexType>>>(
- argument_types, result_is_nullable, attr);
case PrimitiveType::TYPE_HLL:
- return creator_without_type::create_unary_arguments<
-
AggregateFunctionsSingleValue<Data<SingleValueDataComplexType>>>(
- argument_types, result_is_nullable, attr);
case PrimitiveType::TYPE_QUANTILE_STATE:
return creator_without_type::create_unary_arguments<
AggregateFunctionsSingleValue<Data<SingleValueDataComplexType>>>(
diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp
b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp
index bc77a327527..659a2181c26 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp
@@ -17,9 +17,11 @@
#include "vec/aggregate_functions/aggregate_function_min_max_by.h"
+#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+
namespace doris::vectorized {
#include "common/compile_check_begin.h"
-std::unique_ptr<MaxMinValueBase> create_max_min_value(const DataTypePtr& type)
{
+std::unique_ptr<MaxMinValueBase> create_max_min_value(const DataTypePtr& type,
int be_version) {
switch (type->get_primitive_type()) {
case PrimitiveType::TYPE_BOOLEAN:
return
std::make_unique<MaxMinValue<SingleValueDataFixed<TYPE_BOOLEAN>>>();
@@ -61,6 +63,11 @@ std::unique_ptr<MaxMinValueBase> create_max_min_value(const
DataTypePtr& type) {
return
std::make_unique<MaxMinValue<SingleValueDataFixed<TYPE_DATETIMEV2>>>();
case PrimitiveType::TYPE_BITMAP:
return std::make_unique<MaxMinValue<BitmapValueData>>();
+ case PrimitiveType::TYPE_ARRAY:
+ case PrimitiveType::TYPE_MAP:
+ case PrimitiveType::TYPE_STRUCT:
+ return
std::make_unique<MaxMinValue<SingleValueDataComplexType>>(DataTypes {type},
+
be_version);
default:
throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"Illegal type {} of argument of aggregate
function min/max_by",
@@ -69,6 +76,13 @@ std::unique_ptr<MaxMinValueBase> create_max_min_value(const
DataTypePtr& type) {
}
}
+void register_aggregate_function_max_min_by(AggregateFunctionSimpleFactory&
factory) {
+ factory.register_function_both(
+ "min_by",
create_aggregate_function_min_max_by<AggregateFunctionMinByData>);
+ factory.register_function_both(
+ "max_by",
create_aggregate_function_min_max_by<AggregateFunctionMaxByData>);
+}
+
} // namespace doris::vectorized
#include "common/compile_check_end.h"
diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h
b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h
index 686505ad99a..d2397d0c255 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h
@@ -47,6 +47,9 @@ struct MaxMinValue : public MaxMinValueBase {
MaxMinValue() = default;
+ MaxMinValue(const DataTypes& argument_types, int be_version)
+ : value(argument_types, be_version) {}
+
~MaxMinValue() override = default;
void write(BufferWritable& buf) const override { value.write(buf); }
@@ -67,7 +70,7 @@ struct MaxMinValue : public MaxMinValueBase {
}
};
-std::unique_ptr<MaxMinValueBase> create_max_min_value(const DataTypePtr& type);
+std::unique_ptr<MaxMinValueBase> create_max_min_value(const DataTypePtr& type,
int be_version);
/// For bitmap value
struct BitmapValueData {
@@ -120,6 +123,25 @@ public:
}
};
+/**
+ * The template parameter KT is introduced here primarily for performance
reasons.
+ *
+ * Without using a template parameter, the key type would have to be
+ * std::unique_ptr<MaxMinValueBase>. Since MaxMinValueBase is a polymorphic
base
+ * class with virtual methods, comparing keys would inevitably involve virtual
+ * function calls, which can introduce significant runtime overhead.
+ *
+ * By making KT a template parameter, the concrete key type is known at compile
+ * time, allowing static dispatch and avoiding virtual function calls. This
+ * substantially reduces the cost of key comparisons.
+ *
+ * In contrast, the value type VT is intentionally not made a template
parameter.
+ * On one hand, templating both key and value types would lead to an n × n
+ * explosion in template instantiations, increasing compile time and code size.
+ * On the other hand, value objects typically only invoke the change method;
for
+ * random data, this method is called approximately log(x) times (where x is
the
+ * data size), making the overhead acceptable.
+ */
template <typename KT>
struct AggregateFunctionMinMaxByBaseData {
protected:
@@ -127,8 +149,18 @@ protected:
KT key;
public:
- AggregateFunctionMinMaxByBaseData(const DataTypes argument_types) {
- value = create_max_min_value(argument_types[0]);
+ AggregateFunctionMinMaxByBaseData() {}
+
+ AggregateFunctionMinMaxByBaseData(const DataTypes argument_types, int
be_version)
+ requires(std::is_same_v<KT, SingleValueDataComplexType>)
+ : key(SingleValueDataComplexType(DataTypes {argument_types[1]},
be_version)) {
+ value = create_max_min_value(argument_types[0], be_version);
+ }
+
+ AggregateFunctionMinMaxByBaseData(const DataTypes argument_types, int
be_version)
+ requires(!std::is_same_v<KT, SingleValueDataComplexType>)
+ {
+ value = create_max_min_value(argument_types[0], be_version);
}
void insert_result_into(IColumn& to) const {
value->insert_result_into(to); }
@@ -152,8 +184,10 @@ template <typename KT>
struct AggregateFunctionMaxByData : public
AggregateFunctionMinMaxByBaseData<KT> {
using Self = AggregateFunctionMaxByData;
- AggregateFunctionMaxByData(const DataTypes argument_types)
- : AggregateFunctionMinMaxByBaseData<KT>(argument_types) {}
+ AggregateFunctionMaxByData() {}
+
+ AggregateFunctionMaxByData(const DataTypes argument_types, int be_version)
+ : AggregateFunctionMinMaxByBaseData<KT>(argument_types,
be_version) {}
void change_if_better(const IColumn& value_column, const IColumn&
key_column, size_t row_num,
Arena& arena) {
@@ -188,8 +222,11 @@ template <typename KT>
struct AggregateFunctionMinByData : public
AggregateFunctionMinMaxByBaseData<KT> {
using Self = AggregateFunctionMinByData;
- AggregateFunctionMinByData(const DataTypes argument_types)
- : AggregateFunctionMinMaxByBaseData<KT>(argument_types) {}
+ AggregateFunctionMinByData() {}
+
+ AggregateFunctionMinByData(const DataTypes argument_types, int be_version)
+ : AggregateFunctionMinMaxByBaseData<KT>(argument_types,
be_version) {}
+
void change_if_better(const IColumn& value_column, const IColumn&
key_column, size_t row_num,
Arena& arena) {
if (this->key.change_if_less(key_column, row_num, arena)) {
@@ -221,7 +258,7 @@ struct AggregateFunctionMinByData : public
AggregateFunctionMinMaxByBaseData<KT>
template <typename Data>
class AggregateFunctionsMinMaxBy final
- : public IAggregateFunctionDataHelper<Data,
AggregateFunctionsMinMaxBy<Data>, true>,
+ : public IAggregateFunctionDataHelper<Data,
AggregateFunctionsMinMaxBy<Data>>,
MultiExpression,
NullableAggregateFunction {
private:
@@ -230,11 +267,15 @@ private:
public:
AggregateFunctionsMinMaxBy(const DataTypes& arguments)
- : IAggregateFunctionDataHelper<Data,
AggregateFunctionsMinMaxBy<Data>, true>(
+ : IAggregateFunctionDataHelper<Data,
AggregateFunctionsMinMaxBy<Data>>(
{arguments[0], arguments[1]}),
value_type(this->argument_types[0]),
key_type(this->argument_types[1]) {}
+ void create(AggregateDataPtr __restrict place) const override {
+ new (place) Data(IAggregateFunction::argument_types,
IAggregateFunction::version);
+ }
+
String get_name() const override { return Data::name(); }
DataTypePtr get_return_type() const override { return value_type; }
@@ -270,7 +311,7 @@ public:
}
};
-template <template <typename> class AggregateFunctionTemplate, template
<typename> class Data>
+template <template <typename> class Data>
AggregateFunctionPtr create_aggregate_function_min_max_by(const String& name,
const DataTypes&
argument_types,
const DataTypePtr&
result_type,
@@ -283,77 +324,81 @@ AggregateFunctionPtr
create_aggregate_function_min_max_by(const String& name,
switch (argument_types[1]->get_primitive_type()) {
case PrimitiveType::TYPE_BOOLEAN:
return creator_without_type::create_multi_arguments<
-
AggregateFunctionTemplate<Data<SingleValueDataFixed<TYPE_BOOLEAN>>>>(
+
AggregateFunctionsMinMaxBy<Data<SingleValueDataFixed<TYPE_BOOLEAN>>>>(
argument_types, result_is_nullable, attr);
case PrimitiveType::TYPE_TINYINT:
return creator_without_type::create_multi_arguments<
-
AggregateFunctionTemplate<Data<SingleValueDataFixed<TYPE_TINYINT>>>>(
+
AggregateFunctionsMinMaxBy<Data<SingleValueDataFixed<TYPE_TINYINT>>>>(
argument_types, result_is_nullable, attr);
case PrimitiveType::TYPE_SMALLINT:
return creator_without_type::create_multi_arguments<
-
AggregateFunctionTemplate<Data<SingleValueDataFixed<TYPE_SMALLINT>>>>(
+
AggregateFunctionsMinMaxBy<Data<SingleValueDataFixed<TYPE_SMALLINT>>>>(
argument_types, result_is_nullable, attr);
case PrimitiveType::TYPE_INT:
return creator_without_type::create_multi_arguments<
-
AggregateFunctionTemplate<Data<SingleValueDataFixed<TYPE_INT>>>>(
+
AggregateFunctionsMinMaxBy<Data<SingleValueDataFixed<TYPE_INT>>>>(
argument_types, result_is_nullable, attr);
case PrimitiveType::TYPE_BIGINT:
return creator_without_type::create_multi_arguments<
-
AggregateFunctionTemplate<Data<SingleValueDataFixed<TYPE_BIGINT>>>>(
+
AggregateFunctionsMinMaxBy<Data<SingleValueDataFixed<TYPE_BIGINT>>>>(
argument_types, result_is_nullable, attr);
case PrimitiveType::TYPE_LARGEINT:
return creator_without_type::create_multi_arguments<
-
AggregateFunctionTemplate<Data<SingleValueDataFixed<TYPE_LARGEINT>>>>(
+
AggregateFunctionsMinMaxBy<Data<SingleValueDataFixed<TYPE_LARGEINT>>>>(
argument_types, result_is_nullable, attr);
case PrimitiveType::TYPE_FLOAT:
return creator_without_type::create_multi_arguments<
-
AggregateFunctionTemplate<Data<SingleValueDataFixed<TYPE_FLOAT>>>>(
+
AggregateFunctionsMinMaxBy<Data<SingleValueDataFixed<TYPE_FLOAT>>>>(
argument_types, result_is_nullable, attr);
case PrimitiveType::TYPE_DOUBLE:
return creator_without_type::create_multi_arguments<
-
AggregateFunctionTemplate<Data<SingleValueDataFixed<TYPE_DOUBLE>>>>(
+
AggregateFunctionsMinMaxBy<Data<SingleValueDataFixed<TYPE_DOUBLE>>>>(
argument_types, result_is_nullable, attr);
case PrimitiveType::TYPE_DECIMAL32:
return creator_without_type::create_multi_arguments<
-
AggregateFunctionTemplate<Data<SingleValueDataDecimal<TYPE_DECIMAL32>>>>(
+
AggregateFunctionsMinMaxBy<Data<SingleValueDataDecimal<TYPE_DECIMAL32>>>>(
argument_types, result_is_nullable, attr);
case PrimitiveType::TYPE_DECIMAL64:
return creator_without_type::create_multi_arguments<
-
AggregateFunctionTemplate<Data<SingleValueDataDecimal<TYPE_DECIMAL64>>>>(
+
AggregateFunctionsMinMaxBy<Data<SingleValueDataDecimal<TYPE_DECIMAL64>>>>(
argument_types, result_is_nullable, attr);
case PrimitiveType::TYPE_DECIMAL128I:
return creator_without_type::create_multi_arguments<
-
AggregateFunctionTemplate<Data<SingleValueDataDecimal<TYPE_DECIMAL128I>>>>(
+
AggregateFunctionsMinMaxBy<Data<SingleValueDataDecimal<TYPE_DECIMAL128I>>>>(
argument_types, result_is_nullable, attr);
case PrimitiveType::TYPE_DECIMALV2:
return creator_without_type::create_multi_arguments<
-
AggregateFunctionTemplate<Data<SingleValueDataDecimal<TYPE_DECIMALV2>>>>(
+
AggregateFunctionsMinMaxBy<Data<SingleValueDataDecimal<TYPE_DECIMALV2>>>>(
argument_types, result_is_nullable, attr);
case PrimitiveType::TYPE_DECIMAL256:
return creator_without_type::create_multi_arguments<
-
AggregateFunctionTemplate<Data<SingleValueDataDecimal<TYPE_DECIMAL256>>>>(
+
AggregateFunctionsMinMaxBy<Data<SingleValueDataDecimal<TYPE_DECIMAL256>>>>(
argument_types, result_is_nullable, attr);
case PrimitiveType::TYPE_CHAR:
case PrimitiveType::TYPE_VARCHAR:
case PrimitiveType::TYPE_STRING:
return creator_without_type::create_multi_arguments<
-
AggregateFunctionTemplate<Data<SingleValueDataString>>>(argument_types,
-
result_is_nullable, attr);
+
AggregateFunctionsMinMaxBy<Data<SingleValueDataString>>>(argument_types,
+
result_is_nullable, attr);
case PrimitiveType::TYPE_DATE:
return creator_without_type::create_multi_arguments<
-
AggregateFunctionTemplate<Data<SingleValueDataFixed<TYPE_DATE>>>>(
+
AggregateFunctionsMinMaxBy<Data<SingleValueDataFixed<TYPE_DATE>>>>(
argument_types, result_is_nullable, attr);
case PrimitiveType::TYPE_DATETIME:
return creator_without_type::create_multi_arguments<
-
AggregateFunctionTemplate<Data<SingleValueDataFixed<TYPE_DATETIME>>>>(
+
AggregateFunctionsMinMaxBy<Data<SingleValueDataFixed<TYPE_DATETIME>>>>(
argument_types, result_is_nullable, attr);
case PrimitiveType::TYPE_DATEV2:
return creator_without_type::create_multi_arguments<
-
AggregateFunctionTemplate<Data<SingleValueDataFixed<TYPE_DATEV2>>>>(
+
AggregateFunctionsMinMaxBy<Data<SingleValueDataFixed<TYPE_DATEV2>>>>(
argument_types, result_is_nullable, attr);
case PrimitiveType::TYPE_DATETIMEV2:
return creator_without_type::create_multi_arguments<
-
AggregateFunctionTemplate<Data<SingleValueDataFixed<TYPE_DATETIMEV2>>>>(
+
AggregateFunctionsMinMaxBy<Data<SingleValueDataFixed<TYPE_DATETIMEV2>>>>(
+ argument_types, result_is_nullable, attr);
+ case PrimitiveType::TYPE_ARRAY:
+ return creator_without_type::create_multi_arguments<
+ AggregateFunctionsMinMaxBy<Data<SingleValueDataComplexType>>>(
argument_types, result_is_nullable, attr);
default:
return nullptr;
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
index aa4410d347e..8ba0696b667 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
@@ -33,8 +33,7 @@ void
register_aggregate_function_combinator_foreachv2(AggregateFunctionSimpleFac
void register_aggregate_function_sum(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_minmax(AggregateFunctionSimpleFactory&
factory);
-void register_aggregate_function_min_by(AggregateFunctionSimpleFactory&
factory);
-void register_aggregate_function_max_by(AggregateFunctionSimpleFactory&
factory);
+void register_aggregate_function_max_min_by(AggregateFunctionSimpleFactory&
factory);
void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_count(AggregateFunctionSimpleFactory&
factory);
void register_aggregate_function_count_by_enum(AggregateFunctionSimpleFactory&
factory);
@@ -87,8 +86,7 @@ AggregateFunctionSimpleFactory&
AggregateFunctionSimpleFactory::instance() {
std::call_once(oc, [&]() {
register_aggregate_function_sum(instance);
register_aggregate_function_minmax(instance);
- register_aggregate_function_min_by(instance);
- register_aggregate_function_max_by(instance);
+ register_aggregate_function_max_min_by(instance);
register_aggregate_function_avg(instance);
register_aggregate_function_count(instance);
register_aggregate_function_count_by_enum(instance);
diff --git a/be/test/vec/aggregate_functions/agg_min_max_by_test.cpp
b/be/test/vec/aggregate_functions/agg_min_max_by_test.cpp
index e6c110bebad..1f16517c5c1 100644
--- a/be/test/vec/aggregate_functions/agg_min_max_by_test.cpp
+++ b/be/test/vec/aggregate_functions/agg_min_max_by_test.cpp
@@ -45,8 +45,7 @@ const int agg_test_batch_size = 4096;
namespace doris::vectorized {
// declare function
-void register_aggregate_function_min_by(AggregateFunctionSimpleFactory&
factory);
-void register_aggregate_function_max_by(AggregateFunctionSimpleFactory&
factory);
+void register_aggregate_function_max_min_by(AggregateFunctionSimpleFactory&
factory);
class AggMinMaxByTest : public ::testing::TestWithParam<std::string> {};
@@ -78,8 +77,7 @@ TEST_P(AggMinMaxByTest, min_max_by_test) {
// Prepare test function and parameters.
AggregateFunctionSimpleFactory factory;
- register_aggregate_function_min_by(factory);
- register_aggregate_function_max_by(factory);
+ register_aggregate_function_max_min_by(factory);
// Test on 2 kind of key types (int32, string).
for (int i = 0; i < 2; i++) {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MaxBy.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MaxBy.java
index cc0324bcf39..63ca76c4800 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MaxBy.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MaxBy.java
@@ -66,11 +66,16 @@ public class MaxBy extends NullableAggregateFunction
@Override
public void checkLegalityBeforeTypeCoercion() {
- if (getArgumentType(1).isOnlyMetricType()) {
+ if (getArgumentType(1).isOnlyMetricType() &&
!getArgumentType(1).isArrayType()) {
throw new AnalysisException(Type.OnlyMetricTypeErrorMsg);
}
}
+ @Override
+ public void checkLegalityAfterRewrite() {
+ checkLegalityBeforeTypeCoercion();
+ }
+
/**
* withDistinctAndChildren.
*/
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MinBy.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MinBy.java
index 632abea0322..ab3f7a68e74 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MinBy.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MinBy.java
@@ -66,11 +66,16 @@ public class MinBy extends NullableAggregateFunction
@Override
public void checkLegalityBeforeTypeCoercion() {
- if (getArgumentType(1).isOnlyMetricType()) {
+ if (getArgumentType(1).isOnlyMetricType() &&
!getArgumentType(1).isArrayType()) {
throw new AnalysisException(Type.OnlyMetricTypeErrorMsg);
}
}
+ @Override
+ public void checkLegalityAfterRewrite() {
+ checkLegalityBeforeTypeCoercion();
+ }
+
/**
* withDistinctAndChildren.
*/
diff --git
a/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions2.out
b/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions2.out
index 7df61ecbf6e..41ca9183b9c 100644
---
a/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions2.out
+++
b/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions2.out
@@ -375,13 +375,78 @@ true
-- !maxmin_array_1 --
\N \N
+-- !maxminby_array_1 --
+\N \N
+
+-- !maxminby_map_1 --
+\N \N
+
+-- !maxminby_struct_1 --
+\N \N
+
-- !maxmin_array_2 --
[7] [1, 2, 3]
-[1, 2, 3, 4] [1, 2]
+[1, 2, 5] [1, 2]
[11, 22, 33, 44] [3, 1]
[10] []
[11, null, null, 55] [1, null, null, 4]
+-- !maxminby_array_2 --
+[1, 2, 5] [11, 22, 33, 44]
+
+-- !maxminby_array_3 --
+{"A":10, "B":1} {"x":50, "y":60}
+
+-- !maxminby_array_4 --
+{"a":10, "b":"tt"} {"a":4, "b":"delta"}
+
+-- !maxminby_array_5 --
+1 [5, 6] [7]
+2 [1, 2, 5] [1, 2]
+3 [3, 1] [11, 22, 33, 44]
+4 [3, 1] [10]
+5 [1, null, 3, 4] [1, 2, 3, 4]
+
+-- !maxminby_map_2 --
+{"foo":1, "bar":2} {"A":10, "B":1}
+
+-- !maxminby_map_3 --
+1 {"k1":30, "k2":15} {"x":100, "y":200}
+2 {"foo":1, "bar":2} {"foo":2, "bar":1}
+3 {"key1":99, "key2":98} {"A":10, "B":1}
+4 {"key1":99, "key2":98} {"A":5, "B":10}
+5 {"A":null, "B":null} {"A":null, "B":5}
+
+-- !maxminby_map_4 --
+1 {"x":100, "y":200} {"k1":10, "k2":20}
+2 {"foo":1, "bar":2} {"foo":2, "bar":1}
+3 {"A":10, "B":1} {"key1":99, "key2":98}
+4 {"A":5, "B":10} {"x":50, "y":60}
+5 {"A":10, "B":5} {"A":null, "B":null}
+
+-- !maxminby_struct_2 --
+{"a":5, "b":"echo"} {"a":10, "b":"tt"}
+
+-- !maxminby_struct_3 --
+1 {"a":2, "b":"beta"} {"a":3, "b":"gamma"}
+2 {"a":5, "b":"echo"} {"a":6, "b":"zulu"}
+3 {"a":8, "b":"eight"} {"a":10, "b":"tt"}
+4 {"a":8, "b":"eight"} {"a":9, "b":"nine"}
+5 {"a":null, "b":null} {"a":null, "b":"ten"}
+
+-- !maxminby_struct_4 --
+1 {"a":3, "b":"gamma"} {"a":1, "b":"alpha"}
+2 {"a":5, "b":"echo"} {"a":6, "b":"zulu"}
+3 {"a":10, "b":"tt"} {"a":8, "b":"eight"}
+4 {"a":9, "b":"nine"} {"a":4, "b":"delta"}
+5 {"a":10, "b":"ten"} {"a":null, "b":null}
+
-- !maxmin_array_3 --
[[3, 4], [3, 4]] [[1, 2], [3, 4]]
+-- !maxminby_array_6 --
+3 1
+
+-- !maxminby_array_7 --
+[[3, 4], [3, 4]] [[1, 2], [3, 4]]
+
diff --git
a/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions2.groovy
b/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions2.groovy
index e2f186a61f3..3d9258f6568 100644
---
a/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions2.groovy
+++
b/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions2.groovy
@@ -287,31 +287,49 @@ suite("test_aggregate_all_functions2") {
"""
qt_maxmin_array_1 """SELECT max(arr), min(arr) from test_maxmin"""
+ qt_maxminby_array_1 """SELECT max_by(arr, weight), min_by(arr, weight)
from test_maxmin"""
+ qt_maxminby_map_1 """SELECT max_by(mp, weight), min_by(mp, weight) from
test_maxmin"""
+ qt_maxminby_struct_1 """SELECT max_by(st, weight), min_by(st, weight) from
test_maxmin"""
sql """
INSERT INTO test_maxmin (id, arr, mp, st, weight) VALUES
(1, [1,2,3], {"k1": 10, "k2": 20},
NAMED_STRUCT("a", 1, "b", "alpha"), 5),
(1, [5,6], {"k1": 30, "k2": 15},
NAMED_STRUCT("a", 2, "b", "beta"), 10),
(1, [7], {"x": 100, "y": 200},
NAMED_STRUCT("a", 3, "b", "gamma"), 3),
- (2, [1,2,3], {"foo": 1, "bar": 2},
NAMED_STRUCT("a", 5, "b", "echo"), 15),
+ (2, [1,2,5], {"foo": 1, "bar": 2},
NAMED_STRUCT("a", 5, "b", "echo"), 15),
(2, [1,2], {"foo": 2, "bar": 1},
NAMED_STRUCT("a", 6, "b", "zulu"), 7),
(2, [1,2,3,4], {"key1": -1, "key2": -5},
NAMED_STRUCT("a", 7, "b", "seven"), 12),
(3, [3,1], {"key1": 99, "key2": 98},
NAMED_STRUCT("a", 8, "b", "eight"), 9),
(3, [10], {"A": 5, "B": 10},
NAMED_STRUCT("a", 9, "b", "nine"), 6),
- (3, [11,22,33,44], {"A": 10, "B": 5},
NAMED_STRUCT("a", 10,"b", "ten"), 1),
+ (3, [11,22,33,44], {"A": 10, "B": 1},
NAMED_STRUCT("a", 10,"b", "tt"), 1),
(3, null, null, null,
17),
+ (3, null, null, null,
-1),
(4, [3,1], {"key1": 99, "key2": 98},
NAMED_STRUCT("a", 8, "b", "eight"), 9),
(4, [10], {"A": 5, "B": 10},
NAMED_STRUCT("a", 9, "b", "nine"), 6),
(4, [], {"x": 50, "y": 60},
NAMED_STRUCT("a", 4, "b", "delta"), 8),
- (5, [1,2,3,4], {"A": null, "B": 5},
NAMED_STRUCT("a", null,"b", "ten"), 1),
- (5, [1,2,null,4], {"A": 100, "B": null},
NAMED_STRUCT("a", 10,"b", null), 1),
- (5, [1,null,null,4], {"A": null, "B": null},
NAMED_STRUCT("a", null,"b", null), 1),
- (5, [1,null,3,4], {"A": null, "B": null},
NAMED_STRUCT("a", null,"b", null), 1),
- (5, [11,null,null,55], {"A": 10, "B": 5},
NAMED_STRUCT("a", 10,"b", "ten"), 1);
+ (5, [1,2,3,4], {"A": null, "B": 5},
NAMED_STRUCT("a", null,"b", "ten"), 2),
+ (5, [1,2,null,4], {"A": 100, "B": null},
NAMED_STRUCT("a", 10,"b", null), 3),
+ (5, [1,null,null,4], {"A": null, "B": null},
NAMED_STRUCT("a", null,"b", null), 4),
+ (5, [1,null,3,4], {"A": null, "B": null},
NAMED_STRUCT("a", null,"b", null), 5),
+ (5, [11,null,null,55], {"A": 10, "B": 5},
NAMED_STRUCT("a", 10,"b", "ten"), null);
"""
qt_maxmin_array_2 """SELECT max(arr), min(arr) from test_maxmin group by
id order by id"""
+ qt_maxminby_array_2 """SELECT max_by(arr, weight), min_by(arr, weight)
from test_maxmin"""
+ qt_maxminby_array_3 """SELECT max_by(mp, arr), min_by(mp, arr) from
test_maxmin"""
+ qt_maxminby_array_4 """SELECT max_by(st, arr), min_by(st, arr) from
test_maxmin"""
+ qt_maxminby_array_5 """SELECT id, max_by(arr, weight), min_by(arr, weight)
from test_maxmin group by id order by id"""
+
+ qt_maxminby_map_2 """SELECT max_by(mp, weight), min_by(mp, weight) from
test_maxmin"""
+ qt_maxminby_map_3 """SELECT id, max_by(mp, weight), min_by(mp, weight)
from test_maxmin group by id order by id"""
+ qt_maxminby_map_4 """SELECT id, max_by(mp, arr), min_by(mp, arr) from
test_maxmin group by id order by id"""
+
+ qt_maxminby_struct_2 """SELECT max_by(st, weight), min_by(st, weight) from
test_maxmin"""
+ qt_maxminby_struct_3 """SELECT id, max_by(st, weight), min_by(st, weight)
from test_maxmin group by id order by id"""
+ qt_maxminby_struct_4 """SELECT id, max_by(st, arr), min_by(st, arr) from
test_maxmin group by id order by id"""
+
+ sql "DROP TABLE IF EXISTS test_nested_maxmin";
sql """
CREATE TABLE test_nested_maxmin (
id INT,
@@ -330,4 +348,6 @@ suite("test_aggregate_all_functions2") {
"""
qt_maxmin_array_3 """SELECT max(arr), min(arr) from test_nested_maxmin"""
+ qt_maxminby_array_6 """SELECT max_by(weight, arr), min_by(weight, arr)
from test_nested_maxmin"""
+ qt_maxminby_array_7 """SELECT max_by(arr, weight), min_by(arr, weight)
from test_nested_maxmin"""
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]