This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push:
new 89765f5 [SPARK-32018][SQL][FOLLOWUP][3.0] Throw exception on decimal
value overflow of sum aggregation
89765f5 is described below
commit 89765f556f26252aed1add71a9da84209ff03493
Author: Gengliang Wang <[email protected]>
AuthorDate: Thu Aug 13 03:52:12 2020 +0000
[SPARK-32018][SQL][FOLLOWUP][3.0] Throw exception on decimal value overflow
of sum aggregation
### What changes were proposed in this pull request?
This is a followup of https://github.com/apache/spark/pull/29125
In branch 3.0:
1. for hash aggregation, before https://github.com/apache/spark/pull/29125
there will be a runtime exception on decimal overflow of sum aggregation; after
https://github.com/apache/spark/pull/29125, there could be a wrong result.
2. for sort aggregation, with/without
https://github.com/apache/spark/pull/29125, there could be a wrong result on
decimal overflow.
While in master branch(the future 3.1 release), the problem doesn't exist
since in https://github.com/apache/spark/pull/27627 there is a flag for marking
whether overflow happens in aggregation buffer. However, the aggregation buffer
is written in steaming checkpoints. Thus, we can't change to aggregation buffer
to resolve the issue.
As there is no easy solution for returning null/throwing exception
regarding `spark.sql.ansi.enabled` on overflow in branch 3.0, we have to make a
choice here: always throw exception on decimal value overflow of sum
aggregation.
### Why are the changes needed?
Avoid returning wrong result in decimal value sum aggregation.
### Does this PR introduce _any_ user-facing change?
Yes, there is always exception on decimal value overflow of sum
aggregation, instead of a possible wrong result.
### How was this patch tested?
Unit test case
Closes #29404 from gengliangwang/fixSum.
Authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/catalyst/expressions/aggregate/Sum.scala | 19 +++++++++--
.../apache/spark/sql/DataFrameAggregateSuite.scala | 37 ++++++++++++++++++++++
2 files changed, 53 insertions(+), 3 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index d2daaac..d442549 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -71,23 +71,36 @@ case class Sum(child: Expression) extends
DeclarativeAggregate with ImplicitCast
)
override lazy val updateExpressions: Seq[Expression] = {
+ val sumWithChild = resultType match {
+ case d: DecimalType =>
+ CheckOverflow(coalesce(sum, zero) + child.cast(sumDataType), d,
nullOnOverflow = false)
+ case _ =>
+ coalesce(sum, zero) + child.cast(sumDataType)
+ }
+
if (child.nullable) {
Seq(
/* sum = */
- coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)
+ coalesce(sumWithChild, sum)
)
} else {
Seq(
/* sum = */
- coalesce(sum, zero) + child.cast(sumDataType)
+ sumWithChild
)
}
}
override lazy val mergeExpressions: Seq[Expression] = {
+ val sumWithRight = resultType match {
+ case d: DecimalType =>
+ CheckOverflow(coalesce(sum.left, zero) + sum.right, d, nullOnOverflow
= false)
+
+ case _ => coalesce(sum.left, zero) + sum.right
+ }
Seq(
/* sum = */
- coalesce(coalesce(sum.left, zero) + sum.right, sum.left)
+ coalesce(sumWithRight, sum.left)
)
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 54327b3..8c0358e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -21,6 +21,7 @@ import scala.util.Random
import org.scalatest.Matchers.the
+import org.apache.spark.SparkException
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec,
ObjectHashAggregateExec, SortAggregateExec}
@@ -1044,6 +1045,42 @@ class DataFrameAggregateSuite extends QueryTest
checkAnswer(sql(queryTemplate("FIRST")), Row(1))
checkAnswer(sql(queryTemplate("LAST")), Row(3))
}
+
+ private def exceptionOnDecimalOverflow(df: DataFrame): Unit = {
+ val msg = intercept[SparkException] {
+ df.collect()
+ }.getCause.getMessage
+ assert(msg.contains("cannot be represented as Decimal(38, 18)"))
+ }
+
+ test("SPARK-32018: Throw exception on decimal overflow at partial aggregate
phase") {
+ val decimalString = "1" + "0" * 19
+ val union = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1))
+ val hashAgg = union
+ .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"),
lit("1").as("key"))
+ .groupBy("key")
+ .agg(sum($"d").alias("sumD"))
+ .select($"sumD")
+ exceptionOnDecimalOverflow(hashAgg)
+
+ val sortAgg = union
+ .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"),
lit("a").as("str"),
+ lit("1").as("key")).groupBy("key")
+ .agg(sum($"d").alias("sumD"),
min($"str").alias("minStr")).select($"sumD", $"minStr")
+ exceptionOnDecimalOverflow(sortAgg)
+ }
+
+ test("SPARK-32018: Throw exception on decimal overflow at merge aggregation
phase") {
+ val decimalString = "5" + "0" * 19
+ val union = spark.range(0, 1, 1, 1).union(spark.range(0, 1, 1, 1))
+ .union(spark.range(0, 1, 1, 1))
+ val agg = union
+ .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"),
lit("1").as("key"))
+ .groupBy("key")
+ .agg(sum($"d").alias("sumD"))
+ .select($"sumD")
+ exceptionOnDecimalOverflow(agg)
+ }
}
case class B(c: Option[Double])
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]