This is an automated email from the ASF dual-hosted git repository.

ptoth pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c3843a0bbc47 [SPARK-54881][SQL] Improve `BooleanSimplification` to 
handle negation of conjunction and disjunction in one pass
c3843a0bbc47 is described below

commit c3843a0bbc47f3d92780b8c79a1c4436a9dec63d
Author: Asif Hussain Shahid <[email protected]>
AuthorDate: Tue Jan 13 15:55:31 2026 +0100

    [SPARK-54881][SQL] Improve `BooleanSimplification` to handle negation of 
conjunction and disjunction in one pass
    
    Fix to simplify boolean expression of form like !(expr1 || expr2) in a 
single pass, where expr1 and expr2 are binary comparison expression
    
    ### What changes were proposed in this pull request?
    In the rule BooleanSimplification , following two changes are done:
    1) The current  partial function passed as lambda to the 
transformExpressionUp api, is stored in a
     "val actualExprTransformer"
    2) Instead of passing the lambda to the transformExpressionUp, the val 
actualExprTransformer, is passed.
    
    Till this point the code change is mere refactoring.
    The main change in the logic is
    3) for the two cases
    
    case Not(a Or b) =>
          And(Not(a), Not(b)).transformDownWithPruning(_.containsPattern(NOT), 
ruleId) {
            actualExprTransformer
          }
    
    case Not(a And b) =>
          Or(Not(a), Not(b)).transformDownWithPruning(_.containsPattern(NOT), 
ruleId) {
            actualExprTransformer
          }
    
    The new child node of AND and OR, are immediately acted upon by the partial 
function of expression transformer using  transformExpressionDown, which will 
be efficient as the traversal on subtree will stop immediately if the node does 
not contain any NOT operator.
    
    ### Why are the changes needed?
    The change is needed because in the case of tramsformUp, the idempotency is 
not achieved in the optimal way ( single pass compared to double pass).
    The issue arises due to rule transforming
    Not (A || B) => (Not(A) AND Not(B))
    Because the new child has added Not operations, they are not acted in that 
pass due to transformUp.
    With transformDown, the new children with Not, would be simplified in that 
pass itself.
    
    Please note that merely changing transformExpressionUp to 
transformExpressionDown, though will fix this issue, it will break idempotency 
for other cases ( as seen by failure in ConstantFoldingSuite.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Added bug test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #53658 from ahshahid/SPARK-54881.
    
    Authored-by: Asif Hussain Shahid <[email protected]>
    Signed-off-by: Peter Toth <[email protected]>
---
 .../spark/sql/catalyst/optimizer/expressions.scala | 287 +++++++++++----------
 .../optimizer/BooleanSimplificationSuite.scala     |  31 +++
 2 files changed, 181 insertions(+), 137 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 661e43f8548b..98379241c366 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -361,154 +361,167 @@ object BooleanSimplification extends Rule[LogicalPlan] 
with PredicateHelper {
     _.containsAnyPattern(AND, OR, NOT), ruleId) {
     case q: LogicalPlan => q.transformExpressionsUpWithPruning(
       _.containsAnyPattern(AND, OR, NOT), ruleId) {
-      case TrueLiteral And e => e
-      case e And TrueLiteral => e
-      case FalseLiteral Or e => e
-      case e Or FalseLiteral => e
-
-      case FalseLiteral And _ => FalseLiteral
-      case _ And FalseLiteral => FalseLiteral
-      case TrueLiteral Or _ => TrueLiteral
-      case _ Or TrueLiteral => TrueLiteral
-
-      case a And b if Not(a).semanticEquals(b) =>
-        If(IsNull(a), Literal.create(null, a.dataType), FalseLiteral)
-      case a And b if a.semanticEquals(Not(b)) =>
-        If(IsNull(b), Literal.create(null, b.dataType), FalseLiteral)
-
-      case a Or b if Not(a).semanticEquals(b) =>
-        If(IsNull(a), Literal.create(null, a.dataType), TrueLiteral)
-      case a Or b if a.semanticEquals(Not(b)) =>
-        If(IsNull(b), Literal.create(null, b.dataType), TrueLiteral)
-
-      case a And b if a.semanticEquals(b) => a
-      case a Or b if a.semanticEquals(b) => a
-
-      // The following optimizations are applicable only when the operands are 
not nullable,
-      // since the three-value logic of AND and OR are different in NULL 
handling.
-      // See the chart:
-      // +---------+---------+---------+---------+
-      // | operand | operand |   OR    |   AND   |
-      // +---------+---------+---------+---------+
-      // | TRUE    | TRUE    | TRUE    | TRUE    |
-      // | TRUE    | FALSE   | TRUE    | FALSE   |
-      // | FALSE   | FALSE   | FALSE   | FALSE   |
-      // | UNKNOWN | TRUE    | TRUE    | UNKNOWN |
-      // | UNKNOWN | FALSE   | UNKNOWN | FALSE   |
-      // | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN |
-      // +---------+---------+---------+---------+
-
-      // (NULL And (NULL Or FALSE)) = NULL, but (NULL And FALSE) = FALSE. 
Thus, a can't be nullable.
-      case a And (b Or c) if !a.nullable && Not(a).semanticEquals(b) => And(a, 
c)
-      // (NULL And (FALSE Or NULL)) = NULL, but (NULL And FALSE) = FALSE. 
Thus, a can't be nullable.
-      case a And (b Or c) if !a.nullable && Not(a).semanticEquals(c) => And(a, 
b)
-      // ((NULL Or FALSE) And NULL) = NULL, but (FALSE And NULL) = FALSE. 
Thus, c can't be nullable.
-      case (a Or b) And c if !c.nullable && a.semanticEquals(Not(c)) => And(b, 
c)
-      // ((FALSE Or NULL) And NULL) = NULL, but (FALSE And NULL) = FALSE. 
Thus, c can't be nullable.
-      case (a Or b) And c if !c.nullable && b.semanticEquals(Not(c)) => And(a, 
c)
-
-      // (NULL Or (NULL And TRUE)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a 
can't be nullable.
-      case a Or (b And c) if !a.nullable && Not(a).semanticEquals(b) => Or(a, 
c)
-      // (NULL Or (TRUE And NULL)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a 
can't be nullable.
-      case a Or (b And c) if !a.nullable && Not(a).semanticEquals(c) => Or(a, 
b)
-      // ((NULL And TRUE) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c 
can't be nullable.
-      case (a And b) Or c if !c.nullable && a.semanticEquals(Not(c)) => Or(b, 
c)
-      // ((TRUE And NULL) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c 
can't be nullable.
-      case (a And b) Or c if !c.nullable && b.semanticEquals(Not(c)) => Or(a, 
c)
-
-      // Common factor elimination for conjunction
-      case and @ (left And right) =>
-        // 1. Split left and right to get the disjunctive predicates,
-        //    i.e. lhs = (a || b), rhs = (a || c)
-        // 2. Find the common predict between lhsSet and rhsSet, i.e. common = 
(a)
-        // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), 
rdiff = (c)
-        // 4. If common is non-empty, apply the formula to get the optimized 
predicate:
-        //    common || (ldiff && rdiff)
-        // 5. Else if common is empty, split left and right to get the 
conjunctive predicates.
-        //    for example lhs = (a && b), rhs = (a && c) => all = (a, b, a, 
c), distinct = (a, b, c)
-        //    optimized predicate: (a && b && c)
-        val lhs = splitDisjunctivePredicates(left)
-        val rhs = splitDisjunctivePredicates(right)
-        val common = lhs.filter(e => rhs.exists(e.semanticEquals))
-        if (common.nonEmpty) {
-          val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
-          val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
-          if (ldiff.isEmpty || rdiff.isEmpty) {
-            // (a || b || c || ...) && (a || b) => (a || b)
-            common.reduce(Or)
-          } else {
-            // (a || b || c || ...) && (a || b || d || ...) =>
-            // a || b || ((c || ...) && (d || ...))
-            (common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or)
-          }
+      actualExprTransformer
+    }
+  }
+
+  val actualExprTransformer: PartialFunction[Expression, Expression] = {
+    case TrueLiteral And e => e
+    case e And TrueLiteral => e
+    case FalseLiteral Or e => e
+    case e Or FalseLiteral => e
+
+    case FalseLiteral And _ => FalseLiteral
+    case _ And FalseLiteral => FalseLiteral
+    case TrueLiteral Or _ => TrueLiteral
+    case _ Or TrueLiteral => TrueLiteral
+
+    case a And b if Not(a).semanticEquals(b) =>
+      If(IsNull(a), Literal.create(null, a.dataType), FalseLiteral)
+    case a And b if a.semanticEquals(Not(b)) =>
+      If(IsNull(b), Literal.create(null, b.dataType), FalseLiteral)
+
+    case a Or b if Not(a).semanticEquals(b) =>
+      If(IsNull(a), Literal.create(null, a.dataType), TrueLiteral)
+    case a Or b if a.semanticEquals(Not(b)) =>
+      If(IsNull(b), Literal.create(null, b.dataType), TrueLiteral)
+
+    case a And b if a.semanticEquals(b) => a
+    case a Or b if a.semanticEquals(b) => a
+
+    // The following optimizations are applicable only when the operands are 
not nullable,
+    // since the three-value logic of AND and OR are different in NULL 
handling.
+    // See the chart:
+    // +---------+---------+---------+---------+
+    // | operand | operand |   OR    |   AND   |
+    // +---------+---------+---------+---------+
+    // | TRUE    | TRUE    | TRUE    | TRUE    |
+    // | TRUE    | FALSE   | TRUE    | FALSE   |
+    // | FALSE   | FALSE   | FALSE   | FALSE   |
+    // | UNKNOWN | TRUE    | TRUE    | UNKNOWN |
+    // | UNKNOWN | FALSE   | UNKNOWN | FALSE   |
+    // | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN |
+    // +---------+---------+---------+---------+
+
+    // (NULL And (NULL Or FALSE)) = NULL, but (NULL And FALSE) = FALSE. Thus, 
a can't be nullable.
+    case a And (b Or c) if !a.nullable && Not(a).semanticEquals(b) => And(a, c)
+    // (NULL And (FALSE Or NULL)) = NULL, but (NULL And FALSE) = FALSE. Thus, 
a can't be nullable.
+    case a And (b Or c) if !a.nullable && Not(a).semanticEquals(c) => And(a, b)
+    // ((NULL Or FALSE) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus, 
c can't be nullable.
+    case (a Or b) And c if !c.nullable && a.semanticEquals(Not(c)) => And(b, c)
+    // ((FALSE Or NULL) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus, 
c can't be nullable.
+    case (a Or b) And c if !c.nullable && b.semanticEquals(Not(c)) => And(a, c)
+
+    // (NULL Or (NULL And TRUE)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a 
can't be nullable.
+    case a Or (b And c) if !a.nullable && Not(a).semanticEquals(b) => Or(a, c)
+    // (NULL Or (TRUE And NULL)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a 
can't be nullable.
+    case a Or (b And c) if !a.nullable && Not(a).semanticEquals(c) => Or(a, b)
+    // ((NULL And TRUE) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c 
can't be nullable.
+    case (a And b) Or c if !c.nullable && a.semanticEquals(Not(c)) => Or(b, c)
+    // ((TRUE And NULL) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c 
can't be nullable.
+    case (a And b) Or c if !c.nullable && b.semanticEquals(Not(c)) => Or(a, c)
+
+    // Common factor elimination for conjunction
+    case and @ (left And right) =>
+      // 1. Split left and right to get the disjunctive predicates,
+      //    i.e. lhs = (a || b), rhs = (a || c)
+      // 2. Find the common predict between lhsSet and rhsSet, i.e. common = 
(a)
+      // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), 
rdiff = (c)
+      // 4. If common is non-empty, apply the formula to get the optimized 
predicate:
+      //    common || (ldiff && rdiff)
+      // 5. Else if common is empty, split left and right to get the 
conjunctive predicates.
+      //    for example lhs = (a && b), rhs = (a && c) => all = (a, b, a, c), 
distinct = (a, b, c)
+      //    optimized predicate: (a && b && c)
+      val lhs = splitDisjunctivePredicates(left)
+      val rhs = splitDisjunctivePredicates(right)
+      val common = lhs.filter(e => rhs.exists(e.semanticEquals))
+      if (common.nonEmpty) {
+        val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
+        val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
+        if (ldiff.isEmpty || rdiff.isEmpty) {
+          // (a || b || c || ...) && (a || b) => (a || b)
+          common.reduce(Or)
         } else {
-          // No common factors from disjunctive predicates, reduce common 
factor from conjunction
-          val all = splitConjunctivePredicates(left) ++ 
splitConjunctivePredicates(right)
-          val distinct = ExpressionSet(all)
-          if (all.size == distinct.size) {
-            // No common factors, return the original predicate
-            and
-          } else {
-            // (a && b) && a && (a && c) => a && b && c
-            buildBalancedPredicate(distinct.toSeq, And)
-          }
+          // (a || b || c || ...) && (a || b || d || ...) =>
+          // a || b || ((c || ...) && (d || ...))
+          (common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or)
         }
+      } else {
+        // No common factors from disjunctive predicates, reduce common factor 
from conjunction
+        val all = splitConjunctivePredicates(left) ++ 
splitConjunctivePredicates(right)
+        val distinct = ExpressionSet(all)
+        if (all.size == distinct.size) {
+          // No common factors, return the original predicate
+          and
+        } else {
+          // (a && b) && a && (a && c) => a && b && c
+          buildBalancedPredicate(distinct.toSeq, And)
+        }
+      }
 
-      // Common factor elimination for disjunction
-      case or @ (left Or right) =>
-        // 1. Split left and right to get the conjunctive predicates,
-        //    i.e.  lhs = (a && b), rhs = (a && c)
-        // 2. Find the common predict between lhsSet and rhsSet, i.e. common = 
(a)
-        // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), 
rdiff = (c)
-        // 4. If common is non-empty, apply the formula to get the optimized 
predicate:
-        //    common && (ldiff || rdiff)
-        // 5. Else if common is empty, split left and right to get the 
conjunctive predicates.
-        // for example lhs = (a || b), rhs = (a || c) => all = (a, b, a, c), 
distinct = (a, b, c)
-        // optimized predicate: (a || b || c)
-        val lhs = splitConjunctivePredicates(left)
-        val rhs = splitConjunctivePredicates(right)
-        val common = lhs.filter(e => rhs.exists(e.semanticEquals))
-        if (common.nonEmpty) {
-          val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
-          val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
-          if (ldiff.isEmpty || rdiff.isEmpty) {
-            // (a && b) || (a && b && c && ...) => a && b
-            common.reduce(And)
-          } else {
-            // (a && b && c && ...) || (a && b && d && ...) =>
-            // a && b && ((c && ...) || (d && ...))
-            (common :+ Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And)
-          }
+    // Common factor elimination for disjunction
+    case or @ (left Or right) =>
+      // 1. Split left and right to get the conjunctive predicates,
+      //    i.e.  lhs = (a && b), rhs = (a && c)
+      // 2. Find the common predict between lhsSet and rhsSet, i.e. common = 
(a)
+      // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), 
rdiff = (c)
+      // 4. If common is non-empty, apply the formula to get the optimized 
predicate:
+      //    common && (ldiff || rdiff)
+      // 5. Else if common is empty, split left and right to get the 
conjunctive predicates.
+      // for example lhs = (a || b), rhs = (a || c) => all = (a, b, a, c), 
distinct = (a, b, c)
+      // optimized predicate: (a || b || c)
+      val lhs = splitConjunctivePredicates(left)
+      val rhs = splitConjunctivePredicates(right)
+      val common = lhs.filter(e => rhs.exists(e.semanticEquals))
+      if (common.nonEmpty) {
+        val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
+        val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
+        if (ldiff.isEmpty || rdiff.isEmpty) {
+          // (a && b) || (a && b && c && ...) => a && b
+          common.reduce(And)
         } else {
-          // No common factors in conjunctive predicates, reduce common factor 
from disjunction
-          val all = splitDisjunctivePredicates(left) ++ 
splitDisjunctivePredicates(right)
-          val distinct = ExpressionSet(all)
-          if (all.size == distinct.size) {
-            // No common factors, return the original predicate
-            or
-          } else {
-            // (a || b) || a || (a || c) => a || b || c
-            buildBalancedPredicate(distinct.toSeq, Or)
-          }
+          // (a && b && c && ...) || (a && b && d && ...) =>
+          // a && b && ((c && ...) || (d && ...))
+          (common :+ Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And)
         }
+      } else {
+        // No common factors in conjunctive predicates, reduce common factor 
from disjunction
+        val all = splitDisjunctivePredicates(left) ++ 
splitDisjunctivePredicates(right)
+        val distinct = ExpressionSet(all)
+        if (all.size == distinct.size) {
+          // No common factors, return the original predicate
+          or
+        } else {
+          // (a || b) || a || (a || c) => a || b || c
+          buildBalancedPredicate(distinct.toSeq, Or)
+        }
+      }
 
-      case Not(TrueLiteral) => FalseLiteral
-      case Not(FalseLiteral) => TrueLiteral
+    case Not(TrueLiteral) => FalseLiteral
+    case Not(FalseLiteral) => TrueLiteral
 
-      case Not(a GreaterThan b) => LessThanOrEqual(a, b)
-      case Not(a GreaterThanOrEqual b) => LessThan(a, b)
+    case Not(a GreaterThan b) => LessThanOrEqual(a, b)
+    case Not(a GreaterThanOrEqual b) => LessThan(a, b)
 
-      case Not(a LessThan b) => GreaterThanOrEqual(a, b)
-      case Not(a LessThanOrEqual b) => GreaterThan(a, b)
+    case Not(a LessThan b) => GreaterThanOrEqual(a, b)
+    case Not(a LessThanOrEqual b) => GreaterThan(a, b)
 
-      case Not(a Or b) => And(Not(a), Not(b))
-      case Not(a And b) => Or(Not(a), Not(b))
+    // SPARK-54881: push down the NOT operators on children, before attaching 
the junction Node
+    // to the main tree. This ensures idempotency in an optimal way and avoids 
an extra rule
+    // iteration.
+    case Not(a Or b) =>
+      And(Not(a), Not(b)).transformDownWithPruning(_.containsPattern(NOT), 
ruleId) {
+        actualExprTransformer
+      }
+    case Not(a And b) =>
+      Or(Not(a), Not(b)).transformDownWithPruning(_.containsPattern(NOT), 
ruleId) {
+        actualExprTransformer
+      }
 
-      case Not(Not(e)) => e
+    case Not(Not(e)) => e
 
-      case Not(IsNull(e)) => IsNotNull(e)
-      case Not(IsNotNull(e)) => IsNull(e)
-    }
+    case Not(IsNull(e)) => IsNotNull(e)
+    case Not(IsNotNull(e)) => IsNull(e)
   }
 }
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
index 4cc2ee99284a..5a44119bf049 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
@@ -291,6 +291,37 @@ class BooleanSimplificationSuite extends PlanTest with 
ExpressionEvalHelper {
     checkCondition(Not(IsNull($"b")), IsNotNull($"b"))
   }
 
+  test("SPARK-54881: simplify Not(Expr) in single pass") {
+    def executeRuleOnce(exprToTest: Expression, optimizedExprExpected: 
Expression): Unit = {
+      val planAfterRuleApp = 
BooleanSimplification.apply(testRelation.where(exprToTest).analyze)
+      val expectedOptPlan = testRelation.where(optimizedExprExpected).analyze
+      comparePlans(expectedOptPlan, planAfterRuleApp)
+    }
+    // check simplify Not(A <= B OR A >= B) to (a > b AND a < b) in single pass
+    executeRuleOnce(
+      Not(($"a" <= $"b") || ($"a" >= $"b")),
+      $"a" > $"b" && $"a" < $"b"
+    )
+
+    // check simplify Not((expr1 OR expr2) OR (expr3 AND expr4)) in single pass
+    executeRuleOnce(
+      Not(($"a" <= $"b" || $"c" > $"a" + 4) || ($"a" >= $"b" && $"c" < $"a")),
+      And(
+        And($"a" > $"b", $"c" <= $"a" + 4),
+        Or($"a" < $"b", $"c" >= $"a")
+      )
+    )
+
+    // check simplify Not((expr1 OR expr2) AND (expr3 OR expr4)) in single pass
+    executeRuleOnce(
+      Not(($"a" <= $"b" || $"c" > $"a" + 4) && ($"a" >= $"b" || $"c" < $"a")),
+      Or(
+        And($"a" > $"b", $"c" <= $"a" + 4),
+        And($"a" < $"b", $"c" >= $"a")
+      )
+    )
+  }
+
   protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
     val correctAnswer = Project(Alias(e2, "out")() :: Nil, 
OneRowRelation()).analyze
     val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, 
OneRowRelation()).analyze)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to