This is an automated email from the ASF dual-hosted git repository.
gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 061127b67d2 [SPARK-38548][SQL][FOLLOWUP] try_sum: return null if
overflow happens before merging
061127b67d2 is described below
commit 061127b67d2fbae0042505f1dabfad10eed4a782
Author: Gengliang Wang <[email protected]>
AuthorDate: Fri Apr 8 09:32:21 2022 +0800
[SPARK-38548][SQL][FOLLOWUP] try_sum: return null if overflow happens
before merging
### What changes were proposed in this pull request?
This PR is to fix a bug in the new function `try_sum`. It should return
null if overflow happens before merging the sums from map tasks.
For example:
MAP TASK 1: partial aggregation TRY_SUM(large_numbers_column) -> overflows,
turns into NULL
MAP TASK 2: partial aggregation TRY_SUM(large_numbers_column) -> succeeds,
returns 12345
REDUCE TASK: merge TRY_SUM(NULL, 12345) -> returns 12345
We should use a new slot buffer `isEmpty` to track if there is a non-empty
value in partial aggregation. If the partial result is null and there is
non-empty value, the merge result should be `NULL`.
### Why are the changes needed?
Bug fix
### Does this PR introduce _any_ user-facing change?
No, the new function is not release yet.
### How was this patch tested?
UT
Closes #36097 from gengliangwang/fixTrySum.
Lead-authored-by: Gengliang Wang <[email protected]>
Co-authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Gengliang Wang <[email protected]>
---
.../sql/catalyst/expressions/aggregate/Sum.scala | 131 ++++++++++++---------
.../scala/org/apache/spark/sql/SQLQuerySuite.scala | 13 ++
2 files changed, 86 insertions(+), 58 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index 5d8fd702ba4..fd27edfc8fc 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -32,6 +32,8 @@ abstract class SumBase(child: Expression) extends
DeclarativeAggregate
def failOnError: Boolean
+ protected def shouldTrackIsEmpty: Boolean
+
override def nullable: Boolean = true
// Return data type.
@@ -45,7 +47,7 @@ abstract class SumBase(child: Expression) extends
DeclarativeAggregate
final override val nodePatterns: Seq[TreePattern] = Seq(SUM)
- private lazy val resultType = child.dataType match {
+ protected lazy val resultType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType.bounded(precision + 10, scale)
case _: IntegralType => LongType
@@ -60,51 +62,51 @@ abstract class SumBase(child: Expression) extends
DeclarativeAggregate
private lazy val zero = Literal.default(resultType)
- override lazy val aggBufferAttributes = resultType match {
- case _: DecimalType => sum :: isEmpty :: Nil
- case _ => sum :: Nil
+ override lazy val aggBufferAttributes = if (shouldTrackIsEmpty) {
+ sum :: isEmpty :: Nil
+ } else {
+ sum :: Nil
}
- override lazy val initialValues: Seq[Expression] = resultType match {
- case _: DecimalType => Seq(zero, Literal(true, BooleanType))
- case _ => Seq(Literal(null, resultType))
- }
+ override lazy val initialValues: Seq[Expression] =
+ if (shouldTrackIsEmpty) {
+ Seq(zero, Literal(true, BooleanType))
+ } else {
+ Seq(Literal(null, resultType))
+ }
- protected def getUpdateExpressions: Seq[Expression] = {
- resultType match {
- case _: DecimalType =>
- // For decimal type, the initial value of `sum` is 0. We need to keep
`sum` unchanged if
- // the input is null, as SUM function ignores null input. The `sum`
can only be null if
- // overflow happens under non-ansi mode.
- val sumExpr = if (child.nullable) {
- If(child.isNull, sum,
- Add(sum, KnownNotNull(child).cast(resultType), failOnError =
failOnError))
- } else {
- Add(sum, child.cast(resultType), failOnError = failOnError)
- }
- // The buffer becomes non-empty after seeing the first not-null input.
- val isEmptyExpr = if (child.nullable) {
- isEmpty && child.isNull
- } else {
- Literal(false, BooleanType)
- }
- Seq(sumExpr, isEmptyExpr)
- case _ =>
- // For non-decimal type, the initial value of `sum` is null, which
indicates no value.
- // We need `coalesce(sum, zero)` to start summing values. And we need
an outer `coalesce`
- // in case the input is nullable. The `sum` can only be null if there
is no value, as
- // non-decimal type can produce overflowed value under non-ansi mode.
- if (child.nullable) {
- Seq(coalesce(Add(coalesce(sum, zero), child.cast(resultType),
failOnError = failOnError),
- sum))
- } else {
- Seq(Add(coalesce(sum, zero), child.cast(resultType), failOnError =
failOnError))
- }
+ protected def getUpdateExpressions: Seq[Expression] = if
(shouldTrackIsEmpty) {
+ // If shouldTrackIsEmpty is true, the initial value of `sum` is 0. We need
to keep `sum`
+ // unchanged if the input is null, as SUM function ignores null input. The
`sum` can only be
+ // null if overflow happens under non-ansi mode.
+ val sumExpr = if (child.nullable) {
+ If(child.isNull, sum,
+ Add(sum, KnownNotNull(child).cast(resultType), failOnError =
failOnError))
+ } else {
+ Add(sum, child.cast(resultType), failOnError = failOnError)
+ }
+ // The buffer becomes non-empty after seeing the first not-null input.
+ val isEmptyExpr = if (child.nullable) {
+ isEmpty && child.isNull
+ } else {
+ Literal(false, BooleanType)
+ }
+ Seq(sumExpr, isEmptyExpr)
+ } else {
+ // If shouldTrackIsEmpty is false, the initial value of `sum` is null,
which indicates no value.
+ // We need `coalesce(sum, zero)` to start summing values. And we need an
outer `coalesce`
+ // in case the input is nullable. The `sum` can only be null if there is
no value, as
+ // non-decimal type can produce overflowed value under non-ansi mode.
+ if (child.nullable) {
+ Seq(coalesce(Add(coalesce(sum, zero), child.cast(resultType),
failOnError = failOnError),
+ sum))
+ } else {
+ Seq(Add(coalesce(sum, zero), child.cast(resultType), failOnError =
failOnError))
}
}
/**
- * For decimal type:
+ * When shouldTrackIsEmpty is true:
* If isEmpty is false and if sum is null, then it means we have had an
overflow.
*
* update of the sum is as follows:
@@ -113,26 +115,24 @@ abstract class SumBase(child: Expression) extends
DeclarativeAggregate
* If it did not have overflow, then add the sum.left and sum.right
*
* isEmpty: Set to false if either one of the left or right is set to
false. This
- * means we have seen atleast a value that was not null.
+ * means we have seen at least a value that was not null.
*/
- protected def getMergeExpressions: Seq[Expression] = {
- resultType match {
- case _: DecimalType =>
- val bufferOverflow = !isEmpty.left && sum.left.isNull
- val inputOverflow = !isEmpty.right && sum.right.isNull
- Seq(
- If(
- bufferOverflow || inputOverflow,
- Literal.create(null, resultType),
- // If both the buffer and the input do not overflow, just add
them, as they can't be
- // null. See the comments inside `updateExpressions`: `sum` can
only be null if
- // overflow happens.
- KnownNotNull(sum.left) + KnownNotNull(sum.right)),
- isEmpty.left && isEmpty.right)
- case _ => Seq(coalesce(
- Add(coalesce(sum.left, zero), sum.right, failOnError = failOnError),
- sum.left))
- }
+ protected def getMergeExpressions: Seq[Expression] = if (shouldTrackIsEmpty)
{
+ val bufferOverflow = !isEmpty.left && sum.left.isNull
+ val inputOverflow = !isEmpty.right && sum.right.isNull
+ Seq(
+ If(
+ bufferOverflow || inputOverflow,
+ Literal.create(null, resultType),
+ // If both the buffer and the input do not overflow, just add them, as
they can't be
+ // null. See the comments inside `updateExpressions`: `sum` can only
be null if
+ // overflow happens.
+ Add(KnownNotNull(sum.left), KnownNotNull(sum.right), failOnError)),
+ isEmpty.left && isEmpty.right)
+ } else {
+ Seq(coalesce(
+ Add(coalesce(sum.left, zero), sum.right, failOnError = failOnError),
+ sum.left))
}
/**
@@ -146,6 +146,8 @@ abstract class SumBase(child: Expression) extends
DeclarativeAggregate
case d: DecimalType =>
If(isEmpty, Literal.create(null, resultType),
CheckOverflowInSum(sum, d, !failOnError))
+ case _ if shouldTrackIsEmpty =>
+ If(isEmpty, Literal.create(null, resultType), sum)
case _ => sum
}
@@ -172,6 +174,11 @@ case class Sum(
extends SumBase(child) {
def this(child: Expression) = this(child, failOnError =
SQLConf.get.ansiEnabled)
+ override def shouldTrackIsEmpty: Boolean = resultType match {
+ case _: DecimalType => true
+ case _ => false
+ }
+
override protected def withNewChildInternal(newChild: Expression): Sum =
copy(child = newChild)
override lazy val updateExpressions: Seq[Expression] = getUpdateExpressions
@@ -208,6 +215,14 @@ case class TrySum(child: Expression) extends
SumBase(child) {
case _ => true
}
+ override def shouldTrackIsEmpty: Boolean = resultType match {
+ // The sum of following data types can cause overflow.
+ case _: DecimalType | _: IntegralType | _: YearMonthIntervalType | _:
DayTimeIntervalType =>
+ true
+ case _ =>
+ false
+ }
+
override lazy val updateExpressions: Seq[Expression] =
if (failOnError) {
val expressions = getUpdateExpressions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index d720e542b40..f25a3e399aa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -4328,6 +4328,19 @@ class SQLQuerySuite extends QueryTest with
SharedSparkSession with AdaptiveSpark
Row(3, 2, 6) :: Nil)
}
}
+
+ test("SPARK-38548: try_sum should return null if overflow happens before
merging") {
+ val longDf = Seq(Long.MaxValue, Long.MaxValue, 2).toDF("v")
+ val yearMonthDf = Seq(Int.MaxValue, Int.MaxValue, 2)
+ .map(Period.ofMonths)
+ .toDF("v")
+ val dayTimeDf = Seq(106751991L, 106751991L, 2L)
+ .map(Duration.ofDays)
+ .toDF("v")
+ Seq(longDf, yearMonthDf, dayTimeDf).foreach { df =>
+ checkAnswer(df.repartitionByRange(2, col("v")).selectExpr("try_sum(v)"),
Row(null))
+ }
+ }
}
case class Foo(bar: Option[String])
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]