viirya commented on a change in pull request #29166:
URL: https://github.com/apache/spark/pull/29166#discussion_r457874203



##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
##########
@@ -1237,20 +1249,44 @@ class Analyzer(
       if (conflictPlans.isEmpty) {
         right
       } else {
-        val attributeRewrites = AttributeMap(conflictPlans.flatMap {
-          case (oldRelation, newRelation) => 
oldRelation.output.zip(newRelation.output)})
-        val conflictPlanMap = conflictPlans.toMap
-        // transformDown so that we can replace all the old Relations in one 
turn due to
-        // the reason that `conflictPlans` are also collected in pre-order.
-        right transformDown {
-          case r => conflictPlanMap.getOrElse(r, r)
-        } transformUp {
-          case other => other transformExpressions {
+        rewritePlan(right, conflictPlans.toMap)._1
+      }
+    }
+
+    private def rewritePlan(plan: LogicalPlan, conflictPlanMap: 
Map[LogicalPlan, LogicalPlan])
+      : (LogicalPlan, mutable.ArrayBuffer[(Attribute, Attribute)]) = {
+      val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]()
+      if (conflictPlanMap.contains(plan)) {
+        // If the plan is the one that conflict the with left one, we'd
+        // just replace it with the new plan and collect the rewrite
+        // attributes for the parent node.
+        val newRelation = conflictPlanMap(plan)
+        attrMapping ++= plan.output.zip(newRelation.output)
+        newRelation -> attrMapping
+      } else {
+        var newPlan = plan.mapChildren { child =>
+          // If not, we'd rewrite child plan recursively until we find the
+          // conflict node or reach the leaf node.
+          val (newChild, childAttrMapping) = rewritePlan(child, 
conflictPlanMap)
+          attrMapping ++= childAttrMapping
+          newChild
+        }
+
+        if (attrMapping.isEmpty) {
+          newPlan -> attrMapping
+        } else {
+          assert(!attrMapping.groupBy(_._1.exprId)
+            .exists(_._2.map(_._2.exprId).distinct.length > 1),
+            "Found duplicate rewrite attributes")
+          val attributeRewrites = AttributeMap(attrMapping)
+          // rewrite the attributes of parent node
+          newPlan = newPlan.transformExpressions {

Review comment:
       Oh, I see. This looks more clear. +1 for this change.

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
##########
@@ -1192,11 +1192,23 @@ class Analyzer(
             if 
findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
           Seq((oldVersion, oldVersion.copy(projectList = 
newAliases(projectList))))
 
+        // We don't need to search child plan recursively if the projectList 
of a Project
+        // is only composed of Alias and doesn't contain any conflicting 
attributes.
+        // Because, even if the child plan has some conflicting attributes, 
the attributes
+        // will be aliased to non-conflicting attributes by the Project at the 
end.
+        case _ @ Project(projectList, _)
+          if findAliases(projectList).size == projectList.size =>
+          Nil

Review comment:
       Don't we need to put this before previous `Project` pattern?

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
##########
@@ -1192,11 +1192,23 @@ class Analyzer(
             if 
findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
           Seq((oldVersion, oldVersion.copy(projectList = 
newAliases(projectList))))
 
+        // We don't need to search child plan recursively if the projectList 
of a Project
+        // is only composed of Alias and doesn't contain any conflicting 
attributes.
+        // Because, even if the child plan has some conflicting attributes, 
the attributes
+        // will be aliased to non-conflicting attributes by the Project at the 
end.
+        case _ @ Project(projectList, _)
+          if findAliases(projectList).size == projectList.size =>
+          Nil
+
         case oldVersion @ Aggregate(_, aggregateExpressions, _)
             if 
findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
           Seq((oldVersion, oldVersion.copy(
             aggregateExpressions = newAliases(aggregateExpressions))))
 
+        case _ @ Aggregate(_, aggregateExpressions, _)

Review comment:
       Same reason as above? Add a simple comment too?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to