This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 1dc50d79115 [SPARK-40285][SQL] Simplify the `roundTo[Numeric]` for Spark `Decimal` 1dc50d79115 is described below commit 1dc50d7911586ee48ca0d067617417680f0f19ca Author: Jiaan Geng <belie...@163.com> AuthorDate: Thu Sep 1 12:41:12 2022 +0800 [SPARK-40285][SQL] Simplify the `roundTo[Numeric]` for Spark `Decimal` ### What changes were proposed in this pull request? Simplify the `roundTo[Numeric]` for Spark `Decimal`. ### Why are the changes needed? Spark `Decimal` have a lot of methods named `roundTo[*]`. Except `roundToLong`, everything else is redundant. ### Does this PR introduce _any_ user-facing change? 'No'. Just simplify the inner implementation. ### How was this patch tested? N/A Closes #37736 from beliefer/SPARK-40285. Authored-by: Jiaan Geng <belie...@163.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../scala/org/apache/spark/sql/types/Decimal.scala | 61 ++++++---------------- 1 file changed, 16 insertions(+), 45 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 69eff6de4b9..57e8fc060a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -258,70 +258,41 @@ final class Decimal extends Ordered[Decimal] with Serializable { * @return the Byte value that is equal to the rounded decimal. * @throws ArithmeticException if the decimal is too big to fit in Byte type. */ - private[sql] def roundToByte(): Byte = { - if (decimalVal.eq(null)) { - val actualLongVal = longVal / POW_10(_scale) - if (actualLongVal == actualLongVal.toByte) { - actualLongVal.toByte - } else { - throw QueryExecutionErrors.castingCauseOverflowError( - this, DecimalType(this.precision, this.scale), ByteType) - } - } else { - val doubleVal = decimalVal.toDouble - if (Math.floor(doubleVal) <= Byte.MaxValue && Math.ceil(doubleVal) >= Byte.MinValue) { - doubleVal.toByte - } else { - throw QueryExecutionErrors.castingCauseOverflowError( - this, DecimalType(this.precision, this.scale), ByteType) - } - } - } + private[sql] def roundToByte(): Byte = + roundToNumeric[Byte](ByteType, Byte.MaxValue, Byte.MinValue) (_.toByte) (_.toByte) /** * @return the Short value that is equal to the rounded decimal. * @throws ArithmeticException if the decimal is too big to fit in Short type. */ - private[sql] def roundToShort(): Short = { - if (decimalVal.eq(null)) { - val actualLongVal = longVal / POW_10(_scale) - if (actualLongVal == actualLongVal.toShort) { - actualLongVal.toShort - } else { - throw QueryExecutionErrors.castingCauseOverflowError( - this, DecimalType(this.precision, this.scale), ShortType) - } - } else { - val doubleVal = decimalVal.toDouble - if (Math.floor(doubleVal) <= Short.MaxValue && Math.ceil(doubleVal) >= Short.MinValue) { - doubleVal.toShort - } else { - throw QueryExecutionErrors.castingCauseOverflowError( - this, DecimalType(this.precision, this.scale), ShortType) - } - } - } + private[sql] def roundToShort(): Short = + roundToNumeric[Short](ShortType, Short.MaxValue, Short.MinValue) (_.toShort) (_.toShort) /** * @return the Int value that is equal to the rounded decimal. * @throws ArithmeticException if the decimal too big to fit in Int type. */ - private[sql] def roundToInt(): Int = { + private[sql] def roundToInt(): Int = + roundToNumeric[Int](IntegerType, Int.MaxValue, Int.MinValue) (_.toInt) (_.toInt) + + private def roundToNumeric[T <: AnyVal](integralType: IntegralType, maxValue: Int, minValue: Int) + (f1: Long => T) (f2: Double => T): T = { if (decimalVal.eq(null)) { val actualLongVal = longVal / POW_10(_scale) - if (actualLongVal == actualLongVal.toInt) { - actualLongVal.toInt + val numericVal = f1(actualLongVal) + if (actualLongVal == numericVal) { + numericVal } else { throw QueryExecutionErrors.castingCauseOverflowError( - this, DecimalType(this.precision, this.scale), IntegerType) + this, DecimalType(this.precision, this.scale), integralType) } } else { val doubleVal = decimalVal.toDouble - if (Math.floor(doubleVal) <= Int.MaxValue && Math.ceil(doubleVal) >= Int.MinValue) { - doubleVal.toInt + if (Math.floor(doubleVal) <= maxValue && Math.ceil(doubleVal) >= minValue) { + f2(doubleVal) } else { throw QueryExecutionErrors.castingCauseOverflowError( - this, DecimalType(this.precision, this.scale), IntegerType) + this, DecimalType(this.precision, this.scale), integralType) } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org