This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 45ad297a1d [Enchancement](function) change aggregate function creator
to return AggregateFunctionPtr (#18025)
45ad297a1d is described below
commit 45ad297a1dd6c610dc3266f1b529727ca3ebbbe8
Author: Pxl <[email protected]>
AuthorDate: Sun Mar 26 11:41:34 2023 +0800
[Enchancement](function) change aggregate function creator to return
AggregateFunctionPtr (#18025)
change creator_type to return AggregateFunctionPtr.
remove some function and use creator directly.
---
.../aggregate_function_approx_count_distinct.cpp | 17 ++---
.../aggregate_functions/aggregate_function_avg.cpp | 19 +-----
.../aggregate_function_avg_weighted.cpp | 12 +---
.../aggregate_functions/aggregate_function_bit.cpp | 26 +++-----
.../aggregate_function_bitmap.cpp | 42 ++++---------
.../aggregate_function_collect.cpp | 14 ++---
.../aggregate_function_distinct.cpp | 24 +++----
.../aggregate_function_group_concat.cpp | 14 ++---
.../aggregate_function_histogram.cpp | 14 ++---
.../aggregate_function_hll_union_agg.cpp | 21 ++-----
.../aggregate_function_min_max.cpp | 73 ++++++++--------------
.../aggregate_function_min_max.h | 16 ++---
.../aggregate_function_min_max_by.cpp | 48 ++++++--------
.../aggregate_function_orthogonal_bitmap.cpp | 68 +++++---------------
.../aggregate_function_percentile_approx.cpp | 39 ++++--------
.../aggregate_function_reader.cpp | 13 ++--
.../aggregate_function_retention.cpp | 12 +---
.../aggregate_function_sequence_match.cpp | 22 +++----
.../aggregate_function_stddev.cpp | 59 ++++++++---------
.../aggregate_functions/aggregate_function_sum.cpp | 50 +--------------
.../aggregate_functions/aggregate_function_sum.h | 17 ++++-
.../aggregate_function_topn.cpp | 52 +++++++--------
.../aggregate_function_uniq.cpp | 27 +++-----
.../aggregate_function_window.cpp | 37 ++---------
.../aggregate_function_window_funnel.cpp | 12 ++--
be/src/vec/aggregate_functions/helpers.h | 56 ++++++++++++++---
.../functions/array/function_array_aggregation.cpp | 11 ++--
27 files changed, 285 insertions(+), 530 deletions(-)
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp
b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp
index 2c22586d43..3083b9c7b1 100644
---
a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp
+++
b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp
@@ -18,28 +18,21 @@
#include "vec/aggregate_functions/aggregate_function_approx_count_distinct.h"
#include "vec/aggregate_functions/helpers.h"
-#include "vec/utils/template_helpers.hpp"
namespace doris::vectorized {
AggregateFunctionPtr create_aggregate_function_approx_count_distinct(
const std::string& name, const DataTypes& argument_types, const bool
result_is_nullable) {
- AggregateFunctionPtr res = nullptr;
WhichDataType which(remove_nullable(argument_types[0]));
-#define DISPATCH(TYPE, COLUMN_TYPE)
\
- if (which.idx == TypeIndex::TYPE)
\
-
res.reset(creator_without_type::create<AggregateFunctionApproxCountDistinct<COLUMN_TYPE>>(
\
- result_is_nullable, argument_types));
+#define DISPATCH(TYPE, COLUMN_TYPE)
\
+ if (which.idx == TypeIndex::TYPE)
\
+ return
creator_without_type::create<AggregateFunctionApproxCountDistinct<COLUMN_TYPE>>(
\
+ argument_types, result_is_nullable);
TYPE_TO_COLUMN_TYPE(DISPATCH)
#undef DISPATCH
- if (!res) {
- LOG(WARNING) << fmt::format("Illegal type {} of argument for aggregate
function {}",
- argument_types[0]->get_name(), name);
- }
-
- return res;
+ return nullptr;
}
void
register_aggregate_function_approx_count_distinct(AggregateFunctionSimpleFactory&
factory) {
diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp
b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp
index 4f493c9529..f6fe08a9e3 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp
@@ -20,9 +20,7 @@
#include "vec/aggregate_functions/aggregate_function_avg.h"
-#include "common/logging.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
-#include "vec/aggregate_functions/factory_helpers.h"
#include "vec/aggregate_functions/helpers.h"
namespace doris::vectorized {
@@ -36,22 +34,7 @@ struct Avg {
template <typename T>
using AggregateFuncAvg = typename Avg<T>::Function;
-AggregateFunctionPtr create_aggregate_function_avg(const std::string& name,
- const DataTypes&
argument_types,
- const bool
result_is_nullable) {
- assert_unary(name, argument_types);
-
- AggregateFunctionPtr res(
- creator_with_type::create<AggregateFuncAvg>(result_is_nullable,
argument_types));
-
- if (!res) {
- LOG(WARNING) << fmt::format("Illegal type {} of argument for aggregate
function {}",
- argument_types[0]->get_name(), name);
- }
- return res;
-}
-
void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory) {
- factory.register_function_both("avg", create_aggregate_function_avg);
+ factory.register_function_both("avg",
creator_with_type::creator<AggregateFuncAvg>);
}
} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.cpp
b/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.cpp
index c81bf4b42f..fc5df5303f 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_avg_weighted.cpp
@@ -19,18 +19,10 @@
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/helpers.h"
-#include "vec/data_types/data_type_nullable.h"
namespace doris::vectorized {
-
-AggregateFunctionPtr create_aggregate_function_avg_weight(const std::string&
name,
- const DataTypes&
argument_types,
- const bool
result_is_nullable) {
- return
AggregateFunctionPtr(creator_with_type::create<AggregateFunctionAvgWeight>(
- result_is_nullable, argument_types));
-}
-
void register_aggregate_function_avg_weighted(AggregateFunctionSimpleFactory&
factory) {
- factory.register_function_both("avg_weighted",
create_aggregate_function_avg_weight);
+ factory.register_function_both("avg_weighted",
+
creator_with_type::creator<AggregateFunctionAvgWeight>);
}
} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_bit.cpp
b/be/src/vec/aggregate_functions/aggregate_function_bit.cpp
index bdc51daaf9..97a6c0e92f 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_bit.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_bit.cpp
@@ -25,28 +25,16 @@
namespace doris::vectorized {
-template <template <typename> class Data>
-AggregateFunctionPtr createAggregateFunctionBitwise(const std::string& name,
- const DataTypes&
argument_types,
- const bool
result_is_nullable) {
- AggregateFunctionPtr
res(creator_with_integer_type::create<AggregateFunctionBitwise, Data>(
- result_is_nullable, argument_types));
- if (res) {
- return res;
- }
-
- LOG(WARNING) << fmt::format("Illegal type " +
argument_types[0]->get_name() +
- " of argument for aggregate function " + name);
- return nullptr;
-}
-
void register_aggregate_function_bit(AggregateFunctionSimpleFactory& factory) {
- factory.register_function_both("group_bit_or",
-
createAggregateFunctionBitwise<AggregateFunctionGroupBitOrData>);
factory.register_function_both(
- "group_bit_and",
createAggregateFunctionBitwise<AggregateFunctionGroupBitAndData>);
+ "group_bit_or",
creator_with_integer_type::creator<AggregateFunctionBitwise,
+
AggregateFunctionGroupBitOrData>);
+ factory.register_function_both(
+ "group_bit_and",
creator_with_integer_type::creator<AggregateFunctionBitwise,
+
AggregateFunctionGroupBitAndData>);
factory.register_function_both(
- "group_bit_xor",
createAggregateFunctionBitwise<AggregateFunctionGroupBitXorData>);
+ "group_bit_xor",
creator_with_integer_type::creator<AggregateFunctionBitwise,
+
AggregateFunctionGroupBitXorData>);
}
} // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp
b/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp
index 546dbc56fa..e896ab2996 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_bitmap.cpp
@@ -35,30 +35,6 @@ IAggregateFunction* create_with_int_data_type(const
DataTypes& argument_type) {
return nullptr;
}
-AggregateFunctionPtr create_aggregate_function_bitmap_union(const std::string&
name,
- const DataTypes&
argument_types,
- const bool
result_is_nullable) {
- return AggregateFunctionPtr(
-
creator_without_type::create<AggregateFunctionBitmapOp<AggregateFunctionBitmapUnionOp>>(
- result_is_nullable, argument_types));
-}
-
-AggregateFunctionPtr create_aggregate_function_bitmap_intersect(const
std::string& name,
- const
DataTypes& argument_types,
- const bool
result_is_nullable) {
- return AggregateFunctionPtr(creator_without_type::create<
-
AggregateFunctionBitmapOp<AggregateFunctionBitmapIntersectOp>>(
- result_is_nullable, argument_types));
-}
-
-AggregateFunctionPtr create_aggregate_function_group_bitmap_xor(const
std::string& name,
- const
DataTypes& argument_types,
- const bool
result_is_nullable) {
- return AggregateFunctionPtr(creator_without_type::create<
-
AggregateFunctionBitmapOp<AggregateFunctionGroupBitmapXorOp>>(
- result_is_nullable, argument_types));
-}
-
AggregateFunctionPtr create_aggregate_function_bitmap_union_count(const
std::string& name,
const
DataTypes& argument_types,
const bool
result_is_nullable) {
@@ -75,18 +51,26 @@ AggregateFunctionPtr
create_aggregate_function_bitmap_union_int(const std::strin
const bool
result_is_nullable) {
const bool arg_is_nullable = argument_types[0]->is_nullable();
if (arg_is_nullable) {
- return std::shared_ptr<IAggregateFunction>(
+ return AggregateFunctionPtr(
create_with_int_data_type<true,
AggregateFunctionBitmapCount>(argument_types));
} else {
- return std::shared_ptr<IAggregateFunction>(
+ return AggregateFunctionPtr(
create_with_int_data_type<false,
AggregateFunctionBitmapCount>(argument_types));
}
}
void register_aggregate_function_bitmap(AggregateFunctionSimpleFactory&
factory) {
- factory.register_function_both("bitmap_union",
create_aggregate_function_bitmap_union);
- factory.register_function_both("bitmap_intersect",
create_aggregate_function_bitmap_intersect);
- factory.register_function_both("group_bitmap_xor",
create_aggregate_function_group_bitmap_xor);
+ factory.register_function_both(
+ "bitmap_union", creator_without_type::creator<
+
AggregateFunctionBitmapOp<AggregateFunctionBitmapUnionOp>>);
+ factory.register_function_both(
+ "bitmap_intersect",
+ creator_without_type::creator<
+
AggregateFunctionBitmapOp<AggregateFunctionBitmapIntersectOp>>);
+ factory.register_function_both(
+ "group_bitmap_xor",
+ creator_without_type::creator<
+
AggregateFunctionBitmapOp<AggregateFunctionGroupBitmapXorOp>>);
factory.register_function_both("bitmap_union_count",
create_aggregate_function_bitmap_union_count);
factory.register_function_both("bitmap_union_int",
create_aggregate_function_bitmap_union_int);
diff --git a/be/src/vec/aggregate_functions/aggregate_function_collect.cpp
b/be/src/vec/aggregate_functions/aggregate_function_collect.cpp
index 110618581e..957ce4f3fa 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_collect.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_collect.cpp
@@ -26,15 +26,13 @@ template <typename T, typename HasLimit>
AggregateFunctionPtr do_create_agg_function_collect(bool distinct, const
DataTypes& argument_types,
const bool
result_is_nullable) {
if (distinct) {
- return AggregateFunctionPtr(
- creator_without_type::create<AggregateFunctionCollect<
- AggregateFunctionCollectSetData<T, HasLimit>,
HasLimit>>(result_is_nullable,
-
argument_types));
+ return creator_without_type::create<
+ AggregateFunctionCollect<AggregateFunctionCollectSetData<T,
HasLimit>, HasLimit>>(
+ argument_types, result_is_nullable);
} else {
- return AggregateFunctionPtr(
- creator_without_type::create<AggregateFunctionCollect<
- AggregateFunctionCollectListData<T, HasLimit>,
HasLimit>>(
- result_is_nullable, argument_types));
+ return creator_without_type::create<
+ AggregateFunctionCollect<AggregateFunctionCollectListData<T,
HasLimit>, HasLimit>>(
+ argument_types, result_is_nullable);
}
}
diff --git a/be/src/vec/aggregate_functions/aggregate_function_distinct.cpp
b/be/src/vec/aggregate_functions/aggregate_function_distinct.cpp
index a98a1b8e19..4d57c689b7 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_distinct.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_distinct.cpp
@@ -23,9 +23,6 @@
#include "vec/aggregate_functions/aggregate_function_combinator.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/helpers.h"
-#include "vec/common/typeid_cast.h"
-#include "vec/data_types/data_type.h"
-#include "vec/data_types/data_type_nullable.h"
namespace doris::vectorized {
@@ -53,26 +50,25 @@ public:
AggregateFunctionPtr res(
creator_with_numeric_type::create<AggregateFunctionDistinct,
AggregateFunctionDistinctSingleNumericData>(
- result_is_nullable, arguments, nested_function));
+ arguments, result_is_nullable, nested_function));
if (res) {
return res;
}
if
(arguments[0]->is_value_unambiguously_represented_in_contiguous_memory_region())
{
-
res.reset(creator_without_type::create<AggregateFunctionDistinct<
-
AggregateFunctionDistinctSingleGenericData<true>>>(
- result_is_nullable, arguments, nested_function));
+ res = creator_without_type::create<AggregateFunctionDistinct<
+ AggregateFunctionDistinctSingleGenericData<true>>>(
+ arguments, result_is_nullable, nested_function);
} else {
-
res.reset(creator_without_type::create<AggregateFunctionDistinct<
-
AggregateFunctionDistinctSingleGenericData<false>>>(
- result_is_nullable, arguments, nested_function));
+ res = creator_without_type::create<AggregateFunctionDistinct<
+ AggregateFunctionDistinctSingleGenericData<false>>>(
+ arguments, result_is_nullable, nested_function);
}
return res;
}
- return AggregateFunctionPtr(
- creator_without_type::create<
-
AggregateFunctionDistinct<AggregateFunctionDistinctMultipleGenericData>>(
- result_is_nullable, arguments, nested_function));
+ return creator_without_type::create<
+
AggregateFunctionDistinct<AggregateFunctionDistinctMultipleGenericData>>(
+ arguments, result_is_nullable, nested_function);
}
};
diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp
b/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp
index 5bd070ada3..92a840949f 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_group_concat.cpp
@@ -27,15 +27,13 @@ AggregateFunctionPtr
create_aggregate_function_group_concat(const std::string& n
const DataTypes&
argument_types,
const bool
result_is_nullable) {
if (argument_types.size() == 1) {
- return AggregateFunctionPtr(
- creator_without_type::create<
-
AggregateFunctionGroupConcat<AggregateFunctionGroupConcatImplStr>>(
- result_is_nullable, argument_types));
+ return creator_without_type::create<
+
AggregateFunctionGroupConcat<AggregateFunctionGroupConcatImplStr>>(
+ argument_types, result_is_nullable);
} else if (argument_types.size() == 2) {
- return AggregateFunctionPtr(
- creator_without_type::create<
-
AggregateFunctionGroupConcat<AggregateFunctionGroupConcatImplStrStr>>(
- result_is_nullable, argument_types));
+ return creator_without_type::create<
+
AggregateFunctionGroupConcat<AggregateFunctionGroupConcatImplStrStr>>(
+ argument_types, result_is_nullable);
}
LOG(WARNING) << fmt::format("Illegal number {} of argument for aggregate
function {}",
diff --git a/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp
b/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp
index 8fee319e4b..b396862074 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp
@@ -28,15 +28,13 @@ AggregateFunctionPtr create_agg_function_histogram(const
DataTypes& argument_typ
bool has_input_param = (argument_types.size() == 2);
if (has_input_param) {
- return AggregateFunctionPtr(
- creator_without_type::create<
-
AggregateFunctionHistogram<AggregateFunctionHistogramData<T>, T, true>>(
- result_is_nullable, argument_types));
+ return creator_without_type::create<
+ AggregateFunctionHistogram<AggregateFunctionHistogramData<T>,
T, true>>(
+ argument_types, result_is_nullable);
} else {
- return AggregateFunctionPtr(
- creator_without_type::create<
-
AggregateFunctionHistogram<AggregateFunctionHistogramData<T>, T, false>>(
- result_is_nullable, argument_types));
+ return creator_without_type::create<
+ AggregateFunctionHistogram<AggregateFunctionHistogramData<T>,
T, false>>(
+ argument_types, result_is_nullable);
}
}
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.cpp
b/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.cpp
index dc575e6d25..ec3488e4a1 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.cpp
@@ -18,27 +18,18 @@
#include "vec/aggregate_functions/aggregate_function_hll_union_agg.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
-#include "vec/aggregate_functions/factory_helpers.h"
#include "vec/aggregate_functions/helpers.h"
namespace doris::vectorized {
-template <template <typename> class Impl>
-AggregateFunctionPtr create_aggregate_function_HLL(const std::string& name,
- const DataTypes&
argument_types,
- const bool
result_is_nullable) {
- assert_arity_at_most<1>(name, argument_types);
- return AggregateFunctionPtr(
-
creator_without_type::create<AggregateFunctionHLLUnion<Impl<AggregateFunctionHLLData>>>(
- result_is_nullable, argument_types));
-}
-
void register_aggregate_function_HLL_union_agg(AggregateFunctionSimpleFactory&
factory) {
- factory.register_function_both("hll_union_agg",
-
create_aggregate_function_HLL<AggregateFunctionHLLUnionAggImpl>);
+ factory.register_function_both(
+ "hll_union_agg",
creator_without_type::creator<AggregateFunctionHLLUnion<
+
AggregateFunctionHLLUnionAggImpl<AggregateFunctionHLLData>>>);
- factory.register_function_both("hll_union",
-
create_aggregate_function_HLL<AggregateFunctionHLLUnionImpl>);
+ factory.register_function_both(
+ "hll_union",
creator_without_type::creator<AggregateFunctionHLLUnion<
+
AggregateFunctionHLLUnionImpl<AggregateFunctionHLLData>>>);
factory.register_alias("hll_union", "hll_raw_agg");
}
diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp
b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp
index 3f3b7dc727..d09beb1d49 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_min_max.cpp
@@ -26,21 +26,21 @@
namespace doris::vectorized {
/// min, max, any
-template <template <typename> class AggregateFunctionTemplate, template
<typename> class Data>
-IAggregateFunction* create_aggregate_function_single_value(const String& name,
- const DataTypes&
argument_types,
- const bool
result_is_nullable) {
+template <template <typename> class Data>
+AggregateFunctionPtr create_aggregate_function_single_value(const String& name,
+ const DataTypes&
argument_types,
+ const bool
result_is_nullable) {
assert_unary(name, argument_types);
- IAggregateFunction*
res(creator_with_numeric_type::create<AggregateFunctionTemplate, Data,
-
SingleValueDataFixed>(
- result_is_nullable, argument_types));
+ AggregateFunctionPtr
res(creator_with_numeric_type::create<AggregateFunctionsSingleValue, Data,
+
SingleValueDataFixed>(
+ argument_types, result_is_nullable));
if (res) {
return res;
}
- res = creator_with_decimal_type::create<AggregateFunctionTemplate, Data,
-
SingleValueDataDecimal>(result_is_nullable,
-
argument_types);
+ res = creator_with_decimal_type::create<AggregateFunctionsSingleValue,
Data,
+
SingleValueDataDecimal>(argument_types,
+
result_is_nullable);
if (res) {
return res;
}
@@ -48,58 +48,35 @@ IAggregateFunction*
create_aggregate_function_single_value(const String& name,
WhichDataType which(argument_type);
if (which.idx == TypeIndex::String) {
- return
creator_without_type::create<AggregateFunctionTemplate<Data<SingleValueDataString>>>(
- result_is_nullable, argument_types);
+ return creator_without_type::create<
+
AggregateFunctionsSingleValue<Data<SingleValueDataString>>>(argument_types,
+
result_is_nullable);
}
if (which.idx == TypeIndex::DateTime || which.idx == TypeIndex::Date) {
return creator_without_type::create<
-
AggregateFunctionTemplate<Data<SingleValueDataFixed<Int64>>>>(result_is_nullable,
-
argument_types);
+
AggregateFunctionsSingleValue<Data<SingleValueDataFixed<Int64>>>>(
+ argument_types, result_is_nullable);
}
if (which.idx == TypeIndex::DateV2) {
return creator_without_type::create<
-
AggregateFunctionTemplate<Data<SingleValueDataFixed<UInt32>>>>(result_is_nullable,
-
argument_types);
+
AggregateFunctionsSingleValue<Data<SingleValueDataFixed<UInt32>>>>(
+ argument_types, result_is_nullable);
}
if (which.idx == TypeIndex::DateTimeV2) {
return creator_without_type::create<
-
AggregateFunctionTemplate<Data<SingleValueDataFixed<UInt64>>>>(result_is_nullable,
-
argument_types);
+
AggregateFunctionsSingleValue<Data<SingleValueDataFixed<UInt64>>>>(
+ argument_types, result_is_nullable);
}
return nullptr;
}
-AggregateFunctionPtr create_aggregate_function_max(const std::string& name,
- const DataTypes&
argument_types,
- const bool
result_is_nullable) {
- return AggregateFunctionPtr(
-
create_aggregate_function_single_value<AggregateFunctionsSingleValue,
-
AggregateFunctionMaxData>(name, argument_types,
-
result_is_nullable));
-}
-
-AggregateFunctionPtr create_aggregate_function_min(const std::string& name,
- const DataTypes&
argument_types,
- const bool
result_is_nullable) {
- return AggregateFunctionPtr(
-
create_aggregate_function_single_value<AggregateFunctionsSingleValue,
-
AggregateFunctionMinData>(name, argument_types,
-
result_is_nullable));
-}
-
-AggregateFunctionPtr create_aggregate_function_any(const std::string& name,
- const DataTypes&
argument_types,
- const bool
result_is_nullable) {
- return AggregateFunctionPtr(
-
create_aggregate_function_single_value<AggregateFunctionsSingleValue,
-
AggregateFunctionAnyData>(name, argument_types,
-
result_is_nullable));
-}
-
void register_aggregate_function_minmax(AggregateFunctionSimpleFactory&
factory) {
- factory.register_function_both("max", create_aggregate_function_max);
- factory.register_function_both("min", create_aggregate_function_min);
- factory.register_function_both("any", create_aggregate_function_any);
+ factory.register_function_both(
+ "max",
create_aggregate_function_single_value<AggregateFunctionMaxData>);
+ factory.register_function_both(
+ "min",
create_aggregate_function_single_value<AggregateFunctionMinData>);
+ factory.register_function_both(
+ "any",
create_aggregate_function_single_value<AggregateFunctionAnyData>);
factory.register_alias("any", "any_value");
}
diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max.h
b/be/src/vec/aggregate_functions/aggregate_function_min_max.h
index af99e9223d..f4692bc818 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_min_max.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_min_max.h
@@ -620,16 +620,8 @@ public:
}
};
-AggregateFunctionPtr create_aggregate_function_max(const std::string& name,
- const DataTypes&
argument_types,
- const bool
result_is_nullable);
-
-AggregateFunctionPtr create_aggregate_function_min(const std::string& name,
- const DataTypes&
argument_types,
- const bool
result_is_nullable);
-
-AggregateFunctionPtr create_aggregate_function_any(const std::string& name,
- const DataTypes&
argument_types,
- const bool
result_is_nullable);
-
+template <template <typename> class Data>
+AggregateFunctionPtr create_aggregate_function_single_value(const String& name,
+ const DataTypes&
argument_types,
+ const bool
result_is_nullable);
} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp
b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp
index 1526975b9e..7d0dd084be 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp
@@ -28,15 +28,15 @@ namespace doris::vectorized {
/// min_by, max_by
template <template <typename> class AggregateFunctionTemplate,
template <typename, typename> class Data, typename VT>
-IAggregateFunction* create_aggregate_function_min_max_by_impl(const DataTypes&
argument_types,
- const bool
result_is_nullable) {
+AggregateFunctionPtr create_aggregate_function_min_max_by_impl(const
DataTypes& argument_types,
+ const bool
result_is_nullable) {
WhichDataType which(remove_nullable(argument_types[1]));
#define DISPATCH(TYPE)
\
if (which.idx == TypeIndex::TYPE)
\
return creator_without_type::create<
\
AggregateFunctionTemplate<Data<VT,
SingleValueDataFixed<TYPE>>>>( \
- result_is_nullable, argument_types);
+ argument_types, result_is_nullable);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
@@ -44,29 +44,29 @@ IAggregateFunction*
create_aggregate_function_min_max_by_impl(const DataTypes& a
if (which.idx == TypeIndex::TYPE)
\
return creator_without_type::create<
\
AggregateFunctionTemplate<Data<VT,
SingleValueDataDecimal<TYPE>>>>( \
- result_is_nullable, argument_types);
+ argument_types, result_is_nullable);
FOR_DECIMAL_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::String) {
return creator_without_type::create<
- AggregateFunctionTemplate<Data<VT,
SingleValueDataString>>>(result_is_nullable,
-
argument_types);
+ AggregateFunctionTemplate<Data<VT,
SingleValueDataString>>>(argument_types,
+
result_is_nullable);
}
if (which.idx == TypeIndex::DateTime || which.idx == TypeIndex::Date) {
return creator_without_type::create<
AggregateFunctionTemplate<Data<VT,
SingleValueDataFixed<Int64>>>>(
- result_is_nullable, argument_types);
+ argument_types, result_is_nullable);
}
if (which.idx == TypeIndex::DateV2) {
return creator_without_type::create<
AggregateFunctionTemplate<Data<VT,
SingleValueDataFixed<UInt32>>>>(
- result_is_nullable, argument_types);
+ argument_types, result_is_nullable);
}
if (which.idx == TypeIndex::DateTimeV2) {
return creator_without_type::create<
AggregateFunctionTemplate<Data<VT,
SingleValueDataFixed<UInt64>>>>(
- result_is_nullable, argument_types);
+ argument_types, result_is_nullable);
}
return nullptr;
}
@@ -74,9 +74,9 @@ IAggregateFunction*
create_aggregate_function_min_max_by_impl(const DataTypes& a
/// min_by, max_by
template <template <typename> class AggregateFunctionTemplate,
template <typename, typename> class Data>
-IAggregateFunction* create_aggregate_function_min_max_by(const String& name,
- const DataTypes&
argument_types,
- const bool
result_is_nullable) {
+AggregateFunctionPtr create_aggregate_function_min_max_by(const String& name,
+ const DataTypes&
argument_types,
+ const bool
result_is_nullable) {
assert_binary(name, argument_types);
WhichDataType which(remove_nullable(argument_types[0]));
@@ -119,25 +119,13 @@ IAggregateFunction*
create_aggregate_function_min_max_by(const String& name,
return nullptr;
}
-AggregateFunctionPtr create_aggregate_function_max_by(const std::string& name,
- const DataTypes&
argument_types,
- const bool
result_is_nullable) {
- return
AggregateFunctionPtr(create_aggregate_function_min_max_by<AggregateFunctionsMinMaxBy,
-
AggregateFunctionMaxByData>(
- name, argument_types, result_is_nullable));
-}
-
-AggregateFunctionPtr create_aggregate_function_min_by(const std::string& name,
- const DataTypes&
argument_types,
- const bool
result_is_nullable) {
- return
AggregateFunctionPtr(create_aggregate_function_min_max_by<AggregateFunctionsMinMaxBy,
-
AggregateFunctionMinByData>(
- name, argument_types, result_is_nullable));
-}
-
void register_aggregate_function_min_max_by(AggregateFunctionSimpleFactory&
factory) {
- factory.register_function_both("max_by", create_aggregate_function_max_by);
- factory.register_function_both("min_by", create_aggregate_function_min_by);
+ factory.register_function_both(
+ "max_by",
create_aggregate_function_min_max_by<AggregateFunctionsMinMaxBy,
+
AggregateFunctionMaxByData>);
+ factory.register_function_both(
+ "min_by",
create_aggregate_function_min_max_by<AggregateFunctionsMinMaxBy,
+
AggregateFunctionMinByData>);
}
} // namespace doris::vectorized
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp
b/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp
index b1ef0eb9d1..40876e2b87 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp
@@ -17,8 +17,6 @@
#include "vec/aggregate_functions/aggregate_function_orthogonal_bitmap.h"
-#include <memory>
-
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/helpers.h"
#include "vec/data_types/data_type_string.h"
@@ -40,14 +38,13 @@ AggregateFunctionPtr
create_aggregate_function_orthogonal(const std::string& nam
AggregateFunctionPtr res(
creator_with_type_base<true, true, false,
1>::create<AggFunctionOrthBitmapFunc,
-
Impl>(result_is_nullable,
-
argument_types));
+
Impl>(argument_types,
+
result_is_nullable));
if (res) {
return res;
} else if (which.is_string_or_fixed_string()) {
- res.reset(
-
creator_without_type::create<AggFunctionOrthBitmapFunc<Impl<std::string_view>>>(
- result_is_nullable, argument_types));
+ res =
creator_without_type::create<AggFunctionOrthBitmapFunc<Impl<std::string_view>>>(
+ argument_types, result_is_nullable);
return res;
}
@@ -58,57 +55,20 @@ AggregateFunctionPtr
create_aggregate_function_orthogonal(const std::string& nam
}
}
-AggregateFunctionPtr create_aggregate_function_orthogonal_bitmap_intersect(
- const std::string& name, const DataTypes& argument_types, bool
result_is_nullable) {
- return create_aggregate_function_orthogonal<AggOrthBitMapIntersect>(name,
argument_types,
-
result_is_nullable);
-}
-
-AggregateFunctionPtr
create_aggregate_function_orthogonal_bitmap_intersect_count(
- const std::string& name, const DataTypes& argument_types, bool
result_is_nullable) {
- return
create_aggregate_function_orthogonal<AggOrthBitMapIntersectCount>(name,
argument_types,
-
result_is_nullable);
-}
-
-AggregateFunctionPtr
create_aggregate_function_orthogonal_bitmap_expr_calculate(
- const std::string& name, const DataTypes& argument_types, bool
result_is_nullable) {
- return create_aggregate_function_orthogonal<AggOrthBitMapExprCal>(name,
argument_types,
-
result_is_nullable);
-}
-
-AggregateFunctionPtr
create_aggregate_function_orthogonal_bitmap_expr_calculate_count(
- const std::string& name, const DataTypes& argument_types, bool
result_is_nullable) {
- return
create_aggregate_function_orthogonal<AggOrthBitMapExprCalCount>(name,
argument_types,
-
result_is_nullable);
-}
-
-AggregateFunctionPtr create_aggregate_function_intersect_count(const
std::string& name,
- const
DataTypes& argument_types,
-
- bool
result_is_nullable) {
- return create_aggregate_function_orthogonal<AggIntersectCount>(name,
argument_types,
-
result_is_nullable);
-}
-
-AggregateFunctionPtr create_aggregate_function_orthogonal_bitmap_union_count(
- const std::string& name, const DataTypes& argument_types, const bool
result_is_nullable) {
- return
create_aggregate_function_orthogonal<OrthBitmapUnionCountData>(name,
argument_types,
-
result_is_nullable);
-}
-
void
register_aggregate_function_orthogonal_bitmap(AggregateFunctionSimpleFactory&
factory) {
factory.register_function_both("orthogonal_bitmap_intersect",
-
create_aggregate_function_orthogonal_bitmap_intersect);
- factory.register_function_both("orthogonal_bitmap_intersect_count",
-
create_aggregate_function_orthogonal_bitmap_intersect_count);
+
create_aggregate_function_orthogonal<AggOrthBitMapIntersect>);
+ factory.register_function_both(
+ "orthogonal_bitmap_intersect_count",
+ create_aggregate_function_orthogonal<AggOrthBitMapIntersectCount>);
factory.register_function_both("orthogonal_bitmap_union_count",
-
create_aggregate_function_orthogonal_bitmap_union_count);
- factory.register_function_both("intersect_count",
create_aggregate_function_intersect_count);
+
create_aggregate_function_orthogonal<OrthBitmapUnionCountData>);
+ factory.register_function_both("intersect_count",
+
create_aggregate_function_orthogonal<AggIntersectCount>);
factory.register_function_both("orthogonal_bitmap_expr_calculate",
-
create_aggregate_function_orthogonal_bitmap_expr_calculate);
- factory.register_function_both(
- "orthogonal_bitmap_expr_calculate_count",
- create_aggregate_function_orthogonal_bitmap_expr_calculate_count);
+
create_aggregate_function_orthogonal<AggOrthBitMapExprCal>);
+ factory.register_function_both("orthogonal_bitmap_expr_calculate_count",
+
create_aggregate_function_orthogonal<AggOrthBitMapExprCalCount>);
}
} // namespace doris::vectorized
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.cpp
b/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.cpp
index 3dfe11388b..1c160e7289 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_percentile_approx.cpp
@@ -28,40 +28,25 @@ AggregateFunctionPtr
create_aggregate_function_percentile_approx(const std::stri
const
DataTypes& argument_types,
const bool
result_is_nullable) {
if (argument_types.size() == 1) {
- return AggregateFunctionPtr(
-
creator_without_type::create<AggregateFunctionPercentileApproxMerge<is_nullable>>(
- result_is_nullable, remove_nullable(argument_types)));
+ return
creator_without_type::create<AggregateFunctionPercentileApproxMerge<is_nullable>>(
+ remove_nullable(argument_types), result_is_nullable);
} else if (argument_types.size() == 2) {
- return AggregateFunctionPtr(creator_without_type::create<
-
AggregateFunctionPercentileApproxTwoParams<is_nullable>>(
- result_is_nullable, remove_nullable(argument_types)));
+ return creator_without_type::create<
+ AggregateFunctionPercentileApproxTwoParams<is_nullable>>(
+ remove_nullable(argument_types), result_is_nullable);
} else if (argument_types.size() == 3) {
- return AggregateFunctionPtr(creator_without_type::create<
-
AggregateFunctionPercentileApproxThreeParams<is_nullable>>(
- result_is_nullable, remove_nullable(argument_types)));
+ return creator_without_type::create<
+ AggregateFunctionPercentileApproxThreeParams<is_nullable>>(
+ remove_nullable(argument_types), result_is_nullable);
}
- LOG(WARNING) << fmt::format("Illegal number {} of argument for aggregate
function {}",
- argument_types.size(), name);
return nullptr;
}
-AggregateFunctionPtr create_aggregate_function_percentile(const std::string&
name,
- const DataTypes&
argument_types,
- const bool
result_is_nullable) {
- return
AggregateFunctionPtr(creator_without_type::create<AggregateFunctionPercentile>(
- result_is_nullable, argument_types));
-}
-
-AggregateFunctionPtr create_aggregate_function_percentile_array(const
std::string& name,
- const
DataTypes& argument_types,
- const bool
result_is_nullable) {
- return
AggregateFunctionPtr(creator_without_type::create<AggregateFunctionPercentileArray>(
- result_is_nullable, argument_types));
-}
-
void register_aggregate_function_percentile(AggregateFunctionSimpleFactory&
factory) {
- factory.register_function_both("percentile",
create_aggregate_function_percentile);
- factory.register_function_both("percentile_array",
create_aggregate_function_percentile_array);
+ factory.register_function_both("percentile",
+
creator_without_type::creator<AggregateFunctionPercentile>);
+ factory.register_function_both("percentile_array",
+
creator_without_type::creator<AggregateFunctionPercentileArray>);
}
void
register_aggregate_function_percentile_approx(AggregateFunctionSimpleFactory&
factory) {
diff --git a/be/src/vec/aggregate_functions/aggregate_function_reader.cpp
b/be/src/vec/aggregate_functions/aggregate_function_reader.cpp
index 46384cf48a..034eac8313 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_reader.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_reader.cpp
@@ -28,12 +28,15 @@ void
register_aggregate_function_reader_load(AggregateFunctionSimpleFactory& fac
factory.register_function_both(name + AGG_LOAD_SUFFIX, creator);
};
- register_function_both("sum", create_aggregate_function_sum_reader);
- register_function_both("max", create_aggregate_function_max);
- register_function_both("min", create_aggregate_function_min);
- register_function_both("bitmap_union",
create_aggregate_function_bitmap_union);
+ register_function_both("sum",
creator_with_type::creator<AggregateFunctionSumSimpleReader>);
+ register_function_both("max",
create_aggregate_function_single_value<AggregateFunctionMaxData>);
+ register_function_both("min",
create_aggregate_function_single_value<AggregateFunctionMinData>);
+ register_function_both("bitmap_union",
+ creator_without_type::creator<
+
AggregateFunctionBitmapOp<AggregateFunctionBitmapUnionOp>>);
register_function_both("hll_union",
-
create_aggregate_function_HLL<AggregateFunctionHLLUnionImpl>);
+
creator_without_type::creator<AggregateFunctionHLLUnion<
+
AggregateFunctionHLLUnionImpl<AggregateFunctionHLLData>>>);
register_function_both("quantile_union",
create_aggregate_function_quantile_state_union);
}
diff --git a/be/src/vec/aggregate_functions/aggregate_function_retention.cpp
b/be/src/vec/aggregate_functions/aggregate_function_retention.cpp
index c57c1d075c..38dd8f9de6 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_retention.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_retention.cpp
@@ -18,19 +18,11 @@
#include "vec/aggregate_functions/aggregate_function_retention.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
-#include "vec/aggregate_functions/factory_helpers.h"
#include "vec/aggregate_functions/helpers.h"
namespace doris::vectorized {
-
-AggregateFunctionPtr create_aggregate_function_retention(const std::string&
name,
- const DataTypes&
argument_types,
- const bool
result_is_nullable) {
- return
AggregateFunctionPtr(creator_without_type::create<AggregateFunctionRetention>(
- result_is_nullable, argument_types));
-}
-
void register_aggregate_function_retention(AggregateFunctionSimpleFactory&
factory) {
- factory.register_function_both("retention",
create_aggregate_function_retention);
+ factory.register_function_both("retention",
+
creator_without_type::creator<AggregateFunctionRetention>);
}
} // namespace doris::vectorized
\ No newline at end of file
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_sequence_match.cpp
b/be/src/vec/aggregate_functions/aggregate_function_sequence_match.cpp
index ce8db857b3..4f157fa107 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_sequence_match.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_sequence_match.cpp
@@ -19,7 +19,6 @@
#include "common/logging.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
-#include "vec/aggregate_functions/factory_helpers.h"
#include "vec/aggregate_functions/helpers.h"
namespace doris::vectorized {
@@ -41,21 +40,18 @@ AggregateFunctionPtr
create_aggregate_function_sequence_base(const std::string&
}
if (WhichDataType(remove_nullable(argument_types[1])).is_date_time_v2()) {
- return AggregateFunctionPtr(creator_without_type::create<
-
AggregateFunction<DateV2Value<DateTimeV2ValueType>, UInt64>>(
- result_is_nullable, argument_types));
+ return creator_without_type::create<
+ AggregateFunction<DateV2Value<DateTimeV2ValueType>,
UInt64>>(argument_types,
+
result_is_nullable);
} else if
(WhichDataType(remove_nullable(argument_types[1])).is_date_time()) {
- return AggregateFunctionPtr(
-
creator_without_type::create<AggregateFunction<VecDateTimeValue, Int64>>(
- result_is_nullable, argument_types));
+ return
creator_without_type::create<AggregateFunction<VecDateTimeValue, Int64>>(
+ argument_types, result_is_nullable);
} else if (WhichDataType(remove_nullable(argument_types[1])).is_date_v2())
{
- return AggregateFunctionPtr(creator_without_type::create<
-
AggregateFunction<DateV2Value<DateV2ValueType>, UInt32>>(
- result_is_nullable, argument_types));
- } else {
- LOG(WARNING) << "Only support Date and DateTime type as timestamp
argument!";
- return nullptr;
+ return creator_without_type::create<
+ AggregateFunction<DateV2Value<DateV2ValueType>,
UInt32>>(argument_types,
+
result_is_nullable);
}
+ return nullptr;
}
void
register_aggregate_function_sequence_match(AggregateFunctionSimpleFactory&
factory) {
diff --git a/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp
b/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp
index 9eca714d24..7e96c2e5a4 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_stddev.cpp
@@ -19,79 +19,72 @@
#include "common/logging.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
-#include "vec/aggregate_functions/factory_helpers.h"
#include "vec/aggregate_functions/helpers.h"
namespace doris::vectorized {
template <template <typename, bool> class AggregateFunctionTemplate,
template <typename> class NameData, template <typename, typename>
class Data,
bool is_stddev, bool is_nullable = false>
-IAggregateFunction* create_function_single_value(const String& name,
- const DataTypes&
argument_types,
- const bool result_is_nullable,
- bool custom_nullable) {
- IAggregateFunction* res = nullptr;
+AggregateFunctionPtr create_function_single_value(const String& name,
+ const DataTypes&
argument_types,
+ const bool
result_is_nullable,
+ bool custom_nullable) {
WhichDataType which(remove_nullable(argument_types[0]));
-#define DISPATCH(TYPE)
\
- if (which.idx == TypeIndex::TYPE)
\
- res = creator_without_type::create<AggregateFunctionTemplate<
\
- NameData<Data<TYPE, BaseData<TYPE, is_stddev>>>,
is_nullable>>( \
- result_is_nullable,
\
- custom_nullable ? remove_nullable(argument_types) :
argument_types);
+#define DISPATCH(TYPE)
\
+ if (which.idx == TypeIndex::TYPE)
\
+ return creator_without_type::create<AggregateFunctionTemplate<
\
+ NameData<Data<TYPE, BaseData<TYPE, is_stddev>>>,
is_nullable>>( \
+ custom_nullable ? remove_nullable(argument_types) :
argument_types, \
+ result_is_nullable);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
#define DISPATCH(TYPE)
\
if (which.idx == TypeIndex::TYPE)
\
- res = creator_without_type::create<AggregateFunctionTemplate<
\
+ return creator_without_type::create<AggregateFunctionTemplate<
\
NameData<Data<TYPE, BaseDatadecimal<TYPE, is_stddev>>>,
is_nullable>>( \
- result_is_nullable,
\
- custom_nullable ? remove_nullable(argument_types) :
argument_types);
+ custom_nullable ? remove_nullable(argument_types) :
argument_types, \
+ result_is_nullable);
FOR_DECIMAL_TYPES(DISPATCH)
#undef DISPATCH
- if (res == nullptr) {
- LOG(WARNING) << fmt::format("create_function_single_value with
unknowed type {}",
- argument_types[0]->get_name());
- }
- return res;
+ LOG(WARNING) << fmt::format("create_function_single_value with unknowed
type {}",
+ argument_types[0]->get_name());
+ return nullptr;
}
template <bool is_stddev, bool is_nullable>
AggregateFunctionPtr create_aggregate_function_variance_samp(const
std::string& name,
const DataTypes&
argument_types,
const bool
result_is_nullable) {
- return AggregateFunctionPtr(
- create_function_single_value<AggregateFunctionSamp,
VarianceSampName, SampData,
- is_stddev, is_nullable>(name,
argument_types,
-
result_is_nullable, true));
+ return create_function_single_value<AggregateFunctionSamp,
VarianceSampName, SampData,
+ is_stddev, is_nullable>(name,
argument_types,
+
result_is_nullable, true);
}
template <bool is_stddev, bool is_nullable>
AggregateFunctionPtr create_aggregate_function_stddev_samp(const std::string&
name,
const DataTypes&
argument_types,
const bool
result_is_nullable) {
- return
AggregateFunctionPtr(create_function_single_value<AggregateFunctionSamp,
StddevSampName,
- SampData,
is_stddev, is_nullable>(
- name, argument_types, result_is_nullable, true));
+ return create_function_single_value<AggregateFunctionSamp, StddevSampName,
SampData, is_stddev,
+ is_nullable>(name, argument_types,
result_is_nullable,
+ true);
}
template <bool is_stddev>
AggregateFunctionPtr create_aggregate_function_variance_pop(const std::string&
name,
const DataTypes&
argument_types,
const bool
result_is_nullable) {
- return AggregateFunctionPtr(
- create_function_single_value<AggregateFunctionPop, VarianceName,
PopData, is_stddev>(
- name, argument_types, result_is_nullable, false));
+ return create_function_single_value<AggregateFunctionPop, VarianceName,
PopData, is_stddev>(
+ name, argument_types, result_is_nullable, false);
}
template <bool is_stddev>
AggregateFunctionPtr create_aggregate_function_stddev_pop(const std::string&
name,
const DataTypes&
argument_types,
const bool
result_is_nullable) {
- return AggregateFunctionPtr(
- create_function_single_value<AggregateFunctionPop, StddevName,
PopData, is_stddev>(
- name, argument_types, result_is_nullable, false));
+ return create_function_single_value<AggregateFunctionPop, StddevName,
PopData, is_stddev>(
+ name, argument_types, result_is_nullable, false);
}
void
register_aggregate_function_stddev_variance_pop(AggregateFunctionSimpleFactory&
factory) {
diff --git a/be/src/vec/aggregate_functions/aggregate_function_sum.cpp
b/be/src/vec/aggregate_functions/aggregate_function_sum.cpp
index 0f7b47193a..ede2425198 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_sum.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_sum.cpp
@@ -20,61 +20,13 @@
#include "vec/aggregate_functions/aggregate_function_sum.h"
-#include <fmt/format.h>
-
-#include "common/logging.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/helpers.h"
-#include "vec/data_types/data_type_nullable.h"
namespace doris::vectorized {
-template <typename T>
-struct SumSimple {
- /// @note It uses slow Decimal128 (cause we need such a variant).
sumWithOverflow is faster for Decimal32/64
- using ResultType = DisposeDecimal<T, NearestFieldType<T>>;
- // using ResultType = NearestFieldType<T>;
- using AggregateDataType = AggregateFunctionSumData<ResultType>;
- using Function = AggregateFunctionSum<T, ResultType, AggregateDataType>;
-};
-
-template <typename T>
-using AggregateFunctionSumSimple = typename SumSimple<T>::Function;
-
-template <template <typename> class Function>
-AggregateFunctionPtr create_aggregate_function_sum(const std::string& name,
- const DataTypes&
argument_types,
- const bool
result_is_nullable) {
- AggregateFunctionPtr res(
- creator_with_type::create<Function>(result_is_nullable,
argument_types));
- if (!res) {
- LOG(WARNING) << fmt::format("Illegal type {} of argument for aggregate
function {}",
- argument_types[0]->get_name(), name);
- }
- return res;
-}
-
-// do not level up return type for agg reader
-template <typename T>
-struct SumSimpleReader {
- using ResultType = T;
- using AggregateDataType = AggregateFunctionSumData<ResultType>;
- using Function = AggregateFunctionSum<T, ResultType, AggregateDataType>;
-};
-
-template <typename T>
-using AggregateFunctionSumSimpleReader = typename SumSimpleReader<T>::Function;
-
-AggregateFunctionPtr create_aggregate_function_sum_reader(const std::string&
name,
- const DataTypes&
argument_types,
- const bool
result_is_nullable) {
- return
create_aggregate_function_sum<AggregateFunctionSumSimpleReader>(name,
argument_types,
-
result_is_nullable);
-}
-
void register_aggregate_function_sum(AggregateFunctionSimpleFactory& factory) {
- factory.register_function_both("sum",
-
create_aggregate_function_sum<AggregateFunctionSumSimple>);
+ factory.register_function_both("sum",
creator_with_type::creator<AggregateFunctionSumSimple>);
}
} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_sum.h
b/be/src/vec/aggregate_functions/aggregate_function_sum.h
index 489d4c72fe..3710f4d915 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_sum.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_sum.h
@@ -153,8 +153,19 @@ private:
UInt32 scale;
};
-AggregateFunctionPtr create_aggregate_function_sum_reader(const std::string&
name,
- const DataTypes&
argument_types,
- const bool
result_is_nullable);
+template <typename T, bool level_up>
+struct SumSimple {
+ /// @note It uses slow Decimal128 (cause we need such a variant).
sumWithOverflow is faster for Decimal32/64
+ using ResultType = std::conditional_t<level_up, DisposeDecimal<T,
NearestFieldType<T>>, T>;
+ using AggregateDataType = AggregateFunctionSumData<ResultType>;
+ using Function = AggregateFunctionSum<T, ResultType, AggregateDataType>;
+};
+
+template <typename T>
+using AggregateFunctionSumSimple = typename SumSimple<T, true>::Function;
+
+// do not level up return type for agg reader
+template <typename T>
+using AggregateFunctionSumSimpleReader = typename SumSimple<T,
false>::Function;
} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_topn.cpp
b/be/src/vec/aggregate_functions/aggregate_function_topn.cpp
index c57ec934e5..8fa8200cb6 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_topn.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_topn.cpp
@@ -25,17 +25,12 @@ AggregateFunctionPtr create_aggregate_function_topn(const
std::string& name,
const DataTypes&
argument_types,
const bool
result_is_nullable) {
if (argument_types.size() == 2) {
- return AggregateFunctionPtr(
-
creator_without_type::create<AggregateFunctionTopN<AggregateFunctionTopNImplInt>>(
- result_is_nullable, argument_types));
+ return
creator_without_type::create<AggregateFunctionTopN<AggregateFunctionTopNImplInt>>(
+ argument_types, result_is_nullable);
} else if (argument_types.size() == 3) {
- return AggregateFunctionPtr(creator_without_type::create<
-
AggregateFunctionTopN<AggregateFunctionTopNImplIntInt>>(
- result_is_nullable, argument_types));
+ return
creator_without_type::create<AggregateFunctionTopN<AggregateFunctionTopNImplIntInt>>(
+ argument_types, result_is_nullable);
}
-
- LOG(WARNING) << fmt::format("Illegal number {} of argument for aggregate
function {}",
- argument_types.size(), name);
return nullptr;
}
@@ -45,39 +40,34 @@ AggregateFunctionPtr create_topn_array(const DataTypes&
argument_types,
const bool result_is_nullable) {
WhichDataType which(remove_nullable(argument_types[0]));
-#define DISPATCH(TYPE)
\
- if (which.idx == TypeIndex::TYPE)
\
- return AggregateFunctionPtr(
\
- creator_without_type::create<AggregateFunctionTopNArray<
\
- AggregateFunctionTemplate<TYPE, has_default_param>,
TYPE, is_weighted>>( \
- result_is_nullable, argument_types));
+#define DISPATCH(TYPE)
\
+ if (which.idx == TypeIndex::TYPE)
\
+ return creator_without_type::create<AggregateFunctionTopNArray<
\
+ AggregateFunctionTemplate<TYPE, has_default_param>, TYPE,
is_weighted>>( \
+ argument_types, result_is_nullable);
FOR_NUMERIC_TYPES(DISPATCH)
FOR_DECIMAL_TYPES(DISPATCH)
#undef DISPATCH
if (which.is_string_or_fixed_string()) {
- return AggregateFunctionPtr(
- creator_without_type::create<AggregateFunctionTopNArray<
- AggregateFunctionTemplate<std::string,
has_default_param>, std::string,
- is_weighted>>(result_is_nullable, argument_types));
+ return creator_without_type::create<AggregateFunctionTopNArray<
+ AggregateFunctionTemplate<std::string, has_default_param>,
std::string,
+ is_weighted>>(argument_types, result_is_nullable);
}
if (which.is_date_or_datetime()) {
- return AggregateFunctionPtr(
- creator_without_type::create<AggregateFunctionTopNArray<
- AggregateFunctionTemplate<Int64, has_default_param>,
Int64, is_weighted>>(
- result_is_nullable, argument_types));
+ return creator_without_type::create<AggregateFunctionTopNArray<
+ AggregateFunctionTemplate<Int64, has_default_param>, Int64,
is_weighted>>(
+ argument_types, result_is_nullable);
}
if (which.is_date_v2()) {
- return AggregateFunctionPtr(
- creator_without_type::create<AggregateFunctionTopNArray<
- AggregateFunctionTemplate<UInt32, has_default_param>,
UInt32, is_weighted>>(
- result_is_nullable, argument_types));
+ return creator_without_type::create<AggregateFunctionTopNArray<
+ AggregateFunctionTemplate<UInt32, has_default_param>, UInt32,
is_weighted>>(
+ argument_types, result_is_nullable);
}
if (which.is_date_time_v2()) {
- return AggregateFunctionPtr(
- creator_without_type::create<AggregateFunctionTopNArray<
- AggregateFunctionTemplate<UInt64, has_default_param>,
UInt64, is_weighted>>(
- result_is_nullable, argument_types));
+ return creator_without_type::create<AggregateFunctionTopNArray<
+ AggregateFunctionTemplate<UInt64, has_default_param>, UInt64,
is_weighted>>(
+ argument_types, result_is_nullable);
}
LOG(WARNING) << fmt::format("Illegal argument type for aggregate function
topn_array is: {}",
diff --git a/be/src/vec/aggregate_functions/aggregate_function_uniq.cpp
b/be/src/vec/aggregate_functions/aggregate_function_uniq.cpp
index 18bd119a21..1016d8593b 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_uniq.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_uniq.cpp
@@ -32,35 +32,26 @@ template <template <typename> class Data>
AggregateFunctionPtr create_aggregate_function_uniq(const std::string& name,
const DataTypes&
argument_types,
const bool
result_is_nullable) {
- if (argument_types.empty()) {
- LOG(WARNING) << "Incorrect number of arguments for aggregate function
" << name;
- return nullptr;
- }
-
if (argument_types.size() == 1) {
const IDataType& argument_type = *remove_nullable(argument_types[0]);
WhichDataType which(argument_type);
AggregateFunctionPtr
res(creator_with_numeric_type::create<AggregateFunctionUniq, Data>(
- result_is_nullable, argument_types));
+ argument_types, result_is_nullable));
if (res) {
return res;
} else if (which.is_decimal32()) {
- return AggregateFunctionPtr(
-
creator_without_type::create<AggregateFunctionUniq<Decimal32, Data<Int32>>>(
- result_is_nullable, argument_types));
+ return
creator_without_type::create<AggregateFunctionUniq<Decimal32, Data<Int32>>>(
+ argument_types, result_is_nullable);
} else if (which.is_decimal64()) {
- return AggregateFunctionPtr(
-
creator_without_type::create<AggregateFunctionUniq<Decimal64, Data<Int64>>>(
- result_is_nullable, argument_types));
+ return
creator_without_type::create<AggregateFunctionUniq<Decimal64, Data<Int64>>>(
+ argument_types, result_is_nullable);
} else if (which.is_decimal128() || which.is_decimal128i()) {
- return AggregateFunctionPtr(
-
creator_without_type::create<AggregateFunctionUniq<Decimal128, Data<Int128>>>(
- result_is_nullable, argument_types));
+ return
creator_without_type::create<AggregateFunctionUniq<Decimal128, Data<Int128>>>(
+ argument_types, result_is_nullable);
} else if (which.is_string_or_fixed_string()) {
- return AggregateFunctionPtr(
- creator_without_type::create<AggregateFunctionUniq<String,
Data<String>>>(
- result_is_nullable, argument_types));
+ return creator_without_type::create<AggregateFunctionUniq<String,
Data<String>>>(
+ argument_types, result_is_nullable);
}
}
diff --git a/be/src/vec/aggregate_functions/aggregate_function_window.cpp
b/be/src/vec/aggregate_functions/aggregate_function_window.cpp
index 4ef03fb316..a9bd6547f9 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_window.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_window.cpp
@@ -27,35 +27,6 @@
namespace doris::vectorized {
-AggregateFunctionPtr create_aggregate_function_dense_rank(const std::string&
name,
- const DataTypes&
argument_types,
- const bool
result_is_nullable) {
- return
AggregateFunctionPtr(creator_without_type::create<WindowFunctionDenseRank>(
- result_is_nullable, argument_types));
-}
-
-AggregateFunctionPtr create_aggregate_function_rank(const std::string& name,
- const DataTypes&
argument_types,
- const bool
result_is_nullable) {
- return AggregateFunctionPtr(
-
creator_without_type::create<WindowFunctionRank>(result_is_nullable,
argument_types));
-}
-
-AggregateFunctionPtr create_aggregate_function_row_number(const std::string&
name,
- const DataTypes&
argument_types,
- const bool
result_is_nullable) {
- return
AggregateFunctionPtr(creator_without_type::create<WindowFunctionRowNumber>(
- result_is_nullable, argument_types));
-}
-
-AggregateFunctionPtr create_aggregate_function_ntile(const std::string& name,
- const DataTypes&
argument_types,
- const bool
result_is_nullable) {
- assert_unary(name, argument_types);
- return AggregateFunctionPtr(
-
creator_without_type::create<WindowFunctionNTile>(result_is_nullable,
argument_types));
-}
-
template <template <typename> class AggregateFunctionTemplate,
template <typename ColVecType, bool, bool> class Data, template
<typename> class Impl,
bool result_is_nullable, bool arg_is_nullable>
@@ -110,10 +81,10 @@
CREATE_WINDOW_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_window_last,
WindowFunctionLastImpl);
void register_aggregate_function_window_rank(AggregateFunctionSimpleFactory&
factory) {
- factory.register_function("dense_rank",
create_aggregate_function_dense_rank);
- factory.register_function("rank", create_aggregate_function_rank);
- factory.register_function("row_number",
create_aggregate_function_row_number);
- factory.register_function("ntile", create_aggregate_function_ntile);
+ factory.register_function("dense_rank",
creator_without_type::creator<WindowFunctionDenseRank>);
+ factory.register_function("rank",
creator_without_type::creator<WindowFunctionRank>);
+ factory.register_function("row_number",
creator_without_type::creator<WindowFunctionRowNumber>);
+ factory.register_function("ntile",
creator_without_type::creator<WindowFunctionNTile>);
}
void register_aggregate_function_window_lead_lag_first_last(
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_window_funnel.cpp
b/be/src/vec/aggregate_functions/aggregate_function_window_funnel.cpp
index 7617d30656..66106f54b4 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_window_funnel.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_window_funnel.cpp
@@ -31,14 +31,12 @@ AggregateFunctionPtr
create_aggregate_function_window_funnel(const std::string&
return nullptr;
}
if (WhichDataType(remove_nullable(argument_types[2])).is_date_time_v2()) {
- return AggregateFunctionPtr(
- creator_without_type::create<
-
AggregateFunctionWindowFunnel<DateV2Value<DateTimeV2ValueType>, UInt64>>(
- result_is_nullable, argument_types));
+ return creator_without_type::create<
+
AggregateFunctionWindowFunnel<DateV2Value<DateTimeV2ValueType>, UInt64>>(
+ argument_types, result_is_nullable);
} else if
(WhichDataType(remove_nullable(argument_types[2])).is_date_time()) {
- return AggregateFunctionPtr(creator_without_type::create<
-
AggregateFunctionWindowFunnel<VecDateTimeValue, Int64>>(
- result_is_nullable, argument_types));
+ return
creator_without_type::create<AggregateFunctionWindowFunnel<VecDateTimeValue,
Int64>>(
+ argument_types, result_is_nullable);
} else {
LOG(WARNING) << "Only support DateTime type as window argument!";
return nullptr;
diff --git a/be/src/vec/aggregate_functions/helpers.h
b/be/src/vec/aggregate_functions/helpers.h
index 7a811871ed..d4323ef07e 100644
--- a/be/src/vec/aggregate_functions/helpers.h
+++ b/be/src/vec/aggregate_functions/helpers.h
@@ -54,9 +54,15 @@ struct creator_without_type {
using NullableT = std::conditional_t<multi_arguments,
AggregateFunctionNullVariadicInline<T, f>,
AggregateFunctionNullUnaryInline<T,
f>>;
+ template <typename AggregateFunctionTemplate>
+ static AggregateFunctionPtr creator(const std::string& name, const
DataTypes& argument_types,
+ const bool result_is_nullable) {
+ return create<AggregateFunctionTemplate>(argument_types,
result_is_nullable);
+ }
+
template <typename AggregateFunctionTemplate, typename... TArgs>
- static IAggregateFunction* create(const bool result_is_nullable,
- const DataTypes& argument_types,
TArgs&&... args) {
+ static AggregateFunctionPtr create(const DataTypes& argument_types,
+ const bool result_is_nullable,
TArgs&&... args) {
IAggregateFunction* result(new
AggregateFunctionTemplate(std::forward<TArgs>(args)...,
remove_nullable(argument_types)));
if (have_nullable(argument_types)) {
@@ -68,7 +74,7 @@ struct creator_without_type {
make_bool_variant(argument_types.size() > 1),
make_bool_variant(result_is_nullable));
}
- return result;
+ return AggregateFunctionPtr(result);
}
};
@@ -98,13 +104,13 @@ struct CurryDirectAndData {
template <bool allow_integer, bool allow_float, bool allow_decimal, int
define_index = 0>
struct creator_with_type_base {
template <typename Class, typename... TArgs>
- static IAggregateFunction* create_base(const bool result_is_nullable,
- const DataTypes& argument_types,
TArgs&&... args) {
+ static AggregateFunctionPtr create_base(const DataTypes& argument_types,
+ const bool result_is_nullable,
TArgs&&... args) {
WhichDataType which(remove_nullable(argument_types[define_index]));
#define DISPATCH(TYPE)
\
if (which.idx == TypeIndex::TYPE) {
\
return creator_without_type::create<typename Class::template T<TYPE>>(
\
- result_is_nullable, argument_types,
std::forward<TArgs>(args)...); \
+ argument_types, result_is_nullable,
std::forward<TArgs>(args)...); \
}
if constexpr (allow_integer) {
@@ -120,28 +126,58 @@ struct creator_with_type_base {
return nullptr;
}
+ template <template <typename> class AggregateFunctionTemplate>
+ static AggregateFunctionPtr creator(const std::string& name, const
DataTypes& argument_types,
+ const bool result_is_nullable) {
+ return
create_base<CurryDirect<AggregateFunctionTemplate>>(argument_types,
+
result_is_nullable);
+ }
+
template <template <typename> class AggregateFunctionTemplate, typename...
TArgs>
- static IAggregateFunction* create(TArgs&&... args) {
+ static AggregateFunctionPtr create(TArgs&&... args) {
return
create_base<CurryDirect<AggregateFunctionTemplate>>(std::forward<TArgs>(args)...);
}
+ template <template <typename> class AggregateFunctionTemplate, template
<typename> class Data>
+ static AggregateFunctionPtr creator(const std::string& name, const
DataTypes& argument_types,
+ const bool result_is_nullable) {
+ return create_base<CurryData<AggregateFunctionTemplate,
Data>>(argument_types,
+
result_is_nullable);
+ }
+
template <template <typename> class AggregateFunctionTemplate, template
<typename> class Data,
typename... TArgs>
- static IAggregateFunction* create(TArgs&&... args) {
+ static AggregateFunctionPtr create(TArgs&&... args) {
return create_base<CurryData<AggregateFunctionTemplate, Data>>(
std::forward<TArgs>(args)...);
}
+ template <template <typename> class AggregateFunctionTemplate, template
<typename> class Data,
+ template <typename> class Impl>
+ static AggregateFunctionPtr creator(const std::string& name, const
DataTypes& argument_types,
+ const bool result_is_nullable) {
+ return create_base<CurryDataImpl<AggregateFunctionTemplate, Data,
Impl>>(
+ argument_types, result_is_nullable);
+ }
+
template <template <typename> class AggregateFunctionTemplate, template
<typename> class Data,
template <typename> class Impl, typename... TArgs>
- static IAggregateFunction* create(TArgs&&... args) {
+ static AggregateFunctionPtr create(TArgs&&... args) {
return create_base<CurryDataImpl<AggregateFunctionTemplate, Data,
Impl>>(
std::forward<TArgs>(args)...);
}
+ template <template <typename, typename> class AggregateFunctionTemplate,
+ template <typename> class Data>
+ static AggregateFunctionPtr creator(const std::string& name, const
DataTypes& argument_types,
+ const bool result_is_nullable) {
+ return create_base<CurryDirectAndData<AggregateFunctionTemplate,
Data>>(argument_types,
+
result_is_nullable);
+ }
+
template <template <typename, typename> class AggregateFunctionTemplate,
template <typename> class Data, typename... TArgs>
- static IAggregateFunction* create(TArgs&&... args) {
+ static AggregateFunctionPtr create(TArgs&&... args) {
return create_base<CurryDirectAndData<AggregateFunctionTemplate,
Data>>(
std::forward<TArgs>(args)...);
}
diff --git a/be/src/vec/functions/array/function_array_aggregation.cpp
b/be/src/vec/functions/array/function_array_aggregation.cpp
index 588c15edbe..c0d45acea1 100644
--- a/be/src/vec/functions/array/function_array_aggregation.cpp
+++ b/be/src/vec/functions/array/function_array_aggregation.cpp
@@ -117,8 +117,7 @@ struct AggregateFunction {
using Function = typename Derived::template TypeTraits<T>::Function;
static auto create(const DataTypePtr& data_type_ptr) ->
AggregateFunctionPtr {
- return AggregateFunctionPtr(creator_with_type::create<Function>(
- true, DataTypes {make_nullable(data_type_ptr)}));
+ return creator_with_type::create<Function>(DataTypes
{make_nullable(data_type_ptr)}, true);
}
};
@@ -225,8 +224,8 @@ struct NameArrayMin {
template <>
struct AggregateFunction<AggregateFunctionImpl<AggregateOperation::MIN>> {
static auto create(const DataTypePtr& data_type_ptr) ->
AggregateFunctionPtr {
- return AggregateFunctionPtr(create_aggregate_function_min(
- NameArrayMin::name, {make_nullable(data_type_ptr)}, true));
+ return
create_aggregate_function_single_value<AggregateFunctionMinData>(
+ NameArrayMin::name, {make_nullable(data_type_ptr)}, true);
}
};
@@ -237,8 +236,8 @@ struct NameArrayMax {
template <>
struct AggregateFunction<AggregateFunctionImpl<AggregateOperation::MAX>> {
static auto create(const DataTypePtr& data_type_ptr) ->
AggregateFunctionPtr {
- return AggregateFunctionPtr(create_aggregate_function_max(
- NameArrayMax::name, {make_nullable(data_type_ptr)}, true));
+ return
create_aggregate_function_single_value<AggregateFunctionMaxData>(
+ NameArrayMax::name, {make_nullable(data_type_ptr)}, true);
}
};
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]