This is an automated email from the ASF dual-hosted git repository.
lihaopeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git
The following commit(s) were added to refs/heads/master by this push:
new 143c408 [Feature][Vectorized] support aggregate function
ndv()/approx_count_distinct() (#8044)
143c408 is described below
commit 143c4085ee58007954f3eef8910556f5b8ce6b39
Author: Pxl <[email protected]>
AuthorDate: Wed Feb 16 14:30:13 2022 +0800
[Feature][Vectorized] support aggregate function
ndv()/approx_count_distinct() (#8044)
---
be/src/vec/CMakeLists.txt | 1 +
.../aggregate_function_approx_count_distinct.cpp | 50 ++++++++++
.../aggregate_function_approx_count_distinct.h | 107 +++++++++++++++++++++
.../aggregate_functions/aggregate_function_avg.cpp | 4 -
.../aggregate_function_simple_factory.cpp | 5 +
.../aggregate_function_simple_factory.h | 7 +-
be/src/vec/common/string_ref.h | 6 +-
be/src/vec/functions/function_case.h | 55 ++---------
be/src/vec/functions/function_coalesce.cpp | 102 ++++++++------------
be/src/vec/functions/function_hash.cpp | 80 ++++++---------
be/src/vec/utils/template_helpers.hpp | 69 +++++++++++++
.../java/org/apache/doris/catalog/FunctionSet.java | 54 +++++++----
12 files changed, 352 insertions(+), 188 deletions(-)
diff --git a/be/src/vec/CMakeLists.txt b/be/src/vec/CMakeLists.txt
index c526cf2..7f8878f 100644
--- a/be/src/vec/CMakeLists.txt
+++ b/be/src/vec/CMakeLists.txt
@@ -33,6 +33,7 @@ set(VEC_FILES
aggregate_functions/aggregate_function_window.cpp
aggregate_functions/aggregate_function_stddev.cpp
aggregate_functions/aggregate_function_topn.cpp
+ aggregate_functions/aggregate_function_approx_count_distinct.cpp
aggregate_functions/aggregate_function_simple_factory.cpp
columns/collator.cpp
columns/column.cpp
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp
b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp
new file mode 100644
index 0000000..fc68d85
--- /dev/null
+++
b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.cpp
@@ -0,0 +1,50 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "vec/aggregate_functions/aggregate_function_approx_count_distinct.h"
+
+#include "vec/utils/template_helpers.hpp"
+
+namespace doris::vectorized {
+
+AggregateFunctionPtr create_aggregate_function_approx_count_distinct(
+ const std::string& name, const DataTypes& argument_types, const Array&
parameters,
+ const bool result_is_nullable) {
+ AggregateFunctionPtr res = nullptr;
+ WhichDataType which(argument_types[0]->is_nullable()
+ ? reinterpret_cast<const
DataTypeNullable*>(argument_types[0].get())
+ ->get_nested_type()
+ : argument_types[0]);
+
+
res.reset(create_class_with_type<AggregateFunctionApproxCountDistinct>(*argument_types[0],
+
argument_types));
+
+ if (!res) {
+ LOG(WARNING) << fmt::format("Illegal type {} of argument for aggregate
function {}",
+ argument_types[0]->get_name(), name);
+ }
+
+ return res;
+}
+
+void
register_aggregate_function_approx_count_distinct(AggregateFunctionSimpleFactory&
factory) {
+ factory.register_function("approx_count_distinct",
+ create_aggregate_function_approx_count_distinct);
+ factory.register_alias("approx_count_distinct", "ndv");
+}
+
+} // namespace doris::vectorized
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.h
b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.h
new file mode 100644
index 0000000..ed393af
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.h
@@ -0,0 +1,107 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "exprs/anyval_util.h"
+#include "olap/hll.h"
+#include "udf/udf.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/common/string_ref.h"
+#include "vec/io/io_helper.h"
+
+namespace doris::vectorized {
+
+struct AggregateFunctionApproxCountDistinctData {
+ HyperLogLog hll_data;
+
+ void add(StringRef value) {
+ StringVal sv = value.to_string_val();
+ uint64_t hash_value = AnyValUtil::hash64_murmur(sv,
HashUtil::MURMUR_SEED);
+ if (hash_value != 0) {
+ hll_data.update(hash_value);
+ }
+ }
+
+ void merge(const AggregateFunctionApproxCountDistinctData& rhs) {
+ hll_data.merge(rhs.hll_data);
+ }
+
+ void write(BufferWritable& buf) const {
+ std::string result;
+ result.resize(hll_data.max_serialized_size());
+ int size = hll_data.serialize((uint8_t*)result.data());
+ result.resize(size);
+ write_binary(result, buf);
+ }
+
+ void read(BufferReadable& buf) {
+ StringRef result;
+ read_binary(result, buf);
+ Slice data = Slice(result.data, result.size);
+ hll_data.deserialize(data);
+ }
+
+ int64_t get() const { return hll_data.estimate_cardinality(); }
+
+ void reset() { hll_data.clear(); }
+};
+
+template <typename ColumnDataType>
+class AggregateFunctionApproxCountDistinct final
+ : public IAggregateFunctionDataHelper<
+ AggregateFunctionApproxCountDistinctData,
+ AggregateFunctionApproxCountDistinct<ColumnDataType>> {
+public:
+ String get_name() const override { return "approx_count_distinct"; }
+
+ AggregateFunctionApproxCountDistinct(const DataTypes& argument_types_)
+ :
IAggregateFunctionDataHelper<AggregateFunctionApproxCountDistinctData,
+
AggregateFunctionApproxCountDistinct<ColumnDataType>>(
+ argument_types_, {}) {}
+
+ DataTypePtr get_return_type() const override { return
std::make_shared<DataTypeInt64>(); }
+
+ void add(AggregateDataPtr __restrict place, const IColumn** columns,
size_t row_num,
+ Arena*) const override {
+ this->data(place).add(static_cast<const
ColumnDataType*>(columns[0])->get_data_at(row_num));
+ }
+
+ void reset(AggregateDataPtr place) const override {
this->data(place).reset(); }
+
+ void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+ Arena*) const override {
+ this->data(place).merge(this->data(rhs));
+ }
+
+ void serialize(ConstAggregateDataPtr __restrict place, BufferWritable&
buf) const override {
+ this->data(place).write(buf);
+ }
+
+ void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+ Arena*) const override {
+ this->data(place).read(buf);
+ }
+
+ void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn&
to) const override {
+ auto& column = static_cast<ColumnInt64&>(to);
+ column.get_data().push_back(this->data(place).get());
+ }
+};
+
+} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp
b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp
index bb7605b..61687af 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_avg.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_avg.cpp
@@ -27,8 +27,6 @@
namespace doris::vectorized {
-namespace {
-
template <typename T>
struct Avg {
using FieldType = std::conditional_t<IsDecimalNumber<T>, Decimal128,
NearestFieldType<T>>;
@@ -60,8 +58,6 @@ AggregateFunctionPtr create_aggregate_function_avg(const
std::string& name,
return res;
}
-} // namespace
-
void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory) {
factory.register_function("avg", create_aggregate_function_avg);
}
diff --git
a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
index 4844000..87a52b9 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
@@ -23,7 +23,9 @@
#include "vec/aggregate_functions/aggregate_function_reader.h"
namespace doris::vectorized {
+
class AggregateFunctionSimpleFactory;
+
void register_aggregate_function_sum(AggregateFunctionSimpleFactory& factory);
void
register_aggregate_function_combinator_null(AggregateFunctionSimpleFactory&
factory);
void register_aggregate_function_minmax(AggregateFunctionSimpleFactory&
factory);
@@ -37,6 +39,8 @@ void
register_aggregate_function_window_rank(AggregateFunctionSimpleFactory& fac
void
register_aggregate_function_window_lead_lag(AggregateFunctionSimpleFactory&
factory);
void
register_aggregate_function_stddev_variance(AggregateFunctionSimpleFactory&
factory);
void register_aggregate_function_topn(AggregateFunctionSimpleFactory& factory);
+void
register_aggregate_function_approx_count_distinct(AggregateFunctionSimpleFactory&
factory);
+
AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
static std::once_flag oc;
static AggregateFunctionSimpleFactory instance;
@@ -53,6 +57,7 @@ AggregateFunctionSimpleFactory&
AggregateFunctionSimpleFactory::instance() {
register_aggregate_function_window_rank(instance);
register_aggregate_function_stddev_variance(instance);
register_aggregate_function_topn(instance);
+ register_aggregate_function_approx_count_distinct(instance);
// if you only register function with no nullable, and wants to add
nullable automatically, you should place function above this line
register_aggregate_function_combinator_null(instance);
diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
index 1bac4f1..833e52d 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.h
@@ -46,6 +46,7 @@ private:
AggregateFunctions aggregate_functions;
AggregateFunctions nullable_aggregate_functions;
std::unordered_map<std::string, std::string> function_alias;
+
public:
void register_nullable_function_combinator(const Creator& creator) {
for (const auto& entity : aggregate_functions) {
@@ -86,13 +87,13 @@ public:
if (nullable) {
return nullable_aggregate_functions.find(name_str) ==
nullable_aggregate_functions.end()
? nullptr
- : nullable_aggregate_functions[name_str](name_str,
argument_types, parameters,
-
result_is_nullable);
+ : nullable_aggregate_functions[name_str](name_str,
argument_types,
+
parameters, result_is_nullable);
} else {
return aggregate_functions.find(name_str) ==
aggregate_functions.end()
? nullptr
: aggregate_functions[name_str](name_str,
argument_types, parameters,
- result_is_nullable);
+ result_is_nullable);
}
}
diff --git a/be/src/vec/common/string_ref.h b/be/src/vec/common/string_ref.h
index 727996e..5dd146e 100644
--- a/be/src/vec/common/string_ref.h
+++ b/be/src/vec/common/string_ref.h
@@ -55,9 +55,13 @@ struct StringRef {
explicit operator std::string() const { return to_string(); }
- StringVal to_string_val() const {
+ StringVal to_string_val() {
return StringVal(reinterpret_cast<uint8_t*>(const_cast<char*>(data)),
size);
}
+
+ static StringRef from_string_val(StringVal sv) {
+ return StringRef(reinterpret_cast<char*>(sv.ptr), sv.len);
+ }
};
using StringRefs = std::vector<StringRef>;
diff --git a/be/src/vec/functions/function_case.h
b/be/src/vec/functions/function_case.h
index 1113b5f..1b728e9 100644
--- a/be/src/vec/functions/function_case.h
+++ b/be/src/vec/functions/function_case.h
@@ -17,11 +17,11 @@
#pragma once
-#include "vec/data_types/data_type_decimal.h"
#include "vec/data_types/data_type_nullable.h"
#include "vec/functions/function.h"
#include "vec/functions/function_helpers.h"
#include "vec/functions/simple_function_factory.h"
+#include "vec/utils/template_helpers.hpp"
namespace doris::vectorized {
@@ -311,51 +311,14 @@ public:
? reinterpret_cast<const
DataTypeNullable*>(data_type.get())
->get_nested_type()
: data_type);
-
- // TODO: use template traits here.
- if (which.is_uint8()) {
- return execute_get_when_null<ColumnUInt8>(data_type, block,
arguments, result,
- input_rows_count);
- } else if (which.is_int16()) {
- return execute_get_when_null<ColumnInt16>(data_type, block,
arguments, result,
- input_rows_count);
- } else if (which.is_uint32()) {
- return execute_get_when_null<ColumnUInt32>(data_type, block,
arguments, result,
- input_rows_count);
- } else if (which.is_uint64()) {
- return execute_get_when_null<ColumnUInt64>(data_type, block,
arguments, result,
- input_rows_count);
- } else if (which.is_int8()) {
- return execute_get_when_null<ColumnInt8>(data_type, block,
arguments, result,
- input_rows_count);
- } else if (which.is_int16()) {
- return execute_get_when_null<ColumnInt16>(data_type, block,
arguments, result,
- input_rows_count);
- } else if (which.is_int32()) {
- return execute_get_when_null<ColumnInt32>(data_type, block,
arguments, result,
- input_rows_count);
- } else if (which.is_int64()) {
- return execute_get_when_null<ColumnInt64>(data_type, block,
arguments, result,
- input_rows_count);
- } else if (which.is_date_or_datetime()) {
- return execute_get_when_null<ColumnVector<DateTime>>(data_type,
block, arguments,
- result,
input_rows_count);
- } else if (which.is_float32()) {
- return execute_get_when_null<ColumnFloat32>(data_type, block,
arguments, result,
- input_rows_count);
- } else if (which.is_float64()) {
- return execute_get_when_null<ColumnFloat64>(data_type, block,
arguments, result,
- input_rows_count);
- } else if (which.is_decimal()) {
- return execute_get_when_null<ColumnDecimal<Decimal128>>(data_type,
block, arguments,
- result,
input_rows_count);
- } else if (which.is_string()) {
- return execute_get_when_null<ColumnString>(data_type, block,
arguments, result,
- input_rows_count);
- } else {
- return Status::NotSupported(fmt::format("Unexpected type {} of
argument of function {}",
- data_type->get_name(),
get_name()));
- }
+#define DISPATCH(TYPE, COLUMN_TYPE)
\
+ if (which.idx == TypeIndex::TYPE)
\
+ return execute_get_when_null<COLUMN_TYPE>(data_type, block, arguments,
result, \
+ input_rows_count);
+ TYPE_TO_COLUMN_TYPE(DISPATCH)
+#undef DISPATCH
+ return Status::NotSupported(
+ fmt::format("argument_type {} not supported",
data_type->get_name()));
}
Status execute_impl(FunctionContext* context, Block& block, const
ColumnNumbers& arguments,
diff --git a/be/src/vec/functions/function_coalesce.cpp
b/be/src/vec/functions/function_coalesce.cpp
index 4991fa8..91d6304 100644
--- a/be/src/vec/functions/function_coalesce.cpp
+++ b/be/src/vec/functions/function_coalesce.cpp
@@ -16,11 +16,10 @@
// under the License.
#include "udf/udf.h"
-#include "vec/data_types/data_type_nothing.h"
-#include "vec/data_types/data_type_number.h"
#include "vec/data_types/get_least_supertype.h"
#include "vec/functions/function_helpers.h"
#include "vec/functions/simple_function_factory.h"
+#include "vec/utils/template_helpers.hpp"
#include "vec/utils/util.hpp"
namespace doris::vectorized {
@@ -53,11 +52,9 @@ public:
res = res ? res : arguments[0];
- const ColumnsWithTypeAndName is_not_null_col{
- {nullptr, make_nullable(res), ""}
- };
- func_is_not_null = SimpleFunctionFactory::instance().
- get_function("is_not_null_pred", is_not_null_col,
std::make_shared<DataTypeUInt8>());
+ const ColumnsWithTypeAndName is_not_null_col {{nullptr,
make_nullable(res), ""}};
+ func_is_not_null = SimpleFunctionFactory::instance().get_function(
+ "is_not_null_pred", is_not_null_col,
std::make_shared<DataTypeUInt8>());
return res;
}
@@ -74,7 +71,8 @@ public:
filtered_args.push_back(arguments[i]);
if (!arg_type->is_nullable()) {
if (i == 0) { //if the first column not null, return it's
directly
- block.get_by_position(result).column =
block.get_by_position(arguments[0]).column;
+ block.get_by_position(result).column =
+ block.get_by_position(arguments[0]).column;
return Status::OK();
} else {
break;
@@ -84,8 +82,12 @@ public:
size_t remaining_rows = input_rows_count;
size_t argument_size = filtered_args.size();
- std::vector<uint32_t> record_idx(input_rows_count, 0); //used to save
column idx, record the result data of each row from which column
- std::vector<uint8_t> filled_flags(input_rows_count, 0); //used to save
filled flag, in order to check current row whether have filled data
+ std::vector<uint32_t> record_idx(
+ input_rows_count,
+ 0); //used to save column idx, record the result data of each
row from which column
+ std::vector<uint8_t> filled_flags(
+ input_rows_count,
+ 0); //used to save filled flag, in order to check current row
whether have filled data
MutableColumnPtr result_column;
if (!result_type->is_nullable()) {
@@ -104,7 +106,8 @@ public:
}
auto return_type = std::make_shared<DataTypeUInt8>();
- auto null_map = ColumnUInt8::create(input_rows_count, 1); //if
null_map_data==1, the current row should be null
+ auto null_map = ColumnUInt8::create(
+ input_rows_count, 1); //if null_map_data==1, the current row
should be null
auto* __restrict null_map_data = null_map->get_data().data();
ColumnPtr argument_columns[argument_size]; //use to save nested_column
if is nullable column
@@ -119,17 +122,17 @@ public:
}
Block temporary_block {
- ColumnsWithTypeAndName {
- block.get_by_position(filtered_args[0]),
- {nullptr, std::make_shared<DataTypeUInt8>(), ""}
- }
- };
+ ColumnsWithTypeAndName
{block.get_by_position(filtered_args[0]),
+ {nullptr,
std::make_shared<DataTypeUInt8>(), ""}}};
for (size_t i = 0; i < argument_size && remaining_rows; ++i) {
- temporary_block.get_by_position(0).column =
block.get_by_position(filtered_args[i]).column;
+ temporary_block.get_by_position(0).column =
+ block.get_by_position(filtered_args[i]).column;
func_is_not_null->execute(context, temporary_block, {0}, 1,
input_rows_count);
- auto res_column =
(*temporary_block.get_by_position(1).column->convert_to_full_column_if_const()).mutate();
+ auto res_column =
+
(*temporary_block.get_by_position(1).column->convert_to_full_column_if_const())
+ .mutate();
auto& res_map =
assert_cast<ColumnVector<UInt8>*>(res_column.get())->get_data();
auto* __restrict res = res_map.data();
@@ -152,7 +155,8 @@ public:
if (is_same_column_count == input_rows_count) {
if (result_type->is_nullable()) {
- block.get_by_position(result).column =
make_nullable(argument_columns[i], false);
+ block.get_by_position(result).column =
+ make_nullable(argument_columns[i], false);
} else {
block.get_by_position(result).column =
argument_columns[i];
}
@@ -170,7 +174,7 @@ public:
}
if (is_string_result) {
- //if string type, should according to the record results, fill in
result one by one,
+ //if string type, should according to the record results, fill in
result one by one,
for (size_t row = 0; row < input_rows_count; ++row) {
if (null_map_data[row]) { //should be null
result_column->insert_default();
@@ -181,7 +185,8 @@ public:
}
if (result_type->is_nullable()) {
- block.replace_by_position(result,
ColumnNullable::create(std::move(result_column), std::move(null_map)));
+ block.replace_by_position(
+ result, ColumnNullable::create(std::move(result_column),
std::move(null_map)));
} else {
block.replace_by_position(result, std::move(result_column));
}
@@ -198,18 +203,17 @@ public:
auto* __restrict column_raw_data =
reinterpret_cast<const
ColumnType*>(argument_column.get())->get_data().data();
-
// Here it's SIMD thought the compiler automatically also
// true: null_map_data[row]==0 && filled_idx[row]==0
// if true, could filled current row data into result column
for (size_t row = 0; row < input_rows_count; ++row) {
- result_raw_data[row] += (!(null_map_data[row] | filled_flag[row]))
* column_raw_data[row];
+ result_raw_data[row] +=
+ (!(null_map_data[row] | filled_flag[row])) *
column_raw_data[row];
filled_flag[row] += (!(null_map_data[row] | filled_flag[row]));
}
return Status::OK();
}
- //TODO: this function is same as case when, should be replaced by macro
Status filled_result_column(const DataTypePtr& data_type,
MutableColumnPtr& result_column,
ColumnPtr& argument_column, UInt8* __restrict
null_map_data,
UInt8* __restrict filled_flag, const size_t
input_rows_count) {
@@ -217,46 +221,16 @@ public:
? reinterpret_cast<const
DataTypeNullable*>(data_type.get())
->get_nested_type()
: data_type);
- if (which.is_uint8()) {
- return insert_result_data<ColumnUInt8>(result_column,
argument_column, null_map_data,
- filled_flag,
input_rows_count);
- } else if (which.is_int16()) {
- return insert_result_data<ColumnInt16>(result_column,
argument_column, null_map_data,
- filled_flag,
input_rows_count);
- } else if (which.is_uint32()) {
- return insert_result_data<ColumnUInt32>(result_column,
argument_column, null_map_data,
- filled_flag,
input_rows_count);
- } else if (which.is_uint64()) {
- return insert_result_data<ColumnUInt64>(result_column,
argument_column, null_map_data,
- filled_flag,
input_rows_count);
- } else if (which.is_int8()) {
- return insert_result_data<ColumnInt8>(result_column,
argument_column, null_map_data,
- filled_flag,
input_rows_count);
- } else if (which.is_int16()) {
- return insert_result_data<ColumnInt16>(result_column,
argument_column, null_map_data,
- filled_flag,
input_rows_count);
- } else if (which.is_int32()) {
- return insert_result_data<ColumnInt32>(result_column,
argument_column, null_map_data,
- filled_flag,
input_rows_count);
- } else if (which.is_int64()) {
- return insert_result_data<ColumnInt64>(result_column,
argument_column, null_map_data,
- filled_flag,
input_rows_count);
- } else if (which.is_date_or_datetime()) {
- return insert_result_data<ColumnVector<DateTime>>(
- result_column, argument_column, null_map_data,
filled_flag, input_rows_count);
- } else if (which.is_float32()) {
- return insert_result_data<ColumnFloat32>(result_column,
argument_column, null_map_data,
- filled_flag,
input_rows_count);
- } else if (which.is_float64()) {
- return insert_result_data<ColumnFloat64>(result_column,
argument_column, null_map_data,
- filled_flag,
input_rows_count);
- } else if (which.is_decimal()) {
- return insert_result_data<ColumnDecimal<Decimal128>>(
- result_column, argument_column, null_map_data,
filled_flag, input_rows_count);
- } else {
- return Status::NotSupported(fmt::format("Unexpected type {} of
argument of function {}",
- data_type->get_name(),
get_name()));
- }
+#define DISPATCH(TYPE, COLUMN_TYPE)
\
+ if (which.idx == TypeIndex::TYPE)
\
+ return insert_result_data<COLUMN_TYPE>(result_column, argument_column,
null_map_data, \
+ filled_flag, input_rows_count);
+ NUMERIC_TYPE_TO_COLUMN_TYPE(DISPATCH)
+ DECIMAL_TYPE_TO_COLUMN_TYPE(DISPATCH)
+ TIME_TYPE_TO_COLUMN_TYPE(DISPATCH)
+#undef DISPATCH
+ return Status::NotSupported(
+ fmt::format("argument_type {} not supported",
data_type->get_name()));
}
};
diff --git a/be/src/vec/functions/function_hash.cpp
b/be/src/vec/functions/function_hash.cpp
index 18c7bcc..92e2a55 100644
--- a/be/src/vec/functions/function_hash.cpp
+++ b/be/src/vec/functions/function_hash.cpp
@@ -23,14 +23,14 @@
#include "util/hash_util.hpp"
#include "vec/functions/function_variadic_arguments.h"
#include "vec/functions/simple_function_factory.h"
+#include "vec/utils/template_helpers.hpp"
namespace doris::vectorized {
struct MurmurHash2Impl64 {
static constexpr auto name = "murmurHash2_64";
using ReturnType = UInt64;
- static Status empty_apply(IColumn& icolumn,
- size_t input_rows_count) {
+ static Status empty_apply(IColumn& icolumn, size_t input_rows_count) {
ColumnVector<ReturnType>& vec_to =
assert_cast<ColumnVector<ReturnType>&>(icolumn);
vec_to.get_data().assign(input_rows_count,
static_cast<ReturnType>(0xe28dbde7fe22e41c));
return Status::OK();
@@ -42,8 +42,8 @@ struct MurmurHash2Impl64 {
return Status::OK();
}
- static Status combine_apply(const IDataType* type, const IColumn* column,
size_t input_rows_count,
- IColumn& icolumn) {
+ static Status combine_apply(const IDataType* type, const IColumn* column,
+ size_t input_rows_count, IColumn& icolumn) {
execute_any<false>(type, column, icolumn, input_rows_count);
return Status::OK();
}
@@ -58,7 +58,7 @@ struct MurmurHash2Impl64 {
for (size_t i = 0; i < size; ++i) {
ReturnType val = HashUtil::murmur_hash2_64(
reinterpret_cast<const char*>(reinterpret_cast<const
char*>(&vec_from[i])),
- sizeof(vec_from[i]), 0);
+ sizeof(vec_from[i]), 0);
if (first)
col_to.insert_data(const_cast<const
char*>(reinterpret_cast<char*>(&val)), 0);
else
@@ -137,38 +137,20 @@ struct MurmurHash2Impl64 {
}
template <bool first>
- static Status execute_any(const IDataType* from_type, const IColumn*
icolumn,
- IColumn& col_to, size_t input_rows_count) {
+ static Status execute_any(const IDataType* from_type, const IColumn*
icolumn, IColumn& col_to,
+ size_t input_rows_count) {
WhichDataType which(from_type);
-
- if (which.is_uint8())
- execute_int_type<UInt8, first>(icolumn, col_to, input_rows_count);
- else if (which.is_int16())
- execute_int_type<UInt16, first>(icolumn, col_to, input_rows_count);
- else if (which.is_uint32())
- execute_int_type<UInt32, first>(icolumn, col_to, input_rows_count);
- else if (which.is_uint64())
- execute_int_type<UInt64, first>(icolumn, col_to, input_rows_count);
- else if (which.is_int8())
- execute_int_type<Int8, first>(icolumn, col_to, input_rows_count);
- else if (which.is_int16())
- execute_int_type<Int16, first>(icolumn, col_to, input_rows_count);
- else if (which.is_int32())
- execute_int_type<Int32, first>(icolumn, col_to, input_rows_count);
- else if (which.is_int64())
- execute_int_type<Int64, first>(icolumn, col_to, input_rows_count);
- else if (which.is_float32())
- execute_int_type<Float32, first>(icolumn, col_to,
input_rows_count);
- else if (which.is_float64())
- execute_int_type<Float64, first>(icolumn, col_to,
input_rows_count);
- else if (which.is_string())
- execute_string<first>(icolumn, col_to, input_rows_count);
- else {
- DCHECK(false);
- return Status::NotSupported(fmt::format("Illegal column {} of
argument of function {}",
- icolumn->get_name(),
name));
+ if (which.is_string()) {
+ return execute_string<first>(icolumn, col_to, input_rows_count);
}
- return Status::OK();
+
+#define DISPATCH(TYPE, COLUMN_TYPE) \
+ if (which.idx == TypeIndex::TYPE) \
+ return execute_int_type<TYPE, first>(icolumn, col_to,
input_rows_count);
+ NUMERIC_TYPE_TO_COLUMN_TYPE(DISPATCH)
+#undef DISPATCH
+ return Status::NotSupported(
+ fmt::format("argument_type {} not supported",
from_type->get_name()));
}
};
using FunctionMurmurHash2_64 = FunctionVariadicArgumentsBase<DataTypeUInt64,
MurmurHash2Impl64>;
@@ -177,8 +159,7 @@ struct MurmurHash3Impl32 {
static constexpr auto name = "murmur_hash3_32";
using ReturnType = Int32;
- static Status empty_apply(IColumn& icolumn,
- size_t input_rows_count) {
+ static Status empty_apply(IColumn& icolumn, size_t input_rows_count) {
ColumnVector<ReturnType>& vec_to =
assert_cast<ColumnVector<ReturnType>&>(icolumn);
vec_to.get_data().assign(input_rows_count,
static_cast<ReturnType>(0xe28dbde7fe22e41c));
return Status::OK();
@@ -189,8 +170,8 @@ struct MurmurHash3Impl32 {
return execute<true>(type, column, input_rows_count, icolumn);
}
- static Status combine_apply(const IDataType* type, const IColumn* column,
size_t input_rows_count,
- IColumn& icolumn) {
+ static Status combine_apply(const IDataType* type, const IColumn* column,
+ size_t input_rows_count, IColumn& icolumn) {
return execute<false>(type, column, input_rows_count, icolumn);
}
@@ -207,15 +188,14 @@ struct MurmurHash3Impl32 {
if (first) {
UInt32 val = HashUtil::murmur_hash3_32(
reinterpret_cast<const
char*>(&data[current_offset]),
- offsets[i] - current_offset - 1,
- HashUtil::MURMUR3_32_SEED);
+ offsets[i] - current_offset - 1,
HashUtil::MURMUR3_32_SEED);
col_to.insert_data(const_cast<const
char*>(reinterpret_cast<char*>(&val)), 0);
} else {
assert_cast<ColumnVector<ReturnType>&>(col_to).get_data()[i] =
HashUtil::murmur_hash3_32(
- reinterpret_cast<const
char*>(&data[current_offset]),
- offsets[i] - current_offset - 1,
- ext::bit_cast<UInt32>(col_to[i]));
+ reinterpret_cast<const
char*>(&data[current_offset]),
+ offsets[i] - current_offset - 1,
+ ext::bit_cast<UInt32>(col_to[i]));
}
current_offset = offsets[i];
}
@@ -224,17 +204,13 @@ struct MurmurHash3Impl32 {
String value = col_from_const->get_value<String>().data();
for (size_t i = 0; i < input_rows_count; ++i) {
if (first) {
- UInt32 val = HashUtil::murmur_hash3_32(
- value.data(),
- value.size(),
- HashUtil::MURMUR3_32_SEED);
+ UInt32 val = HashUtil::murmur_hash3_32(value.data(),
value.size(),
+
HashUtil::MURMUR3_32_SEED);
col_to.insert_data(const_cast<const
char*>(reinterpret_cast<char*>(&val)), 0);
} else {
assert_cast<ColumnVector<ReturnType>&>(col_to).get_data()[i] =
- HashUtil::murmur_hash3_32(
- value.data(),
- value.size(),
- ext::bit_cast<UInt32>(col_to[i]));
+ HashUtil::murmur_hash3_32(value.data(),
value.size(),
+
ext::bit_cast<UInt32>(col_to[i]));
}
}
} else {
diff --git a/be/src/vec/utils/template_helpers.hpp
b/be/src/vec/utils/template_helpers.hpp
new file mode 100644
index 0000000..4d4e1e2
--- /dev/null
+++ b/be/src/vec/utils/template_helpers.hpp
@@ -0,0 +1,69 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+// This file is copied from
+//
https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/Helpers.h
+// and modified by Doris
+
+#pragma once
+
+#include "http/http_status.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/columns/columns_number.h"
+#include "vec/data_types/data_type.h"
+#include "vec/functions/function.h"
+
+#define NUMERIC_TYPE_TO_COLUMN_TYPE(M) \
+ M(UInt8, ColumnUInt8) \
+ M(Int8, ColumnInt8) \
+ M(Int16, ColumnInt16) \
+ M(Int32, ColumnInt32) \
+ M(Int64, ColumnInt64) \
+ M(Int128, ColumnInt128) \
+ M(Float32, ColumnFloat32) \
+ M(Float64, ColumnFloat64)
+
+#define DECIMAL_TYPE_TO_COLUMN_TYPE(M) \
+ M(Decimal32, ColumnDecimal<Decimal32>) \
+ M(Decimal64, ColumnDecimal<Decimal64>) \
+ M(Decimal128, ColumnDecimal<Decimal128>)
+
+#define STRING_TYPE_TO_COLUMN_TYPE(M) M(String, ColumnString)
+
+#define TIME_TYPE_TO_COLUMN_TYPE(M) \
+ M(Date, ColumnInt64) \
+ M(DateTime, ColumnInt64)
+
+#define TYPE_TO_COLUMN_TYPE(M) \
+ NUMERIC_TYPE_TO_COLUMN_TYPE(M) \
+ DECIMAL_TYPE_TO_COLUMN_TYPE(M) \
+ STRING_TYPE_TO_COLUMN_TYPE(M) \
+ TIME_TYPE_TO_COLUMN_TYPE(M)
+
+namespace doris::vectorized {
+
+template <template <typename> typename ClassTemplate, typename... TArgs>
+IAggregateFunction* create_class_with_type(const IDataType& argument_type,
TArgs&&... args) {
+ WhichDataType which(argument_type);
+#define DISPATCH(TYPE, COLUMN_TYPE) \
+ if (which.idx == TypeIndex::TYPE) \
+ return new ClassTemplate<COLUMN_TYPE>(std::forward<TArgs>(args)...);
+ TYPE_TO_COLUMN_TYPE(DISPATCH)
+#undef DISPATCH
+ return nullptr;
+}
+
+} // namespace doris::vectorized
diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
index 5f435d1..ca97cbc 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
@@ -1469,25 +1469,43 @@ public class
FunctionSet<min_initIN9doris_udf12DecimalV2ValEEEvPNS2_15FunctionCo
// NDV
// ndv return string
- addBuiltin(AggregateFunction.createBuiltin("ndv",
- Lists.newArrayList(t), Type.BIGINT, Type.VARCHAR,
-
"_ZN5doris12HllFunctions8hll_initEPN9doris_udf15FunctionContextEPNS1_9StringValE",
- "_ZN5doris12HllFunctions" + HLL_UPDATE_SYMBOL.get(t),
-
"_ZN5doris12HllFunctions9hll_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_",
-
"_ZN5doris12HllFunctions13hll_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
-
"_ZN5doris12HllFunctions12hll_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
- true, false, true));
+ addBuiltin(AggregateFunction.createBuiltin("ndv",
Lists.newArrayList(t), Type.BIGINT, Type.VARCHAR,
+
"_ZN5doris12HllFunctions8hll_initEPN9doris_udf15FunctionContextEPNS1_9StringValE",
+ "_ZN5doris12HllFunctions" +
HLL_UPDATE_SYMBOL.get(t),
+
"_ZN5doris12HllFunctions9hll_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_",
+
"_ZN5doris12HllFunctions13hll_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
+
"_ZN5doris12HllFunctions12hll_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
+ true, false, true));
- //APPROX_COUNT_DISTINCT
- //alias of ndv, compute approx count distinct use HyperLogLog
- addBuiltin(AggregateFunction.createBuiltin("approx_count_distinct",
- Lists.newArrayList(t), Type.BIGINT, Type.VARCHAR,
-
"_ZN5doris12HllFunctions8hll_initEPN9doris_udf15FunctionContextEPNS1_9StringValE",
- "_ZN5doris12HllFunctions" + HLL_UPDATE_SYMBOL.get(t),
-
"_ZN5doris12HllFunctions9hll_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_",
-
"_ZN5doris12HllFunctions13hll_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
-
"_ZN5doris12HllFunctions12hll_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
- true, false, true));
+ // vectorized
+ addBuiltin(AggregateFunction.createBuiltin("ndv",
Lists.newArrayList(t), Type.BIGINT, Type.VARCHAR,
+
"_ZN5doris12HllFunctions8hll_initEPN9doris_udf15FunctionContextEPNS1_9StringValE",
+ "_ZN5doris12HllFunctions" +
HLL_UPDATE_SYMBOL.get(t),
+
"_ZN5doris12HllFunctions9hll_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_",
+
"_ZN5doris12HllFunctions13hll_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
+
"_ZN5doris12HllFunctions12hll_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
+ true, false, true, true));
+
+ // APPROX_COUNT_DISTINCT
+ // alias of ndv, compute approx count distinct use HyperLogLog
+
addBuiltin(AggregateFunction.createBuiltin("approx_count_distinct",
Lists.newArrayList(t), Type.BIGINT,
+ Type.VARCHAR,
+
"_ZN5doris12HllFunctions8hll_initEPN9doris_udf15FunctionContextEPNS1_9StringValE",
+ "_ZN5doris12HllFunctions" +
HLL_UPDATE_SYMBOL.get(t),
+
"_ZN5doris12HllFunctions9hll_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_",
+
"_ZN5doris12HllFunctions13hll_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
+
"_ZN5doris12HllFunctions12hll_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
+ true, false, true));
+
+ // vectorized
+
addBuiltin(AggregateFunction.createBuiltin("approx_count_distinct",
Lists.newArrayList(t), Type.BIGINT,
+ Type.VARCHAR,
+
"_ZN5doris12HllFunctions8hll_initEPN9doris_udf15FunctionContextEPNS1_9StringValE",
+ "_ZN5doris12HllFunctions" +
HLL_UPDATE_SYMBOL.get(t),
+
"_ZN5doris12HllFunctions9hll_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_",
+
"_ZN5doris12HllFunctions13hll_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
+
"_ZN5doris12HllFunctions12hll_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE",
+ true, false, true, true));
// BITMAP_UNION_INT
addBuiltin(AggregateFunction.createBuiltin(BITMAP_UNION_INT,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]