This is an automated email from the ASF dual-hosted git repository.

zhangzc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 7fc385dda [GLUTEN-6334][CH] Support ntile window function (#6335)
7fc385dda is described below

commit 7fc385dda81d4a659d36d287b49d232b9504c0b0
Author: Zhichao Zhang <[email protected]>
AuthorDate: Thu Jul 4 18:45:28 2024 +0800

    [GLUTEN-6334][CH] Support ntile window function (#6335)
    
    [CH] Support ntile window function
    
    Close #6334.
---
 .../gluten/backendsapi/clickhouse/CHBackend.scala  |  4 +--
 .../clickhouse/CHSparkPlanExecApi.scala            | 18 +++++++++-
 .../GlutenClickHouseTPCHSaltNullParquetSuite.scala | 16 +++++++++
 .../CommonAggregateFunctionParser.cpp              |  3 --
 .../aggregate_function_parser/NtileParser.cpp      | 42 ++++++++++++++++++++++
 .../Parser/aggregate_function_parser/NtileParser.h | 34 ++++++++++++++++++
 6 files changed, 111 insertions(+), 6 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
index cdca1b031..d369b8c16 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
@@ -26,7 +26,7 @@ import 
org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat._
 import org.apache.spark.SparkEnv
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.catalog.BucketSpec
-import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, 
DenseRank, Expression, Lag, Lead, Literal, NamedExpression, Rank, RowNumber}
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
Partitioning}
 import org.apache.spark.sql.execution.SparkPlan
@@ -237,7 +237,7 @@ object CHBackendSettings extends BackendSettingsApi with 
Logging {
           }
 
           wExpression.windowFunction match {
-            case _: RowNumber | _: AggregateExpression | _: Rank | _: 
DenseRank =>
+            case _: RowNumber | _: AggregateExpression | _: Rank | _: 
DenseRank | _: NTile =>
               allSupported = allSupported
             case l: Lag =>
               checkLagOrLead(l.third)
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index 44aeba021..add82cbb5 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -704,7 +704,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
         val columnName = s"${aliasExpr.name}_${aliasExpr.exprId.id}"
         val wExpression = aliasExpr.child.asInstanceOf[WindowExpression]
         wExpression.windowFunction match {
-          case wf @ (RowNumber() | Rank(_) | DenseRank(_) | CumeDist() | 
PercentRank(_)) =>
+          case wf @ (RowNumber() | Rank(_) | DenseRank(_)) =>
             val aggWindowFunc = wf.asInstanceOf[AggregateWindowFunction]
             val frame = aggWindowFunc.frame.asInstanceOf[SpecifiedWindowFrame]
             val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
@@ -795,6 +795,22 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
               originalInputAttributes.asJava
             )
             windowExpressionNodes.add(windowFunctionNode)
+          case wf @ NTile(buckets: Expression) =>
+            val frame = 
wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
+            val childrenNodeList = new JArrayList[ExpressionNode]()
+            val literal = buckets.asInstanceOf[Literal]
+            childrenNodeList.add(LiteralTransformer(literal).doTransform(args))
+            val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
+              WindowFunctionsBuilder.create(args, wf).toInt,
+              childrenNodeList,
+              columnName,
+              ConverterUtils.getTypeNode(wf.dataType, wf.nullable),
+              frame.upper,
+              frame.lower,
+              frame.frameType.sql,
+              originalInputAttributes.asJava
+            )
+            windowExpressionNodes.add(windowFunctionNode)
           case _ =>
             throw new GlutenNotSupportException(
               "unsupported window function type: " +
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
index c0f37b086..b0d3e1bdb 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
@@ -978,6 +978,22 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends 
GlutenClickHouseTPCHAbstr
     compareResultsAgainstVanillaSpark(sql, true, { _ => })
   }
 
+  test("window ntile") {
+    val sql =
+      """
+        | select n_regionkey, n_nationkey,
+        |   first_value(n_nationkey) over (partition by n_regionkey order by 
n_nationkey) as
+        |   first_v,
+        |   ntile(4) over (partition by n_regionkey order by n_nationkey) as 
ntile_v
+        | from
+        |   (
+        |     select n_regionkey, if(n_nationkey = 1, null, n_nationkey) as 
n_nationkey from nation
+        |   ) as t
+        | order by n_regionkey, n_nationkey
+      """.stripMargin
+    compareResultsAgainstVanillaSpark(sql, true, { _ => })
+  }
+
   test("window first value with nulls") {
     val sql =
       """
diff --git 
a/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp
 
b/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp
index 1619c7410..e7d6e1b9b 100644
--- 
a/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp
+++ 
b/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp
@@ -42,8 +42,5 @@ REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(LastIgnoreNull, 
last_ignore_null, last
 REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(DenseRank, dense_rank, dense_rank)
 REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(Rank, rank, rank)
 REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(RowNumber, row_number, row_number)
-REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(Ntile, ntile, ntile)
-REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(PercentRank, percent_rank, 
percent_rank)
-REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(CumeDist, cume_dist, cume_dist)
 REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(CountDistinct, count_distinct, 
uniqExact)
 }
diff --git 
a/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.cpp 
b/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.cpp
new file mode 100644
index 000000000..49a59c657
--- /dev/null
+++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.cpp
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "NtileParser.h"
+#include <DataTypes/DataTypeNullable.h>
+#include <DataTypes/DataTypesNumber.h>
+#include <Interpreters/ActionsDAG.h>
+
+namespace local_engine
+{
+DB::ActionsDAG::NodeRawConstPtrs
+NtileParser::parseFunctionArguments(const CommonFunctionInfo & func_info, 
const String & /*ch_func_name*/, DB::ActionsDAGPtr & actions_dag) const
+{
+    if (func_info.arguments.size() != 1)
+        throw Exception(ErrorCodes::BAD_ARGUMENTS, "Function ntile takes 
exactly one argument");
+    DB::ActionsDAG::NodeRawConstPtrs args;
+
+    const auto & arg0 = func_info.arguments[0].value();
+    auto [data_type, field] = parseLiteral(arg0.literal());
+    if (!(DB::WhichDataType(data_type).isInt32()))
+        throw Exception(ErrorCodes::BAD_ARGUMENTS, "ntile's argument must be 
i32");
+    Int32 field_index = static_cast<Int32>(field.get<Int32>());
+    // For CH, the data type of the args[0] must be the UInt32
+    const auto * index_node = addColumnToActionsDAG(actions_dag, 
std::make_shared<DataTypeUInt32>(), field_index);
+    args.emplace_back(index_node);
+    return args;
+}
+AggregateFunctionParserRegister<NtileParser> ntile_register;
+}
diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.h 
b/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.h
new file mode 100644
index 000000000..441de2353
--- /dev/null
+++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/NtileParser.h
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+#include <Parser/AggregateFunctionParser.h>
+
+namespace local_engine
+{
+class NtileParser : public AggregateFunctionParser
+{
+public:
+    explicit NtileParser(SerializedPlanParser * plan_parser_) : 
AggregateFunctionParser(plan_parser_) { }
+    ~NtileParser() override = default;
+    static constexpr auto name = "ntile";
+    String getName() const override { return name; }
+    String getCHFunctionName(const CommonFunctionInfo &) const override { 
return "ntile"; }
+    String getCHFunctionName(DB::DataTypes &) const override { return "ntile"; 
}
+    DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments(
+        const CommonFunctionInfo & func_info, const String & ch_func_name, 
DB::ActionsDAGPtr & actions_dag) const override;
+};
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to