HappenLee commented on code in PR #16955:
URL: https://github.com/apache/doris/pull/16955#discussion_r1113769789


##########
be/src/vec/aggregate_functions/helpers.h:
##########
@@ -48,150 +55,145 @@ namespace doris::vectorized {
 
 /** Create an aggregate function with a numeric type in the template 
parameter, depending on the type of the argument.
   */
-template <template <typename> class AggregateFunctionTemplate, typename... 
TArgs>
-static IAggregateFunction* create_with_numeric_type(const IDataType& 
argument_type,
-                                                    TArgs&&... args) {
-    WhichDataType which(argument_type);
-#define DISPATCH(TYPE)                \
-    if (which.idx == TypeIndex::TYPE) \
-        return new 
AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...);
-    FOR_NUMERIC_TYPES(DISPATCH)
-#undef DISPATCH
-    return nullptr;
-}
-
-template <template <typename> class AggregateFunctionTemplate, typename... 
TArgs>
-static IAggregateFunction* create_with_numeric_type_null(const DataTypes& 
argument_types,
-                                                         TArgs&&... args) {
-    WhichDataType which(argument_types[0]);
-#define DISPATCH(TYPE)                                                         
             \
-    if (which.idx == TypeIndex::TYPE)                                          
             \
-        return new 
AggregateFunctionNullUnaryInline<AggregateFunctionTemplate<TYPE>, true>( \
-                new 
AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...),          \
-                argument_types);
-    FOR_NUMERIC_TYPES(DISPATCH)
-#undef DISPATCH
-    return nullptr;
-}
-
-template <template <typename, bool> class AggregateFunctionTemplate, bool 
bool_param,
-          typename... TArgs>
-static IAggregateFunction* create_with_numeric_type(const IDataType& 
argument_type,
-                                                    TArgs&&... args) {
-    WhichDataType which(argument_type);
-#define DISPATCH(TYPE)                \
-    if (which.idx == TypeIndex::TYPE) \
-        return new AggregateFunctionTemplate<TYPE, 
bool_param>(std::forward<TArgs>(args)...);
-    FOR_NUMERIC_TYPES(DISPATCH)
-#undef DISPATCH
-    return nullptr;
-}
-
-template <template <typename, typename> class AggregateFunctionTemplate, 
typename Data,
-          typename... TArgs>
-static IAggregateFunction* create_with_numeric_type(const IDataType& 
argument_type,
-                                                    TArgs&&... args) {
-    WhichDataType which(argument_type);
-#define DISPATCH(TYPE)                \
-    if (which.idx == TypeIndex::TYPE) \
-        return new AggregateFunctionTemplate<TYPE, 
Data>(std::forward<TArgs>(args)...);
-    FOR_NUMERIC_TYPES(DISPATCH)
-#undef DISPATCH
-    return nullptr;
-}
-
-template <template <typename, typename> class AggregateFunctionTemplate,
-          template <typename> class Data, typename... TArgs>
-static IAggregateFunction* create_with_numeric_type(const IDataType& 
argument_type,
-                                                    TArgs&&... args) {
-    WhichDataType which(argument_type);
-#define DISPATCH(TYPE)                \
-    if (which.idx == TypeIndex::TYPE) \
-        return new AggregateFunctionTemplate<TYPE, 
Data<TYPE>>(std::forward<TArgs>(args)...);
-    FOR_NUMERIC_TYPES(DISPATCH)
-#undef DISPATCH
-    return nullptr;
-}
-
+template <template <typename> class AggregateFunctionTemplate, typename Type>
+struct BuilerDirect {
+    using T = AggregateFunctionTemplate<Type>;
+};
 template <template <typename> class AggregateFunctionTemplate, template 
<typename> class Data,
-          typename... TArgs>
-static IAggregateFunction* create_with_numeric_type(const IDataType& 
argument_type,
-                                                    TArgs&&... args) {
-    WhichDataType which(argument_type);
-#define DISPATCH(TYPE)                \
-    if (which.idx == TypeIndex::TYPE) \
-        return new 
AggregateFunctionTemplate<Data<TYPE>>(std::forward<TArgs>(args)...);
-    FOR_NUMERIC_TYPES(DISPATCH)
-#undef DISPATCH
-    return nullptr;
-}
-
-template <template <typename> class AggregateFunctionTemplate, typename... 
TArgs>
-static IAggregateFunction* create_with_decimal_type(const IDataType& 
argument_type,
-                                                    TArgs&&... args) {
-    WhichDataType which(argument_type);
-#define DISPATCH(TYPE)                \
-    if (which.idx == TypeIndex::TYPE) \
-        return new 
AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...);
-    FOR_DECIMAL_TYPES(DISPATCH)
-#undef DISPATCH
-    return nullptr;
-}
-
-template <template <typename> class AggregateFunctionTemplate, typename... 
TArgs>
-static IAggregateFunction* create_with_decimal_type_null(const DataTypes& 
argument_types,
-                                                         TArgs&&... args) {
-    WhichDataType which(argument_types[0]);
-#define DISPATCH(TYPE)                                                         
             \
-    if (which.idx == TypeIndex::TYPE)                                          
             \
-        return new 
AggregateFunctionNullUnaryInline<AggregateFunctionTemplate<TYPE>, true>( \
-                new 
AggregateFunctionTemplate<TYPE>(std::forward<TArgs>(args)...),          \
-                argument_types);
-    FOR_DECIMAL_TYPES(DISPATCH)
-#undef DISPATCH
-    return nullptr;
-}
-
-template <template <typename, typename> class AggregateFunctionTemplate, 
typename Data,
-          typename... TArgs>
-static IAggregateFunction* create_with_decimal_type(const IDataType& 
argument_type,
-                                                    TArgs&&... args) {
-    WhichDataType which(argument_type);
-#define DISPATCH(TYPE)                \
-    if (which.idx == TypeIndex::TYPE) \
-        return new AggregateFunctionTemplate<TYPE, 
Data>(std::forward<TArgs>(args)...);
-    FOR_DECIMAL_TYPES(DISPATCH)
-#undef DISPATCH
-    return nullptr;
-}
-
-/** For template with two arguments.
-  */
-template <typename FirstType, template <typename, typename> class 
AggregateFunctionTemplate,
-          typename... TArgs>
-static IAggregateFunction* create_with_two_numeric_types_second(const 
IDataType& second_type,
-                                                                TArgs&&... 
args) {
-    WhichDataType which(second_type);
-#define DISPATCH(TYPE)                \
-    if (which.idx == TypeIndex::TYPE) \
-        return new AggregateFunctionTemplate<FirstType, 
TYPE>(std::forward<TArgs>(args)...);
-    FOR_NUMERIC_TYPES(DISPATCH)
-#undef DISPATCH
-    return nullptr;
-}
-
-template <template <typename, typename> class AggregateFunctionTemplate, 
typename... TArgs>
-static IAggregateFunction* create_with_two_numeric_types(const IDataType& 
first_type,
-                                                         const IDataType& 
second_type,
-                                                         TArgs&&... args) {
-    WhichDataType which(first_type);
-#define DISPATCH(TYPE)                                                         
       \
-    if (which.idx == TypeIndex::TYPE)                                          
       \
-        return create_with_two_numeric_types_second<TYPE, 
AggregateFunctionTemplate>( \
-                second_type, std::forward<TArgs>(args)...);
-    FOR_NUMERIC_TYPES(DISPATCH)
+          typename Type>
+struct BuilerData {
+    using T = AggregateFunctionTemplate<Data<Type>>;
+};
+template <template <typename, typename> class AggregateFunctionTemplate,
+          template <typename> class Data, typename Type>
+struct BuilerDirectAndData {
+    using T = AggregateFunctionTemplate<Type, Data<Type>>;
+};
+
+template <template <typename> class AggregateFunctionTemplate>
+struct CurryDirect {
+    template <typename Type>
+    using Builder = BuilerDirect<AggregateFunctionTemplate, Type>;
+};
+template <template <typename> class AggregateFunctionTemplate, template 
<typename> class Data>
+struct CurryData {
+    template <typename Type>
+    using Builder = BuilerData<AggregateFunctionTemplate, Data, Type>;
+};
+template <template <typename, typename> class AggregateFunctionTemplate,
+          template <typename> class Data>
+struct CurryBuilerDirectAndData {
+    template <typename Type>
+    using Builder = BuilerDirectAndData<AggregateFunctionTemplate, Data, Type>;
+};
+
+template <bool allow_integer, bool allow_float, bool allow_decimal, int 
define_index = 0>
+struct creator_with_type_base {
+    template <typename Class, typename... TArgs>
+    static IAggregateFunction* create_base(const bool result_is_nullable,
+                                           const DataTypes& argument_types, 
TArgs&&... args) {
+        WhichDataType which(remove_nullable(argument_types[define_index]));
+#define DISPATCH(TYPE)                                                         
                   \
+    if (which.idx == TypeIndex::TYPE) {                                        
                   \
+        using T = typename Class::template Builder<TYPE>::T;                   
                   \
+        if (argument_types[define_index]->is_nullable()) {                     
                   \
+            IAggregateFunction* result = nullptr;                              
                   \
+            if (argument_types.size() > 1) {                                   
                   \
+                std::visit(                                                    
                   \
+                        [&](auto result_is_nullable) {                         
                   \
+                            result = new 
AggregateFunctionNullVariadicInline<T,                   \
+                                                                             
result_is_nullable>( \
+                                    new T(std::forward<TArgs>(args)...,        
                   \
+                                          remove_nullable(argument_types)),    
                   \
+                                    argument_types);                           
                   \
+                        },                                                     
                   \
+                        make_bool_variant(result_is_nullable));                
                   \
+            } else {                                                           
                   \
+                std::visit(                                                    
                   \
+                        [&](auto result_is_nullable) {                         
                   \
+                            result = new AggregateFunctionNullUnaryInline<T, 
result_is_nullable>( \
+                                    new T(std::forward<TArgs>(args)...,        
                   \
+                                          remove_nullable(argument_types)),    
                   \
+                                    argument_types);                           
                   \
+                        },                                                     
                   \
+                        make_bool_variant(result_is_nullable));                
                   \
+            }                                                                  
                   \
+            return result;                                                     
                   \
+        } else {                                                               
                   \
+            return new T(std::forward<TArgs>(args)..., argument_types);        
                   \
+        }                                                                      
                   \
+    }
+
+        if constexpr (allow_integer) {
+            FOR_INTEGER_TYPES(DISPATCH);
+        }
+        if constexpr (allow_float) {
+            FOR_FLOAT_TYPES(DISPATCH);
+        }
+        if constexpr (allow_decimal) {
+            FOR_DECIMAL_TYPES(DISPATCH);
+        }
 #undef DISPATCH
-    return nullptr;
-}
-
+        return nullptr;
+    }
+
+    template <template <typename> class AggregateFunctionTemplate, typename... 
TArgs>
+    static IAggregateFunction* create(TArgs&&... args) {
+        return 
create_base<CurryDirect<AggregateFunctionTemplate>>(std::forward<TArgs>(args)...);
+    }
+
+    template <template <typename> class AggregateFunctionTemplate, template 
<typename> class Data,
+              typename... TArgs>
+    static IAggregateFunction* create(TArgs&&... args) {
+        return create_base<CurryData<AggregateFunctionTemplate, Data>>(
+                std::forward<TArgs>(args)...);
+    }
+
+    template <template <typename, typename> class AggregateFunctionTemplate,
+              template <typename> class Data, typename... TArgs>
+    static IAggregateFunction* create(TArgs&&... args) {
+        return create_base<CurryBuilerDirectAndData<AggregateFunctionTemplate, 
Data>>(
+                std::forward<TArgs>(args)...);
+    }
+};
+
+using creator_with_integer_type = creator_with_type_base<true, false, false>;
+using creator_with_numeric_type = creator_with_type_base<true, true, false>;
+using creator_with_decimal_type = creator_with_type_base<false, false, true>;
+using creator_with_type = creator_with_type_base<true, true, true>;
+
+template <int define_index = 0>
+struct creator_with_no_type {

Review Comment:
   not_number_type



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to