wangyum commented on code in PR #40268:
URL: https://github.com/apache/spark/pull/40268#discussion_r1151346831


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala:
##########
@@ -113,15 +114,13 @@ object ConstantPropagation extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
     _.containsAllPatterns(LITERAL, FILTER), ruleId) {
     case f: Filter =>
-      val (newCondition, _) = traverse(f.condition, replaceChildren = true, 
nullIsFalse = true)
-      if (newCondition.isDefined) {
-        f.copy(condition = newCondition.get)
-      } else {
-        f
-      }
+      f.mapExpressions(traverse(_, nullIsFalse = true, None))

Review Comment:
   This is our internal change:
   ```scala
   object ConstantPropagation extends Rule[LogicalPlan] {
     def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
       _.containsAllPatterns(LITERAL, FILTER), ruleId) {
       case f: Filter =>
         val (newCondition, _) = traverse(f.condition, replaceChildren = true, 
nullIsFalse = true)
         if (newCondition.isDefined) {
           f.copy(condition = newCondition.get)
         } else {
           f
         }
     }
   
     /**
      * Traverse a condition as a tree and replace attributes with constant 
values.
      * - On matching [[And]], recursively traverse each children and get 
propagated mappings.
      *   If the current node is not child of another [[And]], replace all 
occurrences of the
      *   attributes with the corresponding constant values.
      * - If a child of [[And]] is [[EqualTo]] or [[EqualNullSafe]], propagate 
the mapping
      *   of attribute => constant.
      * - On matching [[Or]] or [[Not]], recursively traverse each children, 
propagate empty mapping.
      * - Otherwise, stop traversal and propagate empty mapping.
      * @param condition condition to be traversed
      * @param replaceChildren whether to replace attributes with constant 
values in children
      * @param nullIsFalse whether a boolean expression result can be 
considered to false e.g. in the
      *                    case of `WHERE e`, null result of expression `e` 
means the same as if it
      *                    resulted false
      * @return A tuple including:
      *         1. Option[Expression]: optional changed condition after 
traversal
      *         2. EqualityPredicates: propagated mapping of attribute => 
constant
      */
     private def traverse(condition: Expression, replaceChildren: Boolean, 
nullIsFalse: Boolean)
       : (Option[Expression], AttributeMap[(Literal, BinaryComparison)]) =
       condition match {
         case e @ EqualTo(left: AttributeReference, right: Literal)
           if safeToReplace(left, nullIsFalse) =>
           (None, AttributeMap(Map(left -> (right, e))))
         case e @ EqualTo(left: Literal, right: AttributeReference)
           if safeToReplace(right, nullIsFalse) =>
           (None, AttributeMap(Map(right -> (left, e))))
         case e @ EqualNullSafe(left: AttributeReference, right: Literal)
           if safeToReplace(left, nullIsFalse) =>
           (None, AttributeMap(Map(left -> (right, e))))
         case e @ EqualNullSafe(left: Literal, right: AttributeReference)
           if safeToReplace(right, nullIsFalse) =>
           (None, AttributeMap(Map(right -> (left, e))))
         case a: And =>
           val (newLeft, equalityPredicatesLeft) =
             traverse(a.left, replaceChildren = false, nullIsFalse)
           val (newRight, equalityPredicatesRight) =
             traverse(a.right, replaceChildren = false, nullIsFalse)
           val equalityPredicates = equalityPredicatesLeft ++ 
equalityPredicatesRight
           val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) {
             Some(And(replaceConstants(newLeft.getOrElse(a.left), 
equalityPredicates),
               replaceConstants(newRight.getOrElse(a.right), 
equalityPredicates)))
           } else {
             if (newLeft.isDefined || newRight.isDefined) {
               Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right)))
             } else {
               None
             }
           }
           (newSelf, equalityPredicates)
         case o: Or =>
           // Ignore the EqualityPredicates from children since they are only 
propagated through And.
           val (newLeft, _) = traverse(o.left, replaceChildren = true, 
nullIsFalse)
           val (newRight, _) = traverse(o.right, replaceChildren = true, 
nullIsFalse)
           val newSelf = if (newLeft.isDefined || newRight.isDefined) {
             Some(Or(left = newLeft.getOrElse(o.left), right = 
newRight.getOrElse((o.right))))
           } else {
             None
           }
           (newSelf, AttributeMap.empty)
         case n: Not =>
           // Ignore the EqualityPredicates from children since they are only 
propagated through And.
           val (newChild, _) = traverse(n.child, replaceChildren = true, 
nullIsFalse = false)
           (newChild.map(Not), AttributeMap.empty)
         case _ => (None, AttributeMap.empty)
       }
   
     // We need to take into account if an attribute is nullable and the 
context of the conjunctive
     // expression. E.g. `SELECT * FROM t WHERE NOT(c = 1 AND c + 1 = 1)` where 
attribute `c` can be
     // substituted into `1 + 1 = 1` if 'c' isn't nullable. If 'c' is nullable 
then the enclosing
     // NOT prevents us to do the substitution as NOT flips the context 
(`nullIsFalse`) of what a
     // null result of the enclosed expression means.
     private def safeToReplace(ar: AttributeReference, nullIsFalse: Boolean) =
       !ar.nullable || nullIsFalse
   
     private def replaceConstants(
         condition: Expression,
         equalityPredicates: AttributeMap[(Literal, BinaryComparison)]): 
Expression = {
       val constantsMap = AttributeMap(equalityPredicates.map { case (attr, 
(lit, _)) => attr -> lit })
       val predicates = equalityPredicates.values.map(_._2).toSet
       condition transform {
         case b: BinaryComparison if !predicates.contains(b) => b transform {
           case a: AttributeReference => constantsMap.getOrElse(a, a)
         }
       }
     }
   }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to