This is an automated email from the ASF dual-hosted git repository.

jacktengg pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 6fdec9f50c2 [improve](cast) improve cast performance (#54218)
6fdec9f50c2 is described below

commit 6fdec9f50c214de299e974ab16e7599b6888edb5
Author: TengJianPing <[email protected]>
AuthorDate: Sat Aug 2 09:51:27 2025 +0800

    [improve](cast) improve cast performance (#54218)
    
    #54214 has compile error, it's reverted by #54208.
    
    This PR fix the compile error and recommit.
---
 be/src/vec/data_types/number_traits.h              |   3 +-
 .../functions/cast/cast_to_basic_number_common.h   |  35 ++-
 be/src/vec/functions/cast/cast_to_decimal.h        | 299 ++++++++++++++-------
 be/src/vec/functions/cast/cast_to_float.h          |  11 +-
 be/src/vec/functions/cast/cast_to_int.h            |  13 +-
 5 files changed, 251 insertions(+), 110 deletions(-)

diff --git a/be/src/vec/data_types/number_traits.h 
b/be/src/vec/data_types/number_traits.h
index 7e2778c4cd8..a5d0672bbde 100644
--- a/be/src/vec/data_types/number_traits.h
+++ b/be/src/vec/data_types/number_traits.h
@@ -264,9 +264,10 @@ constexpr int max_ascii_len() {
     return 0;
 }
 
+// bool type
 template <>
 inline constexpr int max_ascii_len<uint8_t>() {
-    return 3;
+    return 1;
 }
 
 template <>
diff --git a/be/src/vec/functions/cast/cast_to_basic_number_common.h 
b/be/src/vec/functions/cast/cast_to_basic_number_common.h
index 6cda24a9ed9..3ed5cc0368e 100644
--- a/be/src/vec/functions/cast/cast_to_basic_number_common.h
+++ b/be/src/vec/functions/cast/cast_to_basic_number_common.h
@@ -181,11 +181,20 @@ struct CastToInt {
         requires(IsCppTypeInt<ToCppT> && IsDecimalNumber<FromCppT>)
     static inline bool from_decimal(FromCppT from, UInt32 from_precision, 
UInt32 from_scale,
                                     ToCppT& to, CastParameters& params) {
+        typename FromCppT::NativeType scale_multiplier =
+                
DataTypeDecimal<FromCppT::PType>::get_scale_multiplier(from_scale);
         constexpr UInt32 to_max_digits = NumberTraits::max_ascii_len<ToCppT>();
         bool narrow_integral = (from_precision - from_scale) >= to_max_digits;
+        return _from_decimal(from, from_precision, from_scale, to, 
scale_multiplier,
+                             narrow_integral, params);
+    }
 
-        typename FromCppT::NativeType scale_multiplier =
-                
DataTypeDecimal<FromCppT::PType>::get_scale_multiplier(from_scale);
+    template <typename FromCppT, typename ToCppT>
+        requires(IsCppTypeInt<ToCppT> && IsDecimalNumber<FromCppT>)
+    static inline bool _from_decimal(FromCppT from, UInt32 from_precision, 
UInt32 from_scale,
+                                     ToCppT& to,
+                                     const typename FromCppT::NativeType& 
scale_multiplier,
+                                     bool narrow_integral, CastParameters& 
params) {
         constexpr auto min_result = std::numeric_limits<ToCppT>::lowest();
         constexpr auto max_result = std::numeric_limits<ToCppT>::max();
         auto tmp = from.value / scale_multiplier;
@@ -274,16 +283,24 @@ struct CastToFloat {
                                     CastParameters& params) {
         if constexpr (IsDecimalV2<FromCppT>) {
             to = binary_cast<int128_t, DecimalV2Value>(from);
+            return true;
         } else {
             typename FromCppT::NativeType scale_multiplier =
                     
DataTypeDecimal<FromCppT::PType>::get_scale_multiplier(from_scale);
-            if constexpr (IsDecimal256<FromCppT>) {
-                to = static_cast<ToCppT>(static_cast<long double>(from.value) /
-                                         static_cast<long 
double>(scale_multiplier));
-            } else {
-                to = static_cast<ToCppT>(static_cast<double>(from.value) /
-                                         
static_cast<double>(scale_multiplier));
-            }
+            return _from_decimalv3(from, from_scale, to, scale_multiplier, 
params);
+        }
+    }
+    template <typename FromCppT, typename ToCppT>
+        requires(IsCppTypeFloat<ToCppT> && IsDecimalNumber<FromCppT>)
+    static inline bool _from_decimalv3(const FromCppT& from, UInt32 
from_scale, ToCppT& to,
+                                       const typename FromCppT::NativeType& 
scale_multiplier,
+                                       CastParameters& params) {
+        if constexpr (IsDecimal256<FromCppT>) {
+            to = static_cast<ToCppT>(static_cast<long double>(from.value) /
+                                     static_cast<long 
double>(scale_multiplier));
+        } else {
+            to = static_cast<ToCppT>(static_cast<double>(from.value) /
+                                     static_cast<double>(scale_multiplier));
         }
         return true;
     }
diff --git a/be/src/vec/functions/cast/cast_to_decimal.h 
b/be/src/vec/functions/cast/cast_to_decimal.h
index cd1937d5803..b464bb43b0b 100644
--- a/be/src/vec/functions/cast/cast_to_decimal.h
+++ b/be/src/vec/functions/cast/cast_to_decimal.h
@@ -68,13 +68,14 @@ struct CastToDecimal {
 
     // cast int to decimal
     template <typename FromCppT, typename ToCppT,
-              typename MaxFieldType =
+              typename MaxNativeType =
                       std::conditional_t<(sizeof(FromCppT) > sizeof(typename 
ToCppT::NativeType)),
                                          FromCppT, typename 
ToCppT::NativeType>>
-        requires(IsDecimalNumber<ToCppT> && IsCppTypeInt<FromCppT>)
+        requires(IsDecimalNumber<ToCppT> &&
+                 (IsCppTypeInt<FromCppT> || std::is_same_v<FromCppT, 
vectorized::UInt8>))
     static inline bool from_int(const FromCppT& from, ToCppT& to, UInt32 
to_precision,
                                 UInt32 to_scale, CastParameters& params) {
-        MaxFieldType scale_multiplier =
+        MaxNativeType scale_multiplier =
                 DataTypeDecimal<ToCppT::PType>::get_scale_multiplier(to_scale);
         typename ToCppT::NativeType max_result =
                 
DataTypeDecimal<ToCppT::PType>::get_max_digits_number(to_precision);
@@ -102,36 +103,13 @@ struct CastToDecimal {
 
     // cast bool to decimal
     template <typename FromCppT, typename ToCppT,
-              typename MaxFieldType =
+              typename MaxNativeType =
                       std::conditional_t<(sizeof(FromCppT) > sizeof(typename 
ToCppT::NativeType)),
                                          FromCppT, typename 
ToCppT::NativeType>>
         requires(IsDecimalNumber<ToCppT> && std::is_same_v<FromCppT, 
vectorized::UInt8>)
     static inline bool from_bool(const FromCppT& from, ToCppT& to, UInt32 
to_precision,
                                  UInt32 to_scale, CastParameters& params) {
-        MaxFieldType scale_multiplier =
-                DataTypeDecimal<ToCppT::PType>::get_scale_multiplier(to_scale);
-        typename ToCppT::NativeType max_result =
-                
DataTypeDecimal<ToCppT::PType>::get_max_digits_number(to_precision);
-        typename ToCppT::NativeType min_result = -max_result;
-
-        UInt32 from_precision = NumberTraits::max_ascii_len<FromCppT>();
-        constexpr UInt32 from_scale = 0;
-        constexpr UInt32 to_max_digits = NumberTraits::max_ascii_len<typename 
ToCppT::NativeType>();
-
-        auto from_max_int_digit_count = from_precision - from_scale;
-        auto to_max_int_digit_count = to_precision - to_scale;
-        bool narrow_integral = (to_max_int_digit_count < 
from_max_int_digit_count);
-        bool multiply_may_overflow = false;
-        if (to_scale > from_scale) {
-            multiply_may_overflow = (from_precision + to_scale - from_scale) 
>= to_max_digits;
-        }
-        return std::visit(
-                [&](auto multiply_may_overflow, auto narrow_integral) {
-                    return _from_int<FromCppT, ToCppT, multiply_may_overflow, 
narrow_integral>(
-                            from, to, to_precision, to_scale, 
scale_multiplier, min_result,
-                            max_result, params);
-                },
-                make_bool_variant(multiply_may_overflow), 
make_bool_variant(narrow_integral));
+        return from_int<FromCppT, ToCppT, MaxNativeType>(from, to, 
to_precision, to_scale, params);
     }
 
     template <typename FromCppT, typename ToCppT>
@@ -144,6 +122,18 @@ struct CastToDecimal {
                 
DataTypeDecimal<ToCppT::PType>::get_max_digits_number(to_precision);
         typename ToCppT::NativeType min_result = -max_result;
 
+        return _from_float<FromCppT, ToCppT>(from, to, to_precision, to_scale, 
scale_multiplier,
+                                             min_result, max_result, params);
+    }
+
+    template <typename FromCppT, typename ToCppT>
+        requires(IsDecimalNumber<ToCppT> && IsCppTypeFloat<FromCppT>)
+    static inline bool _from_float(const FromCppT& from, ToCppT& to, UInt32 
to_precision,
+                                   UInt32 to_scale,
+                                   const typename ToCppT::NativeType& 
scale_multiplier,
+                                   const typename ToCppT::NativeType& 
min_result,
+                                   const typename ToCppT::NativeType& 
max_result,
+                                   CastParameters& params) {
         if (!std::isfinite(from)) {
             params.status = Status(ErrorCode::ARITHMETIC_OVERFLOW_ERRROR,
                                    "Decimal convert overflow. Cannot convert 
infinity or NaN "
@@ -255,35 +245,58 @@ struct CastToDecimal {
                 
DataTypeDecimal<ToCppT::PType>::get_max_digits_number(to_precision);
         typename ToCppT::NativeType min_result = -max_result;
 
+        MaxNativeType multiplier {};
+        if (from_scale < to_scale) {
+            multiplier = 
DataTypeDecimal<MaxFieldType::PType>::get_scale_multiplier(to_scale -
+                                                                               
     from_scale);
+        } else if (from_scale > to_scale) {
+            multiplier = 
DataTypeDecimal<MaxFieldType::PType>::get_scale_multiplier(from_scale -
+                                                                               
     to_scale);
+        }
+
         return std::visit(
                 [&](auto multiply_may_overflow, auto narrow_integral) {
-                    if (from_scale < to_scale) {
-                        MaxNativeType multiplier =
-                                
DataTypeDecimal<MaxFieldType::PType>::get_scale_multiplier(
-                                        to_scale - from_scale);
-                        return _from_decimal_smaller_scale<FromCppT, ToCppT, 
multiply_may_overflow,
-                                                           narrow_integral>(
-                                from, from_precision, from_scale, to, 
to_precision, to_scale,
-                                multiplier, min_result, max_result, params);
-                    } else if (from_scale == to_scale) {
-                        return _from_decimal_same_scale<FromCppT, ToCppT, 
MaxNativeType,
-                                                        narrow_integral>(
-                                from, from_precision, from_scale, to, 
to_precision, to_scale,
-                                min_result, max_result, params);
-                    } else {
-                        MaxNativeType multiplier =
-                                
DataTypeDecimal<MaxFieldType::PType>::get_scale_multiplier(
-                                        from_scale - to_scale);
-                        return _from_decimal_bigger_scale<FromCppT, ToCppT, 
multiply_may_overflow,
-                                                          narrow_integral>(
-                                from, from_precision, from_scale, to, 
to_precision, to_scale,
-                                multiplier, min_result, max_result, params);
-                    }
-                    return true;
+                    return _from_decimal<FromCppT, ToCppT, 
multiply_may_overflow, narrow_integral>(
+                            from, from_precision, from_scale, to, 
to_precision, to_scale,
+                            min_result, max_result, multiplier, params);
                 },
                 make_bool_variant(multiply_may_overflow), 
make_bool_variant(narrow_integral));
     }
 
+    template <typename FromCppT, typename ToCppT, bool multiply_may_overflow, 
bool narrow_integral,
+              typename MaxFieldType = std::conditional_t<
+                      (sizeof(FromCppT) == sizeof(ToCppT)) &&
+                              (std::is_same_v<ToCppT, Decimal128V3> ||
+                               std::is_same_v<FromCppT, Decimal128V3>),
+                      Decimal128V3,
+                      std::conditional_t<(sizeof(FromCppT) > sizeof(ToCppT)), 
FromCppT, ToCppT>>>
+        requires(IsDecimalNumber<ToCppT> && IsDecimalNumber<FromCppT>)
+    static inline bool _from_decimal(const FromCppT& from, const UInt32 
from_precision,
+                                     const UInt32 from_scale, ToCppT& to, 
UInt32 to_precision,
+                                     UInt32 to_scale, const 
ToCppT::NativeType& min_result,
+                                     const ToCppT::NativeType& max_result,
+                                     const typename MaxFieldType::NativeType& 
scale_multiplier,
+                                     CastParameters& params) {
+        using MaxNativeType = typename MaxFieldType::NativeType;
+
+        if (from_scale < to_scale) {
+            return _from_decimal_smaller_scale<FromCppT, ToCppT, 
multiply_may_overflow,
+                                               narrow_integral>(
+                    from, from_precision, from_scale, to, to_precision, 
to_scale, scale_multiplier,
+                    min_result, max_result, params);
+        } else if (from_scale == to_scale) {
+            return _from_decimal_same_scale<FromCppT, ToCppT, MaxNativeType, 
narrow_integral>(
+                    from, from_precision, from_scale, to, to_precision, 
to_scale, min_result,
+                    max_result, params);
+        } else {
+            return _from_decimal_bigger_scale<FromCppT, ToCppT, 
multiply_may_overflow,
+                                              narrow_integral>(
+                    from, from_precision, from_scale, to, to_precision, 
to_scale, scale_multiplier,
+                    min_result, max_result, params);
+        }
+        return true;
+    }
+
     template <
             typename FromCppT, typename ToCppT, bool multiply_may_overflow, 
bool narrow_integral,
             typename MaxNativeType = std::conditional_t<
@@ -513,6 +526,7 @@ public:
                         uint32_t result, size_t input_rows_count,
                         const NullMap::value_type* null_map = nullptr) const 
override {
         using FromFieldType = typename FromDataType::FieldType;
+        using ToFieldType = typename ToDataType::FieldType;
         const ColumnWithTypeAndName& named_from = 
block.get_by_position(arguments[0]);
         const auto* col_from =
                 check_and_get_column<typename 
FromDataType::ColumnType>(named_from.column.get());
@@ -536,6 +550,21 @@ public:
         bool narrow_integral = (to_max_int_digit_count < 
from_max_int_digit_count);
         bool result_is_nullable = (CastMode == CastModeType::NonStrictMode) && 
narrow_integral;
 
+        constexpr UInt32 to_max_digits =
+                NumberTraits::max_ascii_len<typename 
ToFieldType::NativeType>();
+        bool multiply_may_overflow = false;
+        if (to_scale > from_scale) {
+            multiply_may_overflow = (from_precision + to_scale - from_scale) 
>= to_max_digits;
+        }
+        using MaxNativeType = std::conditional_t<(sizeof(FromFieldType) >
+                                                  sizeof(typename 
ToFieldType::NativeType)),
+                                                 FromFieldType, typename 
ToFieldType::NativeType>;
+        MaxNativeType scale_multiplier =
+                
DataTypeDecimal<ToFieldType::PType>::get_scale_multiplier(to_scale);
+        typename ToFieldType::NativeType max_result =
+                
DataTypeDecimal<ToFieldType::PType>::get_max_digits_number(to_precision);
+        typename ToFieldType::NativeType min_result = -max_result;
+
         ColumnUInt8::MutablePtr col_null_map_to;
         NullMap::value_type* null_map_data = nullptr;
         if (result_is_nullable) {
@@ -552,29 +581,25 @@ public:
         CastParameters params;
         params.is_strict = (CastMode == CastModeType::StrictMode);
         size_t size = vec_from.size();
-        for (size_t i = 0; i < size; i++) {
-            if constexpr (IsDataTypeBool<FromDataType>) {
-                if (!CastToDecimal::from_bool<typename FromDataType::FieldType,
-                                              typename ToDataType::FieldType>(
-                            vec_from_data[i], vec_to_data[i], to_precision, 
to_scale, params)) {
-                    if (result_is_nullable) {
-                        null_map_data[i] = 1;
-                    } else {
-                        return params.status;
-                    }
-                }
-            } else {
-                if (!CastToDecimal::from_int<typename FromDataType::FieldType,
-                                             typename ToDataType::FieldType>(
-                            vec_from_data[i], vec_to_data[i], to_precision, 
to_scale, params)) {
-                    if (result_is_nullable) {
-                        null_map_data[i] = 1;
-                    } else {
-                        return params.status;
+
+        RETURN_IF_ERROR(std::visit(
+                [&](auto multiply_may_overflow, auto narrow_integral) {
+                    for (size_t i = 0; i < size; i++) {
+                        if (!CastToDecimal::_from_int<typename 
FromDataType::FieldType,
+                                                      typename 
ToDataType::FieldType,
+                                                      multiply_may_overflow, 
narrow_integral>(
+                                    vec_from_data[i], vec_to_data[i], 
to_precision, to_scale,
+                                    scale_multiplier, min_result, max_result, 
params)) {
+                            if (result_is_nullable) {
+                                null_map_data[i] = 1;
+                            } else {
+                                return params.status;
+                            }
+                        }
                     }
-                }
-            }
-        }
+                    return Status::OK();
+                },
+                make_bool_variant(multiply_may_overflow), 
make_bool_variant(narrow_integral)));
 
         if (result_is_nullable) {
             block.get_by_position(result).column =
@@ -595,6 +620,7 @@ public:
                         uint32_t result, size_t input_rows_count,
                         const NullMap::value_type* null_map = nullptr) const 
override {
         using FromFieldType = typename FromDataType::FieldType;
+        using ToFieldType = typename ToDataType::FieldType;
         const ColumnWithTypeAndName& named_from = 
block.get_by_position(arguments[0]);
         const auto* col_from =
                 check_and_get_column<typename 
FromDataType::ColumnType>(named_from.column.get());
@@ -636,10 +662,17 @@ public:
         CastParameters params;
         params.is_strict = (CastMode == CastModeType::StrictMode);
         size_t size = vec_from.size();
+
+        typename ToFieldType::NativeType scale_multiplier =
+                
DataTypeDecimal<ToFieldType::PType>::get_scale_multiplier(to_scale);
+        typename ToFieldType::NativeType max_result =
+                
DataTypeDecimal<ToFieldType::PType>::get_max_digits_number(to_precision);
+        typename ToFieldType::NativeType min_result = -max_result;
         for (size_t i = 0; i < size; i++) {
-            if (!CastToDecimal::from_float<typename FromDataType::FieldType,
-                                           typename ToDataType::FieldType>(
-                        vec_from_data[i], vec_to_data[i], to_precision, 
to_scale, params)) {
+            if (!CastToDecimal::_from_float<typename FromDataType::FieldType,
+                                            typename ToDataType::FieldType>(
+                        vec_from_data[i], vec_to_data[i], to_precision, 
to_scale, scale_multiplier,
+                        min_result, max_result, params)) {
                 if (result_is_nullable) {
                     null_map_data[i] = 1;
                 } else {
@@ -681,6 +714,8 @@ public:
     Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
                         uint32_t result, size_t input_rows_count,
                         const NullMap::value_type* null_map = nullptr) const 
override {
+        using FromFieldType = typename FromDataType::FieldType;
+        using ToFieldType = typename ToDataType::FieldType;
         const ColumnWithTypeAndName& named_from = 
block.get_by_position(arguments[0]);
         const auto* col_from =
                 check_and_get_column<typename 
FromDataType::ColumnType>(named_from.column.get());
@@ -724,17 +759,53 @@ public:
         const auto* vec_from_data = vec_from.data();
         auto& vec_to = col_to->get_data();
         auto* vec_to_data = vec_to.data();
-        for (size_t i = 0; i < size; i++) {
-            if (!CastToDecimal::from_decimalv2(vec_from_data[i], 
from_precision, from_scale,
-                                               from_original_precision, 
from_original_scale,
-                                               vec_to_data[i], to_precision, 
to_scale, params)) {
-                if (result_is_nullable) {
-                    null_map_data[i] = 1;
-                } else {
-                    return params.status;
-                }
-            }
+
+        using MaxFieldType =
+                std::conditional_t<(sizeof(FromFieldType) == 
sizeof(ToFieldType)) &&
+                                           (std::is_same_v<ToFieldType, 
Decimal128V3> ||
+                                            std::is_same_v<FromFieldType, 
Decimal128V3>),
+                                   Decimal128V3,
+                                   std::conditional_t<(sizeof(FromFieldType) > 
sizeof(ToFieldType)),
+                                                      FromFieldType, 
ToFieldType>>;
+        using MaxNativeType = typename MaxFieldType::NativeType;
+
+        constexpr UInt32 to_max_digits =
+                NumberTraits::max_ascii_len<typename 
ToFieldType::NativeType>();
+        bool multiply_may_overflow = false;
+        if (to_scale > from_scale) {
+            multiply_may_overflow = (from_precision + to_scale - from_scale) 
>= to_max_digits;
+        }
+
+        typename ToFieldType::NativeType max_result =
+                
DataTypeDecimal<ToFieldType::PType>::get_max_digits_number(to_precision);
+        typename ToFieldType::NativeType min_result = -max_result;
+
+        MaxNativeType multiplier {};
+        if (from_scale < to_scale) {
+            multiplier = 
DataTypeDecimal<MaxFieldType::PType>::get_scale_multiplier(to_scale -
+                                                                               
     from_scale);
+        } else if (from_scale > to_scale) {
+            multiplier = 
DataTypeDecimal<MaxFieldType::PType>::get_scale_multiplier(from_scale -
+                                                                               
     to_scale);
         }
+        RETURN_IF_ERROR(std::visit(
+                [&](auto multiply_may_overflow, auto narrow_integral) {
+                    for (size_t i = 0; i < size; i++) {
+                        if (!CastToDecimal::_from_decimal<FromFieldType, 
ToFieldType,
+                                                          
multiply_may_overflow, narrow_integral>(
+                                    vec_from_data[i], from_precision, 
from_scale, vec_to_data[i],
+                                    to_precision, to_scale, min_result, 
max_result, multiplier,
+                                    params)) {
+                            if (result_is_nullable) {
+                                null_map_data[i] = 1;
+                            } else {
+                                return params.status;
+                            }
+                        }
+                    }
+                    return Status::OK();
+                },
+                make_bool_variant(multiply_may_overflow), 
make_bool_variant(narrow_integral)));
         if (result_is_nullable) {
             block.get_by_position(result).column =
                     ColumnNullable::create(std::move(col_to), 
std::move(col_null_map_to));
@@ -753,6 +824,8 @@ public:
     Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
                         uint32_t result, size_t input_rows_count,
                         const NullMap::value_type* null_map = nullptr) const 
override {
+        using FromFieldType = typename FromDataType::FieldType;
+        using ToFieldType = typename ToDataType::FieldType;
         const ColumnWithTypeAndName& named_from = 
block.get_by_position(arguments[0]);
         const auto* col_from =
                 check_and_get_column<typename 
FromDataType::ColumnType>(named_from.column.get());
@@ -794,16 +867,52 @@ public:
         const auto* vec_from_data = vec_from.data();
         auto& vec_to = col_to->get_data();
         auto* vec_to_data = vec_to.data();
-        for (size_t i = 0; i < size; i++) {
-            if (!CastToDecimal::from_decimalv3(vec_from_data[i], 
from_precision, from_scale,
-                                               vec_to_data[i], to_precision, 
to_scale, params)) {
-                if (result_is_nullable) {
-                    null_map_data[i] = 1;
-                } else {
-                    return params.status;
-                }
-            }
+
+        using MaxFieldType =
+                std::conditional_t<(sizeof(FromFieldType) == 
sizeof(ToFieldType)) &&
+                                           (std::is_same_v<ToFieldType, 
Decimal128V3> ||
+                                            std::is_same_v<FromFieldType, 
Decimal128V3>),
+                                   Decimal128V3,
+                                   std::conditional_t<(sizeof(FromFieldType) > 
sizeof(ToFieldType)),
+                                                      FromFieldType, 
ToFieldType>>;
+        using MaxNativeType = typename MaxFieldType::NativeType;
+
+        UInt32 to_max_digits = NumberTraits::max_ascii_len<typename 
ToFieldType::NativeType>();
+        bool multiply_may_overflow = false;
+        if (to_scale > from_scale) {
+            multiply_may_overflow = (from_precision + to_scale - from_scale) 
>= to_max_digits;
+        }
+
+        typename ToFieldType::NativeType max_result =
+                
DataTypeDecimal<ToFieldType::PType>::get_max_digits_number(to_precision);
+        typename ToFieldType::NativeType min_result = -max_result;
+
+        MaxNativeType multiplier {};
+        if (from_scale < to_scale) {
+            multiplier = 
DataTypeDecimal<MaxFieldType::PType>::get_scale_multiplier(to_scale -
+                                                                               
     from_scale);
+        } else if (from_scale > to_scale) {
+            multiplier = 
DataTypeDecimal<MaxFieldType::PType>::get_scale_multiplier(from_scale -
+                                                                               
     to_scale);
         }
+        RETURN_IF_ERROR(std::visit(
+                [&](auto multiply_may_overflow, auto narrow_integral) {
+                    for (size_t i = 0; i < size; i++) {
+                        if (!CastToDecimal::_from_decimal<FromFieldType, 
ToFieldType,
+                                                          
multiply_may_overflow, narrow_integral>(
+                                    vec_from_data[i], from_precision, 
from_scale, vec_to_data[i],
+                                    to_precision, to_scale, min_result, 
max_result, multiplier,
+                                    params)) {
+                            if (result_is_nullable) {
+                                null_map_data[i] = 1;
+                            } else {
+                                return params.status;
+                            }
+                        }
+                    }
+                    return Status::OK();
+                },
+                make_bool_variant(multiply_may_overflow), 
make_bool_variant(narrow_integral)));
         if (result_is_nullable) {
             block.get_by_position(result).column =
                     ColumnNullable::create(std::move(col_to), 
std::move(col_null_map_to));
diff --git a/be/src/vec/functions/cast/cast_to_float.h 
b/be/src/vec/functions/cast/cast_to_float.h
index 0387f743563..771b38768e0 100644
--- a/be/src/vec/functions/cast/cast_to_float.h
+++ b/be/src/vec/functions/cast/cast_to_float.h
@@ -47,6 +47,7 @@ public:
     Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
                         uint32_t result, size_t input_rows_count,
                         const NullMap::value_type* null_map = nullptr) const 
override {
+        using FromFieldType = typename FromDataType::FieldType;
         const ColumnWithTypeAndName& named_from = 
block.get_by_position(arguments[0]);
         const auto* col_from =
                 check_and_get_column<typename 
FromDataType::ColumnType>(named_from.column.get());
@@ -67,8 +68,16 @@ public:
         CastParameters params;
         params.is_strict = (CastMode == CastModeType::StrictMode);
         size_t size = vec_from.size();
+
+        typename FromFieldType::NativeType scale_multiplier =
+                
DataTypeDecimal<FromFieldType::PType>::get_scale_multiplier(from_scale);
         for (size_t i = 0; i < size; ++i) {
-            CastToFloat::from_decimal(vec_from_data[i], from_scale, 
vec_to_data[i], params);
+            if constexpr (IsDecimalV2<FromFieldType>) {
+                vec_to_data[i] = binary_cast<int128_t, 
DecimalV2Value>(vec_from_data[i]);
+            } else {
+                CastToFloat::_from_decimalv3(vec_from_data[i], from_scale, 
vec_to_data[i],
+                                             scale_multiplier, params);
+            }
         }
 
         block.get_by_position(result).column = std::move(col_to);
diff --git a/be/src/vec/functions/cast/cast_to_int.h 
b/be/src/vec/functions/cast/cast_to_int.h
index edf85ccf026..a1a84b4f0de 100644
--- a/be/src/vec/functions/cast/cast_to_int.h
+++ b/be/src/vec/functions/cast/cast_to_int.h
@@ -149,6 +149,7 @@ public:
                         uint32_t result, size_t input_rows_count,
                         const NullMap::value_type* null_map = nullptr) const 
override {
         using ToFieldType = typename ToDataType::FieldType;
+        using FromFieldType = typename FromDataType::FieldType;
 
         const ColumnWithTypeAndName& named_from = 
block.get_by_position(arguments[0]);
         const auto* col_from =
@@ -162,7 +163,8 @@ public:
         UInt32 from_precision = from_decimal_type.get_precision();
         UInt32 from_scale = from_decimal_type.get_scale();
 
-        UInt32 to_max_digits = NumberTraits::max_ascii_len<ToFieldType>();
+        constexpr UInt32 to_max_digits = 
NumberTraits::max_ascii_len<ToFieldType>();
+        bool narrow_integral = (from_precision - from_scale) >= to_max_digits;
 
         // may overflow if integer part of decimal is larger than to_max_digits
         bool may_overflow = (from_precision - from_scale) >= to_max_digits;
@@ -184,10 +186,13 @@ public:
         CastParameters params;
         params.is_strict = (CastMode == CastModeType::StrictMode);
         size_t size = vec_from.size();
+        typename FromFieldType::NativeType scale_multiplier =
+                
DataTypeDecimal<FromFieldType::PType>::get_scale_multiplier(from_scale);
         for (size_t i = 0; i < size; i++) {
-            if (!CastToInt::from_decimal<typename FromDataType::FieldType,
-                                         typename ToDataType::FieldType>(
-                        vec_from_data[i], from_precision, from_scale, 
vec_to_data[i], params)) {
+            if (!CastToInt::_from_decimal<typename FromDataType::FieldType,
+                                          typename ToDataType::FieldType>(
+                        vec_from_data[i], from_precision, from_scale, 
vec_to_data[i],
+                        scale_multiplier, narrow_integral, params)) {
                 if (result_is_nullable) {
                     null_map_data[i] = 1;
                 } else {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to