This is an automated email from the ASF dual-hosted git repository.

liuneng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new ae6a39cae [CH] Support nanvl function (#5199)
ae6a39cae is described below

commit ae6a39cae6805e775ca0ab38bcf4402fd9103bc6
Author: exmy <[email protected]>
AuthorDate: Tue Apr 2 18:11:44 2024 +0800

    [CH] Support nanvl function (#5199)
    
    What changes were proposed in this pull request?
    Support nanvl function
    
    How was this patch tested?
    PASS CI
---
 .../GlutenClickHouseTPCHSaltNullParquetSuite.scala | 13 ++++
 .../Parser/scalar_function_parser/nanvl.cpp        | 83 ++++++++++++++++++++++
 .../gluten/backendsapi/SparkPlanExecApi.scala      |  2 +-
 .../gluten/expression/ExpressionMappings.scala     |  2 +
 .../utils/clickhouse/ClickHouseTestSettings.scala  |  1 -
 .../utils/clickhouse/ClickHouseTestSettings.scala  |  1 -
 6 files changed, 99 insertions(+), 3 deletions(-)

diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
index 8f67195a3..7ee193e88 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
@@ -1146,6 +1146,19 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends 
GlutenClickHouseTPCHAbstr
     compareResultsAgainstVanillaSpark(sql, true, { _ => })
   }
 
+  test("nanvl") {
+    val sql =
+      """
+        |SELECT nanvl(cast('nan' as float), 1f),
+        | nanvl(n_nationkey, cast('null' as double)),
+        | nanvl(cast('null' as double), n_nationkey),
+        | nanvl(n_nationkey, n_nationkey / 0.0d),
+        | nanvl(cast('nan' as float), n_nationkey)
+        | from nation
+        |""".stripMargin
+    runQueryAndCompare(sql)(checkOperatorMatch[ProjectExecTransformer])
+  }
+
   test("test 'sequence'") {
     runQueryAndCompare(
       "select sequence(id, id+10), sequence(id+10, id), sequence(id, id+10, 
3), " +
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/nanvl.cpp 
b/cpp-ch/local-engine/Parser/scalar_function_parser/nanvl.cpp
new file mode 100644
index 000000000..33755ca9e
--- /dev/null
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/nanvl.cpp
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <Parser/FunctionParser.h>
+#include <Common/CHUtil.h>
+#include <Core/Field.h>
+#include <DataTypes/IDataType.h>
+
+namespace DB
+{
+
+namespace ErrorCodes
+{
+    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
+}
+}
+
+namespace local_engine
+{
+
+class FunctionParserNaNvl : public FunctionParser
+{
+public:
+    explicit FunctionParserNaNvl(SerializedPlanParser * plan_parser_) : 
FunctionParser(plan_parser_) {}
+    ~FunctionParserNaNvl() override = default;
+
+    static constexpr auto name = "nanvl";
+
+    String getName() const override { return name; }
+
+    const ActionsDAG::Node * parse(
+        const substrait::Expression_ScalarFunction & substrait_func,
+        ActionsDAGPtr & actions_dag) const override
+    {
+        /*
+            parse nanvl(e1, e2) as
+            if (isNull(e1))
+                null
+            else if (isNaN(e1))
+                e2
+            else
+                e1
+        */
+        auto parsed_args = parseFunctionArguments(substrait_func, "", 
actions_dag);
+        if (parsed_args.size() != 2)
+            throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, 
"Function {} requires at least two arguments", getName());
+
+        const auto * e1 = parsed_args[0];
+        const auto * e2 = parsed_args[1];
+
+        auto result_type = e1->result_type;
+
+        const auto * e1_is_null_node = toFunctionNode(actions_dag, "isNull", 
{e1});
+        const auto * e1_is_nan_node = toFunctionNode(actions_dag, "isNaN", 
{e1});
+
+        const auto * null_const_node = addColumnToActionsDAG(actions_dag, 
makeNullable(result_type), Field());
+        const auto * result_node = toFunctionNode(actions_dag, "multiIf", {
+            e1_is_null_node,
+            null_const_node,
+            e1_is_nan_node,
+            e2,
+            e1
+        });
+        return convertNodeTypeIfNeeded(substrait_func, result_node, 
actions_dag);
+    }
+};
+
+static FunctionParserRegister<FunctionParserNaNvl> register_nanvl;
+}
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index 59244c380..11f02dede 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -205,7 +205,7 @@ trait SparkPlanExecApi {
       left: ExpressionTransformer,
       right: ExpressionTransformer,
       original: NaNvl): ExpressionTransformer = {
-    throw new GlutenNotSupportException("NaNvl is not supported")
+    GenericExpressionTransformer(substraitExprName, Seq(left, right), original)
   }
 
   def genUuidTransformer(substraitExprName: String, original: Uuid): 
ExpressionTransformer = {
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
index ae080efed..ec6fbd047 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
@@ -60,6 +60,8 @@ object ExpressionMappings {
     Sig[IsNull](IS_NULL),
     Sig[Not](NOT),
     Sig[IsNaN](IS_NAN),
+    Sig[NaNvl](NANVL),
+
     // SparkSQL String functions
     Sig[Ascii](ASCII),
     Sig[Chr](CHR),
diff --git 
a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
 
b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index be67c244e..b7c64fa5d 100644
--- 
a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++ 
b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -856,7 +856,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
     .exclude("SparkPartitionID")
     .exclude("InputFileName")
   enableSuite[GlutenNullExpressionsSuite]
-    .exclude("nanvl")
     .exclude("AtLeastNNonNulls")
     .exclude("AtLeastNNonNulls should not throw 64KiB exception")
   enableSuite[GlutenPredicateSuite]
diff --git 
a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
 
b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index 374474f87..60f911a1a 100644
--- 
a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++ 
b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -867,7 +867,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
     .exclude("SparkPartitionID")
     .exclude("InputFileName")
   enableSuite[GlutenNullExpressionsSuite]
-    .exclude("nanvl")
     .exclude("AtLeastNNonNulls")
     .exclude("AtLeastNNonNulls should not throw 64KiB exception")
   enableSuite[GlutenPredicateSuite]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to