Repository: spark
Updated Branches:
  refs/heads/master 68dde3481 -> bc9f9b4d6


[SPARK-25860][SQL] Replace Literal(null, _) with FalseLiteral whenever possible

## What changes were proposed in this pull request?

This PR proposes a new optimization rule that replaces `Literal(null, _)` with 
`FalseLiteral` in conditions in `Join` and `Filter`, predicates in `If`, 
conditions in `CaseWhen`.

The idea is that some expressions evaluate to `false` if the underlying 
expression is `null` (as an example see `GeneratePredicate$create` or 
`doGenCode` and `eval` methods in `If` and `CaseWhen`). Therefore, we can 
replace `Literal(null, _)` with `FalseLiteral`, which can lead to more 
optimizations later on.

Let’s consider a few examples.

```
val df = spark.range(1, 100).select($"id".as("l"), ($"id" > 50).as("b"))
df.createOrReplaceTempView("t")
df.createOrReplaceTempView("p")
```

**Case 1**
```
spark.sql("SELECT * FROM t WHERE if(l > 10, false, NULL)").explain(true)

// without the new rule
…
== Optimized Logical Plan ==
Project [id#0L AS l#2L, cast(id#0L as string) AS s#3]
+- Filter if ((id#0L > 10)) false else null
   +- Range (1, 100, step=1, splits=Some(12))

== Physical Plan ==
*(1) Project [id#0L AS l#2L, cast(id#0L as string) AS s#3]
+- *(1) Filter if ((id#0L > 10)) false else null
   +- *(1) Range (1, 100, step=1, splits=12)

// with the new rule
…
== Optimized Logical Plan ==
LocalRelation <empty>, [l#2L, s#3]

== Physical Plan ==
LocalTableScan <empty>, [l#2L, s#3]
```

**Case 2**
```
spark.sql("SELECT * FROM t WHERE CASE WHEN l < 10 THEN null WHEN l > 40 THEN 
false ELSE null END”).explain(true)

// without the new rule
...
== Optimized Logical Plan ==
Project [id#0L AS l#2L, cast(id#0L as string) AS s#3]
+- Filter CASE WHEN (id#0L < 10) THEN null WHEN (id#0L > 40) THEN false ELSE 
null END
   +- Range (1, 100, step=1, splits=Some(12))

== Physical Plan ==
*(1) Project [id#0L AS l#2L, cast(id#0L as string) AS s#3]
+- *(1) Filter CASE WHEN (id#0L < 10) THEN null WHEN (id#0L > 40) THEN false 
ELSE null END
   +- *(1) Range (1, 100, step=1, splits=12)

// with the new rule
...
== Optimized Logical Plan ==
LocalRelation <empty>, [l#2L, s#3]

== Physical Plan ==
LocalTableScan <empty>, [l#2L, s#3]
```

**Case 3**
```
spark.sql("SELECT * FROM t JOIN p ON IF(t.l > p.l, null, false)").explain(true)

// without the new rule
...
== Optimized Logical Plan ==
Join Inner, if ((l#2L > l#37L)) null else false
:- Project [id#0L AS l#2L, cast(id#0L as string) AS s#3]
:  +- Range (1, 100, step=1, splits=Some(12))
+- Project [id#0L AS l#37L, cast(id#0L as string) AS s#38]
   +- Range (1, 100, step=1, splits=Some(12))

== Physical Plan ==
BroadcastNestedLoopJoin BuildRight, Inner, if ((l#2L > l#37L)) null else false
:- *(1) Project [id#0L AS l#2L, cast(id#0L as string) AS s#3]
:  +- *(1) Range (1, 100, step=1, splits=12)
+- BroadcastExchange IdentityBroadcastMode
   +- *(2) Project [id#0L AS l#37L, cast(id#0L as string) AS s#38]
      +- *(2) Range (1, 100, step=1, splits=12)

// with the new rule
...
== Optimized Logical Plan ==
LocalRelation <empty>, [l#2L, s#3, l#37L, s#38]
```

## How was this patch tested?

This PR comes with a set of dedicated tests.

Closes #22857 from aokolnychyi/spark-25860.

Authored-by: Anton Okolnychyi <[email protected]>
Signed-off-by: DB Tsai <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bc9f9b4d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bc9f9b4d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bc9f9b4d

Branch: refs/heads/master
Commit: bc9f9b4d6e6ac983a903a0b9a3a668950dc0b2a7
Parents: 68dde34
Author: Anton Okolnychyi <[email protected]>
Authored: Wed Oct 31 18:35:33 2018 +0000
Committer: DB Tsai <[email protected]>
Committed: Wed Oct 31 18:35:33 2018 +0000

----------------------------------------------------------------------
 .../sql/catalyst/optimizer/Optimizer.scala      |   1 +
 .../sql/catalyst/optimizer/expressions.scala    |  57 ++++
 .../optimizer/ReplaceNullWithFalseSuite.scala   | 323 +++++++++++++++++++
 .../org/apache/spark/sql/DataFrameSuite.scala   |   4 +-
 .../sql/ReplaceNullWithFalseEndToEndSuite.scala |  71 ++++
 5 files changed, 454 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bc9f9b4d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 95455ff..a330a84 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -84,6 +84,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
         SimplifyConditionals,
         RemoveDispensableExpressions,
         SimplifyBinaryComparison,
+        ReplaceNullWithFalse,
         PruneFilters,
         EliminateSorts,
         SimplifyCasts,

http://git-wip-us.apache.org/repos/asf/spark/blob/bc9f9b4d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 468a950..2b29b49 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -736,3 +736,60 @@ object CombineConcats extends Rule[LogicalPlan] {
       flattenConcats(concat)
   }
 }
+
+/**
+ * A rule that replaces `Literal(null, _)` with `FalseLiteral` for further 
optimizations.
+ *
+ * This rule applies to conditions in [[Filter]] and [[Join]]. Moreover, it 
transforms predicates
+ * in all [[If]] expressions as well as branch conditions in all [[CaseWhen]] 
expressions.
+ *
+ * For example, `Filter(Literal(null, _))` is equal to `Filter(FalseLiteral)`.
+ *
+ * Another example containing branches is `Filter(If(cond, FalseLiteral, 
Literal(null, _)))`;
+ * this can be optimized to `Filter(If(cond, FalseLiteral, FalseLiteral))`, 
and eventually
+ * `Filter(FalseLiteral)`.
+ *
+ * As this rule is not limited to conditions in [[Filter]] and [[Join]], 
arbitrary plans can
+ * benefit from it. For example, `Project(If(And(cond, Literal(null)), 
Literal(1), Literal(2)))`
+ * can be simplified into `Project(Literal(2))`.
+ *
+ * As a result, many unnecessary computations can be removed in the query 
optimization phase.
+ */
+object ReplaceNullWithFalse extends Rule[LogicalPlan] {
+
+  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+    case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond))
+    case j @ Join(_, _, _, Some(cond)) => j.copy(condition = 
Some(replaceNullWithFalse(cond)))
+    case p: LogicalPlan => p transformExpressions {
+      case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred))
+      case cw @ CaseWhen(branches, _) =>
+        val newBranches = branches.map { case (cond, value) =>
+          replaceNullWithFalse(cond) -> value
+        }
+        cw.copy(branches = newBranches)
+    }
+  }
+
+  /**
+   * Recursively replaces `Literal(null, _)` with `FalseLiteral`.
+   *
+   * Note that `transformExpressionsDown` can not be used here as we must stop 
as soon as we hit
+   * an expression that is not [[CaseWhen]], [[If]], [[And]], [[Or]] or 
`Literal(null, _)`.
+   */
+  private def replaceNullWithFalse(e: Expression): Expression = e match {
+    case cw: CaseWhen if cw.dataType == BooleanType =>
+      val newBranches = cw.branches.map { case (cond, value) =>
+        replaceNullWithFalse(cond) -> replaceNullWithFalse(value)
+      }
+      val newElseValue = cw.elseValue.map(replaceNullWithFalse)
+      CaseWhen(newBranches, newElseValue)
+    case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType =>
+      If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal), 
replaceNullWithFalse(falseVal))
+    case And(left, right) =>
+      And(replaceNullWithFalse(left), replaceNullWithFalse(right))
+    case Or(left, right) =>
+      Or(replaceNullWithFalse(left), replaceNullWithFalse(right))
+    case Literal(null, _) => FalseLiteral
+    case _ => e
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/bc9f9b4d/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala
new file mode 100644
index 0000000..c6b5d0e
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala
@@ -0,0 +1,323 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{And, CaseWhen, Expression, 
GreaterThan, If, Literal, Or}
+import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, 
TrueLiteral}
+import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.types.{BooleanType, IntegerType}
+
+class ReplaceNullWithFalseSuite extends PlanTest {
+
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches =
+      Batch("Replace null literals", FixedPoint(10),
+        NullPropagation,
+        ConstantFolding,
+        BooleanSimplification,
+        SimplifyConditionals,
+        ReplaceNullWithFalse) :: Nil
+  }
+
+  private val testRelation = LocalRelation('i.int, 'b.boolean)
+  private val anotherTestRelation = LocalRelation('d.int)
+
+  test("replace null inside filter and join conditions") {
+    testFilter(originalCond = Literal(null), expectedCond = FalseLiteral)
+    testJoin(originalCond = Literal(null), expectedCond = FalseLiteral)
+  }
+
+  test("replace null in branches of If") {
+    val originalCond = If(
+      UnresolvedAttribute("i") > Literal(10),
+      FalseLiteral,
+      Literal(null, BooleanType))
+    testFilter(originalCond, expectedCond = FalseLiteral)
+    testJoin(originalCond, expectedCond = FalseLiteral)
+  }
+
+  test("replace nulls in nested expressions in branches of If") {
+    val originalCond = If(
+      UnresolvedAttribute("i") > Literal(10),
+      TrueLiteral && Literal(null, BooleanType),
+      UnresolvedAttribute("b") && Literal(null, BooleanType))
+    testFilter(originalCond, expectedCond = FalseLiteral)
+    testJoin(originalCond, expectedCond = FalseLiteral)
+  }
+
+  test("replace null in elseValue of CaseWhen") {
+    val branches = Seq(
+      (UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral,
+      (UnresolvedAttribute("i") > Literal(40)) -> FalseLiteral)
+    val originalCond = CaseWhen(branches, Literal(null, BooleanType))
+    val expectedCond = CaseWhen(branches, FalseLiteral)
+    testFilter(originalCond, expectedCond)
+    testJoin(originalCond, expectedCond)
+  }
+
+  test("replace null in branch values of CaseWhen") {
+    val branches = Seq(
+      (UnresolvedAttribute("i") < Literal(10)) -> Literal(null, BooleanType),
+      (UnresolvedAttribute("i") > Literal(40)) -> FalseLiteral)
+    val originalCond = CaseWhen(branches, Literal(null))
+    testFilter(originalCond, expectedCond = FalseLiteral)
+    testJoin(originalCond, expectedCond = FalseLiteral)
+  }
+
+  test("replace null in branches of If inside CaseWhen") {
+    val originalBranches = Seq(
+      (UnresolvedAttribute("i") < Literal(10)) ->
+        If(UnresolvedAttribute("i") < Literal(20), Literal(null, BooleanType), 
FalseLiteral),
+      (UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral)
+    val originalCond = CaseWhen(originalBranches)
+
+    val expectedBranches = Seq(
+      (UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral,
+      (UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral)
+    val expectedCond = CaseWhen(expectedBranches)
+
+    testFilter(originalCond, expectedCond)
+    testJoin(originalCond, expectedCond)
+  }
+
+  test("replace null in complex CaseWhen expressions") {
+    val originalBranches = Seq(
+      (UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral,
+      (Literal(6) <= Literal(1)) -> FalseLiteral,
+      (Literal(4) === Literal(5)) -> FalseLiteral,
+      (UnresolvedAttribute("i") > Literal(10)) -> Literal(null, BooleanType),
+      (Literal(4) === Literal(4)) -> TrueLiteral)
+    val originalCond = CaseWhen(originalBranches)
+
+    val expectedBranches = Seq(
+      (UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral,
+      (UnresolvedAttribute("i") > Literal(10)) -> FalseLiteral,
+      TrueLiteral -> TrueLiteral)
+    val expectedCond = CaseWhen(expectedBranches)
+
+    testFilter(originalCond, expectedCond)
+    testJoin(originalCond, expectedCond)
+  }
+
+  test("replace null in Or") {
+    val originalCond = Or(UnresolvedAttribute("b"), Literal(null))
+    val expectedCond = UnresolvedAttribute("b")
+    testFilter(originalCond, expectedCond)
+    testJoin(originalCond, expectedCond)
+  }
+
+  test("replace null in And") {
+    val originalCond = And(UnresolvedAttribute("b"), Literal(null))
+    testFilter(originalCond, expectedCond = FalseLiteral)
+    testJoin(originalCond, expectedCond = FalseLiteral)
+  }
+
+  test("replace nulls in nested And/Or expressions") {
+    val originalCond = And(
+      And(UnresolvedAttribute("b"), Literal(null)),
+      Or(Literal(null), And(Literal(null), And(UnresolvedAttribute("b"), 
Literal(null)))))
+    testFilter(originalCond, expectedCond = FalseLiteral)
+    testJoin(originalCond, expectedCond = FalseLiteral)
+  }
+
+  test("replace null in And inside branches of If") {
+    val originalCond = If(
+      UnresolvedAttribute("i") > Literal(10),
+      FalseLiteral,
+      And(UnresolvedAttribute("b"), Literal(null, BooleanType)))
+    testFilter(originalCond, expectedCond = FalseLiteral)
+    testJoin(originalCond, expectedCond = FalseLiteral)
+  }
+
+  test("replace null in branches of If inside And") {
+    val originalCond = And(
+      UnresolvedAttribute("b"),
+      If(
+        UnresolvedAttribute("i") > Literal(10),
+        Literal(null),
+        And(FalseLiteral, UnresolvedAttribute("b"))))
+    testFilter(originalCond, expectedCond = FalseLiteral)
+    testJoin(originalCond, expectedCond = FalseLiteral)
+  }
+
+  test("replace null in branches of If inside another If") {
+    val originalCond = If(
+      If(UnresolvedAttribute("b"), Literal(null), FalseLiteral),
+      TrueLiteral,
+      Literal(null))
+    testFilter(originalCond, expectedCond = FalseLiteral)
+    testJoin(originalCond, expectedCond = FalseLiteral)
+  }
+
+  test("replace null in CaseWhen inside another CaseWhen") {
+    val nestedCaseWhen = CaseWhen(Seq(UnresolvedAttribute("b") -> 
FalseLiteral), Literal(null))
+    val originalCond = CaseWhen(Seq(nestedCaseWhen -> TrueLiteral), 
Literal(null))
+    testFilter(originalCond, expectedCond = FalseLiteral)
+    testJoin(originalCond, expectedCond = FalseLiteral)
+  }
+
+  test("inability to replace null in non-boolean branches of If") {
+    val condition = If(
+      UnresolvedAttribute("i") > Literal(10),
+      Literal(5) > If(
+        UnresolvedAttribute("i") === Literal(15),
+        Literal(null, IntegerType),
+        Literal(3)),
+      FalseLiteral)
+    testFilter(originalCond = condition, expectedCond = condition)
+    testJoin(originalCond = condition, expectedCond = condition)
+  }
+
+  test("inability to replace null in non-boolean values of CaseWhen") {
+    val nestedCaseWhen = CaseWhen(
+      Seq((UnresolvedAttribute("i") > Literal(20)) -> Literal(2)),
+      Literal(null, IntegerType))
+    val branchValue = If(
+      Literal(2) === nestedCaseWhen,
+      TrueLiteral,
+      FalseLiteral)
+    val branches = Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue)
+    val condition = CaseWhen(branches)
+    testFilter(originalCond = condition, expectedCond = condition)
+    testJoin(originalCond = condition, expectedCond = condition)
+  }
+
+  test("inability to replace null in non-boolean branches of If inside another 
If") {
+    val condition = If(
+      Literal(5) > If(
+        UnresolvedAttribute("i") === Literal(15),
+        Literal(null, IntegerType),
+        Literal(3)),
+      TrueLiteral,
+      FalseLiteral)
+    testFilter(originalCond = condition, expectedCond = condition)
+    testJoin(originalCond = condition, expectedCond = condition)
+  }
+
+  test("replace null in If used as a join condition") {
+    // this test is only for joins as the condition involves columns from 
different relations
+    val originalCond = If(
+      UnresolvedAttribute("d") > UnresolvedAttribute("i"),
+      Literal(null),
+      FalseLiteral)
+    testJoin(originalCond, expectedCond = FalseLiteral)
+  }
+
+  test("replace null in CaseWhen used as a join condition") {
+    // this test is only for joins as the condition involves columns from 
different relations
+    val originalBranches = Seq(
+      (UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> Literal(null),
+      (UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral)
+
+    val expectedBranches = Seq(
+      (UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> FalseLiteral,
+      (UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral)
+
+    testJoin(
+      originalCond = CaseWhen(originalBranches, FalseLiteral),
+      expectedCond = CaseWhen(expectedBranches, FalseLiteral))
+  }
+
+  test("inability to replace null in CaseWhen inside EqualTo used as a join 
condition") {
+    // this test is only for joins as the condition involves columns from 
different relations
+    val branches = Seq(
+      (UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> Literal(null, 
BooleanType),
+      (UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral)
+    val condition = UnresolvedAttribute("b") === CaseWhen(branches, 
FalseLiteral)
+    testJoin(originalCond = condition, expectedCond = condition)
+  }
+
+  test("replace null in predicates of If") {
+    val predicate = And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), 
Literal(null))
+    testProjection(
+      originalExpr = If(predicate, Literal(5), Literal(1)).as("out"),
+      expectedExpr = Literal(1).as("out"))
+  }
+
+  test("replace null in predicates of If inside another If") {
+    val predicate = If(
+      And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), Literal(null)),
+      TrueLiteral,
+      FalseLiteral)
+    testProjection(
+      originalExpr = If(predicate, Literal(5), Literal(1)).as("out"),
+      expectedExpr = Literal(1).as("out"))
+  }
+
+  test("inability to replace null in non-boolean expressions inside If 
predicates") {
+    val predicate = GreaterThan(
+      UnresolvedAttribute("i"),
+      If(UnresolvedAttribute("b"), Literal(null, IntegerType), Literal(4)))
+    val column = If(predicate, Literal(5), Literal(1)).as("out")
+    testProjection(originalExpr = column, expectedExpr = column)
+  }
+
+  test("replace null in conditions of CaseWhen") {
+    val branches = Seq(
+      And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), Literal(null)) 
-> Literal(5))
+    testProjection(
+      originalExpr = CaseWhen(branches, Literal(2)).as("out"),
+      expectedExpr = Literal(2).as("out"))
+  }
+
+  test("replace null in conditions of CaseWhen inside another CaseWhen") {
+    val nestedCaseWhen = CaseWhen(
+      Seq(And(UnresolvedAttribute("b"), Literal(null)) -> Literal(5)),
+      Literal(2))
+    val branches = Seq(GreaterThan(Literal(3), nestedCaseWhen) -> Literal(1))
+    testProjection(
+      originalExpr = CaseWhen(branches).as("out"),
+      expectedExpr = Literal(1).as("out"))
+  }
+
+  test("inability to replace null in non-boolean exprs inside CaseWhen 
conditions") {
+    val condition = GreaterThan(
+      UnresolvedAttribute("i"),
+      If(UnresolvedAttribute("b"), Literal(null, IntegerType), Literal(4)))
+    val column = CaseWhen(Seq(condition -> Literal(5)), Literal(2)).as("out")
+    testProjection(originalExpr = column, expectedExpr = column)
+  }
+
+  private def testFilter(originalCond: Expression, expectedCond: Expression): 
Unit = {
+    test((rel, exp) => rel.where(exp), originalCond, expectedCond)
+  }
+
+  private def testJoin(originalCond: Expression, expectedCond: Expression): 
Unit = {
+    test((rel, exp) => rel.join(anotherTestRelation, Inner, Some(exp)), 
originalCond, expectedCond)
+  }
+
+  private def testProjection(originalExpr: Expression, expectedExpr: 
Expression): Unit = {
+    test((rel, exp) => rel.select(exp), originalExpr, expectedExpr)
+  }
+
+  private def test(
+      func: (LogicalPlan, Expression) => LogicalPlan,
+      originalExpr: Expression,
+      expectedExpr: Expression): Unit = {
+
+    val originalPlan = func(testRelation, originalExpr).analyze
+    val optimizedPlan = Optimize.execute(originalPlan)
+    val expectedPlan = func(testRelation, expectedExpr).analyze
+    comparePlans(optimizedPlan, expectedPlan)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/bc9f9b4d/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index a430884..4afae56 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -31,14 +31,14 @@ import org.apache.spark.scheduler.{SparkListener, 
SparkListenerJobEnd}
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.expressions.Uuid
 import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, 
Union}
+import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union}
 import org.apache.spark.sql.execution.{FilterExec, QueryExecution, 
WholeStageCodegenExec}
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
ReusedExchangeExec, ShuffleExchangeExec}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, 
SharedSQLContext}
-import org.apache.spark.sql.test.SQLTestData.{NullInts, NullStrings, TestData2}
+import org.apache.spark.sql.test.SQLTestData.{NullStrings, TestData2}
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 import org.apache.spark.util.random.XORShiftRandom

http://git-wip-us.apache.org/repos/asf/spark/blob/bc9f9b4d/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala
new file mode 100644
index 0000000..fc6ecc4
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.expressions.{CaseWhen, If}
+import org.apache.spark.sql.execution.LocalTableScanExec
+import org.apache.spark.sql.functions.{lit, when}
+import org.apache.spark.sql.test.SharedSQLContext
+
+class ReplaceNullWithFalseEndToEndSuite extends QueryTest with 
SharedSQLContext {
+  import testImplicits._
+
+  test("SPARK-25860: Replace Literal(null, _) with FalseLiteral whenever 
possible") {
+    withTable("t1", "t2") {
+      Seq((1, true), (2, false)).toDF("l", "b").write.saveAsTable("t1")
+      Seq(2, 3).toDF("l").write.saveAsTable("t2")
+      val df1 = spark.table("t1")
+      val df2 = spark.table("t2")
+
+      val q1 = df1.where("IF(l > 10, false, b AND null)")
+      checkAnswer(q1, Seq.empty)
+      checkPlanIsEmptyLocalScan(q1)
+
+      val q2 = df1.where("CASE WHEN l < 10 THEN null WHEN l > 40 THEN false 
ELSE null END")
+      checkAnswer(q2, Seq.empty)
+      checkPlanIsEmptyLocalScan(q2)
+
+      val q3 = df1.join(df2, when(df1("l") > df2("l"), 
lit(null)).otherwise(df1("b") && lit(null)))
+      checkAnswer(q3, Seq.empty)
+      checkPlanIsEmptyLocalScan(q3)
+
+      val q4 = df1.where("IF(IF(b, null, false), true, null)")
+      checkAnswer(q4, Seq.empty)
+      checkPlanIsEmptyLocalScan(q4)
+
+      val q5 = df1.selectExpr("IF(l > 1 AND null, 5, 1) AS out")
+      checkAnswer(q5, Row(1) :: Row(1) :: Nil)
+      q5.queryExecution.executedPlan.foreach { p =>
+        assert(p.expressions.forall(e => e.find(_.isInstanceOf[If]).isEmpty))
+      }
+
+      val q6 = df1.selectExpr("CASE WHEN (l > 2 AND null) THEN 3 ELSE 2 END")
+      checkAnswer(q6, Row(2) :: Row(2) :: Nil)
+      q6.queryExecution.executedPlan.foreach { p =>
+        assert(p.expressions.forall(e => 
e.find(_.isInstanceOf[CaseWhen]).isEmpty))
+      }
+
+      checkAnswer(df1.where("IF(l > 10, false, b OR null)"), Row(1, true))
+    }
+
+    def checkPlanIsEmptyLocalScan(df: DataFrame): Unit = 
df.queryExecution.executedPlan match {
+      case s: LocalTableScanExec => assert(s.rows.isEmpty)
+      case p => fail(s"$p is not LocalTableScanExec")
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to