This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e35e29a0517d [SPARK-46915][SQL] Simplify `UnaryMinus` `Abs` and align
error class
e35e29a0517d is described below
commit e35e29a0517db930e12fe801f0f0ab1a31c3b23e
Author: panbingkun <[email protected]>
AuthorDate: Fri Feb 2 20:33:31 2024 +0300
[SPARK-46915][SQL] Simplify `UnaryMinus` `Abs` and align error class
### What changes were proposed in this pull request?
The pr aims to:
- simplify `UnaryMinus` & `Abs`
- convert error-class `_LEGACY_ERROR_TEMP_2043` to `ARITHMETIC_OVERFLOW`,
and remove it.
### Why are the changes needed?
1.When the data type in `UnaryMinus` and `Abs` is `ByteType` or
`ShortType`, if `an overflow exception` occurs, the corresponding error class
is: `_LEGACY_ERROR_TEMP_2043`
But when the data type is `IntegerType` or `LongType`, if `an overflow
exception` occurs, its corresponding error class is: ARITHMETIC_OVERFLOW, We
should unify it.
2.In the `codegen` logic of `UnaryMinus` and `Abs`, there is a difference
between the logic of generating code when the data type is `ByteType` or
`ShortType` and when the data type is `IntegerType` or `LongType`. We can unify
it and simplify the code.
### Does this PR introduce _any_ user-facing change?
Yes,
### How was this patch tested?
- Update existed UT.
- Pass GA.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #44942 from panbingkun/UnaryMinus_improve.
Authored-by: panbingkun <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
---
.../src/main/resources/error/error-classes.json | 5 ---
.../sql/catalyst/expressions/arithmetic.scala | 45 ++++++++--------------
.../spark/sql/errors/QueryExecutionErrors.scala | 8 ----
.../org/apache/spark/sql/types/numerics.scala | 6 +--
.../expressions/ArithmeticExpressionSuite.scala | 27 ++++++++-----
5 files changed, 36 insertions(+), 55 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-classes.json
b/common/utils/src/main/resources/error/error-classes.json
index 136825ab374d..6d88f5ee511c 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -5747,11 +5747,6 @@
"<message>. If necessary set <ansiConfig> to false to bypass this error."
]
},
- "_LEGACY_ERROR_TEMP_2043" : {
- "message" : [
- "- <sqlValue> caused overflow."
- ]
- },
"_LEGACY_ERROR_TEMP_2045" : {
"message" : [
"Unsupported table change: <message>"
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 9f1b42ad84d3..0f95ae821ab0 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -60,23 +60,15 @@ case class UnaryMinus(
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
dataType match {
case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
- case ByteType | ShortType if failOnError =>
+ case ByteType | ShortType | IntegerType | LongType if failOnError =>
+ val typeUtils = TypeUtils.getClass.getCanonicalName.stripSuffix("$")
+ val refDataType = ctx.addReferenceObj("refDataType", dataType,
dataType.getClass.getName)
nullSafeCodeGen(ctx, ev, eval => {
val javaBoxedType = CodeGenerator.boxedType(dataType)
- val javaType = CodeGenerator.javaType(dataType)
- val originValue = ctx.freshName("origin")
s"""
- |$javaType $originValue = ($javaType)($eval);
- |if ($originValue == $javaBoxedType.MIN_VALUE) {
- | throw
QueryExecutionErrors.unaryMinusCauseOverflowError($originValue);
- |}
- |${ev.value} = ($javaType)(-($originValue));
- """.stripMargin
- })
- case IntegerType | LongType if failOnError =>
- val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
- nullSafeCodeGen(ctx, ev, eval => {
- s"${ev.value} = $mathUtils.negateExact($eval);"
+ |${ev.value} = ($javaBoxedType)$typeUtils.getNumeric(
+ | $refDataType, $failOnError).negate($eval);
+ """.stripMargin
})
case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => {
val originValue = ctx.freshName("origin")
@@ -181,23 +173,16 @@ case class Abs(child: Expression, failOnError: Boolean =
SQLConf.get.ansiEnabled
case _: DecimalType =>
defineCodeGen(ctx, ev, c => s"$c.abs()")
- case ByteType | ShortType if failOnError =>
- val javaBoxedType = CodeGenerator.boxedType(dataType)
- val javaType = CodeGenerator.javaType(dataType)
- nullSafeCodeGen(ctx, ev, eval =>
+ case ByteType | ShortType | IntegerType | LongType if failOnError =>
+ val typeUtils = TypeUtils.getClass.getCanonicalName.stripSuffix("$")
+ val refDataType = ctx.addReferenceObj("refDataType", dataType,
dataType.getClass.getName)
+ nullSafeCodeGen(ctx, ev, eval => {
+ val javaBoxedType = CodeGenerator.boxedType(dataType)
s"""
- |if ($eval == $javaBoxedType.MIN_VALUE) {
- | throw QueryExecutionErrors.unaryMinusCauseOverflowError($eval);
- |} else if ($eval < 0) {
- | ${ev.value} = ($javaType)-$eval;
- |} else {
- | ${ev.value} = $eval;
- |}
- |""".stripMargin)
-
- case IntegerType | LongType if failOnError =>
- val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
- defineCodeGen(ctx, ev, c => s"$c < 0 ? $mathUtils.negateExact($c) : $c")
+ |${ev.value} = ($javaBoxedType)$typeUtils.getNumeric(
+ | $refDataType, $failOnError).abs($eval);
+ """.stripMargin
+ })
case _: AnsiIntervalType =>
val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index b09885c904a5..9ff076c5fd50 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -601,14 +601,6 @@ private[sql] object QueryExecutionErrors extends
QueryErrorsBase with ExecutionE
summary = "")
}
- def unaryMinusCauseOverflowError(originValue: Int): SparkArithmeticException
= {
- new SparkArithmeticException(
- errorClass = "_LEGACY_ERROR_TEMP_2043",
- messageParameters = Map("sqlValue" -> toSQLValue(originValue,
IntegerType)),
- context = Array.empty,
- summary = "")
- }
-
def binaryArithmeticCauseOverflowError(
eval1: Short, symbol: String, eval2: Short): SparkArithmeticException = {
new SparkArithmeticException(
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala
index c3d893d82fce..45b6cb44e5fa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.types
import scala.math.Numeric._
import org.apache.spark.sql.catalyst.util.{MathUtils, SQLOrderingUtil}
-import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.errors.{ExecutionErrors, QueryExecutionErrors}
import org.apache.spark.sql.types.Decimal.DecimalIsConflicted
private[sql] object ByteExactNumeric extends ByteIsIntegral with
Ordering.ByteOrdering {
@@ -50,7 +50,7 @@ private[sql] object ByteExactNumeric extends ByteIsIntegral
with Ordering.ByteOr
override def negate(x: Byte): Byte = {
if (x == Byte.MinValue) { // if and only if x is Byte.MinValue, overflow
can happen
- throw QueryExecutionErrors.unaryMinusCauseOverflowError(x)
+ throw ExecutionErrors.arithmeticOverflowError("byte overflow")
}
(-x).toByte
}
@@ -84,7 +84,7 @@ private[sql] object ShortExactNumeric extends ShortIsIntegral
with Ordering.Shor
override def negate(x: Short): Short = {
if (x == Short.MinValue) { // if and only if x is Byte.MinValue, overflow
can happen
- throw QueryExecutionErrors.unaryMinusCauseOverflowError(x)
+ throw ExecutionErrors.arithmeticOverflowError("short overflow")
}
(-x).toShort
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index 7a80188d445d..89f0b95f5c18 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -29,7 +29,8 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.catalyst.trees.Origin
-import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.errors.DataTypeErrors.toSQLConf
+import org.apache.spark.sql.internal.{SqlApiConf, SQLConf}
import org.apache.spark.sql.types._
class ArithmeticExpressionSuite extends SparkFunSuite with
ExpressionEvalHelper {
@@ -116,14 +117,22 @@ class ArithmeticExpressionSuite extends SparkFunSuite
with ExpressionEvalHelper
checkEvaluation(UnaryMinus(Literal(Byte.MinValue)), Byte.MinValue)
}
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
- checkExceptionInExpression[ArithmeticException](
- UnaryMinus(Literal(Long.MinValue)), "overflow")
- checkExceptionInExpression[ArithmeticException](
- UnaryMinus(Literal(Int.MinValue)), "overflow")
- checkExceptionInExpression[ArithmeticException](
- UnaryMinus(Literal(Short.MinValue)), "overflow")
- checkExceptionInExpression[ArithmeticException](
- UnaryMinus(Literal(Byte.MinValue)), "overflow")
+ checkErrorInExpression[SparkArithmeticException](
+ UnaryMinus(Literal(Long.MinValue)), "ARITHMETIC_OVERFLOW",
+ Map("message" -> "long overflow", "alternative" -> "",
+ "config" -> toSQLConf(SqlApiConf.ANSI_ENABLED_KEY)))
+ checkErrorInExpression[SparkArithmeticException](
+ UnaryMinus(Literal(Int.MinValue)), "ARITHMETIC_OVERFLOW",
+ Map("message" -> "integer overflow", "alternative" -> "",
+ "config" -> toSQLConf(SqlApiConf.ANSI_ENABLED_KEY)))
+ checkErrorInExpression[SparkArithmeticException](
+ UnaryMinus(Literal(Short.MinValue)), "ARITHMETIC_OVERFLOW",
+ Map("message" -> "short overflow", "alternative" -> "",
+ "config" -> toSQLConf(SqlApiConf.ANSI_ENABLED_KEY)))
+ checkErrorInExpression[SparkArithmeticException](
+ UnaryMinus(Literal(Byte.MinValue)), "ARITHMETIC_OVERFLOW",
+ Map("message" -> "byte overflow", "alternative" -> "",
+ "config" -> toSQLConf(SqlApiConf.ANSI_ENABLED_KEY)))
checkEvaluation(UnaryMinus(positiveShortLit), (- positiveShort).toShort)
checkEvaluation(UnaryMinus(negativeShortLit), (- negativeShort).toShort)
checkEvaluation(UnaryMinus(positiveIntLit), - positiveInt)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]