Repository: spark Updated Branches: refs/heads/master ddc5baf17 -> 31641128b
[SPARK-8363][SQL] Move sqrt to math and extend UnaryMathExpression JIRA: https://issues.apache.org/jira/browse/SPARK-8363 Author: Liang-Chi Hsieh <[email protected]> Closes #6823 from viirya/move_sqrt and squashes the following commits: 8977e11 [Liang-Chi Hsieh] Remove unnecessary old tests. d23e79e [Liang-Chi Hsieh] Explicitly indicate sqrt value sequence. 699f48b [Liang-Chi Hsieh] Use correct @since tag. 8dff6d1 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into move_sqrt bc2ed77 [Liang-Chi Hsieh] Remove/move arithmetic expression test and expression type checking test. Remove unnecessary Sqrt type rule. d38492f [Liang-Chi Hsieh] Now sqrt accepts boolean because type casting is handled by HiveTypeCoercion. 297cc90 [Liang-Chi Hsieh] Sqrt only accepts double input. ef4a21a [Liang-Chi Hsieh] Move sqrt to math. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/31641128 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/31641128 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/31641128 Branch: refs/heads/master Commit: 31641128b34d6f2aa7cb67324c24dd8b3ed84689 Parents: ddc5baf Author: Liang-Chi Hsieh <[email protected]> Authored: Thu Jun 18 13:00:31 2015 -0700 Committer: Reynold Xin <[email protected]> Committed: Thu Jun 18 13:00:31 2015 -0700 ---------------------------------------------------------------------- .../catalyst/analysis/HiveTypeCoercion.scala | 1 - .../sql/catalyst/expressions/arithmetic.scala | 32 -------------------- .../spark/sql/catalyst/expressions/math.scala | 2 ++ .../expressions/ArithmeticExpressionSuite.scala | 15 --------- .../ExpressionTypeCheckingSuite.scala | 2 -- .../expressions/MathFunctionsSuite.scala | 10 ++++++ .../scala/org/apache/spark/sql/functions.scala | 10 +++++- .../apache/spark/sql/MathExpressionsSuite.scala | 10 ++++++ 8 files changed, 31 insertions(+), 51 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/31641128/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 189451d..8012b22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -307,7 +307,6 @@ trait HiveTypeCoercion { case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) - case Sqrt(e @ StringType()) => Sqrt(Cast(e, DoubleType)) } } http://git-wip-us.apache.org/repos/asf/spark/blob/31641128/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 167e460..ace8427 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -67,38 +67,6 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic { protected override def evalInternal(evalE: Any) = evalE } -case class Sqrt(child: Expression) extends UnaryArithmetic { - override def dataType: DataType = DoubleType - override def nullable: Boolean = true - override def toString: String = s"SQRT($child)" - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function sqrt") - - private lazy val numeric = TypeUtils.getNumeric(child.dataType) - - protected override def evalInternal(evalE: Any) = { - val value = numeric.toDouble(evalE) - if (value < 0) null - else math.sqrt(value) - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val eval = child.gen(ctx) - eval.code + s""" - boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - if (${eval.primitive} < 0.0) { - ${ev.isNull} = true; - } else { - ${ev.primitive} = java.lang.Math.sqrt(${eval.primitive}); - } - } - """ - } -} - /** * A function that get the absolute value of the numeric value. */ http://git-wip-us.apache.org/repos/asf/spark/blob/31641128/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 67cb0b5..3b83c6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -193,6 +193,8 @@ case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH") +case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT") + case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH") http://git-wip-us.apache.org/repos/asf/spark/blob/31641128/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 3f48432..4bbbbe6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -142,19 +142,4 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MinOf(Literal.create(null, IntegerType), 1), 1) checkEvaluation(MinOf(1, Literal.create(null, IntegerType)), 1) } - - test("SQRT") { - val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24)) - val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble)) - val rowSequence = inputSequence.map(l => create_row(l.toDouble)) - val d = 'a.double.at(0) - - for ((row, expected) <- rowSequence zip expectedResults) { - checkEvaluation(Sqrt(d), expected, row) - } - - checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) - checkEvaluation(Sqrt(-1), null, EmptyRow) - checkEvaluation(Sqrt(-1.5), null, EmptyRow) - } } http://git-wip-us.apache.org/repos/asf/spark/blob/31641128/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala index dcb3635..49b1119 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala @@ -54,8 +54,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { test("check types for unary arithmetic") { assertError(UnaryMinus('stringField), "operator - accepts numeric type") - assertSuccess(Sqrt('stringField)) // We will cast String to Double for sqrt - assertError(Sqrt('booleanField), "function sqrt accepts numeric type") assertError(Abs('stringField), "function abs accepts numeric type") assertError(BitwiseNot('stringField), "operator ~ accepts integral type") } http://git-wip-us.apache.org/repos/asf/spark/blob/31641128/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 0050ad3..21e9b92 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.DoubleType @@ -191,6 +192,15 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { testUnary(Log2, f, (-5 to -1).map(_ * 1.0), expectNull = true) } + test("sqrt") { + testUnary(Sqrt, math.sqrt, (0 to 20).map(_ * 0.1)) + testUnary(Sqrt, math.sqrt, (-5 to -1).map(_ * 1.0), expectNull = true) + + checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) + checkEvaluation(Sqrt(Literal(-1.0)), null, EmptyRow) + checkEvaluation(Sqrt(Literal(-1.5)), null, EmptyRow) + } + test("pow") { testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true) http://git-wip-us.apache.org/repos/asf/spark/blob/31641128/sql/core/src/main/scala/org/apache/spark/sql/functions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index dff0932..d8a91be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -707,12 +707,20 @@ object functions { /** * Computes the square root of the specified float value. * - * @group normal_funcs + * @group math_funcs * @since 1.3.0 */ def sqrt(e: Column): Column = Sqrt(e.expr) /** + * Computes the square root of the specified float value. + * + * @group math_funcs + * @since 1.5.0 + */ + def sqrt(colName: String): Column = sqrt(Column(colName)) + + /** * Creates a new struct column. The input column must be a column in a [[DataFrame]], or * a derived column expression that is named (i.e. aliased). * http://git-wip-us.apache.org/repos/asf/spark/blob/31641128/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 7c9c121..2768d7d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -270,6 +270,16 @@ class MathExpressionsSuite extends QueryTest { checkAnswer(ctx.sql("SELECT LOG2(8), LOG2(null)"), Row(3, null)) } + test("sqrt") { + val df = Seq((1, 4)).toDF("a", "b") + checkAnswer( + df.select(sqrt("a"), sqrt("b")), + Row(1.0, 2.0)) + + checkAnswer(ctx.sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null)) + checkAnswer(df.selectExpr("sqrt(a)", "sqrt(b)", "sqrt(null)"), Row(1.0, 2.0, null)) + } + test("negative") { checkAnswer( ctx.sql("SELECT negative(1), negative(0), negative(-1)"), --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
