This is an automated email from the ASF dual-hosted git repository.

changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 55671cbf7c [GLUTEN-6975][CH] Rewrite decimal arithmetic (#7196)
55671cbf7c is described below

commit 55671cbf7c8bcd945b97869110872bf949589438
Author: Shuai li <[email protected]>
AuthorDate: Mon Sep 23 14:22:21 2024 +0800

    [GLUTEN-6975][CH] Rewrite decimal arithmetic (#7196)
    
    * enable tpchq1
---
 .../backendsapi/clickhouse/CHListenerApi.scala     |   4 +-
 .../execution/GlutenClickHouseDecimalSuite.scala   |   7 +-
 .../AggregateFunctionSparkAvg.cpp                  |  36 +-
 cpp-ch/local-engine/Common/GlutenDecimalUtils.h    |   7 -
 .../SparkFunctionDecimalBinaryArithmetic.cpp       | 591 +++++++++++++++++++++
 .../SparkFunctionDecimalBinaryArithmetic.h         | 276 ++++++++++
 cpp-ch/local-engine/Parser/SerializedPlanParser.h  |   1 +
 .../Parser/scalar_function_parser/arithmetic.cpp   | 160 ++++--
 .../gluten/expression/ExpressionConverter.scala    |  16 +-
 9 files changed, 1042 insertions(+), 56 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
index 58acda88fb..8065d35c85 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
@@ -25,7 +25,7 @@ import org.apache.gluten.extension.ExpressionExtensionTrait
 import org.apache.gluten.jni.JniLibLoader
 import org.apache.gluten.vectorized.CHNativeExpressionEvaluator
 
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext}
 import org.apache.spark.api.plugin.PluginContext
 import org.apache.spark.internal.Logging
 import org.apache.spark.listener.CHGlutenSQLAppStatusListener
@@ -85,7 +85,7 @@ class CHListenerApi extends ListenerApi with Logging {
     conf.setCHConfig(
       "timezone" -> conf.get("spark.sql.session.timeZone", 
TimeZone.getDefault.getID),
       "local_engine.settings.log_processors_profiles" -> "true")
-
+    conf.setCHSettings("spark_version", SPARK_VERSION)
     // add memory limit for external sort
     val externalSortKey = 
CHConf.runtimeSettings("max_bytes_before_external_sort")
     if (conf.getLong(externalSortKey, -1) < 0) {
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
index 40f442bc29..f772c909c8 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala
@@ -69,7 +69,7 @@ class GlutenClickHouseDecimalSuite
     (DecimalType.apply(18, 8), Seq()),
     // 3/10: all value is null and compare with limit
     // 1 Spark 3.5
-    (DecimalType.apply(38, 19), if (isSparkVersionLE("3.3")) Seq(3, 10) else 
Seq(1, 3, 10))
+    (DecimalType.apply(38, 19), if (isSparkVersionLE("3.3")) Seq(3, 10) else 
Seq(3, 10))
   )
 
   private def createDecimalTables(dataType: DecimalType): Unit = {
@@ -309,7 +309,10 @@ class GlutenClickHouseDecimalSuite
       "insert into decimals_test values(1, 100.0, 999.0)" +
         ", (2, 12345.123, 12345.123)" +
         ", (3, 0.1234567891011, 1234.1)" +
-        ", (4, 123456789123456789.0, 1.123456789123456789)"
+        ", (4, 123456789123456789.0, 1.123456789123456789)" +
+        ", (5, 0, 0)" +
+        ", (6, 0, 1.23)" +
+        ", (7, 1.23, 0)"
     spark.sql(createSql)
 
     try {
diff --git 
a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp 
b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp
index 0aa2331457..524a39d3a1 100644
--- a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp
+++ b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp
@@ -24,6 +24,7 @@
 
 #include <Common/CHUtil.h>
 #include <Common/GlutenDecimalUtils.h>
+#include <Common/GlutenSettings.h>
 
 namespace DB
 {
@@ -47,7 +48,7 @@ DataTypePtr getSparkAvgReturnType(const DataTypePtr & 
arg_type)
     return createDecimal<DataTypeDecimal>(precision_value, scale_value);
 }
 
-template <typename T>
+template <typename T, bool SPARK35>
 requires is_decimal<T>
 class AggregateFunctionSparkAvg final : public AggregateFunctionAvg<T>
 {
@@ -61,7 +62,7 @@ public:
     {
     }
 
-    DataTypePtr createResultType(const DataTypes & argument_types_, UInt32 
num_scale_, UInt32 round_scale_)
+    DataTypePtr createResultType(const DataTypes & argument_types_, UInt32 
num_scale_, UInt32 /*round_scale_*/)
     {
         const DataTypePtr & data_type = argument_types_[0];
         const UInt32 precision_value = 
std::min<size_t>(getDecimalPrecision(*data_type) + 4, 
DecimalUtils::max_precision<Decimal128>);
@@ -82,7 +83,7 @@ public:
         else if (which.isDecimal64())
         {
             assert_cast<ColumnDecimal<Decimal64> &>(to).getData().push_back(
-          divideDecimalAndUInt(this->data(place), num_scale, result_scale, 
round_scale));
+                divideDecimalAndUInt(this->data(place), num_scale, 
result_scale, round_scale));
         }
         else if (which.isDecimal128())
         {
@@ -116,6 +117,9 @@ private:
 
         auto result = value / avg.denominator;
 
+        if constexpr (SPARK35)
+            return result;
+
         if (round_scale > result_scale)
             return result;
 
@@ -128,8 +132,21 @@ private:
     UInt32 round_scale;
 };
 
-AggregateFunctionPtr
-createAggregateFunctionSparkAvg(const std::string & name, const DataTypes & 
argument_types, const Array & parameters, const Settings * settings)
+template <bool Data, typename... TArgs>
+static IAggregateFunction * createWithDecimalType(const IDataType & 
argument_type, TArgs && ... args)
+{
+    WhichDataType which(argument_type);
+    if (which.idx == TypeIndex::Decimal32) return new 
AggregateFunctionSparkAvg<Decimal32, Data>(args...);
+    if (which.idx == TypeIndex::Decimal64) return new 
AggregateFunctionSparkAvg<Decimal64, Data>(args...);
+    if (which.idx == TypeIndex::Decimal128) return new 
AggregateFunctionSparkAvg<Decimal128, Data>(args...);
+    if (which.idx == TypeIndex::Decimal256) return new 
AggregateFunctionSparkAvg<Decimal256, Data>(args...);
+    if constexpr (AggregateFunctionSparkAvg<DateTime64, 
Data>::DateTime64Supported)
+        if (which.idx == TypeIndex::DateTime64) return new 
AggregateFunctionSparkAvg<DateTime64, Data>(args...);
+    return nullptr;
+}
+
+AggregateFunctionPtr createAggregateFunctionSparkAvg(
+    const std::string & name, const DataTypes & argument_types, const Array & 
parameters, const Settings * settings)
 {
     assertNoParameters(name, parameters);
     assertUnary(name, argument_types);
@@ -140,13 +157,20 @@ createAggregateFunctionSparkAvg(const std::string & name, 
const DataTypes & argu
         throw Exception(
             ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument 
for aggregate function {}", data_type->getName(), name);
 
+    std::string version;
+    if (tryGetString(*settings, "spark_version", version) && 
version.starts_with("3.5"))
+    {
+        res.reset(createWithDecimalType<true>(*data_type, argument_types, 
getDecimalScale(*data_type), 0));
+        return res;
+    }
+
     bool allowPrecisionLoss = 
settings->get(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS).safeGet<bool>();
     const UInt32 p1 = DB::getDecimalPrecision(*data_type);
     const UInt32 s1 = DB::getDecimalScale(*data_type);
     auto [p2, s2] = GlutenDecimalUtils::LONG_DECIMAL;
     auto [_, round_scale] = GlutenDecimalUtils::dividePrecisionScale(p1, s1, 
p2, s2, allowPrecisionLoss);
 
-    res.reset(createWithDecimalType<AggregateFunctionSparkAvg>(*data_type, 
argument_types, getDecimalScale(*data_type), round_scale));
+    res.reset(createWithDecimalType<false>(*data_type, argument_types, 
getDecimalScale(*data_type), round_scale));
     return res;
 }
 
diff --git a/cpp-ch/local-engine/Common/GlutenDecimalUtils.h 
b/cpp-ch/local-engine/Common/GlutenDecimalUtils.h
index 32af66ec59..cf600a5cc9 100644
--- a/cpp-ch/local-engine/Common/GlutenDecimalUtils.h
+++ b/cpp-ch/local-engine/Common/GlutenDecimalUtils.h
@@ -95,13 +95,6 @@ public:
         }
     }
 
-    static std::tuple<size_t, size_t> widerDecimalType(const size_t p1, const 
size_t s1, const size_t p2, const size_t s2)
-    {
-        // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
-        auto scale = std::max(s1, s2);
-        auto range = std::max(p1 - s1, p2 - s2);
-        return std::tuple(range + scale, scale);
-    }
 
 };
 
diff --git 
a/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp 
b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp
new file mode 100644
index 0000000000..26d8e0deb3
--- /dev/null
+++ b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp
@@ -0,0 +1,591 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "SparkFunctionDecimalBinaryArithmetic.h"
+
+#include <Columns/ColumnDecimal.h>
+#include <Columns/ColumnNullable.h>
+#include <Columns/ColumnsNumber.h>
+#include <DataTypes/DataTypeNullable.h>
+#include <DataTypes/DataTypesDecimal.h>
+#include <Functions/FunctionFactory.h>
+#include <Functions/FunctionHelpers.h>
+#include <Functions/IFunction.h>
+#include <Functions/castTypeToEither.h>
+#include <Common/CurrentThread.h>
+#include <Common/GlutenDecimalUtils.h>
+#include <Common/ProfileEvents.h>
+#include <Common/Stopwatch.h>
+
+namespace DB
+{
+namespace ErrorCodes
+{
+extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
+extern const int ILLEGAL_TYPE_OF_ARGUMENT;
+extern const int ILLEGAL_COLUMN;
+extern const int TYPE_MISMATCH;
+extern const int LOGICAL_ERROR;
+}
+
+}
+
+namespace local_engine
+{
+using namespace DB;
+
+namespace
+{
+enum class OpCase : uint8_t
+{
+    Vector,
+    LeftConstant,
+    RightConstant
+};
+
+enum class OpMode : uint8_t
+{
+    Default,
+    Effect
+};
+
+template <bool is_plus_minus, bool is_multiply, bool is_division, bool 
is_modulo>
+bool calculateWith256(const IDataType & left, const IDataType & right)
+{
+    const size_t p1 = getDecimalPrecision(left);
+    const size_t s1 = getDecimalScale(left);
+    const size_t p2 = getDecimalPrecision(right);
+    const size_t s2 = getDecimalScale(right);
+
+    size_t precision;
+    if constexpr (is_plus_minus)
+        precision = std::max(s1, s2) + std::max(p1 - s1, p2 - s2) + 1;
+    else if constexpr (is_multiply)
+        precision = p1 + p2 + 1;
+    else if constexpr (is_division)
+        precision = p1 - s1 + s2 + std::max(static_cast<size_t>(6), s1 + p2 + 
1);
+    else if constexpr (is_modulo)
+        precision = std::min(p1 - s1, p2 - s2) + std::max(s1, s2);
+    else
+        throw Exception(ErrorCodes::LOGICAL_ERROR, "Not supported.");
+
+    return precision > DataTypeDecimal128::maxPrecision();
+}
+
+template <typename Operation, OpMode Mode>
+struct SparkDecimalBinaryOperation
+{
+private:
+    static constexpr bool is_plus_minus = SparkIsOperation<Operation>::plus || 
SparkIsOperation<Operation>::minus;
+    static constexpr bool is_multiply = SparkIsOperation<Operation>::multiply;
+    static constexpr bool is_division = SparkIsOperation<Operation>::division;
+    static constexpr bool is_modulo = SparkIsOperation<Operation>::modulo;
+
+public:
+    template <typename A, typename B, typename R>
+    static ColumnPtr executeDecimal(const ColumnsWithTypeAndName & arguments, 
const A & left, const B & right, const R & result)
+    {
+        using LeftDataType = std::decay_t<decltype(left)>; // e.g. 
DataTypeDecimal<Decimal32>
+        using RightDataType = std::decay_t<decltype(right)>; // e.g. 
DataTypeDecimal<Decimal32>
+        using ResultDataType = std::decay_t<decltype(result)>; // e.g. 
DataTypeDecimal<Decimal32>
+
+        using ColVecLeft = ColumnVectorOrDecimal<typename 
LeftDataType::FieldType>;
+        using ColVecRight = ColumnVectorOrDecimal<typename 
RightDataType::FieldType>;
+
+        const ColumnPtr left_col = arguments[0].column;
+        const ColumnPtr right_col = arguments[1].column;
+
+        const auto * const col_left_raw = left_col.get();
+        const auto * const col_right_raw = right_col.get();
+
+        const size_t col_left_size = col_left_raw->size();
+
+        const ColumnConst * const col_left_const = 
checkAndGetColumnConst<ColVecLeft>(col_left_raw);
+        const ColumnConst * const col_right_const = 
checkAndGetColumnConst<ColVecRight>(col_right_raw);
+
+        const ColVecLeft * const col_left = 
checkAndGetColumn<ColVecLeft>(col_left_raw);
+        const ColVecRight * const col_right = 
checkAndGetColumn<ColVecRight>(col_right_raw);
+
+        if constexpr (Mode == OpMode::Effect)
+        {
+            return executeDecimalImpl<LeftDataType, RightDataType, 
ResultDataType>(
+                left, right, col_left_const, col_right_const, col_left, 
col_right, col_left_size, result);
+        }
+
+        if (calculateWith256<is_plus_minus, is_multiply, is_division, 
is_modulo>(*arguments[0].type.get(), *arguments[1].type.get()))
+        {
+            return executeDecimalImpl<LeftDataType, RightDataType, 
ResultDataType, true>(
+                left, right, col_left_const, col_right_const, col_left, 
col_right, col_left_size, result);
+        }
+
+        return executeDecimalImpl<LeftDataType, RightDataType, ResultDataType>(
+            left, right, col_left_const, col_right_const, col_left, col_right, 
col_left_size, result);
+    }
+
+private:
+    // ResultDataType e.g. DataTypeDecimal<Decimal32>
+    template <class LeftDataType, class RightDataType, class ResultDataType, 
bool CalculateWith256 = false>
+    static ColumnPtr executeDecimalImpl(
+        const auto & left,
+        const auto & right,
+        const ColumnConst * const col_left_const,
+        const ColumnConst * const col_right_const,
+        const auto * const col_left,
+        const auto * const col_right,
+        size_t col_left_size,
+        const ResultDataType & resultDataType)
+    {
+        using LeftFieldType = typename LeftDataType::FieldType;
+        using RightFieldType = typename RightDataType::FieldType;
+        using ResultFieldType = typename ResultDataType::FieldType;
+
+        using NativeResultType = NativeType<ResultFieldType>;
+        using ColVecResult = ColumnVectorOrDecimal<ResultFieldType>;
+
+        size_t max_scale;
+        if constexpr (is_multiply)
+            max_scale = left.getScale() + right.getScale();
+        else
+            max_scale = std::max(resultDataType.getScale(), 
std::max(left.getScale(), right.getScale()));
+
+        NativeResultType scale_left = [&]
+        {
+            if constexpr (is_multiply)
+                return NativeResultType{1};
+
+            // cast scale same to left
+            auto diff_scale = max_scale - left.getScale();
+            if constexpr (is_division)
+                return 
DecimalUtils::scaleMultiplier<NativeResultType>(diff_scale + max_scale);
+            else
+                return 
DecimalUtils::scaleMultiplier<NativeResultType>(diff_scale);
+        }();
+
+        const NativeResultType scale_right = [&]
+        {
+            if constexpr (is_multiply)
+                return NativeResultType{1};
+            else
+                return 
DecimalUtils::scaleMultiplier<NativeResultType>(max_scale - right.getScale());
+        }();
+
+
+        bool calculate_with_256 = false;
+        if constexpr (CalculateWith256)
+            calculate_with_256 = true;
+        else
+        {
+            auto p1 = left.getPrecision();
+            auto p2 = right.getPrecision();
+            if (DataTypeDecimal<LeftFieldType>::maxPrecision() < p1 + 
max_scale - left.getScale()
+                || DataTypeDecimal<RightFieldType>::maxPrecision() < p2 + 
max_scale - right.getScale())
+                calculate_with_256 = true;
+        }
+
+        ColumnUInt8::MutablePtr col_null_map_to = 
ColumnUInt8::create(col_left_size, false);
+        ColumnUInt8::Container * vec_null_map_to = &col_null_map_to->getData();
+
+        typename ColVecResult::MutablePtr col_res = ColVecResult::create(0, 
resultDataType.getScale());
+        auto & vec_res = col_res->getData();
+        vec_res.resize(col_left_size);
+
+        if (col_left && col_right)
+        {
+            if (calculate_with_256)
+            {
+                process<OpCase::Vector, true>(
+                    col_left->getData(),
+                    col_right->getData(),
+                    vec_res,
+                    scale_left,
+                    scale_right,
+                    *vec_null_map_to,
+                    resultDataType,
+                    max_scale);
+            }
+            else
+            {
+                process<OpCase::Vector, false>(
+                    col_left->getData(),
+                    col_right->getData(),
+                    vec_res,
+                    scale_left,
+                    scale_right,
+                    *vec_null_map_to,
+                    resultDataType,
+                    max_scale);
+            }
+        }
+        else if (col_left_const && col_right)
+        {
+            LeftFieldType const_left = 
col_left_const->getValue<LeftFieldType>();
+
+            if (calculate_with_256)
+            {
+                process<OpCase::LeftConstant, true>(
+                    const_left, col_right->getData(), vec_res, scale_left, 
scale_right, *vec_null_map_to, resultDataType, max_scale);
+            }
+            else
+            {
+                process<OpCase::LeftConstant, false>(
+                    const_left, col_right->getData(), vec_res, scale_left, 
scale_right, *vec_null_map_to, resultDataType, max_scale);
+            }
+        }
+        else if (col_left && col_right_const)
+        {
+            RightFieldType const_right = 
col_right_const->getValue<RightFieldType>();
+            if (calculate_with_256)
+            {
+                process<OpCase::RightConstant, true>(
+                    col_left->getData(), const_right, vec_res, scale_left, 
scale_right, *vec_null_map_to, resultDataType, max_scale);
+            }
+            else
+            {
+                process<OpCase::RightConstant, false>(
+                    col_left->getData(), const_right, vec_res, scale_left, 
scale_right, *vec_null_map_to, resultDataType, max_scale);
+            }
+        }
+        else
+        {
+            throw Exception(ErrorCodes::LOGICAL_ERROR, "Not supported.");
+        }
+
+        return ColumnNullable::create(std::move(col_res), 
std::move(col_null_map_to));
+    }
+
+    template <OpCase op_case, bool CalculateWith256, typename 
ResultContainerType, typename NativeResultType, typename ResultDataType>
+    static static void NO_INLINE process(
+        const auto & a,
+        const auto & b,
+        ResultContainerType & result_container,
+        const NativeResultType & scale_a,
+        const NativeResultType & scale_b,
+        ColumnUInt8::Container & vec_null_map_to,
+        const ResultDataType & resultDataType,
+        const size_t & max_scale)
+    {
+        size_t size;
+        if constexpr (op_case == OpCase::LeftConstant)
+            size = b.size();
+        else
+            size = a.size();
+
+        if constexpr (op_case == OpCase::Vector)
+        {
+            for (size_t i = 0; i < size; ++i)
+            {
+                NativeResultType res;
+                if (calculate<CalculateWith256>(
+                        unwrap<op_case, OpCase::LeftConstant>(a, i),
+                        unwrap<op_case, OpCase::RightConstant>(b, i),
+                        scale_a,
+                        scale_b,
+                        res,
+                        resultDataType,
+                        max_scale))
+                    result_container[i] = res;
+                else
+                    vec_null_map_to[i] = static_cast<UInt8>(1);
+            }
+        }
+        else if constexpr (op_case == OpCase::LeftConstant)
+        {
+            auto scaled_a = applyScaled(unwrap<op_case, 
OpCase::LeftConstant>(a, 0), scale_a);
+            for (size_t i = 0; i < size; ++i)
+            {
+                NativeResultType res;
+                if (calculate<CalculateWith256>(
+                        scaled_a,
+                        unwrap<op_case, OpCase::RightConstant>(b, i),
+                        static_cast<NativeResultType>(0),
+                        scale_b,
+                        res,
+                        resultDataType,
+                        max_scale))
+                    result_container[i] = res;
+                else
+                    vec_null_map_to[i] = static_cast<UInt8>(1);
+            }
+        }
+        else if constexpr (op_case == OpCase::RightConstant)
+        {
+            auto scaled_b = applyScaled(unwrap<op_case, 
OpCase::RightConstant>(b, 0), scale_b);
+
+            for (size_t i = 0; i < size; ++i)
+            {
+                NativeResultType res;
+                if (calculate<CalculateWith256>(
+                        unwrap<op_case, OpCase::LeftConstant>(a, i),
+                        scaled_b,
+                        scale_a,
+                        static_cast<NativeResultType>(0),
+                        res,
+                        resultDataType,
+                        max_scale))
+                    result_container[i] = res;
+                else
+                    vec_null_map_to[i] = static_cast<UInt8>(1);
+            }
+        }
+    }
+
+    // ResultNativeType = Int32/64/128/256
+    template <bool CalculateWith256, typename LeftNativeType, typename 
RightNativeType, typename NativeResultType, typename ResultDataType>
+    static NO_SANITIZE_UNDEFINED bool calculate(
+        const LeftNativeType l,
+        const RightNativeType r,
+        const NativeResultType & scale_left,
+        const NativeResultType & scale_right,
+        NativeResultType & res,
+        const ResultDataType & resultDataType,
+        const size_t & max_scale)
+    {
+        if constexpr (CalculateWith256)
+            return calculateImpl<Int256>(l, r, scale_left, scale_right, res, 
resultDataType, max_scale);
+        else if (is_division)
+            return calculateImpl<Int128>(l, r, scale_left, scale_right, res, 
resultDataType, max_scale);
+        else
+            return calculateImpl<NativeResultType>(l, r, scale_left, 
scale_right, res, resultDataType, max_scale);
+    }
+
+    template <typename CalcType, typename LeftNativeType, typename 
RightNativeType, typename NativeResultType, typename ResultDataType>
+    static NO_SANITIZE_UNDEFINED bool calculateImpl(
+        const LeftNativeType & l,
+        const RightNativeType & r,
+        const NativeResultType & scale_left,
+        const NativeResultType & scale_right,
+        NativeResultType & res,
+        const ResultDataType & resultDataType,
+        const size_t & max_scale)
+    {
+        CalcType scaled_l = applyScaled(static_cast<CalcType>(l), 
static_cast<CalcType>(scale_left));
+        CalcType scaled_r = applyScaled(static_cast<CalcType>(r), 
static_cast<CalcType>(scale_right));
+
+        CalcType c_res = 0;
+        auto success = Operation::template apply<CalcType>(scaled_l, scaled_r, 
c_res);
+        if (!success)
+            return false;
+
+        auto result_scale = resultDataType.getScale();
+        auto scale_diff = max_scale - result_scale;
+        chassert(scale_diff >= 0);
+        if (scale_diff)
+        {
+            auto scaled_diff = 
DecimalUtils::scaleMultiplier<CalcType>(scale_diff);
+            DecimalDivideImpl::apply<CalcType>(c_res, scaled_diff, c_res);
+        }
+
+        // check overflow
+        if constexpr (std::is_same_v<CalcType, Int256> || is_division)
+        {
+            auto max_value = 
intExp10OfSize<CalcType>(resultDataType.getPrecision());
+            if (c_res <= -max_value || c_res >= max_value)
+                return false;
+        }
+
+        res = static_cast<NativeResultType>(c_res);
+
+        return true;
+    }
+
+    template <OpCase op_case, OpCase target, class E>
+    static auto unwrap(const E & elem, size_t i)
+    {
+        if constexpr (op_case == target)
+            return elem.value;
+        else
+            return elem[i].value;
+    }
+
+    template <typename NativeType, typename ResultNativeType>
+    static ResultNativeType applyScaled(const NativeType & l, const 
ResultNativeType & scale)
+    {
+        if (scale > 1)
+            return common::mulIgnoreOverflow(l, scale);
+
+        return static_cast<ResultNativeType>(l);
+    }
+};
+
+
+template <class Operation, typename Name, OpMode Mode = OpMode::Default>
+class SparkFunctionDecimalBinaryArithmetic final : public IFunction
+{
+    static constexpr bool is_plus_minus = SparkIsOperation<Operation>::plus || 
SparkIsOperation<Operation>::minus;
+    static constexpr bool is_multiply = SparkIsOperation<Operation>::multiply;
+    static constexpr bool is_division = SparkIsOperation<Operation>::division;
+    static constexpr bool is_modulo = SparkIsOperation<Operation>::modulo;
+
+public:
+    static constexpr auto name = Name::name;
+
+    static FunctionPtr create(ContextPtr context_) { return 
std::make_shared<SparkFunctionDecimalBinaryArithmetic>(context_); }
+
+    explicit SparkFunctionDecimalBinaryArithmetic(ContextPtr context_) : 
context(context_) { }
+
+    String getName() const override { return name; }
+    size_t getNumberOfArguments() const override { return 3; }
+    bool isSuitableForShortCircuitArgumentsExecution(const 
DataTypesWithConstInfo & /*arguments*/) const override { return false; }
+    bool useDefaultImplementationForConstants() const override { return true; }
+    ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return 
{2}; }
+
+    DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) 
const override
+    {
+        if (arguments.size() != 3)
+            throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, 
"Function '{}' expects 3 arguments", getName());
+
+        if (!isDecimal(arguments[0].type) || !isDecimal(arguments[1].type) || 
!isDecimal(arguments[2].type))
+            throw Exception(
+                ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
+                "Illegal type {} {} {} of argument of function {}",
+                arguments[0].type->getName(),
+                arguments[1].type->getName(),
+                arguments[2].type->getName(),
+                getName());
+
+        return std::make_shared<DataTypeNullable>(arguments[2].type);
+    }
+
+    // executeImpl2
+    ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const 
DataTypePtr &, size_t) const override
+    {
+        const auto & left_argument = arguments[0];
+        const auto & right_argument = arguments[1];
+
+        const auto * left_generic = left_argument.type.get();
+        const auto * right_generic = right_argument.type.get();
+
+        ColumnPtr res;
+        const bool valid = castBothTypes(
+            left_generic,
+            right_generic,
+            removeNullable(arguments[2].type).get(),
+            [&](const auto & left, const auto & right, const auto & result) {
+                return (res = SparkDecimalBinaryOperation<Operation, 
Mode>::template executeDecimal(arguments, left, right, result))
+                    != nullptr;
+            });
+
+        if (!valid)
+        {
+            // This is a logical error, because the types should have been 
checked
+            // by getReturnTypeImpl().
+            throw Exception(
+                ErrorCodes::LOGICAL_ERROR,
+                "Arguments of '{}' have incorrect data types: '{}' of type 
'{}',"
+                " '{}' of type '{}'",
+                getName(),
+                left_argument.name,
+                left_argument.type->getName(),
+                right_argument.name,
+                right_argument.type->getName());
+        }
+
+        return res;
+    }
+
+private:
+    template <typename F>
+    static bool castBothTypes(const IDataType * left, const IDataType * right, 
const IDataType * result, F && f)
+    {
+        return castType(
+            left,
+            [&](const auto & left_)
+            {
+                return castType(
+                    right,
+                    [&](const auto & right_) { return castType(result, 
[&](const auto & result_) { return f(left_, right_, result_); }); });
+            });
+    }
+
+    static bool castType(const IDataType * type, auto && f)
+    {
+        using Types = TypeList<DataTypeDecimal32, DataTypeDecimal64, 
DataTypeDecimal128, DataTypeDecimal256>;
+        return castTypeToEither(Types{}, type, std::forward<decltype(f)>(f));
+    }
+
+    ContextPtr context;
+};
+
+struct NameSparkDecimalPlus
+{
+    static constexpr auto name = "sparkDecimalPlus";
+};
+struct NameSparkDecimalPlusEffect
+{
+    static constexpr auto name = "sparkDecimalPlusEffect";
+};
+struct NameSparkDecimalMinus
+{
+    static constexpr auto name = "sparkDecimalMinus";
+};
+struct NameSparkDecimalMinusEffect
+{
+    static constexpr auto name = "sparkDecimalMinusEffect";
+};
+struct NameSparkDecimalMultiply
+{
+    static constexpr auto name = "sparkDecimalMultiply";
+};
+struct NameSparkDecimalMultiplyEffect
+{
+    static constexpr auto name = "sparkDecimalMultiplyEffect";
+};
+struct NameSparkDecimalDivide
+{
+    static constexpr auto name = "sparkDecimalDivide";
+};
+struct NameSparkDecimalDivideEffect
+{
+    static constexpr auto name = "sparkDecimalDivideEffect";
+};
+struct NameSparkDecimalModulo
+{
+    static constexpr auto name = "NameSparkDecimalModulo";
+};
+struct NameSparkDecimalModuloEffect
+{
+    static constexpr auto name = "NameSparkDecimalModuloEffect";
+};
+
+
+using DecimalPlus = SparkFunctionDecimalBinaryArithmetic<DecimalPlusImpl, 
NameSparkDecimalPlus>;
+using DecimalMinus = SparkFunctionDecimalBinaryArithmetic<DecimalMinusImpl, 
NameSparkDecimalMinus>;
+using DecimalMultiply = 
SparkFunctionDecimalBinaryArithmetic<DecimalMultiplyImpl, 
NameSparkDecimalMultiply>;
+using DecimalDivide = SparkFunctionDecimalBinaryArithmetic<DecimalDivideImpl, 
NameSparkDecimalDivide>;
+using DecimalModulo = SparkFunctionDecimalBinaryArithmetic<DecimalModuloImpl, 
NameSparkDecimalModulo>;
+
+using DecimalPlusEffect = 
SparkFunctionDecimalBinaryArithmetic<DecimalPlusImpl, 
NameSparkDecimalPlusEffect, OpMode::Effect>;
+using DecimalMinusEffect = 
SparkFunctionDecimalBinaryArithmetic<DecimalMinusImpl, 
NameSparkDecimalMinusEffect, OpMode::Effect>;
+using DecimalMultiplyEffect = 
SparkFunctionDecimalBinaryArithmetic<DecimalMultiplyImpl, 
NameSparkDecimalMultiplyEffect, OpMode::Effect>;
+using DecimalDivideEffect = 
SparkFunctionDecimalBinaryArithmetic<DecimalDivideImpl, 
NameSparkDecimalDivideEffect, OpMode::Effect>;
+using DecimalModuloEffect = 
SparkFunctionDecimalBinaryArithmetic<DecimalModuloImpl, 
NameSparkDecimalModuloEffect, OpMode::Effect>;
+}
+
+REGISTER_FUNCTION(SparkDecimalFunctionArithmetic)
+{
+    factory.registerFunction<DecimalPlus>();
+    factory.registerFunction<DecimalMinus>();
+    factory.registerFunction<DecimalMultiply>();
+    factory.registerFunction<DecimalDivide>();
+    factory.registerFunction<DecimalModulo>();
+
+    factory.registerFunction<DecimalPlusEffect>();
+    factory.registerFunction<DecimalMinusEffect>();
+    factory.registerFunction<DecimalMultiplyEffect>();
+    factory.registerFunction<DecimalDivideEffect>();
+    factory.registerFunction<DecimalModuloEffect>();
+}
+}
diff --git 
a/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.h 
b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.h
new file mode 100644
index 0000000000..05e5f4ff9f
--- /dev/null
+++ b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.h
@@ -0,0 +1,276 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <base/arithmeticOverflow.h>
+
+namespace local_engine
+{
+
+static bool canCastLower(const Int256 & a, const Int256 & b)
+{
+    if (a.items[2] == 0 && a.items[3] == 0 && b.items[2] == 0 && b.items[3] == 
0)
+        return true;
+
+    return false;
+}
+
+static bool canCastLower(const Int128 & a, const Int128 & b)
+{
+    if (a.items[1] == 0 && b.items[1] == 0)
+        return true;
+
+    return false;
+}
+
+struct DecimalPlusImpl
+{
+    template <typename A>
+    static bool apply(A & a, A & b, A & r)
+    {
+        return !common::addOverflow(a, b, r);
+    }
+
+    template <>
+    static bool apply(Int128 & a, Int128 & b, Int128 & r)
+    {
+        if (canCastLower(a, b))
+        {
+            UInt64 low_result;
+            if (common::addOverflow(static_cast<UInt64>(a), 
static_cast<UInt64>(b), low_result))
+                return !common::addOverflow(a, b, r);
+
+            r = static_cast<Int128>(low_result);
+            return true;
+        }
+
+        return !common::addOverflow(a, b, r);
+    }
+
+
+    template <>
+    static bool apply(Int256 & a, Int256 & b, Int256 & r)
+    {
+        if (canCastLower(a, b))
+        {
+            UInt128 low_result;
+            if (common::addOverflow(static_cast<UInt128>(a), 
static_cast<UInt128>(b), low_result))
+                return !common::addOverflow(a, b, r);
+
+            r = static_cast<Int256>(low_result);
+            return true;
+        }
+
+        return !common::addOverflow(a, b, r);
+    }
+
+
+#if USE_EMBEDDED_COMPILER
+    static constexpr bool compilable = true;
+
+    static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * left, 
llvm::Value * right, bool)
+    {
+        return left->getType()->isIntegerTy() ? b.CreateAdd(left, right) : 
b.CreateFAdd(left, right);
+    }
+#endif
+};
+
+struct DecimalMinusImpl
+{
+    /// Apply operation and check overflow. It's used for Deciamal operations. 
@returns true if overflowed, false otherwise.
+    template <typename A>
+    static bool apply(A & a, A & b, A & r)
+    {
+        return !common::subOverflow(a, b, r);
+    }
+
+    template <>
+    static bool apply(Int128 & a, Int128 & b, Int128 & r)
+    {
+        if (canCastLower(a, b))
+        {
+            UInt64 low_result;
+            if (common::subOverflow(static_cast<UInt64>(a), 
static_cast<UInt64>(b), low_result))
+                return !common::subOverflow(a, b, r);
+
+            r = static_cast<Int128>(low_result);
+            return true;
+        }
+
+        return !common::subOverflow(a, b, r);
+    }
+
+    template <>
+    static bool apply(Int256 & a, Int256 & b, Int256 & r)
+    {
+        if (canCastLower(a, b))
+        {
+            UInt128 low_result;
+            if (common::subOverflow(static_cast<UInt128>(a), 
static_cast<UInt128>(b), low_result))
+                return !common::subOverflow(a, b, r);
+
+            r = static_cast<Int256>(low_result);
+            return true;
+        }
+
+        return !common::subOverflow(a, b, r);
+    }
+
+#if USE_EMBEDDED_COMPILER
+    static constexpr bool compilable = true;
+
+    static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * left, 
llvm::Value * right, bool)
+    {
+        return left->getType()->isIntegerTy() ? b.CreateSub(left, right) : 
b.CreateFSub(left, right);
+    }
+#endif
+};
+
+
+struct DecimalMultiplyImpl
+{
+    /// Apply operation and check overflow. It's used for Decimal operations. 
@returns true if overflowed, false otherwise.
+    template <typename A>
+    static bool apply(A & a, A & b, A & c)
+    {
+        return !common::mulOverflow(a, b, c);
+    }
+
+    template <Int128>
+    static bool apply(Int128 & a, Int128 & b, Int128 & r)
+    {
+        if (canCastLower(a, b))
+        {
+            UInt64 low_result = 0;
+            if (common::mulOverflow(static_cast<UInt64>(a), 
static_cast<UInt64>(b), low_result))
+                return !common::mulOverflow(a, b, r);
+
+            r = static_cast<Int128>(low_result);
+            return true;
+        }
+
+        return !common::mulOverflow(a, b, r);
+    }
+
+#if USE_EMBEDDED_COMPILER
+    static constexpr bool compilable = true;
+
+    static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * left, 
llvm::Value * right, bool)
+    {
+        return left->getType()->isIntegerTy() ? b.CreateMul(left, right) : 
b.CreateFMul(left, right);
+    }
+#endif
+};
+
+struct DecimalDivideImpl
+{
+    template <typename A>
+    static bool apply(A & a, A & b, A & r)
+    {
+        if (b == 0)
+            return false;
+
+        r = a / b;
+        return true;
+    }
+
+    template <>
+    static bool apply(Int128 & a, Int128 & b, Int128 & r)
+    {
+        if (b == 0)
+            return false;
+
+        if (canCastLower(a, b))
+        {
+            r = static_cast<Int128>(static_cast<UInt64>(a) / 
static_cast<UInt64>(b));
+            return true;
+        }
+
+        r = a / b;
+        return true;
+    }
+
+    template <>
+    static bool apply(Int256 & a, Int256 & b, Int256 & r)
+    {
+        if (b == 0)
+            return false;
+
+        if (canCastLower(a, b))
+        {
+            UInt128 low_result = 0;
+            UInt128 low_a = static_cast<UInt128>(a);
+            UInt128 low_b = static_cast<UInt128>(b);
+            apply(low_a, low_b, low_result);
+            r = static_cast<Int256>(low_result);
+            return true;
+        }
+
+        r = a / b;
+        return true;
+    }
+
+#if USE_EMBEDDED_COMPILER
+    static constexpr bool compilable = true;
+
+    static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * left, 
llvm::Value * right, bool)
+    {
+        return left->getType()->isIntegerTy() ? b.CreateSub(left, right) : 
b.CreateFSub(left, right);
+    }
+#endif
+};
+
+
+// ModuloImpl
+struct DecimalModuloImpl
+{
+    template <typename A>
+    static bool apply(A & a, A & b, A & r)
+    {
+        if (b == 0)
+            return false;
+
+        r = a % b;
+        return true;
+    }
+
+#if USE_EMBEDDED_COMPILER
+    static constexpr bool compilable = true;
+
+    static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * left, 
llvm::Value * right, bool)
+    {
+        return left->getType()->isIntegerTy() ? b.CreateSub(left, right) : 
b.CreateFSub(left, right);
+    }
+#endif
+};
+
+template <typename Op1, typename Op2>
+struct IsSameOperation
+{
+    static constexpr bool value = std::is_same_v<Op1, Op2>;
+};
+
+template <typename Op>
+struct SparkIsOperation
+{
+    static constexpr bool plus = IsSameOperation<Op, DecimalPlusImpl>::value;
+    static constexpr bool minus = IsSameOperation<Op, DecimalMinusImpl>::value;
+    static constexpr bool multiply = IsSameOperation<Op, 
DecimalMultiplyImpl>::value;
+    static constexpr bool division = IsSameOperation<Op, 
DecimalDivideImpl>::value;
+    static constexpr bool modulo = IsSameOperation<Op, 
DecimalModuloImpl>::value;
+};
+}
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h 
b/cpp-ch/local-engine/Parser/SerializedPlanParser.h
index c9a48106a9..9e1ffc0dd6 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h
@@ -135,6 +135,7 @@ public:
     IQueryPlanStep * addRollbackFilterHeaderStep(QueryPlanPtr & query_plan, 
const Block & input_header);
 
     static std::pair<DataTypePtr, Field> parseLiteral(const 
substrait::Expression_Literal & literal);
+    ContextPtr getContext() const { return context; }
 
     std::vector<QueryPlanPtr> extra_plan_holder;
 
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp 
b/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp
index 6aba310bf0..f73305df02 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp
@@ -22,6 +22,7 @@
 #include <Parser/TypeParser.h>
 #include <Common/BlockTypeUtils.h>
 #include <Common/CHUtil.h>
+#include <Common/GlutenSettings.h>
 
 namespace DB::ErrorCodes
 {
@@ -136,8 +137,11 @@ protected:
         return toFunctionNode(actions_dag, "checkDecimalOverflowSparkOrNull", 
overflow_args);
     }
 
-    virtual const DB::ActionsDAG::Node *
-    createFunctionNode(DB::ActionsDAG & actions_dag, const String & func_name, 
const DB::ActionsDAG::NodeRawConstPtrs & args) const
+    virtual const DB::ActionsDAG::Node * createFunctionNode(
+        DB::ActionsDAG & actions_dag,
+        const String & func_name,
+        const DB::ActionsDAG::NodeRawConstPtrs & args,
+        DataTypePtr result_type) const
     {
         return toFunctionNode(actions_dag, func_name, args);
     }
@@ -154,33 +158,8 @@ public:
 
         const auto left_type = DB::removeNullable(parsed_args[0]->result_type);
         const auto right_type = 
DB::removeNullable(parsed_args[1]->result_type);
-        const bool converted = isDecimal(left_type) && isDecimal(right_type);
-
-        if (converted)
-        {
-            const DecimalType evalType = getDecimalType(left_type, right_type);
-            parsed_args = convertBinaryArithmeticFunDecimalArgs(actions_dag, 
parsed_args, evalType, substrait_func);
-        }
-
-        const auto * func_node = createFunctionNode(actions_dag, ch_func_name, 
parsed_args);
-
-        if (converted)
-        {
-            const auto parsed_output_type = 
removeNullable(TypeParser::parseType(substrait_func.output_type()));
-            assert(isDecimal(parsed_output_type));
-            const Int32 parsed_precision = 
getDecimalPrecision(*parsed_output_type);
-            const Int32 parsed_scale = getDecimalScale(*parsed_output_type);
-            func_node = checkDecimalOverflow(actions_dag, func_node, 
parsed_precision, parsed_scale);
-#ifndef NDEBUG
-            const auto output_type = removeNullable(func_node->result_type);
-            const Int32 output_precision = getDecimalPrecision(*output_type);
-            const Int32 output_scale = getDecimalScale(*output_type);
-            if (output_precision != parsed_precision || output_scale != 
parsed_scale)
-                throw Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Function {} 
has wrong output type", getName());
-#endif
-
-            return func_node;
-        }
+        const auto result_type = 
removeNullable(TypeParser::parseType(substrait_func.output_type()));
+        const auto * func_node = createFunctionNode(actions_dag, ch_func_name, 
parsed_args, result_type);
         return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag);
     }
 };
@@ -199,6 +178,32 @@ protected:
     {
         return DecimalType::evalAddSubstractDecimalType(p1, s1, p2, s2);
     }
+
+    const DB::ActionsDAG::Node * createFunctionNode(
+        DB::ActionsDAG & actions_dag,
+        const String & func_name,
+        const DB::ActionsDAG::NodeRawConstPtrs & new_args,
+        DataTypePtr result_type) const override
+    {
+        const auto * left_arg = new_args[0];
+        const auto * right_arg = new_args[1];
+
+        if (isDecimal(removeNullable(left_arg->result_type)) && 
isDecimal(removeNullable(right_arg->result_type)))
+        {
+            const ActionsDAG::Node * type_node = 
&actions_dag.addColumn(ColumnWithTypeAndName(
+                result_type->createColumnConstWithDefaultValue(1), 
result_type, getUniqueName(result_type->getName())));
+
+            const auto & settings = 
plan_parser->getContext()->getSettingsRef();
+            auto function_name
+                = settings.has("arithmetic.decimal.mode") && 
settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT")
+                ? "sparkDecimalPlusEffect"
+                : "sparkDecimalPlus";
+
+            return toFunctionNode(actions_dag, function_name, {left_arg, 
right_arg, type_node});
+        }
+
+        return toFunctionNode(actions_dag, "plus", {left_arg, right_arg});
+    }
 };
 
 class FunctionParserMinus final : public FunctionParserBinaryArithmetic
@@ -215,6 +220,32 @@ protected:
     {
         return DecimalType::evalAddSubstractDecimalType(p1, s1, p2, s2);
     }
+
+    const DB::ActionsDAG::Node * createFunctionNode(
+        DB::ActionsDAG & actions_dag,
+        const String & func_name,
+        const DB::ActionsDAG::NodeRawConstPtrs & new_args,
+        DataTypePtr result_type) const override
+    {
+        const auto * left_arg = new_args[0];
+        const auto * right_arg = new_args[1];
+
+        if (isDecimal(removeNullable(left_arg->result_type)) && 
isDecimal(removeNullable(right_arg->result_type)))
+        {
+            const ActionsDAG::Node * type_node = 
&actions_dag.addColumn(ColumnWithTypeAndName(
+                result_type->createColumnConstWithDefaultValue(1), 
result_type, getUniqueName(result_type->getName())));
+
+            const auto & settings = 
plan_parser->getContext()->getSettingsRef();
+            auto function_name
+                = settings.has("arithmetic.decimal.mode") && 
settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT")
+                ? "sparkDecimalMinusEffect"
+                : "sparkDecimalMinus";
+
+            return toFunctionNode(actions_dag, function_name, {left_arg, 
right_arg, type_node});
+        }
+
+        return toFunctionNode(actions_dag, "minus", {left_arg, right_arg});
+    }
 };
 
 class FunctionParserMultiply final : public FunctionParserBinaryArithmetic
@@ -230,6 +261,32 @@ protected:
     {
         return DecimalType::evalMultiplyDecimalType(p1, s1, p2, s2);
     }
+
+    const DB::ActionsDAG::Node * createFunctionNode(
+        DB::ActionsDAG & actions_dag,
+        const String & func_name,
+        const DB::ActionsDAG::NodeRawConstPtrs & new_args,
+        DataTypePtr result_type) const override
+    {
+        const auto * left_arg = new_args[0];
+        const auto * right_arg = new_args[1];
+
+        if (isDecimal(removeNullable(left_arg->result_type)) && 
isDecimal(removeNullable(right_arg->result_type)))
+        {
+            const ActionsDAG::Node * type_node = 
&actions_dag.addColumn(ColumnWithTypeAndName(
+                result_type->createColumnConstWithDefaultValue(1), 
result_type, getUniqueName(result_type->getName())));
+
+            const auto & settings = 
plan_parser->getContext()->getSettingsRef();
+            auto function_name
+                = settings.has("arithmetic.decimal.mode") && 
settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT")
+                ? "sparkDecimalMultiplyEffect"
+                : "sparkDecimalMultiply";
+
+            return toFunctionNode(actions_dag, function_name, {left_arg, 
right_arg, type_node});
+        }
+
+        return toFunctionNode(actions_dag, "multiply", {left_arg, right_arg});
+    }
 };
 
 class FunctionParserModulo final : public FunctionParserBinaryArithmetic
@@ -245,6 +302,33 @@ protected:
     {
         return DecimalType::evalModuloDecimalType(p1, s1, p2, s2);
     }
+
+    const DB::ActionsDAG::Node * createFunctionNode(
+        DB::ActionsDAG & actions_dag,
+        const String & func_name,
+        const DB::ActionsDAG::NodeRawConstPtrs & new_args,
+        DataTypePtr result_type) const override
+    {
+        assert(func_name == name);
+        const auto * left_arg = new_args[0];
+        const auto * right_arg = new_args[1];
+
+        if (isDecimal(removeNullable(left_arg->result_type)) || 
isDecimal(removeNullable(right_arg->result_type)))
+        {
+            const ActionsDAG::Node * type_node = 
&actions_dag.addColumn(ColumnWithTypeAndName(
+                result_type->createColumnConstWithDefaultValue(1), 
result_type, getUniqueName(result_type->getName())));
+
+            const auto & settings = 
plan_parser->getContext()->getSettingsRef();
+            auto function_name
+                = settings.has("arithmetic.decimal.mode") && 
settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT")
+                ? "NameSparkDecimalModuloEffect"
+                : "NameSparkDecimalModulo";
+            ;
+            return toFunctionNode(actions_dag, function_name, {left_arg, 
right_arg, type_node});
+        }
+
+        return toFunctionNode(actions_dag, "modulo", {left_arg, right_arg});
+    }
 };
 
 class FunctionParserDivide final : public FunctionParserBinaryArithmetic
@@ -262,14 +346,28 @@ protected:
     }
 
     const DB::ActionsDAG::Node * createFunctionNode(
-        DB::ActionsDAG & actions_dag, const String & func_name, const 
DB::ActionsDAG::NodeRawConstPtrs & new_args) const override
+        DB::ActionsDAG & actions_dag,
+        const String & func_name,
+        const DB::ActionsDAG::NodeRawConstPtrs & new_args,
+        DataTypePtr result_type) const override
     {
         assert(func_name == name);
         const auto * left_arg = new_args[0];
         const auto * right_arg = new_args[1];
 
         if (isDecimal(removeNullable(left_arg->result_type)) || 
isDecimal(removeNullable(right_arg->result_type)))
-            return toFunctionNode(actions_dag, "sparkDivideDecimal", 
{left_arg, right_arg});
+        {
+            const ActionsDAG::Node * type_node = 
&actions_dag.addColumn(ColumnWithTypeAndName(
+                result_type->createColumnConstWithDefaultValue(1), 
result_type, getUniqueName(result_type->getName())));
+
+            const auto & settings = 
plan_parser->getContext()->getSettingsRef();
+            auto function_name
+                = settings.has("arithmetic.decimal.mode") && 
settingsEqual(settings, "arithmetic.decimal.mode", "EFFECT")
+                ? "sparkDecimalDivideEffect"
+                : "sparkDecimalDivide";
+            ;
+            return toFunctionNode(actions_dag, function_name, {left_arg, 
right_arg, type_node});
+        }
 
         return toFunctionNode(actions_dag, "sparkDivide", {left_arg, 
right_arg});
     }
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
index 0aa6761584..d14a591277 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
@@ -446,6 +446,10 @@ object ExpressionConverter extends SQLConfHelper with 
Logging {
             LiteralTransformer(m.nullOnOverflow)),
           m
         )
+      case PromotePrecision(_ @Cast(child, _: DecimalType, _, _))
+          if child.dataType
+            .isInstanceOf[DecimalType] && 
!BackendsApiManager.getSettings.transformCheckOverflow =>
+        replaceWithExpressionTransformer0(child, attributeSeq, expressionsMap)
       case _: NormalizeNaNAndZero | _: PromotePrecision | _: TaggingExpression 
=>
         ChildTransformer(
           substraitExprName,
@@ -466,16 +470,12 @@ object ExpressionConverter extends SQLConfHelper with 
Logging {
           if !BackendsApiManager.getSettings.transformCheckOverflow &&
             DecimalArithmeticUtil.isDecimalArithmetic(b) =>
         DecimalArithmeticUtil.checkAllowDecimalArithmetic()
-        val leftChild =
+        val arithmeticExprName = getAndCheckSubstraitName(b, expressionsMap)
+        val left =
           replaceWithExpressionTransformer0(b.left, attributeSeq, 
expressionsMap)
-        val rightChild =
+        val right =
           replaceWithExpressionTransformer0(b.right, attributeSeq, 
expressionsMap)
-        DecimalArithmeticExpressionTransformer(
-          getAndCheckSubstraitName(b, expressionsMap),
-          leftChild,
-          rightChild,
-          decimalType,
-          b)
+        DecimalArithmeticExpressionTransformer(arithmeticExprName, left, 
right, decimalType, b)
       case c: CheckOverflow =>
         CheckOverflowTransformer(
           substraitExprName,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to