This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 061127b67d2 [SPARK-38548][SQL][FOLLOWUP] try_sum: return null if 
overflow happens before merging
061127b67d2 is described below

commit 061127b67d2fbae0042505f1dabfad10eed4a782
Author: Gengliang Wang <[email protected]>
AuthorDate: Fri Apr 8 09:32:21 2022 +0800

    [SPARK-38548][SQL][FOLLOWUP] try_sum: return null if overflow happens 
before merging
    
    ### What changes were proposed in this pull request?
    
    This PR is to fix a bug in the new function `try_sum`. It should return 
null if overflow happens before merging the sums from map tasks.
    For example:
    MAP TASK 1: partial aggregation TRY_SUM(large_numbers_column) -> overflows, 
turns into NULL
    MAP TASK 2: partial aggregation TRY_SUM(large_numbers_column) -> succeeds, 
returns 12345
    REDUCE TASK: merge TRY_SUM(NULL, 12345) -> returns 12345
    
    We should use a new slot buffer `isEmpty` to track if there is a non-empty 
value in partial aggregation. If the partial result is null and there is 
non-empty value, the merge result should be `NULL`.
    ### Why are the changes needed?
    
    Bug fix
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, the new function is not release yet.
    
    ### How was this patch tested?
    
    UT
    
    Closes #36097 from gengliangwang/fixTrySum.
    
    Lead-authored-by: Gengliang Wang <[email protected]>
    Co-authored-by: Gengliang Wang <[email protected]>
    Signed-off-by: Gengliang Wang <[email protected]>
---
 .../sql/catalyst/expressions/aggregate/Sum.scala   | 131 ++++++++++++---------
 .../scala/org/apache/spark/sql/SQLQuerySuite.scala |  13 ++
 2 files changed, 86 insertions(+), 58 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index 5d8fd702ba4..fd27edfc8fc 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -32,6 +32,8 @@ abstract class SumBase(child: Expression) extends 
DeclarativeAggregate
 
   def failOnError: Boolean
 
+  protected def shouldTrackIsEmpty: Boolean
+
   override def nullable: Boolean = true
 
   // Return data type.
@@ -45,7 +47,7 @@ abstract class SumBase(child: Expression) extends 
DeclarativeAggregate
 
   final override val nodePatterns: Seq[TreePattern] = Seq(SUM)
 
-  private lazy val resultType = child.dataType match {
+  protected lazy val resultType = child.dataType match {
     case DecimalType.Fixed(precision, scale) =>
       DecimalType.bounded(precision + 10, scale)
     case _: IntegralType => LongType
@@ -60,51 +62,51 @@ abstract class SumBase(child: Expression) extends 
DeclarativeAggregate
 
   private lazy val zero = Literal.default(resultType)
 
-  override lazy val aggBufferAttributes = resultType match {
-    case _: DecimalType => sum :: isEmpty :: Nil
-    case _ => sum :: Nil
+  override lazy val aggBufferAttributes = if (shouldTrackIsEmpty) {
+    sum :: isEmpty :: Nil
+  } else {
+    sum :: Nil
   }
 
-  override lazy val initialValues: Seq[Expression] = resultType match {
-    case _: DecimalType => Seq(zero, Literal(true, BooleanType))
-    case _ => Seq(Literal(null, resultType))
-  }
+  override lazy val initialValues: Seq[Expression] =
+    if (shouldTrackIsEmpty) {
+      Seq(zero, Literal(true, BooleanType))
+    } else {
+      Seq(Literal(null, resultType))
+    }
 
-  protected def getUpdateExpressions: Seq[Expression] = {
-    resultType match {
-      case _: DecimalType =>
-        // For decimal type, the initial value of `sum` is 0. We need to keep 
`sum` unchanged if
-        // the input is null, as SUM function ignores null input. The `sum` 
can only be null if
-        // overflow happens under non-ansi mode.
-        val sumExpr = if (child.nullable) {
-          If(child.isNull, sum,
-            Add(sum, KnownNotNull(child).cast(resultType), failOnError = 
failOnError))
-        } else {
-          Add(sum, child.cast(resultType), failOnError = failOnError)
-        }
-        // The buffer becomes non-empty after seeing the first not-null input.
-        val isEmptyExpr = if (child.nullable) {
-          isEmpty && child.isNull
-        } else {
-          Literal(false, BooleanType)
-        }
-        Seq(sumExpr, isEmptyExpr)
-      case _ =>
-        // For non-decimal type, the initial value of `sum` is null, which 
indicates no value.
-        // We need `coalesce(sum, zero)` to start summing values. And we need 
an outer `coalesce`
-        // in case the input is nullable. The `sum` can only be null if there 
is no value, as
-        // non-decimal type can produce overflowed value under non-ansi mode.
-        if (child.nullable) {
-          Seq(coalesce(Add(coalesce(sum, zero), child.cast(resultType), 
failOnError = failOnError),
-            sum))
-        } else {
-          Seq(Add(coalesce(sum, zero), child.cast(resultType), failOnError = 
failOnError))
-        }
+  protected def getUpdateExpressions: Seq[Expression] = if 
(shouldTrackIsEmpty) {
+    // If shouldTrackIsEmpty is true, the initial value of `sum` is 0. We need 
to keep `sum`
+    // unchanged if the input is null, as SUM function ignores null input. The 
`sum` can only be
+    // null if overflow happens under non-ansi mode.
+    val sumExpr = if (child.nullable) {
+      If(child.isNull, sum,
+        Add(sum, KnownNotNull(child).cast(resultType), failOnError = 
failOnError))
+    } else {
+      Add(sum, child.cast(resultType), failOnError = failOnError)
+    }
+    // The buffer becomes non-empty after seeing the first not-null input.
+    val isEmptyExpr = if (child.nullable) {
+      isEmpty && child.isNull
+    } else {
+      Literal(false, BooleanType)
+    }
+    Seq(sumExpr, isEmptyExpr)
+  } else {
+    // If shouldTrackIsEmpty is false, the initial value of `sum` is null, 
which indicates no value.
+    // We need `coalesce(sum, zero)` to start summing values. And we need an 
outer `coalesce`
+    // in case the input is nullable. The `sum` can only be null if there is 
no value, as
+    // non-decimal type can produce overflowed value under non-ansi mode.
+    if (child.nullable) {
+      Seq(coalesce(Add(coalesce(sum, zero), child.cast(resultType), 
failOnError = failOnError),
+        sum))
+    } else {
+      Seq(Add(coalesce(sum, zero), child.cast(resultType), failOnError = 
failOnError))
     }
   }
 
   /**
-   * For decimal type:
+   * When shouldTrackIsEmpty is true:
    * If isEmpty is false and if sum is null, then it means we have had an 
overflow.
    *
    * update of the sum is as follows:
@@ -113,26 +115,24 @@ abstract class SumBase(child: Expression) extends 
DeclarativeAggregate
    * If it did not have overflow, then add the sum.left and sum.right
    *
    * isEmpty:  Set to false if either one of the left or right is set to 
false. This
-   * means we have seen atleast a value that was not null.
+   * means we have seen at least a value that was not null.
    */
-  protected def getMergeExpressions: Seq[Expression] = {
-    resultType match {
-      case _: DecimalType =>
-        val bufferOverflow = !isEmpty.left && sum.left.isNull
-        val inputOverflow = !isEmpty.right && sum.right.isNull
-        Seq(
-          If(
-            bufferOverflow || inputOverflow,
-            Literal.create(null, resultType),
-            // If both the buffer and the input do not overflow, just add 
them, as they can't be
-            // null. See the comments inside `updateExpressions`: `sum` can 
only be null if
-            // overflow happens.
-            KnownNotNull(sum.left) + KnownNotNull(sum.right)),
-          isEmpty.left && isEmpty.right)
-      case _ => Seq(coalesce(
-        Add(coalesce(sum.left, zero), sum.right, failOnError = failOnError),
-        sum.left))
-    }
+  protected def getMergeExpressions: Seq[Expression] = if (shouldTrackIsEmpty) 
{
+    val bufferOverflow = !isEmpty.left && sum.left.isNull
+    val inputOverflow = !isEmpty.right && sum.right.isNull
+    Seq(
+      If(
+        bufferOverflow || inputOverflow,
+        Literal.create(null, resultType),
+        // If both the buffer and the input do not overflow, just add them, as 
they can't be
+        // null. See the comments inside `updateExpressions`: `sum` can only 
be null if
+        // overflow happens.
+        Add(KnownNotNull(sum.left), KnownNotNull(sum.right), failOnError)),
+      isEmpty.left && isEmpty.right)
+  } else {
+    Seq(coalesce(
+      Add(coalesce(sum.left, zero), sum.right, failOnError = failOnError),
+      sum.left))
   }
 
   /**
@@ -146,6 +146,8 @@ abstract class SumBase(child: Expression) extends 
DeclarativeAggregate
     case d: DecimalType =>
       If(isEmpty, Literal.create(null, resultType),
         CheckOverflowInSum(sum, d, !failOnError))
+    case _ if shouldTrackIsEmpty =>
+      If(isEmpty, Literal.create(null, resultType), sum)
     case _ => sum
   }
 
@@ -172,6 +174,11 @@ case class Sum(
   extends SumBase(child) {
   def this(child: Expression) = this(child, failOnError = 
SQLConf.get.ansiEnabled)
 
+  override def shouldTrackIsEmpty: Boolean = resultType match {
+    case _: DecimalType => true
+    case _ => false
+  }
+
   override protected def withNewChildInternal(newChild: Expression): Sum = 
copy(child = newChild)
 
   override lazy val updateExpressions: Seq[Expression] = getUpdateExpressions
@@ -208,6 +215,14 @@ case class TrySum(child: Expression) extends 
SumBase(child) {
     case _ => true
   }
 
+  override def shouldTrackIsEmpty: Boolean = resultType match {
+    // The sum of following data types can cause overflow.
+    case _: DecimalType | _: IntegralType | _: YearMonthIntervalType | _: 
DayTimeIntervalType =>
+      true
+    case _ =>
+      false
+  }
+
   override lazy val updateExpressions: Seq[Expression] =
     if (failOnError) {
       val expressions = getUpdateExpressions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index d720e542b40..f25a3e399aa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -4328,6 +4328,19 @@ class SQLQuerySuite extends QueryTest with 
SharedSparkSession with AdaptiveSpark
         Row(3, 2, 6) :: Nil)
     }
   }
+
+  test("SPARK-38548: try_sum should return null if overflow happens before 
merging") {
+    val longDf = Seq(Long.MaxValue, Long.MaxValue, 2).toDF("v")
+    val yearMonthDf = Seq(Int.MaxValue, Int.MaxValue, 2)
+      .map(Period.ofMonths)
+      .toDF("v")
+    val dayTimeDf = Seq(106751991L, 106751991L, 2L)
+      .map(Duration.ofDays)
+      .toDF("v")
+    Seq(longDf, yearMonthDf, dayTimeDf).foreach { df =>
+      checkAnswer(df.repartitionByRange(2, col("v")).selectExpr("try_sum(v)"), 
Row(null))
+    }
+  }
 }
 
 case class Foo(bar: Option[String])


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to