This is an automated email from the ASF dual-hosted git repository.
panxiaolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 12bccada097 [Chore](function) reduce some template instantiation of
create_topn_array (#52277)
12bccada097 is described below
commit 12bccada09791e7a7af873cdf0c1f24d7d0c721e
Author: Pxl <[email protected]>
AuthorDate: Mon Jun 30 10:33:04 2025 +0800
[Chore](function) reduce some template instantiation of create_topn_array
(#52277)
reduce some template instantiation of create_topn_array
---
.../aggregate_function_approx_count_distinct.cpp | 79 ++------------
.../aggregate_function_approx_count_distinct.h | 2 -
.../aggregate_function_approx_top_sum.cpp | 4 +-
.../aggregate_function_histogram.cpp | 84 ++++-----------
.../aggregate_function_histogram.h | 33 ++----
.../aggregate_function_orthogonal_bitmap.cpp | 5 +-
.../aggregate_function_topn.cpp | 116 +++------------------
.../aggregate_functions/aggregate_function_topn.h | 53 ++++------
be/src/vec/aggregate_functions/helpers.h | 83 +++++++++------
.../trees/expressions/functions/agg/Histogram.java | 5 +-
.../nereids_function_p0/agg_function/agg.groovy | 2 -
11 files changed, 134 insertions(+), 332 deletions(-)
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp
b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp
index e46fae2c587..ace7abb342b 100644
---
a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp
+++
b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp
@@ -17,6 +17,7 @@
#include "vec/aggregate_functions/aggregate_function_approx_count_distinct.h"
+#include "common/status.h"
#include "vec/aggregate_functions/helpers.h"
#include "vec/columns/column_array.h"
#include "vec/columns/column_decimal.h"
@@ -25,8 +26,6 @@
#include "vec/columns/column_struct.h"
#include "vec/columns/column_variant.h"
#include "vec/data_types/data_type.h"
-#include "vec/data_types/data_type_nullable.h"
-#include "vec/functions/function.h"
namespace doris::vectorized {
#include "common/compile_check_begin.h"
@@ -35,63 +34,6 @@ AggregateFunctionPtr
create_aggregate_function_approx_count_distinct(
const std::string& name, const DataTypes& argument_types, const bool
result_is_nullable,
const AggregateFunctionAttr& attr) {
switch (argument_types[0]->get_primitive_type()) {
- case PrimitiveType::TYPE_BOOLEAN:
- return
creator_without_type::create<AggregateFunctionApproxCountDistinct<TYPE_BOOLEAN>>(
- argument_types, result_is_nullable);
- case PrimitiveType::TYPE_TINYINT:
- return
creator_without_type::create<AggregateFunctionApproxCountDistinct<TYPE_TINYINT>>(
- argument_types, result_is_nullable);
- case PrimitiveType::TYPE_SMALLINT:
- return
creator_without_type::create<AggregateFunctionApproxCountDistinct<TYPE_SMALLINT>>(
- argument_types, result_is_nullable);
- case PrimitiveType::TYPE_INT:
- return
creator_without_type::create<AggregateFunctionApproxCountDistinct<TYPE_INT>>(
- argument_types, result_is_nullable);
- case PrimitiveType::TYPE_BIGINT:
- return
creator_without_type::create<AggregateFunctionApproxCountDistinct<TYPE_BIGINT>>(
- argument_types, result_is_nullable);
- case PrimitiveType::TYPE_LARGEINT:
- return
creator_without_type::create<AggregateFunctionApproxCountDistinct<TYPE_LARGEINT>>(
- argument_types, result_is_nullable);
- case PrimitiveType::TYPE_FLOAT:
- return
creator_without_type::create<AggregateFunctionApproxCountDistinct<TYPE_FLOAT>>(
- argument_types, result_is_nullable);
- case PrimitiveType::TYPE_DOUBLE:
- return
creator_without_type::create<AggregateFunctionApproxCountDistinct<TYPE_DOUBLE>>(
- argument_types, result_is_nullable);
- case PrimitiveType::TYPE_DECIMAL32:
- return
creator_without_type::create<AggregateFunctionApproxCountDistinct<TYPE_DECIMAL32>>(
- argument_types, result_is_nullable);
- case PrimitiveType::TYPE_DECIMAL64:
- return
creator_without_type::create<AggregateFunctionApproxCountDistinct<TYPE_DECIMAL64>>(
- argument_types, result_is_nullable);
- case PrimitiveType::TYPE_DECIMAL128I:
- return
creator_without_type::create<AggregateFunctionApproxCountDistinct<TYPE_DECIMAL128I>>(
- argument_types, result_is_nullable);
- case PrimitiveType::TYPE_DECIMALV2:
- return
creator_without_type::create<AggregateFunctionApproxCountDistinct<TYPE_DECIMALV2>>(
- argument_types, result_is_nullable);
- case PrimitiveType::TYPE_DECIMAL256:
- return
creator_without_type::create<AggregateFunctionApproxCountDistinct<TYPE_DECIMAL256>>(
- argument_types, result_is_nullable);
- case PrimitiveType::TYPE_STRING:
- case PrimitiveType::TYPE_CHAR:
- case PrimitiveType::TYPE_VARCHAR:
- case PrimitiveType::TYPE_JSONB:
- return
creator_without_type::create<AggregateFunctionApproxCountDistinct<TYPE_STRING>>(
- argument_types, result_is_nullable);
- case PrimitiveType::TYPE_DATE:
- return
creator_without_type::create<AggregateFunctionApproxCountDistinct<TYPE_DATE>>(
- argument_types, result_is_nullable);
- case PrimitiveType::TYPE_DATETIME:
- return
creator_without_type::create<AggregateFunctionApproxCountDistinct<TYPE_DATETIME>>(
- argument_types, result_is_nullable);
- case PrimitiveType::TYPE_DATEV2:
- return
creator_without_type::create<AggregateFunctionApproxCountDistinct<TYPE_DATEV2>>(
- argument_types, result_is_nullable);
- case PrimitiveType::TYPE_DATETIMEV2:
- return
creator_without_type::create<AggregateFunctionApproxCountDistinct<TYPE_DATETIMEV2>>(
- argument_types, result_is_nullable);
case PrimitiveType::TYPE_IPV4:
return
creator_without_type::create<AggregateFunctionApproxCountDistinct<TYPE_IPV4>>(
argument_types, result_is_nullable);
@@ -110,18 +52,15 @@ AggregateFunctionPtr
create_aggregate_function_approx_count_distinct(
case PrimitiveType::TYPE_VARIANT:
return
creator_without_type::create<AggregateFunctionApproxCountDistinct<TYPE_VARIANT>>(
argument_types, result_is_nullable);
- case PrimitiveType::TYPE_BITMAP:
- return
creator_without_type::create<AggregateFunctionApproxCountDistinct<TYPE_BITMAP>>(
- argument_types, result_is_nullable);
- case PrimitiveType::TYPE_HLL:
- return
creator_without_type::create<AggregateFunctionApproxCountDistinct<TYPE_HLL>>(
- argument_types, result_is_nullable);
- case PrimitiveType::TYPE_QUANTILE_STATE:
- return creator_without_type::create<
-
AggregateFunctionApproxCountDistinct<TYPE_QUANTILE_STATE>>(argument_types,
-
result_is_nullable);
default:
- return nullptr;
+ auto res =
creator_with_any::create<AggregateFunctionApproxCountDistinct>(
+ argument_types, result_is_nullable);
+ if (!res) {
+ throw Exception(
+ ErrorCode::NOT_IMPLEMENTED_ERROR,
+ "Unsupported type for approx_count_distinct: " +
argument_types[0]->get_name());
+ }
+ return res;
}
}
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.h
b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.h
index 1569f486e41..be67927c0eb 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.h
@@ -20,8 +20,6 @@
#include <stddef.h>
#include <stdint.h>
-#include <algorithm>
-#include <boost/iterator/iterator_facade.hpp>
#include <memory>
#include <string>
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_approx_top_sum.cpp
b/be/src/vec/aggregate_functions/aggregate_function_approx_top_sum.cpp
index 8c80a2c87c5..6fbd4d85faf 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_approx_top_sum.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_approx_top_sum.cpp
@@ -30,7 +30,7 @@ AggregateFunctionPtr
create_aggregate_function_multi_top_sum_impl(
const DataTypes& argument_types, const bool result_is_nullable,
const std::vector<std::string>& column_names) {
if (N == argument_types.size() - 3) {
- return creator_with_type_base<true, false, false, N>::template create<
+ return creator_with_type_base<true, false, false, false, false,
N>::template create<
AggregateFunctionApproxTopSumSimple>(argument_types,
result_is_nullable,
column_names);
} else {
@@ -43,7 +43,7 @@ template <>
AggregateFunctionPtr create_aggregate_function_multi_top_sum_impl<0>(
const DataTypes& argument_types, const bool result_is_nullable,
const std::vector<std::string>& column_names) {
- return creator_with_type_base<true, false, false, 0>::template create<
+ return creator_with_type_base<true, false, false, false, false,
0>::template create<
AggregateFunctionApproxTopSumSimple>(argument_types,
result_is_nullable, column_names);
}
diff --git a/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp
b/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp
index 69eee5556d1..16298a9fc31 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_histogram.cpp
@@ -20,82 +20,40 @@
#include <fmt/format.h>
#include <glog/logging.h>
-#include <algorithm>
-
#include "vec/aggregate_functions/helpers.h"
-#include "vec/core/types.h"
#include "vec/data_types/data_type.h"
-#include "vec/data_types/data_type_nullable.h"
namespace doris::vectorized {
#include "common/compile_check_begin.h"
-template <PrimitiveType T>
-AggregateFunctionPtr create_agg_function_histogram(const DataTypes&
argument_types,
- const bool
result_is_nullable) {
- bool has_input_param = (argument_types.size() == 2);
+template <typename Data>
+using HistogramWithInputParam = AggregateFunctionHistogram<Data, true>;
- if (has_input_param) {
- return creator_without_type::create<
- AggregateFunctionHistogram<AggregateFunctionHistogramData<T>,
T, true>>(
- argument_types, result_is_nullable);
- } else {
- return creator_without_type::create<
- AggregateFunctionHistogram<AggregateFunctionHistogramData<T>,
T, false>>(
- argument_types, result_is_nullable);
- }
-}
+template <typename Data>
+using HistogramNormal = AggregateFunctionHistogram<Data, false>;
AggregateFunctionPtr create_aggregate_function_histogram(const std::string&
name,
const DataTypes&
argument_types,
const bool
result_is_nullable,
const
AggregateFunctionAttr& attr) {
- switch (argument_types[0]->get_primitive_type()) {
- case PrimitiveType::TYPE_BOOLEAN:
- return create_agg_function_histogram<TYPE_BOOLEAN>(argument_types,
result_is_nullable);
- case PrimitiveType::TYPE_TINYINT:
- return create_agg_function_histogram<TYPE_TINYINT>(argument_types,
result_is_nullable);
- case PrimitiveType::TYPE_SMALLINT:
- return create_agg_function_histogram<TYPE_SMALLINT>(argument_types,
result_is_nullable);
- case PrimitiveType::TYPE_INT:
- return create_agg_function_histogram<TYPE_INT>(argument_types,
result_is_nullable);
- case PrimitiveType::TYPE_BIGINT:
- return create_agg_function_histogram<TYPE_BIGINT>(argument_types,
result_is_nullable);
- case PrimitiveType::TYPE_LARGEINT:
- return create_agg_function_histogram<TYPE_LARGEINT>(argument_types,
result_is_nullable);
- case PrimitiveType::TYPE_FLOAT:
- return create_agg_function_histogram<TYPE_FLOAT>(argument_types,
result_is_nullable);
- case PrimitiveType::TYPE_DOUBLE:
- return create_agg_function_histogram<TYPE_DOUBLE>(argument_types,
result_is_nullable);
- case PrimitiveType::TYPE_DECIMAL32:
- return create_agg_function_histogram<TYPE_DECIMAL32>(argument_types,
result_is_nullable);
- case PrimitiveType::TYPE_DECIMAL64:
- return create_agg_function_histogram<TYPE_DECIMAL64>(argument_types,
result_is_nullable);
- case PrimitiveType::TYPE_DECIMAL128I:
- return create_agg_function_histogram<TYPE_DECIMAL128I>(argument_types,
result_is_nullable);
- case PrimitiveType::TYPE_DECIMALV2:
- return create_agg_function_histogram<TYPE_DECIMALV2>(argument_types,
result_is_nullable);
- case PrimitiveType::TYPE_DECIMAL256:
- return create_agg_function_histogram<TYPE_DECIMAL256>(argument_types,
result_is_nullable);
- case PrimitiveType::TYPE_CHAR:
- return create_agg_function_histogram<TYPE_CHAR>(argument_types,
result_is_nullable);
- case PrimitiveType::TYPE_VARCHAR:
- return create_agg_function_histogram<TYPE_VARCHAR>(argument_types,
result_is_nullable);
- case PrimitiveType::TYPE_STRING:
- return create_agg_function_histogram<TYPE_STRING>(argument_types,
result_is_nullable);
- case PrimitiveType::TYPE_DATE:
- return create_agg_function_histogram<TYPE_DATE>(argument_types,
result_is_nullable);
- case PrimitiveType::TYPE_DATETIME:
- return create_agg_function_histogram<TYPE_DATETIME>(argument_types,
result_is_nullable);
- case PrimitiveType::TYPE_DATEV2:
- return create_agg_function_histogram<TYPE_DATEV2>(argument_types,
result_is_nullable);
- case PrimitiveType::TYPE_DATETIMEV2:
- return create_agg_function_histogram<TYPE_DATETIMEV2>(argument_types,
result_is_nullable);
- default:
- LOG(WARNING) << fmt::format("unsupported input type {} for aggregate
function {}",
- argument_types[0]->get_name(), name);
- return nullptr;
+ AggregateFunctionPtr result;
+ if (argument_types.size() == 2) {
+ result = creator_with_any::create<HistogramWithInputParam,
AggregateFunctionHistogramData>(
+ argument_types, result_is_nullable);
+ } else if (argument_types.size() == 1) {
+ result = creator_with_any::create<HistogramNormal,
AggregateFunctionHistogramData>(
+ argument_types, result_is_nullable);
+ } else {
+ throw Exception(ErrorCode::INVALID_ARGUMENT,
+ "Aggregate function histogram requires 1 or 2
arguments, but got {}",
+ argument_types.size());
+ }
+ if (!result) {
+ throw Exception(ErrorCode::NOT_IMPLEMENTED_ERROR,
+ "Aggregate function histogram does not support type
{}",
+ argument_types[0]->get_primitive_type());
}
+ return result;
}
void register_aggregate_function_histogram(AggregateFunctionSimpleFactory&
factory) {
diff --git a/be/src/vec/aggregate_functions/aggregate_function_histogram.h
b/be/src/vec/aggregate_functions/aggregate_function_histogram.h
index 4d0566e57a4..2ca9856e7d9 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_histogram.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_histogram.h
@@ -17,14 +17,9 @@
#pragma once
-#include <rapidjson/stringbuffer.h>
-#include <stddef.h>
-
-#include <iterator>
#include <map>
#include <memory>
#include <string>
-#include <type_traits>
#include <utility>
#include <vector>
@@ -44,19 +39,13 @@
namespace doris {
#include "common/compile_check_begin.h"
-namespace vectorized {
-class Arena;
-class BufferReadable;
-class BufferWritable;
-template <PrimitiveType T>
-class ColumnVector;
-} // namespace vectorized
} // namespace doris
namespace doris::vectorized {
template <PrimitiveType T>
struct AggregateFunctionHistogramData {
+ static constexpr auto Ptype = T;
using ColVecType = typename PrimitiveTypeTraits<T>::ColumnType;
const static size_t DEFAULT_BUCKET_NUM = 128;
const static size_t BUCKET_NUM_INIT_VALUE = 0;
@@ -106,7 +95,7 @@ struct AggregateFunctionHistogramData {
void write(BufferWritable& buf) const {
write_binary(max_num_buckets, buf);
- size_t element_number = (size_t)ordered_map.size();
+ auto element_number = (size_t)ordered_map.size();
write_binary(element_number, buf);
auto pair_vector = map_to_vector();
@@ -157,7 +146,7 @@ struct AggregateFunctionHistogramData {
buckets, ordered_map,
max_num_buckets == BUCKET_NUM_INIT_VALUE ? DEFAULT_BUCKET_NUM
: max_num_buckets);
histogram_to_json(buffer, buckets, data_type);
- return std::string(buffer.GetString());
+ return {buffer.GetString()};
}
std::vector<std::pair<size_t, typename
PrimitiveTypeTraits<T>::ColumnItemType>> map_to_vector()
@@ -174,17 +163,14 @@ private:
std::map<typename PrimitiveTypeTraits<T>::ColumnItemType, size_t>
ordered_map;
};
-template <typename Data, PrimitiveType T, bool has_input_param>
+template <typename Data, bool has_input_param>
class AggregateFunctionHistogram final
- : public IAggregateFunctionDataHelper<
- Data, AggregateFunctionHistogram<Data, T, has_input_param>> {
+ : public IAggregateFunctionDataHelper<Data,
+ AggregateFunctionHistogram<Data,
has_input_param>> {
public:
- using ColVecType = typename PrimitiveTypeTraits<T>::ColumnType;
-
AggregateFunctionHistogram() = default;
AggregateFunctionHistogram(const DataTypes& argument_types_)
- : IAggregateFunctionDataHelper<Data,
- AggregateFunctionHistogram<Data, T,
has_input_param>>(
+ : IAggregateFunctionDataHelper<Data,
AggregateFunctionHistogram<Data, has_input_param>>(
argument_types_),
_argument_type(argument_types_[0]) {}
@@ -207,13 +193,14 @@ public:
this->data(place).set_parameters(Data::DEFAULT_BUCKET_NUM);
}
- if constexpr (is_string_type(T)) {
+ if constexpr (is_string_type(Data::Ptype)) {
this->data(place).add(
assert_cast<const ColumnString&,
TypeCheckOnRelease::DISABLE>(*columns[0])
.get_data_at(row_num));
} else {
this->data(place).add(
- assert_cast<const ColVecType&,
TypeCheckOnRelease::DISABLE>(*columns[0])
+ assert_cast<const typename Data::ColVecType&,
TypeCheckOnRelease::DISABLE>(
+ *columns[0])
.get_data()[row_num]);
}
}
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp
b/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp
index a1b3d172112..493577a9fd4 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.cpp
@@ -46,9 +46,8 @@ AggregateFunctionPtr
create_aggregate_function_orthogonal(const std::string& nam
argument_types, result_is_nullable);
} else {
AggregateFunctionPtr res(
- creator_with_type_base<true, true, false,
1>::create<AggFunctionOrthBitmapFunc,
-
Impl>(argument_types,
-
result_is_nullable));
+ creator_with_type_base<true, true, false, false, false,
1>::create<
+ AggFunctionOrthBitmapFunc, Impl>(argument_types,
result_is_nullable));
if (res) {
return res;
} else if (is_string_type(argument_types[1]->get_primitive_type())) {
diff --git a/be/src/vec/aggregate_functions/aggregate_function_topn.cpp
b/be/src/vec/aggregate_functions/aggregate_function_topn.cpp
index 1bed46e0799..44567dc4a81 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_topn.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_topn.cpp
@@ -21,7 +21,6 @@
#include <glog/logging.h>
#include "vec/aggregate_functions/helpers.h"
-#include "vec/core/types.h"
#include "vec/data_types/data_type.h"
namespace doris::vectorized {
@@ -41,98 +40,10 @@ AggregateFunctionPtr create_aggregate_function_topn(const
std::string& name,
return nullptr;
}
-template <template <PrimitiveType, bool> class AggregateFunctionTemplate, bool
has_default_param,
- bool is_weighted>
-AggregateFunctionPtr create_topn_array(const DataTypes& argument_types,
- const bool result_is_nullable) {
- switch (argument_types[0]->get_primitive_type()) {
- case PrimitiveType::TYPE_BOOLEAN:
- return creator_without_type::create<AggregateFunctionTopNArray<
- AggregateFunctionTemplate<TYPE_BOOLEAN, has_default_param>,
TYPE_BOOLEAN,
- is_weighted>>(argument_types, result_is_nullable);
- case PrimitiveType::TYPE_TINYINT:
- return creator_without_type::create<AggregateFunctionTopNArray<
- AggregateFunctionTemplate<TYPE_TINYINT, has_default_param>,
TYPE_TINYINT,
- is_weighted>>(argument_types, result_is_nullable);
- case PrimitiveType::TYPE_SMALLINT:
- return creator_without_type::create<AggregateFunctionTopNArray<
- AggregateFunctionTemplate<TYPE_SMALLINT, has_default_param>,
TYPE_SMALLINT,
- is_weighted>>(argument_types, result_is_nullable);
- case PrimitiveType::TYPE_INT:
- return creator_without_type::create<AggregateFunctionTopNArray<
- AggregateFunctionTemplate<TYPE_INT, has_default_param>,
TYPE_INT, is_weighted>>(
- argument_types, result_is_nullable);
- case PrimitiveType::TYPE_BIGINT:
- return creator_without_type::create<AggregateFunctionTopNArray<
- AggregateFunctionTemplate<TYPE_BIGINT, has_default_param>,
TYPE_BIGINT,
- is_weighted>>(argument_types, result_is_nullable);
- case PrimitiveType::TYPE_LARGEINT:
- return creator_without_type::create<AggregateFunctionTopNArray<
- AggregateFunctionTemplate<TYPE_LARGEINT, has_default_param>,
TYPE_LARGEINT,
- is_weighted>>(argument_types, result_is_nullable);
- case PrimitiveType::TYPE_FLOAT:
- return creator_without_type::create<AggregateFunctionTopNArray<
- AggregateFunctionTemplate<TYPE_FLOAT, has_default_param>,
TYPE_FLOAT, is_weighted>>(
- argument_types, result_is_nullable);
- case PrimitiveType::TYPE_DOUBLE:
- return creator_without_type::create<AggregateFunctionTopNArray<
- AggregateFunctionTemplate<TYPE_DOUBLE, has_default_param>,
TYPE_DOUBLE,
- is_weighted>>(argument_types, result_is_nullable);
- case PrimitiveType::TYPE_DECIMAL32:
- return creator_without_type::create<AggregateFunctionTopNArray<
- AggregateFunctionTemplate<TYPE_DECIMAL32, has_default_param>,
TYPE_DECIMAL32,
- is_weighted>>(argument_types, result_is_nullable);
- case PrimitiveType::TYPE_DECIMAL64:
- return creator_without_type::create<AggregateFunctionTopNArray<
- AggregateFunctionTemplate<TYPE_DECIMAL64, has_default_param>,
TYPE_DECIMAL64,
- is_weighted>>(argument_types, result_is_nullable);
- case PrimitiveType::TYPE_DECIMAL128I:
- return creator_without_type::create<AggregateFunctionTopNArray<
- AggregateFunctionTemplate<TYPE_DECIMAL128I,
has_default_param>, TYPE_DECIMAL128I,
- is_weighted>>(argument_types, result_is_nullable);
- case PrimitiveType::TYPE_DECIMALV2:
- return creator_without_type::create<AggregateFunctionTopNArray<
- AggregateFunctionTemplate<TYPE_DECIMALV2, has_default_param>,
TYPE_DECIMALV2,
- is_weighted>>(argument_types, result_is_nullable);
- case PrimitiveType::TYPE_DECIMAL256:
- return creator_without_type::create<AggregateFunctionTopNArray<
- AggregateFunctionTemplate<TYPE_DECIMAL256, has_default_param>,
TYPE_DECIMAL256,
- is_weighted>>(argument_types, result_is_nullable);
- case PrimitiveType::TYPE_STRING:
- return creator_without_type::create<AggregateFunctionTopNArray<
- AggregateFunctionTemplate<TYPE_STRING, has_default_param>,
TYPE_STRING,
- is_weighted>>(argument_types, result_is_nullable);
- case PrimitiveType::TYPE_CHAR:
- return creator_without_type::create<AggregateFunctionTopNArray<
- AggregateFunctionTemplate<TYPE_CHAR, has_default_param>,
TYPE_CHAR, is_weighted>>(
- argument_types, result_is_nullable);
- case PrimitiveType::TYPE_VARCHAR:
- return creator_without_type::create<AggregateFunctionTopNArray<
- AggregateFunctionTemplate<TYPE_VARCHAR, has_default_param>,
TYPE_VARCHAR,
- is_weighted>>(argument_types, result_is_nullable);
- case PrimitiveType::TYPE_DATE:
- return creator_without_type::create<AggregateFunctionTopNArray<
- AggregateFunctionTemplate<TYPE_DATE, has_default_param>,
TYPE_DATE, is_weighted>>(
- argument_types, result_is_nullable);
- case PrimitiveType::TYPE_DATETIME:
- return creator_without_type::create<AggregateFunctionTopNArray<
- AggregateFunctionTemplate<TYPE_DATETIME, has_default_param>,
TYPE_DATETIME,
- is_weighted>>(argument_types, result_is_nullable);
- case PrimitiveType::TYPE_DATEV2:
- return creator_without_type::create<AggregateFunctionTopNArray<
- AggregateFunctionTemplate<TYPE_DATEV2, has_default_param>,
TYPE_DATEV2,
- is_weighted>>(argument_types, result_is_nullable);
- case PrimitiveType::TYPE_DATETIMEV2:
- return creator_without_type::create<AggregateFunctionTopNArray<
- AggregateFunctionTemplate<TYPE_DATETIMEV2, has_default_param>,
TYPE_DATETIMEV2,
- is_weighted>>(argument_types, result_is_nullable);
- default:
- LOG(WARNING) << fmt::format(
- "Illegal argument type for aggregate function topn_array is:
{}",
- remove_nullable(argument_types[0])->get_name());
- return nullptr;
- }
-}
+template <PrimitiveType T>
+using ImplArray = AggregateFunctionTopNImplArray<T, false>;
+template <PrimitiveType T>
+using ImplArrayWithDefault = AggregateFunctionTopNImplArray<T, true>;
AggregateFunctionPtr create_aggregate_function_topn_array(const std::string&
name,
const DataTypes&
argument_types,
@@ -140,25 +51,30 @@ AggregateFunctionPtr
create_aggregate_function_topn_array(const std::string& nam
const
AggregateFunctionAttr& attr) {
bool has_default_param = (argument_types.size() == 3);
if (has_default_param) {
- return create_topn_array<AggregateFunctionTopNImplArray, true,
false>(argument_types,
-
result_is_nullable);
+ return creator_with_any::create<AggregateFunctionTopNArray,
ImplArrayWithDefault>(
+ argument_types, result_is_nullable);
} else {
- return create_topn_array<AggregateFunctionTopNImplArray, false,
false>(argument_types,
+ return creator_with_any::create<AggregateFunctionTopNArray,
ImplArray>(argument_types,
result_is_nullable);
}
}
+template <PrimitiveType T>
+using ImplWeight = AggregateFunctionTopNImplWeight<T, false>;
+template <PrimitiveType T>
+using ImplWeightWithDefault = AggregateFunctionTopNImplWeight<T, true>;
+
AggregateFunctionPtr create_aggregate_function_topn_weighted(const
std::string& name,
const DataTypes&
argument_types,
const bool
result_is_nullable,
const
AggregateFunctionAttr& attr) {
bool has_default_param = (argument_types.size() == 4);
if (has_default_param) {
- return create_topn_array<AggregateFunctionTopNImplWeight, true,
true>(argument_types,
-
result_is_nullable);
+ return creator_with_any::create<AggregateFunctionTopNArray,
ImplWeightWithDefault>(
+ argument_types, result_is_nullable);
} else {
- return create_topn_array<AggregateFunctionTopNImplWeight, false,
true>(argument_types,
-
result_is_nullable);
+ return creator_with_any::create<AggregateFunctionTopNArray,
ImplWeight>(argument_types,
+
result_is_nullable);
}
}
diff --git a/be/src/vec/aggregate_functions/aggregate_function_topn.h
b/be/src/vec/aggregate_functions/aggregate_function_topn.h
index b265053053c..edddb129f06 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_topn.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_topn.h
@@ -25,10 +25,8 @@
#include <algorithm>
#include <functional>
-#include <iterator>
#include <memory>
#include <string>
-#include <type_traits>
#include <utility>
#include <vector>
@@ -50,13 +48,6 @@
namespace doris {
#include "common/compile_check_begin.h"
-namespace vectorized {
-class Arena;
-class BufferReadable;
-class BufferWritable;
-template <PrimitiveType T>
-class ColumnDecimal;
-} // namespace vectorized
} // namespace doris
namespace doris::vectorized {
@@ -216,8 +207,8 @@ struct AggregateFunctionTopNData {
};
struct AggregateFunctionTopNImplInt {
- static void add(AggregateFunctionTopNData<TYPE_STRING>& __restrict place,
- const IColumn** columns, size_t row_num) {
+ using Data = AggregateFunctionTopNData<TYPE_STRING>;
+ static void add(Data& __restrict place, const IColumn** columns, size_t
row_num) {
place.set_paramenters(
assert_cast<const ColumnInt32*,
TypeCheckOnRelease::DISABLE>(columns[1])
->get_element(row_num));
@@ -227,8 +218,8 @@ struct AggregateFunctionTopNImplInt {
};
struct AggregateFunctionTopNImplIntInt {
- static void add(AggregateFunctionTopNData<TYPE_STRING>& __restrict place,
- const IColumn** columns, size_t row_num) {
+ using Data = AggregateFunctionTopNData<TYPE_STRING>;
+ static void add(Data& __restrict place, const IColumn** columns, size_t
row_num) {
place.set_paramenters(
assert_cast<const ColumnInt32*,
TypeCheckOnRelease::DISABLE>(columns[1])
->get_element(row_num),
@@ -241,7 +232,9 @@ struct AggregateFunctionTopNImplIntInt {
//for topn_array agg
template <PrimitiveType T, bool has_default_param>
struct AggregateFunctionTopNImplArray {
+ using Data = AggregateFunctionTopNData<T>;
using ColVecType = typename PrimitiveTypeTraits<T>::ColumnType;
+ static String get_name() { return "topn_array"; }
static void add(AggregateFunctionTopNData<T>& __restrict place, const
IColumn** columns,
size_t row_num) {
if constexpr (has_default_param) {
@@ -271,7 +264,9 @@ struct AggregateFunctionTopNImplArray {
//for topn_weighted agg
template <PrimitiveType T, bool has_default_param>
struct AggregateFunctionTopNImplWeight {
+ using Data = AggregateFunctionTopNData<T>;
using ColVecType = typename PrimitiveTypeTraits<T>::ColumnType;
+ static String get_name() { return "topn_weighted"; }
static void add(AggregateFunctionTopNData<T>& __restrict place, const
IColumn** columns,
size_t row_num) {
if constexpr (has_default_param) {
@@ -304,14 +299,14 @@ struct AggregateFunctionTopNImplWeight {
};
//base function
-template <typename Impl, PrimitiveType T>
+template <typename Impl>
class AggregateFunctionTopNBase
- : public IAggregateFunctionDataHelper<AggregateFunctionTopNData<T>,
- AggregateFunctionTopNBase<Impl,
T>> {
+ : public IAggregateFunctionDataHelper<typename Impl::Data,
+ AggregateFunctionTopNBase<Impl>>
{
public:
AggregateFunctionTopNBase(const DataTypes& argument_types_)
- : IAggregateFunctionDataHelper<AggregateFunctionTopNData<T>,
- AggregateFunctionTopNBase<Impl,
T>>(argument_types_) {}
+ : IAggregateFunctionDataHelper<typename Impl::Data,
AggregateFunctionTopNBase<Impl>>(
+ argument_types_) {}
void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
@@ -336,11 +331,11 @@ public:
};
//topn function return string
-template <typename Impl, PrimitiveType T = TYPE_STRING>
-class AggregateFunctionTopN final : public AggregateFunctionTopNBase<Impl, T> {
+template <typename Impl>
+class AggregateFunctionTopN final : public AggregateFunctionTopNBase<Impl> {
public:
AggregateFunctionTopN(const DataTypes& argument_types_)
- : AggregateFunctionTopNBase<Impl, T>(argument_types_) {}
+ : AggregateFunctionTopNBase<Impl>(argument_types_) {}
String get_name() const override { return "topn"; }
@@ -353,20 +348,14 @@ public:
};
//topn function return array
-template <typename Impl, PrimitiveType T, bool is_weighted>
-class AggregateFunctionTopNArray final : public
AggregateFunctionTopNBase<Impl, T> {
+template <typename Impl>
+class AggregateFunctionTopNArray final : public
AggregateFunctionTopNBase<Impl> {
public:
AggregateFunctionTopNArray(const DataTypes& argument_types_)
- : AggregateFunctionTopNBase<Impl, T>(argument_types_),
+ : AggregateFunctionTopNBase<Impl>(argument_types_),
_argument_type(argument_types_[0]) {}
- String get_name() const override {
- if constexpr (is_weighted) {
- return "topn_weighted";
- } else {
- return "topn_array";
- }
- }
+ String get_name() const override { return Impl::get_name(); }
DataTypePtr get_return_type() const override {
return std::make_shared<DataTypeArray>(make_nullable(_argument_type));
@@ -376,7 +365,7 @@ public:
auto& to_arr = assert_cast<ColumnArray&>(to);
auto& to_nested_col = to_arr.get_data();
if (to_nested_col.is_nullable()) {
- auto col_null = reinterpret_cast<ColumnNullable*>(&to_nested_col);
+ auto* col_null = assert_cast<ColumnNullable*>(&to_nested_col);
this->data(place).insert_result_into(col_null->get_nested_column());
col_null->get_null_map_data().resize_fill(col_null->get_nested_column().size(),
0);
} else {
diff --git a/be/src/vec/aggregate_functions/helpers.h
b/be/src/vec/aggregate_functions/helpers.h
index 900d1b2806d..55a59cf6c5f 100644
--- a/be/src/vec/aggregate_functions/helpers.h
+++ b/be/src/vec/aggregate_functions/helpers.h
@@ -185,31 +185,30 @@ struct CurryDirectAndData {
using T = AggregateFunctionTemplate<Type, Data<Type>>;
};
-template <bool allow_integer, bool allow_float, bool allow_decimal, int
define_index = 0>
+template <bool allow_integer, bool allow_float, bool allow_decimal, bool
allow_stringlike,
+ bool allow_datelike, int define_index = 0>
struct creator_with_type_base {
template <typename Class, typename... TArgs>
static AggregateFunctionPtr create_base(const DataTypes& argument_types,
const bool result_is_nullable,
TArgs&&... args) {
+ auto create = [&]<PrimitiveType Ptype>() {
+ return creator_without_type::create<typename Class::template
T<Ptype>>(
+ argument_types, result_is_nullable,
std::forward<TArgs>(args)...);
+ };
if constexpr (allow_integer) {
switch (argument_types[define_index]->get_primitive_type()) {
case PrimitiveType::TYPE_BOOLEAN:
- return creator_without_type::create<typename Class::template
T<TYPE_BOOLEAN>>(
- argument_types, result_is_nullable,
std::forward<TArgs>(args)...);
+ return create.template
operator()<PrimitiveType::TYPE_BOOLEAN>();
case PrimitiveType::TYPE_TINYINT:
- return creator_without_type::create<typename Class::template
T<TYPE_TINYINT>>(
- argument_types, result_is_nullable,
std::forward<TArgs>(args)...);
+ return create.template
operator()<PrimitiveType::TYPE_TINYINT>();
case PrimitiveType::TYPE_SMALLINT:
- return creator_without_type::create<typename Class::template
T<TYPE_SMALLINT>>(
- argument_types, result_is_nullable,
std::forward<TArgs>(args)...);
+ return create.template
operator()<PrimitiveType::TYPE_SMALLINT>();
case PrimitiveType::TYPE_INT:
- return creator_without_type::create<typename Class::template
T<TYPE_INT>>(
- argument_types, result_is_nullable,
std::forward<TArgs>(args)...);
+ return create.template operator()<PrimitiveType::TYPE_INT>();
case PrimitiveType::TYPE_BIGINT:
- return creator_without_type::create<typename Class::template
T<TYPE_BIGINT>>(
- argument_types, result_is_nullable,
std::forward<TArgs>(args)...);
+ return create.template
operator()<PrimitiveType::TYPE_BIGINT>();
case PrimitiveType::TYPE_LARGEINT:
- return creator_without_type::create<typename Class::template
T<TYPE_LARGEINT>>(
- argument_types, result_is_nullable,
std::forward<TArgs>(args)...);
+ return create.template
operator()<PrimitiveType::TYPE_LARGEINT>();
default:
break;
}
@@ -217,11 +216,9 @@ struct creator_with_type_base {
if constexpr (allow_float) {
switch (argument_types[define_index]->get_primitive_type()) {
case PrimitiveType::TYPE_FLOAT:
- return creator_without_type::create<typename Class::template
T<TYPE_FLOAT>>(
- argument_types, result_is_nullable,
std::forward<TArgs>(args)...);
+ return create.template operator()<PrimitiveType::TYPE_FLOAT>();
case PrimitiveType::TYPE_DOUBLE:
- return creator_without_type::create<typename Class::template
T<TYPE_DOUBLE>>(
- argument_types, result_is_nullable,
std::forward<TArgs>(args)...);
+ return create.template
operator()<PrimitiveType::TYPE_DOUBLE>();
default:
break;
}
@@ -229,24 +226,47 @@ struct creator_with_type_base {
if constexpr (allow_decimal) {
switch (argument_types[define_index]->get_primitive_type()) {
case PrimitiveType::TYPE_DECIMAL32:
- return creator_without_type::create<typename Class::template
T<TYPE_DECIMAL32>>(
- argument_types, result_is_nullable,
std::forward<TArgs>(args)...);
+ return create.template
operator()<PrimitiveType::TYPE_DECIMAL32>();
case PrimitiveType::TYPE_DECIMAL64:
- return creator_without_type::create<typename Class::template
T<TYPE_DECIMAL64>>(
- argument_types, result_is_nullable,
std::forward<TArgs>(args)...);
+ return create.template
operator()<PrimitiveType::TYPE_DECIMAL64>();
case PrimitiveType::TYPE_DECIMALV2:
- return creator_without_type::create<typename Class::template
T<TYPE_DECIMALV2>>(
- argument_types, result_is_nullable,
std::forward<TArgs>(args)...);
+ return create.template
operator()<PrimitiveType::TYPE_DECIMALV2>();
case PrimitiveType::TYPE_DECIMAL128I:
- return creator_without_type::create<typename Class::template
T<TYPE_DECIMAL128I>>(
- argument_types, result_is_nullable,
std::forward<TArgs>(args)...);
+ return create.template
operator()<PrimitiveType::TYPE_DECIMAL128I>();
case PrimitiveType::TYPE_DECIMAL256:
- return creator_without_type::create<typename Class::template
T<TYPE_DECIMAL256>>(
- argument_types, result_is_nullable,
std::forward<TArgs>(args)...);
+ return create.template
operator()<PrimitiveType::TYPE_DECIMAL256>();
default:
break;
}
}
+
+ if constexpr (allow_stringlike) {
+ switch (argument_types[define_index]->get_primitive_type()) {
+ case PrimitiveType::TYPE_CHAR:
+ case PrimitiveType::TYPE_VARCHAR:
+ case PrimitiveType::TYPE_STRING:
+ case PrimitiveType::TYPE_JSONB:
+ return create.template
operator()<PrimitiveType::TYPE_VARCHAR>();
+ default:
+ break;
+ }
+ }
+
+ if constexpr (allow_datelike) {
+ switch (argument_types[define_index]->get_primitive_type()) {
+ case PrimitiveType::TYPE_DATE:
+ return create.template operator()<PrimitiveType::TYPE_DATE>();
+ case PrimitiveType::TYPE_DATETIME:
+ return create.template
operator()<PrimitiveType::TYPE_DATETIME>();
+ case PrimitiveType::TYPE_DATEV2:
+ return create.template
operator()<PrimitiveType::TYPE_DATEV2>();
+ case PrimitiveType::TYPE_DATETIMEV2:
+ return create.template
operator()<PrimitiveType::TYPE_DATETIMEV2>();
+ default:
+ break;
+ }
+ }
+
return nullptr;
}
@@ -312,10 +332,11 @@ struct creator_with_type_base {
}
};
-using creator_with_integer_type = creator_with_type_base<true, false, false>;
-using creator_with_numeric_type = creator_with_type_base<true, true, false>;
-using creator_with_decimal_type = creator_with_type_base<false, false, true>;
-using creator_with_type = creator_with_type_base<true, true, true>;
+using creator_with_integer_type = creator_with_type_base<true, false, false,
false, false>;
+using creator_with_numeric_type = creator_with_type_base<true, true, false,
false, false>;
+using creator_with_decimal_type = creator_with_type_base<false, false, true,
false, false>;
+using creator_with_type = creator_with_type_base<true, true, true, false,
false>;
+using creator_with_any = creator_with_type_base<true, true, true, true, true>;
} // namespace doris::vectorized
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Histogram.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Histogram.java
index 827c57facd7..bc4fa336960 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Histogram.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Histogram.java
@@ -23,7 +23,6 @@ import
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSi
import org.apache.doris.nereids.trees.expressions.functions.SearchSignature;
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
-import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.VarcharType;
import org.apache.doris.nereids.types.coercion.AnyDataType;
@@ -43,9 +42,7 @@ public class Histogram extends NotNullableAggregateFunction
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
.args(AnyDataType.INSTANCE_WITHOUT_INDEX),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
- .args(AnyDataType.INSTANCE_WITHOUT_INDEX,
IntegerType.INSTANCE),
- FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
- .args(AnyDataType.INSTANCE_WITHOUT_INDEX,
DoubleType.INSTANCE, IntegerType.INSTANCE)
+ .args(AnyDataType.INSTANCE_WITHOUT_INDEX,
IntegerType.INSTANCE)
);
private Histogram(boolean distinct, List<Expression> args) {
diff --git a/regression-test/suites/nereids_function_p0/agg_function/agg.groovy
b/regression-test/suites/nereids_function_p0/agg_function/agg.groovy
index e581626b574..b3dace65e72 100644
--- a/regression-test/suites/nereids_function_p0/agg_function/agg.groovy
+++ b/regression-test/suites/nereids_function_p0/agg_function/agg.groovy
@@ -1101,8 +1101,6 @@ suite("nereids_agg_fn") {
select histogram(kbool) from fn_test'''
sql '''
select histogram(kbool, 10) from fn_test'''
- sql '''
- select histogram(kbool, 10, 10) from fn_test'''
sql '''
select count(id), histogram(kbool) from fn_test group by id
order by id'''
sql '''
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]