This is an automated email from the ASF dual-hosted git repository.
zhangzc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 7fc385dda [GLUTEN-6334][CH] Support ntile window function (#6335)
7fc385dda is described below
commit 7fc385dda81d4a659d36d287b49d232b9504c0b0
Author: Zhichao Zhang <[email protected]>
AuthorDate: Thu Jul 4 18:45:28 2024 +0800
[GLUTEN-6334][CH] Support ntile window function (#6335)
[CH] Support ntile window function
Close #6334.
---
.../gluten/backendsapi/clickhouse/CHBackend.scala | 4 +--
.../clickhouse/CHSparkPlanExecApi.scala | 18 +++++++++-
.../GlutenClickHouseTPCHSaltNullParquetSuite.scala | 16 +++++++++
.../CommonAggregateFunctionParser.cpp | 3 --
.../aggregate_function_parser/NtileParser.cpp | 42 ++++++++++++++++++++++
.../Parser/aggregate_function_parser/NtileParser.h | 34 ++++++++++++++++++
6 files changed, 111 insertions(+), 6 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
index cdca1b031..d369b8c16 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
@@ -26,7 +26,7 @@ import
org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat._
import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.catalog.BucketSpec
-import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference,
DenseRank, Expression, Lag, Lead, Literal, NamedExpression, Rank, RowNumber}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning,
Partitioning}
import org.apache.spark.sql.execution.SparkPlan
@@ -237,7 +237,7 @@ object CHBackendSettings extends BackendSettingsApi with
Logging {
}
wExpression.windowFunction match {
- case _: RowNumber | _: AggregateExpression | _: Rank | _:
DenseRank =>
+ case _: RowNumber | _: AggregateExpression | _: Rank | _:
DenseRank | _: NTile =>
allSupported = allSupported
case l: Lag =>
checkLagOrLead(l.third)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index 44aeba021..add82cbb5 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -704,7 +704,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
val columnName = s"${aliasExpr.name}_${aliasExpr.exprId.id}"
val wExpression = aliasExpr.child.asInstanceOf[WindowExpression]
wExpression.windowFunction match {
- case wf @ (RowNumber() | Rank(_) | DenseRank(_) | CumeDist() |
PercentRank(_)) =>
+ case wf @ (RowNumber() | Rank(_) | DenseRank(_)) =>
val aggWindowFunc = wf.asInstanceOf[AggregateWindowFunction]
val frame = aggWindowFunc.frame.asInstanceOf[SpecifiedWindowFrame]
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
@@ -795,6 +795,22 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
originalInputAttributes.asJava
)
windowExpressionNodes.add(windowFunctionNode)
+ case wf @ NTile(buckets: Expression) =>
+ val frame =
wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
+ val childrenNodeList = new JArrayList[ExpressionNode]()
+ val literal = buckets.asInstanceOf[Literal]
+ childrenNodeList.add(LiteralTransformer(literal).doTransform(args))
+ val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
+ WindowFunctionsBuilder.create(args, wf).toInt,
+ childrenNodeList,
+ columnName,
+ ConverterUtils.getTypeNode(wf.dataType, wf.nullable),
+ frame.upper,
+ frame.lower,
+ frame.frameType.sql,
+ originalInputAttributes.asJava
+ )
+ windowExpressionNodes.add(windowFunctionNode)
case _ =>
throw new GlutenNotSupportException(
"unsupported window function type: " +
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
index c0f37b086..b0d3e1bdb 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
@@ -978,6 +978,22 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends
GlutenClickHouseTPCHAbstr
compareResultsAgainstVanillaSpark(sql, true, { _ => })
}
+ test("window ntile") {
+ val sql =
+ """
+ | select n_regionkey, n_nationkey,
+ | first_value(n_nationkey) over (partition by n_regionkey order by
n_nationkey) as
+ | first_v,
+ | ntile(4) over (partition by n_regionkey order by n_nationkey) as
ntile_v
+ | from
+ | (
+ | select n_regionkey, if(n_nationkey = 1, null, n_nationkey) as
n_nationkey from nation
+ | ) as t
+ | order by n_regionkey, n_nationkey
+ """.stripMargin
+ compareResultsAgainstVanillaSpark(sql, true, { _ => })
+ }
+
test("window first value with nulls") {
val sql =
"""
diff --git
a/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp
b/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp
index 1619c7410..e7d6e1b9b 100644
---
a/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp
+++
b/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp
@@ -42,8 +42,5 @@ REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(LastIgnoreNull,
last_ignore_null, last
REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(DenseRank, dense_rank, dense_rank)
REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(Rank, rank, rank)
REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(RowNumber, row_number, row_number)
-REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(Ntile, ntile, ntile)
-REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(PercentRank, percent_rank,
percent_rank)
-REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(CumeDist, cume_dist, cume_dist)
REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(CountDistinct, count_distinct,
uniqExact)
}
diff --git
a/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.cpp
b/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.cpp
new file mode 100644
index 000000000..49a59c657
--- /dev/null
+++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.cpp
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "NtileParser.h"
+#include <DataTypes/DataTypeNullable.h>
+#include <DataTypes/DataTypesNumber.h>
+#include <Interpreters/ActionsDAG.h>
+
+namespace local_engine
+{
+DB::ActionsDAG::NodeRawConstPtrs
+NtileParser::parseFunctionArguments(const CommonFunctionInfo & func_info,
const String & /*ch_func_name*/, DB::ActionsDAGPtr & actions_dag) const
+{
+ if (func_info.arguments.size() != 1)
+ throw Exception(ErrorCodes::BAD_ARGUMENTS, "Function ntile takes
exactly one argument");
+ DB::ActionsDAG::NodeRawConstPtrs args;
+
+ const auto & arg0 = func_info.arguments[0].value();
+ auto [data_type, field] = parseLiteral(arg0.literal());
+ if (!(DB::WhichDataType(data_type).isInt32()))
+ throw Exception(ErrorCodes::BAD_ARGUMENTS, "ntile's argument must be
i32");
+ Int32 field_index = static_cast<Int32>(field.get<Int32>());
+ // For CH, the data type of the args[0] must be the UInt32
+ const auto * index_node = addColumnToActionsDAG(actions_dag,
std::make_shared<DataTypeUInt32>(), field_index);
+ args.emplace_back(index_node);
+ return args;
+}
+AggregateFunctionParserRegister<NtileParser> ntile_register;
+}
diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.h
b/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.h
new file mode 100644
index 000000000..441de2353
--- /dev/null
+++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.h
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+#include <Parser/AggregateFunctionParser.h>
+
+namespace local_engine
+{
+class NtileParser : public AggregateFunctionParser
+{
+public:
+ explicit NtileParser(SerializedPlanParser * plan_parser_) :
AggregateFunctionParser(plan_parser_) { }
+ ~NtileParser() override = default;
+ static constexpr auto name = "ntile";
+ String getName() const override { return name; }
+ String getCHFunctionName(const CommonFunctionInfo &) const override {
return "ntile"; }
+ String getCHFunctionName(DB::DataTypes &) const override { return "ntile";
}
+ DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments(
+ const CommonFunctionInfo & func_info, const String & ch_func_name,
DB::ActionsDAGPtr & actions_dag) const override;
+};
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]