This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-2.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push: new 6ee0eb4 [SPARK-32280][SPARK-32372][2.4][SQL] ResolveReferences.dedupRight should only rewrite attributes for ancestor nodes of the conflict plan 6ee0eb4 is described below commit 6ee0eb40870c92889bc3c627d4b3178033a64a18 Author: yi.wu <yi...@databricks.com> AuthorDate: Fri Jul 24 04:26:22 2020 +0000 [SPARK-32280][SPARK-32372][2.4][SQL] ResolveReferences.dedupRight should only rewrite attributes for ancestor nodes of the conflict plan ### What changes were proposed in this pull request? This PR refactors `ResolveReferences.dedupRight` to make sure it only rewrite attributes for ancestor nodes of the conflict plan. ### Why are the changes needed? This is a bug fix. ```scala sql("SELECT name, avg(age) as avg_age FROM person GROUP BY name") .createOrReplaceTempView("person_a") sql("SELECT p1.name, p2.avg_age FROM person p1 JOIN person_a p2 ON p1.name = p2.name") .createOrReplaceTempView("person_b") sql("SELECT * FROM person_a UNION SELECT * FROM person_b") .createOrReplaceTempView("person_c") sql("SELECT p1.name, p2.avg_age FROM person_c p1 JOIN person_c p2 ON p1.name = p2.name").show() ``` When executing the above query, we'll hit the error: ```scala [info] Failed to analyze query: org.apache.spark.sql.AnalysisException: Resolved attribute(s) avg_age#231 missing from name#223,avg_age#218,id#232,age#234,name#233 in operator !Project [name#233, avg_age#231]. Attribute(s) with the same name appear in the operation: avg_age. Please check if the right attribute(s) are used.;; ... ``` The plan below is the problematic plan which is the right plan of a `Join` operator. And, it has conflict plans comparing to the left plan. In this problematic plan, the first `Aggregate` operator (the one under the first child of `Union`) becomes a conflict plan compares to the left one and has a rewrite attribute pair as `avg_age#218` -> `avg_age#231`. With the current `dedupRight` logic, we'll first replace this `Aggregate` with a new one, and then rewrites the attribute `avg_age# [...] ```scala : : +- SubqueryAlias p2 +- SubqueryAlias person_c +- Distinct +- Union :- Project [name#233, avg_age#231] : +- SubqueryAlias person_a : +- Aggregate [name#233], [name#233, avg(cast(age#234 as bigint)) AS avg_age#231] : +- SubqueryAlias person : +- SerializeFromObject [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).id AS id#232, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).name, true, false) AS name#233, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).age AS age#234] : +- ExternalRDD [obj#165] +- Project [name#233 AS name#227, avg_age#231 AS avg_age#228] +- Project [name#233, avg_age#231] +- SubqueryAlias person_b +- !Project [name#233, avg_age#231] +- Join Inner, (name#233 = name#223) :- SubqueryAlias p1 : +- SubqueryAlias person : +- SerializeFromObject [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).id AS id#232, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).name, true, false) AS name#233, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).age AS age#234] : +- ExternalRDD [obj#165] +- SubqueryAlias p2 +- SubqueryAlias person_a +- Aggregate [name#223], [name#223, avg(cast(age#224 as bigint)) AS avg_age#218] +- SubqueryAlias person +- SerializeFromObject [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).id AS id#222, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).name, true, false) AS name#223, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).age AS age#224] +- ExternalRDD [obj#165] ``` ### Does this PR introduce _any_ user-facing change? Yes, users would no longer hit the error after this fix. ### How was this patch tested? Added test. Closes #29208 from Ngone51/cherry-pick-spark-32372. Authored-by: yi.wu <yi...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/analysis/Analyzer.scala | 54 ++++++++++++++++++---- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 24 ++++++++++ 2 files changed, 68 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index afe7b4f..aaaf707 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -886,17 +886,51 @@ class Analyzer( */ right case Some((oldRelation, newRelation)) => - val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) - right transformUp { - case r if r == oldRelation => newRelation - } transformUp { - case other => other transformExpressions { - case a: Attribute => - dedupAttr(a, attributeRewrites) - case s: SubqueryExpression => - s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) - } + rewritePlan(right, Map(oldRelation -> newRelation))._1 + } + } + + private def rewritePlan(plan: LogicalPlan, conflictPlanMap: Map[LogicalPlan, LogicalPlan]) + : (LogicalPlan, Seq[(Attribute, Attribute)]) = { + if (conflictPlanMap.contains(plan)) { + // If the plan is the one that conflict the with left one, we'd + // just replace it with the new plan and collect the rewrite + // attributes for the parent node. + val newRelation = conflictPlanMap(plan) + newRelation -> plan.output.zip(newRelation.output) + } else { + val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]() + val newPlan = plan.mapChildren { child => + // If not, we'd rewrite child plan recursively until we find the + // conflict node or reach the leaf node. + val (newChild, childAttrMapping) = rewritePlan(child, conflictPlanMap) + attrMapping ++= childAttrMapping.filter { case (oldAttr, _) => + // `attrMapping` is not only used to replace the attributes of the current `plan`, + // but also to be propagated to the parent plans of the current `plan`. Therefore, + // the `oldAttr` must be part of either `plan.references` (so that it can be used to + // replace attributes of the current `plan`) or `plan.outputSet` (so that it can be + // used by those parent plans). + (plan.outputSet ++ plan.references).contains(oldAttr) } + newChild + } + + if (attrMapping.isEmpty) { + newPlan -> attrMapping + } else { + assert(!attrMapping.groupBy(_._1.exprId) + .exists(_._2.map(_._2.exprId).distinct.length > 1), + "Found duplicate rewrite attributes") + val attributeRewrites = AttributeMap(attrMapping) + // Using attrMapping from the children plans to rewrite their parent node. + // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes. + newPlan.transformExpressions { + case a: Attribute => + dedupAttr(a, attributeRewrites) + case s: SubqueryExpression => + s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) + } -> attrMapping + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d0114f6..c424ef8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3078,6 +3078,30 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(1)) } } + + test("SPARK-32372: ResolveReferences.dedupRight should only rewrite attributes for ancestor " + + "plans of the conflict plan") { + sql("SELECT name, avg(age) as avg_age FROM person GROUP BY name") + .createOrReplaceTempView("person_a") + sql("SELECT p1.name, p2.avg_age FROM person p1 JOIN person_a p2 ON p1.name = p2.name") + .createOrReplaceTempView("person_b") + sql("SELECT * FROM person_a UNION SELECT * FROM person_b") + .createOrReplaceTempView("person_c") + checkAnswer( + sql("SELECT p1.name, p2.avg_age FROM person_c p1 JOIN person_c p2 ON p1.name = p2.name"), + Row("jim", 20.0) :: Row("mike", 30.0) :: Nil) + } + + test("SPARK-32280: Avoid duplicate rewrite attributes when there're multiple JOINs") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + sql("SELECT 1 AS id").createOrReplaceTempView("A") + sql("SELECT id, 'foo' AS kind FROM A").createOrReplaceTempView("B") + sql("SELECT l.id as id FROM B AS l LEFT SEMI JOIN B AS r ON l.kind = r.kind") + .createOrReplaceTempView("C") + checkAnswer(sql("SELECT 0 FROM ( SELECT * FROM B JOIN C USING (id)) " + + "JOIN ( SELECT * FROM B JOIN C USING (id)) USING (id)"), Row(0)) + } + } } case class Foo(bar: Option[String]) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org