This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 838954e5080 [SPARK-41554] fix changing of Decimal scale when scale
decreased by m…
838954e5080 is described below
commit 838954e50807e583ceb8317877710d58acff0a4b
Author: oleksii.diagiliev <[email protected]>
AuthorDate: Fri Dec 30 15:52:05 2022 +0800
[SPARK-41554] fix changing of Decimal scale when scale decreased by m…
…ore than 18
### What changes were proposed in this pull request?
Fix `Decimal` scaling that is stored as compact long internally when scale
decreased by more than 18. For example,
```
Decimal(1, 38, 19).changePrecision(38, 0)
```
produces an exception
```
java.lang.ArrayIndexOutOfBoundsException: 19
at org.apache.spark.sql.types.Decimal.changePrecision(Decimal.scala:377)
at org.apache.spark.sql.types.Decimal.changePrecision(Decimal.scala:328)
```
Another way to reproduce it with SQL query
```
sql("select cast(cast(cast(cast(id as decimal(38,15)) as decimal(38,30)) as
decimal(38,37)) as decimal(38,17)) from range(3)").show
```
The bug exists for Decimal that is stored using compact long only, it works
fine with Decimal that uses `scala.math.BigDecimal` internally.
### Why are the changes needed?
Not able to execute the SQL query mentioned above. Please note, for my use
case the SQL query is generated programatically, so I cannot optimize it
manually.
### Does this PR introduce _any_ user-facing change?
Yes, it will allow scale Decimal properly that is not currently possible
due to the exception.
### How was this patch tested?
Tests were added. The fix affects the scale decrease only, but I decided to
also include tests for scale increase as I didn't find them.
Closes #39099 from fe2s/SPARK-41554-fix-decimal-scaling.
Authored-by: oleksii.diagiliev <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../scala/org/apache/spark/sql/types/Decimal.scala | 60 +++++++++++++---------
.../org/apache/spark/sql/types/DecimalSuite.scala | 52 ++++++++++++++++++-
2 files changed, 87 insertions(+), 25 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 44c00df379f..2c0b6677541 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -374,30 +374,42 @@ final class Decimal extends Ordered[Decimal] with
Serializable {
if (scale < _scale) {
// Easier case: we just need to divide our scale down
val diff = _scale - scale
- val pow10diff = POW_10(diff)
- // % and / always round to 0
- val droppedDigits = lv % pow10diff
- lv /= pow10diff
- roundMode match {
- case ROUND_FLOOR =>
- if (droppedDigits < 0) {
- lv += -1L
- }
- case ROUND_CEILING =>
- if (droppedDigits > 0) {
- lv += 1L
- }
- case ROUND_HALF_UP =>
- if (math.abs(droppedDigits) * 2 >= pow10diff) {
- lv += (if (droppedDigits < 0) -1L else 1L)
- }
- case ROUND_HALF_EVEN =>
- val doubled = math.abs(droppedDigits) * 2
- if (doubled > pow10diff || doubled == pow10diff && lv % 2 != 0) {
- lv += (if (droppedDigits < 0) -1L else 1L)
- }
- case _ =>
- throw QueryExecutionErrors.unsupportedRoundingMode(roundMode)
+ // If diff is greater than max number of digits we store in Long, then
+ // value becomes 0. Otherwise we calculate new value dividing by power
of 10.
+ // In both cases we apply rounding after that.
+ if (diff > MAX_LONG_DIGITS) {
+ lv = roundMode match {
+ case ROUND_FLOOR => if (lv < 0) -1L else 0L
+ case ROUND_CEILING => if (lv > 0) 1L else 0L
+ case ROUND_HALF_UP | ROUND_HALF_EVEN => 0L
+ case _ => throw
QueryExecutionErrors.unsupportedRoundingMode(roundMode)
+ }
+ } else {
+ val pow10diff = POW_10(diff)
+ // % and / always round to 0
+ val droppedDigits = lv % pow10diff
+ lv /= pow10diff
+ roundMode match {
+ case ROUND_FLOOR =>
+ if (droppedDigits < 0) {
+ lv += -1L
+ }
+ case ROUND_CEILING =>
+ if (droppedDigits > 0) {
+ lv += 1L
+ }
+ case ROUND_HALF_UP =>
+ if (math.abs(droppedDigits) * 2 >= pow10diff) {
+ lv += (if (droppedDigits < 0) -1L else 1L)
+ }
+ case ROUND_HALF_EVEN =>
+ val doubled = math.abs(droppedDigits) * 2
+ if (doubled > pow10diff || doubled == pow10diff && lv % 2 != 0) {
+ lv += (if (droppedDigits < 0) -1L else 1L)
+ }
+ case _ =>
+ throw QueryExecutionErrors.unsupportedRoundingMode(roundMode)
+ }
}
} else if (scale > _scale) {
// We might be able to multiply lv by a power of 10 and not overflow,
but if not,
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
index 89b6cc0b0cb..73944d9dff9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
@@ -27,6 +27,9 @@ import org.apache.spark.sql.types.Decimal._
import org.apache.spark.unsafe.types.UTF8String
class DecimalSuite extends SparkFunSuite with PrivateMethodTester with
SQLHelper {
+
+ val allSupportedRoundModes = Seq(ROUND_HALF_UP, ROUND_HALF_EVEN,
ROUND_CEILING, ROUND_FLOOR)
+
/** Check that a Decimal has the given string representation, precision and
scale */
private def checkDecimal(d: Decimal, string: String, precision: Int, scale:
Int): Unit = {
assert(d.toString === string)
@@ -278,7 +281,7 @@ class DecimalSuite extends SparkFunSuite with
PrivateMethodTester with SQLHelper
}
test("changePrecision/toPrecision on compact decimal should respect rounding
mode") {
- Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach {
mode =>
+ allSupportedRoundModes.foreach { mode =>
Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n
=>
Seq("", "-").foreach { sign =>
val bd = BigDecimal(sign + n)
@@ -384,4 +387,51 @@ class DecimalSuite extends SparkFunSuite with
PrivateMethodTester with SQLHelper
}
}
}
+
+ // 18 is a max number of digits in Decimal's compact long
+ test("SPARK-41554: decrease/increase scale by 18 and more on compact
decimal") {
+ val unscaledNums = Seq(
+ 0L, 1L, 10L, 51L, 123L, 523L,
+ // 18 digits
+ 912345678901234567L,
+ 112345678901234567L,
+ 512345678901234567L
+ )
+ val precision = 38
+ // generate some (from, to) scale pairs, e.g. (38, 18), (-20, -2), etc
+ val scalePairs = for {
+ scale <- Seq(38, 20, 19, 18)
+ delta <- Seq(38, 20, 19, 18)
+ a = scale
+ b = scale - delta
+ } yield {
+ Seq((a, b), (-a, -b), (b, a), (-b, -a))
+ }
+
+ for {
+ unscaled <- unscaledNums
+ mode <- allSupportedRoundModes
+ (scaleFrom, scaleTo) <- scalePairs.flatten
+ sign <- Seq(1L, -1L)
+ } {
+ val unscaledWithSign = unscaled * sign
+ if (scaleFrom < 0 || scaleTo < 0) {
+ withSQLConf(SQLConf.LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED.key
-> "true") {
+ checkScaleChange(unscaledWithSign, scaleFrom, scaleTo, mode)
+ }
+ } else {
+ checkScaleChange(unscaledWithSign, scaleFrom, scaleTo, mode)
+ }
+ }
+
+ def checkScaleChange(unscaled: Long, scaleFrom: Int, scaleTo: Int,
+ roundMode: BigDecimal.RoundingMode.Value): Unit = {
+ val decimal = Decimal(unscaled, precision, scaleFrom)
+ checkCompact(decimal, true)
+ decimal.changePrecision(precision, scaleTo, roundMode)
+ val bd = BigDecimal(unscaled, scaleFrom).setScale(scaleTo, roundMode)
+ assert(decimal.toBigDecimal === bd,
+ s"unscaled: $unscaled, scaleFrom: $scaleFrom, scaleTo: $scaleTo, mode:
$roundMode")
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]