Repository: spark
Updated Branches:
refs/heads/branch-2.4 0f58b989d -> 5554a33f2
[SPARK-25714] Fix Null Handling in the Optimizer rule BooleanSimplification
## What changes were proposed in this pull request?
```Scala
val df1 = Seq(("abc", 1), (null, 3)).toDF("col1", "col2")
df1.write.mode(SaveMode.Overwrite).parquet("/tmp/test1")
val df2 = spark.read.parquet("/tmp/test1")
df2.filter("col1 = 'abc' OR (col1 != 'abc' AND col2 == 3)").show()
```
Before the PR, it returns both rows. After the fix, it returns `Row ("abc",
1))`. This is to fix the bug in NULL handling in BooleanSimplification. This is
a bug introduced in Spark 1.6 release.
## How was this patch tested?
Added test cases
Closes #22702 from gatorsmile/fixBooleanSimplify2.
Authored-by: gatorsmile <[email protected]>
Signed-off-by: gatorsmile <[email protected]>
(cherry picked from commit c9ba59d38e2be17b802156b49d374a726e66c6b9)
Signed-off-by: gatorsmile <[email protected]>
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5554a33f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5554a33f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5554a33f
Branch: refs/heads/branch-2.4
Commit: 5554a33f2809495d78d396339e87fde311427328
Parents: 0f58b98
Author: gatorsmile <[email protected]>
Authored: Fri Oct 12 21:02:38 2018 -0700
Committer: gatorsmile <[email protected]>
Committed: Fri Oct 12 21:02:53 2018 -0700
----------------------------------------------------------------------
.../sql/catalyst/expressions/predicates.scala | 35 ++++++
.../sql/catalyst/optimizer/expressions.scala | 34 ++++--
.../optimizer/BooleanSimplificationSuite.scala | 111 +++++++++++++++----
.../org/apache/spark/sql/DataFrameSuite.scala | 10 ++
4 files changed, 157 insertions(+), 33 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/5554a33f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 149bd79..7f21a62 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -129,6 +129,13 @@ case class Not(child: Expression)
override def inputTypes: Seq[DataType] = Seq(BooleanType)
+ // +---------+-----------+
+ // | CHILD | NOT CHILD |
+ // +---------+-----------+
+ // | TRUE | FALSE |
+ // | FALSE | TRUE |
+ // | UNKNOWN | UNKNOWN |
+ // +---------+-----------+
protected override def nullSafeEval(input: Any): Any =
!input.asInstanceOf[Boolean]
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -406,6 +413,13 @@ case class And(left: Expression, right: Expression)
extends BinaryOperator with
override def sqlOperator: String = "AND"
+ // +---------+---------+---------+---------+
+ // | AND | TRUE | FALSE | UNKNOWN |
+ // +---------+---------+---------+---------+
+ // | TRUE | TRUE | FALSE | UNKNOWN |
+ // | FALSE | FALSE | FALSE | FALSE |
+ // | UNKNOWN | UNKNOWN | FALSE | UNKNOWN |
+ // +---------+---------+---------+---------+
override def eval(input: InternalRow): Any = {
val input1 = left.eval(input)
if (input1 == false) {
@@ -469,6 +483,13 @@ case class Or(left: Expression, right: Expression) extends
BinaryOperator with P
override def sqlOperator: String = "OR"
+ // +---------+---------+---------+---------+
+ // | OR | TRUE | FALSE | UNKNOWN |
+ // +---------+---------+---------+---------+
+ // | TRUE | TRUE | TRUE | TRUE |
+ // | FALSE | TRUE | FALSE | UNKNOWN |
+ // | UNKNOWN | TRUE | UNKNOWN | UNKNOWN |
+ // +---------+---------+---------+---------+
override def eval(input: InternalRow): Any = {
val input1 = left.eval(input)
if (input1 == true) {
@@ -592,6 +613,13 @@ case class EqualTo(left: Expression, right: Expression)
override def symbol: String = "="
+ // +---------+---------+---------+---------+
+ // | = | TRUE | FALSE | UNKNOWN |
+ // +---------+---------+---------+---------+
+ // | TRUE | TRUE | FALSE | UNKNOWN |
+ // | FALSE | FALSE | TRUE | UNKNOWN |
+ // | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN |
+ // +---------+---------+---------+---------+
protected override def nullSafeEval(left: Any, right: Any): Any =
ordering.equiv(left, right)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -629,6 +657,13 @@ case class EqualNullSafe(left: Expression, right:
Expression) extends BinaryComp
override def nullable: Boolean = false
+ // +---------+---------+---------+---------+
+ // | <=> | TRUE | FALSE | UNKNOWN |
+ // +---------+---------+---------+---------+
+ // | TRUE | TRUE | FALSE | UNKNOWN |
+ // | FALSE | FALSE | TRUE | UNKNOWN |
+ // | UNKNOWN | UNKNOWN | UNKNOWN | TRUE |
+ // +---------+---------+---------+---------+
override def eval(input: InternalRow): Any = {
val input1 = left.eval(input)
val input2 = right.eval(input)
http://git-wip-us.apache.org/repos/asf/spark/blob/5554a33f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index f803758..8459043 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -276,15 +276,31 @@ object BooleanSimplification extends Rule[LogicalPlan]
with PredicateHelper {
case a And b if a.semanticEquals(b) => a
case a Or b if a.semanticEquals(b) => a
- case a And (b Or c) if Not(a).semanticEquals(b) => And(a, c)
- case a And (b Or c) if Not(a).semanticEquals(c) => And(a, b)
- case (a Or b) And c if a.semanticEquals(Not(c)) => And(b, c)
- case (a Or b) And c if b.semanticEquals(Not(c)) => And(a, c)
-
- case a Or (b And c) if Not(a).semanticEquals(b) => Or(a, c)
- case a Or (b And c) if Not(a).semanticEquals(c) => Or(a, b)
- case (a And b) Or c if a.semanticEquals(Not(c)) => Or(b, c)
- case (a And b) Or c if b.semanticEquals(Not(c)) => Or(a, c)
+ // The following optimization is applicable only when the operands are
nullable,
+ // since the three-value logic of AND and OR are different in NULL
handling.
+ // See the chart:
+ // +---------+---------+---------+---------+
+ // | p | q | p OR q | p AND q |
+ // +---------+---------+---------+---------+
+ // | TRUE | TRUE | TRUE | TRUE |
+ // | TRUE | FALSE | TRUE | FALSE |
+ // | TRUE | UNKNOWN | TRUE | UNKNOWN |
+ // | FALSE | TRUE | TRUE | FALSE |
+ // | FALSE | FALSE | FALSE | FALSE |
+ // | FALSE | UNKNOWN | UNKNOWN | FALSE |
+ // | UNKNOWN | TRUE | TRUE | UNKNOWN |
+ // | UNKNOWN | FALSE | UNKNOWN | FALSE |
+ // | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN |
+ // +---------+---------+---------+---------+
+ case a And (b Or c) if !a.nullable && Not(a).semanticEquals(b) => And(a,
c)
+ case a And (b Or c) if !a.nullable && Not(a).semanticEquals(c) => And(a,
b)
+ case (a Or b) And c if !a.nullable && a.semanticEquals(Not(c)) => And(b,
c)
+ case (a Or b) And c if !b.nullable && b.semanticEquals(Not(c)) => And(a,
c)
+
+ case a Or (b And c) if !a.nullable && Not(a).semanticEquals(b) => Or(a,
c)
+ case a Or (b And c) if !a.nullable && Not(a).semanticEquals(c) => Or(a,
b)
+ case (a And b) Or c if !a.nullable && a.semanticEquals(Not(c)) => Or(b,
c)
+ case (a And b) Or c if !b.nullable && b.semanticEquals(Not(c)) => Or(a,
c)
// Common factor elimination for conjunction
case and @ (left And right) =>
http://git-wip-us.apache.org/repos/asf/spark/blob/5554a33f/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
index 6cd1108..a0de5f6 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.BooleanType
-class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
+class BooleanSimplificationSuite extends PlanTest with ExpressionEvalHelper
with PredicateHelper {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
@@ -72,6 +72,14 @@ class BooleanSimplificationSuite extends PlanTest with
PredicateHelper {
}
private def checkConditionInNotNullableRelation(
+ input: Expression, expected: Expression): Unit = {
+ val plan = testNotNullableRelationWithData.where(input).analyze
+ val actual = Optimize.execute(plan)
+ val correctAnswer = testNotNullableRelationWithData.where(expected).analyze
+ comparePlans(actual, correctAnswer)
+ }
+
+ private def checkConditionInNotNullableRelation(
input: Expression, expected: LogicalPlan): Unit = {
val plan = testNotNullableRelationWithData.where(input).analyze
val actual = Optimize.execute(plan)
@@ -119,42 +127,55 @@ class BooleanSimplificationSuite extends PlanTest with
PredicateHelper {
'a === 'b || 'b > 3 && 'a > 3 && 'a < 5)
}
- test("e && (!e || f)") {
- checkCondition('e && (!'e || 'f ), 'e && 'f)
+ test("e && (!e || f) - not nullable") {
+ checkConditionInNotNullableRelation('e && (!'e || 'f ), 'e && 'f)
- checkCondition('e && ('f || !'e ), 'e && 'f)
+ checkConditionInNotNullableRelation('e && ('f || !'e ), 'e && 'f)
- checkCondition((!'e || 'f ) && 'e, 'f && 'e)
+ checkConditionInNotNullableRelation((!'e || 'f ) && 'e, 'f && 'e)
- checkCondition(('f || !'e ) && 'e, 'f && 'e)
+ checkConditionInNotNullableRelation(('f || !'e ) && 'e, 'f && 'e)
}
- test("a < 1 && (!(a < 1) || f)") {
- checkCondition('a < 1 && (!('a < 1) || 'f), ('a < 1) && 'f)
- checkCondition('a < 1 && ('f || !('a < 1)), ('a < 1) && 'f)
+ test("e && (!e || f) - nullable") {
+ Seq ('e && (!'e || 'f ),
+ 'e && ('f || !'e ),
+ (!'e || 'f ) && 'e,
+ ('f || !'e ) && 'e,
+ 'e || (!'e && 'f),
+ 'e || ('f && !'e),
+ ('e && 'f) || !'e,
+ ('f && 'e) || !'e).foreach { expr =>
+ checkCondition(expr, expr)
+ }
+ }
- checkCondition('a <= 1 && (!('a <= 1) || 'f), ('a <= 1) && 'f)
- checkCondition('a <= 1 && ('f || !('a <= 1)), ('a <= 1) && 'f)
+ test("a < 1 && (!(a < 1) || f) - not nullable") {
+ checkConditionInNotNullableRelation('a < 1 && (!('a < 1) || 'f), ('a < 1)
&& 'f)
+ checkConditionInNotNullableRelation('a < 1 && ('f || !('a < 1)), ('a < 1)
&& 'f)
- checkCondition('a > 1 && (!('a > 1) || 'f), ('a > 1) && 'f)
- checkCondition('a > 1 && ('f || !('a > 1)), ('a > 1) && 'f)
+ checkConditionInNotNullableRelation('a <= 1 && (!('a <= 1) || 'f), ('a <=
1) && 'f)
+ checkConditionInNotNullableRelation('a <= 1 && ('f || !('a <= 1)), ('a <=
1) && 'f)
- checkCondition('a >= 1 && (!('a >= 1) || 'f), ('a >= 1) && 'f)
- checkCondition('a >= 1 && ('f || !('a >= 1)), ('a >= 1) && 'f)
+ checkConditionInNotNullableRelation('a > 1 && (!('a > 1) || 'f), ('a > 1)
&& 'f)
+ checkConditionInNotNullableRelation('a > 1 && ('f || !('a > 1)), ('a > 1)
&& 'f)
+
+ checkConditionInNotNullableRelation('a >= 1 && (!('a >= 1) || 'f), ('a >=
1) && 'f)
+ checkConditionInNotNullableRelation('a >= 1 && ('f || !('a >= 1)), ('a >=
1) && 'f)
}
- test("a < 1 && ((a >= 1) || f)") {
- checkCondition('a < 1 && ('a >= 1 || 'f ), ('a < 1) && 'f)
- checkCondition('a < 1 && ('f || 'a >= 1), ('a < 1) && 'f)
+ test("a < 1 && ((a >= 1) || f) - not nullable") {
+ checkConditionInNotNullableRelation('a < 1 && ('a >= 1 || 'f ), ('a < 1)
&& 'f)
+ checkConditionInNotNullableRelation('a < 1 && ('f || 'a >= 1), ('a < 1) &&
'f)
- checkCondition('a <= 1 && ('a > 1 || 'f ), ('a <= 1) && 'f)
- checkCondition('a <= 1 && ('f || 'a > 1), ('a <= 1) && 'f)
+ checkConditionInNotNullableRelation('a <= 1 && ('a > 1 || 'f ), ('a <= 1)
&& 'f)
+ checkConditionInNotNullableRelation('a <= 1 && ('f || 'a > 1), ('a <= 1)
&& 'f)
- checkCondition('a > 1 && (('a <= 1) || 'f), ('a > 1) && 'f)
- checkCondition('a > 1 && ('f || ('a <= 1)), ('a > 1) && 'f)
+ checkConditionInNotNullableRelation('a > 1 && (('a <= 1) || 'f), ('a > 1)
&& 'f)
+ checkConditionInNotNullableRelation('a > 1 && ('f || ('a <= 1)), ('a > 1)
&& 'f)
- checkCondition('a >= 1 && (('a < 1) || 'f), ('a >= 1) && 'f)
- checkCondition('a >= 1 && ('f || ('a < 1)), ('a >= 1) && 'f)
+ checkConditionInNotNullableRelation('a >= 1 && (('a < 1) || 'f), ('a >= 1)
&& 'f)
+ checkConditionInNotNullableRelation('a >= 1 && ('f || ('a < 1)), ('a >= 1)
&& 'f)
}
test("DeMorgan's law") {
@@ -217,4 +238,46 @@ class BooleanSimplificationSuite extends PlanTest with
PredicateHelper {
checkCondition('e || !'f, testRelationWithData.where('e || !'f).analyze)
checkCondition(!'f || 'e, testRelationWithData.where(!'f || 'e).analyze)
}
+
+ protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
+ val correctAnswer = Project(Alias(e2, "out")() :: Nil,
OneRowRelation()).analyze
+ val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil,
OneRowRelation()).analyze)
+ comparePlans(actual, correctAnswer)
+ }
+
+ test("filter reduction - positive cases") {
+ val fields = Seq(
+ 'col1NotNULL.boolean.notNull,
+ 'col2NotNULL.boolean.notNull
+ )
+ val Seq(col1NotNULL, col2NotNULL) = fields.zipWithIndex.map { case (f, i)
=> f.at(i) }
+
+ val exprs = Seq(
+ // actual expressions of the transformations: original -> transformed
+ (col1NotNULL && (!col1NotNULL || col2NotNULL)) -> (col1NotNULL &&
col2NotNULL),
+ (col1NotNULL && (col2NotNULL || !col1NotNULL)) -> (col1NotNULL &&
col2NotNULL),
+ ((!col1NotNULL || col2NotNULL) && col1NotNULL) -> (col2NotNULL &&
col1NotNULL),
+ ((col2NotNULL || !col1NotNULL) && col1NotNULL) -> (col2NotNULL &&
col1NotNULL),
+
+ (col1NotNULL || (!col1NotNULL && col2NotNULL)) -> (col1NotNULL ||
col2NotNULL),
+ (col1NotNULL || (col2NotNULL && !col1NotNULL)) -> (col1NotNULL ||
col2NotNULL),
+ ((!col1NotNULL && col2NotNULL) || col1NotNULL) -> (col2NotNULL ||
col1NotNULL),
+ ((col2NotNULL && !col1NotNULL) || col1NotNULL) -> (col2NotNULL ||
col1NotNULL)
+ )
+
+ // check plans
+ for ((originalExpr, expectedExpr) <- exprs) {
+ assertEquivalent(originalExpr, expectedExpr)
+ }
+
+ // check evaluation
+ val binaryBooleanValues = Seq(true, false)
+ for (col1NotNULLVal <- binaryBooleanValues;
+ col2NotNULLVal <- binaryBooleanValues;
+ (originalExpr, expectedExpr) <- exprs) {
+ val inputRow = create_row(col1NotNULLVal, col2NotNULLVal)
+ val optimizedVal = evaluateWithoutCodegen(expectedExpr, inputRow)
+ checkEvaluation(originalExpr, optimizedVal, inputRow)
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/5554a33f/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 279b7b8..4a7bd2f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -30,6 +30,7 @@ import org.apache.spark.SparkException
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.Uuid
+import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation,
Union}
import org.apache.spark.sql.execution.{FilterExec, QueryExecution,
WholeStageCodegenExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
@@ -2579,4 +2580,13 @@ class DataFrameSuite extends QueryTest with
SharedSQLContext {
checkAnswer(df.where("(NOT a) OR a"), Seq.empty)
}
+
+ test("SPARK-25714 Null handling in BooleanSimplification") {
+ withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key ->
ConvertToLocalRelation.ruleName) {
+ val df = Seq(("abc", 1), (null, 3)).toDF("col1", "col2")
+ checkAnswer(
+ df.filter("col1 = 'abc' OR (col1 != 'abc' AND col2 == 3)"),
+ Row ("abc", 1))
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]