This is an automated email from the ASF dual-hosted git repository.

changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 6b1d63c81 [GLUTEN-6975][CH] Fix decimal cast overflow exception
6b1d63c81 is described below

commit 6b1d63c81f20a25fee6bdc715b3308237cbb29ed
Author: Shuai li <[email protected]>
AuthorDate: Wed Sep 25 10:52:18 2024 +0800

    [GLUTEN-6975][CH] Fix decimal cast overflow exception
---
 .../execution/GlutenClickHouseDecimalSuite.scala   | 18 +++++
 .../SparkFunctionDecimalBinaryArithmetic.cpp       | 82 +++++++++++++---------
 .../local-engine/Parser/SerializedPlanParser.cpp   |  7 ++
 3 files changed, 73 insertions(+), 34 deletions(-)

diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
index f772c909c..c0a32a034 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
@@ -335,6 +335,24 @@ class GlutenClickHouseDecimalSuite
       spark.sql("drop table if exists decimals_test")
     }
   }
+
+  test("test castornull") {
+    // prepare
+    val createSql =
+      "create table decimals_cast_test(a decimal(18,8)) using parquet"
+    val inserts =
+      "insert into decimals_cast_test values(123456789.12345678)"
+    spark.sql(createSql)
+
+    try {
+      spark.sql(inserts)
+      val q1 = "select cast(a as decimal(9,2)) from decimals_cast_test"
+      compareResultsAgainstVanillaSpark(q1, compareResult = true, _ => {})
+    } finally {
+      spark.sql("drop table if exists decimals_cast_test")
+    }
+  }
+
   // FIXME: Support AVG for Decimal Type
   Seq("true", "false").foreach {
     allowPrecisionLoss =>
diff --git 
a/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp 
b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp
index 26d8e0deb..a2d5dec73 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp
@@ -121,23 +121,32 @@ public:
 
         if constexpr (Mode == OpMode::Effect)
         {
-            return executeDecimalImpl<LeftDataType, RightDataType, 
ResultDataType>(
+            return executeDecimalImpl<LeftDataType, RightDataType, 
ResultDataType, NativeType<typename ResultDataType::FieldType>>(
                 left, right, col_left_const, col_right_const, col_left, 
col_right, col_left_size, result);
         }
 
         if (calculateWith256<is_plus_minus, is_multiply, is_division, 
is_modulo>(*arguments[0].type.get(), *arguments[1].type.get()))
         {
-            return executeDecimalImpl<LeftDataType, RightDataType, 
ResultDataType, true>(
+            return executeDecimalImpl<LeftDataType, RightDataType, 
ResultDataType, Int256, true>(
                 left, right, col_left_const, col_right_const, col_left, 
col_right, col_left_size, result);
         }
 
-        return executeDecimalImpl<LeftDataType, RightDataType, ResultDataType>(
-            left, right, col_left_const, col_right_const, col_left, col_right, 
col_left_size, result);
+        size_t max_scale = getMaxScaled(left.getScale(), right.getScale(), 
result.getScale());
+        if (is_division && max_scale - left.getScale() + max_scale > 
DataTypeDecimal<typename ResultDataType::FieldType>::maxPrecision())
+        {
+            return executeDecimalImpl<LeftDataType, RightDataType, 
ResultDataType, Int256, true>(
+                left, right, col_left_const, col_right_const, col_left, 
col_right, col_left_size, result);
+        }
+        else
+        {
+            return executeDecimalImpl<LeftDataType, RightDataType, 
ResultDataType, NativeType<typename ResultDataType::FieldType>>(
+                left, right, col_left_const, col_right_const, col_left, 
col_right, col_left_size, result);
+        }
     }
 
 private:
     // ResultDataType e.g. DataTypeDecimal<Decimal32>
-    template <class LeftDataType, class RightDataType, class ResultDataType, 
bool CalculateWith256 = false>
+    template <class LeftDataType, class RightDataType, class ResultDataType, 
class ScaleDataType, bool CalculateWith256 = false>
     static ColumnPtr executeDecimalImpl(
         const auto & left,
         const auto & right,
@@ -152,34 +161,29 @@ private:
         using RightFieldType = typename RightDataType::FieldType;
         using ResultFieldType = typename ResultDataType::FieldType;
 
-        using NativeResultType = NativeType<ResultFieldType>;
         using ColVecResult = ColumnVectorOrDecimal<ResultFieldType>;
 
-        size_t max_scale;
-        if constexpr (is_multiply)
-            max_scale = left.getScale() + right.getScale();
-        else
-            max_scale = std::max(resultDataType.getScale(), 
std::max(left.getScale(), right.getScale()));
+        size_t max_scale = getMaxScaled(left.getScale(), right.getScale(), 
resultDataType.getScale());
 
-        NativeResultType scale_left = [&]
+        ScaleDataType scale_left = [&]
         {
             if constexpr (is_multiply)
-                return NativeResultType{1};
+                return ScaleDataType{1};
 
             // cast scale same to left
             auto diff_scale = max_scale - left.getScale();
             if constexpr (is_division)
-                return 
DecimalUtils::scaleMultiplier<NativeResultType>(diff_scale + max_scale);
+                return DecimalUtils::scaleMultiplier<ScaleDataType>(diff_scale 
+ max_scale);
             else
-                return 
DecimalUtils::scaleMultiplier<NativeResultType>(diff_scale);
+                return 
DecimalUtils::scaleMultiplier<ScaleDataType>(diff_scale);
         }();
 
-        const NativeResultType scale_right = [&]
+        const ScaleDataType scale_right = [&]
         {
             if constexpr (is_multiply)
-                return NativeResultType{1};
+                return ScaleDataType{1};
             else
-                return 
DecimalUtils::scaleMultiplier<NativeResultType>(max_scale - right.getScale());
+                return DecimalUtils::scaleMultiplier<ScaleDataType>(max_scale 
- right.getScale());
         }();
 
 
@@ -266,17 +270,19 @@ private:
         return ColumnNullable::create(std::move(col_res), 
std::move(col_null_map_to));
     }
 
-    template <OpCase op_case, bool CalculateWith256, typename 
ResultContainerType, typename NativeResultType, typename ResultDataType>
+    template <OpCase op_case, bool CalculateWith256, typename 
ResultContainerType, typename ResultDataType, typename ScaleDataType>
     static static void NO_INLINE process(
         const auto & a,
         const auto & b,
         ResultContainerType & result_container,
-        const NativeResultType & scale_a,
-        const NativeResultType & scale_b,
+        const ScaleDataType & scale_a,
+        const ScaleDataType & scale_b,
         ColumnUInt8::Container & vec_null_map_to,
         const ResultDataType & resultDataType,
         const size_t & max_scale)
     {
+        using NativeResultType = NativeType<typename 
ResultDataType::FieldType>;
+
         size_t size;
         if constexpr (op_case == OpCase::LeftConstant)
             size = b.size();
@@ -303,14 +309,14 @@ private:
         }
         else if constexpr (op_case == OpCase::LeftConstant)
         {
-            auto scaled_a = applyScaled(unwrap<op_case, 
OpCase::LeftConstant>(a, 0), scale_a);
+            ScaleDataType scaled_a = applyScaled(unwrap<op_case, 
OpCase::LeftConstant>(a, 0), scale_a);
             for (size_t i = 0; i < size; ++i)
             {
                 NativeResultType res;
                 if (calculate<CalculateWith256>(
                         scaled_a,
                         unwrap<op_case, OpCase::RightConstant>(b, i),
-                        static_cast<NativeResultType>(0),
+                        static_cast<ScaleDataType>(0),
                         scale_b,
                         res,
                         resultDataType,
@@ -322,7 +328,7 @@ private:
         }
         else if constexpr (op_case == OpCase::RightConstant)
         {
-            auto scaled_b = applyScaled(unwrap<op_case, 
OpCase::RightConstant>(b, 0), scale_b);
+            ScaleDataType scaled_b = applyScaled(unwrap<op_case, 
OpCase::RightConstant>(b, 0), scale_b);
 
             for (size_t i = 0; i < size; ++i)
             {
@@ -331,7 +337,7 @@ private:
                         unwrap<op_case, OpCase::LeftConstant>(a, i),
                         scaled_b,
                         scale_a,
-                        static_cast<NativeResultType>(0),
+                        static_cast<ScaleDataType>(0),
                         res,
                         resultDataType,
                         max_scale))
@@ -343,12 +349,12 @@ private:
     }
 
     // ResultNativeType = Int32/64/128/256
-    template <bool CalculateWith256, typename LeftNativeType, typename 
RightNativeType, typename NativeResultType, typename ResultDataType>
+    template <bool CalculateWith256, typename LeftNativeType, typename 
RightNativeType, typename NativeResultType, typename ResultDataType, typename 
ScaleDataType>
     static NO_SANITIZE_UNDEFINED bool calculate(
         const LeftNativeType l,
         const RightNativeType r,
-        const NativeResultType & scale_left,
-        const NativeResultType & scale_right,
+        const ScaleDataType & scale_left,
+        const ScaleDataType & scale_right,
         NativeResultType & res,
         const ResultDataType & resultDataType,
         const size_t & max_scale)
@@ -361,12 +367,12 @@ private:
             return calculateImpl<NativeResultType>(l, r, scale_left, 
scale_right, res, resultDataType, max_scale);
     }
 
-    template <typename CalcType, typename LeftNativeType, typename 
RightNativeType, typename NativeResultType, typename ResultDataType>
+    template <typename CalcType, typename LeftNativeType, typename 
RightNativeType, typename NativeResultType, typename ResultDataType, typename 
ScaleDataType>
     static NO_SANITIZE_UNDEFINED bool calculateImpl(
         const LeftNativeType & l,
         const RightNativeType & r,
-        const NativeResultType & scale_left,
-        const NativeResultType & scale_right,
+        const ScaleDataType & scale_left,
+        const ScaleDataType & scale_right,
         NativeResultType & res,
         const ResultDataType & resultDataType,
         const size_t & max_scale)
@@ -410,13 +416,21 @@ private:
             return elem[i].value;
     }
 
-    template <typename NativeType, typename ResultNativeType>
-    static ResultNativeType applyScaled(const NativeType & l, const 
ResultNativeType & scale)
+    template <typename NativeType, typename ScaleType>
+    static ScaleType applyScaled(const NativeType & l, const ScaleType & scale)
     {
         if (scale > 1)
             return common::mulIgnoreOverflow(l, scale);
 
-        return static_cast<ResultNativeType>(l);
+        return static_cast<ScaleType>(l);
+    }
+
+    static size_t getMaxScaled(const size_t left_scale, const size_t 
right_scale, const size_t result_scale)
+    {
+        if constexpr (is_multiply)
+            return left_scale + right_scale;
+        else
+            return std::max(result_scale, std::max(left_scale, right_scale));
     }
 };
 
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp 
b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
index 75ba2a115..9f90ae364 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
@@ -1032,6 +1032,13 @@ const ActionsDAG::Node * 
SerializedPlanParser::parseExpression(ActionsDAG & acti
                 String function_name = "sparkCastFloatTo" + 
non_nullable_output_type->getName();
                 function_node = toFunctionNode(actions_dag, function_name, 
args);
             }
+            else if ((isDecimal(non_nullable_input_type) && 
substrait_type.has_decimal()))
+            {
+                args.emplace_back(addColumn(actions_dag, 
std::make_shared<DataTypeInt32>(), substrait_type.decimal().precision()));
+                args.emplace_back(addColumn(actions_dag, 
std::make_shared<DataTypeInt32>(), substrait_type.decimal().scale()));
+
+                function_node = toFunctionNode(actions_dag, 
"checkDecimalOverflowSparkOrNull", args);
+            }
             else
             {
                 if (isString(non_nullable_input_type) && 
isInt(non_nullable_output_type))


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to