This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new e08e53efb23 [refine](type)use new type dispatch for agg function
(#57112)
e08e53efb23 is described below
commit e08e53efb237582c595a1d8a9fdd8bdb45552071
Author: Mryange <[email protected]>
AuthorDate: Sat Oct 25 17:06:04 2025 +0800
[refine](type)use new type dispatch for agg function (#57112)
---
.../aggregate_function_array_agg.cpp | 74 ++++----------------
.../aggregate_function_collect.cpp | 79 ++++------------------
.../aggregate_functions/aggregate_function_map.cpp | 59 +++-------------
3 files changed, 37 insertions(+), 175 deletions(-)
diff --git a/be/src/vec/aggregate_functions/aggregate_function_array_agg.cpp
b/be/src/vec/aggregate_functions/aggregate_function_array_agg.cpp
index 4d974229ce7..5f0327c55a1 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_array_agg.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_array_agg.cpp
@@ -20,6 +20,7 @@
#include "vec/aggregate_functions/aggregate_function_collect.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/helpers.h"
+#include "vec/core/call_on_type_index.h"
namespace doris::vectorized {
#include "common/compile_check_begin.h"
@@ -43,69 +44,20 @@ AggregateFunctionPtr
create_aggregate_function_array_agg(const std::string& name
const DataTypes&
argument_types,
const bool
result_is_nullable,
const
AggregateFunctionAttr& attr) {
- switch (argument_types[0]->get_primitive_type()) {
- case PrimitiveType::TYPE_BOOLEAN:
- return do_create_agg_function_collect<TYPE_BOOLEAN>(argument_types,
result_is_nullable,
- attr);
- case PrimitiveType::TYPE_TINYINT:
- return do_create_agg_function_collect<TYPE_TINYINT>(argument_types,
result_is_nullable,
- attr);
- case PrimitiveType::TYPE_SMALLINT:
- return do_create_agg_function_collect<TYPE_SMALLINT>(argument_types,
result_is_nullable,
- attr);
- case PrimitiveType::TYPE_INT:
- return do_create_agg_function_collect<TYPE_INT>(argument_types,
result_is_nullable, attr);
- case PrimitiveType::TYPE_BIGINT:
- return do_create_agg_function_collect<TYPE_BIGINT>(argument_types,
result_is_nullable,
- attr);
- case PrimitiveType::TYPE_LARGEINT:
- return do_create_agg_function_collect<TYPE_LARGEINT>(argument_types,
result_is_nullable,
- attr);
- case PrimitiveType::TYPE_FLOAT:
- return do_create_agg_function_collect<TYPE_FLOAT>(argument_types,
result_is_nullable, attr);
- case PrimitiveType::TYPE_DOUBLE:
- return do_create_agg_function_collect<TYPE_DOUBLE>(argument_types,
result_is_nullable,
- attr);
- case PrimitiveType::TYPE_DECIMAL32:
- return do_create_agg_function_collect<TYPE_DECIMAL32>(argument_types,
result_is_nullable,
- attr);
- case PrimitiveType::TYPE_DECIMAL64:
- return do_create_agg_function_collect<TYPE_DECIMAL64>(argument_types,
result_is_nullable,
- attr);
- case PrimitiveType::TYPE_DECIMAL128I:
- return
do_create_agg_function_collect<TYPE_DECIMAL128I>(argument_types,
result_is_nullable,
- attr);
- case PrimitiveType::TYPE_DECIMALV2:
- return do_create_agg_function_collect<TYPE_DECIMALV2>(argument_types,
result_is_nullable,
- attr);
- case PrimitiveType::TYPE_DECIMAL256:
- return do_create_agg_function_collect<TYPE_DECIMAL256>(argument_types,
result_is_nullable,
- attr);
- case PrimitiveType::TYPE_DATE:
- return do_create_agg_function_collect<TYPE_DATE>(argument_types,
result_is_nullable, attr);
- case PrimitiveType::TYPE_DATETIME:
- return do_create_agg_function_collect<TYPE_DATETIME>(argument_types,
result_is_nullable,
- attr);
- case PrimitiveType::TYPE_DATEV2:
- return do_create_agg_function_collect<TYPE_DATEV2>(argument_types,
result_is_nullable,
- attr);
- case PrimitiveType::TYPE_DATETIMEV2:
- return do_create_agg_function_collect<TYPE_DATETIMEV2>(argument_types,
result_is_nullable,
- attr);
- case PrimitiveType::TYPE_IPV4:
- return do_create_agg_function_collect<TYPE_IPV4>(argument_types,
result_is_nullable, attr);
- case PrimitiveType::TYPE_IPV6:
- return do_create_agg_function_collect<TYPE_IPV6>(argument_types,
result_is_nullable, attr);
- case PrimitiveType::TYPE_STRING:
- case PrimitiveType::TYPE_CHAR:
- case PrimitiveType::TYPE_VARCHAR:
- return do_create_agg_function_collect<TYPE_VARCHAR>(argument_types,
result_is_nullable,
- attr);
- default:
+ AggregateFunctionPtr agg_fn;
+ auto call = [&](const auto& type) -> bool {
+ using DispatcType = std::decay_t<decltype(type)>;
+ agg_fn =
do_create_agg_function_collect<DispatcType::PType>(argument_types,
+
result_is_nullable, attr);
+ return true;
+ };
+
+ if (!dispatch_switch_all(argument_types[0]->get_primitive_type(), call)) {
// We do not care what the real type is.
- return do_create_agg_function_collect<INVALID_TYPE>(argument_types,
result_is_nullable,
- attr);
+ agg_fn = do_create_agg_function_collect<INVALID_TYPE>(argument_types,
result_is_nullable,
+ attr);
}
+ return agg_fn;
}
void register_aggregate_function_array_agg(AggregateFunctionSimpleFactory&
factory) {
diff --git a/be/src/vec/aggregate_functions/aggregate_function_collect.cpp
b/be/src/vec/aggregate_functions/aggregate_function_collect.cpp
index 91531a76747..5d84641c37c 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_collect.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_collect.cpp
@@ -22,6 +22,7 @@
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/factory_helpers.h"
#include "vec/aggregate_functions/helpers.h"
+#include "vec/core/call_on_type_index.h"
namespace doris::vectorized {
#include "common/compile_check_begin.h"
@@ -55,74 +56,20 @@ AggregateFunctionPtr
create_aggregate_function_collect_impl(const std::string& n
const
AggregateFunctionAttr& attr) {
bool distinct = name == "collect_set";
- switch (argument_types[0]->get_primitive_type()) {
- case PrimitiveType::TYPE_BOOLEAN:
- return do_create_agg_function_collect<TYPE_BOOLEAN,
HasLimit>(distinct, argument_types,
-
result_is_nullable, attr);
- case PrimitiveType::TYPE_TINYINT:
- return do_create_agg_function_collect<TYPE_TINYINT,
HasLimit>(distinct, argument_types,
-
result_is_nullable, attr);
- case PrimitiveType::TYPE_SMALLINT:
- return do_create_agg_function_collect<TYPE_SMALLINT,
HasLimit>(distinct, argument_types,
-
result_is_nullable, attr);
- case PrimitiveType::TYPE_INT:
- return do_create_agg_function_collect<TYPE_INT, HasLimit>(distinct,
argument_types,
-
result_is_nullable, attr);
- case PrimitiveType::TYPE_BIGINT:
- return do_create_agg_function_collect<TYPE_BIGINT, HasLimit>(distinct,
argument_types,
-
result_is_nullable, attr);
- case PrimitiveType::TYPE_LARGEINT:
- return do_create_agg_function_collect<TYPE_LARGEINT,
HasLimit>(distinct, argument_types,
-
result_is_nullable, attr);
- case PrimitiveType::TYPE_FLOAT:
- return do_create_agg_function_collect<TYPE_FLOAT, HasLimit>(distinct,
argument_types,
-
result_is_nullable, attr);
- case PrimitiveType::TYPE_DOUBLE:
- return do_create_agg_function_collect<TYPE_DOUBLE, HasLimit>(distinct,
argument_types,
-
result_is_nullable, attr);
- case PrimitiveType::TYPE_DECIMAL32:
- return do_create_agg_function_collect<TYPE_DECIMAL32,
HasLimit>(distinct, argument_types,
-
result_is_nullable, attr);
- case PrimitiveType::TYPE_DECIMAL64:
- return do_create_agg_function_collect<TYPE_DECIMAL64,
HasLimit>(distinct, argument_types,
-
result_is_nullable, attr);
- case PrimitiveType::TYPE_DECIMALV2:
- return do_create_agg_function_collect<TYPE_DECIMALV2,
HasLimit>(distinct, argument_types,
-
result_is_nullable, attr);
- case PrimitiveType::TYPE_DECIMAL128I:
- return do_create_agg_function_collect<TYPE_DECIMAL128I,
HasLimit>(distinct, argument_types,
-
result_is_nullable, attr);
- case PrimitiveType::TYPE_DECIMAL256:
- return do_create_agg_function_collect<TYPE_DECIMAL256,
HasLimit>(distinct, argument_types,
-
result_is_nullable, attr);
- case PrimitiveType::TYPE_DATE:
- return do_create_agg_function_collect<TYPE_DATE, HasLimit>(distinct,
argument_types,
-
result_is_nullable, attr);
- case PrimitiveType::TYPE_DATETIME:
- return do_create_agg_function_collect<TYPE_DATETIME,
HasLimit>(distinct, argument_types,
-
result_is_nullable, attr);
- case PrimitiveType::TYPE_DATEV2:
- return do_create_agg_function_collect<TYPE_DATEV2, HasLimit>(distinct,
argument_types,
-
result_is_nullable, attr);
- case PrimitiveType::TYPE_DATETIMEV2:
- return do_create_agg_function_collect<TYPE_DATETIMEV2,
HasLimit>(distinct, argument_types,
-
result_is_nullable, attr);
- case PrimitiveType::TYPE_IPV6:
- return do_create_agg_function_collect<TYPE_IPV6, HasLimit>(distinct,
argument_types,
-
result_is_nullable, attr);
- case PrimitiveType::TYPE_IPV4:
- return do_create_agg_function_collect<TYPE_IPV4, HasLimit>(distinct,
argument_types,
-
result_is_nullable, attr);
- case PrimitiveType::TYPE_STRING:
- case PrimitiveType::TYPE_CHAR:
- case PrimitiveType::TYPE_VARCHAR:
- return do_create_agg_function_collect<TYPE_VARCHAR,
HasLimit>(distinct, argument_types,
-
result_is_nullable, attr);
- default:
+ AggregateFunctionPtr agg_fn;
+ auto call = [&](const auto& type) -> bool {
+ using DispatcType = std::decay_t<decltype(type)>;
+ agg_fn = do_create_agg_function_collect<DispatcType::PType, HasLimit>(
+ distinct, argument_types, result_is_nullable, attr);
+ return true;
+ };
+
+ if (!dispatch_switch_all(argument_types[0]->get_primitive_type(), call)) {
// We do not care what the real type is.
- return do_create_agg_function_collect<INVALID_TYPE,
HasLimit>(distinct, argument_types,
-
result_is_nullable, attr);
+ agg_fn = do_create_agg_function_collect<INVALID_TYPE,
HasLimit>(distinct, argument_types,
+
result_is_nullable, attr);
}
+ return agg_fn;
}
AggregateFunctionPtr create_aggregate_function_collect(const std::string& name,
diff --git a/be/src/vec/aggregate_functions/aggregate_function_map.cpp
b/be/src/vec/aggregate_functions/aggregate_function_map.cpp
index 1d35b877c32..c78520a0866 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_map.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_map.cpp
@@ -18,6 +18,7 @@
#include "vec/aggregate_functions/aggregate_function_map.h"
#include "vec/aggregate_functions/helpers.h"
+#include "vec/core/call_on_type_index.h"
namespace doris::vectorized {
#include "common/compile_check_begin.h"
@@ -35,58 +36,20 @@ AggregateFunctionPtr
create_aggregate_function_map_agg(const std::string& name,
const DataTypes&
argument_types,
const bool
result_is_nullable,
const
AggregateFunctionAttr& attr) {
- switch (argument_types[0]->get_primitive_type()) {
- case PrimitiveType::TYPE_BOOLEAN:
- return create_agg_function_map_agg<TYPE_BOOLEAN>(argument_types,
result_is_nullable, attr);
- case PrimitiveType::TYPE_TINYINT:
- return create_agg_function_map_agg<TYPE_TINYINT>(argument_types,
result_is_nullable, attr);
- case PrimitiveType::TYPE_SMALLINT:
- return create_agg_function_map_agg<TYPE_SMALLINT>(argument_types,
result_is_nullable, attr);
- case PrimitiveType::TYPE_INT:
- return create_agg_function_map_agg<TYPE_INT>(argument_types,
result_is_nullable, attr);
- case PrimitiveType::TYPE_BIGINT:
- return create_agg_function_map_agg<TYPE_BIGINT>(argument_types,
result_is_nullable, attr);
- case PrimitiveType::TYPE_LARGEINT:
- return create_agg_function_map_agg<TYPE_LARGEINT>(argument_types,
result_is_nullable, attr);
- case PrimitiveType::TYPE_FLOAT:
- return create_agg_function_map_agg<TYPE_FLOAT>(argument_types,
result_is_nullable, attr);
- case PrimitiveType::TYPE_DOUBLE:
- return create_agg_function_map_agg<TYPE_DOUBLE>(argument_types,
result_is_nullable, attr);
- case PrimitiveType::TYPE_DECIMAL32:
- return create_agg_function_map_agg<TYPE_DECIMAL32>(argument_types,
result_is_nullable,
- attr);
- case PrimitiveType::TYPE_DECIMAL64:
- return create_agg_function_map_agg<TYPE_DECIMAL64>(argument_types,
result_is_nullable,
- attr);
- case PrimitiveType::TYPE_DECIMAL128I:
- return create_agg_function_map_agg<TYPE_DECIMAL128I>(argument_types,
result_is_nullable,
- attr);
- case PrimitiveType::TYPE_DECIMALV2:
- return create_agg_function_map_agg<TYPE_DECIMALV2>(argument_types,
result_is_nullable,
- attr);
- case PrimitiveType::TYPE_DECIMAL256:
- return create_agg_function_map_agg<TYPE_DECIMAL256>(argument_types,
result_is_nullable,
- attr);
- case PrimitiveType::TYPE_STRING:
- return create_agg_function_map_agg<TYPE_STRING>(argument_types,
result_is_nullable, attr);
- case PrimitiveType::TYPE_CHAR:
- return create_agg_function_map_agg<TYPE_CHAR>(argument_types,
result_is_nullable, attr);
- case PrimitiveType::TYPE_VARCHAR:
- return create_agg_function_map_agg<TYPE_VARCHAR>(argument_types,
result_is_nullable, attr);
- case PrimitiveType::TYPE_DATE:
- return create_agg_function_map_agg<TYPE_DATE>(argument_types,
result_is_nullable, attr);
- case PrimitiveType::TYPE_DATETIME:
- return create_agg_function_map_agg<TYPE_DATETIME>(argument_types,
result_is_nullable, attr);
- case PrimitiveType::TYPE_DATEV2:
- return create_agg_function_map_agg<TYPE_DATEV2>(argument_types,
result_is_nullable, attr);
- case PrimitiveType::TYPE_DATETIMEV2:
- return create_agg_function_map_agg<TYPE_DATETIMEV2>(argument_types,
result_is_nullable,
- attr);
- default:
+ AggregateFunctionPtr agg_fn;
+ auto call = [&](const auto& type) -> bool {
+ using DispatcType = std::decay_t<decltype(type)>;
+ agg_fn =
create_agg_function_map_agg<DispatcType::PType>(argument_types,
result_is_nullable,
+ attr);
+ return true;
+ };
+
+ if (!dispatch_switch_all(argument_types[0]->get_primitive_type(), call)) {
LOG(WARNING) << fmt::format("unsupported input type {} for aggregate
function {}",
argument_types[0]->get_name(), name);
return nullptr;
}
+ return agg_fn;
}
void register_aggregate_function_map_agg(AggregateFunctionSimpleFactory&
factory) {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]