github-actions[bot] commented on code in PR #15339:
URL: https://github.com/apache/doris/pull/15339#discussion_r1111501259


##########
be/src/vec/aggregate_functions/aggregate_function_collect.cpp:
##########
@@ -18,78 +18,88 @@
 #include "vec/aggregate_functions/aggregate_function_collect.h"
 
 #include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/aggregate_functions/helpers.h"
 
 namespace doris::vectorized {
 
-template <typename T>
-AggregateFunctionPtr create_agg_function_collect(bool distinct, const 
DataTypes& argument_types) {
+#define FOR_DECIMAL_TYPES(M) \
+    M(Decimal32)             \
+    M(Decimal64)             \
+    M(Decimal128)            \
+    M(Decimal128I)
+
+template <typename T, typename HasLimit, typename... TArgs>
+AggregateFunctionPtr do_create_agg_function_collect(bool distinct, const 
DataTypePtr& argument_type,
+                                                    TArgs... args) {
     if (distinct) {
         return AggregateFunctionPtr(
-                new 
AggregateFunctionCollect<AggregateFunctionCollectSetData<T>>(argument_types));
+                new 
AggregateFunctionCollect<AggregateFunctionCollectSetData<T, HasLimit>,
+                                             HasLimit>(argument_type,
+                                                       
std::forward<TArgs>(args)...));
     } else {
         return AggregateFunctionPtr(
-                new 
AggregateFunctionCollect<AggregateFunctionCollectListData<T>>(argument_types));
+                new 
AggregateFunctionCollect<AggregateFunctionCollectListData<T, HasLimit>,
+                                             HasLimit>(argument_type,
+                                                       
std::forward<TArgs>(args)...));
     }
 }
 
-AggregateFunctionPtr create_aggregate_function_collect(const std::string& name,
-                                                       const DataTypes& 
argument_types,
-                                                       const bool 
result_is_nullable) {
-    if (argument_types.size() != 1) {
-        LOG(WARNING) << fmt::format("Illegal number {} of argument for 
aggregate function {}",
-                                    argument_types.size(), name);
-        return nullptr;
-    }
-
+template <typename HasLimit, typename... TArgs>
+AggregateFunctionPtr create_aggregate_function_collect_impl(const std::string& 
name,
+                                                            const DataTypePtr& 
argument_type,
+                                                            TArgs... args) {
     bool distinct = false;
     if (name == "collect_set") {
         distinct = true;
     }
 
-    WhichDataType type(argument_types[0]);
-    if (type.is_uint8()) {
-        return create_agg_function_collect<UInt8>(distinct, argument_types);
-    } else if (type.is_int8()) {
-        return create_agg_function_collect<Int8>(distinct, argument_types);
-    } else if (type.is_int16()) {
-        return create_agg_function_collect<Int16>(distinct, argument_types);
-    } else if (type.is_int32()) {
-        return create_agg_function_collect<Int32>(distinct, argument_types);
-    } else if (type.is_int64()) {
-        return create_agg_function_collect<Int64>(distinct, argument_types);
-    } else if (type.is_int128()) {
-        return create_agg_function_collect<Int128>(distinct, argument_types);
-    } else if (type.is_float32()) {
-        return create_agg_function_collect<Float32>(distinct, argument_types);
-    } else if (type.is_float64()) {
-        return create_agg_function_collect<Float64>(distinct, argument_types);
-    } else if (type.is_decimal32()) {
-        return create_agg_function_collect<Decimal32>(distinct, 
argument_types);
-    } else if (type.is_decimal64()) {
-        return create_agg_function_collect<Decimal64>(distinct, 
argument_types);
-    } else if (type.is_decimal128()) {
-        return create_agg_function_collect<Decimal128>(distinct, 
argument_types);
-    } else if (type.is_decimal128i()) {
-        return create_agg_function_collect<Decimal128I>(distinct, 
argument_types);
-    } else if (type.is_date()) {
-        return create_agg_function_collect<Int64>(distinct, argument_types);
-    } else if (type.is_date_time()) {
-        return create_agg_function_collect<Int64>(distinct, argument_types);
-    } else if (type.is_date_v2()) {
-        return create_agg_function_collect<UInt32>(distinct, argument_types);
-    } else if (type.is_date_time_v2()) {
-        return create_agg_function_collect<UInt64>(distinct, argument_types);
-    } else if (type.is_string()) {
-        return create_agg_function_collect<StringRef>(distinct, 
argument_types);
+    WhichDataType which(argument_type);
+#define DISPATCH(TYPE)                                                         
        \
+    if (which.idx == TypeIndex::TYPE)                                          
        \
+        return do_create_agg_function_collect<TYPE, HasLimit>(distinct, 
argument_type, \
+                                                              
std::forward<TArgs>(args)...);
+    FOR_NUMERIC_TYPES(DISPATCH)
+    FOR_DECIMAL_TYPES(DISPATCH)
+#undef DISPATCH
+    if (which.is_date_or_datetime()) {
+        return do_create_agg_function_collect<Int64, HasLimit>(distinct, 
argument_type,
+                                                               
std::forward<TArgs>(args)...);
+    } else if (which.is_date_v2()) {
+        return do_create_agg_function_collect<UInt32, HasLimit>(distinct, 
argument_type,
+                                                                
std::forward<TArgs>(args)...);
+    } else if (which.is_date_time_v2()) {
+        return do_create_agg_function_collect<UInt64, HasLimit>(distinct, 
argument_type,
+                                                                
std::forward<TArgs>(args)...);
+    } else if (which.is_string()) {
+        return do_create_agg_function_collect<StringRef, HasLimit>(distinct, 
argument_type,
+                                                                   
std::forward<TArgs>(args)...);
     }
 
     LOG(WARNING) << fmt::format("unsupported input type {} for aggregate 
function {}",
-                                argument_types[0]->get_name(), name);
+                                argument_type->get_name(), name);
+    return nullptr;
+}
+
+AggregateFunctionPtr create_aggregate_function_collect(const std::string& name,
+                                                       const DataTypes& 
argument_types,
+                                                       const bool 
result_is_nullable) {
+    if (argument_types.size() == 1) {
+        return create_aggregate_function_collect_impl<std::false_type>(name, 
argument_types[0],
+                                                                       
parameters);
+    }
+    if (argument_types.size() == 2) {
+        return create_aggregate_function_collect_impl<std::true_type>(name, 
argument_types[0],
+                                                                      
parameters);

Review Comment:
   warning: use of undeclared identifier 'parameters' [clang-diagnostic-error]
   ```cpp
                                                                         
parameters);
                                                                         ^
   ```
   



##########
be/src/vec/aggregate_functions/aggregate_function_collect.cpp:
##########
@@ -18,78 +18,88 @@
 #include "vec/aggregate_functions/aggregate_function_collect.h"
 
 #include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/aggregate_functions/helpers.h"
 
 namespace doris::vectorized {
 
-template <typename T>
-AggregateFunctionPtr create_agg_function_collect(bool distinct, const 
DataTypes& argument_types) {
+#define FOR_DECIMAL_TYPES(M) \
+    M(Decimal32)             \
+    M(Decimal64)             \
+    M(Decimal128)            \
+    M(Decimal128I)
+
+template <typename T, typename HasLimit, typename... TArgs>
+AggregateFunctionPtr do_create_agg_function_collect(bool distinct, const 
DataTypePtr& argument_type,
+                                                    TArgs... args) {
     if (distinct) {
         return AggregateFunctionPtr(
-                new 
AggregateFunctionCollect<AggregateFunctionCollectSetData<T>>(argument_types));
+                new 
AggregateFunctionCollect<AggregateFunctionCollectSetData<T, HasLimit>,
+                                             HasLimit>(argument_type,
+                                                       
std::forward<TArgs>(args)...));
     } else {
         return AggregateFunctionPtr(
-                new 
AggregateFunctionCollect<AggregateFunctionCollectListData<T>>(argument_types));
+                new 
AggregateFunctionCollect<AggregateFunctionCollectListData<T, HasLimit>,
+                                             HasLimit>(argument_type,
+                                                       
std::forward<TArgs>(args)...));
     }
 }
 
-AggregateFunctionPtr create_aggregate_function_collect(const std::string& name,
-                                                       const DataTypes& 
argument_types,
-                                                       const bool 
result_is_nullable) {
-    if (argument_types.size() != 1) {
-        LOG(WARNING) << fmt::format("Illegal number {} of argument for 
aggregate function {}",
-                                    argument_types.size(), name);
-        return nullptr;
-    }
-
+template <typename HasLimit, typename... TArgs>
+AggregateFunctionPtr create_aggregate_function_collect_impl(const std::string& 
name,
+                                                            const DataTypePtr& 
argument_type,
+                                                            TArgs... args) {
     bool distinct = false;
     if (name == "collect_set") {
         distinct = true;
     }
 
-    WhichDataType type(argument_types[0]);
-    if (type.is_uint8()) {
-        return create_agg_function_collect<UInt8>(distinct, argument_types);
-    } else if (type.is_int8()) {
-        return create_agg_function_collect<Int8>(distinct, argument_types);
-    } else if (type.is_int16()) {
-        return create_agg_function_collect<Int16>(distinct, argument_types);
-    } else if (type.is_int32()) {
-        return create_agg_function_collect<Int32>(distinct, argument_types);
-    } else if (type.is_int64()) {
-        return create_agg_function_collect<Int64>(distinct, argument_types);
-    } else if (type.is_int128()) {
-        return create_agg_function_collect<Int128>(distinct, argument_types);
-    } else if (type.is_float32()) {
-        return create_agg_function_collect<Float32>(distinct, argument_types);
-    } else if (type.is_float64()) {
-        return create_agg_function_collect<Float64>(distinct, argument_types);
-    } else if (type.is_decimal32()) {
-        return create_agg_function_collect<Decimal32>(distinct, 
argument_types);
-    } else if (type.is_decimal64()) {
-        return create_agg_function_collect<Decimal64>(distinct, 
argument_types);
-    } else if (type.is_decimal128()) {
-        return create_agg_function_collect<Decimal128>(distinct, 
argument_types);
-    } else if (type.is_decimal128i()) {
-        return create_agg_function_collect<Decimal128I>(distinct, 
argument_types);
-    } else if (type.is_date()) {
-        return create_agg_function_collect<Int64>(distinct, argument_types);
-    } else if (type.is_date_time()) {
-        return create_agg_function_collect<Int64>(distinct, argument_types);
-    } else if (type.is_date_v2()) {
-        return create_agg_function_collect<UInt32>(distinct, argument_types);
-    } else if (type.is_date_time_v2()) {
-        return create_agg_function_collect<UInt64>(distinct, argument_types);
-    } else if (type.is_string()) {
-        return create_agg_function_collect<StringRef>(distinct, 
argument_types);
+    WhichDataType which(argument_type);
+#define DISPATCH(TYPE)                                                         
        \
+    if (which.idx == TypeIndex::TYPE)                                          
        \
+        return do_create_agg_function_collect<TYPE, HasLimit>(distinct, 
argument_type, \
+                                                              
std::forward<TArgs>(args)...);
+    FOR_NUMERIC_TYPES(DISPATCH)
+    FOR_DECIMAL_TYPES(DISPATCH)
+#undef DISPATCH
+    if (which.is_date_or_datetime()) {
+        return do_create_agg_function_collect<Int64, HasLimit>(distinct, 
argument_type,
+                                                               
std::forward<TArgs>(args)...);
+    } else if (which.is_date_v2()) {
+        return do_create_agg_function_collect<UInt32, HasLimit>(distinct, 
argument_type,
+                                                                
std::forward<TArgs>(args)...);
+    } else if (which.is_date_time_v2()) {
+        return do_create_agg_function_collect<UInt64, HasLimit>(distinct, 
argument_type,
+                                                                
std::forward<TArgs>(args)...);
+    } else if (which.is_string()) {
+        return do_create_agg_function_collect<StringRef, HasLimit>(distinct, 
argument_type,
+                                                                   
std::forward<TArgs>(args)...);
     }
 
     LOG(WARNING) << fmt::format("unsupported input type {} for aggregate 
function {}",
-                                argument_types[0]->get_name(), name);
+                                argument_type->get_name(), name);
+    return nullptr;
+}
+
+AggregateFunctionPtr create_aggregate_function_collect(const std::string& name,
+                                                       const DataTypes& 
argument_types,
+                                                       const bool 
result_is_nullable) {
+    if (argument_types.size() == 1) {
+        return create_aggregate_function_collect_impl<std::false_type>(name, 
argument_types[0],
+                                                                       
parameters);

Review Comment:
   warning: use of undeclared identifier 'parameters' [clang-diagnostic-error]
   ```cpp
                                                                          
parameters);
                                                                          ^
   ```
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to