This is an automated email from the ASF dual-hosted git repository.
lihaopeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new ebf2b88f039 [optimization](agg) add more type in agg percentile
(#34423)
ebf2b88f039 is described below
commit ebf2b88f03955a9bbc35af8ec55293f3104c069f
Author: Mryange <[email protected]>
AuthorDate: Wed May 8 14:32:40 2024 +0800
[optimization](agg) add more type in agg percentile (#34423)
---
be/src/util/counts.h | 19 ++---
.../aggregate_function_percentile.cpp | 8 +-
.../aggregate_function_percentile.h | 35 ++++----
be/test/util/counts_test.cpp | 12 +--
.../expressions/functions/agg/Percentile.java | 11 ++-
.../expressions/functions/agg/PercentileArray.java | 15 +++-
.../data/mv_p0/mv_percentile/mv_percentile.out | 6 +-
.../test_aggregate_percentile_no_cast.out | 36 ++++++++
.../mv_p0/mv_percentile/mv_percentile.groovy | 2 +-
.../test_aggregate_percentile_no_cast.groovy | 97 ++++++++++++++++++++++
10 files changed, 201 insertions(+), 40 deletions(-)
diff --git a/be/src/util/counts.h b/be/src/util/counts.h
index 70469d6fa72..e479f04c620 100644
--- a/be/src/util/counts.h
+++ b/be/src/util/counts.h
@@ -138,8 +138,7 @@ public:
private:
std::unordered_map<int64_t, uint32_t> _counts;
};
-
-// #TODO use template to reduce the Counts memery. Eg: Int do not need use
int64_t
+template <typename Ty>
class Counts {
public:
Counts() = default;
@@ -150,7 +149,7 @@ public:
}
}
- void increment(int64_t key, uint32_t i) {
+ void increment(Ty key, uint32_t i) {
auto old_size = _nums.size();
_nums.resize(_nums.size() + i);
for (uint32_t j = 0; j < i; ++j) {
@@ -163,7 +162,7 @@ public:
pdqsort(_nums.begin(), _nums.end());
size_t size = _nums.size();
write_binary(size, buf);
- buf.write(reinterpret_cast<const char*>(_nums.data()),
sizeof(int64_t) * size);
+ buf.write(reinterpret_cast<const char*>(_nums.data()), sizeof(Ty)
* size);
} else {
// convert _sorted_nums_vec to _nums and do seiralize again
_convert_sorted_num_vec_to_nums();
@@ -175,7 +174,7 @@ public:
size_t size;
read_binary(size, buf);
_nums.resize(size);
- auto buff = buf.read(sizeof(int64_t) * size);
+ auto buff = buf.read(sizeof(Ty) * size);
memcpy(_nums.data(), buff.data, buff.size);
}
@@ -231,7 +230,7 @@ public:
private:
struct Node {
- int64_t value;
+ Ty value;
int array_index;
int64_t element_index;
@@ -265,8 +264,8 @@ private:
_sorted_nums_vec.clear();
}
- std::pair<int64_t, int64_t> _merge_sort_and_get_numbers(int64_t target,
bool reverse) {
- int64_t first_number = 0, second_number = 0;
+ std::pair<Ty, Ty> _merge_sort_and_get_numbers(int64_t target, bool
reverse) {
+ Ty first_number = 0, second_number = 0;
size_t count = 0;
if (reverse) {
std::priority_queue<Node> max_heap;
@@ -321,8 +320,8 @@ private:
return {first_number, second_number};
}
- vectorized::PODArray<int64_t> _nums;
- std::vector<vectorized::PODArray<int64_t>> _sorted_nums_vec;
+ vectorized::PODArray<Ty> _nums;
+ std::vector<vectorized::PODArray<Ty>> _sorted_nums_vec;
};
} // namespace doris
diff --git a/be/src/vec/aggregate_functions/aggregate_function_percentile.cpp
b/be/src/vec/aggregate_functions/aggregate_function_percentile.cpp
index 079b1da83ff..afadb5b8dca 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_percentile.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_percentile.cpp
@@ -19,6 +19,7 @@
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/helpers.h"
+#include "vec/core/types.h"
namespace doris::vectorized {
@@ -50,9 +51,10 @@ AggregateFunctionPtr
create_aggregate_function_percentile_approx(const std::stri
void register_aggregate_function_percentile(AggregateFunctionSimpleFactory&
factory) {
factory.register_function_both("percentile",
-
creator_without_type::creator<AggregateFunctionPercentile>);
- factory.register_function_both("percentile_array",
-
creator_without_type::creator<AggregateFunctionPercentileArray>);
+
creator_with_integer_type::creator<AggregateFunctionPercentile>);
+ factory.register_function_both(
+ "percentile_array",
+
creator_with_integer_type::creator<AggregateFunctionPercentileArray>);
}
void
register_aggregate_function_percentile_approx(AggregateFunctionSimpleFactory&
factory) {
diff --git a/be/src/vec/aggregate_functions/aggregate_function_percentile.h
b/be/src/vec/aggregate_functions/aggregate_function_percentile.h
index 6322a80c934..231057158ce 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_percentile.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_percentile.h
@@ -283,8 +283,9 @@ public:
}
};
+template <typename T>
struct PercentileState {
- mutable std::vector<Counts> vec_counts;
+ mutable std::vector<Counts<T>> vec_counts;
std::vector<double> vec_quantile {-1};
bool inited_flag = false;
@@ -317,7 +318,7 @@ struct PercentileState {
}
}
- void add(int64_t source, const PaddedPODArray<Float64>& quantiles, int
arg_size) {
+ void add(T source, const PaddedPODArray<Float64>& quantiles, int arg_size)
{
if (!inited_flag) {
vec_counts.resize(arg_size);
vec_quantile.resize(arg_size, -1);
@@ -346,7 +347,7 @@ struct PercentileState {
if (vec_quantile[i] == -1.0) {
vec_quantile[i] = rhs.vec_quantile[i];
}
- vec_counts[i].merge(const_cast<Counts*>(&(rhs.vec_counts[i])));
+ vec_counts[i].merge(const_cast<Counts<T>*>(&(rhs.vec_counts[i])));
}
}
@@ -366,12 +367,13 @@ struct PercentileState {
}
};
+template <typename T>
class AggregateFunctionPercentile final
- : public IAggregateFunctionDataHelper<PercentileState,
AggregateFunctionPercentile> {
+ : public IAggregateFunctionDataHelper<PercentileState<T>,
AggregateFunctionPercentile<T>> {
public:
- AggregateFunctionPercentile(const DataTypes& argument_types_)
- : IAggregateFunctionDataHelper<PercentileState,
AggregateFunctionPercentile>(
- argument_types_) {}
+ using ColVecType = ColumnVector<T>;
+ using Base = IAggregateFunctionDataHelper<PercentileState<T>,
AggregateFunctionPercentile<T>>;
+ AggregateFunctionPercentile(const DataTypes& argument_types_) :
Base(argument_types_) {}
String get_name() const override { return "percentile"; }
@@ -379,10 +381,10 @@ public:
void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
- const auto& sources = assert_cast<const
ColumnVector<Int64>&>(*columns[0]);
+ const auto& sources = assert_cast<const ColVecType&>(*columns[0]);
const auto& quantile = assert_cast<const
ColumnVector<Float64>&>(*columns[1]);
- AggregateFunctionPercentile::data(place).add(sources.get_int(row_num),
quantile.get_data(),
- 1);
+
AggregateFunctionPercentile::data(place).add(sources.get_data()[row_num],
+ quantile.get_data(), 1);
}
void reset(AggregateDataPtr __restrict place) const override {
@@ -409,12 +411,15 @@ public:
}
};
+template <typename T>
class AggregateFunctionPercentileArray final
- : public IAggregateFunctionDataHelper<PercentileState,
AggregateFunctionPercentileArray> {
+ : public IAggregateFunctionDataHelper<PercentileState<T>,
+
AggregateFunctionPercentileArray<T>> {
public:
- AggregateFunctionPercentileArray(const DataTypes& argument_types_)
- : IAggregateFunctionDataHelper<PercentileState,
AggregateFunctionPercentileArray>(
- argument_types_) {}
+ using ColVecType = ColumnVector<T>;
+ using Base =
+ IAggregateFunctionDataHelper<PercentileState<T>,
AggregateFunctionPercentileArray<T>>;
+ AggregateFunctionPercentileArray(const DataTypes& argument_types_) :
Base(argument_types_) {}
String get_name() const override { return "percentile_array"; }
@@ -424,7 +429,7 @@ public:
void add(AggregateDataPtr __restrict place, const IColumn** columns,
ssize_t row_num,
Arena*) const override {
- const auto& sources = assert_cast<const
ColumnVector<Int64>&>(*columns[0]);
+ const auto& sources = assert_cast<const ColVecType&>(*columns[0]);
const auto& quantile_array = assert_cast<const
ColumnArray&>(*columns[1]);
const auto& offset_column_data = quantile_array.get_offsets();
const auto& nested_column =
diff --git a/be/test/util/counts_test.cpp b/be/test/util/counts_test.cpp
index 20d9ea54c97..d60f235e788 100644
--- a/be/test/util/counts_test.cpp
+++ b/be/test/util/counts_test.cpp
@@ -20,6 +20,8 @@
#include <gtest/gtest-message.h>
#include <gtest/gtest-test-part.h>
+#include <cstdint>
+
#include "gtest/gtest_pred_impl.h"
namespace doris {
@@ -27,7 +29,7 @@ namespace doris {
class TCountsTest : public testing::Test {};
TEST_F(TCountsTest, TotalTest) {
- Counts counts;
+ Counts<int64_t> counts;
// 1 1 1 2 5 7 7 9 9 19
// >>> import numpy as np
// >>> a = np.array([1,1,1,2,5,7,7,9,9,19])
@@ -48,14 +50,14 @@ TEST_F(TCountsTest, TotalTest) {
counts.serialize(bw);
bw.commit();
- Counts other;
+ Counts<int64_t> other;
StringRef res(cs->get_chars().data(), cs->get_chars().size());
vectorized::BufferReadable br(res);
other.unserialize(br);
double result1 = other.terminate(0.2);
EXPECT_EQ(result, result1);
- Counts other1;
+ Counts<int64_t> other1;
other1.increment(1, 1);
other1.increment(100, 3);
other1.increment(50, 3);
@@ -66,11 +68,11 @@ TEST_F(TCountsTest, TotalTest) {
cs->clear();
other1.serialize(bw);
bw.commit();
- Counts other1_deserialized;
+ Counts<int64_t> other1_deserialized;
vectorized::BufferReadable br1(res);
other1_deserialized.unserialize(br1);
- Counts merge_res;
+ Counts<int64_t> merge_res;
merge_res.merge(&other);
merge_res.merge(&other1_deserialized);
// 1 1 1 1 2 5 7 7 9 9 10 19 50 50 50 99 99 100 100 100
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Percentile.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Percentile.java
index d85b69516a8..abc0498f6e1 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Percentile.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Percentile.java
@@ -26,6 +26,10 @@ import
org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DoubleType;
+import org.apache.doris.nereids.types.IntegerType;
+import org.apache.doris.nereids.types.LargeIntType;
+import org.apache.doris.nereids.types.SmallIntType;
+import org.apache.doris.nereids.types.TinyIntType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@@ -39,7 +43,12 @@ public class Percentile extends AggregateFunction
implements BinaryExpression, ExplicitlyCastableSignature,
PropagateNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
-
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE,
DoubleType.INSTANCE)
+
FunctionSignature.ret(DoubleType.INSTANCE).args(LargeIntType.INSTANCE,
DoubleType.INSTANCE),
+
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE,
DoubleType.INSTANCE),
+
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE,
DoubleType.INSTANCE),
+
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE,
DoubleType.INSTANCE),
+
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE,
DoubleType.INSTANCE)
+
);
/**
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java
index d4d8ed6c39a..61fcaf3b4c4 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/PercentileArray.java
@@ -26,6 +26,10 @@ import
org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DoubleType;
+import org.apache.doris.nereids.types.IntegerType;
+import org.apache.doris.nereids.types.LargeIntType;
+import org.apache.doris.nereids.types.SmallIntType;
+import org.apache.doris.nereids.types.TinyIntType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@@ -40,8 +44,15 @@ public class PercentileArray extends AggregateFunction
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE))
- .args(BigIntType.INSTANCE,
ArrayType.of(DoubleType.INSTANCE))
- );
+ .args(LargeIntType.INSTANCE,
ArrayType.of(DoubleType.INSTANCE)),
+ FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE))
+ .args(BigIntType.INSTANCE,
ArrayType.of(DoubleType.INSTANCE)),
+ FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE))
+ .args(IntegerType.INSTANCE,
ArrayType.of(DoubleType.INSTANCE)),
+ FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE))
+ .args(SmallIntType.INSTANCE,
ArrayType.of(DoubleType.INSTANCE)),
+ FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE))
+ .args(TinyIntType.INSTANCE,
ArrayType.of(DoubleType.INSTANCE)));
/**
* constructor with 2 arguments.
diff --git a/regression-test/data/mv_p0/mv_percentile/mv_percentile.out
b/regression-test/data/mv_p0/mv_percentile/mv_percentile.out
index 32e5595dac7..858d558c346 100644
--- a/regression-test/data/mv_p0/mv_percentile/mv_percentile.out
+++ b/regression-test/data/mv_p0/mv_percentile/mv_percentile.out
@@ -1,9 +1,9 @@
-- This file is automatically generated. You should know what you did if you
want to edit this
-- !select_star --
\N 4 \N d
--4 -4 -4.000000 d
-1 1 1.000000 a
-2 2 2.000000 b
+-4 -4 -4 d
+1 1 1 a
+2 2 2 b
3 -3 \N c
-- !select_mv --
diff --git
a/regression-test/data/nereids_p0/sql_functions/aggregate_functions/test_aggregate_percentile_no_cast.out
b/regression-test/data/nereids_p0/sql_functions/aggregate_functions/test_aggregate_percentile_no_cast.out
new file mode 100644
index 00000000000..1764ba21dde
--- /dev/null
+++
b/regression-test/data/nereids_p0/sql_functions/aggregate_functions/test_aggregate_percentile_no_cast.out
@@ -0,0 +1,36 @@
+-- This file is automatically generated. You should know what you did if you
want to edit this
+-- !select --
+1 10.0 10.0 10.0 [10, 10, 10]
+2 61.0 66.30000000000001 93.33 [61, 66.30000000000001, 93.33]
+3 10.0 10.0 10.0 [10, 10, 10]
+5 29.0 29.0 29.0 [29, 29, 29]
+6 101.0 101.0 101.0 [101, 101, 101]
+
+-- !select --
+1 10.0 10.0 10.0 [10, 10, 10]
+2 61.0 66.30000000000001 93.33 [61, 66.30000000000001, 93.33]
+3 10.0 10.0 10.0 [10, 10, 10]
+5 29.0 29.0 29.0 [29, 29, 29]
+6 101.0 101.0 101.0 [101, 101, 101]
+
+-- !select --
+1 10.0 10.0 10.0 [10, 10, 10]
+2 61.0 66.30000000000001 93.33 [61, 66.30000000000001, 93.33]
+3 10.0 10.0 10.0 [10, 10, 10]
+5 29.0 29.0 29.0 [29, 29, 29]
+6 101.0 101.0 101.0 [101, 101, 101]
+
+-- !select --
+1 10.0 10.0 10.0 [10, 10, 10]
+2 61.0 66.30000000000001 93.33 [61, 66.30000000000001, 93.33]
+3 10.0 10.0 10.0 [10, 10, 10]
+5 29.0 29.0 29.0 [29, 29, 29]
+6 101.0 101.0 101.0 [101, 101, 101]
+
+-- !select --
+1 10.0 10.0 10.0 [10, 10, 10]
+2 61.0 66.30000000000001 93.33 [61, 66.30000000000001, 93.33]
+3 10.0 10.0 10.0 [10, 10, 10]
+5 29.0 29.0 29.0 [29, 29, 29]
+6 101.0 101.0 101.0 [101, 101, 101]
+
diff --git a/regression-test/suites/mv_p0/mv_percentile/mv_percentile.groovy
b/regression-test/suites/mv_p0/mv_percentile/mv_percentile.groovy
index dd6cb453305..e4624d29f00 100644
--- a/regression-test/suites/mv_p0/mv_percentile/mv_percentile.groovy
+++ b/regression-test/suites/mv_p0/mv_percentile/mv_percentile.groovy
@@ -26,7 +26,7 @@ suite ("mv_percentile") {
create table d_table(
k1 int null,
k2 int not null,
- k3 decimal(28,6) null,
+ k3 bigint null,
k4 varchar(100) null
)
duplicate key (k1,k2,k3)
diff --git
a/regression-test/suites/nereids_p0/sql_functions/aggregate_functions/test_aggregate_percentile_no_cast.groovy
b/regression-test/suites/nereids_p0/sql_functions/aggregate_functions/test_aggregate_percentile_no_cast.groovy
new file mode 100644
index 00000000000..ef76aee4405
--- /dev/null
+++
b/regression-test/suites/nereids_p0/sql_functions/aggregate_functions/test_aggregate_percentile_no_cast.groovy
@@ -0,0 +1,97 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_aggregate_percentile_no_cast") {
+ sql "SET enable_nereids_planner=true"
+ sql "SET enable_fallback_to_original_planner=false"
+
+ sql "set batch_size = 4096"
+
+ sql "DROP TABLE IF EXISTS percentile_test_db"
+ sql """
+ CREATE TABLE IF NOT EXISTS percentile_test_db (
+ id int,
+ level tinyint
+ )
+ DISTRIBUTED BY HASH(id) BUCKETS 1
+ PROPERTIES (
+ "replication_num" = "1"
+ )
+ """
+ sql "INSERT INTO percentile_test_db values(1,10), (2,8), (2,114) ,(3,10)
,(5,29) ,(6,101)"
+
+ qt_select "select id,percentile(level,0.5) , percentile(level,0.55) ,
percentile(level,0.805) , percentile_array(level,[0.5,0.55,0.805])from
percentile_test_db group by id order by id"
+
+ sql "DROP TABLE IF EXISTS percentile_test_db"
+ sql """
+ CREATE TABLE IF NOT EXISTS percentile_test_db (
+ id int,
+ level smallint
+ )
+ DISTRIBUTED BY HASH(id) BUCKETS 1
+ PROPERTIES (
+ "replication_num" = "1"
+ )
+ """
+ sql "INSERT INTO percentile_test_db values(1,10), (2,8), (2,114) ,(3,10)
,(5,29) ,(6,101)"
+
+ qt_select "select id,percentile(level,0.5) , percentile(level,0.55) ,
percentile(level,0.805) , percentile_array(level,[0.5,0.55,0.805])from
percentile_test_db group by id order by id"
+
+ sql "DROP TABLE IF EXISTS percentile_test_db"
+ sql """
+ CREATE TABLE IF NOT EXISTS percentile_test_db (
+ id int,
+ level int
+ )
+ DISTRIBUTED BY HASH(id) BUCKETS 1
+ PROPERTIES (
+ "replication_num" = "1"
+ )
+ """
+ sql "INSERT INTO percentile_test_db values(1,10), (2,8), (2,114) ,(3,10)
,(5,29) ,(6,101)"
+
+ qt_select "select id,percentile(level,0.5) , percentile(level,0.55) ,
percentile(level,0.805) , percentile_array(level,[0.5,0.55,0.805])from
percentile_test_db group by id order by id"
+
+ sql "DROP TABLE IF EXISTS percentile_test_db"
+ sql """
+ CREATE TABLE IF NOT EXISTS percentile_test_db (
+ id int,
+ level bigint
+ )
+ DISTRIBUTED BY HASH(id) BUCKETS 1
+ PROPERTIES (
+ "replication_num" = "1"
+ )
+ """
+ sql "INSERT INTO percentile_test_db values(1,10), (2,8), (2,114) ,(3,10)
,(5,29) ,(6,101)"
+ qt_select "select id,percentile(level,0.5) , percentile(level,0.55) ,
percentile(level,0.805) , percentile_array(level,[0.5,0.55,0.805])from
percentile_test_db group by id order by id"
+
+ sql "DROP TABLE IF EXISTS percentile_test_db"
+ sql """
+ CREATE TABLE IF NOT EXISTS percentile_test_db (
+ id int,
+ level largeint
+ )
+ DISTRIBUTED BY HASH(id) BUCKETS 1
+ PROPERTIES (
+ "replication_num" = "1"
+ )
+ """
+ sql "INSERT INTO percentile_test_db values(1,10), (2,8), (2,114) ,(3,10)
,(5,29) ,(6,101)"
+ qt_select "select id,percentile(level,0.5) , percentile(level,0.55) ,
percentile(level,0.805) , percentile_array(level,[0.5,0.55,0.805])from
percentile_test_db group by id order by id"
+
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]