This is an automated email from the ASF dual-hosted git repository.
rui pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 920cfafac [VL] Fix high precision rounding (#6707)
920cfafac is described below
commit 920cfafac58bdcd3cfb7c8112a05f0fb3b976ec9
Author: Arnav Balyan <[email protected]>
AuthorDate: Fri Aug 9 18:57:11 2024 +0530
[VL] Fix high precision rounding (#6707)
---
cpp/velox/operators/functions/Arithmetic.h | 11 +++++++----
.../sql/catalyst/expressions/GlutenMathExpressionsSuite.scala | 3 +++
.../sql/catalyst/expressions/GlutenMathExpressionsSuite.scala | 4 ++++
.../sql/catalyst/expressions/GlutenMathExpressionsSuite.scala | 3 +++
4 files changed, 17 insertions(+), 4 deletions(-)
diff --git a/cpp/velox/operators/functions/Arithmetic.h
b/cpp/velox/operators/functions/Arithmetic.h
index 0474e1554..7b4c9ae9d 100644
--- a/cpp/velox/operators/functions/Arithmetic.h
+++ b/cpp/velox/operators/functions/Arithmetic.h
@@ -17,6 +17,7 @@
#include <folly/CPortability.h>
#include <stdint.h>
#include <cmath>
+#include <limits>
#include <type_traits>
namespace gluten {
@@ -38,14 +39,16 @@ struct RoundFunction {
return number;
}
- double factor = std::pow(10, decimals);
+ // Using long double for high precision during intermediate calculations.
+ // TODO: Make this more efficient with Boost to support high arbitrary
precision at runtime.
+ long double factor = std::pow(10.0L, static_cast<long double>(decimals));
static const TNum kInf = std::numeric_limits<TNum>::infinity();
+
if (number < 0) {
- return (std::round(std::nextafter(number, -kInf) * factor * -1) /
factor) * -1;
+ return static_cast<TNum>((std::round(std::nextafter(number, -kInf) *
factor * -1) / factor) * -1);
}
- return std::round(std::nextafter(number, kInf) * factor) / factor;
+ return static_cast<TNum>(std::round(std::nextafter(number, kInf) * factor)
/ factor);
}
-
template <typename TInput>
FOLLY_ALWAYS_INLINE void call(TInput& result, const TInput& a, const int32_t
b = 0) {
result = round(a, b);
diff --git
a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala
b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala
index 54583547d..765a64f91 100644
---
a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala
+++
b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala
@@ -121,6 +121,9 @@ class GlutenMathExpressionsSuite extends
MathExpressionsSuite with GlutenTestsTr
checkEvaluation(Round(-3.5, 0), -4.0)
checkEvaluation(Round(-0.35, 1), -0.4)
checkEvaluation(Round(-35, -1), -40)
+ checkEvaluation(Round(1.12345678901234567, 8), 1.12345679)
+ checkEvaluation(Round(-0.98765432109876543, 5), -0.98765)
+ checkEvaluation(Round(12345.67890123456789, 6), 12345.678901)
checkEvaluation(BRound(2.5, 0), 2.0)
checkEvaluation(BRound(3.5, 0), 4.0)
checkEvaluation(BRound(-2.5, 0), -2.0)
diff --git
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala
index a60f0dce6..122f8dc06 100644
---
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala
+++
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala
@@ -249,6 +249,10 @@ class GlutenMathExpressionsSuite extends
MathExpressionsSuite with GlutenTestsTr
checkEvaluation(Round(-3.5, 0), -4.0)
checkEvaluation(Round(-0.35, 1), -0.4)
checkEvaluation(Round(-35, -1), -40)
+ checkEvaluation(Round(1.12345678901234567, 8), 1.12345679)
+ checkEvaluation(Round(-0.98765432109876543, 5), -0.98765)
+ checkEvaluation(Round(12345.67890123456789, 6), 12345.678901)
+ checkEvaluation(Round(-35, -1), -40)
checkEvaluation(Round(BigDecimal("45.00"), -1), BigDecimal(50))
checkEvaluation(BRound(2.5, 0), 2.0)
checkEvaluation(BRound(3.5, 0), 4.0)
diff --git
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala
index e22092488..7308352e4 100644
---
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala
+++
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala
@@ -248,6 +248,9 @@ class GlutenMathExpressionsSuite extends
MathExpressionsSuite with GlutenTestsTr
checkEvaluation(BRound(-3.5, 0), -4.0)
checkEvaluation(BRound(-0.35, 1), -0.4)
checkEvaluation(BRound(-35, -1), -40)
+ checkEvaluation(Round(1.12345678901234567, 8), 1.12345679)
+ checkEvaluation(Round(-0.98765432109876543, 5), -0.98765)
+ checkEvaluation(Round(12345.67890123456789, 6), 12345.678901)
checkEvaluation(BRound(BigDecimal("45.00"), -1), BigDecimal(40))
checkEvaluation(checkDataTypeAndCast(RoundFloor(Literal(2.5),
Literal(0))), Decimal(2))
checkEvaluation(checkDataTypeAndCast(RoundFloor(Literal(3.5),
Literal(0))), Decimal(3))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]