This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new e08e53efb23 [refine](type)use new type dispatch for agg function 
(#57112)
e08e53efb23 is described below

commit e08e53efb237582c595a1d8a9fdd8bdb45552071
Author: Mryange <[email protected]>
AuthorDate: Sat Oct 25 17:06:04 2025 +0800

    [refine](type)use new type dispatch for agg function (#57112)
---
 .../aggregate_function_array_agg.cpp               | 74 ++++----------------
 .../aggregate_function_collect.cpp                 | 79 ++++------------------
 .../aggregate_functions/aggregate_function_map.cpp | 59 +++-------------
 3 files changed, 37 insertions(+), 175 deletions(-)

diff --git a/be/src/vec/aggregate_functions/aggregate_function_array_agg.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_array_agg.cpp
index 4d974229ce7..5f0327c55a1 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_array_agg.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_array_agg.cpp
@@ -20,6 +20,7 @@
 #include "vec/aggregate_functions/aggregate_function_collect.h"
 #include "vec/aggregate_functions/aggregate_function_simple_factory.h"
 #include "vec/aggregate_functions/helpers.h"
+#include "vec/core/call_on_type_index.h"
 
 namespace doris::vectorized {
 #include "common/compile_check_begin.h"
@@ -43,69 +44,20 @@ AggregateFunctionPtr 
create_aggregate_function_array_agg(const std::string& name
                                                          const DataTypes& 
argument_types,
                                                          const bool 
result_is_nullable,
                                                          const 
AggregateFunctionAttr& attr) {
-    switch (argument_types[0]->get_primitive_type()) {
-    case PrimitiveType::TYPE_BOOLEAN:
-        return do_create_agg_function_collect<TYPE_BOOLEAN>(argument_types, 
result_is_nullable,
-                                                            attr);
-    case PrimitiveType::TYPE_TINYINT:
-        return do_create_agg_function_collect<TYPE_TINYINT>(argument_types, 
result_is_nullable,
-                                                            attr);
-    case PrimitiveType::TYPE_SMALLINT:
-        return do_create_agg_function_collect<TYPE_SMALLINT>(argument_types, 
result_is_nullable,
-                                                             attr);
-    case PrimitiveType::TYPE_INT:
-        return do_create_agg_function_collect<TYPE_INT>(argument_types, 
result_is_nullable, attr);
-    case PrimitiveType::TYPE_BIGINT:
-        return do_create_agg_function_collect<TYPE_BIGINT>(argument_types, 
result_is_nullable,
-                                                           attr);
-    case PrimitiveType::TYPE_LARGEINT:
-        return do_create_agg_function_collect<TYPE_LARGEINT>(argument_types, 
result_is_nullable,
-                                                             attr);
-    case PrimitiveType::TYPE_FLOAT:
-        return do_create_agg_function_collect<TYPE_FLOAT>(argument_types, 
result_is_nullable, attr);
-    case PrimitiveType::TYPE_DOUBLE:
-        return do_create_agg_function_collect<TYPE_DOUBLE>(argument_types, 
result_is_nullable,
-                                                           attr);
-    case PrimitiveType::TYPE_DECIMAL32:
-        return do_create_agg_function_collect<TYPE_DECIMAL32>(argument_types, 
result_is_nullable,
-                                                              attr);
-    case PrimitiveType::TYPE_DECIMAL64:
-        return do_create_agg_function_collect<TYPE_DECIMAL64>(argument_types, 
result_is_nullable,
-                                                              attr);
-    case PrimitiveType::TYPE_DECIMAL128I:
-        return 
do_create_agg_function_collect<TYPE_DECIMAL128I>(argument_types, 
result_is_nullable,
-                                                                attr);
-    case PrimitiveType::TYPE_DECIMALV2:
-        return do_create_agg_function_collect<TYPE_DECIMALV2>(argument_types, 
result_is_nullable,
-                                                              attr);
-    case PrimitiveType::TYPE_DECIMAL256:
-        return do_create_agg_function_collect<TYPE_DECIMAL256>(argument_types, 
result_is_nullable,
-                                                               attr);
-    case PrimitiveType::TYPE_DATE:
-        return do_create_agg_function_collect<TYPE_DATE>(argument_types, 
result_is_nullable, attr);
-    case PrimitiveType::TYPE_DATETIME:
-        return do_create_agg_function_collect<TYPE_DATETIME>(argument_types, 
result_is_nullable,
-                                                             attr);
-    case PrimitiveType::TYPE_DATEV2:
-        return do_create_agg_function_collect<TYPE_DATEV2>(argument_types, 
result_is_nullable,
-                                                           attr);
-    case PrimitiveType::TYPE_DATETIMEV2:
-        return do_create_agg_function_collect<TYPE_DATETIMEV2>(argument_types, 
result_is_nullable,
-                                                               attr);
-    case PrimitiveType::TYPE_IPV4:
-        return do_create_agg_function_collect<TYPE_IPV4>(argument_types, 
result_is_nullable, attr);
-    case PrimitiveType::TYPE_IPV6:
-        return do_create_agg_function_collect<TYPE_IPV6>(argument_types, 
result_is_nullable, attr);
-    case PrimitiveType::TYPE_STRING:
-    case PrimitiveType::TYPE_CHAR:
-    case PrimitiveType::TYPE_VARCHAR:
-        return do_create_agg_function_collect<TYPE_VARCHAR>(argument_types, 
result_is_nullable,
-                                                            attr);
-    default:
+    AggregateFunctionPtr agg_fn;
+    auto call = [&](const auto& type) -> bool {
+        using DispatcType = std::decay_t<decltype(type)>;
+        agg_fn = 
do_create_agg_function_collect<DispatcType::PType>(argument_types,
+                                                                    
result_is_nullable, attr);
+        return true;
+    };
+
+    if (!dispatch_switch_all(argument_types[0]->get_primitive_type(), call)) {
         // We do not care what the real type is.
-        return do_create_agg_function_collect<INVALID_TYPE>(argument_types, 
result_is_nullable,
-                                                            attr);
+        agg_fn = do_create_agg_function_collect<INVALID_TYPE>(argument_types, 
result_is_nullable,
+                                                              attr);
     }
+    return agg_fn;
 }
 
 void register_aggregate_function_array_agg(AggregateFunctionSimpleFactory& 
factory) {
diff --git a/be/src/vec/aggregate_functions/aggregate_function_collect.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_collect.cpp
index 91531a76747..5d84641c37c 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_collect.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_collect.cpp
@@ -22,6 +22,7 @@
 #include "vec/aggregate_functions/aggregate_function_simple_factory.h"
 #include "vec/aggregate_functions/factory_helpers.h"
 #include "vec/aggregate_functions/helpers.h"
+#include "vec/core/call_on_type_index.h"
 
 namespace doris::vectorized {
 #include "common/compile_check_begin.h"
@@ -55,74 +56,20 @@ AggregateFunctionPtr 
create_aggregate_function_collect_impl(const std::string& n
                                                             const 
AggregateFunctionAttr& attr) {
     bool distinct = name == "collect_set";
 
-    switch (argument_types[0]->get_primitive_type()) {
-    case PrimitiveType::TYPE_BOOLEAN:
-        return do_create_agg_function_collect<TYPE_BOOLEAN, 
HasLimit>(distinct, argument_types,
-                                                                      
result_is_nullable, attr);
-    case PrimitiveType::TYPE_TINYINT:
-        return do_create_agg_function_collect<TYPE_TINYINT, 
HasLimit>(distinct, argument_types,
-                                                                      
result_is_nullable, attr);
-    case PrimitiveType::TYPE_SMALLINT:
-        return do_create_agg_function_collect<TYPE_SMALLINT, 
HasLimit>(distinct, argument_types,
-                                                                       
result_is_nullable, attr);
-    case PrimitiveType::TYPE_INT:
-        return do_create_agg_function_collect<TYPE_INT, HasLimit>(distinct, 
argument_types,
-                                                                  
result_is_nullable, attr);
-    case PrimitiveType::TYPE_BIGINT:
-        return do_create_agg_function_collect<TYPE_BIGINT, HasLimit>(distinct, 
argument_types,
-                                                                     
result_is_nullable, attr);
-    case PrimitiveType::TYPE_LARGEINT:
-        return do_create_agg_function_collect<TYPE_LARGEINT, 
HasLimit>(distinct, argument_types,
-                                                                       
result_is_nullable, attr);
-    case PrimitiveType::TYPE_FLOAT:
-        return do_create_agg_function_collect<TYPE_FLOAT, HasLimit>(distinct, 
argument_types,
-                                                                    
result_is_nullable, attr);
-    case PrimitiveType::TYPE_DOUBLE:
-        return do_create_agg_function_collect<TYPE_DOUBLE, HasLimit>(distinct, 
argument_types,
-                                                                     
result_is_nullable, attr);
-    case PrimitiveType::TYPE_DECIMAL32:
-        return do_create_agg_function_collect<TYPE_DECIMAL32, 
HasLimit>(distinct, argument_types,
-                                                                        
result_is_nullable, attr);
-    case PrimitiveType::TYPE_DECIMAL64:
-        return do_create_agg_function_collect<TYPE_DECIMAL64, 
HasLimit>(distinct, argument_types,
-                                                                        
result_is_nullable, attr);
-    case PrimitiveType::TYPE_DECIMALV2:
-        return do_create_agg_function_collect<TYPE_DECIMALV2, 
HasLimit>(distinct, argument_types,
-                                                                        
result_is_nullable, attr);
-    case PrimitiveType::TYPE_DECIMAL128I:
-        return do_create_agg_function_collect<TYPE_DECIMAL128I, 
HasLimit>(distinct, argument_types,
-                                                                          
result_is_nullable, attr);
-    case PrimitiveType::TYPE_DECIMAL256:
-        return do_create_agg_function_collect<TYPE_DECIMAL256, 
HasLimit>(distinct, argument_types,
-                                                                         
result_is_nullable, attr);
-    case PrimitiveType::TYPE_DATE:
-        return do_create_agg_function_collect<TYPE_DATE, HasLimit>(distinct, 
argument_types,
-                                                                   
result_is_nullable, attr);
-    case PrimitiveType::TYPE_DATETIME:
-        return do_create_agg_function_collect<TYPE_DATETIME, 
HasLimit>(distinct, argument_types,
-                                                                       
result_is_nullable, attr);
-    case PrimitiveType::TYPE_DATEV2:
-        return do_create_agg_function_collect<TYPE_DATEV2, HasLimit>(distinct, 
argument_types,
-                                                                     
result_is_nullable, attr);
-    case PrimitiveType::TYPE_DATETIMEV2:
-        return do_create_agg_function_collect<TYPE_DATETIMEV2, 
HasLimit>(distinct, argument_types,
-                                                                         
result_is_nullable, attr);
-    case PrimitiveType::TYPE_IPV6:
-        return do_create_agg_function_collect<TYPE_IPV6, HasLimit>(distinct, 
argument_types,
-                                                                   
result_is_nullable, attr);
-    case PrimitiveType::TYPE_IPV4:
-        return do_create_agg_function_collect<TYPE_IPV4, HasLimit>(distinct, 
argument_types,
-                                                                   
result_is_nullable, attr);
-    case PrimitiveType::TYPE_STRING:
-    case PrimitiveType::TYPE_CHAR:
-    case PrimitiveType::TYPE_VARCHAR:
-        return do_create_agg_function_collect<TYPE_VARCHAR, 
HasLimit>(distinct, argument_types,
-                                                                      
result_is_nullable, attr);
-    default:
+    AggregateFunctionPtr agg_fn;
+    auto call = [&](const auto& type) -> bool {
+        using DispatcType = std::decay_t<decltype(type)>;
+        agg_fn = do_create_agg_function_collect<DispatcType::PType, HasLimit>(
+                distinct, argument_types, result_is_nullable, attr);
+        return true;
+    };
+
+    if (!dispatch_switch_all(argument_types[0]->get_primitive_type(), call)) {
         // We do not care what the real type is.
-        return do_create_agg_function_collect<INVALID_TYPE, 
HasLimit>(distinct, argument_types,
-                                                                      
result_is_nullable, attr);
+        agg_fn = do_create_agg_function_collect<INVALID_TYPE, 
HasLimit>(distinct, argument_types,
+                                                                        
result_is_nullable, attr);
     }
+    return agg_fn;
 }
 
 AggregateFunctionPtr create_aggregate_function_collect(const std::string& name,
diff --git a/be/src/vec/aggregate_functions/aggregate_function_map.cpp 
b/be/src/vec/aggregate_functions/aggregate_function_map.cpp
index 1d35b877c32..c78520a0866 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_map.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_map.cpp
@@ -18,6 +18,7 @@
 #include "vec/aggregate_functions/aggregate_function_map.h"
 
 #include "vec/aggregate_functions/helpers.h"
+#include "vec/core/call_on_type_index.h"
 
 namespace doris::vectorized {
 #include "common/compile_check_begin.h"
@@ -35,58 +36,20 @@ AggregateFunctionPtr 
create_aggregate_function_map_agg(const std::string& name,
                                                        const DataTypes& 
argument_types,
                                                        const bool 
result_is_nullable,
                                                        const 
AggregateFunctionAttr& attr) {
-    switch (argument_types[0]->get_primitive_type()) {
-    case PrimitiveType::TYPE_BOOLEAN:
-        return create_agg_function_map_agg<TYPE_BOOLEAN>(argument_types, 
result_is_nullable, attr);
-    case PrimitiveType::TYPE_TINYINT:
-        return create_agg_function_map_agg<TYPE_TINYINT>(argument_types, 
result_is_nullable, attr);
-    case PrimitiveType::TYPE_SMALLINT:
-        return create_agg_function_map_agg<TYPE_SMALLINT>(argument_types, 
result_is_nullable, attr);
-    case PrimitiveType::TYPE_INT:
-        return create_agg_function_map_agg<TYPE_INT>(argument_types, 
result_is_nullable, attr);
-    case PrimitiveType::TYPE_BIGINT:
-        return create_agg_function_map_agg<TYPE_BIGINT>(argument_types, 
result_is_nullable, attr);
-    case PrimitiveType::TYPE_LARGEINT:
-        return create_agg_function_map_agg<TYPE_LARGEINT>(argument_types, 
result_is_nullable, attr);
-    case PrimitiveType::TYPE_FLOAT:
-        return create_agg_function_map_agg<TYPE_FLOAT>(argument_types, 
result_is_nullable, attr);
-    case PrimitiveType::TYPE_DOUBLE:
-        return create_agg_function_map_agg<TYPE_DOUBLE>(argument_types, 
result_is_nullable, attr);
-    case PrimitiveType::TYPE_DECIMAL32:
-        return create_agg_function_map_agg<TYPE_DECIMAL32>(argument_types, 
result_is_nullable,
-                                                           attr);
-    case PrimitiveType::TYPE_DECIMAL64:
-        return create_agg_function_map_agg<TYPE_DECIMAL64>(argument_types, 
result_is_nullable,
-                                                           attr);
-    case PrimitiveType::TYPE_DECIMAL128I:
-        return create_agg_function_map_agg<TYPE_DECIMAL128I>(argument_types, 
result_is_nullable,
-                                                             attr);
-    case PrimitiveType::TYPE_DECIMALV2:
-        return create_agg_function_map_agg<TYPE_DECIMALV2>(argument_types, 
result_is_nullable,
-                                                           attr);
-    case PrimitiveType::TYPE_DECIMAL256:
-        return create_agg_function_map_agg<TYPE_DECIMAL256>(argument_types, 
result_is_nullable,
-                                                            attr);
-    case PrimitiveType::TYPE_STRING:
-        return create_agg_function_map_agg<TYPE_STRING>(argument_types, 
result_is_nullable, attr);
-    case PrimitiveType::TYPE_CHAR:
-        return create_agg_function_map_agg<TYPE_CHAR>(argument_types, 
result_is_nullable, attr);
-    case PrimitiveType::TYPE_VARCHAR:
-        return create_agg_function_map_agg<TYPE_VARCHAR>(argument_types, 
result_is_nullable, attr);
-    case PrimitiveType::TYPE_DATE:
-        return create_agg_function_map_agg<TYPE_DATE>(argument_types, 
result_is_nullable, attr);
-    case PrimitiveType::TYPE_DATETIME:
-        return create_agg_function_map_agg<TYPE_DATETIME>(argument_types, 
result_is_nullable, attr);
-    case PrimitiveType::TYPE_DATEV2:
-        return create_agg_function_map_agg<TYPE_DATEV2>(argument_types, 
result_is_nullable, attr);
-    case PrimitiveType::TYPE_DATETIMEV2:
-        return create_agg_function_map_agg<TYPE_DATETIMEV2>(argument_types, 
result_is_nullable,
-                                                            attr);
-    default:
+    AggregateFunctionPtr agg_fn;
+    auto call = [&](const auto& type) -> bool {
+        using DispatcType = std::decay_t<decltype(type)>;
+        agg_fn = 
create_agg_function_map_agg<DispatcType::PType>(argument_types, 
result_is_nullable,
+                                                                 attr);
+        return true;
+    };
+
+    if (!dispatch_switch_all(argument_types[0]->get_primitive_type(), call)) {
         LOG(WARNING) << fmt::format("unsupported input type {} for aggregate 
function {}",
                                     argument_types[0]->get_name(), name);
         return nullptr;
     }
+    return agg_fn;
 }
 
 void register_aggregate_function_map_agg(AggregateFunctionSimpleFactory& 
factory) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to