gengliangwang commented on a change in pull request #25461: 
[SPARK-28741][SQL]Throw exceptions when casting to integers causes overflow
URL: https://github.com/apache/spark/pull/25461#discussion_r316707668
 
 

 ##########
 File path: sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
 ##########
 @@ -232,12 +232,149 @@ final class Decimal extends Ordered[Decimal] with 
Serializable {
     }
   }
 
+  /**
+   * @return the Long value that is equal to the rounded decimal.
+   * @throws ArithmeticException if checkOverflow is true and
+   *                             the decimal too big to fit in Long type.
+   */
+  def toLong(checkOverflow: Boolean): Long = {
+    if (!checkOverflow) {
+      toLong
+    } else {
+      roundToLong()
+    }
+  }
+
   def toInt: Int = toLong.toInt
 
+  /**
+   * @return the Int value that is equal to the rounded decimal.
+   * @throws ArithmeticException if checkOverflow is true and
+   *                             the decimal too big to fit in Int type.
+   */
+  def toInt(checkOverflow: Boolean): Int = {
+    if (!checkOverflow) {
+      toInt
+    } else {
+      roundToInt()
+    }
+  }
+
   def toShort: Short = toLong.toShort
 
+  /**
+   * @return the Short value that is equal to the rounded decimal.
+   * @throws ArithmeticException if checkOverflow is true and
+   *                             the decimal is too big to fit in Short type.
+   */
+  def toShort(checkOverflow: Boolean): Short = {
+    if (!checkOverflow) {
+      toShort
+    } else {
+      roundToShort()
+    }
+  }
+
   def toByte: Byte = toLong.toByte
 
+  /**
+   * @return the Byte value that is equal to the rounded decimal.
+   * @throws ArithmeticException if checkOverflow is true and
+   *                             the decimal is too big to fit in Byte type.
+   */
+  def toByte(checkOverflow: Boolean): Byte = {
+    if (!checkOverflow) {
+      toByte
+    } else {
+      roundToByte()
+    }
+  }
+
+  private def overflowException(dataType: String) =
+    throw new ArithmeticException(s"Casting $this to $dataType causes 
overflow.")
+
+  /**
+   * @return the Byte value that is equal to the rounded decimal.
+   * @throws ArithmeticException if the decimal is too big to fit in Byte type.
+   */
+  private def roundToByte(): Byte = {
+    if (decimalVal.eq(null)) {
+      val actualLongVal = longVal / POW_10(_scale)
+      if (actualLongVal == actualLongVal.toByte) {
+        actualLongVal.toByte
+      } else {
+        overflowException("byte")
+      }
+    } else {
+      val doubleVal = decimalVal.toDouble
+      if (Math.floor(doubleVal) <= Byte.MaxValue && Math.ceil(doubleVal) >= 
Byte.MinValue) {
+        doubleVal.toByte
+      } else {
+        overflowException("byte")
+      }
+    }
+  }
+
+  /**
+   * @return the Short value that is equal to the rounded decimal.
+   * @throws ArithmeticException if the decimal is too big to fit in Short 
type.
+   */
+  private def roundToShort(): Short = {
+    if (decimalVal.eq(null)) {
+      val actualLongVal = longVal / POW_10(_scale)
+      if (actualLongVal == actualLongVal.toShort) {
+        actualLongVal.toShort
+      } else {
+        overflowException("short")
+      }
+    } else {
+      val doubleVal = decimalVal.toDouble
+      if (Math.floor(doubleVal) <= Short.MaxValue && Math.ceil(doubleVal) >= 
Short.MinValue) {
+        doubleVal.toShort
+      } else {
+        overflowException("short")
+      }
+    }
+  }
+
+  /**
+   * @return the Int value that is equal to the rounded decimal.
+   * @throws ArithmeticException if the decimal too big to fit in Int type.
+   */
+  private def roundToInt(): Int = {
+    if (decimalVal.eq(null)) {
+      val actualLongVal = longVal / POW_10(_scale)
+      if (actualLongVal == actualLongVal.toInt) {
+        actualLongVal.toInt
+      } else {
+        overflowException("int")
+      }
+    } else {
+      val doubleVal = decimalVal.toDouble
+      if (Math.floor(doubleVal) <= Int.MaxValue && Math.ceil(doubleVal) >= 
Int.MinValue) {
+        doubleVal.toInt
+      } else {
+        overflowException("int")
+      }
+    }
+  }
+
+  /**
+   * @return the Long value that is equal to the rounded decimal.
+   * @throws ArithmeticException if the decimal too big to fit in Long type.
+   */
+  private def roundToLong(): Long = {
+    if (decimalVal.eq(null)) {
+      longVal / POW_10(_scale)
+    } else {
+      try {
+        decimalVal.bigDecimal.toBigInteger.longValueExact()
 
 Review comment:
   I think the current implementation is simple and accurate.
   If we convert the value to double, then it won't be accurate;
   If we compare the value with another `Decimal`, then internally both values 
are converted to `BigDecimal`.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to