Github user cloud-fan commented on a diff in the pull request:
https://github.com/apache/spark/pull/20094#discussion_r158855737
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
---
@@ -1079,100 +1083,76 @@ class Analyzer(
case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa
case sa @ Sort(_, _, child: Aggregate) => sa
- case s @ Sort(order, _, originalChild) if !s.resolved &&
originalChild.resolved =>
- val child = EliminateBarriers(originalChild)
- try {
- val newOrder = order.map(resolveExpressionRecursively(_,
child).asInstanceOf[SortOrder])
- val requiredAttrs = AttributeSet(newOrder).filter(_.resolved)
- val missingAttrs = requiredAttrs -- child.outputSet
- if (missingAttrs.nonEmpty) {
- // Add missing attributes and then project them away after the
sort.
- Project(child.output,
- Sort(newOrder, s.global, addMissingAttr(child,
missingAttrs)))
- } else if (newOrder != order) {
- s.copy(order = newOrder)
- } else {
- s
- }
- } catch {
- // Attempting to resolve it might fail. When this happens,
return the original plan.
- // Users will see an AnalysisException for resolution failure of
missing attributes
- // in Sort
- case ae: AnalysisException => s
+ case s @ Sort(order, _, child) if !s.resolved && child.resolved =>
+ val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(order,
child)
+ val ordering = newOrder.map(_.asInstanceOf[SortOrder])
+ if (child.output == newChild.output) {
+ s.copy(order = ordering)
+ } else {
+ // Add missing attributes and then project them away.
+ val newSort = s.copy(order = ordering, child = newChild)
+ Project(child.output, newSort)
}
- case f @ Filter(cond, originalChild) if !f.resolved &&
originalChild.resolved =>
- val child = EliminateBarriers(originalChild)
- try {
- val newCond = resolveExpressionRecursively(cond, child)
- val requiredAttrs = newCond.references.filter(_.resolved)
- val missingAttrs = requiredAttrs -- child.outputSet
- if (missingAttrs.nonEmpty) {
- // Add missing attributes and then project them away.
- Project(child.output,
- Filter(newCond, addMissingAttr(child, missingAttrs)))
- } else if (newCond != cond) {
- f.copy(condition = newCond)
- } else {
- f
- }
- } catch {
- // Attempting to resolve it might fail. When this happens,
return the original plan.
- // Users will see an AnalysisException for resolution failure of
missing attributes
- case ae: AnalysisException => f
+ case f @ Filter(cond, child) if !f.resolved && child.resolved =>
+ val (newCond, newChild) =
resolveExprsAndAddMissingAttrs(Seq(cond), child)
+ if (child.output == newChild.output) {
+ f.copy(condition = newCond.head)
+ } else {
+ // Add missing attributes and then project them away.
+ val newFilter = Filter(newCond.head, newChild)
+ Project(child.output, newFilter)
}
}
- /**
- * Add the missing attributes into projectList of Project/Window or
aggregateExpressions of
- * Aggregate.
- */
- private def addMissingAttr(plan: LogicalPlan, missingAttrs:
AttributeSet): LogicalPlan = {
- if (missingAttrs.isEmpty) {
- return AnalysisBarrier(plan)
- }
- plan match {
- case p: Project =>
- val missing = missingAttrs -- p.child.outputSet
- Project(p.projectList ++ missingAttrs, addMissingAttr(p.child,
missing))
- case a: Aggregate =>
- // all the missing attributes should be grouping expressions
- // TODO: push down AggregateExpression
- missingAttrs.foreach { attr =>
- if (!a.groupingExpressions.exists(_.semanticEquals(attr))) {
- throw new AnalysisException(s"Can't add $attr to
${a.simpleString}")
- }
- }
- val newAggregateExpressions = a.aggregateExpressions ++
missingAttrs
- a.copy(aggregateExpressions = newAggregateExpressions)
- case g: Generate =>
- // If join is false, we will convert it to true for getting from
the child the missing
- // attributes that its child might have or could have.
- val missing = missingAttrs -- g.child.outputSet
- g.copy(join = true, child = addMissingAttr(g.child, missing))
- case d: Distinct =>
- throw new AnalysisException(s"Can't add $missingAttrs to $d")
- case u: UnaryNode =>
- u.withNewChildren(addMissingAttr(u.child, missingAttrs) :: Nil)
- case other =>
- throw new AnalysisException(s"Can't add $missingAttrs to $other")
- }
- }
-
- /**
- * Resolve the expression on a specified logical plan and it's child
(recursively), until
- * the expression is resolved or meet a non-unary node or Subquery.
- */
- @tailrec
- private def resolveExpressionRecursively(expr: Expression, plan:
LogicalPlan): Expression = {
- val resolved = resolveExpression(expr, plan)
- if (resolved.resolved) {
- resolved
+ private def resolveExprsAndAddMissingAttrs(
--- End diff --
I refactored the code to resolve expressions and add missing attributes in
one shot, so that we have a central place to deal with analysis barrier and to
decide which operator is supported and which is not.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]