Github user cloud-fan commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20094#discussion_r158891321
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 ---
    @@ -1079,100 +1083,76 @@ class Analyzer(
           case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa
           case sa @ Sort(_, _, child: Aggregate) => sa
     
    -      case s @ Sort(order, _, originalChild) if !s.resolved && 
originalChild.resolved =>
    -        val child = EliminateBarriers(originalChild)
    -        try {
    -          val newOrder = order.map(resolveExpressionRecursively(_, 
child).asInstanceOf[SortOrder])
    -          val requiredAttrs = AttributeSet(newOrder).filter(_.resolved)
    -          val missingAttrs = requiredAttrs -- child.outputSet
    -          if (missingAttrs.nonEmpty) {
    -            // Add missing attributes and then project them away after the 
sort.
    -            Project(child.output,
    -              Sort(newOrder, s.global, addMissingAttr(child, 
missingAttrs)))
    -          } else if (newOrder != order) {
    -            s.copy(order = newOrder)
    -          } else {
    -            s
    -          }
    -        } catch {
    -          // Attempting to resolve it might fail. When this happens, 
return the original plan.
    -          // Users will see an AnalysisException for resolution failure of 
missing attributes
    -          // in Sort
    -          case ae: AnalysisException => s
    +      case s @ Sort(order, _, child) if !s.resolved && child.resolved =>
    +        val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(order, 
child)
    +        val ordering = newOrder.map(_.asInstanceOf[SortOrder])
    +        if (child.output == newChild.output) {
    +          s.copy(order = ordering)
    +        } else {
    +          // Add missing attributes and then project them away.
    +          val newSort = s.copy(order = ordering, child = newChild)
    +          Project(child.output, newSort)
             }
     
    -      case f @ Filter(cond, originalChild) if !f.resolved && 
originalChild.resolved =>
    -        val child = EliminateBarriers(originalChild)
    -        try {
    -          val newCond = resolveExpressionRecursively(cond, child)
    -          val requiredAttrs = newCond.references.filter(_.resolved)
    -          val missingAttrs = requiredAttrs -- child.outputSet
    -          if (missingAttrs.nonEmpty) {
    -            // Add missing attributes and then project them away.
    -            Project(child.output,
    -              Filter(newCond, addMissingAttr(child, missingAttrs)))
    -          } else if (newCond != cond) {
    -            f.copy(condition = newCond)
    -          } else {
    -            f
    -          }
    -        } catch {
    -          // Attempting to resolve it might fail. When this happens, 
return the original plan.
    -          // Users will see an AnalysisException for resolution failure of 
missing attributes
    -          case ae: AnalysisException => f
    +      case f @ Filter(cond, child) if !f.resolved && child.resolved =>
    +        val (newCond, newChild) = 
resolveExprsAndAddMissingAttrs(Seq(cond), child)
    +        if (child.output == newChild.output) {
    +          f.copy(condition = newCond.head)
    +        } else {
    +          // Add missing attributes and then project them away.
    +          val newFilter = Filter(newCond.head, newChild)
    +          Project(child.output, newFilter)
             }
         }
     
    -    /**
    -     * Add the missing attributes into projectList of Project/Window or 
aggregateExpressions of
    -     * Aggregate.
    -     */
    -    private def addMissingAttr(plan: LogicalPlan, missingAttrs: 
AttributeSet): LogicalPlan = {
    -      if (missingAttrs.isEmpty) {
    -        return AnalysisBarrier(plan)
    -      }
    -      plan match {
    -        case p: Project =>
    -          val missing = missingAttrs -- p.child.outputSet
    -          Project(p.projectList ++ missingAttrs, addMissingAttr(p.child, 
missing))
    -        case a: Aggregate =>
    -          // all the missing attributes should be grouping expressions
    -          // TODO: push down AggregateExpression
    -          missingAttrs.foreach { attr =>
    -            if (!a.groupingExpressions.exists(_.semanticEquals(attr))) {
    -              throw new AnalysisException(s"Can't add $attr to 
${a.simpleString}")
    -            }
    -          }
    -          val newAggregateExpressions = a.aggregateExpressions ++ 
missingAttrs
    -          a.copy(aggregateExpressions = newAggregateExpressions)
    -        case g: Generate =>
    -          // If join is false, we will convert it to true for getting from 
the child the missing
    -          // attributes that its child might have or could have.
    -          val missing = missingAttrs -- g.child.outputSet
    -          g.copy(join = true, child = addMissingAttr(g.child, missing))
    -        case d: Distinct =>
    -          throw new AnalysisException(s"Can't add $missingAttrs to $d")
    -        case u: UnaryNode =>
    -          u.withNewChildren(addMissingAttr(u.child, missingAttrs) :: Nil)
    -        case other =>
    -          throw new AnalysisException(s"Can't add $missingAttrs to $other")
    -      }
    -    }
    -
    -    /**
    -     * Resolve the expression on a specified logical plan and it's child 
(recursively), until
    -     * the expression is resolved or meet a non-unary node or Subquery.
    -     */
    -    @tailrec
    -    private def resolveExpressionRecursively(expr: Expression, plan: 
LogicalPlan): Expression = {
    -      val resolved = resolveExpression(expr, plan)
    -      if (resolved.resolved) {
    -        resolved
    +    private def resolveExprsAndAddMissingAttrs(
    +        exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], 
LogicalPlan) = {
    +      if (exprs.forall(_.resolved)) {
    +        // All given expressions are resolved, no need to continue anymore.
    +        (exprs, plan)
           } else {
             plan match {
    -          case u: UnaryNode if !u.isInstanceOf[SubqueryAlias] =>
    -            resolveExpressionRecursively(resolved, u.child)
    -          case other => resolved
    +          // For `AnalysisBarrier`, recursively resolve expressions and 
add missing attributes via
    +          // its child.
    +          case barrier: AnalysisBarrier =>
    +            val (newExprs, newChild) = 
resolveExprsAndAddMissingAttrs(exprs, barrier.child)
    +            (newExprs, AnalysisBarrier(newChild))
    +
    +          case p: Project =>
    +            val maybeResolvedExprs = exprs.map(resolveExpression(_, p))
    +            val (newExprs, newChild) = 
resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child)
    +            val missingAttrs = AttributeSet(newExprs) -- 
AttributeSet(maybeResolvedExprs)
    +            (newExprs, Project(p.projectList ++ missingAttrs, newChild))
    +
    +          case a @ Aggregate(groupExprs, aggExprs, child) =>
    +            val maybeResolvedExprs = exprs.map(resolveExpression(_, a))
    +            val (newExprs, newChild) = 
resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child)
    +            val missingAttrs = AttributeSet(newExprs) -- 
AttributeSet(maybeResolvedExprs)
    +            if (missingAttrs.forall(attr => 
groupExprs.exists(_.semanticEquals(attr)))) {
    +              // All the missing attributes are grouping expressions, 
valid case.
    +              (newExprs, a.copy(aggregateExpressions = aggExprs ++ 
missingAttrs, child = newChild))
    +            } else {
    +              // Need to add non-grouping attributes, invalid case.
    +              (exprs, a)
    +            }
    +
    +          case g: Generate =>
    +            val maybeResolvedExprs = exprs.map(resolveExpression(_, g))
    +            val (newExprs, newChild) = 
resolveExprsAndAddMissingAttrs(maybeResolvedExprs, g.child)
    +            (newExprs, g.copy(join = true, child = newChild))
    +
    +          // For `Distinct`, we can't recursively resolve and add 
attributes via its children.
    +          case d: Distinct =>
    +            (exprs.map(resolveExpression(_, d)), d)
    +
    +          case u: UnaryNode =>
    --- End diff --
    
    ah good catch! I missed that because the logic was in 
`resolveExpressionRecursively` instead of `addMissingAttr`.
    
    It indicates that it's more clear to merge these 2 methods :)


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to