Ngone51 commented on a change in pull request #31470:
URL: https://github.com/apache/spark/pull/31470#discussion_r599300393
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
##########
@@ -1428,124 +1429,21 @@ class Analyzer(override val catalogManager:
CatalogManager)
* a logical plan node's children.
*/
object ResolveReferences extends Rule[LogicalPlan] {
- /**
- * Generate a new logical plan for the right child with different
expression IDs
- * for all conflicting attributes.
- */
- private def dedupRight (left: LogicalPlan, right: LogicalPlan):
LogicalPlan = {
- val conflictingAttributes = left.outputSet.intersect(right.outputSet)
- logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")}
" +
- s"between $left and $right")
-
- /**
- * For LogicalPlan likes MultiInstanceRelation, Project, Aggregate, etc,
whose output doesn't
- * inherit directly from its children, we could just stop collect on it.
Because we could
- * always replace all the lower conflict attributes with the new
attributes from the new
- * plan. Theoretically, we should do recursively collect for Generate
and Window but we leave
- * it to the next batch to reduce possible overhead because this should
be a corner case.
- */
- def collectConflictPlans(plan: LogicalPlan): Seq[(LogicalPlan,
LogicalPlan)] = plan match {
- // Handle base relations that might appear more than once.
- case oldVersion: MultiInstanceRelation
- if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty
=>
- val newVersion = oldVersion.newInstance()
- newVersion.copyTagsFrom(oldVersion)
- Seq((oldVersion, newVersion))
-
- case oldVersion: SerializeFromObject
- if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty
=>
- Seq((oldVersion, oldVersion.copy(
- serializer = oldVersion.serializer.map(_.newInstance()))))
-
- // Handle projects that create conflicting aliases.
- case oldVersion @ Project(projectList, _)
- if
findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
- Seq((oldVersion, oldVersion.copy(projectList =
newAliases(projectList))))
-
- // We don't need to search child plan recursively if the projectList
of a Project
- // is only composed of Alias and doesn't contain any conflicting
attributes.
- // Because, even if the child plan has some conflicting attributes,
the attributes
- // will be aliased to non-conflicting attributes by the Project at the
end.
- case _ @ Project(projectList, _)
- if findAliases(projectList).size == projectList.size =>
- Nil
-
- case oldVersion @ Aggregate(_, aggregateExpressions, _)
- if
findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
- Seq((oldVersion, oldVersion.copy(
- aggregateExpressions = newAliases(aggregateExpressions))))
-
- // We don't search the child plan recursively for the same reason as
the above Project.
- case _ @ Aggregate(_, aggregateExpressions, _)
- if findAliases(aggregateExpressions).size ==
aggregateExpressions.size =>
- Nil
-
- case oldVersion @ FlatMapGroupsInPandas(_, _, output, _)
- if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty
=>
- Seq((oldVersion, oldVersion.copy(output =
output.map(_.newInstance()))))
-
- case oldVersion @ FlatMapCoGroupsInPandas(_, _, _, output, _, _)
- if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty
=>
- Seq((oldVersion, oldVersion.copy(output =
output.map(_.newInstance()))))
-
- case oldVersion @ MapInPandas(_, output, _)
- if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty
=>
- Seq((oldVersion, oldVersion.copy(output =
output.map(_.newInstance()))))
-
- case oldVersion: Generate
- if
oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
- val newOutput = oldVersion.generatorOutput.map(_.newInstance())
- Seq((oldVersion, oldVersion.copy(generatorOutput = newOutput)))
-
- case oldVersion: Expand
- if
oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
- val producedAttributes = oldVersion.producedAttributes
- val newOutput = oldVersion.output.map { attr =>
- if (producedAttributes.contains(attr)) {
- attr.newInstance()
- } else {
- attr
- }
- }
- Seq((oldVersion, oldVersion.copy(output = newOutput)))
-
- case oldVersion @ Window(windowExpressions, _, _, child)
- if
AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
- .nonEmpty =>
- Seq((oldVersion, oldVersion.copy(windowExpressions =
newAliases(windowExpressions))))
- case oldVersion @ ScriptTransformation(_, _, output, _, _)
- if AttributeSet(output).intersect(conflictingAttributes).nonEmpty
=>
- Seq((oldVersion, oldVersion.copy(output =
output.map(_.newInstance()))))
-
- case _ => plan.children.flatMap(collectConflictPlans)
- }
-
- val conflictPlans = collectConflictPlans(right)
-
- /*
- * Note that it's possible `conflictPlans` can be empty which implies
that there
- * is a logical plan node that produces new references that this rule
cannot handle.
- * When that is the case, there must be another rule that resolves these
conflicts.
- * Otherwise, the analysis will fail.
- */
- if (conflictPlans.isEmpty) {
- right
- } else {
- val planMapping = conflictPlans.toMap
- right.transformUpWithNewOutput {
- case oldPlan =>
- val newPlanOpt = planMapping.get(oldPlan)
- newPlanOpt.map { newPlan =>
- newPlan -> oldPlan.output.zip(newPlan.output)
- }.getOrElse(oldPlan -> Nil)
- }
+ private def hasConflictingAttrs(p: LogicalPlan): Boolean = {
+ p.children.length > 1 && {
+ p.children.tail.foldLeft(p.children.head.outputSet) {
+ case (conflictAttrs, child) =>
conflictAttrs.intersect(child.outputSet)
Review comment:
For example, the test `SQLQuerySuite.self join with alias in agg`:
```scala
[info] org.apache.spark.sql.AnalysisException: cannot resolve 'x.str'
given input columns: [x.str, y.str, x.strCount, y.strCount]; line 3 pos 23;
[info] 'Aggregate ['x.str], ['x.str, unresolvedalias('SUM('x.strCount),
None)]
[info] +- 'Join Inner, ('x.str = 'y.str)
[info] :- SubqueryAlias x
[info] : +- SubqueryAlias df
[info] : +- View (`df`, [str#226,str#226,strCount#233L])
[info] : +- Aggregate [str#226], [str#226, str#226,
count(str#226) AS strCount#233L]
[info] : +- Project [_1#220 AS int#225, _2#221 AS str#226]
[info] : +- LocalRelation [_1#220, _2#221]
[info] +- SubqueryAlias y
[info] +- SubqueryAlias df
[info] +- View (`df`, [str#241,str#241,strCount#239L])
[info] +- Aggregate [str#241], [str#241, str#241,
count(str#241) AS strCount#239L]
[info] +- Project [_1#237 AS int#240, _2#238 AS str#241]
[info] +- LocalRelation [_1#237, _2#238]
```
Here, the `Aggregate` has duplicate `str`s.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]