This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push: new ebac47b [SPARK-32280][SPARK-32372][SQL] ResolveReferences.dedupRight should only rewrite attributes for ancestor nodes of the conflict plan ebac47b is described below commit ebac47b96cb38d7cd5d73a3f238017a5f8d77d1a Author: yi.wu <yi...@databricks.com> AuthorDate: Thu Jul 23 14:24:47 2020 +0000 [SPARK-32280][SPARK-32372][SQL] ResolveReferences.dedupRight should only rewrite attributes for ancestor nodes of the conflict plan This PR refactors `ResolveReferences.dedupRight` to make sure it only rewrite attributes for ancestor nodes of the conflict plan. This is a bug fix. ```scala sql("SELECT name, avg(age) as avg_age FROM person GROUP BY name") .createOrReplaceTempView("person_a") sql("SELECT p1.name, p2.avg_age FROM person p1 JOIN person_a p2 ON p1.name = p2.name") .createOrReplaceTempView("person_b") sql("SELECT * FROM person_a UNION SELECT * FROM person_b") .createOrReplaceTempView("person_c") sql("SELECT p1.name, p2.avg_age FROM person_c p1 JOIN person_c p2 ON p1.name = p2.name").show() ``` When executing the above query, we'll hit the error: ```scala [info] Failed to analyze query: org.apache.spark.sql.AnalysisException: Resolved attribute(s) avg_age#231 missing from name#223,avg_age#218,id#232,age#234,name#233 in operator !Project [name#233, avg_age#231]. Attribute(s) with the same name appear in the operation: avg_age. Please check if the right attribute(s) are used.;; ... ``` The plan below is the problematic plan which is the right plan of a `Join` operator. And, it has conflict plans comparing to the left plan. In this problematic plan, the first `Aggregate` operator (the one under the first child of `Union`) becomes a conflict plan compares to the left one and has a rewrite attribute pair as `avg_age#218` -> `avg_age#231`. With the current `dedupRight` logic, we'll first replace this `Aggregate` with a new one, and then rewrites the attribute `avg_age# [...] ```scala : : +- SubqueryAlias p2 +- SubqueryAlias person_c +- Distinct +- Union :- Project [name#233, avg_age#231] : +- SubqueryAlias person_a : +- Aggregate [name#233], [name#233, avg(cast(age#234 as bigint)) AS avg_age#231] : +- SubqueryAlias person : +- SerializeFromObject [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).id AS id#232, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).name, true, false) AS name#233, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).age AS age#234] : +- ExternalRDD [obj#165] +- Project [name#233 AS name#227, avg_age#231 AS avg_age#228] +- Project [name#233, avg_age#231] +- SubqueryAlias person_b +- !Project [name#233, avg_age#231] +- Join Inner, (name#233 = name#223) :- SubqueryAlias p1 : +- SubqueryAlias person : +- SerializeFromObject [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).id AS id#232, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).name, true, false) AS name#233, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).age AS age#234] : +- ExternalRDD [obj#165] +- SubqueryAlias p2 +- SubqueryAlias person_a +- Aggregate [name#223], [name#223, avg(cast(age#224 as bigint)) AS avg_age#218] +- SubqueryAlias person +- SerializeFromObject [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).id AS id#222, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).name, true, false) AS name#223, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).age AS age#224] +- ExternalRDD [obj#165] ``` Yes, users would no longer hit the error after this fix. Added test. Closes #29166 from Ngone51/impr-dedup. Authored-by: yi.wu <yi...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit a8e3de36e7d543f1c7923886628ac3178f45f512) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/analysis/Analyzer.scala | 63 ++++++++++++++++++---- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 22 ++++++++ 2 files changed, 75 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 68fe580..bd5a797 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1181,11 +1181,24 @@ class Analyzer( if findAliases(projectList).intersect(conflictingAttributes).nonEmpty => Seq((oldVersion, oldVersion.copy(projectList = newAliases(projectList)))) + // We don't need to search child plan recursively if the projectList of a Project + // is only composed of Alias and doesn't contain any conflicting attributes. + // Because, even if the child plan has some conflicting attributes, the attributes + // will be aliased to non-conflicting attributes by the Project at the end. + case _ @ Project(projectList, _) + if findAliases(projectList).size == projectList.size => + Nil + case oldVersion @ Aggregate(_, aggregateExpressions, _) if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => Seq((oldVersion, oldVersion.copy( aggregateExpressions = newAliases(aggregateExpressions)))) + // We don't search the child plan recursively for the same reason as the above Project. + case _ @ Aggregate(_, aggregateExpressions, _) + if findAliases(aggregateExpressions).size == aggregateExpressions.size => + Nil + case oldVersion @ FlatMapGroupsInPandas(_, _, output, _) if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance())))) @@ -1226,20 +1239,50 @@ class Analyzer( if (conflictPlans.isEmpty) { right } else { - val attributeRewrites = AttributeMap(conflictPlans.flatMap { - case (oldRelation, newRelation) => oldRelation.output.zip(newRelation.output)}) - val conflictPlanMap = conflictPlans.toMap - // transformDown so that we can replace all the old Relations in one turn due to - // the reason that `conflictPlans` are also collected in pre-order. - right transformDown { - case r => conflictPlanMap.getOrElse(r, r) - } transformUp { - case other => other transformExpressions { + rewritePlan(right, conflictPlans.toMap)._1 + } + } + + private def rewritePlan(plan: LogicalPlan, conflictPlanMap: Map[LogicalPlan, LogicalPlan]) + : (LogicalPlan, Seq[(Attribute, Attribute)]) = { + if (conflictPlanMap.contains(plan)) { + // If the plan is the one that conflict the with left one, we'd + // just replace it with the new plan and collect the rewrite + // attributes for the parent node. + val newRelation = conflictPlanMap(plan) + newRelation -> plan.output.zip(newRelation.output) + } else { + val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]() + val newPlan = plan.mapChildren { child => + // If not, we'd rewrite child plan recursively until we find the + // conflict node or reach the leaf node. + val (newChild, childAttrMapping) = rewritePlan(child, conflictPlanMap) + attrMapping ++= childAttrMapping.filter { case (oldAttr, _) => + // `attrMapping` is not only used to replace the attributes of the current `plan`, + // but also to be propagated to the parent plans of the current `plan`. Therefore, + // the `oldAttr` must be part of either `plan.references` (so that it can be used to + // replace attributes of the current `plan`) or `plan.outputSet` (so that it can be + // used by those parent plans). + (plan.outputSet ++ plan.references).contains(oldAttr) + } + newChild + } + + if (attrMapping.isEmpty) { + newPlan -> attrMapping + } else { + assert(!attrMapping.groupBy(_._1.exprId) + .exists(_._2.map(_._2.exprId).distinct.length > 1), + "Found duplicate rewrite attributes") + val attributeRewrites = AttributeMap(attrMapping) + // Using attrMapping from the children plans to rewrite their parent node. + // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes. + newPlan.transformExpressions { case a: Attribute => dedupAttr(a, attributeRewrites) case s: SubqueryExpression => s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) - } + } -> attrMapping } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 093f2db..6fab47d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3467,6 +3467,28 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark |""".stripMargin), Row(1)) } } + + test("SPARK-32372: ResolveReferences.dedupRight should only rewrite attributes for ancestor " + + "plans of the conflict plan") { + sql("SELECT name, avg(age) as avg_age FROM person GROUP BY name") + .createOrReplaceTempView("person_a") + sql("SELECT p1.name, p2.avg_age FROM person p1 JOIN person_a p2 ON p1.name = p2.name") + .createOrReplaceTempView("person_b") + sql("SELECT * FROM person_a UNION SELECT * FROM person_b") + .createOrReplaceTempView("person_c") + checkAnswer( + sql("SELECT p1.name, p2.avg_age FROM person_c p1 JOIN person_c p2 ON p1.name = p2.name"), + Row("jim", 20.0) :: Row("mike", 30.0) :: Nil) + } + + test("SPARK-32280: Avoid duplicate rewrite attributes when there're multiple JOINs") { + sql("SELECT 1 AS id").createOrReplaceTempView("A") + sql("SELECT id, 'foo' AS kind FROM A").createOrReplaceTempView("B") + sql("SELECT l.id as id FROM B AS l LEFT SEMI JOIN B AS r ON l.kind = r.kind") + .createOrReplaceTempView("C") + checkAnswer(sql("SELECT 0 FROM ( SELECT * FROM B JOIN C USING (id)) " + + "JOIN ( SELECT * FROM B JOIN C USING (id)) USING (id)"), Row(0)) + } } case class Foo(bar: Option[String]) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org