This is an automated email from the ASF dual-hosted git repository.
zhangzc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 06450c14e [GLUTEN-5620][CH] Simplify Decimal process for Remainder(%)
operator (#5977)
06450c14e is described below
commit 06450c14eea8c3c086808967198796a6840d4cfe
Author: Chang chen <[email protected]>
AuthorDate: Tue Jun 4 20:37:33 2024 +0800
[GLUTEN-5620][CH] Simplify Decimal process for Remainder(%) operator (#5977)
[CH] Simplify Decimal process for Remainder(%) operator
---
.../execution/GlutenClickHouseDecimalSuite.scala | 52 ++++++++++++++++++++++
.../local-engine/Parser/SerializedPlanParser.cpp | 2 +-
.../Parser/scalar_function_parser/arithmetic.cpp | 22 +++++++++
3 files changed, 75 insertions(+), 1 deletion(-)
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
index c41ea0ccb..3aa498ea3 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
@@ -300,6 +300,39 @@ class GlutenClickHouseDecimalSuite
customCheck = customCheck,
noFallBack = noFallBack)
}
+
+ test("from decimalArithmeticOperations.sql") {
+ // prepare
+ val createSql =
+ "create table decimals_test(id int, a decimal(38,18), b decimal(38,18))
using parquet"
+ val inserts =
+ "insert into decimals_test values(1, 100.0, 999.0)" +
+ ", (2, 12345.123, 12345.123)" +
+ ", (3, 0.1234567891011, 1234.1)" +
+ ", (4, 123456789123456789.0, 1.123456789123456789)"
+ spark.sql(createSql)
+
+ try {
+ spark.sql(inserts)
+
+ val q1 = "select id, a+b, a-b, a*b, a/b ,a%b from decimals_test order by
id"
+
+ // test operations between decimals and constants
+ val q2 = "select id, a*10, b/10 from decimals_test order by id"
+ // FIXME val q2 = "select id, a*10, b/10, a%20, b%30 from decimals_test
order by id"
+
+ Seq("true", "false").foreach {
+ allowPrecisionLoss =>
+ withSQLConf((SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key,
allowPrecisionLoss)) {
+ compareResultsAgainstVanillaSpark(q1, compareResult = true, _ =>
{})
+ compareResultsAgainstVanillaSpark(q2, compareResult = true, _ =>
{})
+ }
+ }
+ } finally {
+ spark.sql("drop table if exists decimals_test")
+ }
+ }
+
Seq("true", "false").foreach {
allowPrecisionLoss =>
Range
@@ -390,6 +423,25 @@ class GlutenClickHouseDecimalSuite
compareResultsAgainstVanillaSpark(sql_not_null, compareResult = true, _ =>
{})
}
+ test("bigint % 6.1") {
+ val sql =
+ s"""
+ | select
+ | s_suppkey,
+ | s_suppkey % 6.1
+ | from supplier
+ |""".stripMargin
+ spark.sql(s"use decimal_${9}_${4}")
+ withSQLConf(vanillaSparkConfs(): _*) {
+ val df2 = spark.sql(sql)
+ print(df2.queryExecution.executedPlan)
+ }
+ testFromRandomBase(
+ sql,
+ _ => {}
+ )
+ }
+
def testFromRandomBase(
sql: String,
customCheck: DataFrame => Unit,
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
index b0d3bbeca..25ea86e5b 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
@@ -1772,7 +1772,7 @@ QueryPlanPtr SerializedPlanParser::parse(const
std::string & plan)
if (logger->debug())
{
auto out = PlanUtil::explainPlan(*res);
- LOG_ERROR(logger, "clickhouse plan:\n{}", out);
+ LOG_DEBUG(logger, "clickhouse plan:\n{}", out);
}
return res;
}
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp
b/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp
index 2a6e43566..d58b22a87 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp
@@ -65,6 +65,13 @@ public:
return bounded_to_click_house(precision, scale);
}
+ static DecimalType evalModuloDecimalType(const Int32 p1, const Int32 s1,
const Int32 p2, const Int32 s2)
+ {
+ const Int32 scale = std::max(s1, s2);
+ const Int32 precision = std::min(p1 - s1, p2 - s2) + scale;
+ return bounded_to_click_house(precision, scale);
+ }
+
static DecimalType evalMultiplyDecimalType(const Int32 p1, const Int32 s1,
const Int32 p2, const Int32 s2)
{
const Int32 scale = s1;
@@ -221,6 +228,20 @@ protected:
}
};
+class FunctionParserModulo final : public FunctionParserBinaryArithmetic
+{
+public:
+ explicit FunctionParserModulo(SerializedPlanParser * plan_parser_) :
FunctionParserBinaryArithmetic(plan_parser_) { }
+ static constexpr auto name = "modulus";
+ String getName() const override { return name; }
+
+protected:
+ DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32
p2, const Int32 s2) const override
+ {
+ return DecimalType::evalModuloDecimalType(p1, s1, p2, s2);
+ }
+};
+
class FunctionParserDivide final : public FunctionParserBinaryArithmetic
{
public:
@@ -252,5 +273,6 @@ static FunctionParserRegister<FunctionParserPlus>
register_plus;
static FunctionParserRegister<FunctionParserMinus> register_minus;
static FunctionParserRegister<FunctionParserMultiply> register_mltiply;
static FunctionParserRegister<FunctionParserDivide> register_divide;
+static FunctionParserRegister<FunctionParserModulo> register_modulo;
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]