This is an automated email from the ASF dual-hosted git repository.

taiyangli pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 399a91b12 [CH] Fix GlutenLiteralExpressionSuite and 
GlutenMathExpressionsSuite (#7235)
399a91b12 is described below

commit 399a91b12d572375a7c8c7a617c7f04e2e3612e8
Author: 李扬 <[email protected]>
AuthorDate: Sun Sep 15 17:24:11 2024 +0800

    [CH] Fix GlutenLiteralExpressionSuite and GlutenMathExpressionsSuite (#7235)
    
    * fix failed uts
    
    * Update CommonScalarFunctionParser.cpp
    
    * override checkResult for ch backend
---
 .../org/apache/gluten/utils/CHExpressionUtil.scala |   3 +-
 .../CommonScalarFunctionParser.cpp                 |   2 -
 .../Parser/scalar_function_parser/shift.cpp        | 105 +++++++++++++++++++++
 .../scalar_function_parser/shiftRightUnsigned.cpp  |  19 ++--
 .../utils/clickhouse/ClickHouseTestSettings.scala  |  26 +----
 .../expressions/GlutenLiteralExpressionSuite.scala |  38 +++++++-
 .../expressions/GlutenMathExpressionsSuite.scala   |  51 ++++++++++
 7 files changed, 208 insertions(+), 36 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
index 3c9fa9888..645189310 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
@@ -214,6 +214,7 @@ object CHExpressionUtil {
     STACK -> DefaultValidator(),
     TRANSFORM_KEYS -> DefaultValidator(),
     TRANSFORM_VALUES -> DefaultValidator(),
-    RAISE_ERROR -> DefaultValidator()
+    RAISE_ERROR -> DefaultValidator(),
+    WIDTH_BUCKET -> DefaultValidator()
   )
 }
diff --git 
a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
 
b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
index 37282104c..88e5d7ea8 100644
--- 
a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
+++ 
b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
@@ -102,8 +102,6 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Sign, sign, sign);
 REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Radians, radians, radians);
 REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Greatest, greatest, sparkGreatest);
 REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Least, least, sparkLeast);
-REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ShiftLeft, shiftleft, bitShiftLeft);
-REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ShiftRight, shiftright, bitShiftRight);
 REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Rand, rand, randCanonical);
 REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Bin, bin, sparkBin);
 REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Rint, rint, sparkRint);
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/shift.cpp 
b/cpp-ch/local-engine/Parser/scalar_function_parser/shift.cpp
new file mode 100644
index 000000000..663bf5e26
--- /dev/null
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/shift.cpp
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <Parser/FunctionParser.h>
+#include <DataTypes/IDataType.h>
+#include <Common/CHUtil.h>
+#include <Core/Field.h>
+#include <DataTypes/DataTypeArray.h>
+#include <DataTypes/DataTypeNullable.h>
+#include <Functions/FunctionHelpers.h>
+
+namespace DB
+{
+namespace ErrorCodes
+{
+    extern const int BAD_ARGUMENTS;
+    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
+}
+}
+
+namespace local_engine
+{
+
+class FunctionParserShiftBase : public FunctionParser
+{
+public:
+    explicit FunctionParserShiftBase(SerializedPlanParser * plan_parser_) : 
FunctionParser(plan_parser_) { }
+    ~FunctionParserShiftBase() override = default;
+
+    virtual String getCHFunctionName() const = 0;
+
+    const ActionsDAG::Node * parse(
+        const substrait::Expression_ScalarFunction & substrait_func,
+        ActionsDAG & actions_dag) const override
+    {
+        /// parse spark shiftxxx(expr, n) as
+        /// If expr has long type -> CH bitShiftxxx(expr, pmod(n, 64))
+        /// Otherwise             -> CH bitShiftxxx(expr, pmod(n, 32))
+        auto parsed_args = parseFunctionArguments(substrait_func, actions_dag);
+        if (parsed_args.size() != 2)
+            throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, 
"Function {} requires exactly two arguments", getName());
+
+
+        auto input_type = removeNullable(parsed_args[0]->result_type);
+        WhichDataType which(input_type);
+        const ActionsDAG::Node * base_node = nullptr;
+        if (which.isInt64())
+        {
+            base_node = addColumnToActionsDAG(actions_dag, 
std::make_shared<DataTypeInt32>(), 64);
+        }
+        else if (which.isInt32())
+        {
+            base_node = addColumnToActionsDAG(actions_dag, 
std::make_shared<DataTypeInt32>(), 32);
+        }
+        else
+            throw Exception(DB::ErrorCodes::BAD_ARGUMENTS, "First argument for 
function {} must be an long or integer", getName());
+
+        const auto * pmod_node = toFunctionNode(actions_dag, "pmod", 
{parsed_args[1], base_node});
+        auto ch_function_name = getCHFunctionName();
+        const auto * shift_node = toFunctionNode(actions_dag, 
ch_function_name, {parsed_args[0], pmod_node});
+        return convertNodeTypeIfNeeded(substrait_func, shift_node, 
actions_dag);
+    }
+};
+
+class FunctionParserShiftLeft : public FunctionParserShiftBase
+{
+public:
+    explicit FunctionParserShiftLeft(SerializedPlanParser * plan_parser_) : 
FunctionParserShiftBase(plan_parser_) { }
+    ~FunctionParserShiftLeft() override = default;
+
+    static constexpr auto name = "shiftleft";
+    String getName() const override { return name; }
+
+    String getCHFunctionName() const override { return "bitShiftLeft"; }
+};
+static FunctionParserRegister<FunctionParserShiftLeft> register_shiftleft;
+
+class FunctionParserShiftRight: public FunctionParserShiftBase
+{
+public:
+    explicit FunctionParserShiftRight(SerializedPlanParser * plan_parser_) : 
FunctionParserShiftBase(plan_parser_) { }
+    ~FunctionParserShiftRight() override = default;
+
+    static constexpr auto name = "shiftright";
+    String getName() const override { return name; }
+
+    String getCHFunctionName() const override { return "bitShiftRight"; }
+};
+static FunctionParserRegister<FunctionParserShiftRight> register_shiftright;
+
+
+}
diff --git 
a/cpp-ch/local-engine/Parser/scalar_function_parser/shiftRightUnsigned.cpp 
b/cpp-ch/local-engine/Parser/scalar_function_parser/shiftRightUnsigned.cpp
index 28288461a..ca88e8522 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/shiftRightUnsigned.cpp
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/shiftRightUnsigned.cpp
@@ -43,9 +43,9 @@ public:
     {
         /// parse shiftrightunsigned(a, b) as
         /// if (isInteger(a))
-        ///   bitShiftRight(a::UInt32, b::UInt32)
+        ///   bitShiftRight(a::UInt32, pmod(b, 32))
         /// else if (isLong(a))
-        ///   bitShiftRight(a::UInt64, b::UInt64)
+        ///   bitShiftRight(a::UInt64, pmod(b, 32))
         /// else
         ///   throw Exception
 
@@ -55,26 +55,27 @@ public:
 
         const auto * a = parsed_args[0];
         const auto * b = parsed_args[1];
-        const auto * new_a = a;
-        const auto * new_b = b;
 
         WhichDataType which(removeNullable(a->result_type));
+        const ActionsDAG::Node * base_node = nullptr;
+        const ActionsDAG::Node * unsigned_a_node = nullptr;
         if (which.isInt32())
         {
+            base_node = addColumnToActionsDAG(actions_dag, 
std::make_shared<DataTypeUInt32>(), 32);
             const auto * uint32_type_node = addColumnToActionsDAG(actions_dag, 
std::make_shared<DataTypeString>(), "Nullable(UInt32)");
-            new_a = toFunctionNode(actions_dag, "CAST", {a, uint32_type_node});
-            new_b = toFunctionNode(actions_dag, "CAST", {b, uint32_type_node});
+            unsigned_a_node = toFunctionNode(actions_dag, "CAST", {a, 
uint32_type_node});
         }
         else if (which.isInt64())
         {
+            base_node = addColumnToActionsDAG(actions_dag, 
std::make_shared<DataTypeUInt32>(), 64);
             const auto * uint64_type_node = addColumnToActionsDAG(actions_dag, 
std::make_shared<DataTypeString>(), "Nullable(UInt64)");
-            new_a = toFunctionNode(actions_dag, "CAST", {a, uint64_type_node});
-            new_b = toFunctionNode(actions_dag, "CAST", {b, uint64_type_node});
+            unsigned_a_node = toFunctionNode(actions_dag, "CAST", {a, 
uint64_type_node});
         }
         else
             throw Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Function {} 
requires integer or long as first argument", getName());
 
-        const auto * result = toFunctionNode(actions_dag, "bitShiftRight", 
{new_a, new_b});
+        const auto * pmod_node = toFunctionNode(actions_dag, "pmod", {b, 
base_node});
+        const auto * result = toFunctionNode(actions_dag, "bitShiftRight", 
{unsigned_a_node, pmod_node});
         return convertNodeTypeIfNeeded(substrait_func, result, actions_dag);
     }
 };
diff --git 
a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
 
b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index f40007957..c260d3f80 100644
--- 
a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++ 
b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -788,32 +788,12 @@ class ClickHouseTestSettings extends BackendTestSettings {
     .exclude("SPARK-35728: Check multiply/divide of day-time intervals of any 
fields by numeric")
     .exclude("SPARK-35778: Check multiply/divide of year-month intervals of 
any fields by numeric")
   enableSuite[GlutenLiteralExpressionSuite]
-    .exclude("null")
     .exclude("default")
-    .exclude("decimal")
-    .exclude("array")
-    .exclude("seq")
-    .exclude("map")
-    .exclude("struct")
-    .exclude("SPARK-35664: construct literals from java.time.LocalDateTime")
-    .exclude("SPARK-34605: construct literals from java.time.Duration")
-    .exclude("SPARK-34605: construct literals from arrays of 
java.time.Duration")
-    .exclude("SPARK-34615: construct literals from java.time.Period")
-    .exclude("SPARK-34615: construct literals from arrays of java.time.Period")
-    .exclude("SPARK-35871: Literal.create(value, dataType) should support 
fields")
     .exclude("SPARK-37967: Literal.create support ObjectType")
   enableSuite[GlutenMathExpressionsSuite]
-    .exclude("tanh")
-    .exclude("unhex")
-    .exclude("atan2")
-    .exclude("round/bround/floor/ceil")
-    .exclude("SPARK-36922: Support ANSI intervals for SIGN/SIGNUM")
-    .exclude("SPARK-35926: Support YearMonthIntervalType in width-bucket 
function")
-    .exclude("SPARK-35925: Support DayTimeIntervalType in width-bucket 
function")
-    .exclude("SPARK-37388: width_bucket")
-    .exclude("shift left")
-    .exclude("shift right")
-    .exclude("shift right unsigned")
+    .exclude("unhex") // https://github.com/apache/incubator-gluten/issues/7232
+    .exclude("round/bround/floor/ceil") // 
https://github.com/apache/incubator-gluten/issues/7233
+    .exclude("atan2") // https://github.com/apache/incubator-gluten/issues/7233
   enableSuite[GlutenMiscExpressionsSuite]
   enableSuite[GlutenNondeterministicSuite]
     .exclude("MonotonicallyIncreasingID")
diff --git 
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenLiteralExpressionSuite.scala
 
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenLiteralExpressionSuite.scala
index 556d185af..f81ef0b6f 100644
--- 
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenLiteralExpressionSuite.scala
+++ 
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenLiteralExpressionSuite.scala
@@ -17,5 +17,41 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.GlutenTestsTrait
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
 
-class GlutenLiteralExpressionSuite extends LiteralExpressionSuite with 
GlutenTestsTrait {}
+import java.nio.charset.StandardCharsets
+import java.time.{Instant, LocalDate}
+
+class GlutenLiteralExpressionSuite extends LiteralExpressionSuite with 
GlutenTestsTrait {
+  testGluten("default") {
+    checkEvaluation(Literal.default(BooleanType), false)
+    checkEvaluation(Literal.default(ByteType), 0.toByte)
+    checkEvaluation(Literal.default(ShortType), 0.toShort)
+    checkEvaluation(Literal.default(IntegerType), 0)
+    checkEvaluation(Literal.default(LongType), 0L)
+    checkEvaluation(Literal.default(FloatType), 0.0f)
+    checkEvaluation(Literal.default(DoubleType), 0.0)
+    checkEvaluation(Literal.default(StringType), "")
+    checkEvaluation(Literal.default(BinaryType), 
"".getBytes(StandardCharsets.UTF_8))
+    checkEvaluation(Literal.default(DecimalType.USER_DEFAULT), Decimal(0))
+    checkEvaluation(Literal.default(DecimalType.SYSTEM_DEFAULT), Decimal(0))
+    withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "false") {
+      checkEvaluation(Literal.default(DateType), DateTimeUtils.toJavaDate(0))
+      checkEvaluation(Literal.default(TimestampType), 
DateTimeUtils.toJavaTimestamp(0L))
+    }
+    withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") {
+      checkEvaluation(Literal.default(DateType), LocalDate.ofEpochDay(0))
+      checkEvaluation(Literal.default(TimestampType), Instant.ofEpochSecond(0))
+    }
+    checkEvaluation(Literal.default(CalendarIntervalType), new 
CalendarInterval(0, 0, 0L))
+    checkEvaluation(Literal.default(YearMonthIntervalType()), 0)
+    checkEvaluation(Literal.default(DayTimeIntervalType()), 0L)
+    checkEvaluation(Literal.default(ArrayType(StringType)), Array())
+    checkEvaluation(Literal.default(MapType(IntegerType, StringType)), Map())
+    checkEvaluation(Literal.default(StructType(StructField("a", StringType) :: 
Nil)), Row(""))
+  }
+}
diff --git 
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala
 
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala
index 7085b70ae..a8716b6ef 100644
--- 
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala
+++ 
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala
@@ -18,11 +18,44 @@ package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.gluten.utils.BackendTestUtils
 
+import org.apache.spark.sql.GlutenQueryTestUtil.isNaNOrInf
 import org.apache.spark.sql.GlutenTestsTrait
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.types._
 
+import org.apache.commons.math3.util.Precision
+
+import java.nio.charset.StandardCharsets
+
 class GlutenMathExpressionsSuite extends MathExpressionsSuite with 
GlutenTestsTrait {
+  override protected def checkResult(
+      result: Any,
+      expected: Any,
+      exprDataType: DataType,
+      exprNullable: Boolean): Boolean = {
+    if (BackendTestUtils.isVeloxBackendLoaded()) {
+      super.checkResult(result, expected, exprDataType, exprNullable)
+    } else {
+      // The result is null for a non-nullable expression
+      assert(result != null || exprNullable, "exprNullable should be true if 
result is null")
+      (result, expected) match {
+        case (result: Double, expected: Double) =>
+          if (
+            (isNaNOrInf(result) || isNaNOrInf(expected))
+            || (result == -0.0) || (expected == -0.0)
+          ) {
+            java.lang.Double.doubleToRawLongBits(result) ==
+              java.lang.Double.doubleToRawLongBits(expected)
+          } else {
+            Precision.equalsWithRelativeTolerance(result, expected, 0.00001d) 
||
+            Precision.equals(result, expected, 0.00001d)
+          }
+        case _ =>
+          super.checkResult(result, expected, exprDataType, exprNullable)
+      }
+    }
+  }
+
   testGluten("round/bround/floor/ceil") {
     val scales = -6 to 6
     val doublePi: Double = math.Pi
@@ -284,4 +317,22 @@ class GlutenMathExpressionsSuite extends 
MathExpressionsSuite with GlutenTestsTr
     checkEvaluation(checkDataTypeAndCast(RoundCeil(Literal(3.1411), 
Literal(-3))), Decimal(1000))
     checkEvaluation(checkDataTypeAndCast(RoundCeil(Literal(135.135), 
Literal(-2))), Decimal(200))
   }
+
+  testGluten("unhex") {
+    checkEvaluation(Unhex(Literal.create(null, StringType)), null)
+    checkEvaluation(Unhex(Literal("737472696E67")), 
"string".getBytes(StandardCharsets.UTF_8))
+    checkEvaluation(Unhex(Literal("")), new Array[Byte](0))
+    checkEvaluation(Unhex(Literal("F")), Array[Byte](15))
+    checkEvaluation(Unhex(Literal("ff")), Array[Byte](-1))
+
+//    checkEvaluation(Unhex(Literal("GG")), null)
+    checkEvaluation(Unhex(Literal("123")), Array[Byte](1, 35))
+    checkEvaluation(Unhex(Literal("12345")), Array[Byte](1, 35, 69))
+    // scalastyle:off
+    // Turn off scala style for non-ascii chars
+    checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), 
"三重的".getBytes(StandardCharsets.UTF_8))
+//    checkEvaluation(Unhex(Literal("三重的")), null)
+    // scalastyle:on
+    checkConsistencyBetweenInterpretedAndCodegen(Unhex, StringType)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to