Repository: spark
Updated Branches:
  refs/heads/branch-1.5 77269fcb5 -> d9dfd43d4


[SPARK-10090] [SQL] fix decimal scale of division

We should rounding the result of multiply/division of decimal to expected 
precision/scale, also check overflow.

Author: Davies Liu <[email protected]>

Closes #8287 from davies/decimal_division.

(cherry picked from commit 1f4c4fe6dfd8cc52b5fddfd67a31a77edbb1a036)
Signed-off-by: Michael Armbrust <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d9dfd43d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d9dfd43d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d9dfd43d

Branch: refs/heads/branch-1.5
Commit: d9dfd43d463cd5e3c9e72197850382ad8897b8ac
Parents: 77269fc
Author: Davies Liu <[email protected]>
Authored: Wed Aug 19 14:03:47 2015 -0700
Committer: Michael Armbrust <[email protected]>
Committed: Wed Aug 19 14:04:09 2015 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/HiveTypeCoercion.scala    | 28 +++++----
 .../spark/sql/catalyst/expressions/Cast.scala   | 32 +++++-----
 .../catalyst/expressions/decimalFunctions.scala | 38 +++++++++++-
 .../org/apache/spark/sql/types/Decimal.scala    |  4 +-
 .../expressions/DecimalExpressionSuite.scala    | 63 ++++++++++++++++++++
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 23 ++++++-
 6 files changed, 157 insertions(+), 31 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d9dfd43d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 8581d6b..62c27ee 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -371,8 +371,8 @@ object HiveTypeCoercion {
       DecimalType.bounded(range + scale, scale)
     }
 
-    private def changePrecision(e: Expression, dataType: DataType): Expression 
= {
-      ChangeDecimalPrecision(Cast(e, dataType))
+    private def promotePrecision(e: Expression, dataType: DataType): 
Expression = {
+      PromotePrecision(Cast(e, dataType))
     }
 
     def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
@@ -383,36 +383,42 @@ object HiveTypeCoercion {
         case e if !e.childrenResolved => e
 
         // Skip nodes who is already promoted
-        case e: BinaryArithmetic if 
e.left.isInstanceOf[ChangeDecimalPrecision] => e
+        case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e
 
         case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ 
DecimalType.Expression(p2, s2)) =>
           val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 
1, max(s1, s2))
-          Add(changePrecision(e1, dt), changePrecision(e2, dt))
+          CheckOverflow(Add(promotePrecision(e1, dt), promotePrecision(e2, 
dt)), dt)
 
         case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ 
DecimalType.Expression(p2, s2)) =>
           val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 
1, max(s1, s2))
-          Subtract(changePrecision(e1, dt), changePrecision(e2, dt))
+          CheckOverflow(Subtract(promotePrecision(e1, dt), 
promotePrecision(e2, dt)), dt)
 
         case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ 
DecimalType.Expression(p2, s2)) =>
-          val dt = DecimalType.bounded(p1 + p2 + 1, s1 + s2)
-          Multiply(changePrecision(e1, dt), changePrecision(e2, dt))
+          val resultType = DecimalType.bounded(p1 + p2 + 1, s1 + s2)
+          val widerType = widerDecimalType(p1, s1, p2, s2)
+          CheckOverflow(Multiply(promotePrecision(e1, widerType), 
promotePrecision(e2, widerType)),
+            resultType)
 
         case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ 
DecimalType.Expression(p2, s2)) =>
-          val dt = DecimalType.bounded(p1 - s1 + s2 + max(6, s1 + p2 + 1), 
max(6, s1 + p2 + 1))
-          Divide(changePrecision(e1, dt), changePrecision(e2, dt))
+          val resultType = DecimalType.bounded(p1 - s1 + s2 + max(6, s1 + p2 + 
1),
+            max(6, s1 + p2 + 1))
+          val widerType = widerDecimalType(p1, s1, p2, s2)
+          CheckOverflow(Divide(promotePrecision(e1, widerType), 
promotePrecision(e2, widerType)),
+            resultType)
 
         case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ 
DecimalType.Expression(p2, s2)) =>
           val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, 
s2), max(s1, s2))
           // resultType may have lower precision, so we cast them into wider 
type first.
           val widerType = widerDecimalType(p1, s1, p2, s2)
-          Cast(Remainder(changePrecision(e1, widerType), changePrecision(e2, 
widerType)),
+          CheckOverflow(Remainder(promotePrecision(e1, widerType), 
promotePrecision(e2, widerType)),
             resultType)
 
         case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ 
DecimalType.Expression(p2, s2)) =>
           val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, 
s2), max(s1, s2))
           // resultType may have lower precision, so we cast them into wider 
type first.
           val widerType = widerDecimalType(p1, s1, p2, s2)
-          Cast(Pmod(changePrecision(e1, widerType), changePrecision(e2, 
widerType)), resultType)
+          CheckOverflow(Pmod(promotePrecision(e1, widerType), 
promotePrecision(e2, widerType)),
+            resultType)
 
         case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),
                                   e2 @ DecimalType.Expression(p2, s2)) if p1 
!= p2 || s1 != s2 =>

http://git-wip-us.apache.org/repos/asf/spark/blob/d9dfd43d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 616b9e0..2db9542 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -447,7 +447,7 @@ case class Cast(child: Expression, dataType: DataType)
     case StringType => castToStringCode(from, ctx)
     case BinaryType => castToBinaryCode(from)
     case DateType => castToDateCode(from, ctx)
-    case decimal: DecimalType => castToDecimalCode(from, decimal)
+    case decimal: DecimalType => castToDecimalCode(from, decimal, ctx)
     case TimestampType => castToTimestampCode(from, ctx)
     case CalendarIntervalType => castToIntervalCode(from)
     case BooleanType => castToBooleanCode(from)
@@ -528,14 +528,18 @@ case class Cast(child: Expression, dataType: DataType)
       }
     """
 
-  private[this] def castToDecimalCode(from: DataType, target: DecimalType): 
CastFunction = {
+  private[this] def castToDecimalCode(
+      from: DataType,
+      target: DecimalType,
+      ctx: CodeGenContext): CastFunction = {
+    val tmp = ctx.freshName("tmpDecimal")
     from match {
       case StringType =>
         (c, evPrim, evNull) =>
           s"""
             try {
-              Decimal tmpDecimal = Decimal.apply(new 
java.math.BigDecimal($c.toString()));
-              ${changePrecision("tmpDecimal", target, evPrim, evNull)}
+              Decimal $tmp = Decimal.apply(new 
java.math.BigDecimal($c.toString()));
+              ${changePrecision(tmp, target, evPrim, evNull)}
             } catch (java.lang.NumberFormatException e) {
               $evNull = true;
             }
@@ -543,8 +547,8 @@ case class Cast(child: Expression, dataType: DataType)
       case BooleanType =>
         (c, evPrim, evNull) =>
           s"""
-            Decimal tmpDecimal = $c ? Decimal.apply(1) : Decimal.apply(0);
-            ${changePrecision("tmpDecimal", target, evPrim, evNull)}
+            Decimal $tmp = $c ? Decimal.apply(1) : Decimal.apply(0);
+            ${changePrecision(tmp, target, evPrim, evNull)}
           """
       case DateType =>
         // date can't cast to decimal in Hive
@@ -553,29 +557,29 @@ case class Cast(child: Expression, dataType: DataType)
         // Note that we lose precision here.
         (c, evPrim, evNull) =>
           s"""
-            Decimal tmpDecimal = Decimal.apply(
+            Decimal $tmp = Decimal.apply(
               scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)}));
-            ${changePrecision("tmpDecimal", target, evPrim, evNull)}
+            ${changePrecision(tmp, target, evPrim, evNull)}
           """
       case DecimalType() =>
         (c, evPrim, evNull) =>
           s"""
-            Decimal tmpDecimal = $c.clone();
-            ${changePrecision("tmpDecimal", target, evPrim, evNull)}
+            Decimal $tmp = $c.clone();
+            ${changePrecision(tmp, target, evPrim, evNull)}
           """
       case x: IntegralType =>
         (c, evPrim, evNull) =>
           s"""
-            Decimal tmpDecimal = Decimal.apply((long) $c);
-            ${changePrecision("tmpDecimal", target, evPrim, evNull)}
+            Decimal $tmp = Decimal.apply((long) $c);
+            ${changePrecision(tmp, target, evPrim, evNull)}
           """
       case x: FractionalType =>
         // All other numeric types can be represented precisely as Doubles
         (c, evPrim, evNull) =>
           s"""
             try {
-              Decimal tmpDecimal = 
Decimal.apply(scala.math.BigDecimal.valueOf((double) $c));
-              ${changePrecision("tmpDecimal", target, evPrim, evNull)}
+              Decimal $tmp = 
Decimal.apply(scala.math.BigDecimal.valueOf((double) $c));
+              ${changePrecision(tmp, target, evPrim, evNull)}
             } catch (java.lang.NumberFormatException e) {
               $evNull = true;
             }

http://git-wip-us.apache.org/repos/asf/spark/blob/d9dfd43d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
index adb33e4..b7be12f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
@@ -66,10 +66,44 @@ case class MakeDecimal(child: Expression, precision: Int, 
scale: Int) extends Un
  * An expression used to wrap the children when promote the precision of 
DecimalType to avoid
  * promote multiple times.
  */
-case class ChangeDecimalPrecision(child: Expression) extends UnaryExpression {
+case class PromotePrecision(child: Expression) extends UnaryExpression {
   override def dataType: DataType = child.dataType
   override def eval(input: InternalRow): Any = child.eval(input)
   override def gen(ctx: CodeGenContext): GeneratedExpressionCode = 
child.gen(ctx)
   override protected def genCode(ctx: CodeGenContext, ev: 
GeneratedExpressionCode): String = ""
-  override def prettyName: String = "change_decimal_precision"
+  override def prettyName: String = "promote_precision"
+}
+
+/**
+ * Rounds the decimal to given scale and check whether the decimal can fit in 
provided precision
+ * or not, returns null if not.
+ */
+case class CheckOverflow(child: Expression, dataType: DecimalType) extends 
UnaryExpression {
+
+  override def nullable: Boolean = true
+
+  override def nullSafeEval(input: Any): Any = {
+    val d = input.asInstanceOf[Decimal].clone()
+    if (d.changePrecision(dataType.precision, dataType.scale)) {
+      d
+    } else {
+      null
+    }
+  }
+
+  override protected def genCode(ctx: CodeGenContext, ev: 
GeneratedExpressionCode): String = {
+    nullSafeCodeGen(ctx, ev, eval => {
+      val tmp = ctx.freshName("tmp")
+      s"""
+         | Decimal $tmp = $eval.clone();
+         | if ($tmp.changePrecision(${dataType.precision}, ${dataType.scale})) 
{
+         |   ${ev.primitive} = $tmp;
+         | } else {
+         |   ${ev.isNull} = true;
+         | }
+       """.stripMargin
+    })
+  }
+
+  override def toString: String = s"CheckOverflow($child, $dataType)"
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d9dfd43d/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index d95805c..c988f1d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -267,7 +267,7 @@ final class Decimal extends Ordered[Decimal] with 
Serializable {
     if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == 
that.scale) {
       Decimal(longVal + that.longVal, Math.max(precision, that.precision), 
scale)
     } else {
-      Decimal(toBigDecimal + that.toBigDecimal, precision, scale)
+      Decimal(toBigDecimal + that.toBigDecimal)
     }
   }
 
@@ -275,7 +275,7 @@ final class Decimal extends Ordered[Decimal] with 
Serializable {
     if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == 
that.scale) {
       Decimal(longVal - that.longVal, Math.max(precision, that.precision), 
scale)
     } else {
-      Decimal(toBigDecimal - that.toBigDecimal, precision, scale)
+      Decimal(toBigDecimal - that.toBigDecimal)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d9dfd43d/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
new file mode 100644
index 0000000..511f030
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types.{LongType, DecimalType, Decimal}
+
+
+class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+  test("UnscaledValue") {
+    val d1 = Decimal("10.1")
+    checkEvaluation(UnscaledValue(Literal(d1)), 101L)
+    val d2 = Decimal(101, 3, 1)
+    checkEvaluation(UnscaledValue(Literal(d2)), 101L)
+    checkEvaluation(UnscaledValue(Literal.create(null, DecimalType(2, 1))), 
null)
+  }
+
+  test("MakeDecimal") {
+    checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1"))
+    checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null)
+  }
+
+  test("PromotePrecision") {
+    val d1 = Decimal("10.1")
+    checkEvaluation(PromotePrecision(Literal(d1)), d1)
+    val d2 = Decimal(101, 3, 1)
+    checkEvaluation(PromotePrecision(Literal(d2)), d2)
+    checkEvaluation(PromotePrecision(Literal.create(null, DecimalType(2, 1))), 
null)
+  }
+
+  test("CheckOverflow") {
+    val d1 = Decimal("10.1")
+    checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0)), 
Decimal("10"))
+    checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1)), d1)
+    checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2)), d1)
+    checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3)), null)
+
+    val d2 = Decimal(101, 3, 1)
+    checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0)), 
Decimal("10"))
+    checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1)), d2)
+    checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2)), d2)
+    checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3)), null)
+
+    checkEvaluation(CheckOverflow(Literal.create(null, DecimalType(2, 1)), 
DecimalType(3, 2)), null)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d9dfd43d/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index c329fdb..141468c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -17,16 +17,17 @@
 
 package org.apache.spark.sql
 
+import java.math.MathContext
 import java.sql.Timestamp
 
 import org.apache.spark.AccumulatorSuite
-import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
 import org.apache.spark.sql.catalyst.DefaultParserDialect
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
 import org.apache.spark.sql.catalyst.errors.DialectException
 import org.apache.spark.sql.execution.aggregate
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.test.SQLTestData._
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 
 /** A SQL Dialect for testing purpose, and it can not be nested type */
@@ -1608,6 +1609,24 @@ class SQLQuerySuite extends QueryTest with 
SharedSQLContext {
     }
   }
 
+  test("decimal precision with multiply/division") {
+    checkAnswer(sql("select 10.3 * 3.0"), Row(BigDecimal("30.90")))
+    checkAnswer(sql("select 10.3000 * 3.0"), Row(BigDecimal("30.90000")))
+    checkAnswer(sql("select 10.30000 * 30.0"), Row(BigDecimal("309.000000")))
+    checkAnswer(sql("select 10.300000000000000000 * 3.000000000000000000"),
+      Row(BigDecimal("30.900000000000000000000000000000000000", new 
MathContext(38))))
+    checkAnswer(sql("select 10.300000000000000000 * 3.0000000000000000000"),
+      Row(null))
+
+    checkAnswer(sql("select 10.3 / 3.0"), Row(BigDecimal("3.433333")))
+    checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333")))
+    checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333")))
+    checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"),
+      Row(BigDecimal("3.4333333333333333333333333333333333333", new 
MathContext(38))))
+    checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"),
+      Row(null))
+  }
+
   test("external sorting updates peak execution memory") {
     withSQLConf((SQLConf.EXTERNAL_SORT.key, "true")) {
       val sc = sqlContext.sparkContext


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to