This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.0 by this push:
     new c5111e28208 [fix](decimal) fix wrong decimal overflow of cast caused 
by uninitialized nested column of null value (#29960) (#30272)
c5111e28208 is described below

commit c5111e282082243fa0c738c32166f9477ecc495b
Author: TengJianPing <[email protected]>
AuthorDate: Thu Jan 25 10:34:59 2024 +0800

    [fix](decimal) fix wrong decimal overflow of cast caused by uninitialized 
nested column of null value (#29960) (#30272)
---
 be/src/vec/data_types/number_traits.h |  11 +++-
 be/src/vec/functions/function_cast.h  | 116 ++++++++++++++++++++++++++++++----
 2 files changed, 115 insertions(+), 12 deletions(-)

diff --git a/be/src/vec/data_types/number_traits.h 
b/be/src/vec/data_types/number_traits.h
index 37240d67568..0a8ee5781fc 100644
--- a/be/src/vec/data_types/number_traits.h
+++ b/be/src/vec/data_types/number_traits.h
@@ -211,7 +211,6 @@ template <typename T>
 /// Returns the maximum ascii string length for this type.
 /// e.g. the max/min int8_t has 3 characters.
 int max_ascii_len() {
-    LOG(FATAL) << "Not implemented.";
     return 0;
 }
 
@@ -259,6 +258,16 @@ template <>
 inline int max_ascii_len<__int128>() {
     return 39;
 }
+
+template <>
+inline int max_ascii_len<float>() {
+    return INT_MAX;
+}
+
+template <>
+inline int max_ascii_len<double>() {
+    return INT_MAX;
+}
 } // namespace NumberTraits
 
 } // namespace doris::vectorized
diff --git a/be/src/vec/functions/function_cast.h 
b/be/src/vec/functions/function_cast.h
index d9ef6268dd1..58656309606 100644
--- a/be/src/vec/functions/function_cast.h
+++ b/be/src/vec/functions/function_cast.h
@@ -2067,6 +2067,94 @@ private:
         return wrapper;
     }
 
+    static bool need_replace_null_data_to_default(FunctionContext* context,
+                                                  const DataTypePtr& from_type,
+                                                  const DataTypePtr& to_type) {
+        if (from_type->equals(*to_type)) {
+            return false;
+        }
+
+        auto make_default_wrapper = [&](const auto& types) -> bool {
+            using Types = std::decay_t<decltype(types)>;
+            using ToDataType = typename Types::LeftType;
+
+            if constexpr (!(IsDataTypeDecimalOrNumber<ToDataType> || 
IsTimeType<ToDataType> ||
+                            IsTimeV2Type<ToDataType> ||
+                            std::is_same_v<ToDataType, DataTypeTimeV2> ||
+                            std::is_same_v<ToDataType, DataTypeTime>)) {
+                return false;
+            }
+            return call_on_index_and_data_type<
+                    ToDataType>(from_type->get_type_id(), [&](const auto& 
types2) -> bool {
+                using Types2 = std::decay_t<decltype(types2)>;
+                using FromDataType = typename Types2::LeftType;
+                if constexpr (!(IsDataTypeDecimalOrNumber<FromDataType> ||
+                                IsTimeType<FromDataType> || 
IsTimeV2Type<FromDataType> ||
+                                std::is_same_v<FromDataType, DataTypeTimeV2> ||
+                                std::is_same_v<FromDataType, DataTypeTime>)) {
+                    return false;
+                }
+                if constexpr (IsDataTypeDecimal<FromDataType> || 
IsDataTypeDecimal<ToDataType>) {
+                    using FromFieldType = typename FromDataType::FieldType;
+                    using ToFieldType = typename ToDataType::FieldType;
+                    UInt32 from_precision = 
NumberTraits::max_ascii_len<FromFieldType>();
+                    UInt32 from_scale = 0;
+
+                    if constexpr (IsDataTypeDecimal<FromDataType>) {
+                        const auto* from_decimal_type =
+                                
check_and_get_data_type<FromDataType>(from_type.get());
+                        from_precision =
+                                NumberTraits::max_ascii_len<typename 
FromFieldType::NativeType>();
+                        from_scale = from_decimal_type->get_scale();
+                    }
+
+                    UInt32 to_max_digits = 0;
+                    UInt32 to_precision = 0;
+                    UInt32 to_scale = 0;
+
+                    ToFieldType max_result {0};
+                    ToFieldType min_result {0};
+                    if constexpr (IsDataTypeDecimal<ToDataType>) {
+                        to_max_digits =
+                                NumberTraits::max_ascii_len<typename 
ToFieldType::NativeType>();
+
+                        const auto* to_decimal_type =
+                                
check_and_get_data_type<ToDataType>(to_type.get());
+                        to_precision = to_decimal_type->get_precision();
+                        ToDataType::check_type_precision(to_precision);
+
+                        to_scale = to_decimal_type->get_scale();
+                        ToDataType::check_type_scale(to_scale);
+
+                        max_result = 
ToDataType::get_max_digits_number(to_precision);
+                        min_result = -max_result;
+                    }
+                    if constexpr (std::is_integral_v<ToFieldType> ||
+                                  std::is_floating_point_v<ToFieldType>) {
+                        max_result = type_limit<ToFieldType>::max();
+                        min_result = type_limit<ToFieldType>::min();
+                        to_max_digits = 
NumberTraits::max_ascii_len<ToFieldType>();
+                        to_precision = to_max_digits;
+                    }
+
+                    bool narrow_integral =
+                            context->check_overflow_for_decimal() &&
+                            (to_precision - to_scale) <= (from_precision - 
from_scale);
+
+                    bool multiply_may_overflow = 
context->check_overflow_for_decimal();
+                    if (to_scale > from_scale) {
+                        multiply_may_overflow &=
+                                (from_precision + to_scale - from_scale) >= 
to_max_digits;
+                    }
+                    return narrow_integral || multiply_may_overflow;
+                }
+                return false;
+            });
+        };
+
+        return call_on_index_and_data_type<void>(to_type->get_type_id(), 
make_default_wrapper);
+    }
+
     WrapperType prepare_remove_nullable(FunctionContext* context, const 
DataTypePtr& from_type,
                                         const DataTypePtr& to_type,
                                         bool skip_not_null_check) const {
@@ -2074,13 +2162,19 @@ private:
         bool source_is_nullable = from_type->is_nullable();
         bool result_is_nullable = to_type->is_nullable();
 
-        auto wrapper = prepare_impl(context, remove_nullable(from_type), 
remove_nullable(to_type),
+        auto from_type_not_nullable = remove_nullable(from_type);
+        auto to_type_not_nullable = remove_nullable(to_type);
+
+        bool replace_null_data_to_default = need_replace_null_data_to_default(
+                context, from_type_not_nullable, to_type_not_nullable);
+
+        auto wrapper = prepare_impl(context, from_type_not_nullable, 
to_type_not_nullable,
                                     result_is_nullable);
 
         if (result_is_nullable) {
-            return [wrapper, source_is_nullable](FunctionContext* context, 
Block& block,
-                                                 const ColumnNumbers& 
arguments,
-                                                 const size_t result, size_t 
input_rows_count) {
+            return [wrapper, source_is_nullable, replace_null_data_to_default](
+                           FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
+                           const size_t result, size_t input_rows_count) {
                 /// Create a temporary block on which to perform the operation.
                 auto& res = block.get_by_position(result);
                 const auto& ret_type = res.type;
@@ -2090,8 +2184,8 @@ private:
                 Block tmp_block;
                 size_t tmp_res_index = 0;
                 if (source_is_nullable) {
-                    auto [t_block, tmp_args] =
-                            create_block_with_nested_columns(block, arguments, 
true);
+                    auto [t_block, tmp_args] = 
create_block_with_nested_columns(
+                            block, arguments, true, 
replace_null_data_to_default);
                     tmp_block = std::move(t_block);
                     tmp_res_index = tmp_block.columns();
                     tmp_block.insert({nullptr, nested_type, ""});
@@ -2121,11 +2215,11 @@ private:
         } else if (source_is_nullable) {
             /// Conversion from Nullable to non-Nullable.
 
-            return [wrapper, skip_not_null_check](FunctionContext* context, 
Block& block,
-                                                  const ColumnNumbers& 
arguments,
-                                                  const size_t result, size_t 
input_rows_count) {
-                auto [tmp_block, tmp_args, tmp_res] =
-                        create_block_with_nested_columns(block, arguments, 
result);
+            return [wrapper, skip_not_null_check, 
replace_null_data_to_default](
+                           FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
+                           const size_t result, size_t input_rows_count) {
+                auto [tmp_block, tmp_args, tmp_res] = 
create_block_with_nested_columns(
+                        block, arguments, result, 
replace_null_data_to_default);
 
                 /// Check that all values are not-NULL.
                 /// Check can be skipped in case if LowCardinality dictionary 
is transformed.


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to