This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-2.0 by this push:
new c5111e28208 [fix](decimal) fix wrong decimal overflow of cast caused
by uninitialized nested column of null value (#29960) (#30272)
c5111e28208 is described below
commit c5111e282082243fa0c738c32166f9477ecc495b
Author: TengJianPing <[email protected]>
AuthorDate: Thu Jan 25 10:34:59 2024 +0800
[fix](decimal) fix wrong decimal overflow of cast caused by uninitialized
nested column of null value (#29960) (#30272)
---
be/src/vec/data_types/number_traits.h | 11 +++-
be/src/vec/functions/function_cast.h | 116 ++++++++++++++++++++++++++++++----
2 files changed, 115 insertions(+), 12 deletions(-)
diff --git a/be/src/vec/data_types/number_traits.h
b/be/src/vec/data_types/number_traits.h
index 37240d67568..0a8ee5781fc 100644
--- a/be/src/vec/data_types/number_traits.h
+++ b/be/src/vec/data_types/number_traits.h
@@ -211,7 +211,6 @@ template <typename T>
/// Returns the maximum ascii string length for this type.
/// e.g. the max/min int8_t has 3 characters.
int max_ascii_len() {
- LOG(FATAL) << "Not implemented.";
return 0;
}
@@ -259,6 +258,16 @@ template <>
inline int max_ascii_len<__int128>() {
return 39;
}
+
+template <>
+inline int max_ascii_len<float>() {
+ return INT_MAX;
+}
+
+template <>
+inline int max_ascii_len<double>() {
+ return INT_MAX;
+}
} // namespace NumberTraits
} // namespace doris::vectorized
diff --git a/be/src/vec/functions/function_cast.h
b/be/src/vec/functions/function_cast.h
index d9ef6268dd1..58656309606 100644
--- a/be/src/vec/functions/function_cast.h
+++ b/be/src/vec/functions/function_cast.h
@@ -2067,6 +2067,94 @@ private:
return wrapper;
}
+ static bool need_replace_null_data_to_default(FunctionContext* context,
+ const DataTypePtr& from_type,
+ const DataTypePtr& to_type) {
+ if (from_type->equals(*to_type)) {
+ return false;
+ }
+
+ auto make_default_wrapper = [&](const auto& types) -> bool {
+ using Types = std::decay_t<decltype(types)>;
+ using ToDataType = typename Types::LeftType;
+
+ if constexpr (!(IsDataTypeDecimalOrNumber<ToDataType> ||
IsTimeType<ToDataType> ||
+ IsTimeV2Type<ToDataType> ||
+ std::is_same_v<ToDataType, DataTypeTimeV2> ||
+ std::is_same_v<ToDataType, DataTypeTime>)) {
+ return false;
+ }
+ return call_on_index_and_data_type<
+ ToDataType>(from_type->get_type_id(), [&](const auto&
types2) -> bool {
+ using Types2 = std::decay_t<decltype(types2)>;
+ using FromDataType = typename Types2::LeftType;
+ if constexpr (!(IsDataTypeDecimalOrNumber<FromDataType> ||
+ IsTimeType<FromDataType> ||
IsTimeV2Type<FromDataType> ||
+ std::is_same_v<FromDataType, DataTypeTimeV2> ||
+ std::is_same_v<FromDataType, DataTypeTime>)) {
+ return false;
+ }
+ if constexpr (IsDataTypeDecimal<FromDataType> ||
IsDataTypeDecimal<ToDataType>) {
+ using FromFieldType = typename FromDataType::FieldType;
+ using ToFieldType = typename ToDataType::FieldType;
+ UInt32 from_precision =
NumberTraits::max_ascii_len<FromFieldType>();
+ UInt32 from_scale = 0;
+
+ if constexpr (IsDataTypeDecimal<FromDataType>) {
+ const auto* from_decimal_type =
+
check_and_get_data_type<FromDataType>(from_type.get());
+ from_precision =
+ NumberTraits::max_ascii_len<typename
FromFieldType::NativeType>();
+ from_scale = from_decimal_type->get_scale();
+ }
+
+ UInt32 to_max_digits = 0;
+ UInt32 to_precision = 0;
+ UInt32 to_scale = 0;
+
+ ToFieldType max_result {0};
+ ToFieldType min_result {0};
+ if constexpr (IsDataTypeDecimal<ToDataType>) {
+ to_max_digits =
+ NumberTraits::max_ascii_len<typename
ToFieldType::NativeType>();
+
+ const auto* to_decimal_type =
+
check_and_get_data_type<ToDataType>(to_type.get());
+ to_precision = to_decimal_type->get_precision();
+ ToDataType::check_type_precision(to_precision);
+
+ to_scale = to_decimal_type->get_scale();
+ ToDataType::check_type_scale(to_scale);
+
+ max_result =
ToDataType::get_max_digits_number(to_precision);
+ min_result = -max_result;
+ }
+ if constexpr (std::is_integral_v<ToFieldType> ||
+ std::is_floating_point_v<ToFieldType>) {
+ max_result = type_limit<ToFieldType>::max();
+ min_result = type_limit<ToFieldType>::min();
+ to_max_digits =
NumberTraits::max_ascii_len<ToFieldType>();
+ to_precision = to_max_digits;
+ }
+
+ bool narrow_integral =
+ context->check_overflow_for_decimal() &&
+ (to_precision - to_scale) <= (from_precision -
from_scale);
+
+ bool multiply_may_overflow =
context->check_overflow_for_decimal();
+ if (to_scale > from_scale) {
+ multiply_may_overflow &=
+ (from_precision + to_scale - from_scale) >=
to_max_digits;
+ }
+ return narrow_integral || multiply_may_overflow;
+ }
+ return false;
+ });
+ };
+
+ return call_on_index_and_data_type<void>(to_type->get_type_id(),
make_default_wrapper);
+ }
+
WrapperType prepare_remove_nullable(FunctionContext* context, const
DataTypePtr& from_type,
const DataTypePtr& to_type,
bool skip_not_null_check) const {
@@ -2074,13 +2162,19 @@ private:
bool source_is_nullable = from_type->is_nullable();
bool result_is_nullable = to_type->is_nullable();
- auto wrapper = prepare_impl(context, remove_nullable(from_type),
remove_nullable(to_type),
+ auto from_type_not_nullable = remove_nullable(from_type);
+ auto to_type_not_nullable = remove_nullable(to_type);
+
+ bool replace_null_data_to_default = need_replace_null_data_to_default(
+ context, from_type_not_nullable, to_type_not_nullable);
+
+ auto wrapper = prepare_impl(context, from_type_not_nullable,
to_type_not_nullable,
result_is_nullable);
if (result_is_nullable) {
- return [wrapper, source_is_nullable](FunctionContext* context,
Block& block,
- const ColumnNumbers&
arguments,
- const size_t result, size_t
input_rows_count) {
+ return [wrapper, source_is_nullable, replace_null_data_to_default](
+ FunctionContext* context, Block& block, const
ColumnNumbers& arguments,
+ const size_t result, size_t input_rows_count) {
/// Create a temporary block on which to perform the operation.
auto& res = block.get_by_position(result);
const auto& ret_type = res.type;
@@ -2090,8 +2184,8 @@ private:
Block tmp_block;
size_t tmp_res_index = 0;
if (source_is_nullable) {
- auto [t_block, tmp_args] =
- create_block_with_nested_columns(block, arguments,
true);
+ auto [t_block, tmp_args] =
create_block_with_nested_columns(
+ block, arguments, true,
replace_null_data_to_default);
tmp_block = std::move(t_block);
tmp_res_index = tmp_block.columns();
tmp_block.insert({nullptr, nested_type, ""});
@@ -2121,11 +2215,11 @@ private:
} else if (source_is_nullable) {
/// Conversion from Nullable to non-Nullable.
- return [wrapper, skip_not_null_check](FunctionContext* context,
Block& block,
- const ColumnNumbers&
arguments,
- const size_t result, size_t
input_rows_count) {
- auto [tmp_block, tmp_args, tmp_res] =
- create_block_with_nested_columns(block, arguments,
result);
+ return [wrapper, skip_not_null_check,
replace_null_data_to_default](
+ FunctionContext* context, Block& block, const
ColumnNumbers& arguments,
+ const size_t result, size_t input_rows_count) {
+ auto [tmp_block, tmp_args, tmp_res] =
create_block_with_nested_columns(
+ block, arguments, result,
replace_null_data_to_default);
/// Check that all values are not-NULL.
/// Check can be skipped in case if LowCardinality dictionary
is transformed.
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]