This is an automated email from the ASF dual-hosted git repository.
gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 42721120f3c [SPARK-42045][SQL] ANSI SQL mode: Round/Bround should
return an error on integer overflow
42721120f3c is described below
commit 42721120f3c7206a9fc22db5d0bb7cf40f0cacfd
Author: Gengliang Wang <[email protected]>
AuthorDate: Fri Jan 13 09:40:36 2023 -0800
[SPARK-42045][SQL] ANSI SQL mode: Round/Bround should return an error on
integer overflow
### What changes were proposed in this pull request?
In ANSI SQL mode, Round/Bround should return an error on integer overflow.
Note this PR is for integer only. Once it is merge, I will create one
follow-up PR for all the rest integral types: byte, short, and long.
Also, the function ceil and floor accepts decimal type input, so there is
no need to change them.
### Why are the changes needed?
In ANSI SQL mode, integer overflow should cause error instead of returning
an unreasonable result.
For example, `round(2147483647, -1)` should return error instead of
returning `-2147483646`
### Does this PR introduce _any_ user-facing change?
Yes, in ANSI SQL mode, SQL function Round and Bround will return an error
on integer overflow
### How was this patch tested?
UT
Closes #39546 from gengliangwang/fixRound.
Authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Gengliang Wang <[email protected]>
---
.../sql/catalyst/expressions/mathExpressions.scala | 60 +++++--
.../apache/spark/sql/catalyst/util/MathUtils.scala | 12 +-
.../expressions/MathExpressionsSuite.scala | 15 +-
.../catalyst/util/PhysicalAggregationSuite.scala | 2 +-
.../test/resources/sql-tests/inputs/ansi/math.sql | 1 +
.../src/test/resources/sql-tests/inputs/math.sql | 17 ++
.../resources/sql-tests/results/ansi/math.sql.out | 175 +++++++++++++++++++++
.../test/resources/sql-tests/results/math.sql.out | 111 +++++++++++++
8 files changed, 381 insertions(+), 12 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index 9ffc148180a..50a1194c2f1 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -26,8 +26,10 @@ import
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch,
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.catalyst.util.{NumberConverter, TypeUtils}
+import org.apache.spark.sql.catalyst.trees.SQLQueryContext
+import org.apache.spark.sql.catalyst.util.{MathUtils, NumberConverter,
TypeUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors,
QueryExecutionErrors}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -1447,11 +1449,13 @@ case class Logarithm(left: Expression, right:
Expression)
*/
abstract class RoundBase(child: Expression, scale: Expression,
mode: BigDecimal.RoundingMode.Value, modeStr: String)
- extends BinaryExpression with Serializable with ImplicitCastInputTypes {
+ extends BinaryExpression with Serializable with ImplicitCastInputTypes with
SupportQueryContext {
override def left: Expression = child
override def right: Expression = scale
+ protected def ansiEnabled: Boolean = false
+
// round of Decimal would eval to null if it fails to `changePrecision`
override def nullable: Boolean = true
@@ -1501,6 +1505,14 @@ abstract class RoundBase(child: Expression, scale:
Expression,
private lazy val scaleV: Any = scale.eval(EmptyRow)
protected lazy val _scale: Int = scaleV.asInstanceOf[Int]
+ override def initQueryContext(): Option[SQLQueryContext] = {
+ if (ansiEnabled) {
+ Some(origin.context)
+ } else {
+ None
+ }
+ }
+
override def eval(input: InternalRow): Any = {
if (scaleV == null) { // if scale is null, no need to eval its child at all
null
@@ -1529,6 +1541,10 @@ abstract class RoundBase(child: Expression, scale:
Expression,
BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte
case ShortType =>
BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, mode).toShort
+ case IntegerType if ansiEnabled =>
+ MathUtils.withOverflow(
+ f = BigDecimal(input1.asInstanceOf[Int]).setScale(_scale,
mode).toIntExact,
+ context = getContextOrNull)
case IntegerType =>
BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, mode).toInt
case LongType =>
@@ -1584,9 +1600,19 @@ abstract class RoundBase(child: Expression, scale:
Expression,
}
case IntegerType =>
if (_scale < 0) {
- s"""
- ${ev.value} = new java.math.BigDecimal(${ce.value}).
- setScale(${_scale}, java.math.BigDecimal.${modeStr}).intValue();"""
+ if (ansiEnabled) {
+ val errorContext = getContextOrNullCode(ctx)
+ val evalCode = s"""
+ |${ev.value} = new java.math.BigDecimal(${ce.value}).
+ |setScale(${_scale},
java.math.BigDecimal.${modeStr}).intValueExact();
+ |""".stripMargin
+ MathUtils.withOverflowCode(evalCode, errorContext)
+ } else {
+ s"""
+ |${ev.value} = new java.math.BigDecimal(${ce.value}).
+ |setScale(${_scale},
java.math.BigDecimal.${modeStr}).intValue();
+ |""".stripMargin
+ }
} else {
s"${ev.value} = ${ce.value};"
}
@@ -1648,9 +1674,17 @@ abstract class RoundBase(child: Expression, scale:
Expression,
since = "1.5.0",
group = "math_funcs")
// scalastyle:on line.size.limit
-case class Round(child: Expression, scale: Expression)
+case class Round(
+ child: Expression,
+ scale: Expression,
+ override val ansiEnabled: Boolean = SQLConf.get.ansiEnabled)
extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_UP,
"ROUND_HALF_UP") {
- def this(child: Expression) = this(child, Literal(0))
+ def this(child: Expression) = this(child, Literal(0),
SQLConf.get.ansiEnabled)
+
+ def this(child: Expression, scale: Expression) = this(child, scale,
SQLConf.get.ansiEnabled)
+
+ override def flatArguments: Iterator[Any] = Iterator(child, scale)
+
override protected def withNewChildrenInternal(newLeft: Expression,
newRight: Expression): Round =
copy(child = newLeft, scale = newRight)
}
@@ -1673,9 +1707,17 @@ case class Round(child: Expression, scale: Expression)
since = "2.0.0",
group = "math_funcs")
// scalastyle:on line.size.limit
-case class BRound(child: Expression, scale: Expression)
+case class BRound(
+ child: Expression,
+ scale: Expression,
+ override val ansiEnabled: Boolean = SQLConf.get.ansiEnabled)
extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_EVEN,
"ROUND_HALF_EVEN") {
- def this(child: Expression) = this(child, Literal(0))
+ def this(child: Expression) = this(child, Literal(0),
SQLConf.get.ansiEnabled)
+
+ def this(child: Expression, scale: Expression) = this(child, scale,
SQLConf.get.ansiEnabled)
+
+ override def flatArguments: Iterator[Any] = Iterator(child, scale)
+
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): BRound = copy(child = newLeft,
scale = newRight)
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
index e79e483076d..b285b1df572 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala
@@ -75,7 +75,7 @@ object MathUtils {
def floorMod(a: Long, b: Long): Long = withOverflow(Math.floorMod(a, b))
- private def withOverflow[A](
+ def withOverflow[A](
f: => A,
hint: String = "",
context: SQLQueryContext = null): A = {
@@ -86,4 +86,14 @@ object MathUtils {
throw QueryExecutionErrors.arithmeticOverflowError(e.getMessage, hint,
context)
}
}
+
+ def withOverflowCode(evalCode: String, context: String): String = {
+ s"""
+ |try {
+ | $evalCode
+ |} catch (ArithmeticException e) {
+ | throw QueryExecutionErrors.arithmeticOverflowError(e.getMessage(),
"", $context);
+ |}
+ |""".stripMargin
+ }
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
index c78d72e7a98..92b683a7106 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
@@ -23,7 +23,7 @@ import java.time.temporal.ChronoUnit
import com.google.common.math.LongMath
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkArithmeticException, SparkFunSuite}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCoercion.implicitCast
import org.apache.spark.sql.catalyst.dsl.expressions._
@@ -838,6 +838,19 @@ class MathExpressionsSuite extends SparkFunSuite with
ExpressionEvalHelper {
checkEvaluation(checkDataTypeAndCast(RoundCeil(Literal(135.135),
Literal(-2))), Decimal(200))
}
+ test("SPARK-42045: integer overflow in round/bround") {
+ val input = 2147483647
+ val scale = -1
+ Seq(Round(input, scale, ansiEnabled = true),
+ BRound(input, scale, ansiEnabled = true)).foreach { expr =>
+ checkExceptionInExpression[SparkArithmeticException](expr, "Overflow")
+ }
+ Seq(Round(input, scale, ansiEnabled = false),
+ BRound(input, scale, ansiEnabled = false)).foreach { expr =>
+ checkEvaluation(expr, -2147483646)
+ }
+ }
+
test("SPARK-36922: Support ANSI intervals for SIGN/SIGNUM") {
checkEvaluation(Signum(Literal(Period.ZERO)), 0.0)
checkEvaluation(Signum(Literal(Period.ofYears(10))), 1.0)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/PhysicalAggregationSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/PhysicalAggregationSuite.scala
index cf9c9490fab..c0db9c61388 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/PhysicalAggregationSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/PhysicalAggregationSuite.scala
@@ -48,7 +48,7 @@ class PhysicalAggregationSuite extends PlanTest {
// Verify that Round's scale parameter is a Literal.
resultExpressions(1) match {
- case Alias(Round(_, _: Literal), _) =>
+ case Alias(Round(_, _: Literal, _), _) =>
case other => fail("unexpected result expression: " + other)
}
}
diff --git a/sql/core/src/test/resources/sql-tests/inputs/ansi/math.sql
b/sql/core/src/test/resources/sql-tests/inputs/ansi/math.sql
new file mode 100644
index 00000000000..5ee19c28ca6
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/ansi/math.sql
@@ -0,0 +1 @@
+--IMPORT math.sql
diff --git a/sql/core/src/test/resources/sql-tests/inputs/math.sql
b/sql/core/src/test/resources/sql-tests/inputs/math.sql
new file mode 100644
index 00000000000..df7210c4595
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/math.sql
@@ -0,0 +1,17 @@
+-- Round with integer input
+SELECT round(525, 1);
+SELECT round(525, 0);
+SELECT round(525, -1);
+SELECT round(525, -2);
+SELECT round(525, -3);
+SELECT round(2147483647, -1);
+SELECT round(-2147483647, -1);
+
+-- BRound with integer input
+SELECT bround(525, 1);
+SELECT bround(525, 0);
+SELECT bround(525, -1);
+SELECT bround(525, -2);
+SELECT bround(525, -3);
+SELECT bround(2147483647, -1);
+SELECT bround(-2147483647, -1);
\ No newline at end of file
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out
b/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out
new file mode 100644
index 00000000000..e7866b59047
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out
@@ -0,0 +1,175 @@
+-- Automatically generated by SQLQueryTestSuite
+-- !query
+SELECT round(525, 1)
+-- !query schema
+struct<round(525, 1):int>
+-- !query output
+525
+
+
+-- !query
+SELECT round(525, 0)
+-- !query schema
+struct<round(525, 0):int>
+-- !query output
+525
+
+
+-- !query
+SELECT round(525, -1)
+-- !query schema
+struct<round(525, -1):int>
+-- !query output
+530
+
+
+-- !query
+SELECT round(525, -2)
+-- !query schema
+struct<round(525, -2):int>
+-- !query output
+500
+
+
+-- !query
+SELECT round(525, -3)
+-- !query schema
+struct<round(525, -3):int>
+-- !query output
+1000
+
+
+-- !query
+SELECT round(2147483647, -1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+{
+ "errorClass" : "ARITHMETIC_OVERFLOW",
+ "sqlState" : "22003",
+ "messageParameters" : {
+ "alternative" : "",
+ "config" : "\"spark.sql.ansi.enabled\"",
+ "message" : "Overflow"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 28,
+ "fragment" : "round(2147483647, -1)"
+ } ]
+}
+
+
+-- !query
+SELECT round(-2147483647, -1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+{
+ "errorClass" : "ARITHMETIC_OVERFLOW",
+ "sqlState" : "22003",
+ "messageParameters" : {
+ "alternative" : "",
+ "config" : "\"spark.sql.ansi.enabled\"",
+ "message" : "Overflow"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 29,
+ "fragment" : "round(-2147483647, -1)"
+ } ]
+}
+
+
+-- !query
+SELECT bround(525, 1)
+-- !query schema
+struct<bround(525, 1):int>
+-- !query output
+525
+
+
+-- !query
+SELECT bround(525, 0)
+-- !query schema
+struct<bround(525, 0):int>
+-- !query output
+525
+
+
+-- !query
+SELECT bround(525, -1)
+-- !query schema
+struct<bround(525, -1):int>
+-- !query output
+520
+
+
+-- !query
+SELECT bround(525, -2)
+-- !query schema
+struct<bround(525, -2):int>
+-- !query output
+500
+
+
+-- !query
+SELECT bround(525, -3)
+-- !query schema
+struct<bround(525, -3):int>
+-- !query output
+1000
+
+
+-- !query
+SELECT bround(2147483647, -1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+{
+ "errorClass" : "ARITHMETIC_OVERFLOW",
+ "sqlState" : "22003",
+ "messageParameters" : {
+ "alternative" : "",
+ "config" : "\"spark.sql.ansi.enabled\"",
+ "message" : "Overflow"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 29,
+ "fragment" : "bround(2147483647, -1)"
+ } ]
+}
+
+
+-- !query
+SELECT bround(-2147483647, -1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+{
+ "errorClass" : "ARITHMETIC_OVERFLOW",
+ "sqlState" : "22003",
+ "messageParameters" : {
+ "alternative" : "",
+ "config" : "\"spark.sql.ansi.enabled\"",
+ "message" : "Overflow"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 30,
+ "fragment" : "bround(-2147483647, -1)"
+ } ]
+}
diff --git a/sql/core/src/test/resources/sql-tests/results/math.sql.out
b/sql/core/src/test/resources/sql-tests/results/math.sql.out
new file mode 100644
index 00000000000..693ce3e8cbf
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/math.sql.out
@@ -0,0 +1,111 @@
+-- Automatically generated by SQLQueryTestSuite
+-- !query
+SELECT round(525, 1)
+-- !query schema
+struct<round(525, 1):int>
+-- !query output
+525
+
+
+-- !query
+SELECT round(525, 0)
+-- !query schema
+struct<round(525, 0):int>
+-- !query output
+525
+
+
+-- !query
+SELECT round(525, -1)
+-- !query schema
+struct<round(525, -1):int>
+-- !query output
+530
+
+
+-- !query
+SELECT round(525, -2)
+-- !query schema
+struct<round(525, -2):int>
+-- !query output
+500
+
+
+-- !query
+SELECT round(525, -3)
+-- !query schema
+struct<round(525, -3):int>
+-- !query output
+1000
+
+
+-- !query
+SELECT round(2147483647, -1)
+-- !query schema
+struct<round(2147483647, -1):int>
+-- !query output
+-2147483646
+
+
+-- !query
+SELECT round(-2147483647, -1)
+-- !query schema
+struct<round(-2147483647, -1):int>
+-- !query output
+2147483646
+
+
+-- !query
+SELECT bround(525, 1)
+-- !query schema
+struct<bround(525, 1):int>
+-- !query output
+525
+
+
+-- !query
+SELECT bround(525, 0)
+-- !query schema
+struct<bround(525, 0):int>
+-- !query output
+525
+
+
+-- !query
+SELECT bround(525, -1)
+-- !query schema
+struct<bround(525, -1):int>
+-- !query output
+520
+
+
+-- !query
+SELECT bround(525, -2)
+-- !query schema
+struct<bround(525, -2):int>
+-- !query output
+500
+
+
+-- !query
+SELECT bround(525, -3)
+-- !query schema
+struct<bround(525, -3):int>
+-- !query output
+1000
+
+
+-- !query
+SELECT bround(2147483647, -1)
+-- !query schema
+struct<bround(2147483647, -1):int>
+-- !query output
+-2147483646
+
+
+-- !query
+SELECT bround(-2147483647, -1)
+-- !query schema
+struct<bround(-2147483647, -1):int>
+-- !query output
+2147483646
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]