Repository: spark
Updated Branches:
refs/heads/master 68dde3481 -> bc9f9b4d6
[SPARK-25860][SQL] Replace Literal(null, _) with FalseLiteral whenever possible
## What changes were proposed in this pull request?
This PR proposes a new optimization rule that replaces `Literal(null, _)` with
`FalseLiteral` in conditions in `Join` and `Filter`, predicates in `If`,
conditions in `CaseWhen`.
The idea is that some expressions evaluate to `false` if the underlying
expression is `null` (as an example see `GeneratePredicate$create` or
`doGenCode` and `eval` methods in `If` and `CaseWhen`). Therefore, we can
replace `Literal(null, _)` with `FalseLiteral`, which can lead to more
optimizations later on.
Letâs consider a few examples.
```
val df = spark.range(1, 100).select($"id".as("l"), ($"id" > 50).as("b"))
df.createOrReplaceTempView("t")
df.createOrReplaceTempView("p")
```
**Case 1**
```
spark.sql("SELECT * FROM t WHERE if(l > 10, false, NULL)").explain(true)
// without the new rule
â¦
== Optimized Logical Plan ==
Project [id#0L AS l#2L, cast(id#0L as string) AS s#3]
+- Filter if ((id#0L > 10)) false else null
+- Range (1, 100, step=1, splits=Some(12))
== Physical Plan ==
*(1) Project [id#0L AS l#2L, cast(id#0L as string) AS s#3]
+- *(1) Filter if ((id#0L > 10)) false else null
+- *(1) Range (1, 100, step=1, splits=12)
// with the new rule
â¦
== Optimized Logical Plan ==
LocalRelation <empty>, [l#2L, s#3]
== Physical Plan ==
LocalTableScan <empty>, [l#2L, s#3]
```
**Case 2**
```
spark.sql("SELECT * FROM t WHERE CASE WHEN l < 10 THEN null WHEN l > 40 THEN
false ELSE null ENDâ).explain(true)
// without the new rule
...
== Optimized Logical Plan ==
Project [id#0L AS l#2L, cast(id#0L as string) AS s#3]
+- Filter CASE WHEN (id#0L < 10) THEN null WHEN (id#0L > 40) THEN false ELSE
null END
+- Range (1, 100, step=1, splits=Some(12))
== Physical Plan ==
*(1) Project [id#0L AS l#2L, cast(id#0L as string) AS s#3]
+- *(1) Filter CASE WHEN (id#0L < 10) THEN null WHEN (id#0L > 40) THEN false
ELSE null END
+- *(1) Range (1, 100, step=1, splits=12)
// with the new rule
...
== Optimized Logical Plan ==
LocalRelation <empty>, [l#2L, s#3]
== Physical Plan ==
LocalTableScan <empty>, [l#2L, s#3]
```
**Case 3**
```
spark.sql("SELECT * FROM t JOIN p ON IF(t.l > p.l, null, false)").explain(true)
// without the new rule
...
== Optimized Logical Plan ==
Join Inner, if ((l#2L > l#37L)) null else false
:- Project [id#0L AS l#2L, cast(id#0L as string) AS s#3]
: +- Range (1, 100, step=1, splits=Some(12))
+- Project [id#0L AS l#37L, cast(id#0L as string) AS s#38]
+- Range (1, 100, step=1, splits=Some(12))
== Physical Plan ==
BroadcastNestedLoopJoin BuildRight, Inner, if ((l#2L > l#37L)) null else false
:- *(1) Project [id#0L AS l#2L, cast(id#0L as string) AS s#3]
: +- *(1) Range (1, 100, step=1, splits=12)
+- BroadcastExchange IdentityBroadcastMode
+- *(2) Project [id#0L AS l#37L, cast(id#0L as string) AS s#38]
+- *(2) Range (1, 100, step=1, splits=12)
// with the new rule
...
== Optimized Logical Plan ==
LocalRelation <empty>, [l#2L, s#3, l#37L, s#38]
```
## How was this patch tested?
This PR comes with a set of dedicated tests.
Closes #22857 from aokolnychyi/spark-25860.
Authored-by: Anton Okolnychyi <[email protected]>
Signed-off-by: DB Tsai <[email protected]>
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bc9f9b4d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bc9f9b4d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bc9f9b4d
Branch: refs/heads/master
Commit: bc9f9b4d6e6ac983a903a0b9a3a668950dc0b2a7
Parents: 68dde34
Author: Anton Okolnychyi <[email protected]>
Authored: Wed Oct 31 18:35:33 2018 +0000
Committer: DB Tsai <[email protected]>
Committed: Wed Oct 31 18:35:33 2018 +0000
----------------------------------------------------------------------
.../sql/catalyst/optimizer/Optimizer.scala | 1 +
.../sql/catalyst/optimizer/expressions.scala | 57 ++++
.../optimizer/ReplaceNullWithFalseSuite.scala | 323 +++++++++++++++++++
.../org/apache/spark/sql/DataFrameSuite.scala | 4 +-
.../sql/ReplaceNullWithFalseEndToEndSuite.scala | 71 ++++
5 files changed, 454 insertions(+), 2 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/bc9f9b4d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 95455ff..a330a84 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -84,6 +84,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
SimplifyConditionals,
RemoveDispensableExpressions,
SimplifyBinaryComparison,
+ ReplaceNullWithFalse,
PruneFilters,
EliminateSorts,
SimplifyCasts,
http://git-wip-us.apache.org/repos/asf/spark/blob/bc9f9b4d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 468a950..2b29b49 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -736,3 +736,60 @@ object CombineConcats extends Rule[LogicalPlan] {
flattenConcats(concat)
}
}
+
+/**
+ * A rule that replaces `Literal(null, _)` with `FalseLiteral` for further
optimizations.
+ *
+ * This rule applies to conditions in [[Filter]] and [[Join]]. Moreover, it
transforms predicates
+ * in all [[If]] expressions as well as branch conditions in all [[CaseWhen]]
expressions.
+ *
+ * For example, `Filter(Literal(null, _))` is equal to `Filter(FalseLiteral)`.
+ *
+ * Another example containing branches is `Filter(If(cond, FalseLiteral,
Literal(null, _)))`;
+ * this can be optimized to `Filter(If(cond, FalseLiteral, FalseLiteral))`,
and eventually
+ * `Filter(FalseLiteral)`.
+ *
+ * As this rule is not limited to conditions in [[Filter]] and [[Join]],
arbitrary plans can
+ * benefit from it. For example, `Project(If(And(cond, Literal(null)),
Literal(1), Literal(2)))`
+ * can be simplified into `Project(Literal(2))`.
+ *
+ * As a result, many unnecessary computations can be removed in the query
optimization phase.
+ */
+object ReplaceNullWithFalse extends Rule[LogicalPlan] {
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond))
+ case j @ Join(_, _, _, Some(cond)) => j.copy(condition =
Some(replaceNullWithFalse(cond)))
+ case p: LogicalPlan => p transformExpressions {
+ case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred))
+ case cw @ CaseWhen(branches, _) =>
+ val newBranches = branches.map { case (cond, value) =>
+ replaceNullWithFalse(cond) -> value
+ }
+ cw.copy(branches = newBranches)
+ }
+ }
+
+ /**
+ * Recursively replaces `Literal(null, _)` with `FalseLiteral`.
+ *
+ * Note that `transformExpressionsDown` can not be used here as we must stop
as soon as we hit
+ * an expression that is not [[CaseWhen]], [[If]], [[And]], [[Or]] or
`Literal(null, _)`.
+ */
+ private def replaceNullWithFalse(e: Expression): Expression = e match {
+ case cw: CaseWhen if cw.dataType == BooleanType =>
+ val newBranches = cw.branches.map { case (cond, value) =>
+ replaceNullWithFalse(cond) -> replaceNullWithFalse(value)
+ }
+ val newElseValue = cw.elseValue.map(replaceNullWithFalse)
+ CaseWhen(newBranches, newElseValue)
+ case i @ If(pred, trueVal, falseVal) if i.dataType == BooleanType =>
+ If(replaceNullWithFalse(pred), replaceNullWithFalse(trueVal),
replaceNullWithFalse(falseVal))
+ case And(left, right) =>
+ And(replaceNullWithFalse(left), replaceNullWithFalse(right))
+ case Or(left, right) =>
+ Or(replaceNullWithFalse(left), replaceNullWithFalse(right))
+ case Literal(null, _) => FalseLiteral
+ case _ => e
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/bc9f9b4d/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala
new file mode 100644
index 0000000..c6b5d0e
--- /dev/null
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala
@@ -0,0 +1,323 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{And, CaseWhen, Expression,
GreaterThan, If, Literal, Or}
+import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral,
TrueLiteral}
+import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.types.{BooleanType, IntegerType}
+
+class ReplaceNullWithFalseSuite extends PlanTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Replace null literals", FixedPoint(10),
+ NullPropagation,
+ ConstantFolding,
+ BooleanSimplification,
+ SimplifyConditionals,
+ ReplaceNullWithFalse) :: Nil
+ }
+
+ private val testRelation = LocalRelation('i.int, 'b.boolean)
+ private val anotherTestRelation = LocalRelation('d.int)
+
+ test("replace null inside filter and join conditions") {
+ testFilter(originalCond = Literal(null), expectedCond = FalseLiteral)
+ testJoin(originalCond = Literal(null), expectedCond = FalseLiteral)
+ }
+
+ test("replace null in branches of If") {
+ val originalCond = If(
+ UnresolvedAttribute("i") > Literal(10),
+ FalseLiteral,
+ Literal(null, BooleanType))
+ testFilter(originalCond, expectedCond = FalseLiteral)
+ testJoin(originalCond, expectedCond = FalseLiteral)
+ }
+
+ test("replace nulls in nested expressions in branches of If") {
+ val originalCond = If(
+ UnresolvedAttribute("i") > Literal(10),
+ TrueLiteral && Literal(null, BooleanType),
+ UnresolvedAttribute("b") && Literal(null, BooleanType))
+ testFilter(originalCond, expectedCond = FalseLiteral)
+ testJoin(originalCond, expectedCond = FalseLiteral)
+ }
+
+ test("replace null in elseValue of CaseWhen") {
+ val branches = Seq(
+ (UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral,
+ (UnresolvedAttribute("i") > Literal(40)) -> FalseLiteral)
+ val originalCond = CaseWhen(branches, Literal(null, BooleanType))
+ val expectedCond = CaseWhen(branches, FalseLiteral)
+ testFilter(originalCond, expectedCond)
+ testJoin(originalCond, expectedCond)
+ }
+
+ test("replace null in branch values of CaseWhen") {
+ val branches = Seq(
+ (UnresolvedAttribute("i") < Literal(10)) -> Literal(null, BooleanType),
+ (UnresolvedAttribute("i") > Literal(40)) -> FalseLiteral)
+ val originalCond = CaseWhen(branches, Literal(null))
+ testFilter(originalCond, expectedCond = FalseLiteral)
+ testJoin(originalCond, expectedCond = FalseLiteral)
+ }
+
+ test("replace null in branches of If inside CaseWhen") {
+ val originalBranches = Seq(
+ (UnresolvedAttribute("i") < Literal(10)) ->
+ If(UnresolvedAttribute("i") < Literal(20), Literal(null, BooleanType),
FalseLiteral),
+ (UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral)
+ val originalCond = CaseWhen(originalBranches)
+
+ val expectedBranches = Seq(
+ (UnresolvedAttribute("i") < Literal(10)) -> FalseLiteral,
+ (UnresolvedAttribute("i") > Literal(40)) -> TrueLiteral)
+ val expectedCond = CaseWhen(expectedBranches)
+
+ testFilter(originalCond, expectedCond)
+ testJoin(originalCond, expectedCond)
+ }
+
+ test("replace null in complex CaseWhen expressions") {
+ val originalBranches = Seq(
+ (UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral,
+ (Literal(6) <= Literal(1)) -> FalseLiteral,
+ (Literal(4) === Literal(5)) -> FalseLiteral,
+ (UnresolvedAttribute("i") > Literal(10)) -> Literal(null, BooleanType),
+ (Literal(4) === Literal(4)) -> TrueLiteral)
+ val originalCond = CaseWhen(originalBranches)
+
+ val expectedBranches = Seq(
+ (UnresolvedAttribute("i") < Literal(10)) -> TrueLiteral,
+ (UnresolvedAttribute("i") > Literal(10)) -> FalseLiteral,
+ TrueLiteral -> TrueLiteral)
+ val expectedCond = CaseWhen(expectedBranches)
+
+ testFilter(originalCond, expectedCond)
+ testJoin(originalCond, expectedCond)
+ }
+
+ test("replace null in Or") {
+ val originalCond = Or(UnresolvedAttribute("b"), Literal(null))
+ val expectedCond = UnresolvedAttribute("b")
+ testFilter(originalCond, expectedCond)
+ testJoin(originalCond, expectedCond)
+ }
+
+ test("replace null in And") {
+ val originalCond = And(UnresolvedAttribute("b"), Literal(null))
+ testFilter(originalCond, expectedCond = FalseLiteral)
+ testJoin(originalCond, expectedCond = FalseLiteral)
+ }
+
+ test("replace nulls in nested And/Or expressions") {
+ val originalCond = And(
+ And(UnresolvedAttribute("b"), Literal(null)),
+ Or(Literal(null), And(Literal(null), And(UnresolvedAttribute("b"),
Literal(null)))))
+ testFilter(originalCond, expectedCond = FalseLiteral)
+ testJoin(originalCond, expectedCond = FalseLiteral)
+ }
+
+ test("replace null in And inside branches of If") {
+ val originalCond = If(
+ UnresolvedAttribute("i") > Literal(10),
+ FalseLiteral,
+ And(UnresolvedAttribute("b"), Literal(null, BooleanType)))
+ testFilter(originalCond, expectedCond = FalseLiteral)
+ testJoin(originalCond, expectedCond = FalseLiteral)
+ }
+
+ test("replace null in branches of If inside And") {
+ val originalCond = And(
+ UnresolvedAttribute("b"),
+ If(
+ UnresolvedAttribute("i") > Literal(10),
+ Literal(null),
+ And(FalseLiteral, UnresolvedAttribute("b"))))
+ testFilter(originalCond, expectedCond = FalseLiteral)
+ testJoin(originalCond, expectedCond = FalseLiteral)
+ }
+
+ test("replace null in branches of If inside another If") {
+ val originalCond = If(
+ If(UnresolvedAttribute("b"), Literal(null), FalseLiteral),
+ TrueLiteral,
+ Literal(null))
+ testFilter(originalCond, expectedCond = FalseLiteral)
+ testJoin(originalCond, expectedCond = FalseLiteral)
+ }
+
+ test("replace null in CaseWhen inside another CaseWhen") {
+ val nestedCaseWhen = CaseWhen(Seq(UnresolvedAttribute("b") ->
FalseLiteral), Literal(null))
+ val originalCond = CaseWhen(Seq(nestedCaseWhen -> TrueLiteral),
Literal(null))
+ testFilter(originalCond, expectedCond = FalseLiteral)
+ testJoin(originalCond, expectedCond = FalseLiteral)
+ }
+
+ test("inability to replace null in non-boolean branches of If") {
+ val condition = If(
+ UnresolvedAttribute("i") > Literal(10),
+ Literal(5) > If(
+ UnresolvedAttribute("i") === Literal(15),
+ Literal(null, IntegerType),
+ Literal(3)),
+ FalseLiteral)
+ testFilter(originalCond = condition, expectedCond = condition)
+ testJoin(originalCond = condition, expectedCond = condition)
+ }
+
+ test("inability to replace null in non-boolean values of CaseWhen") {
+ val nestedCaseWhen = CaseWhen(
+ Seq((UnresolvedAttribute("i") > Literal(20)) -> Literal(2)),
+ Literal(null, IntegerType))
+ val branchValue = If(
+ Literal(2) === nestedCaseWhen,
+ TrueLiteral,
+ FalseLiteral)
+ val branches = Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue)
+ val condition = CaseWhen(branches)
+ testFilter(originalCond = condition, expectedCond = condition)
+ testJoin(originalCond = condition, expectedCond = condition)
+ }
+
+ test("inability to replace null in non-boolean branches of If inside another
If") {
+ val condition = If(
+ Literal(5) > If(
+ UnresolvedAttribute("i") === Literal(15),
+ Literal(null, IntegerType),
+ Literal(3)),
+ TrueLiteral,
+ FalseLiteral)
+ testFilter(originalCond = condition, expectedCond = condition)
+ testJoin(originalCond = condition, expectedCond = condition)
+ }
+
+ test("replace null in If used as a join condition") {
+ // this test is only for joins as the condition involves columns from
different relations
+ val originalCond = If(
+ UnresolvedAttribute("d") > UnresolvedAttribute("i"),
+ Literal(null),
+ FalseLiteral)
+ testJoin(originalCond, expectedCond = FalseLiteral)
+ }
+
+ test("replace null in CaseWhen used as a join condition") {
+ // this test is only for joins as the condition involves columns from
different relations
+ val originalBranches = Seq(
+ (UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> Literal(null),
+ (UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral)
+
+ val expectedBranches = Seq(
+ (UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> FalseLiteral,
+ (UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral)
+
+ testJoin(
+ originalCond = CaseWhen(originalBranches, FalseLiteral),
+ expectedCond = CaseWhen(expectedBranches, FalseLiteral))
+ }
+
+ test("inability to replace null in CaseWhen inside EqualTo used as a join
condition") {
+ // this test is only for joins as the condition involves columns from
different relations
+ val branches = Seq(
+ (UnresolvedAttribute("d") > UnresolvedAttribute("i")) -> Literal(null,
BooleanType),
+ (UnresolvedAttribute("d") === UnresolvedAttribute("i")) -> TrueLiteral)
+ val condition = UnresolvedAttribute("b") === CaseWhen(branches,
FalseLiteral)
+ testJoin(originalCond = condition, expectedCond = condition)
+ }
+
+ test("replace null in predicates of If") {
+ val predicate = And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)),
Literal(null))
+ testProjection(
+ originalExpr = If(predicate, Literal(5), Literal(1)).as("out"),
+ expectedExpr = Literal(1).as("out"))
+ }
+
+ test("replace null in predicates of If inside another If") {
+ val predicate = If(
+ And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), Literal(null)),
+ TrueLiteral,
+ FalseLiteral)
+ testProjection(
+ originalExpr = If(predicate, Literal(5), Literal(1)).as("out"),
+ expectedExpr = Literal(1).as("out"))
+ }
+
+ test("inability to replace null in non-boolean expressions inside If
predicates") {
+ val predicate = GreaterThan(
+ UnresolvedAttribute("i"),
+ If(UnresolvedAttribute("b"), Literal(null, IntegerType), Literal(4)))
+ val column = If(predicate, Literal(5), Literal(1)).as("out")
+ testProjection(originalExpr = column, expectedExpr = column)
+ }
+
+ test("replace null in conditions of CaseWhen") {
+ val branches = Seq(
+ And(GreaterThan(UnresolvedAttribute("i"), Literal(0.5)), Literal(null))
-> Literal(5))
+ testProjection(
+ originalExpr = CaseWhen(branches, Literal(2)).as("out"),
+ expectedExpr = Literal(2).as("out"))
+ }
+
+ test("replace null in conditions of CaseWhen inside another CaseWhen") {
+ val nestedCaseWhen = CaseWhen(
+ Seq(And(UnresolvedAttribute("b"), Literal(null)) -> Literal(5)),
+ Literal(2))
+ val branches = Seq(GreaterThan(Literal(3), nestedCaseWhen) -> Literal(1))
+ testProjection(
+ originalExpr = CaseWhen(branches).as("out"),
+ expectedExpr = Literal(1).as("out"))
+ }
+
+ test("inability to replace null in non-boolean exprs inside CaseWhen
conditions") {
+ val condition = GreaterThan(
+ UnresolvedAttribute("i"),
+ If(UnresolvedAttribute("b"), Literal(null, IntegerType), Literal(4)))
+ val column = CaseWhen(Seq(condition -> Literal(5)), Literal(2)).as("out")
+ testProjection(originalExpr = column, expectedExpr = column)
+ }
+
+ private def testFilter(originalCond: Expression, expectedCond: Expression):
Unit = {
+ test((rel, exp) => rel.where(exp), originalCond, expectedCond)
+ }
+
+ private def testJoin(originalCond: Expression, expectedCond: Expression):
Unit = {
+ test((rel, exp) => rel.join(anotherTestRelation, Inner, Some(exp)),
originalCond, expectedCond)
+ }
+
+ private def testProjection(originalExpr: Expression, expectedExpr:
Expression): Unit = {
+ test((rel, exp) => rel.select(exp), originalExpr, expectedExpr)
+ }
+
+ private def test(
+ func: (LogicalPlan, Expression) => LogicalPlan,
+ originalExpr: Expression,
+ expectedExpr: Expression): Unit = {
+
+ val originalPlan = func(testRelation, originalExpr).analyze
+ val optimizedPlan = Optimize.execute(originalPlan)
+ val expectedPlan = func(testRelation, expectedExpr).analyze
+ comparePlans(optimizedPlan, expectedPlan)
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/bc9f9b4d/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index a430884..4afae56 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -31,14 +31,14 @@ import org.apache.spark.scheduler.{SparkListener,
SparkListenerJobEnd}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.Uuid
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation,
Union}
+import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union}
import org.apache.spark.sql.execution.{FilterExec, QueryExecution,
WholeStageCodegenExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT,
SharedSQLContext}
-import org.apache.spark.sql.test.SQLTestData.{NullInts, NullStrings, TestData2}
+import org.apache.spark.sql.test.SQLTestData.{NullStrings, TestData2}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
http://git-wip-us.apache.org/repos/asf/spark/blob/bc9f9b4d/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala
new file mode 100644
index 0000000..fc6ecc4
--- /dev/null
+++
b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.expressions.{CaseWhen, If}
+import org.apache.spark.sql.execution.LocalTableScanExec
+import org.apache.spark.sql.functions.{lit, when}
+import org.apache.spark.sql.test.SharedSQLContext
+
+class ReplaceNullWithFalseEndToEndSuite extends QueryTest with
SharedSQLContext {
+ import testImplicits._
+
+ test("SPARK-25860: Replace Literal(null, _) with FalseLiteral whenever
possible") {
+ withTable("t1", "t2") {
+ Seq((1, true), (2, false)).toDF("l", "b").write.saveAsTable("t1")
+ Seq(2, 3).toDF("l").write.saveAsTable("t2")
+ val df1 = spark.table("t1")
+ val df2 = spark.table("t2")
+
+ val q1 = df1.where("IF(l > 10, false, b AND null)")
+ checkAnswer(q1, Seq.empty)
+ checkPlanIsEmptyLocalScan(q1)
+
+ val q2 = df1.where("CASE WHEN l < 10 THEN null WHEN l > 40 THEN false
ELSE null END")
+ checkAnswer(q2, Seq.empty)
+ checkPlanIsEmptyLocalScan(q2)
+
+ val q3 = df1.join(df2, when(df1("l") > df2("l"),
lit(null)).otherwise(df1("b") && lit(null)))
+ checkAnswer(q3, Seq.empty)
+ checkPlanIsEmptyLocalScan(q3)
+
+ val q4 = df1.where("IF(IF(b, null, false), true, null)")
+ checkAnswer(q4, Seq.empty)
+ checkPlanIsEmptyLocalScan(q4)
+
+ val q5 = df1.selectExpr("IF(l > 1 AND null, 5, 1) AS out")
+ checkAnswer(q5, Row(1) :: Row(1) :: Nil)
+ q5.queryExecution.executedPlan.foreach { p =>
+ assert(p.expressions.forall(e => e.find(_.isInstanceOf[If]).isEmpty))
+ }
+
+ val q6 = df1.selectExpr("CASE WHEN (l > 2 AND null) THEN 3 ELSE 2 END")
+ checkAnswer(q6, Row(2) :: Row(2) :: Nil)
+ q6.queryExecution.executedPlan.foreach { p =>
+ assert(p.expressions.forall(e =>
e.find(_.isInstanceOf[CaseWhen]).isEmpty))
+ }
+
+ checkAnswer(df1.where("IF(l > 10, false, b OR null)"), Row(1, true))
+ }
+
+ def checkPlanIsEmptyLocalScan(df: DataFrame): Unit =
df.queryExecution.executedPlan match {
+ case s: LocalTableScanExec => assert(s.rows.isEmpty)
+ case p => fail(s"$p is not LocalTableScanExec")
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]