This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new d36e3e62c72a [SPARK-45786][SQL] Fix inaccurate Decimal multiplication and division results d36e3e62c72a is described below commit d36e3e62c72ae121ebf3404db7c4cc51fe66066b Author: Kazuyuki Tanimura <ktanim...@apple.com> AuthorDate: Tue Nov 7 09:06:00 2023 -0800 [SPARK-45786][SQL] Fix inaccurate Decimal multiplication and division results ### What changes were proposed in this pull request? This PR fixes inaccurate Decimal multiplication and division results. ### Why are the changes needed? Decimal multiplication and division results may be inaccurate due to rounding issues. #### Multiplication: ``` scala> sql("select -14120025096157587712113961295153.858047 * -0.4652").show(truncate=false) +----------------------------------------------------+ |(-14120025096157587712113961295153.858047 * -0.4652)| +----------------------------------------------------+ |6568635674732509803675414794505.574764 | +----------------------------------------------------+ ``` The correct answer is `6568635674732509803675414794505.574763` Please note that the last digit is `3` instead of `4` as ``` scala> java.math.BigDecimal("-14120025096157587712113961295153.858047").multiply(java.math.BigDecimal("-0.4652")) val res21: java.math.BigDecimal = 6568635674732509803675414794505.5747634644 ``` Since the factional part `.574763` is followed by `4644`, it should not be rounded up. #### Division: ``` scala> sql("select -0.172787979 / 533704665545018957788294905796.5").show(truncate=false) +-------------------------------------------------+ |(-0.172787979 / 533704665545018957788294905796.5)| +-------------------------------------------------+ |-3.237521E-31 | +-------------------------------------------------+ ``` The correct answer is `-3.237520E-31` Please note that the last digit is `0` instead of `1` as ``` scala> java.math.BigDecimal("-0.172787979").divide(java.math.BigDecimal("533704665545018957788294905796.5"), 100, java.math.RoundingMode.DOWN) val res22: java.math.BigDecimal = -3.237520489418037889998826491401059986665344697406144511563561222578738E-31 ``` Since the factional part `.237520` is followed by `4894...`, it should not be rounded up. ### Does this PR introduce _any_ user-facing change? Yes, users will see correct Decimal multiplication and division results. Directly multiplying and dividing with `org.apache.spark.sql.types.Decimal()` (not via SQL) will return 39 digit at maximum instead of 38 at maximum and round down instead of round half-up ### How was this patch tested? Test added ### Was this patch authored or co-authored using generative AI tooling? No Closes #43678 from kazuyukitanimura/SPARK-45786. Authored-by: Kazuyuki Tanimura <ktanim...@apple.com> Signed-off-by: Dongjoon Hyun <dh...@apple.com> (cherry picked from commit 5ef3a846f52ab90cb7183953cff3080449d0b57b) Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- .../scala/org/apache/spark/sql/types/Decimal.scala | 8 +- .../expressions/ArithmeticExpressionSuite.scala | 107 +++++++++++++++++++++ .../ansi/decimalArithmeticOperations.sql.out | 14 +-- 3 files changed, 120 insertions(+), 9 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala index afe73635a682..77e9aa06c830 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -499,7 +499,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { def / (that: Decimal): Decimal = if (that.isZero) null else Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal, - DecimalType.MAX_SCALE, MATH_CONTEXT.getRoundingMode)) + DecimalType.MAX_SCALE + 1, MATH_CONTEXT.getRoundingMode)) def % (that: Decimal): Decimal = if (that.isZero) null @@ -547,7 +547,11 @@ object Decimal { val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong) - private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION, RoundingMode.HALF_UP) + // SPARK-45786 Using RoundingMode.HALF_UP with MathContext may cause inaccurate SQL results + // because TypeCoercion later rounds again. Instead, always round down and use 1 digit longer + // precision than DecimalType.MAX_PRECISION. Then, TypeCoercion will properly round up/down + // the last extra digit. + private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION + 1, RoundingMode.DOWN) private[sql] val ZERO = Decimal(0) private[sql] val ONE = Decimal(1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index e21793ab506c..568dcd10d116 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.math.RoundingMode import java.sql.{Date, Timestamp} import java.time.{Duration, Period} import java.time.temporal.ChronoUnit @@ -225,6 +226,112 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } + test("SPARK-45786: Decimal multiply, divide, remainder, quot") { + // Some known cases + checkEvaluation( + Multiply( + Literal(Decimal(BigDecimal("-14120025096157587712113961295153.858047"), 38, 6)), + Literal(Decimal(BigDecimal("-0.4652"), 4, 4)) + ), + Decimal(BigDecimal("6568635674732509803675414794505.574763")) + ) + checkEvaluation( + Multiply( + Literal(Decimal(BigDecimal("-240810500742726"), 15, 0)), + Literal(Decimal(BigDecimal("-5677.6988688550027099967697071"), 29, 25)) + ), + Decimal(BigDecimal("1367249507675382200.164877854336665327")) + ) + checkEvaluation( + Divide( + Literal(Decimal(BigDecimal("-0.172787979"), 9, 9)), + Literal(Decimal(BigDecimal("533704665545018957788294905796.5"), 31, 1)) + ), + Decimal(BigDecimal("-3.237520E-31")) + ) + checkEvaluation( + Divide( + Literal(Decimal(BigDecimal("-0.574302343618"), 12, 12)), + Literal(Decimal(BigDecimal("-795826820326278835912868.106"), 27, 3)) + ), + Decimal(BigDecimal("7.21642358550E-25")) + ) + + // Random tests + val rand = scala.util.Random + def makeNum(p: Int, s: Int): String = { + val int1 = rand.nextLong() + val int2 = rand.nextLong().abs + val frac1 = rand.nextLong().abs + val frac2 = rand.nextLong().abs + s"$int1$int2".take(p - s + (int1 >>> 63).toInt) + "." + s"$frac1$frac2".take(s) + } + + (0 until 100).foreach { _ => + val p1 = rand.nextInt(38) + 1 // 1 <= p1 <= 38 + val s1 = rand.nextInt(p1 + 1) // 0 <= s1 <= p1 + val p2 = rand.nextInt(38) + 1 + val s2 = rand.nextInt(p2 + 1) + + val n1 = makeNum(p1, s1) + val n2 = makeNum(p2, s2) + + val mulActual = Multiply( + Literal(Decimal(BigDecimal(n1), p1, s1)), + Literal(Decimal(BigDecimal(n2), p2, s2)) + ) + val mulExact = new java.math.BigDecimal(n1).multiply(new java.math.BigDecimal(n2)) + + val divActual = Divide( + Literal(Decimal(BigDecimal(n1), p1, s1)), + Literal(Decimal(BigDecimal(n2), p2, s2)) + ) + val divExact = new java.math.BigDecimal(n1) + .divide(new java.math.BigDecimal(n2), 100, RoundingMode.DOWN) + + val remActual = Remainder( + Literal(Decimal(BigDecimal(n1), p1, s1)), + Literal(Decimal(BigDecimal(n2), p2, s2)) + ) + val remExact = new java.math.BigDecimal(n1).remainder(new java.math.BigDecimal(n2)) + + val quotActual = IntegralDivide( + Literal(Decimal(BigDecimal(n1), p1, s1)), + Literal(Decimal(BigDecimal(n2), p2, s2)) + ) + val quotExact = + new java.math.BigDecimal(n1).divideToIntegralValue(new java.math.BigDecimal(n2)) + + Seq(true, false).foreach { allowPrecLoss => + withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> allowPrecLoss.toString) { + val mulType = Multiply(null, null).resultDecimalType(p1, s1, p2, s2) + val mulResult = Decimal(mulExact.setScale(mulType.scale, RoundingMode.HALF_UP)) + val mulExpected = + if (mulResult.precision > DecimalType.MAX_PRECISION) null else mulResult + checkEvaluation(mulActual, mulExpected) + + val divType = Divide(null, null).resultDecimalType(p1, s1, p2, s2) + val divResult = Decimal(divExact.setScale(divType.scale, RoundingMode.HALF_UP)) + val divExpected = + if (divResult.precision > DecimalType.MAX_PRECISION) null else divResult + checkEvaluation(divActual, divExpected) + + val remType = Remainder(null, null).resultDecimalType(p1, s1, p2, s2) + val remResult = Decimal(remExact.setScale(remType.scale, RoundingMode.HALF_UP)) + val remExpected = + if (remResult.precision > DecimalType.MAX_PRECISION) null else remResult + checkEvaluation(remActual, remExpected) + + val quotType = IntegralDivide(null, null).resultDecimalType(p1, s1, p2, s2) + val quotResult = Decimal(quotExact.setScale(quotType.scale, RoundingMode.HALF_UP)) + val quotExpected = + if (quotResult.precision > DecimalType.MAX_PRECISION) null else quotResult + checkEvaluation(quotActual, quotExpected.toLong) + } + } + } + } + private def testDecimalAndDoubleType(testFunc: (Int => Any) => Unit): Unit = { testFunc(_.toDouble) testFunc(Decimal(_)) diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/decimalArithmeticOperations.sql.out index 699c916fd8fd..9593291fae21 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/decimalArithmeticOperations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/decimalArithmeticOperations.sql.out @@ -155,7 +155,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "1000000000000000000000000000000000000.00000000000000000000000000000000000000" + "value" : "1000000000000000000000000000000000000.000000000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", @@ -204,7 +204,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "10123456789012345678901234567890123456.00000000000000000000000000000000000000" + "value" : "10123456789012345678901234567890123456.000000000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", @@ -229,7 +229,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "101234567890123456789012345678901234.56000000000000000000000000000000000000" + "value" : "101234567890123456789012345678901234.560000000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", @@ -254,7 +254,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "10123456789012345678901234567890123.45600000000000000000000000000000000000" + "value" : "10123456789012345678901234567890123.456000000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", @@ -279,7 +279,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "1012345678901234567890123456789012.34560000000000000000000000000000000000" + "value" : "1012345678901234567890123456789012.345600000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", @@ -304,7 +304,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "101234567890123456789012345678901.23456000000000000000000000000000000000" + "value" : "101234567890123456789012345678901.234560000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", @@ -337,7 +337,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "101234567890123456789012345678901.23456000000000000000000000000000000000" + "value" : "101234567890123456789012345678901.234560000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org