This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1dc50d79115 [SPARK-40285][SQL] Simplify the `roundTo[Numeric]` for 
Spark `Decimal`
1dc50d79115 is described below

commit 1dc50d7911586ee48ca0d067617417680f0f19ca
Author: Jiaan Geng <belie...@163.com>
AuthorDate: Thu Sep 1 12:41:12 2022 +0800

    [SPARK-40285][SQL] Simplify the `roundTo[Numeric]` for Spark `Decimal`
    
    ### What changes were proposed in this pull request?
    Simplify the `roundTo[Numeric]` for Spark `Decimal`.
    
    ### Why are the changes needed?
    Spark `Decimal` have a lot of methods named `roundTo[*]`.
    Except `roundToLong`, everything else is redundant.
    
    ### Does this PR introduce _any_ user-facing change?
    'No'.
    Just simplify the inner implementation.
    
    ### How was this patch tested?
    N/A
    
    Closes #37736 from beliefer/SPARK-40285.
    
    Authored-by: Jiaan Geng <belie...@163.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../scala/org/apache/spark/sql/types/Decimal.scala | 61 ++++++----------------
 1 file changed, 16 insertions(+), 45 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 69eff6de4b9..57e8fc060a2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -258,70 +258,41 @@ final class Decimal extends Ordered[Decimal] with 
Serializable {
    * @return the Byte value that is equal to the rounded decimal.
    * @throws ArithmeticException if the decimal is too big to fit in Byte type.
    */
-  private[sql] def roundToByte(): Byte = {
-    if (decimalVal.eq(null)) {
-      val actualLongVal = longVal / POW_10(_scale)
-      if (actualLongVal == actualLongVal.toByte) {
-        actualLongVal.toByte
-      } else {
-        throw QueryExecutionErrors.castingCauseOverflowError(
-          this, DecimalType(this.precision, this.scale), ByteType)
-      }
-    } else {
-      val doubleVal = decimalVal.toDouble
-      if (Math.floor(doubleVal) <= Byte.MaxValue && Math.ceil(doubleVal) >= 
Byte.MinValue) {
-        doubleVal.toByte
-      } else {
-        throw QueryExecutionErrors.castingCauseOverflowError(
-          this, DecimalType(this.precision, this.scale), ByteType)
-      }
-    }
-  }
+  private[sql] def roundToByte(): Byte =
+    roundToNumeric[Byte](ByteType, Byte.MaxValue, Byte.MinValue) (_.toByte) 
(_.toByte)
 
   /**
    * @return the Short value that is equal to the rounded decimal.
    * @throws ArithmeticException if the decimal is too big to fit in Short 
type.
    */
-  private[sql] def roundToShort(): Short = {
-    if (decimalVal.eq(null)) {
-      val actualLongVal = longVal / POW_10(_scale)
-      if (actualLongVal == actualLongVal.toShort) {
-        actualLongVal.toShort
-      } else {
-        throw QueryExecutionErrors.castingCauseOverflowError(
-          this, DecimalType(this.precision, this.scale), ShortType)
-      }
-    } else {
-      val doubleVal = decimalVal.toDouble
-      if (Math.floor(doubleVal) <= Short.MaxValue && Math.ceil(doubleVal) >= 
Short.MinValue) {
-        doubleVal.toShort
-      } else {
-        throw QueryExecutionErrors.castingCauseOverflowError(
-          this, DecimalType(this.precision, this.scale), ShortType)
-      }
-    }
-  }
+  private[sql] def roundToShort(): Short =
+    roundToNumeric[Short](ShortType, Short.MaxValue, Short.MinValue) 
(_.toShort) (_.toShort)
 
   /**
    * @return the Int value that is equal to the rounded decimal.
    * @throws ArithmeticException if the decimal too big to fit in Int type.
    */
-  private[sql] def roundToInt(): Int = {
+  private[sql] def roundToInt(): Int =
+    roundToNumeric[Int](IntegerType, Int.MaxValue, Int.MinValue) (_.toInt) 
(_.toInt)
+
+  private def roundToNumeric[T <: AnyVal](integralType: IntegralType, 
maxValue: Int, minValue: Int)
+      (f1: Long => T) (f2: Double => T): T = {
     if (decimalVal.eq(null)) {
       val actualLongVal = longVal / POW_10(_scale)
-      if (actualLongVal == actualLongVal.toInt) {
-        actualLongVal.toInt
+      val numericVal = f1(actualLongVal)
+      if (actualLongVal == numericVal) {
+        numericVal
       } else {
         throw QueryExecutionErrors.castingCauseOverflowError(
-          this, DecimalType(this.precision, this.scale), IntegerType)
+          this, DecimalType(this.precision, this.scale), integralType)
       }
     } else {
       val doubleVal = decimalVal.toDouble
-      if (Math.floor(doubleVal) <= Int.MaxValue && Math.ceil(doubleVal) >= 
Int.MinValue) {
-        doubleVal.toInt
+      if (Math.floor(doubleVal) <= maxValue && Math.ceil(doubleVal) >= 
minValue) {
+        f2(doubleVal)
       } else {
         throw QueryExecutionErrors.castingCauseOverflowError(
-          this, DecimalType(this.precision, this.scale), IntegerType)
+          this, DecimalType(this.precision, this.scale), integralType)
       }
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to