cloud-fan commented on a change in pull request #27627: [WIP][SPARK-28067][SQL] 
Fix incorrect results for decimal aggregate sum by returning null on decimal 
overflow
URL: https://github.com/apache/spark/pull/27627#discussion_r387056495
 
 

 ##########
 File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
 ##########
 @@ -60,38 +60,104 @@ case class Sum(child: Expression) extends 
DeclarativeAggregate with ImplicitCast
   private lazy val sumDataType = resultType
 
   private lazy val sum = AttributeReference("sum", sumDataType)()
+  private lazy val overflow = AttributeReference("overflow", BooleanType, 
false)()
 
   private lazy val zero = Literal.default(resultType)
 
-  override lazy val aggBufferAttributes = sum :: Nil
+  override lazy val aggBufferAttributes = sum :: overflow :: Nil
 
   override lazy val initialValues: Seq[Expression] = Seq(
-    /* sum = */ Literal.create(null, sumDataType)
+    /* sum = */ Literal.create(null, sumDataType),
+    /* overflow = */ Literal.create(false, BooleanType)
   )
 
   override lazy val updateExpressions: Seq[Expression] = {
-    if (child.nullable) {
+    if (!SQLConf.get.ansiEnabled) {
+      if (child.nullable) {
+        Seq(
+          /* sum = */
+          coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum),
+          /* overflow = */
+          resultType match {
+            case d: DecimalType =>
+              If(overflow, true,
+                HasOverflow(coalesce(sum, zero) + child.cast(sumDataType), d))
+            case _ => false
+          })
+      } else {
+        Seq(
+          /* sum = */
+          coalesce(sum, zero) + child.cast(sumDataType),
+          /* overflow = */
+          resultType match {
+            case d: DecimalType =>
+              If(overflow, true, HasOverflow(coalesce(sum, zero) + 
child.cast(sumDataType), d))
+            case _ => false
+          })
+      }
+    } else {
+      if (child.nullable) {
+        Seq(
+          /* sum = */
+          resultType match {
+            case d: DecimalType => coalesce(
+              CheckOverflow(
+                coalesce(sum, zero) + child.cast(sumDataType), d, 
!SQLConf.get.ansiEnabled), sum)
+            case _ => coalesce(coalesce(sum, zero) + child.cast(sumDataType), 
sum)
+          },
+          /* overflow = */
+          // overflow flag doesnt need any updates since CheckOverflow will 
throw exception
+          // if overflow happens
+          false
+        )
+      } else {
+        Seq(
+          /* sum = */
+          resultType match {
+            case d: DecimalType => CheckOverflow(
+              coalesce(sum, zero) + child.cast(sumDataType), d, 
!SQLConf.get.ansiEnabled)
+            case _ => coalesce(sum, zero) + child.cast(sumDataType)
+          },
+          /* overflow = */
+          false
+        )
+      }
+    }
+  }
+
+  override lazy val mergeExpressions: Seq[Expression] = {
+    if (!SQLConf.get.ansiEnabled) {
       Seq(
         /* sum = */
-        coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)
+        coalesce(coalesce(sum.left, zero) + sum.right, sum.left),
+        /* overflow = */
+        resultType match {
+          case d: DecimalType =>
+            If(coalesce(overflow.left, false) || coalesce(overflow.right, 
false),
+              true, HasOverflow(coalesce(sum.left, zero) + sum.right, d))
+          case _ =>
+            If(coalesce(overflow.left, false) || coalesce(overflow.right, 
false), true, false)
+        }
       )
     } else {
       Seq(
         /* sum = */
-        coalesce(sum, zero) + child.cast(sumDataType)
+        resultType match {
+          case d: DecimalType =>
+            coalesce(
+              CheckOverflow(coalesce(sum.left, zero) + sum.right, d, 
!SQLConf.get.ansiEnabled),
+              sum.left)
+          case _ => coalesce(coalesce(sum.left, zero) + sum.right, sum.left)
+        },
+        /* overflow = */
+        If(coalesce(overflow.left, false) || coalesce(overflow.right, false), 
true, false)
       )
     }
   }
 
-  override lazy val mergeExpressions: Seq[Expression] = {
-    Seq(
-      /* sum = */
-      coalesce(coalesce(sum.left, zero) + sum.right, sum.left)
-    )
-  }
-
   override lazy val evaluateExpression: Expression = resultType match {
-    case d: DecimalType => CheckOverflow(sum, d, !SQLConf.get.ansiEnabled)
 
 Review comment:
   so you basically mean we should check overflow at every step during the sum 
calculation?
   
   This can be super expensive, maybe we should have a special sum expression, 
which keeps an unlimited decimal presentation (like java BigDecimal) as buffer, 
and only check overflow when we need to serialize the buffer.
   
   cc @hvanhovell @maryannxue @rednaxelafx 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to