Repository: spark Updated Branches: refs/heads/branch-2.0 d73ce364e -> 51706f8a4
[SPARK-14541][SQL] Support IFNULL, NULLIF, NVL and NVL2 ## What changes were proposed in this pull request? This patch adds support for a few SQL functions to improve compatibility with other databases: IFNULL, NULLIF, NVL and NVL2. In order to do this, this patch introduced a RuntimeReplaceable expression trait that allows replacing an unevaluable expression in the optimizer before evaluation. Note that the semantics are not completely identical to other databases in esoteric cases. ## How was this patch tested? Added a new test suite SQLCompatibilityFunctionSuite. Closes #12373. Author: Reynold Xin <[email protected]> Closes #13084 from rxin/SPARK-14541. (cherry picked from commit eda2800d44843b6478e22d2c99bca4af7e9c9613) Signed-off-by: Yin Huai <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/51706f8a Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/51706f8a Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/51706f8a Branch: refs/heads/branch-2.0 Commit: 51706f8a4dd94e235cf4e2c0627bc3788fec8251 Parents: d73ce36 Author: Reynold Xin <[email protected]> Authored: Thu May 12 22:18:39 2016 -0700 Committer: Yin Huai <[email protected]> Committed: Thu May 12 22:19:03 2016 -0700 ---------------------------------------------------------------------- .../catalyst/analysis/FunctionRegistry.scala | 5 +- .../catalyst/analysis/HiveTypeCoercion.scala | 2 + .../sql/catalyst/expressions/Expression.scala | 27 +++++++ .../catalyst/expressions/nullExpressions.scala | 78 +++++++++++++++++++- .../sql/catalyst/optimizer/Optimizer.scala | 12 +++ .../spark/sql/DataFrameFunctionsSuite.scala | 6 -- .../sql/SQLCompatibilityFunctionSuite.scala | 72 ++++++++++++++++++ .../sql/catalyst/ExpressionToSQLSuite.scala | 1 - 8 files changed, 194 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/51706f8a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index c459fe5..eca837c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -165,13 +165,16 @@ object FunctionRegistry { expression[Greatest]("greatest"), expression[If]("if"), expression[IsNaN]("isnan"), + expression[IfNull]("ifnull"), expression[IsNull]("isnull"), expression[IsNotNull]("isnotnull"), expression[Least]("least"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[NaNvl]("nanvl"), - expression[Coalesce]("nvl"), + expression[NullIf]("nullif"), + expression[Nvl]("nvl"), + expression[Nvl2]("nvl2"), expression[Rand]("rand"), expression[Randn]("randn"), expression[CreateStruct]("struct"), http://git-wip-us.apache.org/repos/asf/spark/blob/51706f8a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 8319ec0..537dda6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -521,6 +521,8 @@ object HiveTypeCoercion { NaNvl(l, Cast(r, DoubleType)) case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType => NaNvl(Cast(l, DoubleType), r) + + case e: RuntimeReplaceable => e.replaceForTypeCoercion() } } http://git-wip-us.apache.org/repos/asf/spark/blob/51706f8a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index c26faee..fab1634 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -222,6 +222,33 @@ trait Unevaluable extends Expression { /** + * An expression that gets replaced at runtime (currently by the optimizer) into a different + * expression for evaluation. This is mainly used to provide compatibility with other databases. + * For example, we use this to support "nvl" by replacing it with "coalesce". + */ +trait RuntimeReplaceable extends Unevaluable { + /** + * Method for concrete implementations to override that specifies how to construct the expression + * that should replace the current one. + */ + def replaceForEvaluation(): Expression + + /** + * Method for concrete implementations to override that specifies how to coerce the input types. + */ + def replaceForTypeCoercion(): Expression + + /** The expression that should be used during evaluation. */ + lazy val replaced: Expression = replaceForEvaluation() + + override def nullable: Boolean = replaced.nullable + override def foldable: Boolean = replaced.foldable + override def dataType: DataType = replaced.dataType + override def checkInputDataTypes(): TypeCheckResult = replaced.checkInputDataTypes() +} + + +/** * Expressions that don't have SQL representation should extend this trait. Examples are * `ScalaUDF`, `ScalaUDAF`, and object expressions like `MapObjects` and `Invoke`. */ http://git-wip-us.apache.org/repos/asf/spark/blob/51706f8a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 421200e..641c81b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{HiveTypeCoercion, TypeCheckResult} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -88,6 +88,82 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } +@ExpressionDescription(usage = "_FUNC_(a,b) - Returns b if a is null, or a otherwise.") +case class IfNull(left: Expression, right: Expression) extends RuntimeReplaceable { + override def children: Seq[Expression] = Seq(left, right) + + override def replaceForEvaluation(): Expression = Coalesce(Seq(left, right)) + + override def replaceForTypeCoercion(): Expression = { + if (left.dataType != right.dataType) { + HiveTypeCoercion.findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { dtype => + copy(left = Cast(left, dtype), right = Cast(right, dtype)) + }.getOrElse(this) + } else { + this + } + } +} + + +@ExpressionDescription(usage = "_FUNC_(a,b) - Returns null if a equals to b, or a otherwise.") +case class NullIf(left: Expression, right: Expression) extends RuntimeReplaceable { + override def children: Seq[Expression] = Seq(left, right) + + override def replaceForEvaluation(): Expression = { + If(EqualTo(left, right), Literal.create(null, left.dataType), left) + } + + override def replaceForTypeCoercion(): Expression = { + if (left.dataType != right.dataType) { + HiveTypeCoercion.findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { dtype => + copy(left = Cast(left, dtype), right = Cast(right, dtype)) + }.getOrElse(this) + } else { + this + } + } +} + + +@ExpressionDescription(usage = "_FUNC_(a,b) - Returns b if a is null, or a otherwise.") +case class Nvl(left: Expression, right: Expression) extends RuntimeReplaceable { + override def children: Seq[Expression] = Seq(left, right) + + override def replaceForEvaluation(): Expression = Coalesce(Seq(left, right)) + + override def replaceForTypeCoercion(): Expression = { + if (left.dataType != right.dataType) { + HiveTypeCoercion.findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { dtype => + copy(left = Cast(left, dtype), right = Cast(right, dtype)) + }.getOrElse(this) + } else { + this + } + } +} + + +@ExpressionDescription(usage = "_FUNC_(a,b,c) - Returns b if a is not null, or c otherwise.") +case class Nvl2(expr1: Expression, expr2: Expression, expr3: Expression) + extends RuntimeReplaceable { + + override def replaceForEvaluation(): Expression = If(IsNotNull(expr1), expr2, expr3) + + override def children: Seq[Expression] = Seq(expr1, expr2, expr3) + + override def replaceForTypeCoercion(): Expression = { + if (expr2.dataType != expr3.dataType) { + HiveTypeCoercion.findTightestCommonTypeOfTwo(expr2.dataType, expr3.dataType).map { dtype => + copy(expr2 = Cast(expr2, dtype), expr3 = Cast(expr3, dtype)) + }.getOrElse(this) + } else { + this + } + } +} + + /** * Evaluates to `true` iff it's NaN. */ http://git-wip-us.apache.org/repos/asf/spark/blob/51706f8a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 928ba21..af7532e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -49,6 +49,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) // we do not eliminate subqueries or compute current time in the analyzer. Batch("Finish Analysis", Once, EliminateSubqueryAliases, + ReplaceExpressions, ComputeCurrentTime, GetCurrentDatabase(sessionCatalog), DistinctAggregationRewriter) :: @@ -1512,6 +1513,17 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] { } /** + * Finds all [[RuntimeReplaceable]] expressions and replace them with the expressions that can + * be evaluated. This is mainly used to provide compatibility with other databases. + * For example, we use this to support "nvl" by replacing it with "coalesce". + */ +object ReplaceExpressions extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case e: RuntimeReplaceable => e.replaced + } +} + +/** * Computes the current date and time to make sure we return the same result in a single query. */ object ComputeCurrentTime extends Rule[LogicalPlan] { http://git-wip-us.apache.org/repos/asf/spark/blob/51706f8a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 746e25a..73d7765 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -152,12 +152,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row("one", "not_one")) } - test("nvl function") { - checkAnswer( - sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"), - Row("x", "y", null)) - } - test("misc md5 function") { val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") checkAnswer( http://git-wip-us.apache.org/repos/asf/spark/blob/51706f8a/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala new file mode 100644 index 0000000..1e32395 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSQLContext + +/** + * A test suite for functions added for compatibility with other databases such as Oracle, MSSQL. + * These functions are typically implemented using the trait + * [[org.apache.spark.sql.catalyst.expressions.RuntimeReplaceable]]. + */ +class SQLCompatibilityFunctionSuite extends QueryTest with SharedSQLContext { + + test("ifnull") { + checkAnswer( + sql("SELECT ifnull(null, 'x'), ifnull('y', 'x'), ifnull(null, null)"), + Row("x", "y", null)) + + // Type coercion + checkAnswer( + sql("SELECT ifnull(1, 2.1d), ifnull(null, 2.1d)"), + Row(1.0, 2.1)) + } + + test("nullif") { + checkAnswer( + sql("SELECT nullif('x', 'x'), nullif('x', 'y')"), + Row(null, "x")) + + // Type coercion + checkAnswer( + sql("SELECT nullif(1, 2.1d), nullif(1, 1.0d)"), + Row(1.0, null)) + } + + test("nvl") { + checkAnswer( + sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"), + Row("x", "y", null)) + + // Type coercion + checkAnswer( + sql("SELECT nvl(1, 2.1d), nvl(null, 2.1d)"), + Row(1.0, 2.1)) + } + + test("nvl2") { + checkAnswer( + sql("SELECT nvl2(null, 'x', 'y'), nvl2('n', 'x', 'y'), nvl2(null, null, null)"), + Row("y", "x", null)) + + // Type coercion + checkAnswer( + sql("SELECT nvl2(null, 1, 2.1d), nvl2('n', 1, 2.1d)"), + Row(2.1, 1.0)) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/51706f8a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala index 72736ee..b4eb50e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala @@ -102,7 +102,6 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkSqlGeneration("SELECT map(1, 'a', 2, 'b')") checkSqlGeneration("SELECT named_struct('c1',1,'c2',2,'c3',3)") checkSqlGeneration("SELECT nanvl(a, 5), nanvl(b, 10), nanvl(d, c) from t2") - checkSqlGeneration("SELECT nvl(null, 1, 2)") checkSqlGeneration("SELECT rand(1)") checkSqlGeneration("SELECT randn(3)") checkSqlGeneration("SELECT struct(1,2,3)") --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
