This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 5318846db1e3 [SPARK-47641][SQL] Improve the performance for `UnaryMinus` and `Abs` 5318846db1e3 is described below commit 5318846db1e367b17bb04366aa57419867e6b538 Author: panbingkun <panbing...@baidu.com> AuthorDate: Fri Mar 29 09:05:19 2024 -0700 [SPARK-47641][SQL] Improve the performance for `UnaryMinus` and `Abs` ### What changes were proposed in this pull request? The pr aims to improve the `performance` for `UnaryMinus` and `Abs`. ### Why are the changes needed? We can further `improve the performance` of `UnaryMinus` and `Abs` by the following suggestions: <img width="905" alt="image" src="https://github.com/apache/spark/assets/15246973/456b142d-a15d-408e-8aad-91b53841fc16"> ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - Manually test. - Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #45766 from panbingkun/improve_UnaryMinus. Authored-by: panbingkun <panbing...@baidu.com> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- .../apache/spark/sql/catalyst/util/MathUtils.scala | 14 ++++++++++++++ .../spark/sql/catalyst/expressions/arithmetic.scala | 20 ++++---------------- .../scala/org/apache/spark/sql/types/numerics.scala | 12 +++--------- 3 files changed, 21 insertions(+), 25 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala index 99caef978bb4..96c3fb81aa66 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala @@ -61,6 +61,20 @@ object MathUtils { withOverflow(Math.multiplyExact(a, b), hint = "try_multiply", context) } + def negateExact(a: Byte): Byte = { + if (a == Byte.MinValue) { // if and only if x is Byte.MinValue, overflow can happen + throw ExecutionErrors.arithmeticOverflowError("byte overflow") + } + (-a).toByte + } + + def negateExact(a: Short): Short = { + if (a == Short.MinValue) { // if and only if x is Short.MinValue, overflow can happen + throw ExecutionErrors.arithmeticOverflowError("short overflow") + } + (-a).toShort + } + def negateExact(a: Int): Int = withOverflow(Math.negateExact(a)) def negateExact(a: Long): Long = withOverflow(Math.negateExact(a)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 4e54e7890e1a..9eecf81684ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -61,14 +61,9 @@ case class UnaryMinus( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") case ByteType | ShortType | IntegerType | LongType if failOnError => - val typeUtils = TypeUtils.getClass.getCanonicalName.stripSuffix("$") - val refDataType = ctx.addReferenceObj("refDataType", dataType, dataType.getClass.getName) + val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$") nullSafeCodeGen(ctx, ev, eval => { - val javaBoxedType = CodeGenerator.boxedType(dataType) - s""" - |${ev.value} = ($javaBoxedType)$typeUtils.getNumeric( - | $refDataType, $failOnError).negate($eval); - """.stripMargin + s"${ev.value} = $mathUtils.negateExact($eval);" }) case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => { val originValue = ctx.freshName("origin") @@ -174,15 +169,8 @@ case class Abs(child: Expression, failOnError: Boolean = SQLConf.get.ansiEnabled defineCodeGen(ctx, ev, c => s"$c.abs()") case ByteType | ShortType | IntegerType | LongType if failOnError => - val typeUtils = TypeUtils.getClass.getCanonicalName.stripSuffix("$") - val refDataType = ctx.addReferenceObj("refDataType", dataType, dataType.getClass.getName) - nullSafeCodeGen(ctx, ev, eval => { - val javaBoxedType = CodeGenerator.boxedType(dataType) - s""" - |${ev.value} = ($javaBoxedType)$typeUtils.getNumeric( - | $refDataType, $failOnError).abs($eval); - """.stripMargin - }) + val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$") + defineCodeGen(ctx, ev, c => s"$c < 0 ? $mathUtils.negateExact($c) : $c") case _: AnsiIntervalType => val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala index 45b6cb44e5fa..19b1b5d8af26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.types import scala.math.Numeric._ import org.apache.spark.sql.catalyst.util.{MathUtils, SQLOrderingUtil} -import org.apache.spark.sql.errors.{ExecutionErrors, QueryExecutionErrors} +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.Decimal.DecimalIsConflicted private[sql] object ByteExactNumeric extends ByteIsIntegral with Ordering.ByteOrdering { @@ -49,10 +49,7 @@ private[sql] object ByteExactNumeric extends ByteIsIntegral with Ordering.ByteOr } override def negate(x: Byte): Byte = { - if (x == Byte.MinValue) { // if and only if x is Byte.MinValue, overflow can happen - throw ExecutionErrors.arithmeticOverflowError("byte overflow") - } - (-x).toByte + MathUtils.negateExact(x) } } @@ -83,10 +80,7 @@ private[sql] object ShortExactNumeric extends ShortIsIntegral with Ordering.Shor } override def negate(x: Short): Short = { - if (x == Short.MinValue) { // if and only if x is Byte.MinValue, overflow can happen - throw ExecutionErrors.arithmeticOverflowError("short overflow") - } - (-x).toShort + MathUtils.negateExact(x) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org