This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.1 by this push:
     new de2272ce487 [fix](round) fix round decimal128 overflow (#37733) 
(#37963)
de2272ce487 is described below

commit de2272ce4870ad519c074bb120f003cea15aae23
Author: camby <[email protected]>
AuthorDate: Thu Jul 18 23:50:23 2024 +0800

    [fix](round) fix round decimal128 overflow (#37733) (#37963)
    
    cherry-pick #37733 to branch-2.1
---
 be/src/vec/exec/format/format_common.h             |  4 +++-
 be/src/vec/functions/round.h                       | 26 +++++++++++++---------
 .../math_functions/test_round_overflow.out         | 10 +++++++++
 .../math_functions/test_round_overflow.groovy      | 25 +++++++++++++++++++++
 4 files changed, 53 insertions(+), 12 deletions(-)

diff --git a/be/src/vec/exec/format/format_common.h 
b/be/src/vec/exec/format/format_common.h
index 4227a2128d2..3edf021ad27 100644
--- a/be/src/vec/exec/format/format_common.h
+++ b/be/src/vec/exec/format/format_common.h
@@ -33,7 +33,7 @@ struct DecimalScaleParams {
     int64_t scale_factor = 1;
 
     template <typename DecimalPrimitiveType>
-    static inline constexpr DecimalPrimitiveType get_scale_factor(int32_t n) {
+    static inline constexpr DecimalPrimitiveType::NativeType 
get_scale_factor(int32_t n) {
         if constexpr (std::is_same_v<DecimalPrimitiveType, Decimal32>) {
             return common::exp10_i32(n);
         } else if constexpr (std::is_same_v<DecimalPrimitiveType, Decimal64>) {
@@ -42,6 +42,8 @@ struct DecimalScaleParams {
             return common::exp10_i128(n);
         } else if constexpr (std::is_same_v<DecimalPrimitiveType, 
Decimal128V3>) {
             return common::exp10_i128(n);
+        } else if constexpr (std::is_same_v<DecimalPrimitiveType, Decimal256>) 
{
+            return common::exp10_i256(n);
         } else {
             static_assert(!sizeof(DecimalPrimitiveType),
                           "All types must be matched with if constexpr.");
diff --git a/be/src/vec/functions/round.h b/be/src/vec/functions/round.h
index a17865914c4..5d0b0f05159 100644
--- a/be/src/vec/functions/round.h
+++ b/be/src/vec/functions/round.h
@@ -32,6 +32,7 @@
 #include "vec/core/types.h"
 #include "vec/data_types/data_type.h"
 #include "vec/data_types/data_type_nullable.h"
+#include "vec/exec/format/format_common.h"
 #include "vec/functions/function.h"
 #if defined(__SSE4_1__) || defined(__aarch64__)
 #include "util/sse_util.hpp"
@@ -39,6 +40,7 @@
 #include <fenv.h>
 #endif
 #include <algorithm>
+#include <type_traits>
 
 #include "vec/columns/column.h"
 #include "vec/columns/column_decimal.h"
@@ -74,7 +76,7 @@ enum class TieBreakingMode {
 };
 
 template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode,
-          TieBreakingMode tie_breaking_mode>
+          TieBreakingMode tie_breaking_mode, typename U>
 struct IntegerRoundingComputation {
     static const size_t data_count = 1;
 
@@ -126,7 +128,7 @@ struct IntegerRoundingComputation {
         __builtin_unreachable();
     }
 
-    static ALWAYS_INLINE T compute(T x, T scale, size_t target_scale) {
+    static ALWAYS_INLINE T compute(T x, T scale, T target_scale) {
         switch (scale_mode) {
         case ScaleMode::Zero:
         case ScaleMode::Positive:
@@ -138,10 +140,10 @@ struct IntegerRoundingComputation {
         __builtin_unreachable();
     }
 
-    static ALWAYS_INLINE void compute(const T* __restrict in, size_t scale, T* 
__restrict out,
-                                      size_t target_scale) {
+    static ALWAYS_INLINE void compute(const T* __restrict in, U scale, T* 
__restrict out,
+                                      U target_scale) {
         if constexpr (sizeof(T) <= sizeof(scale) && scale_mode == 
ScaleMode::Negative) {
-            if (scale > size_t(std::numeric_limits<T>::max())) {
+            if (scale >= std::numeric_limits<T>::max()) {
                 *out = 0;
                 return;
             }
@@ -155,7 +157,7 @@ class DecimalRoundingImpl {
 private:
     using NativeType = typename T::NativeType;
     using Op = IntegerRoundingComputation<NativeType, rounding_mode, 
ScaleMode::Negative,
-                                          tie_breaking_mode>;
+                                          tie_breaking_mode, NativeType>;
     using Container = typename ColumnDecimal<T>::Container;
 
 public:
@@ -163,15 +165,16 @@ public:
                                 Int16 out_scale) {
         Int16 scale_arg = in_scale - out_scale;
         if (scale_arg > 0) {
-            size_t scale = int_exp10(scale_arg);
+            auto scale = DecimalScaleParams::get_scale_factor<T>(scale_arg);
 
             const NativeType* __restrict p_in = reinterpret_cast<const 
NativeType*>(in.data());
             const NativeType* end_in = reinterpret_cast<const 
NativeType*>(in.data()) + in.size();
             NativeType* __restrict p_out = 
reinterpret_cast<NativeType*>(out.data());
 
             if (out_scale < 0) {
+                auto negative_scale = 
DecimalScaleParams::get_scale_factor<T>(-out_scale);
                 while (p_in < end_in) {
-                    Op::compute(p_in, scale, p_out, int_exp10(-out_scale));
+                    Op::compute(p_in, scale, p_out, negative_scale);
                     ++p_in;
                     ++p_out;
                 }
@@ -191,9 +194,10 @@ public:
                                 Int16 out_scale) {
         Int16 scale_arg = in_scale - out_scale;
         if (scale_arg > 0) {
-            size_t scale = int_exp10(scale_arg);
+            auto scale = DecimalScaleParams::get_scale_factor<T>(scale_arg);
             if (out_scale < 0) {
-                Op::compute(&in, scale, &out, int_exp10(-out_scale));
+                auto negative_scale = 
DecimalScaleParams::get_scale_factor<T>(-out_scale);
+                Op::compute(&in, scale, &out, negative_scale);
             } else {
                 Op::compute(&in, scale, &out, 1);
             }
@@ -350,7 +354,7 @@ template <typename T, RoundingMode rounding_mode, ScaleMode 
scale_mode,
           TieBreakingMode tie_breaking_mode>
 struct IntegerRoundingImpl {
 private:
-    using Op = IntegerRoundingComputation<T, rounding_mode, scale_mode, 
tie_breaking_mode>;
+    using Op = IntegerRoundingComputation<T, rounding_mode, scale_mode, 
tie_breaking_mode, size_t>;
     using Container = typename ColumnVector<T>::Container;
 
 public:
diff --git 
a/regression-test/data/nereids_p0/sql_functions/math_functions/test_round_overflow.out
 
b/regression-test/data/nereids_p0/sql_functions/math_functions/test_round_overflow.out
new file mode 100644
index 00000000000..a18cc094872
--- /dev/null
+++ 
b/regression-test/data/nereids_p0/sql_functions/math_functions/test_round_overflow.out
@@ -0,0 +1,10 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !select1 --
+186
+
+-- !select2 --
+0
+
+-- !select3 --
+20000000000000000000000
+
diff --git 
a/regression-test/suites/nereids_p0/sql_functions/math_functions/test_round_overflow.groovy
 
b/regression-test/suites/nereids_p0/sql_functions/math_functions/test_round_overflow.groovy
new file mode 100644
index 00000000000..12311537c27
--- /dev/null
+++ 
b/regression-test/suites/nereids_p0/sql_functions/math_functions/test_round_overflow.groovy
@@ -0,0 +1,25 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_round_overflow") {
+    sql "SET enable_nereids_planner=true"
+    sql "SET enable_fallback_to_original_planner=false"
+
+    qt_select1 "select 
round(coalesce(186,-33280029.8473323000000000000000000));"
+    qt_select2 "select 
round(coalesce(186,-33280029.8473323000000000000000000), -20);"
+    qt_select3 "select 
round(coalesce(18618500001234567890123.0,-33280029.847332300000000),-22);"
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to