This is an automated email from the ASF dual-hosted git repository.

yamamuro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b79cf0d  [SPARK-28224][SQL] Check overflow in decimal Sum aggregate
b79cf0d is described below

commit b79cf0d14351c741efe4f27523919a0e24b8b2ed
Author: Mick Jermsurawong <mickjermsuraw...@stripe.com>
AuthorDate: Tue Aug 20 09:47:04 2019 +0900

    [SPARK-28224][SQL] Check overflow in decimal Sum aggregate
    
    ## What changes were proposed in this pull request?
    - Currently `sum` in aggregates for decimal type can overflow and return 
null.
      - `Sum` expression codegens arithmetic on `sql.Decimal` and the output 
which preserves scale and precision goes into `UnsafeRowWriter`. Here 
overflowing will be converted to null when writing out.
      - It also does not go through this branch in `DecimalAggregates` because 
it's expecting precision of the sum (not the elements to be summed) to be less 
than 5.
    
https://github.com/apache/spark/blob/4ebff5b6d68f26cc1ff9265a5489e0d7c2e05449/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala#L1400-L1403
    
    - This PR adds the check at the final result of the sum operator itself.
    
https://github.com/apache/spark/blob/4ebff5b6d68f26cc1ff9265a5489e0d7c2e05449/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala#L372-L376
    
    https://issues.apache.org/jira/browse/SPARK-28224
    
    ## How was this patch tested?
    
    - Added an integration test on dataframe suite
    
    cc mgaido91 JoshRosen
    
    Closes #25033 from mickjermsurawong-stripe/SPARK-28224.
    
    Authored-by: Mick Jermsurawong <mickjermsuraw...@stripe.com>
    Signed-off-by: Takeshi Yamamuro <yamam...@apache.org>
---
 .../sql/catalyst/expressions/aggregate/Sum.scala   |  7 ++++++-
 .../org/apache/spark/sql/DataFrameSuite.scala      | 23 +++++++++++++++++++++-
 2 files changed, 28 insertions(+), 2 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index ef204ec..d04fe92 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.util.TypeUtils
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
 @ExpressionDescription(
@@ -89,5 +90,9 @@ case class Sum(child: Expression) extends 
DeclarativeAggregate with ImplicitCast
     )
   }
 
-  override lazy val evaluateExpression: Expression = sum
+  override lazy val evaluateExpression: Expression = resultType match {
+    case d: DecimalType => CheckOverflow(sum, d, 
SQLConf.get.decimalOperationsNullOnOverflow)
+    case _ => sum
+  }
+
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index ba8fced..c6daff1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -38,7 +38,7 @@ import 
org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExc
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, 
SharedSparkSession}
-import org.apache.spark.sql.test.SQLTestData.{NullStrings, TestData2}
+import org.apache.spark.sql.test.SQLTestData.{DecimalData, NullStrings, 
TestData2}
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 import org.apache.spark.util.random.XORShiftRandom
@@ -156,6 +156,27 @@ class DataFrameSuite extends QueryTest with 
SharedSparkSession {
       structDf.select(xxhash64($"a", $"record.*")))
   }
 
+  test("SPARK-28224: Aggregate sum big decimal overflow") {
+    val largeDecimals = spark.sparkContext.parallelize(
+      DecimalData(BigDecimal("1"* 20 + ".123"), BigDecimal("1"* 20 + ".123")) 
::
+        DecimalData(BigDecimal("9"* 20 + ".123"), BigDecimal("9"* 20 + 
".123")) :: Nil).toDF()
+
+    Seq(true, false).foreach { nullOnOverflow =>
+      withSQLConf((SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key, 
nullOnOverflow.toString)) {
+        val structDf = largeDecimals.select("a").agg(sum("a"))
+        if (nullOnOverflow) {
+          checkAnswer(structDf, Row(null))
+        } else {
+          val e = intercept[SparkException] {
+            structDf.collect
+          }
+          assert(e.getCause.getClass.equals(classOf[ArithmeticException]))
+          assert(e.getCause.getMessage.contains("cannot be represented as 
Decimal"))
+        }
+      }
+    }
+  }
+
   test("Star Expansion - explode should fail with a meaningful message if it 
takes a star") {
     val df = Seq(("1,2"), ("4"), ("7,8,9")).toDF("csv")
     val e = intercept[AnalysisException] {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to