ahshahid commented on code in PR #53658:
URL: https://github.com/apache/spark/pull/53658#discussion_r2666129524


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala:
##########
@@ -357,157 +357,168 @@ object OptimizeIn extends Rule[LogicalPlan] {
  * 4. Removes `Not` operator.
  */
 object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
-  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
-    _.containsAnyPattern(AND, OR, NOT), ruleId) {
-    case q: LogicalPlan => q.transformExpressionsUpWithPruning(
-      _.containsAnyPattern(AND, OR, NOT), ruleId) {
-      case TrueLiteral And e => e
-      case e And TrueLiteral => e
-      case FalseLiteral Or e => e
-      case e Or FalseLiteral => e
-
-      case FalseLiteral And _ => FalseLiteral
-      case _ And FalseLiteral => FalseLiteral
-      case TrueLiteral Or _ => TrueLiteral
-      case _ Or TrueLiteral => TrueLiteral
-
-      case a And b if Not(a).semanticEquals(b) =>
-        If(IsNull(a), Literal.create(null, a.dataType), FalseLiteral)
-      case a And b if a.semanticEquals(Not(b)) =>
-        If(IsNull(b), Literal.create(null, b.dataType), FalseLiteral)
-
-      case a Or b if Not(a).semanticEquals(b) =>
-        If(IsNull(a), Literal.create(null, a.dataType), TrueLiteral)
-      case a Or b if a.semanticEquals(Not(b)) =>
-        If(IsNull(b), Literal.create(null, b.dataType), TrueLiteral)
-
-      case a And b if a.semanticEquals(b) => a
-      case a Or b if a.semanticEquals(b) => a
-
-      // The following optimizations are applicable only when the operands are 
not nullable,
-      // since the three-value logic of AND and OR are different in NULL 
handling.
-      // See the chart:
-      // +---------+---------+---------+---------+
-      // | operand | operand |   OR    |   AND   |
-      // +---------+---------+---------+---------+
-      // | TRUE    | TRUE    | TRUE    | TRUE    |
-      // | TRUE    | FALSE   | TRUE    | FALSE   |
-      // | FALSE   | FALSE   | FALSE   | FALSE   |
-      // | UNKNOWN | TRUE    | TRUE    | UNKNOWN |
-      // | UNKNOWN | FALSE   | UNKNOWN | FALSE   |
-      // | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN |
-      // +---------+---------+---------+---------+
-
-      // (NULL And (NULL Or FALSE)) = NULL, but (NULL And FALSE) = FALSE. 
Thus, a can't be nullable.
-      case a And (b Or c) if !a.nullable && Not(a).semanticEquals(b) => And(a, 
c)
-      // (NULL And (FALSE Or NULL)) = NULL, but (NULL And FALSE) = FALSE. 
Thus, a can't be nullable.
-      case a And (b Or c) if !a.nullable && Not(a).semanticEquals(c) => And(a, 
b)
-      // ((NULL Or FALSE) And NULL) = NULL, but (FALSE And NULL) = FALSE. 
Thus, c can't be nullable.
-      case (a Or b) And c if !c.nullable && a.semanticEquals(Not(c)) => And(b, 
c)
-      // ((FALSE Or NULL) And NULL) = NULL, but (FALSE And NULL) = FALSE. 
Thus, c can't be nullable.
-      case (a Or b) And c if !c.nullable && b.semanticEquals(Not(c)) => And(a, 
c)
-
-      // (NULL Or (NULL And TRUE)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a 
can't be nullable.
-      case a Or (b And c) if !a.nullable && Not(a).semanticEquals(b) => Or(a, 
c)
-      // (NULL Or (TRUE And NULL)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a 
can't be nullable.
-      case a Or (b And c) if !a.nullable && Not(a).semanticEquals(c) => Or(a, 
b)
-      // ((NULL And TRUE) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c 
can't be nullable.
-      case (a And b) Or c if !c.nullable && a.semanticEquals(Not(c)) => Or(b, 
c)
-      // ((TRUE And NULL) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c 
can't be nullable.
-      case (a And b) Or c if !c.nullable && b.semanticEquals(Not(c)) => Or(a, 
c)
-
-      // Common factor elimination for conjunction
-      case and @ (left And right) =>
-        // 1. Split left and right to get the disjunctive predicates,
-        //    i.e. lhs = (a || b), rhs = (a || c)
-        // 2. Find the common predict between lhsSet and rhsSet, i.e. common = 
(a)
-        // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), 
rdiff = (c)
-        // 4. If common is non-empty, apply the formula to get the optimized 
predicate:
-        //    common || (ldiff && rdiff)
-        // 5. Else if common is empty, split left and right to get the 
conjunctive predicates.
-        //    for example lhs = (a && b), rhs = (a && c) => all = (a, b, a, 
c), distinct = (a, b, c)
-        //    optimized predicate: (a && b && c)
-        val lhs = splitDisjunctivePredicates(left)
-        val rhs = splitDisjunctivePredicates(right)
-        val common = lhs.filter(e => rhs.exists(e.semanticEquals))
-        if (common.nonEmpty) {
-          val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
-          val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
-          if (ldiff.isEmpty || rdiff.isEmpty) {
-            // (a || b || c || ...) && (a || b) => (a || b)
-            common.reduce(Or)
-          } else {
-            // (a || b || c || ...) && (a || b || d || ...) =>
-            // a || b || ((c || ...) && (d || ...))
-            (common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or)
-          }
+
+  val actualExprTransformer: PartialFunction[Expression, Expression] = {
+    case TrueLiteral And e => e
+    case e And TrueLiteral => e
+    case FalseLiteral Or e => e
+    case e Or FalseLiteral => e
+
+    case FalseLiteral And _ => FalseLiteral
+    case _ And FalseLiteral => FalseLiteral
+    case TrueLiteral Or _ => TrueLiteral
+    case _ Or TrueLiteral => TrueLiteral
+
+    case a And b if Not(a).semanticEquals(b) =>
+      If(IsNull(a), Literal.create(null, a.dataType), FalseLiteral)
+    case a And b if a.semanticEquals(Not(b)) =>
+      If(IsNull(b), Literal.create(null, b.dataType), FalseLiteral)
+
+    case a Or b if Not(a).semanticEquals(b) =>
+      If(IsNull(a), Literal.create(null, a.dataType), TrueLiteral)
+    case a Or b if a.semanticEquals(Not(b)) =>
+      If(IsNull(b), Literal.create(null, b.dataType), TrueLiteral)
+
+    case a And b if a.semanticEquals(b) => a
+    case a Or b if a.semanticEquals(b) => a
+
+    // The following optimizations are applicable only when the operands are 
not nullable,
+    // since the three-value logic of AND and OR are different in NULL 
handling.
+    // See the chart:
+    // +---------+---------+---------+---------+
+    // | operand | operand |   OR    |   AND   |
+    // +---------+---------+---------+---------+
+    // | TRUE    | TRUE    | TRUE    | TRUE    |
+    // | TRUE    | FALSE   | TRUE    | FALSE   |
+    // | FALSE   | FALSE   | FALSE   | FALSE   |
+    // | UNKNOWN | TRUE    | TRUE    | UNKNOWN |
+    // | UNKNOWN | FALSE   | UNKNOWN | FALSE   |
+    // | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN |
+    // +---------+---------+---------+---------+
+
+    // (NULL And (NULL Or FALSE)) = NULL, but (NULL And FALSE) = FALSE. Thus, 
a can't be nullable.
+    case a And (b Or c) if !a.nullable && Not(a).semanticEquals(b) => And(a, c)
+    // (NULL And (FALSE Or NULL)) = NULL, but (NULL And FALSE) = FALSE. Thus, 
a can't be nullable.
+    case a And (b Or c) if !a.nullable && Not(a).semanticEquals(c) => And(a, b)
+    // ((NULL Or FALSE) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus, 
c can't be nullable.
+    case (a Or b) And c if !c.nullable && a.semanticEquals(Not(c)) => And(b, c)
+    // ((FALSE Or NULL) And NULL) = NULL, but (FALSE And NULL) = FALSE. Thus, 
c can't be nullable.
+    case (a Or b) And c if !c.nullable && b.semanticEquals(Not(c)) => And(a, c)
+
+    // (NULL Or (NULL And TRUE)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a 
can't be nullable.
+    case a Or (b And c) if !a.nullable && Not(a).semanticEquals(b) => Or(a, c)
+    // (NULL Or (TRUE And NULL)) = NULL, but (NULL Or TRUE) = TRUE. Thus, a 
can't be nullable.
+    case a Or (b And c) if !a.nullable && Not(a).semanticEquals(c) => Or(a, b)
+    // ((NULL And TRUE) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c 
can't be nullable.
+    case (a And b) Or c if !c.nullable && a.semanticEquals(Not(c)) => Or(b, c)
+    // ((TRUE And NULL) Or NULL) = NULL, but (TRUE Or NULL) = TRUE. Thus, c 
can't be nullable.
+    case (a And b) Or c if !c.nullable && b.semanticEquals(Not(c)) => Or(a, c)
+
+    // Common factor elimination for conjunction
+    case and @ (left And right) =>
+      // 1. Split left and right to get the disjunctive predicates,
+      //    i.e. lhs = (a || b), rhs = (a || c)
+      // 2. Find the common predict between lhsSet and rhsSet, i.e. common = 
(a)
+      // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), 
rdiff = (c)
+      // 4. If common is non-empty, apply the formula to get the optimized 
predicate:
+      //    common || (ldiff && rdiff)
+      // 5. Else if common is empty, split left and right to get the 
conjunctive predicates.
+      //    for example lhs = (a && b), rhs = (a && c) => all = (a, b, a, c), 
distinct = (a, b, c)
+      //    optimized predicate: (a && b && c)
+      val lhs = splitDisjunctivePredicates(left)
+      val rhs = splitDisjunctivePredicates(right)
+      val common = lhs.filter(e => rhs.exists(e.semanticEquals))
+      if (common.nonEmpty) {
+        val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
+        val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
+        if (ldiff.isEmpty || rdiff.isEmpty) {
+          // (a || b || c || ...) && (a || b) => (a || b)
+          common.reduce(Or)
         } else {
-          // No common factors from disjunctive predicates, reduce common 
factor from conjunction
-          val all = splitConjunctivePredicates(left) ++ 
splitConjunctivePredicates(right)
-          val distinct = ExpressionSet(all)
-          if (all.size == distinct.size) {
-            // No common factors, return the original predicate
-            and
-          } else {
-            // (a && b) && a && (a && c) => a && b && c
-            buildBalancedPredicate(distinct.toSeq, And)
-          }
+          // (a || b || c || ...) && (a || b || d || ...) =>
+          // a || b || ((c || ...) && (d || ...))
+          (common :+ And(ldiff.reduce(Or), rdiff.reduce(Or))).reduce(Or)
+        }
+      } else {
+        // No common factors from disjunctive predicates, reduce common factor 
from conjunction
+        val all = splitConjunctivePredicates(left) ++ 
splitConjunctivePredicates(right)
+        val distinct = ExpressionSet(all)
+        if (all.size == distinct.size) {
+          // No common factors, return the original predicate
+          and
+        } else {
+          // (a && b) && a && (a && c) => a && b && c
+          buildBalancedPredicate(distinct.toSeq, And)
         }
+      }
 
-      // Common factor elimination for disjunction
-      case or @ (left Or right) =>
-        // 1. Split left and right to get the conjunctive predicates,
-        //    i.e.  lhs = (a && b), rhs = (a && c)
-        // 2. Find the common predict between lhsSet and rhsSet, i.e. common = 
(a)
-        // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), 
rdiff = (c)
-        // 4. If common is non-empty, apply the formula to get the optimized 
predicate:
-        //    common && (ldiff || rdiff)
-        // 5. Else if common is empty, split left and right to get the 
conjunctive predicates.
-        // for example lhs = (a || b), rhs = (a || c) => all = (a, b, a, c), 
distinct = (a, b, c)
-        // optimized predicate: (a || b || c)
-        val lhs = splitConjunctivePredicates(left)
-        val rhs = splitConjunctivePredicates(right)
-        val common = lhs.filter(e => rhs.exists(e.semanticEquals))
-        if (common.nonEmpty) {
-          val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
-          val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
-          if (ldiff.isEmpty || rdiff.isEmpty) {
-            // (a && b) || (a && b && c && ...) => a && b
-            common.reduce(And)
-          } else {
-            // (a && b && c && ...) || (a && b && d && ...) =>
-            // a && b && ((c && ...) || (d && ...))
-            (common :+ Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And)
-          }
+    // Common factor elimination for disjunction
+    case or @ (left Or right) =>
+      // 1. Split left and right to get the conjunctive predicates,
+      //    i.e.  lhs = (a && b), rhs = (a && c)
+      // 2. Find the common predict between lhsSet and rhsSet, i.e. common = 
(a)
+      // 3. Remove common predict from lhsSet and rhsSet, i.e. ldiff = (b), 
rdiff = (c)
+      // 4. If common is non-empty, apply the formula to get the optimized 
predicate:
+      //    common && (ldiff || rdiff)
+      // 5. Else if common is empty, split left and right to get the 
conjunctive predicates.
+      // for example lhs = (a || b), rhs = (a || c) => all = (a, b, a, c), 
distinct = (a, b, c)
+      // optimized predicate: (a || b || c)
+      val lhs = splitConjunctivePredicates(left)
+      val rhs = splitConjunctivePredicates(right)
+      val common = lhs.filter(e => rhs.exists(e.semanticEquals))
+      if (common.nonEmpty) {
+        val ldiff = lhs.filterNot(e => common.exists(e.semanticEquals))
+        val rdiff = rhs.filterNot(e => common.exists(e.semanticEquals))
+        if (ldiff.isEmpty || rdiff.isEmpty) {
+          // (a && b) || (a && b && c && ...) => a && b
+          common.reduce(And)
         } else {
-          // No common factors in conjunctive predicates, reduce common factor 
from disjunction
-          val all = splitDisjunctivePredicates(left) ++ 
splitDisjunctivePredicates(right)
-          val distinct = ExpressionSet(all)
-          if (all.size == distinct.size) {
-            // No common factors, return the original predicate
-            or
-          } else {
-            // (a || b) || a || (a || c) => a || b || c
-            buildBalancedPredicate(distinct.toSeq, Or)
-          }
+          // (a && b && c && ...) || (a && b && d && ...) =>
+          // a && b && ((c && ...) || (d && ...))
+          (common :+ Or(ldiff.reduce(And), rdiff.reduce(And))).reduce(And)
         }
+      } else {
+        // No common factors in conjunctive predicates, reduce common factor 
from disjunction
+        val all = splitDisjunctivePredicates(left) ++ 
splitDisjunctivePredicates(right)
+        val distinct = ExpressionSet(all)
+        if (all.size == distinct.size) {
+          // No common factors, return the original predicate
+          or
+        } else {
+          // (a || b) || a || (a || c) => a || b || c
+          buildBalancedPredicate(distinct.toSeq, Or)
+        }
+      }
+
+    case Not(TrueLiteral) => FalseLiteral
+    case Not(FalseLiteral) => TrueLiteral
 
-      case Not(TrueLiteral) => FalseLiteral
-      case Not(FalseLiteral) => TrueLiteral
+    case Not(a GreaterThan b) => LessThanOrEqual(a, b)
+    case Not(a GreaterThanOrEqual b) => LessThan(a, b)
 
-      case Not(a GreaterThan b) => LessThanOrEqual(a, b)
-      case Not(a GreaterThanOrEqual b) => LessThan(a, b)
+    case Not(a LessThan b) => GreaterThanOrEqual(a, b)
+    case Not(a LessThanOrEqual b) => GreaterThan(a, b)
 
-      case Not(a LessThan b) => GreaterThanOrEqual(a, b)
-      case Not(a LessThanOrEqual b) => GreaterThan(a, b)
+    case Not(a Or b) =>
+      And(Not(a), Not(b)).transformDownWithPruning(_.containsPattern(NOT), 
ruleId) {

Review Comment:
   > Is this safe? I mean, before this PR the simplification logic of 
`actualExprTransformer` was called with `transformUp...`, but now you call it 
with `transformDown...` (please note that a `Not` node can be deep down in `a` 
or `b`). Is there any reason why we invoke the logic with `transformUp` or 
could the whole rule use `transformDown` on expression trees?
   
   I believe it's safe..
   If the original logic is modified such that instead of transform up ,
   transform down is used,  then this bug would be fixed, but other cases like
   that mentioned in Constant folding suite will break in idempotency.
   To take care of both the cases, use of transform up and transform down is
   needed...as in the pr. This reason is also mentioned in the initial PR 
details.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to