This is an automated email from the ASF dual-hosted git repository. dataroaring pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/doris.git
commit 52f1e0cf1b1eec44231451107693e5a5f28f6fae Author: Mryange <[email protected]> AuthorDate: Tue May 28 11:17:37 2024 +0800 [feature])(function) add percentile_approx_weighted function (#35132) --- .../aggregate_function_percentile.cpp | 29 +++++- .../aggregate_function_percentile.h | 107 ++++++++++++++++--- .../doris/catalog/BuiltinAggregateFunctions.java | 14 +-- .../functions/agg/PercentileApproxWeighted.java | 115 +++++++++++++++++++++ .../visitor/AggregateFunctionVisitor.java | 5 + .../test_aggregate_percentile_approx_weighted.out | 7 ++ ...est_aggregate_percentile_approx_weighted.groovy | 62 +++++++++++ 7 files changed, 316 insertions(+), 23 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_percentile.cpp b/be/src/vec/aggregate_functions/aggregate_function_percentile.cpp index afadb5b8dca..75fce934f73 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_percentile.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_percentile.cpp @@ -32,10 +32,6 @@ AggregateFunctionPtr create_aggregate_function_percentile_approx(const std::stri if (which.idx != TypeIndex::Float64) { return nullptr; } - if (argument_types.size() == 1) { - return creator_without_type::create<AggregateFunctionPercentileApproxMerge<is_nullable>>( - remove_nullable(argument_types), result_is_nullable); - } if (argument_types.size() == 2) { return creator_without_type::create< AggregateFunctionPercentileApproxTwoParams<is_nullable>>( @@ -49,6 +45,27 @@ AggregateFunctionPtr create_aggregate_function_percentile_approx(const std::stri return nullptr; } +template <bool is_nullable> +AggregateFunctionPtr create_aggregate_function_percentile_approx_weighted( + const std::string& name, const DataTypes& argument_types, const bool result_is_nullable) { + const DataTypePtr& argument_type = remove_nullable(argument_types[0]); + WhichDataType which(argument_type); + if (which.idx != TypeIndex::Float64) { + return nullptr; + } + if (argument_types.size() == 3) { + return creator_without_type::create< + AggregateFunctionPercentileApproxWeightedThreeParams<is_nullable>>( + remove_nullable(argument_types), result_is_nullable); + } + if (argument_types.size() == 4) { + return creator_without_type::create< + AggregateFunctionPercentileApproxWeightedFourParams<is_nullable>>( + remove_nullable(argument_types), result_is_nullable); + } + return nullptr; +} + void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& factory) { factory.register_function_both("percentile", creator_with_integer_type::creator<AggregateFunctionPercentile>); @@ -62,5 +79,9 @@ void register_aggregate_function_percentile_approx(AggregateFunctionSimpleFactor create_aggregate_function_percentile_approx<false>, false); factory.register_function("percentile_approx", create_aggregate_function_percentile_approx<true>, true); + factory.register_function("percentile_approx_weighted", + create_aggregate_function_percentile_approx_weighted<false>, false); + factory.register_function("percentile_approx_weighted", + create_aggregate_function_percentile_approx_weighted<true>, true); } } // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/aggregate_functions/aggregate_function_percentile.h b/be/src/vec/aggregate_functions/aggregate_function_percentile.h index 5984a81109f..dd4af2559bb 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_percentile.h +++ b/be/src/vec/aggregate_functions/aggregate_function_percentile.h @@ -130,6 +130,11 @@ struct PercentileApproxState { target_quantile = quantile; } + void add_with_weight(double source, double weight, double quantile) { + digest->add(source, weight); + target_quantile = quantile; + } + void reset() { target_quantile = INIT_QUANTILE; init_flag = false; @@ -189,19 +194,6 @@ public: } }; -// only for merge -template <bool is_nullable> -class AggregateFunctionPercentileApproxMerge : public AggregateFunctionPercentileApprox { -public: - AggregateFunctionPercentileApproxMerge(const DataTypes& argument_types_) - : AggregateFunctionPercentileApprox(argument_types_) {} - void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, - Arena*) const override { - LOG(FATAL) << "AggregateFunctionPercentileApproxMerge do not support add()"; - __builtin_unreachable(); - } -}; - template <bool is_nullable> class AggregateFunctionPercentileApproxTwoParams : public AggregateFunctionPercentileApprox { public: @@ -283,6 +275,95 @@ public: } }; +template <bool is_nullable> +class AggregateFunctionPercentileApproxWeightedThreeParams + : public AggregateFunctionPercentileApprox { +public: + AggregateFunctionPercentileApproxWeightedThreeParams(const DataTypes& argument_types_) + : AggregateFunctionPercentileApprox(argument_types_) {} + + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + if constexpr (is_nullable) { + // sources quantile weight + double column_data[3] = {0, 0, 0}; + for (int i = 0; i < 3; ++i) { + const auto* nullable_column = check_and_get_column<ColumnNullable>(columns[i]); + if (nullable_column == nullptr) { //Not Nullable column + const auto& column = assert_cast<const ColumnVector<Float64>&>(*columns[i]); + column_data[i] = column.get_float64(row_num); + } else if (!nullable_column->is_null_at( + row_num)) { // Nullable column && Not null data + const auto& column = assert_cast<const ColumnVector<Float64>&>( + nullable_column->get_nested_column()); + column_data[i] = column.get_float64(row_num); + } else { // Nullable column && null data + if (i == 0) { + return; + } + } + } + this->data(place).init(); + this->data(place).add_with_weight(column_data[0], column_data[1], column_data[2]); + + } else { + const auto& sources = assert_cast<const ColumnVector<Float64>&>(*columns[0]); + const auto& weight = assert_cast<const ColumnVector<Float64>&>(*columns[1]); + const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[2]); + + this->data(place).init(); + this->data(place).add_with_weight(sources.get_float64(row_num), + weight.get_float64(row_num), + quantile.get_float64(row_num)); + } + } +}; + +template <bool is_nullable> +class AggregateFunctionPercentileApproxWeightedFourParams + : public AggregateFunctionPercentileApprox { +public: + AggregateFunctionPercentileApproxWeightedFourParams(const DataTypes& argument_types_) + : AggregateFunctionPercentileApprox(argument_types_) {} + void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, + Arena*) const override { + if constexpr (is_nullable) { + double column_data[4] = {0, 0, 0, 0}; + + for (int i = 0; i < 4; ++i) { + const auto* nullable_column = check_and_get_column<ColumnNullable>(columns[i]); + if (nullable_column == nullptr) { //Not Nullable column + const auto& column = assert_cast<const ColumnVector<Float64>&>(*columns[i]); + column_data[i] = column.get_float64(row_num); + } else if (!nullable_column->is_null_at( + row_num)) { // Nullable column && Not null data + const auto& column = assert_cast<const ColumnVector<Float64>&>( + nullable_column->get_nested_column()); + column_data[i] = column.get_float64(row_num); + } else { // Nullable column && null data + if (i == 0) { + return; + } + } + } + + this->data(place).init(column_data[3]); + this->data(place).add_with_weight(column_data[0], column_data[1], column_data[2]); + + } else { + const auto& sources = assert_cast<const ColumnVector<Float64>&>(*columns[0]); + const auto& weight = assert_cast<const ColumnVector<Float64>&>(*columns[1]); + const auto& quantile = assert_cast<const ColumnVector<Float64>&>(*columns[2]); + const auto& compression = assert_cast<const ColumnVector<Float64>&>(*columns[3]); + + this->data(place).init(compression.get_float64(row_num)); + this->data(place).add_with_weight(sources.get_float64(row_num), + weight.get_float64(row_num), + quantile.get_float64(row_num)); + } + } +}; + template <typename T> struct PercentileState { mutable std::vector<Counts<T>> vec_counts; diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java index 4933e83d916..28b1352eaf4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java @@ -58,6 +58,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.OrthogonalBitmap import org.apache.doris.nereids.trees.expressions.functions.agg.OrthogonalBitmapUnionCount; import org.apache.doris.nereids.trees.expressions.functions.agg.Percentile; import org.apache.doris.nereids.trees.expressions.functions.agg.PercentileApprox; +import org.apache.doris.nereids.trees.expressions.functions.agg.PercentileApproxWeighted; import org.apache.doris.nereids.trees.expressions.functions.agg.PercentileArray; import org.apache.doris.nereids.trees.expressions.functions.agg.QuantileUnion; import org.apache.doris.nereids.trees.expressions.functions.agg.Retention; @@ -126,12 +127,13 @@ public class BuiltinAggregateFunctions implements FunctionHelper { agg(Ndv.class, "approx_count_distinct", "ndv"), agg(OrthogonalBitmapIntersect.class, "orthogonal_bitmap_intersect"), agg(OrthogonalBitmapIntersectCount.class, "orthogonal_bitmap_intersect_count"), - agg(OrthogonalBitmapUnionCount.class, "orthogonal_bitmap_union_count"), - agg(Percentile.class, "percentile"), - agg(PercentileApprox.class, "percentile_approx"), - agg(PercentileArray.class, "percentile_array"), - agg(QuantileUnion.class, "quantile_union"), - agg(Retention.class, "retention"), + agg(OrthogonalBitmapUnionCount.class, "orthogonal_bitmap_union_count"), + agg(Percentile.class, "percentile"), + agg(PercentileApprox.class, "percentile_approx"), + agg(PercentileApproxWeighted.class, "percentile_approx_weighted"), + agg(PercentileArray.class, "percentile_array"), + agg(QuantileUnion.class, "quantile_union"), + agg(Retention.class, "retention"), agg(SequenceCount.class, "sequence_count"), agg(SequenceMatch.class, "sequence_match"), agg(Stddev.class, "stddev_pop", "stddev"), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileApproxWeighted.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileApproxWeighted.java new file mode 100644 index 00000000000..c0698014641 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileApproxWeighted.java @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.agg; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.DoubleType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * AggregateFunction 'percentile_approx_weighted'. + */ +public class PercentileApproxWeighted extends AggregateFunction + implements ExplicitlyCastableSignature, AlwaysNullable { + + public static final List<FunctionSignature> SIGNATURES = ImmutableList.of( + + FunctionSignature.ret(DoubleType.INSTANCE) + .args(DoubleType.INSTANCE, DoubleType.INSTANCE, DoubleType.INSTANCE), + + FunctionSignature.ret(DoubleType.INSTANCE) + .args(DoubleType.INSTANCE, DoubleType.INSTANCE, DoubleType.INSTANCE, DoubleType.INSTANCE)); + + /** + * constructor with 3 arguments. + */ + public PercentileApproxWeighted(Expression arg0, Expression arg1, Expression arg2) { + super("percentile_approx_weighted", arg0, arg1, arg2); + } + + /** + * constructor with 3 arguments. + */ + public PercentileApproxWeighted(boolean distinct, Expression arg0, Expression arg1, Expression arg2) { + super("percentile_approx_weighted", distinct, arg0, arg1, arg2); + } + + /** + * constructor with 4 arguments. + */ + public PercentileApproxWeighted(Expression arg0, Expression arg1, Expression arg2, Expression arg3) { + super("percentile_approx_weighted", arg0, arg1, arg2, arg3); + } + + /** + * constructor with 5 arguments. + */ + public PercentileApproxWeighted(boolean distinct, Expression arg0, Expression arg1, Expression arg2, + Expression arg3) { + super("percentile_approx_weighted", distinct, arg0, arg1, arg2, arg3); + } + + @Override + public void checkLegalityBeforeTypeCoercion() { + if (!getArgument(2).isConstant()) { + throw new AnalysisException( + "percentile_approx_weighted requires the third parameter must be a constant : " + this.toSql()); + } + if (arity() == 4) { + if (!getArgument(3).isConstant()) { + throw new AnalysisException( + "percentile_approx_weighted requires the fourth parameter must be a constant : " + + this.toSql()); + } + } + } + + /** + * withDistinctAndChildren. + */ + @Override + public PercentileApproxWeighted withDistinctAndChildren(boolean distinct, List<Expression> children) { + Preconditions.checkArgument(children.size() == 3 + || children.size() == 4); + if (children.size() == 3) { + return new PercentileApproxWeighted(distinct, children.get(0), children.get(1), children.get(2)); + } else { + return new PercentileApproxWeighted(distinct, children.get(0), children.get(1), children.get(2), + children.get(3)); + } + } + + @Override + public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) { + return visitor.visitPercentileApprox(this, context); + } + + @Override + public List<FunctionSignature> getSignatures() { + return SIGNATURES; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java index bde1ae61660..38af4be08a8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java @@ -60,6 +60,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.OrthogonalBitmap import org.apache.doris.nereids.trees.expressions.functions.agg.OrthogonalBitmapUnionCount; import org.apache.doris.nereids.trees.expressions.functions.agg.Percentile; import org.apache.doris.nereids.trees.expressions.functions.agg.PercentileApprox; +import org.apache.doris.nereids.trees.expressions.functions.agg.PercentileApproxWeighted; import org.apache.doris.nereids.trees.expressions.functions.agg.PercentileArray; import org.apache.doris.nereids.trees.expressions.functions.agg.QuantileUnion; import org.apache.doris.nereids.trees.expressions.functions.agg.Retention; @@ -253,6 +254,10 @@ public interface AggregateFunctionVisitor<R, C> { return visitAggregateFunction(percentileApprox, context); } + default R visitPercentileApprox(PercentileApproxWeighted percentileApprox, C context) { + return visitAggregateFunction(percentileApprox, context); + } + default R visitPercentileArray(PercentileArray percentileArray, C context) { return visitAggregateFunction(percentileArray, context); } diff --git a/regression-test/data/nereids_p0/sql_functions/aggregate_functions/test_aggregate_percentile_approx_weighted.out b/regression-test/data/nereids_p0/sql_functions/aggregate_functions/test_aggregate_percentile_approx_weighted.out new file mode 100644 index 00000000000..0f86d83c18f --- /dev/null +++ b/regression-test/data/nereids_p0/sql_functions/aggregate_functions/test_aggregate_percentile_approx_weighted.out @@ -0,0 +1,7 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !select -- +1.0 1.6437499523162842 2.5900001525878906 4.539999485015869 6.0 + +-- !select -- +1.0 1.6437499523162842 2.5900001525878906 4.539999485015869 6.0 + diff --git a/regression-test/suites/nereids_p0/sql_functions/aggregate_functions/test_aggregate_percentile_approx_weighted.groovy b/regression-test/suites/nereids_p0/sql_functions/aggregate_functions/test_aggregate_percentile_approx_weighted.groovy new file mode 100644 index 00000000000..99d8d688599 --- /dev/null +++ b/regression-test/suites/nereids_p0/sql_functions/aggregate_functions/test_aggregate_percentile_approx_weighted.groovy @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// The cases is copied from https://github.com/trinodb/trino/tree/master +// /testing/trino-product-tests/src/main/resources/sql-tests/testcases/aggregate +// and modified by Doris. + +suite("test_aggregate_percentile_approx_weighted") { + sql "SET enable_nereids_planner=true" + sql "SET enable_fallback_to_original_planner=false" + + + sql "DROP TABLE IF EXISTS quantile_weighted_table" + + sql """ + create table quantile_weighted_table ( + k int, + w int + ) + DUPLICATE key (k) + distributed by hash(k) buckets 1 + properties( + "replication_num" = "1" + ); + """ + sql """insert into quantile_weighted_table values(1,10),(2,6),(3,4),(4,2),(5,2),(6,1),(5,4);""" + + qt_select """ + select + percentile_approx_weighted(k,w,0.1), + percentile_approx_weighted(k,w,0.35), + percentile_approx_weighted(k,w,0.55), + percentile_approx_weighted(k,w,0.78), + percentile_approx_weighted(k,w,0.99) + from quantile_weighted_table; + """ + + qt_select """ + select + percentile_approx_weighted(k,w,0.1,2048), + percentile_approx_weighted(k,w,0.35,2048), + percentile_approx_weighted(k,w,0.55,2048), + percentile_approx_weighted(k,w,0.78,2048), + percentile_approx_weighted(k,w,0.99,2048) + from quantile_weighted_table; + """ + +} \ No newline at end of file --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
