This is an automated email from the ASF dual-hosted git repository.
kxiao pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-2.0 by this push:
new 9605afa39bb [feature](function) round function defaults to rounding
normally #31583 (#32050)
9605afa39bb is described below
commit 9605afa39bb0c0245f983e19e675afc3e0370798
Author: Mryange <[email protected]>
AuthorDate: Wed Mar 13 23:08:38 2024 +0800
[feature](function) round function defaults to rounding normally #31583
(#32050)
---
be/src/vec/functions/round.h | 92 +++++++---------------
.../sql-functions/numeric-functions/round.md | 1 +
.../sql-functions/numeric-functions/round.md | 2 +-
.../data/correctness/test_float_round_up.out | 13 +++
.../data/nereids_function_p0/scalar_function/D.out | 4 +-
.../data/nereids_function_p0/scalar_function/R.out | 4 +-
.../sql_functions/math_functions/test_round.out | 24 +++---
.../suites/correctness/test_float_round_up.groovy | 38 +++++++++
8 files changed, 97 insertions(+), 81 deletions(-)
diff --git a/be/src/vec/functions/round.h b/be/src/vec/functions/round.h
index ae79804022a..30c2b71e841 100644
--- a/be/src/vec/functions/round.h
+++ b/be/src/vec/functions/round.h
@@ -59,7 +59,7 @@ enum class RoundingMode {
};
enum class TieBreakingMode {
- Auto, // use banker's rounding for floating point numbers, round up
otherwise
+ Auto, // use round up
Bankers, // use banker's rounding
};
@@ -178,59 +178,16 @@ public:
}
};
-#if defined(__SSE4_1__) || defined(__aarch64__)
-
-template <typename T>
-class BaseFloatRoundingComputation;
-
-template <>
-class BaseFloatRoundingComputation<Float32> {
-public:
- using ScalarType = Float32;
- using VectorType = __m128;
- static const size_t data_count = 4;
-
- static VectorType load(const ScalarType* in) { return _mm_loadu_ps(in); }
- static VectorType load1(const ScalarType in) { return _mm_load1_ps(&in); }
- static void store(ScalarType* out, VectorType val) { _mm_storeu_ps(out,
val); }
- static VectorType multiply(VectorType val, VectorType scale) { return
_mm_mul_ps(val, scale); }
- static VectorType divide(VectorType val, VectorType scale) { return
_mm_div_ps(val, scale); }
- template <RoundingMode mode>
- static VectorType apply(VectorType val) {
- return _mm_round_ps(val, int(mode));
- }
-
- static VectorType prepare(size_t scale) { return load1(scale); }
-};
-
-template <>
-class BaseFloatRoundingComputation<Float64> {
-public:
- using ScalarType = Float64;
- using VectorType = __m128d;
- static const size_t data_count = 2;
-
- static VectorType load(const ScalarType* in) { return _mm_loadu_pd(in); }
- static VectorType load1(const ScalarType in) { return _mm_load1_pd(&in); }
- static void store(ScalarType* out, VectorType val) { _mm_storeu_pd(out,
val); }
- static VectorType multiply(VectorType val, VectorType scale) { return
_mm_mul_pd(val, scale); }
- static VectorType divide(VectorType val, VectorType scale) { return
_mm_div_pd(val, scale); }
- template <RoundingMode mode>
- static VectorType apply(VectorType val) {
- return _mm_round_pd(val, int(mode));
- }
-
- static VectorType prepare(size_t scale) { return load1(scale); }
-};
-
-#else
-
-/// Implementation for ARM. Not vectorized.
-
+template <TieBreakingMode tie_breaking_mode>
inline float roundWithMode(float x, RoundingMode mode) {
switch (mode) {
- case RoundingMode::Round:
- return nearbyintf(x);
+ case RoundingMode::Round: {
+ if constexpr (tie_breaking_mode == TieBreakingMode::Bankers) {
+ return nearbyintf(x);
+ } else {
+ return roundf(x);
+ }
+ }
case RoundingMode::Floor:
return floorf(x);
case RoundingMode::Ceil:
@@ -243,10 +200,16 @@ inline float roundWithMode(float x, RoundingMode mode) {
__builtin_unreachable();
}
+template <TieBreakingMode tie_breaking_mode>
inline double roundWithMode(double x, RoundingMode mode) {
switch (mode) {
- case RoundingMode::Round:
- return nearbyint(x);
+ case RoundingMode::Round: {
+ if constexpr (tie_breaking_mode == TieBreakingMode::Bankers) {
+ return nearbyint(x);
+ } else {
+ return round(x);
+ }
+ }
case RoundingMode::Floor:
return floor(x);
case RoundingMode::Ceil:
@@ -259,7 +222,7 @@ inline double roundWithMode(double x, RoundingMode mode) {
__builtin_unreachable();
}
-template <typename T>
+template <typename T, TieBreakingMode tie_breaking_mode>
class BaseFloatRoundingComputation {
public:
using ScalarType = T;
@@ -273,19 +236,18 @@ public:
static VectorType divide(VectorType val, VectorType scale) { return val /
scale; }
template <RoundingMode mode>
static VectorType apply(VectorType val) {
- return roundWithMode(val, mode);
+ return roundWithMode<tie_breaking_mode>(val, mode);
}
static VectorType prepare(size_t scale) { return load1(scale); }
};
-#endif
-
/** Implementation of low-level round-off functions for floating-point values.
*/
-template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
-class FloatRoundingComputation : public BaseFloatRoundingComputation<T> {
- using Base = BaseFloatRoundingComputation<T>;
+template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode,
+ TieBreakingMode tie_breaking_mode>
+class FloatRoundingComputation : public BaseFloatRoundingComputation<T,
tie_breaking_mode> {
+ using Base = BaseFloatRoundingComputation<T, tie_breaking_mode>;
public:
static inline void compute(const T* __restrict in, const typename
Base::VectorType& scale,
@@ -312,12 +274,13 @@ public:
/** Implementing high-level rounding functions.
*/
-template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
+template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode,
+ TieBreakingMode tie_breaking_mode>
struct FloatRoundingImpl {
private:
static_assert(!IsDecimalNumber<T>);
- using Op = FloatRoundingComputation<T, rounding_mode, scale_mode>;
+ using Op = FloatRoundingComputation<T, rounding_mode, scale_mode,
tie_breaking_mode>;
using Data = std::array<T, Op::data_count>;
using ColumnType = ColumnVector<T>;
using Container = typename ColumnType::Container;
@@ -433,7 +396,8 @@ struct Dispatcher {
using FunctionRoundingImpl = std::conditional_t<
IsDecimalNumber<T>, DecimalRoundingImpl<T, rounding_mode,
tie_breaking_mode>,
std::conditional_t<
- std::is_floating_point_v<T>, FloatRoundingImpl<T,
rounding_mode, scale_mode>,
+ std::is_floating_point_v<T>,
+ FloatRoundingImpl<T, rounding_mode, scale_mode,
tie_breaking_mode>,
IntegerRoundingImpl<T, rounding_mode, scale_mode,
tie_breaking_mode>>>;
static ColumnPtr apply(const IColumn* col_general, Int16 scale_arg) {
diff --git a/docs/en/docs/sql-manual/sql-functions/numeric-functions/round.md
b/docs/en/docs/sql-manual/sql-functions/numeric-functions/round.md
index ad574d3ff27..f34519acb60 100644
--- a/docs/en/docs/sql-manual/sql-functions/numeric-functions/round.md
+++ b/docs/en/docs/sql-manual/sql-functions/numeric-functions/round.md
@@ -29,6 +29,7 @@ under the License.
`T round(T x[, d])`
Rounds the argument `x` to `d` decimal places. `d` defaults to 0 if not
specified. If d is negative, the left d digits of the decimal point are 0. If x
or d is null, null is returned.
+2.5 will round up to 3. If you want to round down to 2, please use the
round_bankers function.
### example
diff --git
a/docs/zh-CN/docs/sql-manual/sql-functions/numeric-functions/round.md
b/docs/zh-CN/docs/sql-manual/sql-functions/numeric-functions/round.md
index 8909f7c9d11..576a96e1d8d 100644
--- a/docs/zh-CN/docs/sql-manual/sql-functions/numeric-functions/round.md
+++ b/docs/zh-CN/docs/sql-manual/sql-functions/numeric-functions/round.md
@@ -29,7 +29,7 @@ under the License.
`T round(T x[, d])`
将`x`四舍五入后保留d位小数,d默认为0。如果d为负数,则小数点左边d位为0。如果x或d为null,返回null。
-
+2.5会舍入到3,如果想要舍入到2的算法,请使用round_bankers函数。
### example
```
diff --git a/regression-test/data/correctness/test_float_round_up.out
b/regression-test/data/correctness/test_float_round_up.out
new file mode 100644
index 00000000000..812d0826ccd
--- /dev/null
+++ b/regression-test/data/correctness/test_float_round_up.out
@@ -0,0 +1,13 @@
+-- This file is automatically generated. You should know what you did if you
want to edit this
+-- !select --
+2.5 3.0
+
+-- !select --
+3.5 4.0
+
+-- !select --
+2.5 2.0
+
+-- !select --
+3.5 4.0
+
diff --git a/regression-test/data/nereids_function_p0/scalar_function/D.out
b/regression-test/data/nereids_function_p0/scalar_function/D.out
index 7814c1c6f99..18d46ce727c 100644
--- a/regression-test/data/nereids_function_p0/scalar_function/D.out
+++ b/regression-test/data/nereids_function_p0/scalar_function/D.out
@@ -2876,7 +2876,7 @@ Monday
0.0
0.0
0.0
-0.0
+1.0
1.0
1.0
1.0
@@ -2890,7 +2890,7 @@ Monday
0.0
0.0
0.0
-0.0
+1.0
1.0
1.0
1.0
diff --git a/regression-test/data/nereids_function_p0/scalar_function/R.out
b/regression-test/data/nereids_function_p0/scalar_function/R.out
index 34c2514e147..1ad09e5cb6a 100644
--- a/regression-test/data/nereids_function_p0/scalar_function/R.out
+++ b/regression-test/data/nereids_function_p0/scalar_function/R.out
@@ -469,7 +469,7 @@ string3
0.0
0.0
0.0
-0.0
+1.0
1.0
1.0
1.0
@@ -483,7 +483,7 @@ string3
0.0
0.0
0.0
-0.0
+1.0
1.0
1.0
1.0
diff --git
a/regression-test/data/query_p0/sql_functions/math_functions/test_round.out
b/regression-test/data/query_p0/sql_functions/math_functions/test_round.out
index bbc66e5b976..50d15b2843b 100644
--- a/regression-test/data/query_p0/sql_functions/math_functions/test_round.out
+++ b/regression-test/data/query_p0/sql_functions/math_functions/test_round.out
@@ -11,18 +11,6 @@
-- !select --
10.12
--- !select --
-0.0 0.0
-0.5 0.0
-1.0 1.0
-1.5 2.0
-2.0 2.0
-2.5 2.0
-3.0 3.0
-3.5 4.0
-4.0 4.0
-4.5 4.0
-
-- !select --
0.0 0.0
0.5 1.0
@@ -35,6 +23,18 @@
4.0 4.0
4.5 5.0
+-- !select --
+0.0 0
+0.5 1
+1.0 1
+1.5 2
+2.0 2
+2.5 3
+3.0 3
+3.5 4
+4.0 4
+4.5 5
+
-- !truncate --
1 1989 1001 123.1 0.1 6.3
2 1986 1001 1243.5 20.2 789.2
diff --git a/regression-test/suites/correctness/test_float_round_up.groovy
b/regression-test/suites/correctness/test_float_round_up.groovy
new file mode 100644
index 00000000000..26839e115aa
--- /dev/null
+++ b/regression-test/suites/correctness/test_float_round_up.groovy
@@ -0,0 +1,38 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_float_round_up") {
+ sql """ set enable_nereids_planner=true; """
+ sql """ set enable_fallback_to_original_planner=false; """
+
+
+ qt_select """
+ select 5/2 , round(5/2);
+ """
+
+ qt_select """
+ select 7/ 2 , round(7/2);
+ """
+
+ qt_select """
+ select 5/2 , round_bankers(5/2);
+ """
+
+ qt_select """
+ select 7/ 2 , round_bankers(7/2);
+ """
+}
\ No newline at end of file
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]