Github user dbtsai commented on a diff in the pull request:

    https://github.com/apache/spark/pull/22857#discussion_r228739018
  
    --- Diff: 
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala
 ---
    @@ -0,0 +1,324 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.catalyst.optimizer
    +
    +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
    +import org.apache.spark.sql.catalyst.dsl.expressions._
    +import org.apache.spark.sql.catalyst.dsl.plans._
    +import org.apache.spark.sql.catalyst.expressions.{And, CaseWhen, 
Expression, GreaterThan, If, Literal, Or}
    +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, 
TrueLiteral}
    +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
    +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, 
LogicalPlan}
    +import org.apache.spark.sql.catalyst.rules.RuleExecutor
    +import org.apache.spark.sql.types.{BooleanType, IntegerType}
    +
    +class ReplaceNullWithFalseSuite extends PlanTest {
    +
    +  object Optimize extends RuleExecutor[LogicalPlan] {
    +    val batches =
    +      Batch("Replace null literals", FixedPoint(10),
    +        NullPropagation,
    +        ConstantFolding,
    +        BooleanSimplification,
    +        SimplifyConditionals,
    +        ReplaceNullWithFalse) :: Nil
    +  }
    +
    +  private val testRelation = LocalRelation('i.int, 'b.boolean)
    +  private val anotherTestRelation = LocalRelation('d.int)
    +
    +  test("successful replacement of null literals in filter and join 
conditions (1)") {
    +    testFilter(originalCond = Literal(null), expectedCond = FalseLiteral)
    +    testJoin(originalCond = Literal(null), expectedCond = FalseLiteral)
    +  }
    +
    +  test("successful replacement of null literals in filter and join 
conditions (2)") {
    +    val originalCond = If(
    +      UnresolvedAttribute("i") > Literal(10),
    +      FalseLiteral,
    +      Literal(null, BooleanType))
    +    testFilter(originalCond, expectedCond = FalseLiteral)
    +    testJoin(originalCond, expectedCond = FalseLiteral)
    +  }
    +
    +  test("successful replacement of null literals in filter and join 
conditions (3)") {
    +    val originalCond = If(
    +      UnresolvedAttribute("i") > Literal(10),
    +      TrueLiteral && Literal(null, BooleanType),
    +      UnresolvedAttribute("b") && Literal(null, BooleanType))
    +    testFilter(originalCond, expectedCond = FalseLiteral)
    +    testJoin(originalCond, expectedCond = FalseLiteral)
    +  }
    +
    +  test("successful replacement of null literals in filter and join 
conditions (4)") {
    +    val branches = Seq(
    +      (UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral,
    +      (UnresolvedAttribute("i") > Literal(40)) -> FalseLiteral)
    +    val originalCond = CaseWhen(branches, Literal(null, BooleanType))
    +    val expectedCond = CaseWhen(branches, FalseLiteral)
    +    testFilter(originalCond, expectedCond)
    +    testJoin(originalCond, expectedCond)
    +  }
    +
    +  test("successful replacement of null literals in filter and join 
conditions (5)") {
    +    val branches = Seq(
    +      (UnresolvedAttribute("i") < Literal(10)) -> Literal(null, 
BooleanType),
    +      (UnresolvedAttribute("i") > Literal(40)) -> FalseLiteral)
    +    val originalCond = CaseWhen(branches, Literal(null))
    +    testFilter(originalCond, expectedCond = FalseLiteral)
    +    testJoin(originalCond, expectedCond = FalseLiteral)
    +  }
    +
    +  test("successful replacement of null literals in filter and join 
conditions (6)") {
    +    val originalBranches = Seq(
    +      (UnresolvedAttribute("i") < Literal(10)) ->
    +        If(UnresolvedAttribute("i") < Literal(20), Literal(null, 
BooleanType), FalseLiteral),
    +      (UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral)
    +    val originalCond = CaseWhen(originalBranches)
    +
    +    val expectedBranches = Seq(
    +      (UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral,
    +      (UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral)
    +    val expectedCond = CaseWhen(expectedBranches)
    +
    +    testFilter(originalCond, expectedCond)
    +    testJoin(originalCond, expectedCond)
    +  }
    +
    +  test("successful replacement of null literals in filter and join 
conditions (7)") {
    +    val originalBranches = Seq(
    +      (UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral,
    +      (Literal(6) <= Literal(1)) -> FalseLiteral,
    +      (Literal(4) === Literal(5)) -> FalseLiteral,
    +      (UnresolvedAttribute("i") > Literal(10)) -> Literal(null, 
BooleanType),
    +      (Literal(4) === Literal(4)) -> TrueLiteral)
    +    val originalCond = CaseWhen(originalBranches)
    +
    +    val expectedBranches = Seq(
    +      (UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral,
    +      (UnresolvedAttribute("i") > Literal(10)) -> FalseLiteral,
    +      TrueLiteral -> TrueLiteral)
    +    val expectedCond = CaseWhen(expectedBranches)
    +
    +    testFilter(originalCond, expectedCond)
    +    testJoin(originalCond, expectedCond)
    +  }
    +
    +  test("successful replacement of null literals in filter and join 
conditions (8)") {
    +    val originalCond = Or(UnresolvedAttribute("b"), Literal(null))
    +    val expectedCond = UnresolvedAttribute("b")
    +    testFilter(originalCond, expectedCond)
    +    testJoin(originalCond, expectedCond)
    +  }
    +
    +  test("successful replacement of null literals in filter and join 
conditions (9)") {
    +    val originalCond = And(UnresolvedAttribute("b"), Literal(null))
    +    testFilter(originalCond, expectedCond = FalseLiteral)
    +    testJoin(originalCond, expectedCond = FalseLiteral)
    +  }
    +
    +  test("successful replacement of null literals in filter and join 
conditions (10)") {
    +    val originalCond = And(
    +      And(UnresolvedAttribute("b"), Literal(null)),
    +      Or(Literal(null), And(Literal(null), And(UnresolvedAttribute("b"), 
Literal(null)))))
    +    testFilter(originalCond, expectedCond = FalseLiteral)
    +    testJoin(originalCond, expectedCond = FalseLiteral)
    +  }
    +
    +  test("successful replacement of null literals in filter and join 
conditions (11)") {
    +    val originalCond = If(
    +      UnresolvedAttribute("i") > Literal(10),
    +      FalseLiteral,
    +      And(UnresolvedAttribute("b"), Literal(null, BooleanType)))
    +    testFilter(originalCond, expectedCond = FalseLiteral)
    +    testJoin(originalCond, expectedCond = FalseLiteral)
    +  }
    +
    +  test("successful replacement of null literals in filter and join 
conditions (12)") {
    +    val originalCond = And(
    +      UnresolvedAttribute("b"),
    +      If(
    +        UnresolvedAttribute("i") > Literal(10),
    +        Literal(null),
    +        And(FalseLiteral, UnresolvedAttribute("b"))))
    +    testFilter(originalCond, expectedCond = FalseLiteral)
    +    testJoin(originalCond, expectedCond = FalseLiteral)
    +  }
    +
    +  test("successful replacement of null literals in filter and join 
conditions (13)") {
    +    val originalCond = If(
    +      If(UnresolvedAttribute("b"), Literal(null), FalseLiteral),
    +      TrueLiteral,
    +      Literal(null))
    +    testFilter(originalCond, expectedCond = FalseLiteral)
    +    testJoin(originalCond, expectedCond = FalseLiteral)
    +  }
    +
    +  test("successful replacement of null literals in filter and join 
conditions (14)") {
    +    val nestedCaseWhen = CaseWhen(Seq(UnresolvedAttribute("b") -> 
FalseLiteral), Literal(null))
    +    val originalCond = CaseWhen(Seq(nestedCaseWhen -> TrueLiteral), 
Literal(null))
    +    testFilter(originalCond, expectedCond = FalseLiteral)
    +    testJoin(originalCond, expectedCond = FalseLiteral)
    +  }
    +
    +  test("inability to replace null literals in filter and join conditions 
(1)") {
    +    val condition = If(
    +      UnresolvedAttribute("i") > Literal(10),
    +      Literal(5) > If(
    +        UnresolvedAttribute("i") === Literal(15),
    +        Literal(null, IntegerType),
    +        Literal(3)),
    +      FalseLiteral)
    +    testFilter(originalCond = condition, expectedCond = condition)
    +    testJoin(originalCond = condition, expectedCond = condition)
    +  }
    +
    +  test("inability to replace null literals in filter and join conditions 
(2)") {
    +    val nestedCaseWhen = CaseWhen(
    +      Seq((UnresolvedAttribute("i") > Literal(20)) -> Literal(2)),
    +      Literal(null, IntegerType))
    +    val branchValue = If(
    +      Literal(2) === nestedCaseWhen,
    +      TrueLiteral,
    +      FalseLiteral)
    +    val branches = Seq((UnresolvedAttribute("i") > Literal(10)) -> 
branchValue)
    +    val condition = CaseWhen(branches)
    +    testFilter(originalCond = condition, expectedCond = condition)
    +    testJoin(originalCond = condition, expectedCond = condition)
    +  }
    +
    +  test("inability to replace null literals in filter and join conditions 
(3)") {
    +    val condition = If(
    +      Literal(5) > If(
    +        UnresolvedAttribute("i") === Literal(15),
    +        Literal(null, IntegerType),
    +        Literal(3)),
    +      TrueLiteral,
    +      FalseLiteral)
    +    testFilter(originalCond = condition, expectedCond = condition)
    +    testJoin(originalCond = condition, expectedCond = condition)
    +  }
    +
    +  test("successful replacement of null literals in join conditions (1)") {
    +    // this test is only for joins as the condition involves columns from 
different relations
    +    val originalCond = If(
    +      UnresolvedAttribute("d") > UnresolvedAttribute("i"),
    +      Literal(null),
    +      FalseLiteral)
    +    testJoin(originalCond, expectedCond = FalseLiteral)
    +  }
    +
    +  test("successful replacement of null literals in join conditions (2)") {
    +    // this test is only for joins as the condition involves columns from 
different relations
    +    val originalBranches = Seq(
    +      (UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> 
Literal(null),
    +      (UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> 
TrueLiteral)
    +
    +    val expectedBranches = Seq(
    +      (UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> 
FalseLiteral,
    +      (UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> 
TrueLiteral)
    +
    +    testJoin(
    +      originalCond = CaseWhen(originalBranches, FalseLiteral),
    +      expectedCond = CaseWhen(expectedBranches, FalseLiteral))
    +  }
    +
    +  test("inability to replace null literals in join conditions (1)") {
    +    // this test is only for joins as the condition involves columns from 
different relations
    +    val branches = Seq(
    +      (UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> 
Literal(null, BooleanType),
    +      (UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> 
TrueLiteral)
    +    val condition = UnresolvedAttribute("b") === CaseWhen(branches, 
FalseLiteral)
    +    testJoin(originalCond = condition, expectedCond = condition)
    +  }
    +
    +  test("successful replacement of null literals in if predicates (1)") {
    +    val predicate = And(GreaterThan(UnresolvedAttribute("i"), 
Literal(0.5)), Literal(null))
    +    testProjection(
    +      originalExpr = If(predicate, Literal(5), Literal(1)).as("out"),
    +      expectedExpr = Literal(1).as("out"))
    +  }
    +
    +  test("successful replacement of null literals in if predicates (2)") {
    +    val predicate = If(
    +      And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), 
Literal(null)),
    +      TrueLiteral,
    +      FalseLiteral)
    +    testProjection(
    +      originalExpr = If(predicate, Literal(5), Literal(1)).as("out"),
    +      expectedExpr = Literal(1).as("out"))
    +  }
    +
    +  test("inability to replace null literals in if predicates") {
    +    val predicate = GreaterThan(
    +      UnresolvedAttribute("i"),
    +      If(UnresolvedAttribute("b"), Literal(null, IntegerType), Literal(4)))
    +    val column = If(predicate, Literal(5), Literal(1)).as("out")
    +    testProjection(originalExpr = column, expectedExpr = column)
    +  }
    +
    +  test("successful replacement of null literals in branches of case when 
(1)") {
    +    val branches = Seq(
    +      And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), 
Literal(null)) -> Literal(5))
    +    testProjection(
    +      originalExpr = CaseWhen(branches, Literal(2)).as("out"),
    +      expectedExpr = Literal(2).as("out"))
    +  }
    +
    +  test("successful replacement of null literals in branches of case when 
(2)") {
    +    val nestedCaseWhen = CaseWhen(
    +      Seq(And(UnresolvedAttribute("b"), Literal(null)) -> Literal(5)),
    +      Literal(2))
    +    val branches = Seq(GreaterThan(Literal(3), nestedCaseWhen) -> 
Literal(1))
    +    testProjection(
    +      originalExpr = CaseWhen(branches).as("out"),
    +      expectedExpr = Literal(1).as("out"))
    +  }
    +
    +  test("inability to replace null literals in branches of case when") {
    +    val condition = GreaterThan(
    +      UnresolvedAttribute("i"),
    +      If(UnresolvedAttribute("b"), Literal(null, IntegerType), Literal(4)))
    +    val column = CaseWhen(Seq(condition -> Literal(5)), 
Literal(2)).as("out")
    +    testProjection(originalExpr = column, expectedExpr = column)
    +  }
    +
    +  private def testFilter(originalCond: Expression, expectedCond: 
Expression): Unit = {
    +    test((rel, exp) => rel.where(exp), originalCond, expectedCond)
    +  }
    +
    +  private def testJoin(originalCond: Expression, expectedCond: 
Expression): Unit = {
    +    test((rel, exp) => rel.join(anotherTestRelation, Inner, Some(exp)), 
originalCond, expectedCond)
    +  }
    +
    +  private def testProjection(originalExpr: Expression, expectedExpr: 
Expression): Unit = {
    +    test((rel, exp) => rel.select(exp), originalExpr, expectedExpr)
    +  }
    +
    +  private def test(
    +      func: (LogicalPlan, Expression) => LogicalPlan,
    +      originalExpr: Expression,
    +      expectedExpr: Expression): Unit = {
    +
    +    val originalPlan = func(testRelation, originalExpr).analyze
    +    val optimizedPlan = Optimize.execute(originalPlan)
    +    val expectedPlan = func(testRelation, expectedExpr).analyze
    +    comparePlans(optimizedPlan, expectedPlan)
    +  }
    +
    --- End diff --
    
    remove extra line.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to