This is an automated email from the ASF dual-hosted git repository.
ptoth pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c3843a0bbc47 [SPARK-54881][SQL] Improve `BooleanSimplification` to
handle negation of conjunction and disjunction in one pass
c3843a0bbc47 is described below
commit c3843a0bbc47f3d92780b8c79a1c4436a9dec63d
Author: Asif Hussain Shahid <[email protected]>
AuthorDate: Tue Jan 13 15:55:31 2026 +0100
[SPARK-54881][SQL] Improve `BooleanSimplification` to handle negation of
conjunction and disjunction in one pass
Fix to simplify boolean expression of form like !(expr1 || expr2) in a
single pass, where expr1 and expr2 are binary comparison expression
### What changes were proposed in this pull request?
In the rule BooleanSimplification , following two changes are done:
1) The current partial function passed as lambda to the
transformExpressionUp api, is stored in a
"val actualExprTransformer"
2) Instead of passing the lambda to the transformExpressionUp, the val
actualExprTransformer, is passed.
Till this point the code change is mere refactoring.
The main change in the logic is
3) for the two cases
case Not(a Or b) =>
And(Not(a), Not(b)).transformDownWithPruning(_.containsPattern(NOT),
ruleId) {
actualExprTransformer
}
case Not(a And b) =>
Or(Not(a), Not(b)).transformDownWithPruning(_.containsPattern(NOT),
ruleId) {
actualExprTransformer
}
The new child node of AND and OR, are immediately acted upon by the partial
function of expression transformer using transformExpressionDown, which will
be efficient as the traversal on subtree will stop immediately if the node does
not contain any NOT operator.
### Why are the changes needed?
The change is needed because in the case of tramsformUp, the idempotency is
not achieved in the optimal way ( single pass compared to double pass).
The issue arises due to rule transforming
Not (A || B) => (Not(A) AND Not(B))
Because the new child has added Not operations, they are not acted in that
pass due to transformUp.
With transformDown, the new children with Not, would be simplified in that
pass itself.
Please note that merely changing transformExpressionUp to
transformExpressionDown, though will fix this issue, it will break idempotency
for other cases ( as seen by failure in ConstantFoldingSuite.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Added bug test
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #53658 from ahshahid/SPARK-54881.
Authored-by: Asif Hussain Shahid <[email protected]>
Signed-off-by: Peter Toth <[email protected]>
---
.../spark/sql/catalyst/optimizer/expressions.scala | 287 +++++++++++----------
.../optimizer/BooleanSimplificationSuite.scala | 31 +++
2 files changed, 181 insertions(+), 137 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 661e43f8548b..98379241c366 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -361,154 +361,167 @@ object BooleanSimplification extends Rule[LogicalPlan]
with PredicateHelper {
_.containsAnyPattern(AND, OR, NOT), ruleId) {
case q: LogicalPlan => q.transformExpressionsUpWithPruning(
_.containsAnyPattern(AND, OR, NOT), ruleId) {
- case TrueLiteral And e => e
- case e And TrueLiteral => e
- case FalseLiteral Or e => e
- case e Or FalseLiteral => e
-
- case FalseLiteral And _ => FalseLiteral
- case _ And FalseLiteral => FalseLiteral
- case TrueLiteral Or _ => TrueLiteral
- case _ Or TrueLiteral => TrueLiteral
-
- case a And b if Not(a).semanticEquals(b) =>
- If(IsNull(a), Literal.create(null, a.dataType), FalseLiteral)
- case a And b if a.semanticEquals(Not(b)) =>
- If(IsNull(b), Literal.create(null, b.dataType), FalseLiteral)
-
- case a Or b if Not(a).semanticEquals(b) =>
- If(IsNull(a), Literal.create(null, a.dataType), TrueLiteral)
- case a Or b if a.semanticEquals(Not(b)) =>
- If(IsNull(b), Literal.create(null, b.dataType), TrueLiteral)
-
- case a And b if a.semanticEquals(b) => a
- case a Or b if a.semanticEquals(b) => a
-
- // The following optimizations are applicable only when the operands are
not nullable,
- // since the three-value logic of AND and OR are different in NULL
handling.
- // See the chart:
- // +---------+---------+---------+---------+
- // | operand | operand | OR | AND |
- // +---------+---------+---------+---------+
- // | TRUE | TRUE | TRUE | TRUE |
- // | TRUE | FALSE | TRUE | FALSE |
- // | FALSE | FALSE | FALSE | FALSE |
- // | UNKNOWN | TRUE | TRUE | UNKNOWN |
- // | UNKNOWN | FALSE | UNKNOWN | FALSE |
- // | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN |
- // +---------+---------+---------+---------+
-
- // (NULL And (NULL Or FALSE)) = NULL, but (NULL And FALSE) = FALSE.
Thus, a can't be nullable.
- case a And (b Or c) if !a.nullable && Not(a).semanticEquals(b) => And(a,
c)
- // (NULL And (FALSE Or NULL)) = NULL, but (NULL And FALSE) = FALSE.
Thus, a can't be nullable.
- case a And (b Or c) if !a.nullable && Not(a).semanticEquals(c) => And(a,
b)
- // ((NULL Or FALSE) And NULL) = NULL, but (FALSE And NULL) = FALSE.
Thus, c can't be nullable.
- case (a Or b) And c if !c.nullable && a.semanticEquals(Not(c)) => And(b,
c)
- // ((FALSE Or NULL) And NULL) = NULL, but (FALSE And NULL) = FALSE.
Thus, c can't be nullable.
- case (a Or b) And c if !c.nullable && b.semanticEquals(Not(c)) => And(a,
c)
-
- // (NULL Or (NULL And TRUE)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a
can't be nullable.
- case a Or (b And c) if !a.nullable && Not(a).semanticEquals(b) => Or(a,
c)
- // (NULL Or (TRUE And NULL)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a
can't be nullable.
- case a Or (b And c) if !a.nullable && Not(a).semanticEquals(c) => Or(a,
b)
- // ((NULL And TRUE) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c
can't be nullable.
- case (a And b) Or c if !c.nullable && a.semanticEquals(Not(c)) => Or(b,
c)
- // ((TRUE And NULL) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c
can't be nullable.
- case (a And b) Or c if !c.nullable && b.semanticEquals(Not(c)) => Or(a,
c)
-
- // Common factor elimination for conjunction
- case and @ (left And right) =>
- // 1. Split left and right to get the disjunctive predicates,
- // i.e. lhs = (a || b), rhs = (a || c)
- // 2. Find the common predict between lhsSet and rhsSet, i.e. common =
(a)
- // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b),
rdiff = (c)
- // 4. If common is non-empty, apply the formula to get the optimized
predicate:
- // common || (ldiff && rdiff)
- // 5. Else if common is empty, split left and right to get the
conjunctive predicates.
- // for example lhs = (a && b), rhs = (a && c) => all = (a, b, a,
c), distinct = (a, b, c)
- // optimized predicate: (a && b && c)
- val lhs = splitDisjunctivePredicates(left)
- val rhs = splitDisjunctivePredicates(right)
- val common = lhs.filter(e => rhs.exists(e.semanticEquals))
- if (common.nonEmpty) {
- val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
- val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
- if (ldiff.isEmpty || rdiff.isEmpty) {
- // (a || b || c || ...) && (a || b) => (a || b)
- common.reduce(Or)
- } else {
- // (a || b || c || ...) && (a || b || d || ...) =>
- // a || b || ((c || ...) && (d || ...))
- (common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or)
- }
+ actualExprTransformer
+ }
+ }
+
+ val actualExprTransformer: PartialFunction[Expression, Expression] = {
+ case TrueLiteral And e => e
+ case e And TrueLiteral => e
+ case FalseLiteral Or e => e
+ case e Or FalseLiteral => e
+
+ case FalseLiteral And _ => FalseLiteral
+ case _ And FalseLiteral => FalseLiteral
+ case TrueLiteral Or _ => TrueLiteral
+ case _ Or TrueLiteral => TrueLiteral
+
+ case a And b if Not(a).semanticEquals(b) =>
+ If(IsNull(a), Literal.create(null, a.dataType), FalseLiteral)
+ case a And b if a.semanticEquals(Not(b)) =>
+ If(IsNull(b), Literal.create(null, b.dataType), FalseLiteral)
+
+ case a Or b if Not(a).semanticEquals(b) =>
+ If(IsNull(a), Literal.create(null, a.dataType), TrueLiteral)
+ case a Or b if a.semanticEquals(Not(b)) =>
+ If(IsNull(b), Literal.create(null, b.dataType), TrueLiteral)
+
+ case a And b if a.semanticEquals(b) => a
+ case a Or b if a.semanticEquals(b) => a
+
+ // The following optimizations are applicable only when the operands are
not nullable,
+ // since the three-value logic of AND and OR are different in NULL
handling.
+ // See the chart:
+ // +---------+---------+---------+---------+
+ // | operand | operand | OR | AND |
+ // +---------+---------+---------+---------+
+ // | TRUE | TRUE | TRUE | TRUE |
+ // | TRUE | FALSE | TRUE | FALSE |
+ // | FALSE | FALSE | FALSE | FALSE |
+ // | UNKNOWN | TRUE | TRUE | UNKNOWN |
+ // | UNKNOWN | FALSE | UNKNOWN | FALSE |
+ // | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN |
+ // +---------+---------+---------+---------+
+
+ // (NULL And (NULL Or FALSE)) = NULL, but (NULL And FALSE) = FALSE. Thus,
a can't be nullable.
+ case a And (b Or c) if !a.nullable && Not(a).semanticEquals(b) => And(a, c)
+ // (NULL And (FALSE Or NULL)) = NULL, but (NULL And FALSE) = FALSE. Thus,
a can't be nullable.
+ case a And (b Or c) if !a.nullable && Not(a).semanticEquals(c) => And(a, b)
+ // ((NULL Or FALSE) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus,
c can't be nullable.
+ case (a Or b) And c if !c.nullable && a.semanticEquals(Not(c)) => And(b, c)
+ // ((FALSE Or NULL) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus,
c can't be nullable.
+ case (a Or b) And c if !c.nullable && b.semanticEquals(Not(c)) => And(a, c)
+
+ // (NULL Or (NULL And TRUE)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a
can't be nullable.
+ case a Or (b And c) if !a.nullable && Not(a).semanticEquals(b) => Or(a, c)
+ // (NULL Or (TRUE And NULL)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a
can't be nullable.
+ case a Or (b And c) if !a.nullable && Not(a).semanticEquals(c) => Or(a, b)
+ // ((NULL And TRUE) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c
can't be nullable.
+ case (a And b) Or c if !c.nullable && a.semanticEquals(Not(c)) => Or(b, c)
+ // ((TRUE And NULL) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c
can't be nullable.
+ case (a And b) Or c if !c.nullable && b.semanticEquals(Not(c)) => Or(a, c)
+
+ // Common factor elimination for conjunction
+ case and @ (left And right) =>
+ // 1. Split left and right to get the disjunctive predicates,
+ // i.e. lhs = (a || b), rhs = (a || c)
+ // 2. Find the common predict between lhsSet and rhsSet, i.e. common =
(a)
+ // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b),
rdiff = (c)
+ // 4. If common is non-empty, apply the formula to get the optimized
predicate:
+ // common || (ldiff && rdiff)
+ // 5. Else if common is empty, split left and right to get the
conjunctive predicates.
+ // for example lhs = (a && b), rhs = (a && c) => all = (a, b, a, c),
distinct = (a, b, c)
+ // optimized predicate: (a && b && c)
+ val lhs = splitDisjunctivePredicates(left)
+ val rhs = splitDisjunctivePredicates(right)
+ val common = lhs.filter(e => rhs.exists(e.semanticEquals))
+ if (common.nonEmpty) {
+ val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
+ val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
+ if (ldiff.isEmpty || rdiff.isEmpty) {
+ // (a || b || c || ...) && (a || b) => (a || b)
+ common.reduce(Or)
} else {
- // No common factors from disjunctive predicates, reduce common
factor from conjunction
- val all = splitConjunctivePredicates(left) ++
splitConjunctivePredicates(right)
- val distinct = ExpressionSet(all)
- if (all.size == distinct.size) {
- // No common factors, return the original predicate
- and
- } else {
- // (a && b) && a && (a && c) => a && b && c
- buildBalancedPredicate(distinct.toSeq, And)
- }
+ // (a || b || c || ...) && (a || b || d || ...) =>
+ // a || b || ((c || ...) && (d || ...))
+ (common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or)
}
+ } else {
+ // No common factors from disjunctive predicates, reduce common factor
from conjunction
+ val all = splitConjunctivePredicates(left) ++
splitConjunctivePredicates(right)
+ val distinct = ExpressionSet(all)
+ if (all.size == distinct.size) {
+ // No common factors, return the original predicate
+ and
+ } else {
+ // (a && b) && a && (a && c) => a && b && c
+ buildBalancedPredicate(distinct.toSeq, And)
+ }
+ }
- // Common factor elimination for disjunction
- case or @ (left Or right) =>
- // 1. Split left and right to get the conjunctive predicates,
- // i.e. lhs = (a && b), rhs = (a && c)
- // 2. Find the common predict between lhsSet and rhsSet, i.e. common =
(a)
- // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b),
rdiff = (c)
- // 4. If common is non-empty, apply the formula to get the optimized
predicate:
- // common && (ldiff || rdiff)
- // 5. Else if common is empty, split left and right to get the
conjunctive predicates.
- // for example lhs = (a || b), rhs = (a || c) => all = (a, b, a, c),
distinct = (a, b, c)
- // optimized predicate: (a || b || c)
- val lhs = splitConjunctivePredicates(left)
- val rhs = splitConjunctivePredicates(right)
- val common = lhs.filter(e => rhs.exists(e.semanticEquals))
- if (common.nonEmpty) {
- val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
- val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
- if (ldiff.isEmpty || rdiff.isEmpty) {
- // (a && b) || (a && b && c && ...) => a && b
- common.reduce(And)
- } else {
- // (a && b && c && ...) || (a && b && d && ...) =>
- // a && b && ((c && ...) || (d && ...))
- (common :+ Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And)
- }
+ // Common factor elimination for disjunction
+ case or @ (left Or right) =>
+ // 1. Split left and right to get the conjunctive predicates,
+ // i.e. lhs = (a && b), rhs = (a && c)
+ // 2. Find the common predict between lhsSet and rhsSet, i.e. common =
(a)
+ // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b),
rdiff = (c)
+ // 4. If common is non-empty, apply the formula to get the optimized
predicate:
+ // common && (ldiff || rdiff)
+ // 5. Else if common is empty, split left and right to get the
conjunctive predicates.
+ // for example lhs = (a || b), rhs = (a || c) => all = (a, b, a, c),
distinct = (a, b, c)
+ // optimized predicate: (a || b || c)
+ val lhs = splitConjunctivePredicates(left)
+ val rhs = splitConjunctivePredicates(right)
+ val common = lhs.filter(e => rhs.exists(e.semanticEquals))
+ if (common.nonEmpty) {
+ val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
+ val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
+ if (ldiff.isEmpty || rdiff.isEmpty) {
+ // (a && b) || (a && b && c && ...) => a && b
+ common.reduce(And)
} else {
- // No common factors in conjunctive predicates, reduce common factor
from disjunction
- val all = splitDisjunctivePredicates(left) ++
splitDisjunctivePredicates(right)
- val distinct = ExpressionSet(all)
- if (all.size == distinct.size) {
- // No common factors, return the original predicate
- or
- } else {
- // (a || b) || a || (a || c) => a || b || c
- buildBalancedPredicate(distinct.toSeq, Or)
- }
+ // (a && b && c && ...) || (a && b && d && ...) =>
+ // a && b && ((c && ...) || (d && ...))
+ (common :+ Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And)
}
+ } else {
+ // No common factors in conjunctive predicates, reduce common factor
from disjunction
+ val all = splitDisjunctivePredicates(left) ++
splitDisjunctivePredicates(right)
+ val distinct = ExpressionSet(all)
+ if (all.size == distinct.size) {
+ // No common factors, return the original predicate
+ or
+ } else {
+ // (a || b) || a || (a || c) => a || b || c
+ buildBalancedPredicate(distinct.toSeq, Or)
+ }
+ }
- case Not(TrueLiteral) => FalseLiteral
- case Not(FalseLiteral) => TrueLiteral
+ case Not(TrueLiteral) => FalseLiteral
+ case Not(FalseLiteral) => TrueLiteral
- case Not(a GreaterThan b) => LessThanOrEqual(a, b)
- case Not(a GreaterThanOrEqual b) => LessThan(a, b)
+ case Not(a GreaterThan b) => LessThanOrEqual(a, b)
+ case Not(a GreaterThanOrEqual b) => LessThan(a, b)
- case Not(a LessThan b) => GreaterThanOrEqual(a, b)
- case Not(a LessThanOrEqual b) => GreaterThan(a, b)
+ case Not(a LessThan b) => GreaterThanOrEqual(a, b)
+ case Not(a LessThanOrEqual b) => GreaterThan(a, b)
- case Not(a Or b) => And(Not(a), Not(b))
- case Not(a And b) => Or(Not(a), Not(b))
+ // SPARK-54881: push down the NOT operators on children, before attaching
the junction Node
+ // to the main tree. This ensures idempotency in an optimal way and avoids
an extra rule
+ // iteration.
+ case Not(a Or b) =>
+ And(Not(a), Not(b)).transformDownWithPruning(_.containsPattern(NOT),
ruleId) {
+ actualExprTransformer
+ }
+ case Not(a And b) =>
+ Or(Not(a), Not(b)).transformDownWithPruning(_.containsPattern(NOT),
ruleId) {
+ actualExprTransformer
+ }
- case Not(Not(e)) => e
+ case Not(Not(e)) => e
- case Not(IsNull(e)) => IsNotNull(e)
- case Not(IsNotNull(e)) => IsNull(e)
- }
+ case Not(IsNull(e)) => IsNotNull(e)
+ case Not(IsNotNull(e)) => IsNull(e)
}
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
index 4cc2ee99284a..5a44119bf049 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
@@ -291,6 +291,37 @@ class BooleanSimplificationSuite extends PlanTest with
ExpressionEvalHelper {
checkCondition(Not(IsNull($"b")), IsNotNull($"b"))
}
+ test("SPARK-54881: simplify Not(Expr) in single pass") {
+ def executeRuleOnce(exprToTest: Expression, optimizedExprExpected:
Expression): Unit = {
+ val planAfterRuleApp =
BooleanSimplification.apply(testRelation.where(exprToTest).analyze)
+ val expectedOptPlan = testRelation.where(optimizedExprExpected).analyze
+ comparePlans(expectedOptPlan, planAfterRuleApp)
+ }
+ // check simplify Not(A <= B OR A >= B) to (a > b AND a < b) in single pass
+ executeRuleOnce(
+ Not(($"a" <= $"b") || ($"a" >= $"b")),
+ $"a" > $"b" && $"a" < $"b"
+ )
+
+ // check simplify Not((expr1 OR expr2) OR (expr3 AND expr4)) in single pass
+ executeRuleOnce(
+ Not(($"a" <= $"b" || $"c" > $"a" + 4) || ($"a" >= $"b" && $"c" < $"a")),
+ And(
+ And($"a" > $"b", $"c" <= $"a" + 4),
+ Or($"a" < $"b", $"c" >= $"a")
+ )
+ )
+
+ // check simplify Not((expr1 OR expr2) AND (expr3 OR expr4)) in single pass
+ executeRuleOnce(
+ Not(($"a" <= $"b" || $"c" > $"a" + 4) && ($"a" >= $"b" || $"c" < $"a")),
+ Or(
+ And($"a" > $"b", $"c" <= $"a" + 4),
+ And($"a" < $"b", $"c" >= $"a")
+ )
+ )
+ }
+
protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
val correctAnswer = Project(Alias(e2, "out")() :: Nil,
OneRowRelation()).analyze
val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil,
OneRowRelation()).analyze)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]