This is an automated email from the ASF dual-hosted git repository.

taiyangli pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 27647b6cec [GLUTEN-8704][CH] try accelerate some spark* function by 
optimizing tight loops (#8708)
27647b6cec is described below

commit 27647b6cec1152ebfcd507838ef52194ed55cb2a
Author: 李扬 <[email protected]>
AuthorDate: Tue Feb 18 14:56:35 2025 +0800

    [GLUTEN-8704][CH] try accelerate some spark* function by optimizing tight 
loops (#8708)
    
    * improve some loops
    
    * commit again
    
    * fix style
    
    * fix uts
    
    * add benchmark
    
    * add all benchmarks
    
    * vector
    
    * commit again
    
    * improve decimalcheckoverflow
    
    * commit again
    
    * optimize checkoverflow from int/float
    
    * commit again
    
    * fix bugs
    
    * fix failed ut
    
    * commit again
---
 .../Functions/SparkFunctionCastFloatToInt.cpp      |  32 +--
 .../Functions/SparkFunctionCastFloatToInt.h        |  91 ++++---
 .../SparkFunctionCheckDecimalOverflow.cpp          | 283 +++++++++++----------
 .../SparkFunctionDecimalBinaryArithmetic.h         |   1 -
 .../local-engine/Functions/SparkFunctionDivide.h   | 190 ++++++++------
 cpp-ch/local-engine/Parser/ExpressionParser.cpp    |   9 +-
 cpp-ch/local-engine/tests/CMakeLists.txt           |   2 +-
 ..._function.cpp => benchmark_spark_functions.cpp} | 229 ++++++++++++++++-
 .../utils/clickhouse/ClickHouseTestSettings.scala  |   1 +
 9 files changed, 550 insertions(+), 288 deletions(-)

diff --git a/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.cpp 
b/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.cpp
index c378f9fbf7..7cdc85cb17 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.cpp
@@ -26,31 +26,15 @@ using namespace DB;
 namespace local_engine
 {
 
-struct NameToUInt8 { static constexpr auto name = "sparkCastFloatToUInt8"; };
-struct NameToUInt16 { static constexpr auto name = "sparkCastFloatToUInt16"; };
-struct NameToUInt32 { static constexpr auto name = "sparkCastFloatToUInt32"; };
-struct NameToUInt64 { static constexpr auto name = "sparkCastFloatToUInt64"; };
-struct NameToUInt128 { static constexpr auto name = "sparkCastFloatToUInt128"; 
};
-struct NameToUInt256 { static constexpr auto name = "sparkCastFloatToUInt256"; 
};
 struct NameToInt8 { static constexpr auto name = "sparkCastFloatToInt8"; };
 struct NameToInt16 { static constexpr auto name = "sparkCastFloatToInt16"; };
 struct NameToInt32 { static constexpr auto name = "sparkCastFloatToInt32"; };
 struct NameToInt64 { static constexpr auto name = "sparkCastFloatToInt64"; };
-struct NameToInt128 { static constexpr auto name = "sparkCastFloatToInt128"; };
-struct NameToInt256 { static constexpr auto name = "sparkCastFloatToInt256"; };
 
-using SparkFunctionCastFloatToInt8 = 
local_engine::SparkFunctionCastFloatToInt<Int8, NameToInt8, INT8_MAX, INT8_MIN>;
-using SparkFunctionCastFloatToInt16 = 
local_engine::SparkFunctionCastFloatToInt<Int16, NameToInt16, INT16_MAX, 
INT16_MIN>;
-using SparkFunctionCastFloatToInt32 = 
local_engine::SparkFunctionCastFloatToInt<Int32, NameToInt32, INT32_MAX, 
INT32_MIN>;
-using SparkFunctionCastFloatToInt64 = 
local_engine::SparkFunctionCastFloatToInt<Int64, NameToInt64, INT64_MAX, 
INT64_MIN>;
-using SparkFunctionCastFloatToInt128 = 
local_engine::SparkFunctionCastFloatToInt<Int128, NameToInt128, 
std::numeric_limits<Int128>::max(), std::numeric_limits<Int128>::min()>;
-using SparkFunctionCastFloatToInt256 = 
local_engine::SparkFunctionCastFloatToInt<Int256, NameToInt256, 
std::numeric_limits<Int256>::max(), std::numeric_limits<Int256>::min()>;
-using SparkFunctionCastFloatToUInt8 = 
local_engine::SparkFunctionCastFloatToInt<UInt8, NameToUInt8, UINT8_MAX, 0>;
-using SparkFunctionCastFloatToUInt16 = 
local_engine::SparkFunctionCastFloatToInt<UInt16, NameToUInt16, UINT16_MAX, 0>;
-using SparkFunctionCastFloatToUInt32 = 
local_engine::SparkFunctionCastFloatToInt<UInt32, NameToUInt32, UINT32_MAX, 0>;
-using SparkFunctionCastFloatToUInt64 = 
local_engine::SparkFunctionCastFloatToInt<UInt64, NameToUInt64, UINT64_MAX, 0>;
-using SparkFunctionCastFloatToUInt128 = 
local_engine::SparkFunctionCastFloatToInt<UInt128, NameToUInt128, 
std::numeric_limits<UInt128>::max(), 0>;
-using SparkFunctionCastFloatToUInt256 = 
local_engine::SparkFunctionCastFloatToInt<UInt256, NameToUInt256, 
std::numeric_limits<UInt256>::max(), 0>;
+using SparkFunctionCastFloatToInt8 = 
local_engine::SparkFunctionCastFloatToInt<Int8, NameToInt8>;
+using SparkFunctionCastFloatToInt16 = 
local_engine::SparkFunctionCastFloatToInt<Int16, NameToInt16>;
+using SparkFunctionCastFloatToInt32 = 
local_engine::SparkFunctionCastFloatToInt<Int32, NameToInt32>;
+using SparkFunctionCastFloatToInt64 = 
local_engine::SparkFunctionCastFloatToInt<Int64, NameToInt64>;
 
 REGISTER_FUNCTION(SparkFunctionCastToInt)
 {
@@ -58,13 +42,5 @@ REGISTER_FUNCTION(SparkFunctionCastToInt)
     factory.registerFunction<SparkFunctionCastFloatToInt16>();
     factory.registerFunction<SparkFunctionCastFloatToInt32>();
     factory.registerFunction<SparkFunctionCastFloatToInt64>();
-    factory.registerFunction<SparkFunctionCastFloatToInt128>();
-    factory.registerFunction<SparkFunctionCastFloatToInt256>();
-    factory.registerFunction<SparkFunctionCastFloatToUInt8>();
-    factory.registerFunction<SparkFunctionCastFloatToUInt16>();
-    factory.registerFunction<SparkFunctionCastFloatToUInt32>();
-    factory.registerFunction<SparkFunctionCastFloatToUInt64>();
-    factory.registerFunction<SparkFunctionCastFloatToUInt128>();
-    factory.registerFunction<SparkFunctionCastFloatToUInt256>();
 }
 }
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.h 
b/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.h
index abe66d536d..00614798ce 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.h
+++ b/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.h
@@ -40,8 +40,7 @@ namespace ErrorCodes
 namespace local_engine
 {
 
-/// TODO(taiyang-li): remove int_max_value and int_min_value for it is 
determined by T
-template <is_integer T, typename Name, T int_max_value, T int_min_value>
+template <is_integer T, typename Name>
 class SparkFunctionCastFloatToInt : public DB::IFunction
 {
 public:
@@ -61,19 +60,16 @@ public:
         if (arguments.size() != 1)
             throw 
DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {}'s 
arguments number must be 1", name);
 
+        if (!isFloat(removeNullable(arguments[0])))
+            throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, 
"Function {}'s 1st argument must be float type", name);
+
         return makeNullable(std::make_shared<const DB::DataTypeNumber<T>>());
     }
 
-    DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, 
const DB::DataTypePtr & result_type, size_t) const override
+    DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, 
const DB::DataTypePtr & result_type, size_t input_rows_count) const override
     {
-        if (arguments.size() != 1)
-            throw 
DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {}'s 
arguments number must be 1", name);
-
-        if (!isFloat(removeNullable(arguments[0].type)))
-            throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, 
"Function {}'s 1st argument must be float type", name);
-
         DB::ColumnPtr src_col = arguments[0].column;
-        size_t size = src_col->size();
+        size_t size = input_rows_count;
 
         auto res_col = DB::ColumnVector<T>::create(size, 0);
         auto null_map_col = DB::ColumnUInt8::create(size, 0);
@@ -94,23 +90,60 @@ public:
         return DB::ColumnNullable::create(std::move(res_col), 
std::move(null_map_col));
     }
 
+    MULTITARGET_FUNCTION_AVX2_SSE42(
+        MULTITARGET_FUNCTION_HEADER(template <typename F> static void 
NO_SANITIZE_UNDEFINED NO_INLINE),
+        vectorImpl,
+        MULTITARGET_FUNCTION_BODY(
+            (F int_min,
+             F int_max,
+             const DB::PaddedPODArray<F> & src_data,
+             DB::PaddedPODArray<T> & data,
+             DB::PaddedPODArray<UInt8> & null_map_data,
+             size_t rows) /// NOLINT
+            {
+                for (size_t i = 0; i < rows; ++i)
+                {
+                    null_map_data[i] = !isFinite(src_data[i]);
+                    data[i] = static_cast<T>(std::fmax(int_min, 
std::fmin(int_max, src_data[i])));
+                }
+            }))
+
     template <typename F>
-    void executeInternal(const DB::ColumnPtr & src, DB::PaddedPODArray<T> & 
data, DB::PaddedPODArray<UInt8> & null_map_data) const
+    static void NO_INLINE vector(
+        F int_min,
+        F int_max,
+        const DB::PaddedPODArray<F> & src_data,
+        DB::PaddedPODArray<T> & data,
+        DB::PaddedPODArray<UInt8> & null_map_data,
+        size_t rows)
     {
-        const DB::ColumnVector<F> * src_vec = assert_cast<const 
DB::ColumnVector<F> *>(src.get());
-        /// TODO(taiyang-li): try to vectorize below loop
-        for (size_t i = 0; i < src_vec->size(); ++i)
+#if USE_MULTITARGET_CODE
+        if (isArchSupported(DB::TargetArch::AVX2))
         {
-            F element = src_vec->getElement(i);
-            if (isNaN(element) || !isFinite(element))
-                null_map_data[i] = 1;
-            else if (element > int_max_value)
-                data[i] = int_max_value;
-            else if (element < int_min_value)
-                data[i] = int_min_value;
-            else
-                data[i] = static_cast<T>(element);
+            vectorImplAVX2(int_min, int_max, src_data, data, null_map_data, 
rows);
+            return;
         }
+
+        if (isArchSupported(DB::TargetArch::SSE42))
+        {
+            vectorImplSSE42(int_min, int_max, src_data, data, null_map_data, 
rows);
+            return;
+        }
+#endif
+
+        vectorImpl(int_min, int_max, src_data, data, null_map_data, rows);
+    }
+
+    template <typename F>
+    void executeInternal(const DB::ColumnPtr & src, DB::PaddedPODArray<T> & 
data, DB::PaddedPODArray<UInt8> & null_map_data) const
+    {
+        const DB::ColumnVector<F> * src_vec = assert_cast<const 
DB::ColumnVector<F> *>(src.get());
+
+        size_t rows = src_vec->size();
+        const auto & src_data = src_vec->getData();
+        const auto int_min = static_cast<F>(std::numeric_limits<T>::min());
+        const auto int_max = static_cast<F>(std::numeric_limits<T>::max());
+        vector(int_min, int_max, src_data, data, null_map_data, rows);
     }
 
 #if USE_EMBEDDED_COMPILER
@@ -125,6 +158,7 @@ public:
         return true;
     }
 
+
     llvm::Value *
     compileImpl(llvm::IRBuilderBase & builder, const DB::ValuesWithType & 
arguments, const DB::DataTypePtr & result_type) const override
     {
@@ -140,16 +174,7 @@ public:
             b.CreateFCmpOEQ(src_value, 
llvm::ConstantFP::getInfinity(float_type, true)));
 
         bool is_signed = std::is_signed_v<T>;
-        llvm::Value * max_value = llvm::ConstantInt::get(int_type, 
static_cast<UInt64>(int_max_value), is_signed);
-        llvm::Value * min_value = llvm::ConstantInt::get(int_type, 
static_cast<UInt64>(int_min_value), is_signed);
-        llvm::Value * clamped_value = b.CreateSelect(
-            b.CreateFCmpOGT(src_value, llvm::ConstantFP::get(float_type, 
static_cast<Float64>(int_max_value))),
-            max_value,
-            b.CreateSelect(
-                b.CreateFCmpOLT(src_value, llvm::ConstantFP::get(float_type, 
static_cast<Float64>(int_min_value))),
-                min_value,
-                is_signed_v<T> ? b.CreateFPToSI(src_value, int_type) : 
b.CreateFPToUI(src_value, int_type)));
-        llvm::Value * result_value = b.CreateSelect(b.CreateOr(is_nan, 
is_inf), llvm::Constant::getNullValue(int_type), clamped_value);
+        llvm::Value * result_value = is_signed_v<T> ? 
b.CreateFPToSI(src_value, int_type) : b.CreateFPToUI(src_value, int_type);
         llvm::Value * result_is_null = b.CreateOr(is_nan, is_inf);
 
         auto * nullable_structure_type = toNativeType(b, result_type);
diff --git 
a/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.cpp 
b/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.cpp
index 95853b187a..3b870b22d7 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.cpp
@@ -16,7 +16,6 @@
  */
 #include "SparkFunctionCheckDecimalOverflow.h"
 
-#include <typeinfo>
 #include <Columns/ColumnDecimal.h>
 #include <Columns/ColumnNullable.h>
 #include <Columns/ColumnsNumber.h>
@@ -27,6 +26,8 @@
 #include <Functions/FunctionFactory.h>
 #include <Functions/FunctionHelpers.h>
 #include <Functions/IFunction.h>
+#include "Columns/ColumnsCommon.h"
+#include <iostream>
 
 namespace DB
 {
@@ -45,6 +46,8 @@ namespace local_engine
 {
 using namespace DB;
 
+namespace
+{
 struct CheckDecimalOverflowSpark
 {
     static constexpr auto name = "checkDecimalOverflowSpark";
@@ -54,14 +57,19 @@ struct CheckDecimalOverflowSparkOrNull
     static constexpr auto name = "checkDecimalOverflowSparkOrNull";
 };
 
-enum class CheckExceptionMode
+enum class CheckExceptionMode: uint8_t
 {
     Throw, /// Throw exception if value cannot be parsed.
     Null /// Return ColumnNullable with NULLs when value cannot be parsed.
 };
 
-namespace
+enum class ScaleDirection: int8_t
 {
+    Up = 1,
+    Down = -1,
+    None = 0
+};
+
 /// Returns received decimal value if and Decimal value has less digits then 
it's Precision allow, 0 otherwise.
 /// Precision could be set as second argument or omitted. If omitted function 
uses Decimal precision of the first argument.
 template <typename Name, CheckExceptionMode mode>
@@ -102,193 +110,210 @@ public:
 
     ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const 
DataTypePtr &, size_t input_rows_count) const override
     {
-        const auto & src_column = arguments[0];
-        UInt32 precision = extractArgument(arguments[1]);
-        UInt32 scale = extractArgument(arguments[2]);
+        UInt32 to_precision = extractArgument(arguments[1]);
+        UInt32 to_scale = extractArgument(arguments[2]);
 
-        ColumnPtr result_column;
+        const auto & src_col = arguments[0];
+        ColumnPtr dst_col;
 
         auto call = [&](const auto & types) -> bool
         {
             using Types = std::decay_t<decltype(types)>;
             using FromDataType = typename Types::LeftType;
             using ToDataType = typename Types::RightType;
+
             if constexpr (IsDataTypeDecimal<FromDataType> || 
IsDataTypeNumber<FromDataType>)
             {
                 using FromFieldType = typename FromDataType::FieldType;
-                if (const ColumnVectorOrDecimal<FromFieldType> * col_vec = 
checkAndGetColumn<ColumnVectorOrDecimal<FromFieldType>>(src_column.column.get()))
+
+                /// Fast path
+                if constexpr (IsDataTypeDecimal<FromDataType>)
+                {
+                    auto from_precision = getDecimalPrecision(*src_col.type);
+                    auto from_scale = getDecimalScale(*src_col.type);
+                    if (from_precision == to_precision && from_scale == 
to_scale)
+                    {
+                        dst_col = src_col.column;
+                        return true;
+                    }
+                }
+
+                if (const ColumnVectorOrDecimal<FromFieldType> * col_vec = 
checkAndGetColumn<ColumnVectorOrDecimal<FromFieldType>>(src_col.column.get()))
                 {
-                    executeInternal<FromDataType, ToDataType>(*col_vec, 
result_column, input_rows_count, precision, scale);
+                    executeInternal<FromDataType, ToDataType>(*col_vec, 
dst_col, input_rows_count, to_precision, to_scale);
                     return true;
                 }
             }
+
             throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal 
column while execute function {}", getName());
         };
 
-        if (precision <= DecimalUtils::max_precision<Decimal32>)
-            
callOnIndexAndDataType<DataTypeDecimal<Decimal32>>(src_column.type->getTypeId(),
 call);
-        else if (precision <= DecimalUtils::max_precision<Decimal64>)
-            
callOnIndexAndDataType<DataTypeDecimal<Decimal64>>(src_column.type->getTypeId(),
 call);
-        else if (precision <= DecimalUtils::max_precision<Decimal128>)
-            
callOnIndexAndDataType<DataTypeDecimal<Decimal128>>(src_column.type->getTypeId(),
 call);
+        if (to_precision <= DecimalUtils::max_precision<Decimal32>)
+            
callOnIndexAndDataType<DataTypeDecimal<Decimal32>>(src_col.type->getTypeId(), 
call);
+        else if (to_precision <= DecimalUtils::max_precision<Decimal64>)
+            
callOnIndexAndDataType<DataTypeDecimal<Decimal64>>(src_col.type->getTypeId(), 
call);
+        else if (to_precision <= DecimalUtils::max_precision<Decimal128>)
+            
callOnIndexAndDataType<DataTypeDecimal<Decimal128>>(src_col.type->getTypeId(), 
call);
         else
-            
callOnIndexAndDataType<DataTypeDecimal<Decimal256>>(src_column.type->getTypeId(),
 call);
-
+            
callOnIndexAndDataType<DataTypeDecimal<Decimal256>>(src_col.type->getTypeId(), 
call);
 
-        if (!result_column)
-            throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Wrong call for {} 
with {}", getName(), src_column.type->getName());
+        if (!dst_col)
+            throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Wrong call for {} 
with {}", getName(), src_col.type->getName());
 
-        return result_column;
+        return dst_col;
     }
 
 private:
     template <typename FromDataType, typename ToDataType>
     requires(IsDataTypeDecimal<ToDataType> && (IsDataTypeDecimal<FromDataType> 
|| IsDataTypeNumber<FromDataType>))
-    static void executeInternal(
-        const ColumnVectorOrDecimal<typename FromDataType::FieldType> & 
col_source, ColumnPtr & result_column, size_t input_rows_count, UInt32 
precision, UInt32 scale_to)
+    static void
+    executeInternal(const FromDataType::ColumnType & src_col, ColumnPtr & 
dst_col, size_t rows, UInt32 to_precision, UInt32 to_scale)
     {
         using ToFieldType = typename ToDataType::FieldType;
+        using ToNativeType = typename ToFieldType::NativeType;
         using ToColumnType = typename ToDataType::ColumnType;
-        using T = typename FromDataType::FieldType;
-
-        ColumnUInt8::MutablePtr col_null_map_to;
-        ColumnUInt8::Container * vec_null_map_to [[maybe_unused]] = nullptr;
-        UInt32 scale_from = 0;
-        using ToFieldNativeType = typename ToFieldType::NativeType;
-        ToFieldNativeType decimal_int_part_max = 0;
-        ToFieldNativeType decimal_int_part_min = 0;
-        if constexpr (IsDataTypeDecimal<FromDataType>)
-            scale_from = col_source.getScale();
-        else
-        {
-            decimal_int_part_max = 
DecimalUtils::scaleMultiplier<ToFieldNativeType>(precision - scale_to) - 1;
-            decimal_int_part_min = 1 - 
DecimalUtils::scaleMultiplier<ToFieldNativeType>(precision - scale_to);
-        }
-        if constexpr (exception_mode == CheckExceptionMode::Null)
-        {
-            col_null_map_to = ColumnUInt8::create(input_rows_count, false);
-            vec_null_map_to = &col_null_map_to->getData();
-        }
+        using FromFieldType = typename FromDataType::FieldType;
 
-        typename ToColumnType::MutablePtr col_to = 
ToColumnType::create(input_rows_count, scale_to);
-        auto & vec_to = col_to->getData();
-        vec_to.resize_exact(input_rows_count);
+        using MaxFieldType = std::conditional_t<
+            is_decimal<FromFieldType>,
+            std::conditional_t<(sizeof(FromFieldType) > sizeof(ToFieldType)), 
FromFieldType, ToFieldType>,
+            ToFieldType>;
+        using MaxNativeType = typename MaxFieldType::NativeType;
 
-        auto & datas = col_source.getData();
-        for (size_t i = 0; i < input_rows_count; ++i)
+        /// Calculate const parameters for decimal conversion outside the loop 
to avoid unnecessary calculations.
+        ScaleDirection scale_direction;
+        UInt32 from_scale = 0;
+        MaxNativeType scale_multiplier = 0;
+        MaxNativeType pow10_to_precision = 
DecimalUtils::scaleMultiplier<MaxNativeType>(to_precision);
+        if constexpr (IsDataTypeDecimal<FromDataType>)
         {
-            ToFieldType result;
-            bool success = convertToDecimalImpl<FromDataType, 
ToDataType>(datas[i], precision, scale_from, scale_to, decimal_int_part_max, 
decimal_int_part_min, result);
-            if constexpr (exception_mode == CheckExceptionMode::Null)
+            from_scale = src_col.getScale();
+            if (to_scale > from_scale)
             {
-                vec_to[i] = static_cast<ToFieldType>(result);
-                (*vec_null_map_to)[i] = !success;
+                scale_direction = ScaleDirection::Up;
+                scale_multiplier = 
DecimalUtils::scaleMultiplier<MaxNativeType>(to_scale - from_scale);
+            }
+            else if (to_scale < from_scale)
+            {
+                scale_direction = ScaleDirection::Down;
+                scale_multiplier = 
DecimalUtils::scaleMultiplier<MaxNativeType>(from_scale - to_scale);
             }
             else
             {
-                if (success)
-                    vec_to[i] = static_cast<ToFieldType>(result);
-                else
-                    throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Decimal 
value is overflow.");
+                scale_direction = ScaleDirection::None;
+                scale_multiplier = 1;
             }
         }
+        else
+        {
+            scale_multiplier = 
DecimalUtils::scaleMultiplier<MaxNativeType>(to_scale);
+        }
 
-        if constexpr (exception_mode == CheckExceptionMode::Null)
-            result_column = ColumnNullable::create(std::move(col_to), 
std::move(col_null_map_to));
+        auto & src_data = src_col.getData();
+
+        auto res_data_col = ToColumnType::create(rows, to_scale);
+        auto & res_data = res_data_col->getData();
+        auto res_nullmap_col = ColumnUInt8::create(rows, 0);
+        auto & res_nullmap_data = res_nullmap_col->getData();
+
+        if constexpr (IsDataTypeDecimal<FromDataType>)
+        {
+            if (scale_direction == ScaleDirection::Up)
+                for (size_t i = 0; i < rows; ++i)
+                    res_nullmap_data[i] = 
!convertDecimalToDecimalImpl<ScaleDirection::Up, FromDataType, ToDataType, 
MaxNativeType>(
+                        src_data[i], scale_multiplier, pow10_to_precision, 
res_data[i]);
+            else if (scale_direction == ScaleDirection::Down)
+                for (size_t i = 0; i < rows; ++i)
+                    res_nullmap_data[i] = 
!convertDecimalToDecimalImpl<ScaleDirection::Down, FromDataType, ToDataType, 
MaxNativeType>(
+                        src_data[i], scale_multiplier, pow10_to_precision, 
res_data[i]);
+            else
+                for (size_t i = 0; i < rows; ++i)
+                    res_nullmap_data[i] = 
!convertDecimalToDecimalImpl<ScaleDirection::None, FromDataType, ToDataType, 
MaxNativeType>(
+                        src_data[i], scale_multiplier, pow10_to_precision, 
res_data[i]);
+        }
         else
-            result_column = std::move(col_to);
-    }
+        {
+            for (size_t i = 0; i < rows; ++i)
+                res_nullmap_data[i]
+                    = !convertNumberToDecimalImpl<FromDataType, 
ToDataType>(src_data[i], scale_multiplier, pow10_to_precision, res_data[i]);
+        }
 
-    template <typename FromDataType, typename ToDataType>
-    requires(IsDataTypeDecimal<ToDataType>)
-    static bool convertToDecimalImpl(
-        const FromDataType::FieldType & value,
-        UInt32 precision_to,
-        UInt32 scale_from,
-        UInt32 scale_to,
-        typename ToDataType::FieldType::NativeType decimal_int_part_max,
-        typename ToDataType::FieldType::NativeType decimal_int_part_min,
-        typename ToDataType::FieldType & result)
-    {
-        using FromFieldType = typename FromDataType::FieldType;
-        if constexpr (std::is_same_v<FromFieldType, Decimal32>)
-            return convertDecimalsImpl<DataTypeDecimal<Decimal32>, 
ToDataType>(value, precision_to, scale_from, scale_to, result);
-        else if constexpr (std::is_same_v<FromFieldType, Decimal64>)
-            return convertDecimalsImpl<DataTypeDecimal<Decimal64>, 
ToDataType>(value, precision_to, scale_from, scale_to, result);
-        else if constexpr (std::is_same_v<FromFieldType, Decimal128>)
-            return convertDecimalsImpl<DataTypeDecimal<Decimal128>, 
ToDataType>(value, precision_to, scale_from, scale_to, result);
-        else if constexpr (std::is_same_v<FromFieldType, Decimal256>)
-            return convertDecimalsImpl<DataTypeDecimal<Decimal256>, 
ToDataType>(value, precision_to, scale_from, scale_to, result);
-        else if constexpr (IsDataTypeNumber<FromDataType> && 
!std::is_same_v<FromFieldType, BFloat16>)
-            return convertNumberToDecimalImpl<DataTypeNumber<FromFieldType>, 
ToDataType>(value, scale_to, decimal_int_part_max, decimal_int_part_min, 
result);
+        if constexpr (exception_mode == CheckExceptionMode::Throw)
+        {
+            if (!memoryIsZero(res_nullmap_data.data(), 0, rows))
+                throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Decimal value 
is overflow.");
+
+            dst_col = std::move(res_data_col);
+        }
         else
-            throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Convert from {} type 
to decimal type is not implemented.", typeid(value).name());
+            dst_col = ColumnNullable::create(std::move(res_data_col), 
std::move(res_nullmap_col));
     }
 
     template <typename FromDataType, typename ToDataType>
     requires(IsDataTypeNumber<FromDataType> && IsDataTypeDecimal<ToDataType>)
-    static inline bool convertNumberToDecimalImpl(
-        const typename FromDataType::FieldType & value,
-        UInt32 scale,
-        typename ToDataType::FieldType::NativeType decimal_int_part_max,
-        typename ToDataType::FieldType::NativeType decimal_int_part_min,
-        typename ToDataType::FieldType & result)
+    static ALWAYS_INLINE bool convertNumberToDecimalImpl(
+        const typename FromDataType::FieldType & from,
+        const typename ToDataType::FieldType::NativeType & scale_multiplier,
+        const typename ToDataType::FieldType::NativeType & pow10_to_precision,
+        typename ToDataType::FieldType & to)
     {
         using FromFieldType = typename FromDataType::FieldType;
-        using ToFieldNativeType = typename ToDataType::FieldType::NativeType;
-        ToFieldNativeType int_part = 0;
-        if constexpr (std::is_same_v<FromFieldType, Float32> || 
std::is_same_v<FromFieldType, Float64>)
-            int_part = static_cast<ToFieldNativeType>(value);
+        using ToNativeType = typename ToDataType::FieldType::NativeType;
+
+        bool ok = false;
+        if constexpr (std::is_floating_point_v<FromFieldType>)
+        {
+            /// float to decimal
+            auto converted = from * 
static_cast<FromFieldType>(scale_multiplier);
+            auto float_pow10_to_precision = 
static_cast<FromFieldType>(pow10_to_precision);
+            ok = isFinite(from) && converted < float_pow10_to_precision && 
converted > -float_pow10_to_precision;
+            to = ok ? static_cast<ToNativeType>(converted) : 
static_cast<ToNativeType>(0);
+        }
         else
-            int_part = value;
+        {
+            /// signed integer to decimal
+            using MaxNativeType = std::conditional_t<(sizeof(FromFieldType) > 
sizeof(ToNativeType)), FromFieldType, ToNativeType>;
 
-        return int_part >= decimal_int_part_min && int_part <= 
decimal_int_part_max && tryConvertToDecimal<FromDataType, ToDataType>(value, 
scale, result);
+            MaxNativeType converted = 0;
+            ok = !common::mulOverflow(static_cast<MaxNativeType>(from), 
static_cast<MaxNativeType>(scale_multiplier), converted) && converted < 
pow10_to_precision
+                && converted > -pow10_to_precision;
+            to = ok ? static_cast<ToNativeType>(converted) : 
static_cast<ToNativeType>(0);
+        }
+        return ok;
     }
 
-    template <typename FromDataType, typename ToDataType>
+    template <ScaleDirection scale_direction, typename FromDataType, typename 
ToDataType, typename MaxNativeType>
     requires(IsDataTypeDecimal<FromDataType> && IsDataTypeDecimal<ToDataType>)
-    static bool convertDecimalsImpl(
-        const typename FromDataType::FieldType & value,
-        UInt32 precision_to,
-        UInt32 scale_from,
-        UInt32 scale_to,
-        typename ToDataType::FieldType & result)
+    static ALWAYS_INLINE bool convertDecimalToDecimalImpl(
+        const typename FromDataType::FieldType & from,
+        const MaxNativeType & scale_multiplier,
+        const MaxNativeType & pow10_to_precision,
+        typename ToDataType::FieldType & to)
     {
         using FromFieldType = typename FromDataType::FieldType;
         using ToFieldType = typename ToDataType::FieldType;
-        using MaxFieldType = std::conditional_t<(sizeof(FromFieldType) > 
sizeof(ToFieldType)), FromFieldType, ToFieldType>;
-        using MaxNativeType = typename MaxFieldType::NativeType;
+        using ToNativeType = typename ToFieldType::NativeType;
 
-
-        auto false_value = []() -> bool
+        MaxNativeType converted;
+        bool ok = false;
+        if constexpr (scale_direction == ScaleDirection::Up)
         {
-            if constexpr (exception_mode == CheckExceptionMode::Null)
-                return false;
-            else
-                throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Decimal value 
is overflow.");
-        };
-
-        MaxNativeType converted_value;
-        if (scale_to > scale_from)
+            ok = !common::mulOverflow(static_cast<MaxNativeType>(from.value), 
scale_multiplier, converted)
+                && converted < pow10_to_precision && converted > 
-pow10_to_precision;
+        }
+        else if constexpr (scale_direction == ScaleDirection::None)
         {
-            converted_value = 
DecimalUtils::scaleMultiplier<MaxNativeType>(scale_to - scale_from);
-            if (common::mulOverflow(static_cast<MaxNativeType>(value.value), 
converted_value, converted_value))
-                return false_value();
+            converted = from.value;
+            ok = converted < pow10_to_precision && converted > 
-pow10_to_precision;
         }
-        else if (scale_to == scale_from)
-            converted_value = value.value;
         else
-            converted_value = value.value / 
DecimalUtils::scaleMultiplier<MaxNativeType>(scale_from - scale_to);
-
-        // if constexpr (sizeof(FromFieldType) > sizeof(ToFieldType))
-        // {
-        MaxNativeType pow10 = intExp10OfSize<MaxNativeType>(precision_to);
-        if (converted_value <= -pow10 || converted_value >= pow10)
-            return false_value();
-        // }
+        {
+            converted = from.value / scale_multiplier;
+            ok = converted < pow10_to_precision && converted > 
-pow10_to_precision;
+        }
 
-        result = static_cast<typename 
ToFieldType::NativeType>(converted_value);
-        return true;
+        to = ok ? static_cast<ToNativeType>(converted) : 
static_cast<ToNativeType>(0);
+        return ok;
     }
 };
 
diff --git 
a/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.h 
b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.h
index beef342e08..eb2c4df4f4 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.h
+++ b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.h
@@ -422,7 +422,6 @@ private:
     }
 };
 
-/// TODO(taiyang-li): implement JIT for binary deicmal arithmetic functions
 template <class Operation, typename Name, OpMode mode = OpMode::Default>
 class SparkFunctionDecimalBinaryArithmetic final : public IFunction
 {
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionDivide.h 
b/cpp-ch/local-engine/Functions/SparkFunctionDivide.h
index ccd9eeb61b..f924d559a5 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionDivide.h
+++ b/cpp-ch/local-engine/Functions/SparkFunctionDivide.h
@@ -81,91 +81,117 @@ public:
     DB::ColumnPtr
     executeImpl(const DB::ColumnsWithTypeAndName & arguments, const 
DB::DataTypePtr & result_type, size_t input_rows_count) const override
     {
-        if (arguments.size() != 2)
-            throw 
DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {}'s 
arguments number must be 2", name);
+        using L = Float64;
+        using R = Float64;
+        using T = Float64;
+
+        const DB::ColumnVector<L> * col_left = nullptr;
+        const DB::ColumnVector<R> * col_right = nullptr;
+        const DB::ColumnVector<L> * const_col_left = 
checkAndGetColumnConstData<DB::ColumnVector<L>>(arguments[0].column.get());
+        const DB::ColumnVector<R> * const_col_right = 
checkAndGetColumnConstData<DB::ColumnVector<R>>(arguments[1].column.get());
+
+        L left_const_val = 0;
+        if (const_col_left)
+            left_const_val = const_col_left->getElement(0);
+        else
+            col_left = assert_cast<const DB::ColumnVector<L> 
*>(arguments[0].column.get());
+
+        R right_const_val = 0;
+        if (const_col_right)
+        {
+            right_const_val = const_col_right->getElement(0);
+            if (right_const_val == 0)
+            {
+                auto data_col = DB::ColumnVector<T>::create(1, 0);
+                auto null_map_col = DB::ColumnVector<UInt8>::create(1, 1);
+                return 
DB::ColumnConst::create(DB::ColumnNullable::create(std::move(data_col), 
std::move(null_map_col)), input_rows_count);
+            }
+        }
+        else
+            col_right = assert_cast<const DB::ColumnVector<R> 
*>(arguments[1].column.get());
+
+        auto res_col = DB::ColumnVector<T>::create(input_rows_count, 0);
+        auto res_null_map = DB::ColumnVector<UInt8>::create(input_rows_count, 
0);
+        DB::PaddedPODArray<T> & res_data = res_col->getData();
+        DB::PaddedPODArray<UInt8> & res_null_map_data = 
res_null_map->getData();
+        vector(col_left, col_right, left_const_val, right_const_val, res_data, 
res_null_map_data, input_rows_count);
+        return DB::ColumnNullable::create(std::move(res_col), 
std::move(res_null_map));
+    }
 
-        if (!isNativeNumber(arguments[0].type) || 
!isNativeNumber(arguments[1].type))
-            throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, 
"Function {}'s arguments type must be native number", name);
-
-        using Types = TypeList<
-            DB::DataTypeFloat32,
-            DB::DataTypeFloat64,
-            DB::DataTypeUInt8,
-            DB::DataTypeUInt16,
-            DB::DataTypeUInt32,
-            DB::DataTypeUInt64,
-            DB::DataTypeInt8,
-            DB::DataTypeInt16,
-            DB::DataTypeInt32,
-            DB::DataTypeInt64>;
-
-        DB::ColumnPtr result = nullptr;
-        bool valid = castTypeToEither(
-            Types{},
-            arguments[0].type.get(),
-            [&](const auto & left_)
+    MULTITARGET_FUNCTION_AVX2_SSE42(
+        MULTITARGET_FUNCTION_HEADER(static void NO_SANITIZE_UNDEFINED 
NO_INLINE),
+        vectorImpl,
+        MULTITARGET_FUNCTION_BODY(
+            (const DB::ColumnVector<Float64> * col_left,
+             const DB::ColumnVector<Float64> * col_right,
+             Float64 left_const_val,
+             Float64 right_const_val,
+             DB::PaddedPODArray<Float64> & res_data,
+             DB::PaddedPODArray<UInt8> & res_null_map_data,
+             size_t input_rows_count) /// NOLINT
             {
-                return castTypeToEither(
-                    Types{},
-                    arguments[1].type.get(),
-                    [&](const auto & right_)
+                if (col_left && col_right)
+                {
+                    const auto & ldata = col_left->getData();
+                    const auto & rdata = col_right->getData();
+
+                    for (size_t i = 0; i < input_rows_count; ++i)
+                    {
+                        auto l = ldata[i];
+                        auto r = rdata[i];
+                        res_data[i] = SparkDivideFloatingImpl<Float64, 
Float64>::apply(l, r ? r : 1);
+                        res_null_map_data[i] = !rdata[i];
+                    }
+                }
+                else if (col_left)
+                {
+                    Float64 r = right_const_val;
+                    for (size_t i = 0; i < input_rows_count; ++i)
                     {
-                        using L = typename 
std::decay_t<decltype(left_)>::FieldType;
-                        using R = typename 
std::decay_t<decltype(right_)>::FieldType;
-                        using T = typename 
DB::NumberTraits::ResultOfFloatingPointDivision<L, R>::Type;
-
-                        const DB::ColumnVector<L> * col_left = nullptr;
-                        const DB::ColumnVector<R> * col_right = nullptr;
-                        const DB::ColumnVector<L> * const_col_left = 
checkAndGetColumnConstData<DB::ColumnVector<L>>(arguments[0].column.get());
-                        const DB::ColumnVector<R> * const_col_right
-                            = 
checkAndGetColumnConstData<DB::ColumnVector<R>>(arguments[1].column.get());
-
-                        L left_const_val = 0;
-                        if (const_col_left)
-                            left_const_val = const_col_left->getElement(0);
-                        else
-                            col_left = assert_cast<const DB::ColumnVector<L> 
*>(arguments[0].column.get());
-
-                        R right_const_val = 0;
-                        if (const_col_right)
-                        {
-                            right_const_val = const_col_right->getElement(0);
-                            if (right_const_val == 0)
-                            {
-                                /// TODO(taiyang-li): return const column 
instead
-                                auto data_col = 
DB::ColumnVector<T>::create(arguments[0].column->size(), 0);
-                                auto null_map_col = 
DB::ColumnVector<UInt8>::create(arguments[0].column->size(), 1);
-                                result = 
DB::ColumnNullable::create(std::move(data_col), std::move(null_map_col));
-                                return true;
-                            }
-                        }
-                        else
-                            col_right = assert_cast<const DB::ColumnVector<R> 
*>(arguments[1].column.get());
-
-                        auto res_values = 
DB::ColumnVector<T>::create(input_rows_count, 0);
-                        auto res_null_map = 
DB::ColumnVector<UInt8>::create(input_rows_count, 0);
-                        DB::PaddedPODArray<T> & res_data = 
res_values->getData();
-                        DB::PaddedPODArray<UInt8> & res_null_map_data = 
res_null_map->getData();
-                        for (size_t i = 0; i < input_rows_count; ++i)
-                        {
-                            L l = col_left ? col_left->getElement(i) : 
left_const_val;
-                            R r = col_right ? col_right->getElement(i) : 
right_const_val;
-
-                            /// TODO(taiyang-li): try to vectorize it
-                            if (r == 0)
-                                res_null_map_data[i] = 1;
-                            else
-                                res_data[i] = SparkDivideFloatingImpl<L, 
R>::apply(l, r);
-                        }
-
-                        result = 
DB::ColumnNullable::create(std::move(res_values), std::move(res_null_map));
-                        return true;
-                    });
-            });
-
-        if (!valid)
-            throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, 
"Function {}'s arguments type is not valid", name);
-        return result;
+                        Float64 l = col_left->getData()[i];
+
+                        /// r must not be zero because r = 0 is already 
processed in fast path
+                        /// No need to assign null_map_data[i] = 0, because it 
is already 0
+                        // res_null_map_data[i] = 0;
+                        res_data[i] = SparkDivideFloatingImpl<Float64, 
Float64>::apply(l, r);
+                    }
+                }
+                else if (col_right)
+                {
+                    Float64 l = left_const_val;
+                    for (size_t i = 0; i < input_rows_count; ++i)
+                    {
+                        Float64 r = col_right->getData()[i];
+                        res_null_map_data[i] = !r;
+                        res_data[i] = SparkDivideFloatingImpl<Float64, 
Float64>::apply(l, r ? r : 1);
+                    }
+                }
+            }))
+
+    static void NO_INLINE vector(
+        const DB::ColumnVector<Float64> * col_left,
+        const DB::ColumnVector<Float64> * col_right,
+        Float64 left_const_val,
+        Float64 right_const_val,
+        DB::PaddedPODArray<Float64> & res_data,
+        DB::PaddedPODArray<UInt8> & res_null_map_data,
+        size_t input_rows_count)
+    {
+#if USE_MULTITARGET_CODE
+        if (isArchSupported(DB::TargetArch::AVX2))
+        {
+            vectorImplAVX2(col_left, col_right, left_const_val, 
right_const_val, res_data, res_null_map_data, input_rows_count);
+            return;
+        }
+
+        if (isArchSupported(DB::TargetArch::SSE42))
+        {
+            vectorImplSSE42(col_left, col_right, left_const_val, 
right_const_val, res_data, res_null_map_data, input_rows_count);
+            return;
+        }
+#endif
+
+        vectorImpl(col_left, col_right, left_const_val, right_const_val, 
res_data, res_null_map_data, input_rows_count);
     }
 
 #if USE_EMBEDDED_COMPILER
diff --git a/cpp-ch/local-engine/Parser/ExpressionParser.cpp 
b/cpp-ch/local-engine/Parser/ExpressionParser.cpp
index 5c6cfd0316..99bc5b3ff4 100644
--- a/cpp-ch/local-engine/Parser/ExpressionParser.cpp
+++ b/cpp-ch/local-engine/Parser/ExpressionParser.cpp
@@ -337,11 +337,12 @@ ExpressionParser::NodeRawConstPtr 
ExpressionParser::parseExpression(ActionsDAG &
             }
             else if ((isDecimal(denull_input_type) || 
isNativeNumber(denull_input_type)) && substrait_type.has_decimal())
             {
-                int decimal_precision = substrait_type.decimal().precision();
-                if (decimal_precision)
+                int precision = substrait_type.decimal().precision();
+                int scale = substrait_type.decimal().scale();
+                if (precision)
                 {
-                    args.emplace_back(addConstColumn(actions_dag, 
std::make_shared<DataTypeInt32>(), decimal_precision));
-                    args.emplace_back(addConstColumn(actions_dag, 
std::make_shared<DataTypeInt32>(), substrait_type.decimal().scale()));
+                    args.emplace_back(addConstColumn(actions_dag, 
std::make_shared<DataTypeInt32>(), precision));
+                    args.emplace_back(addConstColumn(actions_dag, 
std::make_shared<DataTypeInt32>(), scale));
                     result_node = toFunctionNode(actions_dag, 
"checkDecimalOverflowSparkOrNull", args);
                 }
             }
diff --git a/cpp-ch/local-engine/tests/CMakeLists.txt 
b/cpp-ch/local-engine/tests/CMakeLists.txt
index 4617714b5e..56bf07aa06 100644
--- a/cpp-ch/local-engine/tests/CMakeLists.txt
+++ b/cpp-ch/local-engine/tests/CMakeLists.txt
@@ -97,7 +97,7 @@ if(ENABLE_BENCHMARKS)
     benchmark_parquet_read.cpp
     benchmark_spark_row.cpp
     benchmark_unix_timestamp_function.cpp
-    benchmark_spark_floor_function.cpp
+    benchmark_spark_functions.cpp
     benchmark_cast_float_function.cpp
     benchmark_to_datetime_function.cpp
     benchmark_spark_divide_function.cpp
diff --git a/cpp-ch/local-engine/tests/benchmark_spark_floor_function.cpp 
b/cpp-ch/local-engine/tests/benchmark_spark_functions.cpp
similarity index 82%
rename from cpp-ch/local-engine/tests/benchmark_spark_floor_function.cpp
rename to cpp-ch/local-engine/tests/benchmark_spark_functions.cpp
index 95da98d322..ce6f1b42ca 100644
--- a/cpp-ch/local-engine/tests/benchmark_spark_floor_function.cpp
+++ b/cpp-ch/local-engine/tests/benchmark_spark_functions.cpp
@@ -18,10 +18,7 @@
 #if defined(__x86_64__)
 
 #include <cstddef>
-#if USE_MULTITARGET_CODE
-#include <immintrin.h>
-#endif
-
+#include <Columns/ColumnsCommon.h>
 #include <Columns/IColumn.h>
 #include <Core/Block.h>
 #include <DataTypes/DataTypeArray.h>
@@ -35,7 +32,11 @@
 #include <benchmark/benchmark.h>
 #include <Common/QueryContext.h>
 #include <Common/TargetSpecific.h>
-#include <Columns/ColumnsCommon.h>
+#include <DataTypes/DataTypeNullable.h>
+
+#if USE_MULTITARGET_CODE
+#include <immintrin.h>
+#endif
 
 using namespace DB;
 
@@ -44,7 +45,7 @@ static IColumn::Offsets createOffsets(size_t rows)
     IColumn::Offsets offsets(rows, 0);
     for (size_t i = 0; i < rows; ++i)
         offsets[i] = offsets[i-1] + (rand() % 10);
-    return std::move(offsets);
+    return offsets;
 }
 
 static ColumnPtr createColumn(const DataTypePtr & type, size_t rows)
@@ -79,6 +80,11 @@ static ColumnPtr createColumn(const DataTypePtr & type, 
size_t rows)
             double d = i * 1.0;
             column->insert(d);
         }
+        else if (isDecimal(type_not_nullable))
+        {
+            Decimal128 d = Decimal128(i * i);
+            column->insert(d);
+        }
         else if (isString(type_not_nullable))
         {
             String s = "helloworld";
@@ -158,6 +164,212 @@ static void 
BM_SparkFloorFunction_For_Float64(benchmark::State & state)
     }
 }
 
+BENCHMARK(BM_CHFloorFunction_For_Int64);
+BENCHMARK(BM_CHFloorFunction_For_Float64);
+BENCHMARK(BM_SparkFloorFunction_For_Int64);
+BENCHMARK(BM_SparkFloorFunction_For_Float64);
+
+static void BM_OptSparkDivide_VectorVector(benchmark::State & state)
+{
+    using namespace DB;
+    auto & factory = FunctionFactory::instance();
+    auto function = factory.get("sparkDivide", 
local_engine::QueryContext::globalContext());
+    auto type = DataTypeFactory::instance().get("Nullable(Float64)");
+    auto left = createColumn(type, 65536);
+    auto right = createColumn(type, 65536);
+    auto block = Block({ColumnWithTypeAndName(left, type, "left"), 
ColumnWithTypeAndName(right, type, "right")});
+    auto executable = function->build(block.getColumnsWithTypeAndName());
+    for (auto _ : state)
+    {
+        auto result = executable->execute(block.getColumnsWithTypeAndName(), 
executable->getResultType(), block.rows(), false);
+        benchmark::DoNotOptimize(result);
+    }
+}
+
+static void BM_OptSparkDivide_VectorConstant(benchmark::State & state)
+{
+    using namespace DB;
+    auto & factory = FunctionFactory::instance();
+    auto function = factory.get("sparkDivide", 
local_engine::QueryContext::globalContext());
+    auto type = DataTypeFactory::instance().get("Nullable(Float64)");
+    auto left = createColumn(type, 65536);
+    auto right = createColumn(type, 1);
+    auto const_right = ColumnConst::create(std::move(right), 65536);
+    auto block = Block({ColumnWithTypeAndName(left, type, "left"), 
ColumnWithTypeAndName(std::move(const_right), type, "right")});
+    auto executable = function->build(block.getColumnsWithTypeAndName());
+    for (auto _ : state)
+    {
+        auto result = executable->execute(block.getColumnsWithTypeAndName(), 
executable->getResultType(), block.rows(), false);
+        benchmark::DoNotOptimize(result);
+    }
+}
+
+static void BM_OptSparkDivide_ConstantVector(benchmark::State & state)
+{
+    using namespace DB;
+    auto & factory = FunctionFactory::instance();
+    auto function = factory.get("sparkDivide", 
local_engine::QueryContext::globalContext());
+    auto type = DataTypeFactory::instance().get("Nullable(Float64)");
+    auto left = createColumn(type, 1);
+    auto const_left = ColumnConst::create(std::move(left), 65536);
+    auto right = createColumn(type, 65536);
+    auto block = Block({ColumnWithTypeAndName(std::move(const_left), type, 
"left"), ColumnWithTypeAndName(std::move(right), type, "right")});
+    auto executable = function->build(block.getColumnsWithTypeAndName());
+    for (auto _ : state)
+    {
+        auto result = executable->execute(block.getColumnsWithTypeAndName(), 
executable->getResultType(), block.rows(), false);
+        benchmark::DoNotOptimize(result);
+    }
+}
+
+BENCHMARK(BM_OptSparkDivide_VectorVector);
+BENCHMARK(BM_OptSparkDivide_VectorConstant);
+BENCHMARK(BM_OptSparkDivide_ConstantVector);
+
+static void BM_OptSparkCastFloatToInt(benchmark::State & state)
+{
+    using namespace DB;
+    auto & factory = FunctionFactory::instance();
+    auto function = factory.get("sparkCastFloatToInt32", 
local_engine::QueryContext::globalContext());
+    auto type = DataTypeFactory::instance().get("Nullable(Float64)");
+    auto input = createColumn(type, 65536);
+    auto block = Block({ColumnWithTypeAndName(std::move(input), type, 
"input")});
+    auto executable = function->build(block.getColumnsWithTypeAndName());
+    for (auto _ : state)
+    {
+        auto result = executable->execute(block.getColumnsWithTypeAndName(), 
executable->getResultType(), block.rows(), false);
+        benchmark::DoNotOptimize(result);
+    }
+}
+
+BENCHMARK(BM_OptSparkCastFloatToInt);
+
+/// decimal to decimal, scale up
+static void BM_OptCheckDecimalOverflowSparkFromDecimal1(benchmark::State & 
state)
+{
+    using namespace DB;
+    auto & factory = FunctionFactory::instance();
+    auto function = factory.get("checkDecimalOverflowSparkOrNull", 
local_engine::QueryContext::globalContext());
+    auto type = DataTypeFactory::instance().get("Nullable(Decimal128(10))");
+
+    auto input = createColumn(type, 65536);
+    auto precision = ColumnConst::create(ColumnUInt32::create(1, 38), 65536);
+    auto scale = ColumnConst::create(ColumnUInt32::create(1, 5), 65536);
+
+    auto block = Block(
+        {ColumnWithTypeAndName(std::move(input), type, "input"),
+         ColumnWithTypeAndName(std::move(precision), 
std::make_shared<DataTypeUInt32>(), "precision"),
+         ColumnWithTypeAndName(std::move(scale), 
std::make_shared<DataTypeUInt32>(), "scale")});
+    auto executable = function->build(block.getColumnsWithTypeAndName());
+    for (auto _ : state)
+    {
+        auto result = executable->execute(block.getColumnsWithTypeAndName(), 
executable->getResultType(), block.rows(), false);
+        benchmark::DoNotOptimize(result);
+    }
+}
+
+/// decimal to decimal, scale down
+static void BM_OptCheckDecimalOverflowSparkFromDecimal2(benchmark::State & 
state)
+{
+    using namespace DB;
+    auto & factory = FunctionFactory::instance();
+    auto function = factory.get("checkDecimalOverflowSparkOrNull", 
local_engine::QueryContext::globalContext());
+    auto type = DataTypeFactory::instance().get("Nullable(Decimal128(10))");
+
+    auto input = createColumn(type, 65536);
+    auto precision = ColumnConst::create(ColumnUInt32::create(1, 38), 65536);
+    auto scale = ColumnConst::create(ColumnUInt32::create(1, 15), 65536);
+
+    auto block = Block(
+        {ColumnWithTypeAndName(std::move(input), type, "input"),
+         ColumnWithTypeAndName(std::move(precision), 
std::make_shared<DataTypeUInt32>(), "precision"),
+         ColumnWithTypeAndName(std::move(scale), 
std::make_shared<DataTypeUInt32>(), "scale")});
+    auto executable = function->build(block.getColumnsWithTypeAndName());
+    for (auto _ : state)
+    {
+        auto result = executable->execute(block.getColumnsWithTypeAndName(), 
executable->getResultType(), block.rows(), false);
+        benchmark::DoNotOptimize(result);
+    }
+}
+
+/// decimal to decimal, scale doesn't change
+static void BM_OptCheckDecimalOverflowSparkFromDecimal3(benchmark::State & 
state)
+{
+    using namespace DB;
+    auto & factory = FunctionFactory::instance();
+    auto function = factory.get("checkDecimalOverflowSparkOrNull", 
local_engine::QueryContext::globalContext());
+    auto type = DataTypeFactory::instance().get("Nullable(Decimal(38, 10))");
+
+    auto input = createColumn(type, 65536);
+    auto precision = ColumnConst::create(ColumnUInt32::create(1, 38), 65536);
+    auto scale = ColumnConst::create(ColumnUInt32::create(1, 10), 65536);
+
+    auto block = Block(
+        {ColumnWithTypeAndName(std::move(input), type, "input"),
+         ColumnWithTypeAndName(std::move(precision), 
std::make_shared<DataTypeUInt32>(), "precision"),
+         ColumnWithTypeAndName(std::move(scale), 
std::make_shared<DataTypeUInt32>(), "scale")});
+    auto executable = function->build(block.getColumnsWithTypeAndName());
+    for (auto _ : state)
+    {
+        auto result = executable->execute(block.getColumnsWithTypeAndName(), 
executable->getResultType(), block.rows(), false);
+        benchmark::DoNotOptimize(result);
+    }
+}
+
+/// int to decimal
+static void BM_OptCheckDecimalOverflowSparkFromInt(benchmark::State & state)
+{
+    using namespace DB;
+    auto & factory = FunctionFactory::instance();
+    auto function = factory.get("checkDecimalOverflowSparkOrNull", 
local_engine::QueryContext::globalContext());
+    auto type = DataTypeFactory::instance().get("Nullable(Int64)");
+
+    auto input = createColumn(type, 65536);
+    auto precision = ColumnConst::create(ColumnUInt32::create(1, 38), 65536);
+    auto scale = ColumnConst::create(ColumnUInt32::create(1, 10), 65536);
+
+    auto block = Block(
+        {ColumnWithTypeAndName(std::move(input), type, "input"),
+         ColumnWithTypeAndName(std::move(precision), 
std::make_shared<DataTypeUInt32>(), "precision"),
+         ColumnWithTypeAndName(std::move(scale), 
std::make_shared<DataTypeUInt32>(), "scale")});
+    auto executable = function->build(block.getColumnsWithTypeAndName());
+    for (auto _ : state)
+    {
+        auto result = executable->execute(block.getColumnsWithTypeAndName(), 
executable->getResultType(), block.rows(), false);
+        benchmark::DoNotOptimize(result);
+    }
+}
+
+/// float to decimal
+static void BM_OptCheckDecimalOverflowSparkFromFloat(benchmark::State & state)
+{
+    using namespace DB;
+    auto & factory = FunctionFactory::instance();
+    auto function = factory.get("checkDecimalOverflowSparkOrNull", 
local_engine::QueryContext::globalContext());
+    auto type = DataTypeFactory::instance().get("Nullable(Float64)");
+
+    auto input = createColumn(type, 65536);
+    auto precision = ColumnConst::create(ColumnUInt32::create(1, 38), 65536);
+    auto scale = ColumnConst::create(ColumnUInt32::create(1, 10), 65536);
+
+    auto block = Block(
+        {ColumnWithTypeAndName(std::move(input), type, "input"),
+         ColumnWithTypeAndName(std::move(precision), 
std::make_shared<DataTypeUInt32>(), "precision"),
+         ColumnWithTypeAndName(std::move(scale), 
std::make_shared<DataTypeUInt32>(), "scale")});
+    auto executable = function->build(block.getColumnsWithTypeAndName());
+    for (auto _ : state)
+    {
+        auto result = executable->execute(block.getColumnsWithTypeAndName(), 
executable->getResultType(), block.rows(), false);
+        benchmark::DoNotOptimize(result);
+    }
+}
+
+BENCHMARK(BM_OptCheckDecimalOverflowSparkFromDecimal1);
+BENCHMARK(BM_OptCheckDecimalOverflowSparkFromDecimal2);
+BENCHMARK(BM_OptCheckDecimalOverflowSparkFromDecimal3);
+BENCHMARK(BM_OptCheckDecimalOverflowSparkFromInt);
+BENCHMARK(BM_OptCheckDecimalOverflowSparkFromFloat);
+
 static void nanInfToNullAutoOpt(float * data, uint8_t * null_map, size_t size)
 {
     for (size_t i = 0; i < size; ++i)
@@ -267,10 +479,6 @@ static void BMNanInfToNull(benchmark::State & state)
 }
 BENCHMARK(BMNanInfToNull);
 
-BENCHMARK(BM_CHFloorFunction_For_Int64);
-BENCHMARK(BM_CHFloorFunction_For_Float64);
-BENCHMARK(BM_SparkFloorFunction_For_Int64);
-BENCHMARK(BM_SparkFloorFunction_For_Float64);
 
 
 /*
@@ -1278,4 +1486,5 @@ BENCHMARK_TEMPLATE(BM_myFilterToIndicesAVX512, UInt32);
 BENCHMARK_TEMPLATE(BM_myFilterToIndicesAVX512, UInt64);
 */
 
+
 #endif
diff --git 
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
 
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
index ff1d688472..3c1429a34f 100644
--- 
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
+++ 
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala
@@ -540,6 +540,7 @@ class ClickHouseTestSettings extends BackendTestSettings {
     )
     // test for sort node not present but gluten uses shuffle hash join
     .exclude("SPARK-41048: Improve output partitioning and ordering with AQE 
cache")
+    .exclude("SPARK-28224: Aggregate sum big decimal overflow")
     // Rewrite this test since it checks the physical operator which is 
changed in Gluten
     .excludeCH("SPARK-27439: Explain result should match collected result 
after view change")
     .excludeCH("SPARK-28067: Aggregate sum should not return wrong results for 
decimal overflow")


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to