This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5318846db1e3 [SPARK-47641][SQL] Improve the performance for
`UnaryMinus` and `Abs`
5318846db1e3 is described below
commit 5318846db1e367b17bb04366aa57419867e6b538
Author: panbingkun <[email protected]>
AuthorDate: Fri Mar 29 09:05:19 2024 -0700
[SPARK-47641][SQL] Improve the performance for `UnaryMinus` and `Abs`
### What changes were proposed in this pull request?
The pr aims to improve the `performance` for `UnaryMinus` and `Abs`.
### Why are the changes needed?
We can further `improve the performance` of `UnaryMinus` and `Abs` by the
following suggestions:
<img width="905" alt="image"
src="https://github.com/apache/spark/assets/15246973/456b142d-a15d-408e-8aad-91b53841fc16">
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
- Manually test.
- Pass GA.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #45766 from panbingkun/improve_UnaryMinus.
Authored-by: panbingkun <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../apache/spark/sql/catalyst/util/MathUtils.scala | 14 ++++++++++++++
.../spark/sql/catalyst/expressions/arithmetic.scala | 20 ++++----------------
.../scala/org/apache/spark/sql/types/numerics.scala | 12 +++---------
3 files changed, 21 insertions(+), 25 deletions(-)
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
index 99caef978bb4..96c3fb81aa66 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
@@ -61,6 +61,20 @@ object MathUtils {
withOverflow(Math.multiplyExact(a, b), hint = "try_multiply", context)
}
+ def negateExact(a: Byte): Byte = {
+ if (a == Byte.MinValue) { // if and only if x is Byte.MinValue, overflow
can happen
+ throw ExecutionErrors.arithmeticOverflowError("byte overflow")
+ }
+ (-a).toByte
+ }
+
+ def negateExact(a: Short): Short = {
+ if (a == Short.MinValue) { // if and only if x is Short.MinValue, overflow
can happen
+ throw ExecutionErrors.arithmeticOverflowError("short overflow")
+ }
+ (-a).toShort
+ }
+
def negateExact(a: Int): Int = withOverflow(Math.negateExact(a))
def negateExact(a: Long): Long = withOverflow(Math.negateExact(a))
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 4e54e7890e1a..9eecf81684ce 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -61,14 +61,9 @@ case class UnaryMinus(
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
dataType match {
case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
case ByteType | ShortType | IntegerType | LongType if failOnError =>
- val typeUtils = TypeUtils.getClass.getCanonicalName.stripSuffix("$")
- val refDataType = ctx.addReferenceObj("refDataType", dataType,
dataType.getClass.getName)
+ val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
nullSafeCodeGen(ctx, ev, eval => {
- val javaBoxedType = CodeGenerator.boxedType(dataType)
- s"""
- |${ev.value} = ($javaBoxedType)$typeUtils.getNumeric(
- | $refDataType, $failOnError).negate($eval);
- """.stripMargin
+ s"${ev.value} = $mathUtils.negateExact($eval);"
})
case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => {
val originValue = ctx.freshName("origin")
@@ -174,15 +169,8 @@ case class Abs(child: Expression, failOnError: Boolean =
SQLConf.get.ansiEnabled
defineCodeGen(ctx, ev, c => s"$c.abs()")
case ByteType | ShortType | IntegerType | LongType if failOnError =>
- val typeUtils = TypeUtils.getClass.getCanonicalName.stripSuffix("$")
- val refDataType = ctx.addReferenceObj("refDataType", dataType,
dataType.getClass.getName)
- nullSafeCodeGen(ctx, ev, eval => {
- val javaBoxedType = CodeGenerator.boxedType(dataType)
- s"""
- |${ev.value} = ($javaBoxedType)$typeUtils.getNumeric(
- | $refDataType, $failOnError).abs($eval);
- """.stripMargin
- })
+ val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
+ defineCodeGen(ctx, ev, c => s"$c < 0 ? $mathUtils.negateExact($c) : $c")
case _: AnsiIntervalType =>
val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala
index 45b6cb44e5fa..19b1b5d8af26 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.types
import scala.math.Numeric._
import org.apache.spark.sql.catalyst.util.{MathUtils, SQLOrderingUtil}
-import org.apache.spark.sql.errors.{ExecutionErrors, QueryExecutionErrors}
+import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.Decimal.DecimalIsConflicted
private[sql] object ByteExactNumeric extends ByteIsIntegral with
Ordering.ByteOrdering {
@@ -49,10 +49,7 @@ private[sql] object ByteExactNumeric extends ByteIsIntegral
with Ordering.ByteOr
}
override def negate(x: Byte): Byte = {
- if (x == Byte.MinValue) { // if and only if x is Byte.MinValue, overflow
can happen
- throw ExecutionErrors.arithmeticOverflowError("byte overflow")
- }
- (-x).toByte
+ MathUtils.negateExact(x)
}
}
@@ -83,10 +80,7 @@ private[sql] object ShortExactNumeric extends
ShortIsIntegral with Ordering.Shor
}
override def negate(x: Short): Short = {
- if (x == Short.MinValue) { // if and only if x is Byte.MinValue, overflow
can happen
- throw ExecutionErrors.arithmeticOverflowError("short overflow")
- }
- (-x).toShort
+ MathUtils.negateExact(x)
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]