This is an automated email from the ASF dual-hosted git repository.
taiyangli pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 38c83d5dfc [GLUTEN-6387][CH] support percentile function (#6396)
38c83d5dfc is described below
commit 38c83d5dfc998e998ea78b8edb1885d7ed0dfc24
Author: 李扬 <[email protected]>
AuthorDate: Tue Oct 29 14:58:10 2024 +0800
[GLUTEN-6387][CH] support percentile function (#6396)
* support percentile function
* finish dev percentile
* fix style
* fix failed uts
* fix building
* fix failed uts
---
.../execution/CHHashAggregateExecTransformer.scala | 9 ++
.../org/apache/gluten/utils/CHExpressionUtil.scala | 1 -
.../GlutenClickHouseTPCHSaltNullParquetSuite.scala | 13 +++
.../gluten/backendsapi/velox/VeloxBackend.scala | 5 +-
.../AggregateFunctionPartialMerge.cpp | 128 ++++++++++----------
.../AggregateFunctionPartialMerge.h | 4 +-
.../Parser/AggregateFunctionParser.cpp | 53 +++++----
.../local-engine/Parser/AggregateFunctionParser.h | 6 +-
.../Parser/RelParsers/AggregateRelParser.cpp | 2 +-
.../Parser/RelParsers/WindowRelParser.cpp | 2 +-
.../ApproxPercentileParser.cpp | 130 +++------------------
.../ApproxPercentileParser.h | 48 --------
.../BloomFilterAggParser.cpp | 2 +-
.../BloomFilterAggParser.h | 2 +-
.../aggregate_function_parser/PercentileParser.cpp | 50 ++++++++
...rcentileParser.cpp => PercentileParserBase.cpp} | 97 ++++++++-------
.../PercentileParserBase.h | 62 ++++++++++
.../gluten/expression/ExpressionMappings.scala | 3 +-
.../utils/clickhouse/ClickHouseTestSettings.scala | 2 +
.../aggregate/GlutenPercentileSuite.scala | 21 ++++
.../utils/clickhouse/ClickHouseTestSettings.scala | 2 +
.../aggregate/GlutenPercentileSuite.scala | 21 ++++
.../utils/clickhouse/ClickHouseTestSettings.scala | 2 +
.../aggregate/GlutenPercentileSuite.scala | 21 ++++
.../utils/clickhouse/ClickHouseTestSettings.scala | 2 +
.../aggregate/GlutenPercentileSuite.scala | 21 ++++
.../apache/gluten/expression/ExpressionNames.scala | 1 +
27 files changed, 406 insertions(+), 304 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
index b45f4cb97c..f5e64330cd 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
@@ -390,6 +390,15 @@ case class CHHashAggregateExecTransformer(
approxPercentile.percentageExpression.dataType,
approxPercentile.percentageExpression.nullable)
(makeStructType(fields), attr.nullable)
+ case percentile: Percentile =>
+ var fields = Seq[(DataType, Boolean)]()
+ // Use percentile.nullable as the nullable of the struct type
+ // to make sure it returns null when input is empty
+ fields = fields :+ (percentile.child.dataType,
percentile.nullable)
+ fields = fields :+ (
+ percentile.percentageExpression.dataType,
+ percentile.percentageExpression.nullable)
+ (makeStructType(fields), attr.nullable)
case _ =>
(makeStructTypeSingleOne(attr.dataType, attr.nullable),
attr.nullable)
}
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
index bb23612b13..f6d18d7a22 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
@@ -167,7 +167,6 @@ case class FormatStringValidator() extends
FunctionValidator {
}
object CHExpressionUtil {
-
final val CH_AGGREGATE_FUNC_BLACKLIST: Map[String, FunctionValidator] = Map(
MAX_BY -> DefaultValidator(),
MIN_BY -> DefaultValidator()
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
index 4c9bca4422..739b040dba 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
@@ -2609,6 +2609,19 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends
GlutenClickHouseTPCHAbstr
runQueryAndCompare(sql2)({ _ => })
}
+ test("aggregate function percentile") {
+ // single percentage
+ val sql1 = "select l_linenumber % 10, percentile(l_extendedprice, 0.5) " +
+ "from lineitem group by l_linenumber % 10"
+ runQueryAndCompare(sql1)({ _ => })
+
+ // multiple percentages
+ val sql2 =
+ "select l_linenumber % 10, percentile(l_extendedprice, array(0.1, 0.2,
0.3)) " +
+ "from lineitem group by l_linenumber % 10"
+ runQueryAndCompare(sql2)({ _ => })
+ }
+
test("GLUTEN-5096: Bug fix regexp_extract diff") {
val tbl_create_sql = "create table test_tbl_5096(id bigint, data string)
using parquet"
val tbl_insert_sql = "insert into test_tbl_5096 values(1, 'abc'), (2,
'abc\n')"
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
index b9d9abf88e..251b93cc7c 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
@@ -33,7 +33,7 @@ import org.apache.gluten.utils._
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions.{Alias, CumeDist, DenseRank,
Descending, Expression, Lag, Lead, NamedExpression, NthValue, NTile,
PercentRank, RangeFrame, Rank, RowNumber, SortOrder, SpecialFrameBoundary,
SpecifiedWindowFrame}
-import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
ApproximatePercentile}
+import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
ApproximatePercentile, Percentile}
import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.execution.{ColumnarCachedBatchSerializer,
SparkPlan}
@@ -371,7 +371,8 @@ object VeloxBackendSettings extends BackendSettingsApi {
case _: RowNumber | _: Rank | _: CumeDist | _: DenseRank | _:
PercentRank |
_: NthValue | _: NTile | _: Lag | _: Lead =>
case aggrExpr: AggregateExpression
- if
!aggrExpr.aggregateFunction.isInstanceOf[ApproximatePercentile] =>
+ if
!aggrExpr.aggregateFunction.isInstanceOf[ApproximatePercentile]
+ && !aggrExpr.aggregateFunction.isInstanceOf[Percentile] =>
case _ =>
allSupported = false
}
diff --git
a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionPartialMerge.cpp
b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionPartialMerge.cpp
index 0ecb294100..ba5c8543e0 100644
--- a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionPartialMerge.cpp
+++ b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionPartialMerge.cpp
@@ -14,8 +14,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#include <AggregateFunctions/Combinators/AggregateFunctionCombinatorFactory.h>
#include <AggregateFunctions/AggregateFunctionPartialMerge.h>
+#include <AggregateFunctions/Combinators/AggregateFunctionCombinatorFactory.h>
#include <DataTypes/DataTypeAggregateFunction.h>
@@ -25,8 +25,8 @@ namespace DB
{
namespace ErrorCodes
{
- extern const int ILLEGAL_TYPE_OF_ARGUMENT;
- extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
+extern const int ILLEGAL_TYPE_OF_ARGUMENT;
+extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
}
@@ -34,79 +34,77 @@ namespace local_engine
{
namespace
{
- class AggregateFunctionCombinatorPartialMerge final : public
IAggregateFunctionCombinator
- {
- public:
- String getName() const override { return "PartialMerge"; }
-
- DataTypes transformArguments(const DataTypes & arguments) const
override
- {
- if (arguments.size() != 1)
- throw Exception(
- ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
- "Incorrect number of arguments for aggregate function with
{} suffix",
- getName());
-
- const DataTypePtr & argument = arguments[0];
+class AggregateFunctionCombinatorPartialMerge final : public
IAggregateFunctionCombinator
+{
+public:
+ String getName() const override { return "PartialMerge"; }
- const DataTypeAggregateFunction * function = typeid_cast<const
DataTypeAggregateFunction *>(argument.get());
- if (!function)
- throw Exception(
- ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
- "Illegal type {} of argument for aggregate function with
{} suffix must be AggregateFunction(...)",
- argument->getName(),
- getName());
+ DataTypes transformArguments(const DataTypes & arguments) const override
+ {
+ if (arguments.size() != 1)
+ throw Exception(
+ ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
+ "Incorrect number of arguments for aggregate function with {}
suffix",
+ getName());
+
+ const DataTypePtr & argument = arguments[0];
+
+ const DataTypeAggregateFunction * function = typeid_cast<const
DataTypeAggregateFunction *>(argument.get());
+ if (!function)
+ throw Exception(
+ ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
+ "Illegal type {} of argument for aggregate function with {}
suffix must be AggregateFunction(...)",
+ argument->getName(),
+ getName());
+
+ const DataTypeAggregateFunction * function2
+ = typeid_cast<const DataTypeAggregateFunction
*>(function->getArgumentsDataTypes()[0].get());
+ if (function2)
+ return transformArguments(function->getArgumentsDataTypes());
+ return function->getArgumentsDataTypes();
+ }
+
+ AggregateFunctionPtr transformAggregateFunction(
+ const AggregateFunctionPtr & nested_function,
+ const AggregateFunctionProperties &,
+ const DataTypes & arguments,
+ const Array & params) const override
+ {
+ DataTypePtr & argument = const_cast<DataTypePtr &>(arguments[0]);
- const DataTypeAggregateFunction * function2
- = typeid_cast<const DataTypeAggregateFunction
*>(function->getArgumentsDataTypes()[0].get());
- if (function2)
- {
- return transformArguments(function->getArgumentsDataTypes());
- }
- return function->getArgumentsDataTypes();
- }
+ const DataTypeAggregateFunction * function = typeid_cast<const
DataTypeAggregateFunction *>(argument.get());
+ if (!function)
+ throw Exception(
+ ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
+ "Illegal type {} of argument for aggregate function with {}
suffix must be AggregateFunction(...)",
+ argument->getName(),
+ getName());
- AggregateFunctionPtr transformAggregateFunction(
- const AggregateFunctionPtr & nested_function,
- const AggregateFunctionProperties &,
- const DataTypes & arguments,
- const Array & params) const override
+ while (nested_function->getName() != function->getFunctionName())
{
- DataTypePtr & argument = const_cast<DataTypePtr &>(arguments[0]);
-
- const DataTypeAggregateFunction * function = typeid_cast<const
DataTypeAggregateFunction *>(argument.get());
+ argument = function->getArgumentsDataTypes()[0];
+ function = typeid_cast<const DataTypeAggregateFunction
*>(function->getArgumentsDataTypes()[0].get());
if (!function)
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of argument for aggregate function with
{} suffix must be AggregateFunction(...)",
argument->getName(),
getName());
-
- while (nested_function->getName() != function->getFunctionName())
- {
- argument = function->getArgumentsDataTypes()[0];
- function = typeid_cast<const DataTypeAggregateFunction
*>(function->getArgumentsDataTypes()[0].get());
- if (!function)
- throw Exception(
- ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
- "Illegal type {} of argument for aggregate function
with {} suffix must be AggregateFunction(...)",
- argument->getName(),
- getName());
- }
-
- if (nested_function->getName() != function->getFunctionName())
- throw Exception(
- ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
- "Illegal type {} of argument for aggregate function with
{} suffix, because it corresponds to different aggregate "
- "function: {} instead of {}",
- argument->getName(),
- getName(),
- function->getFunctionName(),
- nested_function->getName());
-
- return
std::make_shared<AggregateFunctionPartialMerge>(nested_function, argument,
params);
}
- };
+
+ if (nested_function->getName() != function->getFunctionName())
+ throw Exception(
+ ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
+ "Illegal type {} of argument for aggregate function with {}
suffix, because it corresponds to different aggregate "
+ "function: {} instead of {}",
+ argument->getName(),
+ getName(),
+ function->getFunctionName(),
+ nested_function->getName());
+
+ return
std::make_shared<AggregateFunctionPartialMerge>(nested_function, argument,
params);
+ }
+};
}
diff --git
a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionPartialMerge.h
b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionPartialMerge.h
index bc1f40f60e..822e08c8e7 100644
--- a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionPartialMerge.h
+++ b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionPartialMerge.h
@@ -16,7 +16,7 @@
*/
#pragma once
-#include <AggregateFunctions/IAggregateFunction_fwd.h>
+#include <AggregateFunctions/IAggregateFunction.h>
#include <Columns/ColumnAggregateFunction.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <Common/assert_cast.h>
@@ -41,8 +41,6 @@ struct Settings;
* this class is copied from AggregateFunctionMerge with little enhancement.
* we use this PartialMerge for both spark PartialMerge and Final
*/
-
-
class AggregateFunctionPartialMerge final : public
IAggregateFunctionHelper<AggregateFunctionPartialMerge>
{
private:
diff --git a/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp
b/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp
index 88e60b5931..42c4230e4a 100644
--- a/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp
+++ b/cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp
@@ -15,7 +15,6 @@
* limitations under the License.
*/
#include "AggregateFunctionParser.h"
-#include <type_traits>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/DataTypeTuple.h>
@@ -105,6 +104,7 @@ AggregateFunctionParser::parseFunctionArguments(const
CommonFunctionInfo & func_
collected_args.push_back(arg_node);
}
+
if (func_info.has_filter)
{
// With `If` combinator, the function take one more argument which
refers to the condition.
@@ -115,47 +115,46 @@ AggregateFunctionParser::parseFunctionArguments(const
CommonFunctionInfo & func_
}
std::pair<String, DB::DataTypes> AggregateFunctionParser::tryApplyCHCombinator(
- const CommonFunctionInfo & func_info, const String & ch_func_name, const
DB::DataTypes & arg_column_types) const
+ const CommonFunctionInfo & func_info, const String & ch_func_name, const
DB::DataTypes & argument_types) const
{
- auto get_aggregate_function = [](const String & name, const DB::DataTypes
& arg_types) -> DB::AggregateFunctionPtr
+ auto get_aggregate_function
+ = [](const String & name, const DB::DataTypes & argument_types, const
DB::Array & parameters) -> DB::AggregateFunctionPtr
{
DB::AggregateFunctionProperties properties;
- auto func = RelParser::getAggregateFunction(name, arg_types,
properties);
+ auto func = RelParser::getAggregateFunction(name, argument_types,
properties, parameters);
if (!func)
- {
throw Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknown aggregate
function {}", name);
- }
+
return func;
};
+
String combinator_function_name = ch_func_name;
- DB::DataTypes combinator_arg_column_types = arg_column_types;
+ DB::DataTypes combinator_argument_types = argument_types;
+
if (func_info.phase !=
substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE
&& func_info.phase !=
substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_RESULT)
{
- if (arg_column_types.size() != 1)
- {
+ if (argument_types.size() != 1)
throw Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Only support one
argument aggregate function in phase {}", func_info.phase);
- }
+
// Add a check here for safty.
if (func_info.has_filter)
- {
- throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unspport apply
filter in phase {}", func_info.phase);
- }
+ throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Apply filter
in phase {} not supported", func_info.phase);
- const auto * agg_function_data =
DB::checkAndGetDataType<DB::DataTypeAggregateFunction>(arg_column_types[0].get());
- if (!agg_function_data)
+ const auto * aggr_func_type =
DB::checkAndGetDataType<DB::DataTypeAggregateFunction>(argument_types[0].get());
+ if (!aggr_func_type)
{
// FIXME. This is should be fixed. It's the case that
count(distinct(xxx)) with other aggregate functions.
// Gluten breaks the rule that intermediate result should have a
special format name here.
LOG_INFO(logger, "Intermediate aggregate function data is expected
in phase {} for {}", func_info.phase, ch_func_name);
- auto arg_type = DB::removeNullable(arg_column_types[0]);
+
+ auto arg_type = DB::removeNullable(argument_types[0]);
if (auto * tupe_type = typeid_cast<const DB::DataTypeTuple
*>(arg_type.get()))
- {
- combinator_arg_column_types = tupe_type->getElements();
- }
- auto agg_function = get_aggregate_function(ch_func_name,
arg_column_types);
+ combinator_argument_types = tupe_type->getElements();
+
+ auto agg_function = get_aggregate_function(ch_func_name,
argument_types, aggr_func_type->getParameters());
auto agg_intermediate_result_type = agg_function->getStateType();
- combinator_arg_column_types = {agg_intermediate_result_type};
+ combinator_argument_types = {agg_intermediate_result_type};
}
else
{
@@ -167,12 +166,12 @@ std::pair<String, DB::DataTypes>
AggregateFunctionParser::tryApplyCHCombinator(
// count(a),count(b), count(1), count(distinct(a)),
count(distinct(b))
// from values (1, null), (2,2) as data(a,b)
// with `first_value` enable
- if (endsWith(agg_function_data->getFunction()->getName(), "If") &&
ch_func_name != agg_function_data->getFunction()->getName())
+ if (endsWith(aggr_func_type->getFunction()->getName(), "If") &&
ch_func_name != aggr_func_type->getFunction()->getName())
{
- auto original_args_types =
agg_function_data->getArgumentsDataTypes();
- combinator_arg_column_types =
DataTypes(original_args_types.begin(), std::prev(original_args_types.end()));
- auto agg_function = get_aggregate_function(ch_func_name,
combinator_arg_column_types);
- combinator_arg_column_types = {agg_function->getStateType()};
+ auto original_args_types =
aggr_func_type->getArgumentsDataTypes();
+ combinator_argument_types =
DataTypes(original_args_types.begin(), std::prev(original_args_types.end()));
+ auto agg_function = get_aggregate_function(ch_func_name,
combinator_argument_types, aggr_func_type->getParameters());
+ combinator_argument_types = {agg_function->getStateType()};
}
}
combinator_function_name += "PartialMerge";
@@ -182,7 +181,7 @@ std::pair<String, DB::DataTypes>
AggregateFunctionParser::tryApplyCHCombinator(
// Apply `If` aggregate function combinator on the original aggregate
function.
combinator_function_name += "If";
}
- return {combinator_function_name, combinator_arg_column_types};
+ return {combinator_function_name, combinator_argument_types};
}
const DB::ActionsDAG::Node * AggregateFunctionParser::convertNodeTypeIfNeeded(
diff --git a/cpp-ch/local-engine/Parser/AggregateFunctionParser.h
b/cpp-ch/local-engine/Parser/AggregateFunctionParser.h
index 2986efd9b6..02b09fc256 100644
--- a/cpp-ch/local-engine/Parser/AggregateFunctionParser.h
+++ b/cpp-ch/local-engine/Parser/AggregateFunctionParser.h
@@ -92,7 +92,7 @@ public:
/// In most cases, arguments size and types are enough to determine the CH
function implementation.
/// It is only be used in TypeParser::buildBlockFromNamedStruct
- /// Users are allowed to modify arg types to make it fit for
ggregateFunctionFactory::instance().get(...) in
TypeParser::buildBlockFromNamedStruct
+ /// Users are allowed to modify arg types to make it fit for
AggregateFunctionFactory::instance().get(...) in
TypeParser::buildBlockFromNamedStruct
virtual String getCHFunctionName(DB::DataTypes & args) const = 0;
/// Do some preprojections for the function arguments, and return the
necessary arguments for the CH function.
@@ -114,8 +114,8 @@ public:
/// Parameters are only used in aggregate functions at present. e.g.
percentiles(0.5)(x).
/// 0.5 is the parameter of percentiles function.
- virtual DB::Array
- parseFunctionParameters(const CommonFunctionInfo & /*func_info*/,
DB::ActionsDAG::NodeRawConstPtrs & /*arg_nodes*/) const
+ virtual DB::Array parseFunctionParameters(
+ const CommonFunctionInfo & /*func_info*/,
DB::ActionsDAG::NodeRawConstPtrs & /*arg_nodes*/, DB::ActionsDAG &
/*actions_dag*/) const
{
return DB::Array();
}
diff --git a/cpp-ch/local-engine/Parser/RelParsers/AggregateRelParser.cpp
b/cpp-ch/local-engine/Parser/RelParsers/AggregateRelParser.cpp
index adde9ac182..269d55e645 100644
--- a/cpp-ch/local-engine/Parser/RelParsers/AggregateRelParser.cpp
+++ b/cpp-ch/local-engine/Parser/RelParsers/AggregateRelParser.cpp
@@ -226,7 +226,7 @@ void AggregateRelParser::addPreProjection()
{
auto arg_nodes =
agg_info.function_parser->parseFunctionArguments(agg_info.parser_func_info,
projection_action);
// This may remove elements from arg_nodes, because some of them are
converted to CH func parameters.
- agg_info.params =
agg_info.function_parser->parseFunctionParameters(agg_info.parser_func_info,
arg_nodes);
+ agg_info.params =
agg_info.function_parser->parseFunctionParameters(agg_info.parser_func_info,
arg_nodes, projection_action);
for (auto & arg_node : arg_nodes)
{
agg_info.arg_column_names.emplace_back(arg_node->result_name);
diff --git a/cpp-ch/local-engine/Parser/RelParsers/WindowRelParser.cpp
b/cpp-ch/local-engine/Parser/RelParsers/WindowRelParser.cpp
index 7b5b0147ab..d52f2543c8 100644
--- a/cpp-ch/local-engine/Parser/RelParsers/WindowRelParser.cpp
+++ b/cpp-ch/local-engine/Parser/RelParsers/WindowRelParser.cpp
@@ -324,7 +324,7 @@ void WindowRelParser::tryAddProjectionBeforeWindow()
{
auto arg_nodes =
win_info.function_parser->parseFunctionArguments(win_info.parser_func_info,
actions_dag);
// This may remove elements from arg_nodes, because some of them are
converted to CH func parameters.
- win_info.params =
win_info.function_parser->parseFunctionParameters(win_info.parser_func_info,
arg_nodes);
+ win_info.params =
win_info.function_parser->parseFunctionParameters(win_info.parser_func_info,
arg_nodes, actions_dag);
for (auto & arg_node : arg_nodes)
{
win_info.arg_column_names.emplace_back(arg_node->result_name);
diff --git
a/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.cpp
b/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.cpp
index ceddbd2aef..896157c4f2 100644
---
a/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.cpp
+++
b/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.cpp
@@ -14,126 +14,34 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
-#include <string>
-#include <DataTypes/DataTypeAggregateFunction.h>
-#include <DataTypes/DataTypeNullable.h>
-#include <Functions/FunctionHelpers.h>
-#include <Interpreters/ActionsDAG.h>
-#include <Parser/AggregateFunctionParser.h>
-#include <Parser/aggregate_function_parser/ApproxPercentileParser.h>
-#include <substrait/algebra.pb.h>
-
-namespace DB
-{
-namespace ErrorCodes
-{
- extern const int BAD_ARGUMENTS;
- extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
-}
-}
+#pragma once
+#include <Parser/aggregate_function_parser/PercentileParserBase.h>
namespace local_engine
{
-void ApproxPercentileParser::assertArgumentsSize(substrait::AggregationPhase
phase, size_t size, size_t expect) const
-{
- if (size != expect)
- throw Exception(
- DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
- "Function {} in phase {} requires exactly {} arguments but got {}
arguments",
- getName(),
- magic_enum::enum_name(phase),
- expect,
- size);
-}
-
-const substrait::Expression::Literal &
-ApproxPercentileParser::assertAndGetLiteral(substrait::AggregationPhase phase,
const substrait::Expression & expr) const
-{
- if (!expr.has_literal())
- throw Exception(
- DB::ErrorCodes::BAD_ARGUMENTS,
- "The argument of function {} in phase {} must be literal, but is
{}",
- getName(),
- magic_enum::enum_name(phase),
- expr.DebugString());
- return expr.literal();
-}
-
-String ApproxPercentileParser::getCHFunctionName(const CommonFunctionInfo &
func_info) const
-{
- const auto & output_type = func_info.output_type;
- return output_type.has_list() ? "quantilesGK" : "quantileGK";
-}
-
-String ApproxPercentileParser::getCHFunctionName(DB::DataTypes & types) const
-{
- /// Always invoked during second stage
- assertArgumentsSize(substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT,
types.size(), 2);
-
- auto type = removeNullable(types[1]);
- types.resize(1);
- return isArray(type) ? "quantilesGK" : "quantileGK";
-}
-
-DB::Array ApproxPercentileParser::parseFunctionParameters(
- const CommonFunctionInfo & func_info, DB::ActionsDAG::NodeRawConstPtrs &
arg_nodes) const
+/*
+spark: approx_percentile(col, percentage [, accuracy])
+1. When percentage is an array literal, spark returns an array of percentiles,
corresponding to CH: quantilesGK(accuracy, percentage[0], ...)(col)
+1. Otherwise spark return a single percentile, corresponding to CH:
quantileGK(accuracy, percentage)(col)
+*/
+class ApproxPercentileParser : public PercentileParserBase
{
- if (func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE
- || func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_RESULT
|| func_info.phase == substrait::AGGREGATION_PHASE_UNSPECIFIED)
- {
- Array params;
- const auto & arguments = func_info.arguments;
- assertArgumentsSize(func_info.phase, arguments.size(), 3);
-
- const auto & accuracy_expr = arguments[2].value();
- const auto & accuracy_literal = assertAndGetLiteral(func_info.phase,
accuracy_expr);
- auto [type1, field1] = parseLiteral(accuracy_literal);
- params.emplace_back(std::move(field1));
-
- const auto & percentage_expr = arguments[1].value();
- const auto & percentage_literal = assertAndGetLiteral(func_info.phase,
percentage_expr);
- auto [type2, field2] = parseLiteral(percentage_literal);
- if (isArray(type2))
- {
- /// Multiple percentages for quantilesGK
- const Array & percentags = field2.safeGet<Array>();
- for (const auto & percentage : percentags)
- params.emplace_back(percentage);
- }
- else
- {
- /// Single percentage for quantileGK
- params.emplace_back(std::move(field2));
- }
+public:
+ static constexpr auto name = "approx_percentile";
- /// Delete percentage and accuracy argument for clickhouse
compatiability
- arg_nodes.resize(1);
- return params;
- }
- else
- {
- assertArgumentsSize(func_info.phase, arg_nodes.size(), 1);
- const auto & result_type = arg_nodes[0]->result_type;
- const auto * aggregate_function_type =
DB::checkAndGetDataType<DB::DataTypeAggregateFunction>(result_type.get());
- if (!aggregate_function_type)
- throw Exception(
- DB::ErrorCodes::BAD_ARGUMENTS,
- "The first argument type of function {} in phase {} must be
AggregateFunction, but is {}",
- getName(),
- magic_enum::enum_name(func_info.phase),
- result_type->getName());
+ explicit ApproxPercentileParser(ParserContextPtr parser_context_) :
PercentileParserBase(parser_context_) { }
- return aggregate_function_type->getParameters();
- }
-}
+ String getName() const override { return name; }
+ String getCHSingularName() const override { return "quantileGK"; }
+ String getCHPluralName() const override { return "quantilesGK"; }
-DB::Array ApproxPercentileParser::getDefaultFunctionParameters() const
-{
- return {10000, 1};
-}
+ size_t expectedArgumentsNumberInFirstStage() const override { return 3; }
+ size_t expectedTupleElementsNumberInSecondStage() const override { return
2; }
+ ColumnNumbers getArgumentsThatAreParameters() const override { return {2,
1}; }
+ DB::Array getDefaultFunctionParametersImpl() const override { return
{10000, 1}; }
+};
static const AggregateFunctionParserRegister<ApproxPercentileParser>
register_approx_percentile;
}
diff --git
a/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.h
b/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.h
deleted file mode 100644
index a58d4b2042..0000000000
---
a/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.h
+++ /dev/null
@@ -1,48 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-#pragma once
-#include <Parser/AggregateFunctionParser.h>
-
-
-/*
-spark: approx_percentile(col, percentage [, accuracy])
-1. When percentage is an array literal, spark returns an array of percentiles,
corresponding to CH: quantilesGK(accuracy, percentage[0], ...)(col)
-1. Otherwise spark return a single percentile, corresponding to CH:
quantileGK(accuracy, percentage)(col)
-*/
-
-namespace local_engine
-{
-class ApproxPercentileParser : public AggregateFunctionParser
-{
-public:
- explicit ApproxPercentileParser(ParserContextPtr parser_context_) :
AggregateFunctionParser(parser_context_) { }
- ~ApproxPercentileParser() override = default;
- String getName() const override { return name; }
- static constexpr auto name = "approx_percentile";
- String getCHFunctionName(const CommonFunctionInfo & func_info) const
override;
- String getCHFunctionName(DB::DataTypes & types) const override;
-
- DB::Array
- parseFunctionParameters(const CommonFunctionInfo & /*func_info*/,
DB::ActionsDAG::NodeRawConstPtrs & arg_nodes) const override;
-
- DB::Array getDefaultFunctionParameters() const override;
-
-private:
- void assertArgumentsSize(substrait::AggregationPhase phase, size_t size,
size_t expect) const;
- const substrait::Expression::Literal &
assertAndGetLiteral(substrait::AggregationPhase phase, const
substrait::Expression & expr) const;
-};
-}
diff --git
a/cpp-ch/local-engine/Parser/aggregate_function_parser/BloomFilterAggParser.cpp
b/cpp-ch/local-engine/Parser/aggregate_function_parser/BloomFilterAggParser.cpp
index 0d4bec5c27..56cc3bf2e5 100644
---
a/cpp-ch/local-engine/Parser/aggregate_function_parser/BloomFilterAggParser.cpp
+++
b/cpp-ch/local-engine/Parser/aggregate_function_parser/BloomFilterAggParser.cpp
@@ -52,7 +52,7 @@ DB::Array get_parameters(Int64 insert_num, Int64 bits_num)
}
DB::Array AggregateFunctionParserBloomFilterAgg::parseFunctionParameters(
- const CommonFunctionInfo & func_info, DB::ActionsDAG::NodeRawConstPtrs &
arg_nodes) const
+ const CommonFunctionInfo & func_info, DB::ActionsDAG::NodeRawConstPtrs &
arg_nodes, DB::ActionsDAG & /*actions_dag*/) const
{
if (func_info.phase ==
substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE || func_info.phase ==
substrait::AGGREGATION_PHASE_INITIAL_TO_RESULT)
{
diff --git
a/cpp-ch/local-engine/Parser/aggregate_function_parser/BloomFilterAggParser.h
b/cpp-ch/local-engine/Parser/aggregate_function_parser/BloomFilterAggParser.h
index 52e7dfcec8..67df407954 100644
---
a/cpp-ch/local-engine/Parser/aggregate_function_parser/BloomFilterAggParser.h
+++
b/cpp-ch/local-engine/Parser/aggregate_function_parser/BloomFilterAggParser.h
@@ -32,6 +32,6 @@ public:
String getCHFunctionName(DB::DataTypes &) const override { return
"groupBloomFilterState"; }
DB::Array
- parseFunctionParameters(const CommonFunctionInfo & /*func_info*/,
DB::ActionsDAG::NodeRawConstPtrs & arg_nodes) const override;
+ parseFunctionParameters(const CommonFunctionInfo & /*func_info*/,
DB::ActionsDAG::NodeRawConstPtrs & arg_nodes, DB::ActionsDAG & /*actions_dag*/)
const override;
};
}
diff --git
a/cpp-ch/local-engine/Parser/aggregate_function_parser/PercentileParser.cpp
b/cpp-ch/local-engine/Parser/aggregate_function_parser/PercentileParser.cpp
new file mode 100644
index 0000000000..3002b4bb67
--- /dev/null
+++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/PercentileParser.cpp
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+#include <Parser/aggregate_function_parser/PercentileParserBase.h>
+
+namespace local_engine
+{
+
+/*
+spark: percentile(col, percentage, [, frequency])
+1. When percentage is an array literal, spark returns an array of percentiles,
corresponding to CH: quantilesExact(percentage[0], ...)(col)
+1. Otherwise spark return a single percentile, corresponding to CH:
quantileExact(percentage)(col)
+*/
+class PercentileParser : public PercentileParserBase
+{
+public:
+ static constexpr auto name = "percentile";
+
+ explicit PercentileParser(ParserContextPtr parser_context_) :
PercentileParserBase(parser_context_) { }
+
+ String getName() const override { return name; }
+ String getCHSingularName() const override { return
"quantileExactWeightedInterpolated"; }
+ String getCHPluralName() const override { return
"quantilesExactWeightedInterpolated"; }
+
+ /// spark percentile(col, percentile[s], frequency)
+ size_t expectedArgumentsNumberInFirstStage() const override { return 3; }
+
+ /// intermediate result: struct{col, percentile[s]}
+ size_t expectedTupleElementsNumberInSecondStage() const override { return
2; }
+
+ ColumnNumbers getArgumentsThatAreParameters() const override { return {1};
}
+ DB::Array getDefaultFunctionParametersImpl() const override { return {1}; }
+};
+
+static const AggregateFunctionParserRegister<PercentileParser>
register_percentile;
+}
diff --git
a/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.cpp
b/cpp-ch/local-engine/Parser/aggregate_function_parser/PercentileParserBase.cpp
similarity index 51%
copy from
cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.cpp
copy to
cpp-ch/local-engine/Parser/aggregate_function_parser/PercentileParserBase.cpp
index ceddbd2aef..82eb39503d 100644
---
a/cpp-ch/local-engine/Parser/aggregate_function_parser/ApproxPercentileParser.cpp
+++
b/cpp-ch/local-engine/Parser/aggregate_function_parser/PercentileParserBase.cpp
@@ -15,13 +15,14 @@
* limitations under the License.
*/
-#include <string>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/DataTypeNullable.h>
#include <Functions/FunctionHelpers.h>
#include <Interpreters/ActionsDAG.h>
#include <Parser/AggregateFunctionParser.h>
-#include <Parser/aggregate_function_parser/ApproxPercentileParser.h>
+#include <Parser/aggregate_function_parser/PercentileParserBase.h>
+#include <Common/CHUtil.h>
+
#include <substrait/algebra.pb.h>
namespace DB
@@ -36,7 +37,7 @@ namespace ErrorCodes
namespace local_engine
{
-void ApproxPercentileParser::assertArgumentsSize(substrait::AggregationPhase
phase, size_t size, size_t expect) const
+void PercentileParserBase::assertArgumentsSize(substrait::AggregationPhase
phase, size_t size, size_t expect) const
{
if (size != expect)
throw Exception(
@@ -49,7 +50,7 @@ void
ApproxPercentileParser::assertArgumentsSize(substrait::AggregationPhase pha
}
const substrait::Expression::Literal &
-ApproxPercentileParser::assertAndGetLiteral(substrait::AggregationPhase phase,
const substrait::Expression & expr) const
+PercentileParserBase::assertAndGetLiteral(substrait::AggregationPhase phase,
const substrait::Expression & expr) const
{
if (!expr.has_literal())
throw Exception(
@@ -61,55 +62,81 @@
ApproxPercentileParser::assertAndGetLiteral(substrait::AggregationPhase phase, c
return expr.literal();
}
-String ApproxPercentileParser::getCHFunctionName(const CommonFunctionInfo &
func_info) const
+String PercentileParserBase::getCHFunctionName(const CommonFunctionInfo &
func_info) const
{
const auto & output_type = func_info.output_type;
- return output_type.has_list() ? "quantilesGK" : "quantileGK";
+ return output_type.has_list() ? getCHPluralName() : getCHSingularName();
}
-String ApproxPercentileParser::getCHFunctionName(DB::DataTypes & types) const
+String PercentileParserBase::getCHFunctionName(DB::DataTypes & types) const
{
/// Always invoked during second stage
- assertArgumentsSize(substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT,
types.size(), 2);
+ assertArgumentsSize(substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT,
types.size(), expectedTupleElementsNumberInSecondStage());
- auto type = removeNullable(types[1]);
+ auto type = removeNullable(types[PERCENTAGE_INDEX]);
types.resize(1);
- return isArray(type) ? "quantilesGK" : "quantileGK";
+
+ if (getName() == "percentile")
+ {
+ /// Corresponding CH function requires two arguments:
quantileExactWeightedInterpolated(xxx)(col, weight)
+ types.push_back(std::make_shared<DataTypeUInt64>());
+ }
+
+ return isArray(type) ? getCHPluralName() : getCHSingularName();
}
-DB::Array ApproxPercentileParser::parseFunctionParameters(
- const CommonFunctionInfo & func_info, DB::ActionsDAG::NodeRawConstPtrs &
arg_nodes) const
+DB::Array PercentileParserBase::parseFunctionParameters(
+ const CommonFunctionInfo & func_info, DB::ActionsDAG::NodeRawConstPtrs &
arg_nodes, DB::ActionsDAG & actions_dag) const
{
if (func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE
|| func_info.phase == substrait::AGGREGATION_PHASE_INITIAL_TO_RESULT
|| func_info.phase == substrait::AGGREGATION_PHASE_UNSPECIFIED)
{
Array params;
const auto & arguments = func_info.arguments;
- assertArgumentsSize(func_info.phase, arguments.size(), 3);
-
- const auto & accuracy_expr = arguments[2].value();
- const auto & accuracy_literal = assertAndGetLiteral(func_info.phase,
accuracy_expr);
- auto [type1, field1] = parseLiteral(accuracy_literal);
- params.emplace_back(std::move(field1));
+ assertArgumentsSize(func_info.phase, arguments.size(),
expectedArgumentsNumberInFirstStage());
- const auto & percentage_expr = arguments[1].value();
- const auto & percentage_literal = assertAndGetLiteral(func_info.phase,
percentage_expr);
- auto [type2, field2] = parseLiteral(percentage_literal);
- if (isArray(type2))
+ auto param_indexes = getArgumentsThatAreParameters();
+ for (auto idx : param_indexes)
{
- /// Multiple percentages for quantilesGK
- const Array & percentags = field2.safeGet<Array>();
- for (const auto & percentage : percentags)
- params.emplace_back(percentage);
+ const auto & expr = arguments[idx].value();
+ const auto & literal = assertAndGetLiteral(func_info.phase, expr);
+ auto [type, field] = parseLiteral(literal);
+
+ if (idx == PERCENTAGE_INDEX && isArray(removeNullable(type)))
+ {
+ /// Multiple percentages for quantilesXXX
+ const Array & percentags = field.safeGet<Array>();
+ for (const auto & percentage : percentags)
+ params.emplace_back(percentage);
+ }
+ else
+ {
+ params.emplace_back(std::move(field));
+ }
}
- else
+
+ /// Collect arguments in substrait plan that are not CH parameters as
CH arguments
+ ActionsDAG::NodeRawConstPtrs new_arg_nodes;
+ for (size_t i = 0; i < arg_nodes.size(); ++i)
{
- /// Single percentage for quantileGK
- params.emplace_back(std::move(field2));
+ if (std::find(param_indexes.begin(), param_indexes.end(), i) ==
param_indexes.end())
+ {
+ if (getName() == "percentile" && i == 2)
+ {
+ /// In spark percentile(col, percentage, weight), the last
argument weight is a signed integer
+ /// But CH requires weight as an unsigned integer
+ DataTypePtr dst_type = std::make_shared<DataTypeUInt64>();
+ if (arg_nodes[i]->result_type->isNullable())
+ dst_type =
std::make_shared<DataTypeNullable>(dst_type);
+
+ arg_nodes[i] =
ActionsDAGUtil::convertNodeTypeIfNeeded(actions_dag, arg_nodes[i], dst_type);
+ }
+
+ new_arg_nodes.emplace_back(arg_nodes[i]);
+ }
}
+ new_arg_nodes.swap(arg_nodes);
- /// Delete percentage and accuracy argument for clickhouse
compatiability
- arg_nodes.resize(1);
return params;
}
else
@@ -128,12 +155,4 @@ DB::Array ApproxPercentileParser::parseFunctionParameters(
return aggregate_function_type->getParameters();
}
}
-
-DB::Array ApproxPercentileParser::getDefaultFunctionParameters() const
-{
- return {10000, 1};
-}
-
-
-static const AggregateFunctionParserRegister<ApproxPercentileParser>
register_approx_percentile;
}
diff --git
a/cpp-ch/local-engine/Parser/aggregate_function_parser/PercentileParserBase.h
b/cpp-ch/local-engine/Parser/aggregate_function_parser/PercentileParserBase.h
new file mode 100644
index 0000000000..f43203d429
--- /dev/null
+++
b/cpp-ch/local-engine/Parser/aggregate_function_parser/PercentileParserBase.h
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+#include <Parser/AggregateFunctionParser.h>
+
+namespace local_engine
+{
+class PercentileParserBase : public AggregateFunctionParser
+{
+public:
+ explicit PercentileParserBase(ParserContextPtr parser_context_) :
AggregateFunctionParser(parser_context_) { }
+
+ String getCHFunctionName(const CommonFunctionInfo & func_info) const
override;
+ String getCHFunctionName(DB::DataTypes & types) const override;
+ DB::Array parseFunctionParameters(
+ const CommonFunctionInfo & /*func_info*/,
+ DB::ActionsDAG::NodeRawConstPtrs & arg_nodes,
+ DB::ActionsDAG & actions_dag) const override;
+ DB::Array getDefaultFunctionParameters() const override { return
getDefaultFunctionParametersImpl(); }
+
+protected:
+ virtual String getCHSingularName() const = 0;
+ virtual String getCHPluralName() const = 0;
+
+ /// Expected arguments number of substrait function in first stage
+ virtual size_t expectedArgumentsNumberInFirstStage() const = 0;
+
+ /// Expected number of element types wrapped in struct type as intermidate
result type of current aggregate function in second stage
+ /// Refer to L327 in
backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala
+ virtual size_t expectedTupleElementsNumberInSecondStage() const = 0;
+
+ /// Get argument indexes in first stage substrait function which should be
treated as parameters in CH aggregate function.
+ /// Note: the indexes are 0-based, and we must guarantee the order of
returned parameters matches the order of parameters in CH aggregate function
+ virtual ColumnNumbers getArgumentsThatAreParameters() const = 0;
+
+ virtual DB::Array getDefaultFunctionParametersImpl() const = 0;
+
+ /// Utils functions
+ void assertArgumentsSize(substrait::AggregationPhase phase, size_t size,
size_t expect) const;
+ const substrait::Expression::Literal &
assertAndGetLiteral(substrait::AggregationPhase phase, const
substrait::Expression & expr) const;
+
+ /// percentage index in substrait function arguments(both in first and
second stage), which is always 1
+ /// All derived implementations must obey this rule
+ static constexpr size_t PERCENTAGE_INDEX = 1;
+};
+
+}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
index ea0259f3b0..f6cc01f8ac 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
@@ -321,7 +321,8 @@ object ExpressionMappings {
Sig[First](FIRST),
Sig[Skewness](SKEWNESS),
Sig[Kurtosis](KURTOSIS),
- Sig[ApproximatePercentile](APPROX_PERCENTILE)
+ Sig[ApproximatePercentile](APPROX_PERCENTILE),
+ Sig[Percentile](PERCENTILE)
) ++ SparkShimLoader.getSparkShims.aggregateExpressionMappings
/** Mapping Spark window expression to Substrait function name */
diff --git
a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index 4bf153f8bf..27e26606f6 100644
---
a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++
b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -21,6 +21,7 @@ import org.apache.gluten.utils.{BackendTestSettings,
SQLQueryTestSettings}
import org.apache.spark.sql._
import org.apache.spark.sql.GlutenTestConstants.GLUTEN_TEST
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.connector._
import org.apache.spark.sql.execution._
import
org.apache.spark.sql.execution.adaptive.clickhouse.ClickHouseAdaptiveQueryExecSuite
@@ -1985,6 +1986,7 @@ class ClickHouseTestSettings extends BackendTestSettings {
enableSuite[SparkFunctionStatistics]
enableSuite[GlutenSparkSessionExtensionSuite]
enableSuite[GlutenHiveSQLQueryCHSuite]
+ enableSuite[GlutenPercentileSuite]
override def getSQLQueryTestSettings: SQLQueryTestSettings =
ClickHouseSQLQueryTestSettings
}
diff --git
a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/GlutenPercentileSuite.scala
b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/GlutenPercentileSuite.scala
new file mode 100644
index 0000000000..5f89c2810e
--- /dev/null
+++
b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/GlutenPercentileSuite.scala
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.GlutenTestsTrait
+
+class GlutenPercentileSuite extends PercentileSuite with GlutenTestsTrait {}
diff --git
a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index 2eeec0b544..da950e2fc1 100644
---
a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++
b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -21,6 +21,7 @@ import org.apache.gluten.utils.{BackendTestSettings,
SQLQueryTestSettings}
import org.apache.spark.sql._
import org.apache.spark.sql.GlutenTestConstants.GLUTEN_TEST
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.connector._
import org.apache.spark.sql.errors._
import org.apache.spark.sql.execution._
@@ -1873,6 +1874,7 @@ class ClickHouseTestSettings extends BackendTestSettings {
.excludeGlutenTest("fallbackSummary with cached data and shuffle")
enableSuite[GlutenSparkSessionExtensionSuite]
enableSuite[GlutenHiveSQLQueryCHSuite]
+ enableSuite[GlutenPercentileSuite]
override def getSQLQueryTestSettings: SQLQueryTestSettings =
ClickHouseSQLQueryTestSettings
}
diff --git
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/GlutenPercentileSuite.scala
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/GlutenPercentileSuite.scala
new file mode 100644
index 0000000000..5f89c2810e
--- /dev/null
+++
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/GlutenPercentileSuite.scala
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.GlutenTestsTrait
+
+class GlutenPercentileSuite extends PercentileSuite with GlutenTestsTrait {}
diff --git
a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index b13bb2abc9..ac08fc5a80 100644
---
a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++
b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -21,6 +21,7 @@ import org.apache.gluten.utils.{BackendTestSettings,
SQLQueryTestSettings}
import org.apache.spark.sql._
import org.apache.spark.sql.GlutenTestConstants.GLUTEN_TEST
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.connector._
import org.apache.spark.sql.errors._
import org.apache.spark.sql.execution._
@@ -1715,6 +1716,7 @@ class ClickHouseTestSettings extends BackendTestSettings {
enableSuite[SparkFunctionStatistics]
enableSuite[GlutenSparkSessionExtensionSuite]
enableSuite[GlutenHiveSQLQueryCHSuite]
+ enableSuite[GlutenPercentileSuite]
override def getSQLQueryTestSettings: SQLQueryTestSettings =
ClickHouseSQLQueryTestSettings
}
diff --git
a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/GlutenPercentileSuite.scala
b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/GlutenPercentileSuite.scala
new file mode 100644
index 0000000000..5f89c2810e
--- /dev/null
+++
b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/GlutenPercentileSuite.scala
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.GlutenTestsTrait
+
+class GlutenPercentileSuite extends PercentileSuite with GlutenTestsTrait {}
diff --git
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index 9836cb27f5..9e4c81081d 100644
---
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -21,6 +21,7 @@ import org.apache.gluten.utils.{BackendTestSettings,
SQLQueryTestSettings}
import org.apache.spark.sql._
import org.apache.spark.sql.GlutenTestConstants.GLUTEN_TEST
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.connector._
import org.apache.spark.sql.errors._
import org.apache.spark.sql.execution._
@@ -1717,6 +1718,7 @@ class ClickHouseTestSettings extends BackendTestSettings {
enableSuite[SparkFunctionStatistics]
enableSuite[GlutenSparkSessionExtensionSuite]
enableSuite[GlutenHiveSQLQueryCHSuite]
+ enableSuite[GlutenPercentileSuite]
override def getSQLQueryTestSettings: SQLQueryTestSettings =
ClickHouseSQLQueryTestSettings
}
diff --git
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/GlutenPercentileSuite.scala
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/GlutenPercentileSuite.scala
new file mode 100644
index 0000000000..5f89c2810e
--- /dev/null
+++
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/GlutenPercentileSuite.scala
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.GlutenTestsTrait
+
+class GlutenPercentileSuite extends PercentileSuite with GlutenTestsTrait {}
diff --git
a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
index 32ac1914e6..713a6ba0dc 100644
---
a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
+++
b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
@@ -47,6 +47,7 @@ object ExpressionNames {
final val FIRST_IGNORE_NULL = "first_ignore_null"
final val APPROX_DISTINCT = "approx_distinct"
final val APPROX_PERCENTILE = "approx_percentile"
+ final val PERCENTILE = "percentile"
final val SKEWNESS = "skewness"
final val KURTOSIS = "kurtosis"
final val REGR_SLOPE = "regr_slope"
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]