Repository: spark Updated Branches: refs/heads/master 86cc90744 -> a09d5ba88
[SPARK-26107][SQL] Extend ReplaceNullWithFalseInPredicate to support higher-order functions: ArrayExists, ArrayFilter, MapFilter ## What changes were proposed in this pull request? Extend the `ReplaceNullWithFalse` optimizer rule introduced in SPARK-25860 (https://github.com/apache/spark/pull/22857) to also support optimizing predicates in higher-order functions of `ArrayExists`, `ArrayFilter`, `MapFilter`. Also rename the rule to `ReplaceNullWithFalseInPredicate` to better reflect its intent. Example: ```sql select filter(a, e -> if(e is null, null, true)) as b from ( select array(null, 1, null, 3) as a) ``` The optimized logical plan: **Before**: ``` == Optimized Logical Plan == Project [filter([null,1,null,3], lambdafunction(if (isnull(lambda e#13)) null else true, lambda e#13, false)) AS b#9] +- OneRowRelation ``` **After**: ``` == Optimized Logical Plan == Project [filter([null,1,null,3], lambdafunction(if (isnull(lambda e#13)) false else true, lambda e#13, false)) AS b#9] +- OneRowRelation ``` ## How was this patch tested? Added new unit test cases to the `ReplaceNullWithFalseInPredicateSuite` (renamed from `ReplaceNullWithFalseSuite`). Closes #23079 from rednaxelafx/catalyst-master. Authored-by: Kris Mok <[email protected]> Signed-off-by: Wenchen Fan <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a09d5ba8 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a09d5ba8 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a09d5ba8 Branch: refs/heads/master Commit: a09d5ba88680d07121ce94a4e68c3f42fc635f4f Parents: 86cc907 Author: Kris Mok <[email protected]> Authored: Tue Nov 20 09:27:46 2018 +0800 Committer: Wenchen Fan <[email protected]> Committed: Tue Nov 20 09:27:46 2018 +0800 ---------------------------------------------------------------------- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../sql/catalyst/optimizer/expressions.scala | 11 +- .../ReplaceNullWithFalseInPredicateSuite.scala | 363 +++++++++++++++++++ .../optimizer/ReplaceNullWithFalseSuite.scala | 323 ----------------- .../sql/ReplaceNullWithFalseEndToEndSuite.scala | 71 ---- ...eNullWithFalseInPredicateEndToEndSuite.scala | 112 ++++++ 6 files changed, 486 insertions(+), 396 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/a09d5ba8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a330a84..8d251ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -84,7 +84,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) SimplifyConditionals, RemoveDispensableExpressions, SimplifyBinaryComparison, - ReplaceNullWithFalse, + ReplaceNullWithFalseInPredicate, PruneFilters, EliminateSorts, SimplifyCasts, http://git-wip-us.apache.org/repos/asf/spark/blob/a09d5ba8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 2b29b49..354efd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -755,7 +755,7 @@ object CombineConcats extends Rule[LogicalPlan] { * * As a result, many unnecessary computations can be removed in the query optimization phase. */ -object ReplaceNullWithFalse extends Rule[LogicalPlan] { +object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond)) @@ -767,6 +767,15 @@ object ReplaceNullWithFalse extends Rule[LogicalPlan] { replaceNullWithFalse(cond) -> value } cw.copy(branches = newBranches) + case af @ ArrayFilter(_, lf @ LambdaFunction(func, _, _)) => + val newLambda = lf.copy(function = replaceNullWithFalse(func)) + af.copy(function = newLambda) + case ae @ ArrayExists(_, lf @ LambdaFunction(func, _, _)) => + val newLambda = lf.copy(function = replaceNullWithFalse(func)) + ae.copy(function = newLambda) + case mf @ MapFilter(_, lf @ LambdaFunction(func, _, _)) => + val newLambda = lf.copy(function = replaceNullWithFalse(func)) + mf.copy(function = newLambda) } } http://git-wip-us.apache.org/repos/asf/spark/blob/a09d5ba8/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala new file mode 100644 index 0000000..3a9e6ca --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -0,0 +1,363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or} +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.{BooleanType, IntegerType} + +class ReplaceNullWithFalseInPredicateSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Replace null literals", FixedPoint(10), + NullPropagation, + ConstantFolding, + BooleanSimplification, + SimplifyConditionals, + ReplaceNullWithFalseInPredicate) :: Nil + } + + private val testRelation = + LocalRelation('i.int, 'b.boolean, 'a.array(IntegerType), 'm.map(IntegerType, IntegerType)) + private val anotherTestRelation = LocalRelation('d.int) + + test("replace null inside filter and join conditions") { + testFilter(originalCond = Literal(null), expectedCond = FalseLiteral) + testJoin(originalCond = Literal(null), expectedCond = FalseLiteral) + } + + test("replace null in branches of If") { + val originalCond = If( + UnresolvedAttribute("i") > Literal(10), + FalseLiteral, + Literal(null, BooleanType)) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("replace nulls in nested expressions in branches of If") { + val originalCond = If( + UnresolvedAttribute("i") > Literal(10), + TrueLiteral && Literal(null, BooleanType), + UnresolvedAttribute("b") && Literal(null, BooleanType)) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("replace null in elseValue of CaseWhen") { + val branches = Seq( + (UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral, + (UnresolvedAttribute("i") > Literal(40)) -> FalseLiteral) + val originalCond = CaseWhen(branches, Literal(null, BooleanType)) + val expectedCond = CaseWhen(branches, FalseLiteral) + testFilter(originalCond, expectedCond) + testJoin(originalCond, expectedCond) + } + + test("replace null in branch values of CaseWhen") { + val branches = Seq( + (UnresolvedAttribute("i") < Literal(10)) -> Literal(null, BooleanType), + (UnresolvedAttribute("i") > Literal(40)) -> FalseLiteral) + val originalCond = CaseWhen(branches, Literal(null)) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("replace null in branches of If inside CaseWhen") { + val originalBranches = Seq( + (UnresolvedAttribute("i") < Literal(10)) -> + If(UnresolvedAttribute("i") < Literal(20), Literal(null, BooleanType), FalseLiteral), + (UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral) + val originalCond = CaseWhen(originalBranches) + + val expectedBranches = Seq( + (UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral, + (UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral) + val expectedCond = CaseWhen(expectedBranches) + + testFilter(originalCond, expectedCond) + testJoin(originalCond, expectedCond) + } + + test("replace null in complex CaseWhen expressions") { + val originalBranches = Seq( + (UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral, + (Literal(6) <= Literal(1)) -> FalseLiteral, + (Literal(4) === Literal(5)) -> FalseLiteral, + (UnresolvedAttribute("i") > Literal(10)) -> Literal(null, BooleanType), + (Literal(4) === Literal(4)) -> TrueLiteral) + val originalCond = CaseWhen(originalBranches) + + val expectedBranches = Seq( + (UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral, + (UnresolvedAttribute("i") > Literal(10)) -> FalseLiteral, + TrueLiteral -> TrueLiteral) + val expectedCond = CaseWhen(expectedBranches) + + testFilter(originalCond, expectedCond) + testJoin(originalCond, expectedCond) + } + + test("replace null in Or") { + val originalCond = Or(UnresolvedAttribute("b"), Literal(null)) + val expectedCond = UnresolvedAttribute("b") + testFilter(originalCond, expectedCond) + testJoin(originalCond, expectedCond) + } + + test("replace null in And") { + val originalCond = And(UnresolvedAttribute("b"), Literal(null)) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("replace nulls in nested And/Or expressions") { + val originalCond = And( + And(UnresolvedAttribute("b"), Literal(null)), + Or(Literal(null), And(Literal(null), And(UnresolvedAttribute("b"), Literal(null))))) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("replace null in And inside branches of If") { + val originalCond = If( + UnresolvedAttribute("i") > Literal(10), + FalseLiteral, + And(UnresolvedAttribute("b"), Literal(null, BooleanType))) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("replace null in branches of If inside And") { + val originalCond = And( + UnresolvedAttribute("b"), + If( + UnresolvedAttribute("i") > Literal(10), + Literal(null), + And(FalseLiteral, UnresolvedAttribute("b")))) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("replace null in branches of If inside another If") { + val originalCond = If( + If(UnresolvedAttribute("b"), Literal(null), FalseLiteral), + TrueLiteral, + Literal(null)) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("replace null in CaseWhen inside another CaseWhen") { + val nestedCaseWhen = CaseWhen(Seq(UnresolvedAttribute("b") -> FalseLiteral), Literal(null)) + val originalCond = CaseWhen(Seq(nestedCaseWhen -> TrueLiteral), Literal(null)) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("inability to replace null in non-boolean branches of If") { + val condition = If( + UnresolvedAttribute("i") > Literal(10), + Literal(5) > If( + UnresolvedAttribute("i") === Literal(15), + Literal(null, IntegerType), + Literal(3)), + FalseLiteral) + testFilter(originalCond = condition, expectedCond = condition) + testJoin(originalCond = condition, expectedCond = condition) + } + + test("inability to replace null in non-boolean values of CaseWhen") { + val nestedCaseWhen = CaseWhen( + Seq((UnresolvedAttribute("i") > Literal(20)) -> Literal(2)), + Literal(null, IntegerType)) + val branchValue = If( + Literal(2) === nestedCaseWhen, + TrueLiteral, + FalseLiteral) + val branches = Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue) + val condition = CaseWhen(branches) + testFilter(originalCond = condition, expectedCond = condition) + testJoin(originalCond = condition, expectedCond = condition) + } + + test("inability to replace null in non-boolean branches of If inside another If") { + val condition = If( + Literal(5) > If( + UnresolvedAttribute("i") === Literal(15), + Literal(null, IntegerType), + Literal(3)), + TrueLiteral, + FalseLiteral) + testFilter(originalCond = condition, expectedCond = condition) + testJoin(originalCond = condition, expectedCond = condition) + } + + test("replace null in If used as a join condition") { + // this test is only for joins as the condition involves columns from different relations + val originalCond = If( + UnresolvedAttribute("d") > UnresolvedAttribute("i"), + Literal(null), + FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("replace null in CaseWhen used as a join condition") { + // this test is only for joins as the condition involves columns from different relations + val originalBranches = Seq( + (UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> Literal(null), + (UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral) + + val expectedBranches = Seq( + (UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> FalseLiteral, + (UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral) + + testJoin( + originalCond = CaseWhen(originalBranches, FalseLiteral), + expectedCond = CaseWhen(expectedBranches, FalseLiteral)) + } + + test("inability to replace null in CaseWhen inside EqualTo used as a join condition") { + // this test is only for joins as the condition involves columns from different relations + val branches = Seq( + (UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> Literal(null, BooleanType), + (UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral) + val condition = UnresolvedAttribute("b") === CaseWhen(branches, FalseLiteral) + testJoin(originalCond = condition, expectedCond = condition) + } + + test("replace null in predicates of If") { + val predicate = And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), Literal(null)) + testProjection( + originalExpr = If(predicate, Literal(5), Literal(1)).as("out"), + expectedExpr = Literal(1).as("out")) + } + + test("replace null in predicates of If inside another If") { + val predicate = If( + And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), Literal(null)), + TrueLiteral, + FalseLiteral) + testProjection( + originalExpr = If(predicate, Literal(5), Literal(1)).as("out"), + expectedExpr = Literal(1).as("out")) + } + + test("inability to replace null in non-boolean expressions inside If predicates") { + val predicate = GreaterThan( + UnresolvedAttribute("i"), + If(UnresolvedAttribute("b"), Literal(null, IntegerType), Literal(4))) + val column = If(predicate, Literal(5), Literal(1)).as("out") + testProjection(originalExpr = column, expectedExpr = column) + } + + test("replace null in conditions of CaseWhen") { + val branches = Seq( + And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), Literal(null)) -> Literal(5)) + testProjection( + originalExpr = CaseWhen(branches, Literal(2)).as("out"), + expectedExpr = Literal(2).as("out")) + } + + test("replace null in conditions of CaseWhen inside another CaseWhen") { + val nestedCaseWhen = CaseWhen( + Seq(And(UnresolvedAttribute("b"), Literal(null)) -> Literal(5)), + Literal(2)) + val branches = Seq(GreaterThan(Literal(3), nestedCaseWhen) -> Literal(1)) + testProjection( + originalExpr = CaseWhen(branches).as("out"), + expectedExpr = Literal(1).as("out")) + } + + test("inability to replace null in non-boolean exprs inside CaseWhen conditions") { + val condition = GreaterThan( + UnresolvedAttribute("i"), + If(UnresolvedAttribute("b"), Literal(null, IntegerType), Literal(4))) + val column = CaseWhen(Seq(condition -> Literal(5)), Literal(2)).as("out") + testProjection(originalExpr = column, expectedExpr = column) + } + + test("replace nulls in lambda function of ArrayFilter") { + testHigherOrderFunc('a, ArrayFilter, Seq('e)) + } + + test("replace nulls in lambda function of ArrayExists") { + testHigherOrderFunc('a, ArrayExists, Seq('e)) + } + + test("replace nulls in lambda function of MapFilter") { + testHigherOrderFunc('m, MapFilter, Seq('k, 'v)) + } + + test("inability to replace nulls in arbitrary higher-order function") { + val lambdaFunc = LambdaFunction( + function = If('e > 0, Literal(null, BooleanType), TrueLiteral), + arguments = Seq[NamedExpression]('e)) + val column = ArrayTransform('a, lambdaFunc) + testProjection(originalExpr = column, expectedExpr = column) + } + + private def testFilter(originalCond: Expression, expectedCond: Expression): Unit = { + test((rel, exp) => rel.where(exp), originalCond, expectedCond) + } + + private def testJoin(originalCond: Expression, expectedCond: Expression): Unit = { + test((rel, exp) => rel.join(anotherTestRelation, Inner, Some(exp)), originalCond, expectedCond) + } + + private def testProjection(originalExpr: Expression, expectedExpr: Expression): Unit = { + test((rel, exp) => rel.select(exp), originalExpr, expectedExpr) + } + + private def testHigherOrderFunc( + argument: Expression, + createExpr: (Expression, Expression) => Expression, + lambdaArgs: Seq[NamedExpression]): Unit = { + val condArg = lambdaArgs.last + // the lambda body is: if(arg > 0, null, true) + val cond = GreaterThan(condArg, Literal(0)) + val lambda1 = LambdaFunction( + function = If(cond, Literal(null, BooleanType), TrueLiteral), + arguments = lambdaArgs) + // the optimized lambda body is: if(arg > 0, false, true) + val lambda2 = LambdaFunction( + function = If(cond, FalseLiteral, TrueLiteral), + arguments = lambdaArgs) + testProjection( + originalExpr = createExpr(argument, lambda1) as 'x, + expectedExpr = createExpr(argument, lambda2) as 'x) + } + + private def test( + func: (LogicalPlan, Expression) => LogicalPlan, + originalExpr: Expression, + expectedExpr: Expression): Unit = { + + val originalPlan = func(testRelation, originalExpr).analyze + val optimizedPlan = Optimize.execute(originalPlan) + val expectedPlan = func(testRelation, expectedExpr).analyze + comparePlans(optimizedPlan, expectedPlan) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/a09d5ba8/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala deleted file mode 100644 index c6b5d0e..0000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala +++ /dev/null @@ -1,323 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{And, CaseWhen, Expression, GreaterThan, If, Literal, Or} -import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} -import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.types.{BooleanType, IntegerType} - -class ReplaceNullWithFalseSuite extends PlanTest { - - object Optimize extends RuleExecutor[LogicalPlan] { - val batches = - Batch("Replace null literals", FixedPoint(10), - NullPropagation, - ConstantFolding, - BooleanSimplification, - SimplifyConditionals, - ReplaceNullWithFalse) :: Nil - } - - private val testRelation = LocalRelation('i.int, 'b.boolean) - private val anotherTestRelation = LocalRelation('d.int) - - test("replace null inside filter and join conditions") { - testFilter(originalCond = Literal(null), expectedCond = FalseLiteral) - testJoin(originalCond = Literal(null), expectedCond = FalseLiteral) - } - - test("replace null in branches of If") { - val originalCond = If( - UnresolvedAttribute("i") > Literal(10), - FalseLiteral, - Literal(null, BooleanType)) - testFilter(originalCond, expectedCond = FalseLiteral) - testJoin(originalCond, expectedCond = FalseLiteral) - } - - test("replace nulls in nested expressions in branches of If") { - val originalCond = If( - UnresolvedAttribute("i") > Literal(10), - TrueLiteral && Literal(null, BooleanType), - UnresolvedAttribute("b") && Literal(null, BooleanType)) - testFilter(originalCond, expectedCond = FalseLiteral) - testJoin(originalCond, expectedCond = FalseLiteral) - } - - test("replace null in elseValue of CaseWhen") { - val branches = Seq( - (UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral, - (UnresolvedAttribute("i") > Literal(40)) -> FalseLiteral) - val originalCond = CaseWhen(branches, Literal(null, BooleanType)) - val expectedCond = CaseWhen(branches, FalseLiteral) - testFilter(originalCond, expectedCond) - testJoin(originalCond, expectedCond) - } - - test("replace null in branch values of CaseWhen") { - val branches = Seq( - (UnresolvedAttribute("i") < Literal(10)) -> Literal(null, BooleanType), - (UnresolvedAttribute("i") > Literal(40)) -> FalseLiteral) - val originalCond = CaseWhen(branches, Literal(null)) - testFilter(originalCond, expectedCond = FalseLiteral) - testJoin(originalCond, expectedCond = FalseLiteral) - } - - test("replace null in branches of If inside CaseWhen") { - val originalBranches = Seq( - (UnresolvedAttribute("i") < Literal(10)) -> - If(UnresolvedAttribute("i") < Literal(20), Literal(null, BooleanType), FalseLiteral), - (UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral) - val originalCond = CaseWhen(originalBranches) - - val expectedBranches = Seq( - (UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral, - (UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral) - val expectedCond = CaseWhen(expectedBranches) - - testFilter(originalCond, expectedCond) - testJoin(originalCond, expectedCond) - } - - test("replace null in complex CaseWhen expressions") { - val originalBranches = Seq( - (UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral, - (Literal(6) <= Literal(1)) -> FalseLiteral, - (Literal(4) === Literal(5)) -> FalseLiteral, - (UnresolvedAttribute("i") > Literal(10)) -> Literal(null, BooleanType), - (Literal(4) === Literal(4)) -> TrueLiteral) - val originalCond = CaseWhen(originalBranches) - - val expectedBranches = Seq( - (UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral, - (UnresolvedAttribute("i") > Literal(10)) -> FalseLiteral, - TrueLiteral -> TrueLiteral) - val expectedCond = CaseWhen(expectedBranches) - - testFilter(originalCond, expectedCond) - testJoin(originalCond, expectedCond) - } - - test("replace null in Or") { - val originalCond = Or(UnresolvedAttribute("b"), Literal(null)) - val expectedCond = UnresolvedAttribute("b") - testFilter(originalCond, expectedCond) - testJoin(originalCond, expectedCond) - } - - test("replace null in And") { - val originalCond = And(UnresolvedAttribute("b"), Literal(null)) - testFilter(originalCond, expectedCond = FalseLiteral) - testJoin(originalCond, expectedCond = FalseLiteral) - } - - test("replace nulls in nested And/Or expressions") { - val originalCond = And( - And(UnresolvedAttribute("b"), Literal(null)), - Or(Literal(null), And(Literal(null), And(UnresolvedAttribute("b"), Literal(null))))) - testFilter(originalCond, expectedCond = FalseLiteral) - testJoin(originalCond, expectedCond = FalseLiteral) - } - - test("replace null in And inside branches of If") { - val originalCond = If( - UnresolvedAttribute("i") > Literal(10), - FalseLiteral, - And(UnresolvedAttribute("b"), Literal(null, BooleanType))) - testFilter(originalCond, expectedCond = FalseLiteral) - testJoin(originalCond, expectedCond = FalseLiteral) - } - - test("replace null in branches of If inside And") { - val originalCond = And( - UnresolvedAttribute("b"), - If( - UnresolvedAttribute("i") > Literal(10), - Literal(null), - And(FalseLiteral, UnresolvedAttribute("b")))) - testFilter(originalCond, expectedCond = FalseLiteral) - testJoin(originalCond, expectedCond = FalseLiteral) - } - - test("replace null in branches of If inside another If") { - val originalCond = If( - If(UnresolvedAttribute("b"), Literal(null), FalseLiteral), - TrueLiteral, - Literal(null)) - testFilter(originalCond, expectedCond = FalseLiteral) - testJoin(originalCond, expectedCond = FalseLiteral) - } - - test("replace null in CaseWhen inside another CaseWhen") { - val nestedCaseWhen = CaseWhen(Seq(UnresolvedAttribute("b") -> FalseLiteral), Literal(null)) - val originalCond = CaseWhen(Seq(nestedCaseWhen -> TrueLiteral), Literal(null)) - testFilter(originalCond, expectedCond = FalseLiteral) - testJoin(originalCond, expectedCond = FalseLiteral) - } - - test("inability to replace null in non-boolean branches of If") { - val condition = If( - UnresolvedAttribute("i") > Literal(10), - Literal(5) > If( - UnresolvedAttribute("i") === Literal(15), - Literal(null, IntegerType), - Literal(3)), - FalseLiteral) - testFilter(originalCond = condition, expectedCond = condition) - testJoin(originalCond = condition, expectedCond = condition) - } - - test("inability to replace null in non-boolean values of CaseWhen") { - val nestedCaseWhen = CaseWhen( - Seq((UnresolvedAttribute("i") > Literal(20)) -> Literal(2)), - Literal(null, IntegerType)) - val branchValue = If( - Literal(2) === nestedCaseWhen, - TrueLiteral, - FalseLiteral) - val branches = Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue) - val condition = CaseWhen(branches) - testFilter(originalCond = condition, expectedCond = condition) - testJoin(originalCond = condition, expectedCond = condition) - } - - test("inability to replace null in non-boolean branches of If inside another If") { - val condition = If( - Literal(5) > If( - UnresolvedAttribute("i") === Literal(15), - Literal(null, IntegerType), - Literal(3)), - TrueLiteral, - FalseLiteral) - testFilter(originalCond = condition, expectedCond = condition) - testJoin(originalCond = condition, expectedCond = condition) - } - - test("replace null in If used as a join condition") { - // this test is only for joins as the condition involves columns from different relations - val originalCond = If( - UnresolvedAttribute("d") > UnresolvedAttribute("i"), - Literal(null), - FalseLiteral) - testJoin(originalCond, expectedCond = FalseLiteral) - } - - test("replace null in CaseWhen used as a join condition") { - // this test is only for joins as the condition involves columns from different relations - val originalBranches = Seq( - (UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> Literal(null), - (UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral) - - val expectedBranches = Seq( - (UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> FalseLiteral, - (UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral) - - testJoin( - originalCond = CaseWhen(originalBranches, FalseLiteral), - expectedCond = CaseWhen(expectedBranches, FalseLiteral)) - } - - test("inability to replace null in CaseWhen inside EqualTo used as a join condition") { - // this test is only for joins as the condition involves columns from different relations - val branches = Seq( - (UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> Literal(null, BooleanType), - (UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral) - val condition = UnresolvedAttribute("b") === CaseWhen(branches, FalseLiteral) - testJoin(originalCond = condition, expectedCond = condition) - } - - test("replace null in predicates of If") { - val predicate = And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), Literal(null)) - testProjection( - originalExpr = If(predicate, Literal(5), Literal(1)).as("out"), - expectedExpr = Literal(1).as("out")) - } - - test("replace null in predicates of If inside another If") { - val predicate = If( - And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), Literal(null)), - TrueLiteral, - FalseLiteral) - testProjection( - originalExpr = If(predicate, Literal(5), Literal(1)).as("out"), - expectedExpr = Literal(1).as("out")) - } - - test("inability to replace null in non-boolean expressions inside If predicates") { - val predicate = GreaterThan( - UnresolvedAttribute("i"), - If(UnresolvedAttribute("b"), Literal(null, IntegerType), Literal(4))) - val column = If(predicate, Literal(5), Literal(1)).as("out") - testProjection(originalExpr = column, expectedExpr = column) - } - - test("replace null in conditions of CaseWhen") { - val branches = Seq( - And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), Literal(null)) -> Literal(5)) - testProjection( - originalExpr = CaseWhen(branches, Literal(2)).as("out"), - expectedExpr = Literal(2).as("out")) - } - - test("replace null in conditions of CaseWhen inside another CaseWhen") { - val nestedCaseWhen = CaseWhen( - Seq(And(UnresolvedAttribute("b"), Literal(null)) -> Literal(5)), - Literal(2)) - val branches = Seq(GreaterThan(Literal(3), nestedCaseWhen) -> Literal(1)) - testProjection( - originalExpr = CaseWhen(branches).as("out"), - expectedExpr = Literal(1).as("out")) - } - - test("inability to replace null in non-boolean exprs inside CaseWhen conditions") { - val condition = GreaterThan( - UnresolvedAttribute("i"), - If(UnresolvedAttribute("b"), Literal(null, IntegerType), Literal(4))) - val column = CaseWhen(Seq(condition -> Literal(5)), Literal(2)).as("out") - testProjection(originalExpr = column, expectedExpr = column) - } - - private def testFilter(originalCond: Expression, expectedCond: Expression): Unit = { - test((rel, exp) => rel.where(exp), originalCond, expectedCond) - } - - private def testJoin(originalCond: Expression, expectedCond: Expression): Unit = { - test((rel, exp) => rel.join(anotherTestRelation, Inner, Some(exp)), originalCond, expectedCond) - } - - private def testProjection(originalExpr: Expression, expectedExpr: Expression): Unit = { - test((rel, exp) => rel.select(exp), originalExpr, expectedExpr) - } - - private def test( - func: (LogicalPlan, Expression) => LogicalPlan, - originalExpr: Expression, - expectedExpr: Expression): Unit = { - - val originalPlan = func(testRelation, originalExpr).analyze - val optimizedPlan = Optimize.execute(originalPlan) - val expectedPlan = func(testRelation, expectedExpr).analyze - comparePlans(optimizedPlan, expectedPlan) - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/a09d5ba8/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala deleted file mode 100644 index fc6ecc4..0000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.sql.catalyst.expressions.{CaseWhen, If} -import org.apache.spark.sql.execution.LocalTableScanExec -import org.apache.spark.sql.functions.{lit, when} -import org.apache.spark.sql.test.SharedSQLContext - -class ReplaceNullWithFalseEndToEndSuite extends QueryTest with SharedSQLContext { - import testImplicits._ - - test("SPARK-25860: Replace Literal(null, _) with FalseLiteral whenever possible") { - withTable("t1", "t2") { - Seq((1, true), (2, false)).toDF("l", "b").write.saveAsTable("t1") - Seq(2, 3).toDF("l").write.saveAsTable("t2") - val df1 = spark.table("t1") - val df2 = spark.table("t2") - - val q1 = df1.where("IF(l > 10, false, b AND null)") - checkAnswer(q1, Seq.empty) - checkPlanIsEmptyLocalScan(q1) - - val q2 = df1.where("CASE WHEN l < 10 THEN null WHEN l > 40 THEN false ELSE null END") - checkAnswer(q2, Seq.empty) - checkPlanIsEmptyLocalScan(q2) - - val q3 = df1.join(df2, when(df1("l") > df2("l"), lit(null)).otherwise(df1("b") && lit(null))) - checkAnswer(q3, Seq.empty) - checkPlanIsEmptyLocalScan(q3) - - val q4 = df1.where("IF(IF(b, null, false), true, null)") - checkAnswer(q4, Seq.empty) - checkPlanIsEmptyLocalScan(q4) - - val q5 = df1.selectExpr("IF(l > 1 AND null, 5, 1) AS out") - checkAnswer(q5, Row(1) :: Row(1) :: Nil) - q5.queryExecution.executedPlan.foreach { p => - assert(p.expressions.forall(e => e.find(_.isInstanceOf[If]).isEmpty)) - } - - val q6 = df1.selectExpr("CASE WHEN (l > 2 AND null) THEN 3 ELSE 2 END") - checkAnswer(q6, Row(2) :: Row(2) :: Nil) - q6.queryExecution.executedPlan.foreach { p => - assert(p.expressions.forall(e => e.find(_.isInstanceOf[CaseWhen]).isEmpty)) - } - - checkAnswer(df1.where("IF(l > 10, false, b OR null)"), Row(1, true)) - } - - def checkPlanIsEmptyLocalScan(df: DataFrame): Unit = df.queryExecution.executedPlan match { - case s: LocalTableScanExec => assert(s.rows.isEmpty) - case p => fail(s"$p is not LocalTableScanExec") - } - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/a09d5ba8/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala new file mode 100644 index 0000000..0f84b0c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.{CaseWhen, If, Literal} +import org.apache.spark.sql.execution.LocalTableScanExec +import org.apache.spark.sql.functions.{lit, when} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.BooleanType + +class ReplaceNullWithFalseInPredicateEndToEndSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("SPARK-25860: Replace Literal(null, _) with FalseLiteral whenever possible") { + withTable("t1", "t2") { + Seq((1, true), (2, false)).toDF("l", "b").write.saveAsTable("t1") + Seq(2, 3).toDF("l").write.saveAsTable("t2") + val df1 = spark.table("t1") + val df2 = spark.table("t2") + + val q1 = df1.where("IF(l > 10, false, b AND null)") + checkAnswer(q1, Seq.empty) + checkPlanIsEmptyLocalScan(q1) + + val q2 = df1.where("CASE WHEN l < 10 THEN null WHEN l > 40 THEN false ELSE null END") + checkAnswer(q2, Seq.empty) + checkPlanIsEmptyLocalScan(q2) + + val q3 = df1.join(df2, when(df1("l") > df2("l"), lit(null)).otherwise(df1("b") && lit(null))) + checkAnswer(q3, Seq.empty) + checkPlanIsEmptyLocalScan(q3) + + val q4 = df1.where("IF(IF(b, null, false), true, null)") + checkAnswer(q4, Seq.empty) + checkPlanIsEmptyLocalScan(q4) + + val q5 = df1.selectExpr("IF(l > 1 AND null, 5, 1) AS out") + checkAnswer(q5, Row(1) :: Row(1) :: Nil) + q5.queryExecution.executedPlan.foreach { p => + assert(p.expressions.forall(e => e.find(_.isInstanceOf[If]).isEmpty)) + } + + val q6 = df1.selectExpr("CASE WHEN (l > 2 AND null) THEN 3 ELSE 2 END") + checkAnswer(q6, Row(2) :: Row(2) :: Nil) + q6.queryExecution.executedPlan.foreach { p => + assert(p.expressions.forall(e => e.find(_.isInstanceOf[CaseWhen]).isEmpty)) + } + + checkAnswer(df1.where("IF(l > 10, false, b OR null)"), Row(1, true)) + } + + def checkPlanIsEmptyLocalScan(df: DataFrame): Unit = df.queryExecution.executedPlan match { + case s: LocalTableScanExec => assert(s.rows.isEmpty) + case p => fail(s"$p is not LocalTableScanExec") + } + } + + test("SPARK-26107: Replace Literal(null, _) with FalseLiteral in higher-order functions") { + def assertNoLiteralNullInPlan(df: DataFrame): Unit = { + df.queryExecution.executedPlan.foreach { p => + assert(p.expressions.forall(_.find { + case Literal(null, BooleanType) => true + case _ => false + }.isEmpty)) + } + } + + withTable("t1", "t2") { + // to test ArrayFilter and ArrayExists + spark.sql("select array(null, 1, null, 3) as a") + .write.saveAsTable("t1") + // to test MapFilter + spark.sql(""" + select map_from_entries(arrays_zip(a, transform(a, e -> if(mod(e, 2) = 0, null, e)))) as m + from (select array(0, 1, 2, 3) as a) + """).write.saveAsTable("t2") + + val df1 = spark.table("t1") + val df2 = spark.table("t2") + + // ArrayExists + val q1 = df1.selectExpr("EXISTS(a, e -> IF(e is null, null, true))") + checkAnswer(q1, Row(true) :: Nil) + assertNoLiteralNullInPlan(q1) + + // ArrayFilter + val q2 = df1.selectExpr("FILTER(a, e -> IF(e is null, null, true))") + checkAnswer(q2, Row(Seq[Any](1, 3)) :: Nil) + assertNoLiteralNullInPlan(q2) + + // MapFilter + val q3 = df2.selectExpr("MAP_FILTER(m, (k, v) -> IF(v is null, null, true))") + checkAnswer(q3, Row(Map[Any, Any](1 -> 1, 3 -> 3))) + assertNoLiteralNullInPlan(q3) + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
