gengliangwang commented on a change in pull request #25461: [SPARK-28741][SQL]Throw exceptions when casting to integers causes overflow URL: https://github.com/apache/spark/pull/25461#discussion_r316707668
########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala ########## @@ -232,12 +232,149 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } + /** + * @return the Long value that is equal to the rounded decimal. + * @throws ArithmeticException if checkOverflow is true and + * the decimal too big to fit in Long type. + */ + def toLong(checkOverflow: Boolean): Long = { + if (!checkOverflow) { + toLong + } else { + roundToLong() + } + } + def toInt: Int = toLong.toInt + /** + * @return the Int value that is equal to the rounded decimal. + * @throws ArithmeticException if checkOverflow is true and + * the decimal too big to fit in Int type. + */ + def toInt(checkOverflow: Boolean): Int = { + if (!checkOverflow) { + toInt + } else { + roundToInt() + } + } + def toShort: Short = toLong.toShort + /** + * @return the Short value that is equal to the rounded decimal. + * @throws ArithmeticException if checkOverflow is true and + * the decimal is too big to fit in Short type. + */ + def toShort(checkOverflow: Boolean): Short = { + if (!checkOverflow) { + toShort + } else { + roundToShort() + } + } + def toByte: Byte = toLong.toByte + /** + * @return the Byte value that is equal to the rounded decimal. + * @throws ArithmeticException if checkOverflow is true and + * the decimal is too big to fit in Byte type. + */ + def toByte(checkOverflow: Boolean): Byte = { + if (!checkOverflow) { + toByte + } else { + roundToByte() + } + } + + private def overflowException(dataType: String) = + throw new ArithmeticException(s"Casting $this to $dataType causes overflow.") + + /** + * @return the Byte value that is equal to the rounded decimal. + * @throws ArithmeticException if the decimal is too big to fit in Byte type. + */ + private def roundToByte(): Byte = { + if (decimalVal.eq(null)) { + val actualLongVal = longVal / POW_10(_scale) + if (actualLongVal == actualLongVal.toByte) { + actualLongVal.toByte + } else { + overflowException("byte") + } + } else { + val doubleVal = decimalVal.toDouble + if (Math.floor(doubleVal) <= Byte.MaxValue && Math.ceil(doubleVal) >= Byte.MinValue) { + doubleVal.toByte + } else { + overflowException("byte") + } + } + } + + /** + * @return the Short value that is equal to the rounded decimal. + * @throws ArithmeticException if the decimal is too big to fit in Short type. + */ + private def roundToShort(): Short = { + if (decimalVal.eq(null)) { + val actualLongVal = longVal / POW_10(_scale) + if (actualLongVal == actualLongVal.toShort) { + actualLongVal.toShort + } else { + overflowException("short") + } + } else { + val doubleVal = decimalVal.toDouble + if (Math.floor(doubleVal) <= Short.MaxValue && Math.ceil(doubleVal) >= Short.MinValue) { + doubleVal.toShort + } else { + overflowException("short") + } + } + } + + /** + * @return the Int value that is equal to the rounded decimal. + * @throws ArithmeticException if the decimal too big to fit in Int type. + */ + private def roundToInt(): Int = { + if (decimalVal.eq(null)) { + val actualLongVal = longVal / POW_10(_scale) + if (actualLongVal == actualLongVal.toInt) { + actualLongVal.toInt + } else { + overflowException("int") + } + } else { + val doubleVal = decimalVal.toDouble + if (Math.floor(doubleVal) <= Int.MaxValue && Math.ceil(doubleVal) >= Int.MinValue) { + doubleVal.toInt + } else { + overflowException("int") + } + } + } + + /** + * @return the Long value that is equal to the rounded decimal. + * @throws ArithmeticException if the decimal too big to fit in Long type. + */ + private def roundToLong(): Long = { + if (decimalVal.eq(null)) { + longVal / POW_10(_scale) + } else { + try { + decimalVal.bigDecimal.toBigInteger.longValueExact() Review comment: I think the current implementation is simple and accurate. If we convert the value to double, then it won't be accurate; If we compare the value with another `Decimal`, then internally both values are converted to `BigDecimal`. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org