This is an automated email from the ASF dual-hosted git repository.
taiyangli pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 399a91b12 [CH] Fix GlutenLiteralExpressionSuite and
GlutenMathExpressionsSuite (#7235)
399a91b12 is described below
commit 399a91b12d572375a7c8c7a617c7f04e2e3612e8
Author: 李扬 <[email protected]>
AuthorDate: Sun Sep 15 17:24:11 2024 +0800
[CH] Fix GlutenLiteralExpressionSuite and GlutenMathExpressionsSuite (#7235)
* fix failed uts
* Update CommonScalarFunctionParser.cpp
* override checkResult for ch backend
---
.../org/apache/gluten/utils/CHExpressionUtil.scala | 3 +-
.../CommonScalarFunctionParser.cpp | 2 -
.../Parser/scalar_function_parser/shift.cpp | 105 +++++++++++++++++++++
.../scalar_function_parser/shiftRightUnsigned.cpp | 19 ++--
.../utils/clickhouse/ClickHouseTestSettings.scala | 26 +----
.../expressions/GlutenLiteralExpressionSuite.scala | 38 +++++++-
.../expressions/GlutenMathExpressionsSuite.scala | 51 ++++++++++
7 files changed, 208 insertions(+), 36 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
index 3c9fa9888..645189310 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
@@ -214,6 +214,7 @@ object CHExpressionUtil {
STACK -> DefaultValidator(),
TRANSFORM_KEYS -> DefaultValidator(),
TRANSFORM_VALUES -> DefaultValidator(),
- RAISE_ERROR -> DefaultValidator()
+ RAISE_ERROR -> DefaultValidator(),
+ WIDTH_BUCKET -> DefaultValidator()
)
}
diff --git
a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
index 37282104c..88e5d7ea8 100644
---
a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
+++
b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
@@ -102,8 +102,6 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Sign, sign, sign);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Radians, radians, radians);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Greatest, greatest, sparkGreatest);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Least, least, sparkLeast);
-REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ShiftLeft, shiftleft, bitShiftLeft);
-REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ShiftRight, shiftright, bitShiftRight);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Rand, rand, randCanonical);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Bin, bin, sparkBin);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Rint, rint, sparkRint);
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/shift.cpp
b/cpp-ch/local-engine/Parser/scalar_function_parser/shift.cpp
new file mode 100644
index 000000000..663bf5e26
--- /dev/null
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/shift.cpp
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <Parser/FunctionParser.h>
+#include <DataTypes/IDataType.h>
+#include <Common/CHUtil.h>
+#include <Core/Field.h>
+#include <DataTypes/DataTypeArray.h>
+#include <DataTypes/DataTypeNullable.h>
+#include <Functions/FunctionHelpers.h>
+
+namespace DB
+{
+namespace ErrorCodes
+{
+ extern const int BAD_ARGUMENTS;
+ extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
+}
+}
+
+namespace local_engine
+{
+
+class FunctionParserShiftBase : public FunctionParser
+{
+public:
+ explicit FunctionParserShiftBase(SerializedPlanParser * plan_parser_) :
FunctionParser(plan_parser_) { }
+ ~FunctionParserShiftBase() override = default;
+
+ virtual String getCHFunctionName() const = 0;
+
+ const ActionsDAG::Node * parse(
+ const substrait::Expression_ScalarFunction & substrait_func,
+ ActionsDAG & actions_dag) const override
+ {
+ /// parse spark shiftxxx(expr, n) as
+ /// If expr has long type -> CH bitShiftxxx(expr, pmod(n, 64))
+ /// Otherwise -> CH bitShiftxxx(expr, pmod(n, 32))
+ auto parsed_args = parseFunctionArguments(substrait_func, actions_dag);
+ if (parsed_args.size() != 2)
+ throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Function {} requires exactly two arguments", getName());
+
+
+ auto input_type = removeNullable(parsed_args[0]->result_type);
+ WhichDataType which(input_type);
+ const ActionsDAG::Node * base_node = nullptr;
+ if (which.isInt64())
+ {
+ base_node = addColumnToActionsDAG(actions_dag,
std::make_shared<DataTypeInt32>(), 64);
+ }
+ else if (which.isInt32())
+ {
+ base_node = addColumnToActionsDAG(actions_dag,
std::make_shared<DataTypeInt32>(), 32);
+ }
+ else
+ throw Exception(DB::ErrorCodes::BAD_ARGUMENTS, "First argument for
function {} must be an long or integer", getName());
+
+ const auto * pmod_node = toFunctionNode(actions_dag, "pmod",
{parsed_args[1], base_node});
+ auto ch_function_name = getCHFunctionName();
+ const auto * shift_node = toFunctionNode(actions_dag,
ch_function_name, {parsed_args[0], pmod_node});
+ return convertNodeTypeIfNeeded(substrait_func, shift_node,
actions_dag);
+ }
+};
+
+class FunctionParserShiftLeft : public FunctionParserShiftBase
+{
+public:
+ explicit FunctionParserShiftLeft(SerializedPlanParser * plan_parser_) :
FunctionParserShiftBase(plan_parser_) { }
+ ~FunctionParserShiftLeft() override = default;
+
+ static constexpr auto name = "shiftleft";
+ String getName() const override { return name; }
+
+ String getCHFunctionName() const override { return "bitShiftLeft"; }
+};
+static FunctionParserRegister<FunctionParserShiftLeft> register_shiftleft;
+
+class FunctionParserShiftRight: public FunctionParserShiftBase
+{
+public:
+ explicit FunctionParserShiftRight(SerializedPlanParser * plan_parser_) :
FunctionParserShiftBase(plan_parser_) { }
+ ~FunctionParserShiftRight() override = default;
+
+ static constexpr auto name = "shiftright";
+ String getName() const override { return name; }
+
+ String getCHFunctionName() const override { return "bitShiftRight"; }
+};
+static FunctionParserRegister<FunctionParserShiftRight> register_shiftright;
+
+
+}
diff --git
a/cpp-ch/local-engine/Parser/scalar_function_parser/shiftRightUnsigned.cpp
b/cpp-ch/local-engine/Parser/scalar_function_parser/shiftRightUnsigned.cpp
index 28288461a..ca88e8522 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/shiftRightUnsigned.cpp
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/shiftRightUnsigned.cpp
@@ -43,9 +43,9 @@ public:
{
/// parse shiftrightunsigned(a, b) as
/// if (isInteger(a))
- /// bitShiftRight(a::UInt32, b::UInt32)
+ /// bitShiftRight(a::UInt32, pmod(b, 32))
/// else if (isLong(a))
- /// bitShiftRight(a::UInt64, b::UInt64)
+ /// bitShiftRight(a::UInt64, pmod(b, 32))
/// else
/// throw Exception
@@ -55,26 +55,27 @@ public:
const auto * a = parsed_args[0];
const auto * b = parsed_args[1];
- const auto * new_a = a;
- const auto * new_b = b;
WhichDataType which(removeNullable(a->result_type));
+ const ActionsDAG::Node * base_node = nullptr;
+ const ActionsDAG::Node * unsigned_a_node = nullptr;
if (which.isInt32())
{
+ base_node = addColumnToActionsDAG(actions_dag,
std::make_shared<DataTypeUInt32>(), 32);
const auto * uint32_type_node = addColumnToActionsDAG(actions_dag,
std::make_shared<DataTypeString>(), "Nullable(UInt32)");
- new_a = toFunctionNode(actions_dag, "CAST", {a, uint32_type_node});
- new_b = toFunctionNode(actions_dag, "CAST", {b, uint32_type_node});
+ unsigned_a_node = toFunctionNode(actions_dag, "CAST", {a,
uint32_type_node});
}
else if (which.isInt64())
{
+ base_node = addColumnToActionsDAG(actions_dag,
std::make_shared<DataTypeUInt32>(), 64);
const auto * uint64_type_node = addColumnToActionsDAG(actions_dag,
std::make_shared<DataTypeString>(), "Nullable(UInt64)");
- new_a = toFunctionNode(actions_dag, "CAST", {a, uint64_type_node});
- new_b = toFunctionNode(actions_dag, "CAST", {b, uint64_type_node});
+ unsigned_a_node = toFunctionNode(actions_dag, "CAST", {a,
uint64_type_node});
}
else
throw Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Function {}
requires integer or long as first argument", getName());
- const auto * result = toFunctionNode(actions_dag, "bitShiftRight",
{new_a, new_b});
+ const auto * pmod_node = toFunctionNode(actions_dag, "pmod", {b,
base_node});
+ const auto * result = toFunctionNode(actions_dag, "bitShiftRight",
{unsigned_a_node, pmod_node});
return convertNodeTypeIfNeeded(substrait_func, result, actions_dag);
}
};
diff --git
a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index f40007957..c260d3f80 100644
---
a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++
b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -788,32 +788,12 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("SPARK-35728: Check multiply/divide of day-time intervals of any
fields by numeric")
.exclude("SPARK-35778: Check multiply/divide of year-month intervals of
any fields by numeric")
enableSuite[GlutenLiteralExpressionSuite]
- .exclude("null")
.exclude("default")
- .exclude("decimal")
- .exclude("array")
- .exclude("seq")
- .exclude("map")
- .exclude("struct")
- .exclude("SPARK-35664: construct literals from java.time.LocalDateTime")
- .exclude("SPARK-34605: construct literals from java.time.Duration")
- .exclude("SPARK-34605: construct literals from arrays of
java.time.Duration")
- .exclude("SPARK-34615: construct literals from java.time.Period")
- .exclude("SPARK-34615: construct literals from arrays of java.time.Period")
- .exclude("SPARK-35871: Literal.create(value, dataType) should support
fields")
.exclude("SPARK-37967: Literal.create support ObjectType")
enableSuite[GlutenMathExpressionsSuite]
- .exclude("tanh")
- .exclude("unhex")
- .exclude("atan2")
- .exclude("round/bround/floor/ceil")
- .exclude("SPARK-36922: Support ANSI intervals for SIGN/SIGNUM")
- .exclude("SPARK-35926: Support YearMonthIntervalType in width-bucket
function")
- .exclude("SPARK-35925: Support DayTimeIntervalType in width-bucket
function")
- .exclude("SPARK-37388: width_bucket")
- .exclude("shift left")
- .exclude("shift right")
- .exclude("shift right unsigned")
+ .exclude("unhex") // https://github.com/apache/incubator-gluten/issues/7232
+ .exclude("round/bround/floor/ceil") //
https://github.com/apache/incubator-gluten/issues/7233
+ .exclude("atan2") // https://github.com/apache/incubator-gluten/issues/7233
enableSuite[GlutenMiscExpressionsSuite]
enableSuite[GlutenNondeterministicSuite]
.exclude("MonotonicallyIncreasingID")
diff --git
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenLiteralExpressionSuite.scala
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenLiteralExpressionSuite.scala
index 556d185af..f81ef0b6f 100644
---
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenLiteralExpressionSuite.scala
+++
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenLiteralExpressionSuite.scala
@@ -17,5 +17,41 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.GlutenTestsTrait
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
-class GlutenLiteralExpressionSuite extends LiteralExpressionSuite with
GlutenTestsTrait {}
+import java.nio.charset.StandardCharsets
+import java.time.{Instant, LocalDate}
+
+class GlutenLiteralExpressionSuite extends LiteralExpressionSuite with
GlutenTestsTrait {
+ testGluten("default") {
+ checkEvaluation(Literal.default(BooleanType), false)
+ checkEvaluation(Literal.default(ByteType), 0.toByte)
+ checkEvaluation(Literal.default(ShortType), 0.toShort)
+ checkEvaluation(Literal.default(IntegerType), 0)
+ checkEvaluation(Literal.default(LongType), 0L)
+ checkEvaluation(Literal.default(FloatType), 0.0f)
+ checkEvaluation(Literal.default(DoubleType), 0.0)
+ checkEvaluation(Literal.default(StringType), "")
+ checkEvaluation(Literal.default(BinaryType),
"".getBytes(StandardCharsets.UTF_8))
+ checkEvaluation(Literal.default(DecimalType.USER_DEFAULT), Decimal(0))
+ checkEvaluation(Literal.default(DecimalType.SYSTEM_DEFAULT), Decimal(0))
+ withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "false") {
+ checkEvaluation(Literal.default(DateType), DateTimeUtils.toJavaDate(0))
+ checkEvaluation(Literal.default(TimestampType),
DateTimeUtils.toJavaTimestamp(0L))
+ }
+ withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") {
+ checkEvaluation(Literal.default(DateType), LocalDate.ofEpochDay(0))
+ checkEvaluation(Literal.default(TimestampType), Instant.ofEpochSecond(0))
+ }
+ checkEvaluation(Literal.default(CalendarIntervalType), new
CalendarInterval(0, 0, 0L))
+ checkEvaluation(Literal.default(YearMonthIntervalType()), 0)
+ checkEvaluation(Literal.default(DayTimeIntervalType()), 0L)
+ checkEvaluation(Literal.default(ArrayType(StringType)), Array())
+ checkEvaluation(Literal.default(MapType(IntegerType, StringType)), Map())
+ checkEvaluation(Literal.default(StructType(StructField("a", StringType) ::
Nil)), Row(""))
+ }
+}
diff --git
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala
index 7085b70ae..a8716b6ef 100644
---
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala
+++
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala
@@ -18,11 +18,44 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.gluten.utils.BackendTestUtils
+import org.apache.spark.sql.GlutenQueryTestUtil.isNaNOrInf
import org.apache.spark.sql.GlutenTestsTrait
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types._
+import org.apache.commons.math3.util.Precision
+
+import java.nio.charset.StandardCharsets
+
class GlutenMathExpressionsSuite extends MathExpressionsSuite with
GlutenTestsTrait {
+ override protected def checkResult(
+ result: Any,
+ expected: Any,
+ exprDataType: DataType,
+ exprNullable: Boolean): Boolean = {
+ if (BackendTestUtils.isVeloxBackendLoaded()) {
+ super.checkResult(result, expected, exprDataType, exprNullable)
+ } else {
+ // The result is null for a non-nullable expression
+ assert(result != null || exprNullable, "exprNullable should be true if
result is null")
+ (result, expected) match {
+ case (result: Double, expected: Double) =>
+ if (
+ (isNaNOrInf(result) || isNaNOrInf(expected))
+ || (result == -0.0) || (expected == -0.0)
+ ) {
+ java.lang.Double.doubleToRawLongBits(result) ==
+ java.lang.Double.doubleToRawLongBits(expected)
+ } else {
+ Precision.equalsWithRelativeTolerance(result, expected, 0.00001d)
||
+ Precision.equals(result, expected, 0.00001d)
+ }
+ case _ =>
+ super.checkResult(result, expected, exprDataType, exprNullable)
+ }
+ }
+ }
+
testGluten("round/bround/floor/ceil") {
val scales = -6 to 6
val doublePi: Double = math.Pi
@@ -284,4 +317,22 @@ class GlutenMathExpressionsSuite extends
MathExpressionsSuite with GlutenTestsTr
checkEvaluation(checkDataTypeAndCast(RoundCeil(Literal(3.1411),
Literal(-3))), Decimal(1000))
checkEvaluation(checkDataTypeAndCast(RoundCeil(Literal(135.135),
Literal(-2))), Decimal(200))
}
+
+ testGluten("unhex") {
+ checkEvaluation(Unhex(Literal.create(null, StringType)), null)
+ checkEvaluation(Unhex(Literal("737472696E67")),
"string".getBytes(StandardCharsets.UTF_8))
+ checkEvaluation(Unhex(Literal("")), new Array[Byte](0))
+ checkEvaluation(Unhex(Literal("F")), Array[Byte](15))
+ checkEvaluation(Unhex(Literal("ff")), Array[Byte](-1))
+
+// checkEvaluation(Unhex(Literal("GG")), null)
+ checkEvaluation(Unhex(Literal("123")), Array[Byte](1, 35))
+ checkEvaluation(Unhex(Literal("12345")), Array[Byte](1, 35, 69))
+ // scalastyle:off
+ // Turn off scala style for non-ascii chars
+ checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")),
"三重的".getBytes(StandardCharsets.UTF_8))
+// checkEvaluation(Unhex(Literal("三重的")), null)
+ // scalastyle:on
+ checkConsistencyBetweenInterpretedAndCodegen(Unhex, StringType)
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]