Repository: spark
Updated Branches:
  refs/heads/branch-2.2 3d1908fd5 -> 2cac317a8


[SPARK-20665][SQL] Bround" and "Round" function return NULL

## What changes were proposed in this pull request?
   spark-sql>select bround(12.3, 2);
   spark-sql>NULL
For this case,  the expected result is 12.3, but it is null.
So ,when the second parameter is bigger than "decimal.scala", the result is not 
we expected.
"round" function  has the same problem. This PR can solve the problem for both 
of them.

## How was this patch tested?
unit test cases in MathExpressionsSuite and MathFunctionsSuite

Author: liuxian <liu.xi...@zte.com.cn>

Closes #17906 from 10110346/wip_lx_0509.

(cherry picked from commit 2b36eb696f6c738e1328582630755aaac4293460)
Signed-off-by: Wenchen Fan <wenc...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2cac317a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2cac317a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2cac317a

Branch: refs/heads/branch-2.2
Commit: 2cac317a84a234f034b0c75dcb5e4c27860a4cc0
Parents: 3d1908f
Author: liuxian <liu.xi...@zte.com.cn>
Authored: Fri May 12 11:38:50 2017 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Fri May 12 11:39:02 2017 +0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/mathExpressions.scala     | 12 ++++++------
 .../catalyst/expressions/MathExpressionsSuite.scala    |  7 +++----
 .../org/apache/spark/sql/MathFunctionsSuite.scala      | 13 +++++++++++++
 3 files changed, 22 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2cac317a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index c4d47ab..de1a46d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -1023,10 +1023,10 @@ abstract class RoundBase(child: Expression, scale: 
Expression,
 
   // not overriding since _scale is a constant int at runtime
   def nullSafeEval(input1: Any): Any = {
-    child.dataType match {
-      case _: DecimalType =>
+    dataType match {
+      case DecimalType.Fixed(_, s) =>
         val decimal = input1.asInstanceOf[Decimal]
-        decimal.toPrecision(decimal.precision, _scale, mode).orNull
+        decimal.toPrecision(decimal.precision, s, mode).orNull
       case ByteType =>
         BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte
       case ShortType =>
@@ -1055,10 +1055,10 @@ abstract class RoundBase(child: Expression, scale: 
Expression,
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val ce = child.genCode(ctx)
 
-    val evaluationCode = child.dataType match {
-      case _: DecimalType =>
+    val evaluationCode = dataType match {
+      case DecimalType.Fixed(_, s) =>
         s"""
-        if (${ce.value}.changePrecision(${ce.value}.precision(), ${_scale},
+        if (${ce.value}.changePrecision(${ce.value}.precision(), ${s},
             java.math.BigDecimal.${modeStr})) {
           ${ev.value} = ${ce.value};
         } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/2cac317a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
index 6b5bfac..1555dd1 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
@@ -546,15 +546,14 @@ class MathExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), 
BigDecimal(3.14),
       BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159),
       BigDecimal(3.141593), BigDecimal(3.1415927))
-    // round_scale > current_scale would result in precision increase
-    // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null
+
     (0 to 7).foreach { i =>
       checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow)
       checkEvaluation(BRound(bdPi, i), bdResults(i), EmptyRow)
     }
     (8 to 10).foreach { scale =>
-      checkEvaluation(Round(bdPi, scale), null, EmptyRow)
-      checkEvaluation(BRound(bdPi, scale), null, EmptyRow)
+      checkEvaluation(Round(bdPi, scale), bdPi, EmptyRow)
+      checkEvaluation(BRound(bdPi, scale), bdPi, EmptyRow)
     }
 
     DataTypeTestUtils.numericTypes.foreach { dataType =>

http://git-wip-us.apache.org/repos/asf/spark/blob/2cac317a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
index 328c539..c2d08a0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
@@ -231,6 +231,19 @@ class MathFunctionsSuite extends QueryTest with 
SharedSQLContext {
       Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), 
BigDecimal(3),
         BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142")))
     )
+
+    val bdPi: BigDecimal = BigDecimal(31415925L, 7)
+    checkAnswer(
+      sql(s"SELECT round($bdPi, 7), round($bdPi, 8), round($bdPi, 9), 
round($bdPi, 10), " +
+        s"round($bdPi, 100), round($bdPi, 6), round(null, 8)"),
+      Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141593"), null))
+    )
+
+    checkAnswer(
+      sql(s"SELECT bround($bdPi, 7), bround($bdPi, 8), bround($bdPi, 9), 
bround($bdPi, 10), " +
+        s"bround($bdPi, 100), bround($bdPi, 6), bround(null, 8)"),
+      Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141592"), null))
+    )
   }
 
   test("round/bround with data frame from a local Seq of Product") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to