This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new bc4a676  [SPARK-28201][SQL] Revisit MakeDecimal behavior on overflow
bc4a676 is described below

commit bc4a676b2752c691f7c1d824a58387dbfac6d695
Author: Marco Gaido <[email protected]>
AuthorDate: Mon Jul 1 11:54:58 2019 +0800

    [SPARK-28201][SQL] Revisit MakeDecimal behavior on overflow
    
    ## What changes were proposed in this pull request?
    
    In SPARK-23179, it has been introduced a flag to control the behavior in 
case of overflow on decimals. The behavior is: returning `null` when 
`spark.sql.decimalOperations.nullOnOverflow` (default and traditional Spark 
behavior); throwing an `ArithmeticException` if that conf is false (according 
to SQL standards, other DBs behavior).
    
    `MakeDecimal` so far had an ambiguous behavior. In case of codegen mode, it 
returned `null` as the other operators, but in interpreted mode, it was 
throwing an `IllegalArgumentException`.
    
    The PR aligns `MakeDecimal`'s behavior with the one of other operators as 
defined in SPARK-23179. So now both modes return `null` or throw 
`ArithmeticException` according to 
`spark.sql.decimalOperations.nullOnOverflow`'s value.
    
    Credits for this PR to mickjermsurawong-stripe who pointed out the wrong 
behavior in #20350.
    
    ## How was this patch tested?
    
    improved UTs
    
    Closes #25010 from mgaido91/SPARK-28201.
    
    Authored-by: Marco Gaido <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../catalyst/expressions/decimalExpressions.scala  | 32 ++++++++++++++++++----
 .../scala/org/apache/spark/sql/types/Decimal.scala |  9 +++---
 .../expressions/DecimalExpressionSuite.scala       | 20 ++++++++++++--
 .../org/apache/spark/sql/types/DecimalSuite.scala  | 10 +++----
 4 files changed, 54 insertions(+), 17 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index ad7f7dd..b5b712c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
EmptyBlock, ExprCode}
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
 /**
@@ -46,19 +47,38 @@ case class UnscaledValue(child: Expression) extends 
UnaryExpression {
  */
 case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends 
UnaryExpression {
 
+  private val nullOnOverflow = SQLConf.get.decimalOperationsNullOnOverflow
+
   override def dataType: DataType = DecimalType(precision, scale)
-  override def nullable: Boolean = true
+  override def nullable: Boolean = child.nullable || nullOnOverflow
   override def toString: String = s"MakeDecimal($child,$precision,$scale)"
 
-  protected override def nullSafeEval(input: Any): Any =
-    Decimal(input.asInstanceOf[Long], precision, scale)
+  protected override def nullSafeEval(input: Any): Any = {
+    val longInput = input.asInstanceOf[Long]
+    val result = new Decimal()
+    if (nullOnOverflow) {
+      result.setOrNull(longInput, precision, scale)
+    } else {
+      result.set(longInput, precision, scale)
+    }
+  }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     nullSafeCodeGen(ctx, ev, eval => {
+      val setMethod = if (nullOnOverflow) {
+        "setOrNull"
+      } else {
+        "set"
+      }
+      val setNull = if (nullable) {
+        s"${ev.isNull} = ${ev.value} == null;"
+      } else {
+        ""
+      }
       s"""
-        ${ev.value} = (new Decimal()).setOrNull($eval, $precision, $scale);
-        ${ev.isNull} = ${ev.value} == null;
-      """
+         |${ev.value} = (new Decimal()).$setMethod($eval, $precision, $scale);
+         |$setNull
+         |""".stripMargin
     })
   }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index b7b7097..1bf322a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -76,7 +76,7 @@ final class Decimal extends Ordered[Decimal] with 
Serializable {
    */
   def set(unscaled: Long, precision: Int, scale: Int): Decimal = {
     if (setOrNull(unscaled, precision, scale) == null) {
-      throw new IllegalArgumentException("Unscaled value too large for 
precision")
+      throw new ArithmeticException("Unscaled value too large for precision")
     }
     this
   }
@@ -111,9 +111,10 @@ final class Decimal extends Ordered[Decimal] with 
Serializable {
    */
   def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = {
     this.decimalVal = decimal.setScale(scale, ROUND_HALF_UP)
-    require(
-      decimalVal.precision <= precision,
-      s"Decimal precision ${decimalVal.precision} exceeds max precision 
$precision")
+    if (decimalVal.precision > precision) {
+      throw new ArithmeticException(
+        s"Decimal precision ${decimalVal.precision} exceeds max precision 
$precision")
+    }
     this.longVal = 0L
     this._precision = precision
     this._scale = scale
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
index d14eceb..fc5e8dc 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{Decimal, DecimalType, LongType}
 
 class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -31,8 +32,23 @@ class DecimalExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper {
   }
 
   test("MakeDecimal") {
-    checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1"))
-    checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null)
+    withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "true") {
+      checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1"))
+      checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null)
+      val overflowExpr = MakeDecimal(Literal.create(1000L, LongType), 3, 1)
+      checkEvaluation(overflowExpr, null)
+      checkEvaluationWithMutableProjection(overflowExpr, null)
+      evaluateWithoutCodegen(overflowExpr, null)
+      checkEvaluationWithUnsafeProjection(overflowExpr, null)
+    }
+    withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "false") {
+      checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1"))
+      checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null)
+      val overflowExpr = MakeDecimal(Literal.create(1000L, LongType), 3, 1)
+      
intercept[ArithmeticException](checkEvaluationWithMutableProjection(overflowExpr,
 null))
+      intercept[ArithmeticException](evaluateWithoutCodegen(overflowExpr, 
null))
+      
intercept[ArithmeticException](checkEvaluationWithUnsafeProjection(overflowExpr,
 null))
+    }
   }
 
   test("PromotePrecision") {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
index 8abd762..d69bb2f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
@@ -56,11 +56,11 @@ class DecimalSuite extends SparkFunSuite with 
PrivateMethodTester {
     checkDecimal(Decimal(1000000000000000000L, 20, 2), "10000000000000000.00", 
20, 2)
     checkDecimal(Decimal(Long.MaxValue), Long.MaxValue.toString, 20, 0)
     checkDecimal(Decimal(Long.MinValue), Long.MinValue.toString, 20, 0)
-    intercept[IllegalArgumentException](Decimal(170L, 2, 1))
-    intercept[IllegalArgumentException](Decimal(170L, 2, 0))
-    intercept[IllegalArgumentException](Decimal(BigDecimal("10.030"), 2, 1))
-    intercept[IllegalArgumentException](Decimal(BigDecimal("-9.95"), 2, 1))
-    intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0))
+    intercept[ArithmeticException](Decimal(170L, 2, 1))
+    intercept[ArithmeticException](Decimal(170L, 2, 0))
+    intercept[ArithmeticException](Decimal(BigDecimal("10.030"), 2, 1))
+    intercept[ArithmeticException](Decimal(BigDecimal("-9.95"), 2, 1))
+    intercept[ArithmeticException](Decimal(1e17.toLong, 17, 0))
   }
 
   test("creating decimals with negative scale") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to