This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5318846db1e3 [SPARK-47641][SQL] Improve the performance for 
`UnaryMinus` and `Abs`
5318846db1e3 is described below

commit 5318846db1e367b17bb04366aa57419867e6b538
Author: panbingkun <panbing...@baidu.com>
AuthorDate: Fri Mar 29 09:05:19 2024 -0700

    [SPARK-47641][SQL] Improve the performance for `UnaryMinus` and `Abs`
    
    ### What changes were proposed in this pull request?
    The pr aims to improve the `performance` for `UnaryMinus` and `Abs`.
    
    ### Why are the changes needed?
    We can further `improve the performance` of `UnaryMinus` and `Abs` by the 
following suggestions:
    <img width="905" alt="image" 
src="https://github.com/apache/spark/assets/15246973/456b142d-a15d-408e-8aad-91b53841fc16";>
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    - Manually test.
    - Pass GA.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #45766 from panbingkun/improve_UnaryMinus.
    
    Authored-by: panbingkun <panbing...@baidu.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../apache/spark/sql/catalyst/util/MathUtils.scala   | 14 ++++++++++++++
 .../spark/sql/catalyst/expressions/arithmetic.scala  | 20 ++++----------------
 .../scala/org/apache/spark/sql/types/numerics.scala  | 12 +++---------
 3 files changed, 21 insertions(+), 25 deletions(-)

diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
index 99caef978bb4..96c3fb81aa66 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
@@ -61,6 +61,20 @@ object MathUtils {
     withOverflow(Math.multiplyExact(a, b), hint = "try_multiply", context)
   }
 
+  def negateExact(a: Byte): Byte = {
+    if (a == Byte.MinValue) { // if and only if x is Byte.MinValue, overflow 
can happen
+      throw ExecutionErrors.arithmeticOverflowError("byte overflow")
+    }
+    (-a).toByte
+  }
+
+  def negateExact(a: Short): Short = {
+    if (a == Short.MinValue) { // if and only if x is Short.MinValue, overflow 
can happen
+      throw ExecutionErrors.arithmeticOverflowError("short overflow")
+    }
+    (-a).toShort
+  }
+
   def negateExact(a: Int): Int = withOverflow(Math.negateExact(a))
 
   def negateExact(a: Long): Long = withOverflow(Math.negateExact(a))
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 4e54e7890e1a..9eecf81684ce 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -61,14 +61,9 @@ case class UnaryMinus(
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = 
dataType match {
     case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
     case ByteType | ShortType | IntegerType | LongType if failOnError =>
-      val typeUtils = TypeUtils.getClass.getCanonicalName.stripSuffix("$")
-      val refDataType = ctx.addReferenceObj("refDataType", dataType, 
dataType.getClass.getName)
+      val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
       nullSafeCodeGen(ctx, ev, eval => {
-        val javaBoxedType = CodeGenerator.boxedType(dataType)
-        s"""
-           |${ev.value} = ($javaBoxedType)$typeUtils.getNumeric(
-           |  $refDataType, $failOnError).negate($eval);
-         """.stripMargin
+        s"${ev.value} = $mathUtils.negateExact($eval);"
       })
     case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => {
       val originValue = ctx.freshName("origin")
@@ -174,15 +169,8 @@ case class Abs(child: Expression, failOnError: Boolean = 
SQLConf.get.ansiEnabled
       defineCodeGen(ctx, ev, c => s"$c.abs()")
 
     case ByteType | ShortType | IntegerType | LongType if failOnError =>
-      val typeUtils = TypeUtils.getClass.getCanonicalName.stripSuffix("$")
-      val refDataType = ctx.addReferenceObj("refDataType", dataType, 
dataType.getClass.getName)
-      nullSafeCodeGen(ctx, ev, eval => {
-        val javaBoxedType = CodeGenerator.boxedType(dataType)
-        s"""
-           |${ev.value} = ($javaBoxedType)$typeUtils.getNumeric(
-           |  $refDataType, $failOnError).abs($eval);
-         """.stripMargin
-      })
+      val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
+      defineCodeGen(ctx, ev, c => s"$c < 0 ? $mathUtils.negateExact($c) : $c")
 
     case _: AnsiIntervalType =>
       val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala
index 45b6cb44e5fa..19b1b5d8af26 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.types
 import scala.math.Numeric._
 
 import org.apache.spark.sql.catalyst.util.{MathUtils, SQLOrderingUtil}
-import org.apache.spark.sql.errors.{ExecutionErrors, QueryExecutionErrors}
+import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.types.Decimal.DecimalIsConflicted
 
 private[sql] object ByteExactNumeric extends ByteIsIntegral with 
Ordering.ByteOrdering {
@@ -49,10 +49,7 @@ private[sql] object ByteExactNumeric extends ByteIsIntegral 
with Ordering.ByteOr
   }
 
   override def negate(x: Byte): Byte = {
-    if (x == Byte.MinValue) { // if and only if x is Byte.MinValue, overflow 
can happen
-      throw ExecutionErrors.arithmeticOverflowError("byte overflow")
-    }
-    (-x).toByte
+    MathUtils.negateExact(x)
   }
 }
 
@@ -83,10 +80,7 @@ private[sql] object ShortExactNumeric extends 
ShortIsIntegral with Ordering.Shor
   }
 
   override def negate(x: Short): Short = {
-    if (x == Short.MinValue) { // if and only if x is Byte.MinValue, overflow 
can happen
-      throw ExecutionErrors.arithmeticOverflowError("short overflow")
-    }
-    (-x).toShort
+    MathUtils.negateExact(x)
   }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to