peter-toth commented on a change in pull request #29771:
URL: https://github.com/apache/spark/pull/29771#discussion_r490413007



##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
##########
@@ -624,67 +624,89 @@ object NullPropagation extends Rule[LogicalPlan] {
  */
 object FoldablePropagation extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = {
-    var foldableMap = AttributeMap(plan.flatMap {
-      case Project(projectList, _) => projectList.collect {
-        case a: Alias if a.child.foldable => (a.toAttribute, a)
-      }
-      case _ => Nil
-    })
-    val replaceFoldable: PartialFunction[Expression, Expression] = {
-      case a: AttributeReference if foldableMap.contains(a) => foldableMap(a)
+    CleanupAliases(propagateFoldables(plan)._1)
+  }
+
+  private def propagateFoldables(plan: LogicalPlan): (LogicalPlan, 
AttributeMap[Alias]) = {
+    plan match {
+      case p: Project =>
+        val (newChild, foldableMap) = propagateFoldables(p.child)
+        val newProject =
+          
replaceFoldable(p.withNewChildren(Seq(newChild)).asInstanceOf[Project], 
foldableMap)
+        val newFoldableMap = AttributeMap(newProject.projectList.collect {
+          case a: Alias if a.child.foldable => (a.toAttribute, a)
+        })
+        (newProject, newFoldableMap)
+
+      // We can not replace the attributes in `Expand.output`. If there are 
other non-leaf
+      // operators that have the `output` field, we should put them here too.
+      case e: Expand =>
+        val (newChild, foldableMap) = propagateFoldables(e.child)
+        val expandWithNewChildren = 
e.withNewChildren(Seq(newChild)).asInstanceOf[Expand]
+        val newExpand = if (foldableMap.isEmpty) {
+          expandWithNewChildren
+        } else {
+          val newProjections = 
expandWithNewChildren.projections.map(_.map(_.transform {
+            case a: AttributeReference if foldableMap.contains(a) => 
foldableMap(a)
+          }))
+          if (newProjections == expandWithNewChildren.projections) {
+            expandWithNewChildren
+          } else {
+            expandWithNewChildren.copy(projections = newProjections)
+          }
+        }
+        (newExpand, foldableMap)
+
+      case u: UnaryNode if canPropagateFoldables(u) =>
+        val (newChild, foldableMap) = propagateFoldables(u.child)
+        val newU = replaceFoldable(u.withNewChildren(Seq(newChild)), 
foldableMap)
+        (newU, foldableMap)
+
+      // Join derives the output attributes from its child while they are 
actually not the
+      // same attributes. For example, the output of outer join is not always 
picked from its
+      // children, but can also be null. We should exclude these miss-derived 
attributes when
+      // propagating the foldable expressions.
+      // TODO(cloud-fan): It seems more reasonable to use new attributes as 
the output attributes
+      // of outer join.
+      case j: Join =>
+        val (newChildren, foldableMaps) = 
j.children.map(propagateFoldables).unzip
+        val foldableMap = AttributeMap(
+          foldableMaps.foldLeft(Iterable.empty[(Attribute, Alias)])(_ ++ 
_.baseMap.values).toSeq)
+        val newJoin =
+          replaceFoldable(j.withNewChildren(newChildren).asInstanceOf[Join], 
foldableMap)
+        val missDerivedAttrsSet: AttributeSet = AttributeSet(newJoin.joinType 
match {
+          case _: InnerLike | LeftExistence(_) => Nil
+          case LeftOuter => newJoin.right.output
+          case RightOuter => newJoin.left.output
+          case FullOuter => newJoin.left.output ++ newJoin.right.output
+        })
+        val newFoldableMap = AttributeMap(foldableMap.baseMap.values.filterNot 
{
+          case (attr, _) => missDerivedAttrsSet.contains(attr)
+        }.toSeq)
+        (newJoin, newFoldableMap)
+
+      // For other plans, they are not safe to apply foldable propagation, and 
they should not
+      // propagate foldable expressions from children.
+      case o =>
+        val newOther = o.mapChildren(propagateFoldables(_)._1)
+        (newOther, AttributeMap.empty)
     }
+  }
 
+  private def replaceFoldable(plan: LogicalPlan, foldableMap: 
AttributeMap[Alias]): plan.type = {
     if (foldableMap.isEmpty) {
       plan
     } else {
-      CleanupAliases(plan.transformUp {
-        // We can only propagate foldables for a subset of unary nodes.
-        case u: UnaryNode if foldableMap.nonEmpty && canPropagateFoldables(u) 
=>
-          u.transformExpressions(replaceFoldable)
-
-        // Join derives the output attributes from its child while they are 
actually not the
-        // same attributes. For example, the output of outer join is not 
always picked from its
-        // children, but can also be null. We should exclude these 
miss-derived attributes when
-        // propagating the foldable expressions.
-        // TODO(cloud-fan): It seems more reasonable to use new attributes as 
the output attributes
-        // of outer join.
-        case j @ Join(left, right, joinType, _, _) if foldableMap.nonEmpty =>
-          val newJoin = j.transformExpressions(replaceFoldable)
-          val missDerivedAttrsSet: AttributeSet = AttributeSet(joinType match {
-            case _: InnerLike | LeftExistence(_) => Nil
-            case LeftOuter => right.output
-            case RightOuter => left.output
-            case FullOuter => left.output ++ right.output
-          })
-          foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot {
-            case (attr, _) => missDerivedAttrsSet.contains(attr)
-          }.toSeq)
-          newJoin
-
-        // We can not replace the attributes in `Expand.output`. If there are 
other non-leaf
-        // operators that have the `output` field, we should put them here too.
-        case expand: Expand if foldableMap.nonEmpty =>
-          expand.copy(projections = expand.projections.map { projection =>
-            projection.map(_.transform(replaceFoldable))
-          })
-
-        // For other plans, they are not safe to apply foldable propagation, 
and they should not
-        // propagate foldable expressions from children.
-        case other if foldableMap.nonEmpty =>
-          val childrenOutputSet = 
AttributeSet(other.children.flatMap(_.output))
-          foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot {
-            case (attr, _) => childrenOutputSet.contains(attr)
-          }.toSeq)
-          other
-      })
+      plan transformExpressions {
+        case a: AttributeReference if foldableMap.contains(a) => foldableMap(a)
+      }
     }
   }
 
   /**
    * List of all [[UnaryNode]]s which allow foldable propagation.
    */
   private def canPropagateFoldables(u: UnaryNode): Boolean = u match {
-    case _: Project => true

Review comment:
       We don't need it there any more as we handle `Project` 
(https://github.com/apache/spark/pull/29771/files#diff-a1acb054bc8888376603ef510e6d0ee0R632)
 before `UnaryNode if canPropagateFoldables(u)` 
(https://github.com/apache/spark/pull/29771/files#diff-a1acb054bc8888376603ef510e6d0ee0R660).




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to