This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit cca077326169e0034d320df0eddb89a28a330551
Author: Mryange <[email protected]>
AuthorDate: Tue Mar 5 10:42:21 2024 +0800

    [feature](function) round function defaults to rounding normally
---
 be/src/vec/functions/round.h                       | 92 +++++++---------------
 .../sql-functions/numeric-functions/round.md       |  1 +
 .../sql-functions/numeric-functions/round.md       |  2 +-
 .../data/correctness/test_float_round_up.out       | 13 +++
 .../data/nereids_function_p0/scalar_function/D.out |  4 +-
 .../data/nereids_function_p0/scalar_function/R.out |  4 +-
 .../sql_functions/math_functions/test_round.out    | 24 +++---
 .../issuesWithTheMostComments3.out                 |  2 +-
 .../sql/issuesWithTheMostComments3.out             |  2 +-
 .../suites/correctness/test_float_round_up.groovy  | 38 +++++++++
 10 files changed, 99 insertions(+), 83 deletions(-)

diff --git a/be/src/vec/functions/round.h b/be/src/vec/functions/round.h
index 3c6dc1d441d..7e48b8e9306 100644
--- a/be/src/vec/functions/round.h
+++ b/be/src/vec/functions/round.h
@@ -59,7 +59,7 @@ enum class RoundingMode {
 };
 
 enum class TieBreakingMode {
-    Auto,    // use banker's rounding for floating point numbers, round up 
otherwise
+    Auto,    // use round up
     Bankers, // use banker's rounding
 };
 
@@ -178,59 +178,16 @@ public:
     }
 };
 
-#if defined(__SSE4_1__) || defined(__aarch64__)
-
-template <typename T>
-class BaseFloatRoundingComputation;
-
-template <>
-class BaseFloatRoundingComputation<Float32> {
-public:
-    using ScalarType = Float32;
-    using VectorType = __m128;
-    static const size_t data_count = 4;
-
-    static VectorType load(const ScalarType* in) { return _mm_loadu_ps(in); }
-    static VectorType load1(const ScalarType in) { return _mm_load1_ps(&in); }
-    static void store(ScalarType* out, VectorType val) { _mm_storeu_ps(out, 
val); }
-    static VectorType multiply(VectorType val, VectorType scale) { return 
_mm_mul_ps(val, scale); }
-    static VectorType divide(VectorType val, VectorType scale) { return 
_mm_div_ps(val, scale); }
-    template <RoundingMode mode>
-    static VectorType apply(VectorType val) {
-        return _mm_round_ps(val, int(mode));
-    }
-
-    static VectorType prepare(size_t scale) { return load1(scale); }
-};
-
-template <>
-class BaseFloatRoundingComputation<Float64> {
-public:
-    using ScalarType = Float64;
-    using VectorType = __m128d;
-    static const size_t data_count = 2;
-
-    static VectorType load(const ScalarType* in) { return _mm_loadu_pd(in); }
-    static VectorType load1(const ScalarType in) { return _mm_load1_pd(&in); }
-    static void store(ScalarType* out, VectorType val) { _mm_storeu_pd(out, 
val); }
-    static VectorType multiply(VectorType val, VectorType scale) { return 
_mm_mul_pd(val, scale); }
-    static VectorType divide(VectorType val, VectorType scale) { return 
_mm_div_pd(val, scale); }
-    template <RoundingMode mode>
-    static VectorType apply(VectorType val) {
-        return _mm_round_pd(val, int(mode));
-    }
-
-    static VectorType prepare(size_t scale) { return load1(scale); }
-};
-
-#else
-
-/// Implementation for ARM. Not vectorized.
-
+template <TieBreakingMode tie_breaking_mode>
 inline float roundWithMode(float x, RoundingMode mode) {
     switch (mode) {
-    case RoundingMode::Round:
-        return nearbyintf(x);
+    case RoundingMode::Round: {
+        if constexpr (tie_breaking_mode == TieBreakingMode::Bankers) {
+            return nearbyintf(x);
+        } else {
+            return roundf(x);
+        }
+    }
     case RoundingMode::Floor:
         return floorf(x);
     case RoundingMode::Ceil:
@@ -243,10 +200,16 @@ inline float roundWithMode(float x, RoundingMode mode) {
     __builtin_unreachable();
 }
 
+template <TieBreakingMode tie_breaking_mode>
 inline double roundWithMode(double x, RoundingMode mode) {
     switch (mode) {
-    case RoundingMode::Round:
-        return nearbyint(x);
+    case RoundingMode::Round: {
+        if constexpr (tie_breaking_mode == TieBreakingMode::Bankers) {
+            return nearbyint(x);
+        } else {
+            return round(x);
+        }
+    }
     case RoundingMode::Floor:
         return floor(x);
     case RoundingMode::Ceil:
@@ -259,7 +222,7 @@ inline double roundWithMode(double x, RoundingMode mode) {
     __builtin_unreachable();
 }
 
-template <typename T>
+template <typename T, TieBreakingMode tie_breaking_mode>
 class BaseFloatRoundingComputation {
 public:
     using ScalarType = T;
@@ -273,19 +236,18 @@ public:
     static VectorType divide(VectorType val, VectorType scale) { return val / 
scale; }
     template <RoundingMode mode>
     static VectorType apply(VectorType val) {
-        return roundWithMode(val, mode);
+        return roundWithMode<tie_breaking_mode>(val, mode);
     }
 
     static VectorType prepare(size_t scale) { return load1(scale); }
 };
 
-#endif
-
 /** Implementation of low-level round-off functions for floating-point values.
   */
-template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
-class FloatRoundingComputation : public BaseFloatRoundingComputation<T> {
-    using Base = BaseFloatRoundingComputation<T>;
+template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode,
+          TieBreakingMode tie_breaking_mode>
+class FloatRoundingComputation : public BaseFloatRoundingComputation<T, 
tie_breaking_mode> {
+    using Base = BaseFloatRoundingComputation<T, tie_breaking_mode>;
 
 public:
     static inline void compute(const T* __restrict in, const typename 
Base::VectorType& scale,
@@ -312,12 +274,13 @@ public:
 
 /** Implementing high-level rounding functions.
   */
-template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
+template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode,
+          TieBreakingMode tie_breaking_mode>
 struct FloatRoundingImpl {
 private:
     static_assert(!IsDecimalNumber<T>);
 
-    using Op = FloatRoundingComputation<T, rounding_mode, scale_mode>;
+    using Op = FloatRoundingComputation<T, rounding_mode, scale_mode, 
tie_breaking_mode>;
     using Data = std::array<T, Op::data_count>;
     using ColumnType = ColumnVector<T>;
     using Container = typename ColumnType::Container;
@@ -433,7 +396,8 @@ struct Dispatcher {
     using FunctionRoundingImpl = std::conditional_t<
             IsDecimalNumber<T>, DecimalRoundingImpl<T, rounding_mode, 
tie_breaking_mode>,
             std::conditional_t<
-                    std::is_floating_point_v<T>, FloatRoundingImpl<T, 
rounding_mode, scale_mode>,
+                    std::is_floating_point_v<T>,
+                    FloatRoundingImpl<T, rounding_mode, scale_mode, 
tie_breaking_mode>,
                     IntegerRoundingImpl<T, rounding_mode, scale_mode, 
tie_breaking_mode>>>;
 
     static ColumnPtr apply(const IColumn* col_general, Int16 scale_arg) {
diff --git a/docs/en/docs/sql-manual/sql-functions/numeric-functions/round.md 
b/docs/en/docs/sql-manual/sql-functions/numeric-functions/round.md
index ad574d3ff27..f34519acb60 100644
--- a/docs/en/docs/sql-manual/sql-functions/numeric-functions/round.md
+++ b/docs/en/docs/sql-manual/sql-functions/numeric-functions/round.md
@@ -29,6 +29,7 @@ under the License.
 
 `T round(T x[, d])`
 Rounds the argument `x` to `d` decimal places. `d` defaults to 0 if not 
specified. If d is negative, the left d digits of the decimal point are 0. If x 
or d is null, null is returned.
+2.5 will round up to 3. If you want to round down to 2, please use the 
round_bankers function.
 
 ### example
 
diff --git 
a/docs/zh-CN/docs/sql-manual/sql-functions/numeric-functions/round.md 
b/docs/zh-CN/docs/sql-manual/sql-functions/numeric-functions/round.md
index 8909f7c9d11..576a96e1d8d 100644
--- a/docs/zh-CN/docs/sql-manual/sql-functions/numeric-functions/round.md
+++ b/docs/zh-CN/docs/sql-manual/sql-functions/numeric-functions/round.md
@@ -29,7 +29,7 @@ under the License.
 
 `T round(T x[, d])`
 将`x`四舍五入后保留d位小数,d默认为0。如果d为负数,则小数点左边d位为0。如果x或d为null,返回null。
-
+2.5会舍入到3,如果想要舍入到2的算法,请使用round_bankers函数。
 ### example
 
 ```
diff --git a/regression-test/data/correctness/test_float_round_up.out 
b/regression-test/data/correctness/test_float_round_up.out
new file mode 100644
index 00000000000..812d0826ccd
--- /dev/null
+++ b/regression-test/data/correctness/test_float_round_up.out
@@ -0,0 +1,13 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !select --
+2.5    3.0
+
+-- !select --
+3.5    4.0
+
+-- !select --
+2.5    2.0
+
+-- !select --
+3.5    4.0
+
diff --git a/regression-test/data/nereids_function_p0/scalar_function/D.out 
b/regression-test/data/nereids_function_p0/scalar_function/D.out
index ec685e5982b..93d3090d109 100644
--- a/regression-test/data/nereids_function_p0/scalar_function/D.out
+++ b/regression-test/data/nereids_function_p0/scalar_function/D.out
@@ -2934,7 +2934,7 @@ Monday
 0.0
 0.0
 0.0
-0.0
+1.0
 1.0
 1.0
 1.0
@@ -2948,7 +2948,7 @@ Monday
 0.0
 0.0
 0.0
-0.0
+1.0
 1.0
 1.0
 1.0
diff --git a/regression-test/data/nereids_function_p0/scalar_function/R.out 
b/regression-test/data/nereids_function_p0/scalar_function/R.out
index 7b8252223ad..df1e89de714 100644
--- a/regression-test/data/nereids_function_p0/scalar_function/R.out
+++ b/regression-test/data/nereids_function_p0/scalar_function/R.out
@@ -469,7 +469,7 @@ string3
 0.0
 0.0
 0.0
-0.0
+1.0
 1.0
 1.0
 1.0
@@ -483,7 +483,7 @@ string3
 0.0
 0.0
 0.0
-0.0
+1.0
 1.0
 1.0
 1.0
diff --git 
a/regression-test/data/query_p0/sql_functions/math_functions/test_round.out 
b/regression-test/data/query_p0/sql_functions/math_functions/test_round.out
index bbc66e5b976..50d15b2843b 100644
--- a/regression-test/data/query_p0/sql_functions/math_functions/test_round.out
+++ b/regression-test/data/query_p0/sql_functions/math_functions/test_round.out
@@ -11,18 +11,6 @@
 -- !select --
 10.12
 
--- !select --
-0.0    0.0
-0.5    0.0
-1.0    1.0
-1.5    2.0
-2.0    2.0
-2.5    2.0
-3.0    3.0
-3.5    4.0
-4.0    4.0
-4.5    4.0
-
 -- !select --
 0.0    0.0
 0.5    1.0
@@ -35,6 +23,18 @@
 4.0    4.0
 4.5    5.0
 
+-- !select --
+0.0    0
+0.5    1
+1.0    1
+1.5    2
+2.0    2
+2.5    3
+3.0    3
+3.5    4
+4.0    4
+4.5    5
+
 -- !truncate --
 1      1989    1001    123.1   0.1     6.3
 2      1986    1001    1243.5  20.2    789.2
diff --git 
a/regression-test/data/variant_github_events_p0/issuesWithTheMostComments3.out 
b/regression-test/data/variant_github_events_p0/issuesWithTheMostComments3.out
index 9cb3bda773f..f17b6c34936 100644
--- 
a/regression-test/data/variant_github_events_p0/issuesWithTheMostComments3.out
+++ 
b/regression-test/data/variant_github_events_p0/issuesWithTheMostComments3.out
@@ -33,7 +33,7 @@ tipfortip/issues      10      1       10.0
 Mindwerks/wildmidi     9       9       1.0
 NeuroVault/NeuroVault  9       1       9.0
 THE-ESCAPIST/RSSHub    9       7       1.29
-WhisperSystems/TextSecure      9       8       1.12
+WhisperSystems/TextSecure      9       8       1.13
 XLabs/Xamarin-Forms-Labs       9       6       1.5
 aws/eks-distro 9       1       9.0
 disco-trooper/weather-app      9       9       1.0
diff --git 
a/regression-test/data/variant_github_events_p0_new/sql/issuesWithTheMostComments3.out
 
b/regression-test/data/variant_github_events_p0_new/sql/issuesWithTheMostComments3.out
index 9cb3bda773f..f17b6c34936 100644
--- 
a/regression-test/data/variant_github_events_p0_new/sql/issuesWithTheMostComments3.out
+++ 
b/regression-test/data/variant_github_events_p0_new/sql/issuesWithTheMostComments3.out
@@ -33,7 +33,7 @@ tipfortip/issues      10      1       10.0
 Mindwerks/wildmidi     9       9       1.0
 NeuroVault/NeuroVault  9       1       9.0
 THE-ESCAPIST/RSSHub    9       7       1.29
-WhisperSystems/TextSecure      9       8       1.12
+WhisperSystems/TextSecure      9       8       1.13
 XLabs/Xamarin-Forms-Labs       9       6       1.5
 aws/eks-distro 9       1       9.0
 disco-trooper/weather-app      9       9       1.0
diff --git a/regression-test/suites/correctness/test_float_round_up.groovy 
b/regression-test/suites/correctness/test_float_round_up.groovy
new file mode 100644
index 00000000000..26839e115aa
--- /dev/null
+++ b/regression-test/suites/correctness/test_float_round_up.groovy
@@ -0,0 +1,38 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_float_round_up") {
+    sql """ set enable_nereids_planner=true; """
+    sql """ set enable_fallback_to_original_planner=false; """
+
+    
+    qt_select """
+        select 5/2 , round(5/2);
+    """
+
+    qt_select """
+        select 7/ 2 ,  round(7/2);
+    """
+
+    qt_select """
+        select 5/2 , round_bankers(5/2);
+    """
+
+    qt_select """
+        select 7/ 2 ,  round_bankers(7/2);
+    """
+}
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to