Github user dbtsai commented on a diff in the pull request: https://github.com/apache/spark/pull/22857#discussion_r228739018 --- Diff: sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala --- @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{And, CaseWhen, Expression, GreaterThan, If, Literal, Or} +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.{BooleanType, IntegerType} + +class ReplaceNullWithFalseSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Replace null literals", FixedPoint(10), + NullPropagation, + ConstantFolding, + BooleanSimplification, + SimplifyConditionals, + ReplaceNullWithFalse) :: Nil + } + + private val testRelation = LocalRelation('i.int, 'b.boolean) + private val anotherTestRelation = LocalRelation('d.int) + + test("successful replacement of null literals in filter and join conditions (1)") { + testFilter(originalCond = Literal(null), expectedCond = FalseLiteral) + testJoin(originalCond = Literal(null), expectedCond = FalseLiteral) + } + + test("successful replacement of null literals in filter and join conditions (2)") { + val originalCond = If( + UnresolvedAttribute("i") > Literal(10), + FalseLiteral, + Literal(null, BooleanType)) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("successful replacement of null literals in filter and join conditions (3)") { + val originalCond = If( + UnresolvedAttribute("i") > Literal(10), + TrueLiteral && Literal(null, BooleanType), + UnresolvedAttribute("b") && Literal(null, BooleanType)) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("successful replacement of null literals in filter and join conditions (4)") { + val branches = Seq( + (UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral, + (UnresolvedAttribute("i") > Literal(40)) -> FalseLiteral) + val originalCond = CaseWhen(branches, Literal(null, BooleanType)) + val expectedCond = CaseWhen(branches, FalseLiteral) + testFilter(originalCond, expectedCond) + testJoin(originalCond, expectedCond) + } + + test("successful replacement of null literals in filter and join conditions (5)") { + val branches = Seq( + (UnresolvedAttribute("i") < Literal(10)) -> Literal(null, BooleanType), + (UnresolvedAttribute("i") > Literal(40)) -> FalseLiteral) + val originalCond = CaseWhen(branches, Literal(null)) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("successful replacement of null literals in filter and join conditions (6)") { + val originalBranches = Seq( + (UnresolvedAttribute("i") < Literal(10)) -> + If(UnresolvedAttribute("i") < Literal(20), Literal(null, BooleanType), FalseLiteral), + (UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral) + val originalCond = CaseWhen(originalBranches) + + val expectedBranches = Seq( + (UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral, + (UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral) + val expectedCond = CaseWhen(expectedBranches) + + testFilter(originalCond, expectedCond) + testJoin(originalCond, expectedCond) + } + + test("successful replacement of null literals in filter and join conditions (7)") { + val originalBranches = Seq( + (UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral, + (Literal(6) <= Literal(1)) -> FalseLiteral, + (Literal(4) === Literal(5)) -> FalseLiteral, + (UnresolvedAttribute("i") > Literal(10)) -> Literal(null, BooleanType), + (Literal(4) === Literal(4)) -> TrueLiteral) + val originalCond = CaseWhen(originalBranches) + + val expectedBranches = Seq( + (UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral, + (UnresolvedAttribute("i") > Literal(10)) -> FalseLiteral, + TrueLiteral -> TrueLiteral) + val expectedCond = CaseWhen(expectedBranches) + + testFilter(originalCond, expectedCond) + testJoin(originalCond, expectedCond) + } + + test("successful replacement of null literals in filter and join conditions (8)") { + val originalCond = Or(UnresolvedAttribute("b"), Literal(null)) + val expectedCond = UnresolvedAttribute("b") + testFilter(originalCond, expectedCond) + testJoin(originalCond, expectedCond) + } + + test("successful replacement of null literals in filter and join conditions (9)") { + val originalCond = And(UnresolvedAttribute("b"), Literal(null)) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("successful replacement of null literals in filter and join conditions (10)") { + val originalCond = And( + And(UnresolvedAttribute("b"), Literal(null)), + Or(Literal(null), And(Literal(null), And(UnresolvedAttribute("b"), Literal(null))))) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("successful replacement of null literals in filter and join conditions (11)") { + val originalCond = If( + UnresolvedAttribute("i") > Literal(10), + FalseLiteral, + And(UnresolvedAttribute("b"), Literal(null, BooleanType))) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("successful replacement of null literals in filter and join conditions (12)") { + val originalCond = And( + UnresolvedAttribute("b"), + If( + UnresolvedAttribute("i") > Literal(10), + Literal(null), + And(FalseLiteral, UnresolvedAttribute("b")))) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("successful replacement of null literals in filter and join conditions (13)") { + val originalCond = If( + If(UnresolvedAttribute("b"), Literal(null), FalseLiteral), + TrueLiteral, + Literal(null)) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("successful replacement of null literals in filter and join conditions (14)") { + val nestedCaseWhen = CaseWhen(Seq(UnresolvedAttribute("b") -> FalseLiteral), Literal(null)) + val originalCond = CaseWhen(Seq(nestedCaseWhen -> TrueLiteral), Literal(null)) + testFilter(originalCond, expectedCond = FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("inability to replace null literals in filter and join conditions (1)") { + val condition = If( + UnresolvedAttribute("i") > Literal(10), + Literal(5) > If( + UnresolvedAttribute("i") === Literal(15), + Literal(null, IntegerType), + Literal(3)), + FalseLiteral) + testFilter(originalCond = condition, expectedCond = condition) + testJoin(originalCond = condition, expectedCond = condition) + } + + test("inability to replace null literals in filter and join conditions (2)") { + val nestedCaseWhen = CaseWhen( + Seq((UnresolvedAttribute("i") > Literal(20)) -> Literal(2)), + Literal(null, IntegerType)) + val branchValue = If( + Literal(2) === nestedCaseWhen, + TrueLiteral, + FalseLiteral) + val branches = Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue) + val condition = CaseWhen(branches) + testFilter(originalCond = condition, expectedCond = condition) + testJoin(originalCond = condition, expectedCond = condition) + } + + test("inability to replace null literals in filter and join conditions (3)") { + val condition = If( + Literal(5) > If( + UnresolvedAttribute("i") === Literal(15), + Literal(null, IntegerType), + Literal(3)), + TrueLiteral, + FalseLiteral) + testFilter(originalCond = condition, expectedCond = condition) + testJoin(originalCond = condition, expectedCond = condition) + } + + test("successful replacement of null literals in join conditions (1)") { + // this test is only for joins as the condition involves columns from different relations + val originalCond = If( + UnresolvedAttribute("d") > UnresolvedAttribute("i"), + Literal(null), + FalseLiteral) + testJoin(originalCond, expectedCond = FalseLiteral) + } + + test("successful replacement of null literals in join conditions (2)") { + // this test is only for joins as the condition involves columns from different relations + val originalBranches = Seq( + (UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> Literal(null), + (UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral) + + val expectedBranches = Seq( + (UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> FalseLiteral, + (UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral) + + testJoin( + originalCond = CaseWhen(originalBranches, FalseLiteral), + expectedCond = CaseWhen(expectedBranches, FalseLiteral)) + } + + test("inability to replace null literals in join conditions (1)") { + // this test is only for joins as the condition involves columns from different relations + val branches = Seq( + (UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> Literal(null, BooleanType), + (UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral) + val condition = UnresolvedAttribute("b") === CaseWhen(branches, FalseLiteral) + testJoin(originalCond = condition, expectedCond = condition) + } + + test("successful replacement of null literals in if predicates (1)") { + val predicate = And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), Literal(null)) + testProjection( + originalExpr = If(predicate, Literal(5), Literal(1)).as("out"), + expectedExpr = Literal(1).as("out")) + } + + test("successful replacement of null literals in if predicates (2)") { + val predicate = If( + And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), Literal(null)), + TrueLiteral, + FalseLiteral) + testProjection( + originalExpr = If(predicate, Literal(5), Literal(1)).as("out"), + expectedExpr = Literal(1).as("out")) + } + + test("inability to replace null literals in if predicates") { + val predicate = GreaterThan( + UnresolvedAttribute("i"), + If(UnresolvedAttribute("b"), Literal(null, IntegerType), Literal(4))) + val column = If(predicate, Literal(5), Literal(1)).as("out") + testProjection(originalExpr = column, expectedExpr = column) + } + + test("successful replacement of null literals in branches of case when (1)") { + val branches = Seq( + And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), Literal(null)) -> Literal(5)) + testProjection( + originalExpr = CaseWhen(branches, Literal(2)).as("out"), + expectedExpr = Literal(2).as("out")) + } + + test("successful replacement of null literals in branches of case when (2)") { + val nestedCaseWhen = CaseWhen( + Seq(And(UnresolvedAttribute("b"), Literal(null)) -> Literal(5)), + Literal(2)) + val branches = Seq(GreaterThan(Literal(3), nestedCaseWhen) -> Literal(1)) + testProjection( + originalExpr = CaseWhen(branches).as("out"), + expectedExpr = Literal(1).as("out")) + } + + test("inability to replace null literals in branches of case when") { + val condition = GreaterThan( + UnresolvedAttribute("i"), + If(UnresolvedAttribute("b"), Literal(null, IntegerType), Literal(4))) + val column = CaseWhen(Seq(condition -> Literal(5)), Literal(2)).as("out") + testProjection(originalExpr = column, expectedExpr = column) + } + + private def testFilter(originalCond: Expression, expectedCond: Expression): Unit = { + test((rel, exp) => rel.where(exp), originalCond, expectedCond) + } + + private def testJoin(originalCond: Expression, expectedCond: Expression): Unit = { + test((rel, exp) => rel.join(anotherTestRelation, Inner, Some(exp)), originalCond, expectedCond) + } + + private def testProjection(originalExpr: Expression, expectedExpr: Expression): Unit = { + test((rel, exp) => rel.select(exp), originalExpr, expectedExpr) + } + + private def test( + func: (LogicalPlan, Expression) => LogicalPlan, + originalExpr: Expression, + expectedExpr: Expression): Unit = { + + val originalPlan = func(testRelation, originalExpr).analyze + val optimizedPlan = Optimize.execute(originalPlan) + val expectedPlan = func(testRelation, expectedExpr).analyze + comparePlans(optimizedPlan, expectedPlan) + } + --- End diff -- remove extra line.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org