This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new ee12374  [3.0][SQL] Revert SPARK-32018
ee12374 is described below

commit ee1237483eca1c2f2d9630fce91e2394049ce3a4
Author: Gengliang Wang <[email protected]>
AuthorDate: Mon Aug 17 13:46:41 2020 +0000

    [3.0][SQL] Revert SPARK-32018
    
    ### What changes were proposed in this pull request?
    
    Revert SPARK-32018 related changes in branch 3.0: 
https://github.com/apache/spark/pull/29125 and  
https://github.com/apache/spark/pull/29404
    
    ### Why are the changes needed?
    
    https://github.com/apache/spark/pull/29404 is made to fix correctness 
regression introduced by https://github.com/apache/spark/pull/29125. However, 
the behavior of decimal overflow is strange in non-ansi mode:
    1. from 3.0.0 to 3.0.1: decimal overflow will throw exceptions instead of 
returning null on decimal overflow
    2. from 3.0.1 to 3.1.0: decimal overflow will return null instead of 
throwing exceptions.
    
    So, this PR proposes to revert both 
https://github.com/apache/spark/pull/29404 and 
https://github.com/apache/spark/pull/29125. So that Spark will return null on 
decimal overflow in Spark 3.0.0 and Spark 3.0.1.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, Spark will return null on decimal overflow in Spark 3.0.1.
    
    ### How was this patch tested?
    
    Unit tests
    
    Closes #29450 from gengliangwang/revertDecimalOverflow.
    
    Authored-by: Gengliang Wang <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/catalyst/expressions/UnsafeRow.java  |  2 +-
 .../sql/catalyst/expressions/aggregate/Sum.scala   | 19 ++---------
 .../apache/spark/sql/DataFrameAggregateSuite.scala | 37 ----------------------
 .../org/apache/spark/sql/UnsafeRowSuite.scala      | 10 ------
 4 files changed, 4 insertions(+), 64 deletions(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 4dc5ce1..034894b 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -288,7 +288,7 @@ public final class UnsafeRow extends InternalRow implements 
Externalizable, Kryo
       Platform.putLong(baseObject, baseOffset + cursor, 0L);
       Platform.putLong(baseObject, baseOffset + cursor + 8, 0L);
 
-      if (value == null || !value.changePrecision(precision, value.scale())) {
+      if (value == null) {
         setNullAt(ordinal);
         // keep the offset for future update
         Platform.putLong(baseObject, getFieldOffset(ordinal), cursor << 32);
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index d442549..d2daaac 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -71,36 +71,23 @@ case class Sum(child: Expression) extends 
DeclarativeAggregate with ImplicitCast
   )
 
   override lazy val updateExpressions: Seq[Expression] = {
-    val sumWithChild = resultType match {
-      case d: DecimalType =>
-        CheckOverflow(coalesce(sum, zero) + child.cast(sumDataType), d, 
nullOnOverflow = false)
-      case _ =>
-        coalesce(sum, zero) + child.cast(sumDataType)
-    }
-
     if (child.nullable) {
       Seq(
         /* sum = */
-        coalesce(sumWithChild, sum)
+        coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)
       )
     } else {
       Seq(
         /* sum = */
-        sumWithChild
+        coalesce(sum, zero) + child.cast(sumDataType)
       )
     }
   }
 
   override lazy val mergeExpressions: Seq[Expression] = {
-    val sumWithRight = resultType match {
-      case d: DecimalType =>
-        CheckOverflow(coalesce(sum.left, zero) + sum.right, d, nullOnOverflow 
= false)
-
-      case _ => coalesce(sum.left, zero) + sum.right
-    }
     Seq(
       /* sum = */
-      coalesce(sumWithRight, sum.left)
+      coalesce(coalesce(sum.left, zero) + sum.right, sum.left)
     )
   }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 8c0358e..54327b3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -21,7 +21,6 @@ import scala.util.Random
 
 import org.scalatest.Matchers.the
 
-import org.apache.spark.SparkException
 import org.apache.spark.sql.execution.WholeStageCodegenExec
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, 
ObjectHashAggregateExec, SortAggregateExec}
@@ -1045,42 +1044,6 @@ class DataFrameAggregateSuite extends QueryTest
     checkAnswer(sql(queryTemplate("FIRST")), Row(1))
     checkAnswer(sql(queryTemplate("LAST")), Row(3))
   }
-
-  private def exceptionOnDecimalOverflow(df: DataFrame): Unit = {
-    val msg = intercept[SparkException] {
-      df.collect()
-    }.getCause.getMessage
-    assert(msg.contains("cannot be represented as Decimal(38, 18)"))
-  }
-
-  test("SPARK-32018: Throw exception on decimal overflow at partial aggregate 
phase") {
-    val decimalString = "1" + "0" * 19
-    val union = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1))
-    val hashAgg = union
-      .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), 
lit("1").as("key"))
-      .groupBy("key")
-      .agg(sum($"d").alias("sumD"))
-      .select($"sumD")
-    exceptionOnDecimalOverflow(hashAgg)
-
-    val sortAgg = union
-      .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), 
lit("a").as("str"),
-      lit("1").as("key")).groupBy("key")
-      .agg(sum($"d").alias("sumD"), 
min($"str").alias("minStr")).select($"sumD", $"minStr")
-    exceptionOnDecimalOverflow(sortAgg)
-  }
-
-  test("SPARK-32018: Throw exception on decimal overflow at merge aggregation 
phase") {
-    val decimalString = "5" + "0" * 19
-    val union = spark.range(0, 1, 1, 1).union(spark.range(0, 1, 1, 1))
-      .union(spark.range(0, 1, 1, 1))
-    val agg = union
-      .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), 
lit("1").as("key"))
-      .groupBy("key")
-      .agg(sum($"d").alias("sumD"))
-      .select($"sumD")
-    exceptionOnDecimalOverflow(agg)
-  }
 }
 
 case class B(c: Option[Double])
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
index 9daa69c..a5f904c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala
@@ -178,14 +178,4 @@ class UnsafeRowSuite extends SparkFunSuite {
     // Makes sure hashCode on unsafe array won't crash
     unsafeRow.getArray(0).hashCode()
   }
-
-  test("SPARK-32018: setDecimal with overflowed value") {
-    val d1 = new 
Decimal().set(BigDecimal("10000000000000000000")).toPrecision(38, 18)
-    val row = InternalRow.apply(d1)
-    val unsafeRow = UnsafeProjection.create(Array[DataType](DecimalType(38, 
18))).apply(row)
-    assert(unsafeRow.getDecimal(0, 38, 18) === d1)
-    val d2 = (d1 * Decimal(10)).toPrecision(39, 18)
-    unsafeRow.setDecimal(0, d2, 38)
-    assert(unsafeRow.getDecimal(0, 38, 18) === null)
-  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to