This is an automated email from the ASF dual-hosted git repository. srowen pushed a commit to branch branch-3.3 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push: new 2d539c5c702 [SPARK-41554] fix changing of Decimal scale when scale decreased by m… 2d539c5c702 is described below commit 2d539c5c7022d44d8a2d53e752287c42c2601444 Author: oleksii.diagiliev <oleksii.diagil...@workday.com> AuthorDate: Fri Feb 3 10:48:56 2023 -0600 [SPARK-41554] fix changing of Decimal scale when scale decreased by m… …ore than 18 This is a backport PR for https://github.com/apache/spark/pull/39099 Closes #39813 from fe2s/branch-3.3-fix-decimal-scaling. Authored-by: oleksii.diagiliev <oleksii.diagil...@workday.com> Signed-off-by: Sean Owen <sro...@gmail.com> --- .../scala/org/apache/spark/sql/types/Decimal.scala | 60 +++++++++++++--------- .../org/apache/spark/sql/types/DecimalSuite.scala | 53 ++++++++++++++++++- 2 files changed, 88 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 7a43d01eb2f..07a2c47cff0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -397,30 +397,42 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (scale < _scale) { // Easier case: we just need to divide our scale down val diff = _scale - scale - val pow10diff = POW_10(diff) - // % and / always round to 0 - val droppedDigits = longVal % pow10diff - longVal /= pow10diff - roundMode match { - case ROUND_FLOOR => - if (droppedDigits < 0) { - longVal += -1L - } - case ROUND_CEILING => - if (droppedDigits > 0) { - longVal += 1L - } - case ROUND_HALF_UP => - if (math.abs(droppedDigits) * 2 >= pow10diff) { - longVal += (if (droppedDigits < 0) -1L else 1L) - } - case ROUND_HALF_EVEN => - val doubled = math.abs(droppedDigits) * 2 - if (doubled > pow10diff || doubled == pow10diff && longVal % 2 != 0) { - longVal += (if (droppedDigits < 0) -1L else 1L) - } - case _ => - throw QueryExecutionErrors.unsupportedRoundingMode(roundMode) + // If diff is greater than max number of digits we store in Long, then + // value becomes 0. Otherwise we calculate new value dividing by power of 10. + // In both cases we apply rounding after that. + if (diff > MAX_LONG_DIGITS) { + longVal = roundMode match { + case ROUND_FLOOR => if (longVal < 0) -1L else 0L + case ROUND_CEILING => if (longVal > 0) 1L else 0L + case ROUND_HALF_UP | ROUND_HALF_EVEN => 0L + case _ => sys.error(s"Not supported rounding mode: $roundMode") + } + } else { + val pow10diff = POW_10(diff) + // % and / always round to 0 + val droppedDigits = longVal % pow10diff + longVal /= pow10diff + roundMode match { + case ROUND_FLOOR => + if (droppedDigits < 0) { + longVal += -1L + } + case ROUND_CEILING => + if (droppedDigits > 0) { + longVal += 1L + } + case ROUND_HALF_UP => + if (math.abs(droppedDigits) * 2 >= pow10diff) { + longVal += (if (droppedDigits < 0) -1L else 1L) + } + case ROUND_HALF_EVEN => + val doubled = math.abs(droppedDigits) * 2 + if (doubled > pow10diff || doubled == pow10diff && longVal % 2 != 0) { + longVal += (if (droppedDigits < 0) -1L else 1L) + } + case _ => + throw QueryExecutionErrors.unsupportedRoundingMode(roundMode) + } } } else if (scale > _scale) { // We might be able to multiply longVal by a power of 10 and not overflow, but if not, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 6f70dc51b95..6ccd2b9bd32 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -27,6 +27,9 @@ import org.apache.spark.sql.types.Decimal._ import org.apache.spark.unsafe.types.UTF8String class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper { + + val allSupportedRoundModes = Seq(ROUND_HALF_UP, ROUND_HALF_EVEN, ROUND_CEILING, ROUND_FLOOR) + /** Check that a Decimal has the given string representation, precision and scale */ private def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = { assert(d.toString === string) @@ -222,7 +225,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper } test("changePrecision/toPrecision on compact decimal should respect rounding mode") { - Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { mode => + allSupportedRoundModes.foreach { mode => Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n => Seq("", "-").foreach { sign => val bd = BigDecimal(sign + n) @@ -315,4 +318,52 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester with SQLHelper } } } + + // 18 is a max number of digits in Decimal's compact long + test("SPARK-41554: decrease/increase scale by 18 and more on compact decimal") { + val unscaledNums = Seq( + 0L, 1L, 10L, 51L, 123L, 523L, + // 18 digits + 912345678901234567L, + 112345678901234567L, + 512345678901234567L + ) + val precision = 38 + // generate some (from, to) scale pairs, e.g. (38, 18), (-20, -2), etc + val scalePairs = for { + scale <- Seq(38, 20, 19, 18) + delta <- Seq(38, 20, 19, 18) + a = scale + b = scale - delta + } yield { + Seq((a, b), (-a, -b), (b, a), (-b, -a)) + } + + for { + unscaled <- unscaledNums + mode <- allSupportedRoundModes + (scaleFrom, scaleTo) <- scalePairs.flatten + sign <- Seq(1L, -1L) + } { + val unscaledWithSign = unscaled * sign + if (scaleFrom < 0 || scaleTo < 0) { + withSQLConf(SQLConf.LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED.key -> "true") { + checkScaleChange(unscaledWithSign, scaleFrom, scaleTo, mode) + } + } else { + checkScaleChange(unscaledWithSign, scaleFrom, scaleTo, mode) + } + } + + def checkScaleChange(unscaled: Long, scaleFrom: Int, scaleTo: Int, + roundMode: BigDecimal.RoundingMode.Value): Unit = { + val decimal = Decimal(unscaled, precision, scaleFrom) + checkCompact(decimal, true) + decimal.changePrecision(precision, scaleTo, roundMode) + val bd = BigDecimal(unscaled, scaleFrom).setScale(scaleTo, roundMode) + assert(decimal.toBigDecimal === bd, + s"unscaled: $unscaled, scaleFrom: $scaleFrom, scaleTo: $scaleTo, mode: $roundMode") + } + } + } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org