This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push:
new 6ee0eb4 [SPARK-32280][SPARK-32372][2.4][SQL]
ResolveReferences.dedupRight should only rewrite attributes for ancestor nodes
of the conflict plan
6ee0eb4 is described below
commit 6ee0eb40870c92889bc3c627d4b3178033a64a18
Author: yi.wu <[email protected]>
AuthorDate: Fri Jul 24 04:26:22 2020 +0000
[SPARK-32280][SPARK-32372][2.4][SQL] ResolveReferences.dedupRight should
only rewrite attributes for ancestor nodes of the conflict plan
### What changes were proposed in this pull request?
This PR refactors `ResolveReferences.dedupRight` to make sure it only
rewrite attributes for ancestor nodes of the conflict plan.
### Why are the changes needed?
This is a bug fix.
```scala
sql("SELECT name, avg(age) as avg_age FROM person GROUP BY name")
.createOrReplaceTempView("person_a")
sql("SELECT p1.name, p2.avg_age FROM person p1 JOIN person_a p2 ON p1.name
= p2.name")
.createOrReplaceTempView("person_b")
sql("SELECT * FROM person_a UNION SELECT * FROM person_b")
.createOrReplaceTempView("person_c")
sql("SELECT p1.name, p2.avg_age FROM person_c p1 JOIN person_c p2 ON
p1.name = p2.name").show()
```
When executing the above query, we'll hit the error:
```scala
[info] Failed to analyze query: org.apache.spark.sql.AnalysisException:
Resolved attribute(s) avg_age#231 missing from
name#223,avg_age#218,id#232,age#234,name#233 in operator !Project [name#233,
avg_age#231]. Attribute(s) with the same name appear in the operation: avg_age.
Please check if the right attribute(s) are used.;;
...
```
The plan below is the problematic plan which is the right plan of a `Join`
operator. And, it has conflict plans comparing to the left plan. In this
problematic plan, the first `Aggregate` operator (the one under the first child
of `Union`) becomes a conflict plan compares to the left one and has a rewrite
attribute pair as `avg_age#218` -> `avg_age#231`. With the current
`dedupRight` logic, we'll first replace this `Aggregate` with a new one, and
then rewrites the attribute `avg_age# [...]
```scala
:
:
+- SubqueryAlias p2
+- SubqueryAlias person_c
+- Distinct
+- Union
:- Project [name#233, avg_age#231]
: +- SubqueryAlias person_a
: +- Aggregate [name#233], [name#233, avg(cast(age#234 as
bigint)) AS avg_age#231]
: +- SubqueryAlias person
: +- SerializeFromObject
[knownnotnull(assertnotnull(input[0,
org.apache.spark.sql.test.SQLTestData$Person, true])).id AS id#232,
staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType,
fromString, knownnotnull(assertnotnull(input[0,
org.apache.spark.sql.test.SQLTestData$Person, true])).name, true, false) AS
name#233, knownnotnull(assertnotnull(input[0,
org.apache.spark.sql.test.SQLTestData$Person, true])).age AS age#234]
: +- ExternalRDD [obj#165]
+- Project [name#233 AS name#227, avg_age#231 AS avg_age#228]
+- Project [name#233, avg_age#231]
+- SubqueryAlias person_b
+- !Project [name#233, avg_age#231]
+- Join Inner, (name#233 = name#223)
:- SubqueryAlias p1
: +- SubqueryAlias person
: +- SerializeFromObject
[knownnotnull(assertnotnull(input[0,
org.apache.spark.sql.test.SQLTestData$Person, true])).id AS id#232,
staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType,
fromString, knownnotnull(assertnotnull(input[0,
org.apache.spark.sql.test.SQLTestData$Person, true])).name, true, false) AS
name#233, knownnotnull(assertnotnull(input[0,
org.apache.spark.sql.test.SQLTestData$Person, true])).age AS age#234]
: +- ExternalRDD [obj#165]
+- SubqueryAlias p2
+- SubqueryAlias person_a
+- Aggregate [name#223], [name#223,
avg(cast(age#224 as bigint)) AS avg_age#218]
+- SubqueryAlias person
+- SerializeFromObject
[knownnotnull(assertnotnull(input[0,
org.apache.spark.sql.test.SQLTestData$Person, true])).id AS id#222,
staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType,
fromString, knownnotnull(assertnotnull(input[0,
org.apache.spark.sql.test.SQLTestData$Person, true])).name, true, false) AS
name#223, knownnotnull(assertnotnull(input[0,
org.apache.spark.sql.test.SQLTestData$Person, true])).age AS age#224]
+- ExternalRDD [obj#165]
```
### Does this PR introduce _any_ user-facing change?
Yes, users would no longer hit the error after this fix.
### How was this patch tested?
Added test.
Closes #29208 from Ngone51/cherry-pick-spark-32372.
Authored-by: yi.wu <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/catalyst/analysis/Analyzer.scala | 54 ++++++++++++++++++----
.../scala/org/apache/spark/sql/SQLQuerySuite.scala | 24 ++++++++++
2 files changed, 68 insertions(+), 10 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index afe7b4f..aaaf707 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -886,17 +886,51 @@ class Analyzer(
*/
right
case Some((oldRelation, newRelation)) =>
- val attributeRewrites =
AttributeMap(oldRelation.output.zip(newRelation.output))
- right transformUp {
- case r if r == oldRelation => newRelation
- } transformUp {
- case other => other transformExpressions {
- case a: Attribute =>
- dedupAttr(a, attributeRewrites)
- case s: SubqueryExpression =>
- s.withNewPlan(dedupOuterReferencesInSubquery(s.plan,
attributeRewrites))
- }
+ rewritePlan(right, Map(oldRelation -> newRelation))._1
+ }
+ }
+
+ private def rewritePlan(plan: LogicalPlan, conflictPlanMap:
Map[LogicalPlan, LogicalPlan])
+ : (LogicalPlan, Seq[(Attribute, Attribute)]) = {
+ if (conflictPlanMap.contains(plan)) {
+ // If the plan is the one that conflict the with left one, we'd
+ // just replace it with the new plan and collect the rewrite
+ // attributes for the parent node.
+ val newRelation = conflictPlanMap(plan)
+ newRelation -> plan.output.zip(newRelation.output)
+ } else {
+ val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]()
+ val newPlan = plan.mapChildren { child =>
+ // If not, we'd rewrite child plan recursively until we find the
+ // conflict node or reach the leaf node.
+ val (newChild, childAttrMapping) = rewritePlan(child,
conflictPlanMap)
+ attrMapping ++= childAttrMapping.filter { case (oldAttr, _) =>
+ // `attrMapping` is not only used to replace the attributes of the
current `plan`,
+ // but also to be propagated to the parent plans of the current
`plan`. Therefore,
+ // the `oldAttr` must be part of either `plan.references` (so that
it can be used to
+ // replace attributes of the current `plan`) or `plan.outputSet`
(so that it can be
+ // used by those parent plans).
+ (plan.outputSet ++ plan.references).contains(oldAttr)
}
+ newChild
+ }
+
+ if (attrMapping.isEmpty) {
+ newPlan -> attrMapping
+ } else {
+ assert(!attrMapping.groupBy(_._1.exprId)
+ .exists(_._2.map(_._2.exprId).distinct.length > 1),
+ "Found duplicate rewrite attributes")
+ val attributeRewrites = AttributeMap(attrMapping)
+ // Using attrMapping from the children plans to rewrite their parent
node.
+ // Note that we shouldn't rewrite a node using attrMapping from its
sibling nodes.
+ newPlan.transformExpressions {
+ case a: Attribute =>
+ dedupAttr(a, attributeRewrites)
+ case s: SubqueryExpression =>
+ s.withNewPlan(dedupOuterReferencesInSubquery(s.plan,
attributeRewrites))
+ } -> attrMapping
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index d0114f6..c424ef8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -3078,6 +3078,30 @@ class SQLQuerySuite extends QueryTest with
SharedSQLContext {
checkAnswer(df, Row(1))
}
}
+
+ test("SPARK-32372: ResolveReferences.dedupRight should only rewrite
attributes for ancestor " +
+ "plans of the conflict plan") {
+ sql("SELECT name, avg(age) as avg_age FROM person GROUP BY name")
+ .createOrReplaceTempView("person_a")
+ sql("SELECT p1.name, p2.avg_age FROM person p1 JOIN person_a p2 ON p1.name
= p2.name")
+ .createOrReplaceTempView("person_b")
+ sql("SELECT * FROM person_a UNION SELECT * FROM person_b")
+ .createOrReplaceTempView("person_c")
+ checkAnswer(
+ sql("SELECT p1.name, p2.avg_age FROM person_c p1 JOIN person_c p2 ON
p1.name = p2.name"),
+ Row("jim", 20.0) :: Row("mike", 30.0) :: Nil)
+ }
+
+ test("SPARK-32280: Avoid duplicate rewrite attributes when there're multiple
JOINs") {
+ withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+ sql("SELECT 1 AS id").createOrReplaceTempView("A")
+ sql("SELECT id, 'foo' AS kind FROM A").createOrReplaceTempView("B")
+ sql("SELECT l.id as id FROM B AS l LEFT SEMI JOIN B AS r ON l.kind =
r.kind")
+ .createOrReplaceTempView("C")
+ checkAnswer(sql("SELECT 0 FROM ( SELECT * FROM B JOIN C USING (id)) " +
+ "JOIN ( SELECT * FROM B JOIN C USING (id)) USING (id)"), Row(0))
+ }
+ }
}
case class Foo(bar: Option[String])
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]