This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push:
new ee12374 [3.0][SQL] Revert SPARK-32018
ee12374 is described below
commit ee1237483eca1c2f2d9630fce91e2394049ce3a4
Author: Gengliang Wang <[email protected]>
AuthorDate: Mon Aug 17 13:46:41 2020 +0000
[3.0][SQL] Revert SPARK-32018
### What changes were proposed in this pull request?
Revert SPARK-32018 related changes in branch 3.0:
https://github.com/apache/spark/pull/29125 and
https://github.com/apache/spark/pull/29404
### Why are the changes needed?
https://github.com/apache/spark/pull/29404 is made to fix correctness
regression introduced by https://github.com/apache/spark/pull/29125. However,
the behavior of decimal overflow is strange in non-ansi mode:
1. from 3.0.0 to 3.0.1: decimal overflow will throw exceptions instead of
returning null on decimal overflow
2. from 3.0.1 to 3.1.0: decimal overflow will return null instead of
throwing exceptions.
So, this PR proposes to revert both
https://github.com/apache/spark/pull/29404 and
https://github.com/apache/spark/pull/29125. So that Spark will return null on
decimal overflow in Spark 3.0.0 and Spark 3.0.1.
### Does this PR introduce _any_ user-facing change?
Yes, Spark will return null on decimal overflow in Spark 3.0.1.
### How was this patch tested?
Unit tests
Closes #29450 from gengliangwang/revertDecimalOverflow.
Authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/catalyst/expressions/UnsafeRow.java | 2 +-
.../sql/catalyst/expressions/aggregate/Sum.scala | 19 ++---------
.../apache/spark/sql/DataFrameAggregateSuite.scala | 37 ----------------------
.../org/apache/spark/sql/UnsafeRowSuite.scala | 10 ------
4 files changed, 4 insertions(+), 64 deletions(-)
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 4dc5ce1..034894b 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -288,7 +288,7 @@ public final class UnsafeRow extends InternalRow implements
Externalizable, Kryo
Platform.putLong(baseObject, baseOffset + cursor, 0L);
Platform.putLong(baseObject, baseOffset + cursor + 8, 0L);
- if (value == null || !value.changePrecision(precision, value.scale())) {
+ if (value == null) {
setNullAt(ordinal);
// keep the offset for future update
Platform.putLong(baseObject, getFieldOffset(ordinal), cursor << 32);
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index d442549..d2daaac 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -71,36 +71,23 @@ case class Sum(child: Expression) extends
DeclarativeAggregate with ImplicitCast
)
override lazy val updateExpressions: Seq[Expression] = {
- val sumWithChild = resultType match {
- case d: DecimalType =>
- CheckOverflow(coalesce(sum, zero) + child.cast(sumDataType), d,
nullOnOverflow = false)
- case _ =>
- coalesce(sum, zero) + child.cast(sumDataType)
- }
-
if (child.nullable) {
Seq(
/* sum = */
- coalesce(sumWithChild, sum)
+ coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)
)
} else {
Seq(
/* sum = */
- sumWithChild
+ coalesce(sum, zero) + child.cast(sumDataType)
)
}
}
override lazy val mergeExpressions: Seq[Expression] = {
- val sumWithRight = resultType match {
- case d: DecimalType =>
- CheckOverflow(coalesce(sum.left, zero) + sum.right, d, nullOnOverflow
= false)
-
- case _ => coalesce(sum.left, zero) + sum.right
- }
Seq(
/* sum = */
- coalesce(sumWithRight, sum.left)
+ coalesce(coalesce(sum.left, zero) + sum.right, sum.left)
)
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 8c0358e..54327b3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -21,7 +21,6 @@ import scala.util.Random
import org.scalatest.Matchers.the
-import org.apache.spark.SparkException
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec,
ObjectHashAggregateExec, SortAggregateExec}
@@ -1045,42 +1044,6 @@ class DataFrameAggregateSuite extends QueryTest
checkAnswer(sql(queryTemplate("FIRST")), Row(1))
checkAnswer(sql(queryTemplate("LAST")), Row(3))
}
-
- private def exceptionOnDecimalOverflow(df: DataFrame): Unit = {
- val msg = intercept[SparkException] {
- df.collect()
- }.getCause.getMessage
- assert(msg.contains("cannot be represented as Decimal(38, 18)"))
- }
-
- test("SPARK-32018: Throw exception on decimal overflow at partial aggregate
phase") {
- val decimalString = "1" + "0" * 19
- val union = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1))
- val hashAgg = union
- .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"),
lit("1").as("key"))
- .groupBy("key")
- .agg(sum($"d").alias("sumD"))
- .select($"sumD")
- exceptionOnDecimalOverflow(hashAgg)
-
- val sortAgg = union
- .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"),
lit("a").as("str"),
- lit("1").as("key")).groupBy("key")
- .agg(sum($"d").alias("sumD"),
min($"str").alias("minStr")).select($"sumD", $"minStr")
- exceptionOnDecimalOverflow(sortAgg)
- }
-
- test("SPARK-32018: Throw exception on decimal overflow at merge aggregation
phase") {
- val decimalString = "5" + "0" * 19
- val union = spark.range(0, 1, 1, 1).union(spark.range(0, 1, 1, 1))
- .union(spark.range(0, 1, 1, 1))
- val agg = union
- .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"),
lit("1").as("key"))
- .groupBy("key")
- .agg(sum($"d").alias("sumD"))
- .select($"sumD")
- exceptionOnDecimalOverflow(agg)
- }
}
case class B(c: Option[Double])
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
index 9daa69c..a5f904c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
@@ -178,14 +178,4 @@ class UnsafeRowSuite extends SparkFunSuite {
// Makes sure hashCode on unsafe array won't crash
unsafeRow.getArray(0).hashCode()
}
-
- test("SPARK-32018: setDecimal with overflowed value") {
- val d1 = new
Decimal().set(BigDecimal("10000000000000000000")).toPrecision(38, 18)
- val row = InternalRow.apply(d1)
- val unsafeRow = UnsafeProjection.create(Array[DataType](DecimalType(38,
18))).apply(row)
- assert(unsafeRow.getDecimal(0, 38, 18) === d1)
- val d2 = (d1 * Decimal(10)).toPrecision(39, 18)
- unsafeRow.setDecimal(0, d2, 38)
- assert(unsafeRow.getDecimal(0, 38, 18) === null)
- }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]