Repository: spark
Updated Branches:
  refs/heads/master 3f6296fed -> f0e129740


[SPARK-8279][SQL]Add math function round

JIRA: https://issues.apache.org/jira/browse/SPARK-8279

Author: Yijie Shen <[email protected]>

Closes #6938 from yijieshen/udf_round_3 and squashes the following commits:

07a124c [Yijie Shen] remove useless def children
392b65b [Yijie Shen] add negative scale test in DecimalSuite
61760ee [Yijie Shen] address reviews
302a78a [Yijie Shen] Add dataframe function test
31dfe7c [Yijie Shen] refactor round to make it readable
8c7a949 [Yijie Shen] rebase & inputTypes update
9555e35 [Yijie Shen] tiny style fix
d10be4a [Yijie Shen] use TypeCollection to specify wanted input and implicit 
cast
c3b9839 [Yijie Shen] rely on implict cast to handle string input
b0bff79 [Yijie Shen] make round's inner method's name more meaningful
9bd6930 [Yijie Shen] revert accidental change
e6f44c4 [Yijie Shen] refactor eval and genCode
1b87540 [Yijie Shen] modify checkInputDataTypes using foldable
5486b2d [Yijie Shen] DataFrame API modification
2077888 [Yijie Shen] codegen versioned eval
6cd9a64 [Yijie Shen] refactor Round's constructor
9be894e [Yijie Shen] add round functions in o.a.s.sql.functions
7c83e13 [Yijie Shen] more tests on round
56db4bb [Yijie Shen] Add decimal support to Round
7e163ae [Yijie Shen] style fix
653d047 [Yijie Shen] Add math function round


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f0e12974
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f0e12974
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f0e12974

Branch: refs/heads/master
Commit: f0e129740dc2442a21dfa7fbd97360df87291095
Parents: 3f6296f
Author: Yijie Shen <[email protected]>
Authored: Tue Jul 14 23:30:41 2015 -0700
Committer: Reynold Xin <[email protected]>
Committed: Tue Jul 14 23:30:41 2015 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    |   1 +
 .../spark/sql/catalyst/expressions/math.scala   | 203 ++++++++++++++++++-
 .../analysis/ExpressionTypeCheckingSuite.scala  |  17 ++
 .../expressions/MathFunctionsSuite.scala        |  44 ++++
 .../spark/sql/types/decimal/DecimalSuite.scala  |  23 ++-
 .../scala/org/apache/spark/sql/functions.scala  |  32 +++
 .../apache/spark/sql/MathExpressionsSuite.scala |  15 ++
 .../hive/execution/HiveCompatibilitySuite.scala |   7 +-
 8 files changed, 329 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f0e12974/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 6b1a94e..ec75f51 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -117,6 +117,7 @@ object FunctionRegistry {
     expression[Pow]("power"),
     expression[UnaryPositive]("positive"),
     expression[Rint]("rint"),
+    expression[Round]("round"),
     expression[ShiftLeft]("shiftleft"),
     expression[ShiftRight]("shiftright"),
     expression[ShiftRightUnsigned]("shiftrightunsigned"),

http://git-wip-us.apache.org/repos/asf/spark/blob/f0e12974/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index 4b7fe05..a7ad452 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -19,8 +19,10 @@ package org.apache.spark.sql.catalyst.expressions
 
 import java.{lang => jl}
 
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, 
TypeCheckFailure}
 import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
@@ -520,3 +522,202 @@ case class Logarithm(left: Expression, right: Expression)
     """
   }
 }
+
+/**
+ * Round the `child`'s result to `scale` decimal place when `scale` >= 0
+ * or round at integral part when `scale` < 0.
+ * For example, round(31.415, 2) would eval to 31.42 and round(31.415, -1) 
would eval to 30.
+ *
+ * Child of IntegralType would eval to itself when `scale` >= 0.
+ * Child of FractionalType whose value is NaN or Infinite would always eval to 
itself.
+ *
+ * Round's dataType would always equal to `child`'s dataType except for 
[[DecimalType.Fixed]],
+ * which leads to scale update in DecimalType's [[PrecisionInfo]]
+ *
+ * @param child expr to be round, all [[NumericType]] is allowed as Input
+ * @param scale new scale to be round to, this should be a constant int at 
runtime
+ */
+case class Round(child: Expression, scale: Expression)
+  extends BinaryExpression with ExpectsInputTypes {
+
+  import BigDecimal.RoundingMode.HALF_UP
+
+  def this(child: Expression) = this(child, Literal(0))
+
+  override def left: Expression = child
+  override def right: Expression = scale
+
+  // round of Decimal would eval to null if it fails to `changePrecision`
+  override def nullable: Boolean = true
+
+  override def foldable: Boolean = child.foldable
+
+  override lazy val dataType: DataType = child.dataType match {
+    // if the new scale is bigger which means we are scaling up,
+    // keep the original scale as `Decimal` does
+    case DecimalType.Fixed(p, s) => DecimalType(p, if (_scale > s) s else 
_scale)
+    case t => t
+  }
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, 
IntegerType)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    super.checkInputDataTypes() match {
+      case TypeCheckSuccess =>
+        if (scale.foldable) {
+          TypeCheckSuccess
+        } else {
+          TypeCheckFailure("Only foldable Expression is allowed for scale 
arguments")
+        }
+      case f => f
+    }
+  }
+
+  // Avoid repeated evaluation since `scale` is a constant int,
+  // avoid unnecessary `child` evaluation in both codegen and non-codegen eval
+  // by checking if scaleV == null as well.
+  private lazy val scaleV: Any = scale.eval(EmptyRow)
+  private lazy val _scale: Int = scaleV.asInstanceOf[Int]
+
+  override def eval(input: InternalRow): Any = {
+    if (scaleV == null) { // if scale is null, no need to eval its child at all
+      null
+    } else {
+      val evalE = child.eval(input)
+      if (evalE == null) {
+        null
+      } else {
+        nullSafeEval(evalE)
+      }
+    }
+  }
+
+  // not overriding since _scale is a constant int at runtime
+  def nullSafeEval(input1: Any): Any = {
+    child.dataType match {
+      case _: DecimalType =>
+        val decimal = input1.asInstanceOf[Decimal]
+        if (decimal.changePrecision(decimal.precision, _scale)) decimal else 
null
+      case ByteType =>
+        BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, HALF_UP).toByte
+      case ShortType =>
+        BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, 
HALF_UP).toShort
+      case IntegerType =>
+        BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, HALF_UP).toInt
+      case LongType =>
+        BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, HALF_UP).toLong
+      case FloatType =>
+        val f = input1.asInstanceOf[Float]
+        if (f.isNaN || f.isInfinite) {
+          f
+        } else {
+          BigDecimal(f).setScale(_scale, HALF_UP).toFloat
+        }
+      case DoubleType =>
+        val d = input1.asInstanceOf[Double]
+        if (d.isNaN || d.isInfinite) {
+          d
+        } else {
+          BigDecimal(d).setScale(_scale, HALF_UP).toDouble
+        }
+    }
+  }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
+    val ce = child.gen(ctx)
+
+    val evaluationCode = child.dataType match {
+      case _: DecimalType =>
+        s"""
+        if (${ce.primitive}.changePrecision(${ce.primitive}.precision(), 
${_scale})) {
+          ${ev.primitive} = ${ce.primitive};
+        } else {
+          ${ev.isNull} = true;
+        }"""
+      case ByteType =>
+        if (_scale < 0) {
+          s"""
+          ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
+            setScale(${_scale}, 
java.math.BigDecimal.ROUND_HALF_UP).byteValue();"""
+        } else {
+          s"${ev.primitive} = ${ce.primitive};"
+        }
+      case ShortType =>
+        if (_scale < 0) {
+          s"""
+          ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
+            setScale(${_scale}, 
java.math.BigDecimal.ROUND_HALF_UP).shortValue();"""
+        } else {
+          s"${ev.primitive} = ${ce.primitive};"
+        }
+      case IntegerType =>
+        if (_scale < 0) {
+          s"""
+          ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
+            setScale(${_scale}, 
java.math.BigDecimal.ROUND_HALF_UP).intValue();"""
+        } else {
+          s"${ev.primitive} = ${ce.primitive};"
+        }
+      case LongType =>
+        if (_scale < 0) {
+          s"""
+          ${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
+            setScale(${_scale}, 
java.math.BigDecimal.ROUND_HALF_UP).longValue();"""
+        } else {
+          s"${ev.primitive} = ${ce.primitive};"
+        }
+      case FloatType => // if child eval to NaN or Infinity, just return it.
+        if (_scale == 0) {
+          s"""
+            if (Float.isNaN(${ce.primitive}) || 
Float.isInfinite(${ce.primitive})){
+              ${ev.primitive} = ${ce.primitive};
+            } else {
+              ${ev.primitive} = Math.round(${ce.primitive});
+            }"""
+        } else {
+          s"""
+            if (Float.isNaN(${ce.primitive}) || 
Float.isInfinite(${ce.primitive})){
+              ${ev.primitive} = ${ce.primitive};
+            } else {
+              ${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}).
+                setScale(${_scale}, 
java.math.BigDecimal.ROUND_HALF_UP).floatValue();
+            }"""
+        }
+      case DoubleType => // if child eval to NaN or Infinity, just return it.
+        if (_scale == 0) {
+          s"""
+            if (Double.isNaN(${ce.primitive}) || 
Double.isInfinite(${ce.primitive})){
+              ${ev.primitive} = ${ce.primitive};
+            } else {
+              ${ev.primitive} = Math.round(${ce.primitive});
+            }"""
+        } else {
+          s"""
+            if (Double.isNaN(${ce.primitive}) || 
Double.isInfinite(${ce.primitive})){
+              ${ev.primitive} = ${ce.primitive};
+            } else {
+              ${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}).
+                setScale(${_scale}, 
java.math.BigDecimal.ROUND_HALF_UP).doubleValue();
+            }"""
+        }
+    }
+
+    if (scaleV == null) { // if scale is null, no need to eval its child at all
+      s"""
+        boolean ${ev.isNull} = true;
+        ${ctx.javaType(dataType)} ${ev.primitive} = 
${ctx.defaultValue(dataType)};
+      """
+    } else {
+      s"""
+        ${ce.code}
+        boolean ${ev.isNull} = ${ce.isNull};
+        ${ctx.javaType(dataType)} ${ev.primitive} = 
${ctx.defaultValue(dataType)};
+        if (!${ev.isNull}) {
+          $evaluationCode
+        }
+      """
+    }
+  }
+
+  override def prettyName: String = "round"
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f0e12974/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 5958acb..e885a18 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -52,6 +52,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
       s"differing types in '${expr.prettyString}' (int and boolean)")
   }
 
+  def assertErrorWithImplicitCast(expr: Expression, errorMessage: String): 
Unit = {
+    val e = intercept[AnalysisException] {
+      assertSuccess(expr)
+    }
+    assert(e.getMessage.contains(errorMessage))
+  }
+
   test("check types for unary arithmetic") {
     assertError(UnaryMinus('stringField), "operator - accepts numeric type")
     assertError(Abs('stringField), "function abs accepts numeric type")
@@ -171,4 +178,14 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
       CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)),
         "Odd position only allow foldable and not-null StringType expressions")
   }
+
+  test("check types for ROUND") {
+    assertErrorWithImplicitCast(Round(Literal(null), 'booleanField),
+      "data type mismatch: argument 2 is expected to be of type int")
+    assertErrorWithImplicitCast(Round(Literal(null), 'complexField),
+      "data type mismatch: argument 2 is expected to be of type int")
+    assertSuccess(Round(Literal(null), Literal(null)))
+    assertError(Round('booleanField, 'intField),
+      "data type mismatch: argument 1 is expected to be of type numeric")
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f0e12974/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
index 7ca9e30..52a874a 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import scala.math.BigDecimal.RoundingMode
+
 import com.google.common.math.LongMath
 
 import org.apache.spark.SparkFunSuite
@@ -336,4 +338,46 @@ class MathFunctionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
       null,
       create_row(null))
   }
+
+  test("round") {
+    val domain = -6 to 6
+    val doublePi: Double = math.Pi
+    val shortPi: Short = 31415
+    val intPi: Int = 314159265
+    val longPi: Long = 31415926535897932L
+    val bdPi: BigDecimal = BigDecimal(31415927L, 7)
+
+    val doubleResults: Seq[Double] = Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 
3.1, 3.14, 3.142,
+      3.1416, 3.14159, 3.141593)
+
+    val shortResults: Seq[Short] = Seq[Short](0, 0, 30000, 31000, 31400, 
31420) ++
+      Seq.fill[Short](7)(31415)
+
+    val intResults: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 
314159300,
+      314159270) ++ Seq.fill(7)(314159265)
+
+    val longResults: Seq[Long] = Seq(31415926536000000L, 31415926535900000L,
+      31415926535900000L, 31415926535898000L, 31415926535897900L, 
31415926535897930L) ++
+      Seq.fill(7)(31415926535897932L)
+
+    val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), 
BigDecimal(3.14),
+      BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159),
+      BigDecimal(3.141593), BigDecimal(3.1415927))
+
+    domain.zipWithIndex.foreach { case (scale, i) =>
+      checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow)
+      checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow)
+      checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow)
+      checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow)
+    }
+
+    // round_scale > current_scale would result in precision increase
+    // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null
+    (0 to 7).foreach { i =>
+      checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow)
+    }
+    (8 to 10).foreach { scale =>
+      checkEvaluation(Round(bdPi, scale), null, EmptyRow)
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f0e12974/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala
index 030bb6d..f0c849d 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala
@@ -24,14 +24,14 @@ import org.scalatest.PrivateMethodTester
 import scala.language.postfixOps
 
 class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
-  test("creating decimals") {
-    /** Check that a Decimal has the given string representation, precision 
and scale */
-    def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): 
Unit = {
-      assert(d.toString === string)
-      assert(d.precision === precision)
-      assert(d.scale === scale)
-    }
+  /** Check that a Decimal has the given string representation, precision and 
scale */
+  private def checkDecimal(d: Decimal, string: String, precision: Int, scale: 
Int): Unit = {
+    assert(d.toString === string)
+    assert(d.precision === precision)
+    assert(d.scale === scale)
+  }
 
+  test("creating decimals") {
     checkDecimal(new Decimal(), "0", 1, 0)
     checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3)
     checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1)
@@ -53,6 +53,15 @@ class DecimalSuite extends SparkFunSuite with 
PrivateMethodTester {
     intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0))
   }
 
+  test("creating decimals with negative scale") {
+    checkDecimal(Decimal(BigDecimal("98765"), 5, -3), "9.9E+4", 5, -3)
+    checkDecimal(Decimal(BigDecimal("314.159"), 6, -2), "3E+2", 6, -2)
+    checkDecimal(Decimal(BigDecimal(1.579e12), 4, -9), "1.579E+12", 4, -9)
+    checkDecimal(Decimal(BigDecimal(1.579e12), 4, -10), "1.58E+12", 4, -10)
+    checkDecimal(Decimal(103050709L, 9, -10), "1.03050709E+18", 9, -10)
+    checkDecimal(Decimal(1e8.toLong, 10, -10), "1.00000000E+18", 10, -10)
+  }
+
   test("double and long values") {
     /** Check that a Decimal converts to the given double and long values */
     def checkValues(d: Decimal, doubleValue: Double, longValue: Long): Unit = {

http://git-wip-us.apache.org/repos/asf/spark/blob/f0e12974/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 0d4e160..5119ee3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1390,6 +1390,38 @@ object functions {
   def rint(columnName: String): Column = rint(Column(columnName))
 
   /**
+   * Returns the value of the column `e` rounded to 0 decimal places.
+   *
+   * @group math_funcs
+   * @since 1.5.0
+   */
+  def round(e: Column): Column = round(e.expr, 0)
+
+  /**
+   * Returns the value of the given column rounded to 0 decimal places.
+   *
+   * @group math_funcs
+   * @since 1.5.0
+   */
+  def round(columnName: String): Column = round(Column(columnName), 0)
+
+  /**
+   * Returns the value of `e` rounded to `scale` decimal places.
+   *
+   * @group math_funcs
+   * @since 1.5.0
+   */
+  def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale))
+
+  /**
+   * Returns the value of the given column rounded to `scale` decimal places.
+   *
+   * @group math_funcs
+   * @since 1.5.0
+   */
+  def round(columnName: String, scale: Int): Column = 
round(Column(columnName), scale)
+
+  /**
    * Shift the the given value numBits left. If the given value is a long 
value, this function
    * will return a long value else it will return an integer value.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/f0e12974/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index b30b9f1..087126b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -198,6 +198,21 @@ class MathExpressionsSuite extends QueryTest {
     testOneToOneMathFunction(rint, math.rint)
   }
 
+  test("round") {
+    val df = Seq(5, 55, 555).map(Tuple1(_)).toDF("a")
+    checkAnswer(
+      df.select(round('a), round('a, -1), round('a, -2)),
+      Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600))
+    )
+
+    val pi = 3.1415
+    checkAnswer(
+      ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " +
+        s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"),
+      Seq(Row(0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142))
+    )
+  }
+
   test("exp") {
     testOneToOneMathFunction(exp, math.exp)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/f0e12974/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
 
b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index c884c39..4ada64b 100644
--- 
a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ 
b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -221,9 +221,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with 
BeforeAndAfter {
     "udf_when",
     "udf_case",
 
-    // Needs constant object inspectors
-    "udf_round",
-
     // the table src(key INT, value STRING) is not the same as HIVE unittest. 
In Hive
     // is src(key STRING, value STRING), and in the reflect.q, it failed in
     // Integer.valueOf, which expect the first argument passed as STRING type 
not INT.
@@ -918,8 +915,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with 
BeforeAndAfter {
     "udf_regexp_replace",
     "udf_repeat",
     "udf_rlike",
-    "udf_round",
-    //  "udf_round_3",  TODO: FIX THIS failed due to cast exception
+    // "udf_round",  turn this on after we figure out null vs nan vs infinity
+    "udf_round_3",
     "udf_rpad",
     "udf_rtrim",
     "udf_second",


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to