This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 838954e5080 [SPARK-41554] fix changing of Decimal scale when scale 
decreased by m…
838954e5080 is described below

commit 838954e50807e583ceb8317877710d58acff0a4b
Author: oleksii.diagiliev <[email protected]>
AuthorDate: Fri Dec 30 15:52:05 2022 +0800

    [SPARK-41554] fix changing of Decimal scale when scale decreased by m…
    
    …ore than 18
    
    ### What changes were proposed in this pull request?
    Fix `Decimal` scaling that is stored as compact long internally when scale 
decreased by more than 18. For example,
    ```
    Decimal(1, 38, 19).changePrecision(38, 0)
    ```
    produces an exception
    ```
    java.lang.ArrayIndexOutOfBoundsException: 19
        at org.apache.spark.sql.types.Decimal.changePrecision(Decimal.scala:377)
        at org.apache.spark.sql.types.Decimal.changePrecision(Decimal.scala:328)
    ```
    Another way to reproduce it with SQL query
    ```
    sql("select cast(cast(cast(cast(id as decimal(38,15)) as decimal(38,30)) as 
decimal(38,37)) as decimal(38,17)) from range(3)").show
    ```
    
    The bug exists for Decimal that is stored using compact long only, it works 
fine with Decimal that uses `scala.math.BigDecimal` internally.
    
    ### Why are the changes needed?
    Not able to execute the SQL query mentioned above. Please note, for my use 
case the SQL query is generated programatically, so I cannot optimize it 
manually.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, it will allow scale Decimal properly that is not currently possible 
due to the exception.
    
    ### How was this patch tested?
    Tests were added. The fix affects the scale decrease only, but I decided to 
also include tests for scale increase as I didn't find them.
    
    Closes #39099 from fe2s/SPARK-41554-fix-decimal-scaling.
    
    Authored-by: oleksii.diagiliev <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../scala/org/apache/spark/sql/types/Decimal.scala | 60 +++++++++++++---------
 .../org/apache/spark/sql/types/DecimalSuite.scala  | 52 ++++++++++++++++++-
 2 files changed, 87 insertions(+), 25 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 44c00df379f..2c0b6677541 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -374,30 +374,42 @@ final class Decimal extends Ordered[Decimal] with 
Serializable {
       if (scale < _scale) {
         // Easier case: we just need to divide our scale down
         val diff = _scale - scale
-        val pow10diff = POW_10(diff)
-        // % and / always round to 0
-        val droppedDigits = lv % pow10diff
-        lv /= pow10diff
-        roundMode match {
-          case ROUND_FLOOR =>
-            if (droppedDigits < 0) {
-              lv += -1L
-            }
-          case ROUND_CEILING =>
-            if (droppedDigits > 0) {
-              lv += 1L
-            }
-          case ROUND_HALF_UP =>
-            if (math.abs(droppedDigits) * 2 >= pow10diff) {
-              lv += (if (droppedDigits < 0) -1L else 1L)
-            }
-          case ROUND_HALF_EVEN =>
-            val doubled = math.abs(droppedDigits) * 2
-            if (doubled > pow10diff || doubled == pow10diff && lv % 2 != 0) {
-              lv += (if (droppedDigits < 0) -1L else 1L)
-            }
-          case _ =>
-            throw QueryExecutionErrors.unsupportedRoundingMode(roundMode)
+        // If diff is greater than max number of digits we store in Long, then
+        // value becomes 0. Otherwise we calculate new value dividing by power 
of 10.
+        // In both cases we apply rounding after that.
+        if (diff > MAX_LONG_DIGITS) {
+          lv = roundMode match {
+            case ROUND_FLOOR => if (lv < 0) -1L else 0L
+            case ROUND_CEILING => if (lv > 0) 1L else 0L
+            case ROUND_HALF_UP | ROUND_HALF_EVEN => 0L
+            case _ => throw 
QueryExecutionErrors.unsupportedRoundingMode(roundMode)
+          }
+        } else {
+          val pow10diff = POW_10(diff)
+          // % and / always round to 0
+          val droppedDigits = lv % pow10diff
+          lv /= pow10diff
+          roundMode match {
+            case ROUND_FLOOR =>
+              if (droppedDigits < 0) {
+                lv += -1L
+              }
+            case ROUND_CEILING =>
+              if (droppedDigits > 0) {
+                lv += 1L
+              }
+            case ROUND_HALF_UP =>
+              if (math.abs(droppedDigits) * 2 >= pow10diff) {
+                lv += (if (droppedDigits < 0) -1L else 1L)
+              }
+            case ROUND_HALF_EVEN =>
+              val doubled = math.abs(droppedDigits) * 2
+              if (doubled > pow10diff || doubled == pow10diff && lv % 2 != 0) {
+                lv += (if (droppedDigits < 0) -1L else 1L)
+              }
+            case _ =>
+              throw QueryExecutionErrors.unsupportedRoundingMode(roundMode)
+          }
         }
       } else if (scale > _scale) {
         // We might be able to multiply lv by a power of 10 and not overflow, 
but if not,
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
index 89b6cc0b0cb..73944d9dff9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
@@ -27,6 +27,9 @@ import org.apache.spark.sql.types.Decimal._
 import org.apache.spark.unsafe.types.UTF8String
 
 class DecimalSuite extends SparkFunSuite with PrivateMethodTester with 
SQLHelper {
+
+  val allSupportedRoundModes = Seq(ROUND_HALF_UP, ROUND_HALF_EVEN, 
ROUND_CEILING, ROUND_FLOOR)
+
   /** Check that a Decimal has the given string representation, precision and 
scale */
   private def checkDecimal(d: Decimal, string: String, precision: Int, scale: 
Int): Unit = {
     assert(d.toString === string)
@@ -278,7 +281,7 @@ class DecimalSuite extends SparkFunSuite with 
PrivateMethodTester with SQLHelper
   }
 
   test("changePrecision/toPrecision on compact decimal should respect rounding 
mode") {
-    Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { 
mode =>
+    allSupportedRoundModes.foreach { mode =>
       Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n 
=>
         Seq("", "-").foreach { sign =>
           val bd = BigDecimal(sign + n)
@@ -384,4 +387,51 @@ class DecimalSuite extends SparkFunSuite with 
PrivateMethodTester with SQLHelper
       }
     }
   }
+
+  // 18 is a max number of digits in Decimal's compact long
+  test("SPARK-41554: decrease/increase scale by 18 and more on compact 
decimal") {
+    val unscaledNums = Seq(
+      0L, 1L, 10L, 51L, 123L, 523L,
+      // 18 digits
+      912345678901234567L,
+      112345678901234567L,
+      512345678901234567L
+    )
+    val precision = 38
+    // generate some (from, to) scale pairs, e.g. (38, 18), (-20, -2), etc
+    val scalePairs = for {
+      scale <- Seq(38, 20, 19, 18)
+      delta <- Seq(38, 20, 19, 18)
+      a = scale
+      b = scale - delta
+    } yield {
+      Seq((a, b), (-a, -b), (b, a), (-b, -a))
+    }
+
+    for {
+      unscaled <- unscaledNums
+      mode <- allSupportedRoundModes
+      (scaleFrom, scaleTo) <- scalePairs.flatten
+      sign <- Seq(1L, -1L)
+    } {
+      val unscaledWithSign = unscaled * sign
+      if (scaleFrom < 0 || scaleTo < 0) {
+        withSQLConf(SQLConf.LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED.key 
-> "true") {
+          checkScaleChange(unscaledWithSign, scaleFrom, scaleTo, mode)
+        }
+      } else {
+        checkScaleChange(unscaledWithSign, scaleFrom, scaleTo, mode)
+      }
+    }
+
+    def checkScaleChange(unscaled: Long, scaleFrom: Int, scaleTo: Int,
+                         roundMode: BigDecimal.RoundingMode.Value): Unit = {
+      val decimal = Decimal(unscaled, precision, scaleFrom)
+      checkCompact(decimal, true)
+      decimal.changePrecision(precision, scaleTo, roundMode)
+      val bd = BigDecimal(unscaled, scaleFrom).setScale(scaleTo, roundMode)
+      assert(decimal.toBigDecimal === bd,
+        s"unscaled: $unscaled, scaleFrom: $scaleFrom, scaleTo: $scaleTo, mode: 
$roundMode")
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to