wangyum commented on code in PR #40268:
URL: https://github.com/apache/spark/pull/40268#discussion_r1126249818


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala:
##########
@@ -138,56 +136,52 @@ object ConstantPropagation extends Rule[LogicalPlan] {
    *                    case of `WHERE e`, null result of expression `e` means 
the same as if it
    *                    resulted false
    * @return A tuple including:
-   *         1. Option[Expression]: optional changed condition after traversal
+   *         1. Expression: optional changed condition after traversal
    *         2. EqualityPredicates: propagated mapping of attribute => constant
    */
-  private def traverse(condition: Expression, replaceChildren: Boolean, 
nullIsFalse: Boolean)
-    : (Option[Expression], EqualityPredicates) =
+  private def traverse(
+      condition: Expression,
+      replaceChildren: Boolean,
+      nullIsFalse: Boolean): (Expression, EqualityPredicates) =
     condition match {
       case e @ EqualTo(left: AttributeReference, right: Literal)
         if safeToReplace(left, nullIsFalse) =>
-        (None, Seq(((left, right), e)))
+        e -> mutable.Map(left.canonicalized -> (right, e))
       case e @ EqualTo(left: Literal, right: AttributeReference)
         if safeToReplace(right, nullIsFalse) =>
-        (None, Seq(((right, left), e)))
+        e -> mutable.Map(right.canonicalized -> (left, e))
       case e @ EqualNullSafe(left: AttributeReference, right: Literal)
         if safeToReplace(left, nullIsFalse) =>
-        (None, Seq(((left, right), e)))
+        e -> mutable.Map(left.canonicalized -> (right, e))
       case e @ EqualNullSafe(left: Literal, right: AttributeReference)
         if safeToReplace(right, nullIsFalse) =>
-        (None, Seq(((right, left), e)))
-      case a: And =>
-        val (newLeft, equalityPredicatesLeft) =
-          traverse(a.left, replaceChildren = false, nullIsFalse)
+        e -> mutable.Map(right.canonicalized -> (left, e))
+      case a @ And(left, right) =>
+        val (newLeft, equalityPredicates) =
+          traverse(left, replaceChildren = false, nullIsFalse)
         val (newRight, equalityPredicatesRight) =
-          traverse(a.right, replaceChildren = false, nullIsFalse)
-        val equalityPredicates = equalityPredicatesLeft ++ 
equalityPredicatesRight
-        val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) {
-          Some(And(replaceConstants(newLeft.getOrElse(a.left), 
equalityPredicates),
-            replaceConstants(newRight.getOrElse(a.right), equalityPredicates)))
+          traverse(right, replaceChildren = false, nullIsFalse)
+        // We could recognize when conflicting constants are coming from the 
left and right sides
+        // and immediately shortcut the `And` expression to 
`Literal.FalseLiteral`, but that case is
+        // not so common and actually it is the job of `ConstantFolding` and 
`BooleanSimplification`
+        // rules to deal with those optimizations.
+        equalityPredicates ++= equalityPredicatesRight
+        val newAnd = a.withNewChildren(if (equalityPredicates.nonEmpty && 
replaceChildren) {
+          val replacedNewLeft = replaceConstants(newLeft, equalityPredicates)
+          val replacedNewRight = replaceConstants(newRight, equalityPredicates)
+          Seq(replacedNewLeft, replacedNewRight)
         } else {
-          if (newLeft.isDefined || newRight.isDefined) {
-            Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right)))
-          } else {
-            None
-          }
-        }
-        (newSelf, equalityPredicates)
+          Seq(newLeft, newRight)
+        })
+        newAnd -> equalityPredicates
       case o: Or =>
         // Ignore the EqualityPredicates from children since they are only 
propagated through And.
-        val (newLeft, _) = traverse(o.left, replaceChildren = true, 
nullIsFalse)
-        val (newRight, _) = traverse(o.right, replaceChildren = true, 
nullIsFalse)
-        val newSelf = if (newLeft.isDefined || newRight.isDefined) {
-          Some(Or(left = newLeft.getOrElse(o.left), right = 
newRight.getOrElse((o.right))))
-        } else {
-          None
-        }
-        (newSelf, Seq.empty)
+        o.mapChildren(traverse(_, replaceChildren = true, nullIsFalse)._1) -> 
mutable.Map.empty

Review Comment:
   ```scala
   val newOr = o.mapChildren(traverse(_, replaceChildren = true, 
nullIsFalse)._1)
   newOr -> AttributeMap.empty
   ```
   ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to