allisonwang-db commented on a change in pull request #32470:
URL: https://github.com/apache/spark/pull/32470#discussion_r650165639
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
##########
@@ -2457,164 +2450,136 @@ class Analyzer(override val catalogManager:
CatalogManager)
_.containsPattern(AGGREGATE), ruleId) {
// Resolve aggregate with having clause to Filter(..., Aggregate()).
Note, to avoid wrongly
// resolve the having condition expression, here we skip resolving it in
ResolveReferences
- // and transform it to Filter after aggregate is resolved. See more
details in SPARK-31519.
+ // and transform it to Filter after aggregate is resolved. Basically
columns in HAVING should
+ // be resolved with `agg.child.output` first. See more details in
SPARK-31519.
case UnresolvedHaving(cond, agg: Aggregate) if agg.resolved =>
- resolveHaving(Filter(cond, agg), agg)
-
- case f @ Filter(_, agg: Aggregate) if agg.resolved =>
- resolveHaving(f, agg)
-
- case sort @ Sort(sortOrder, global, aggregate: Aggregate) if
aggregate.resolved =>
-
- // Try resolving the ordering as though it is in the aggregate clause.
- try {
- // If a sort order is unresolved, containing references not in
aggregate, or containing
- // `AggregateExpression`, we need to push down it to the underlying
aggregate operator.
- val unresolvedSortOrders = sortOrder.filter { s =>
- !s.resolved || !s.references.subsetOf(aggregate.outputSet) ||
containsAggregate(s)
+ resolveOperatorWithAggregate(Seq(cond), agg, (newExprs, newChild) => {
+ Filter(newExprs.head, newChild)
+ })
+
+ case Filter(cond, agg: Aggregate) if agg.resolved =>
+ // We should resolve the references normally based on child.output
first.
+ val maybeResolved = resolveExpressionByPlanOutput(cond, agg)
+ resolveOperatorWithAggregate(Seq(maybeResolved), agg, (newExprs,
newChild) => {
+ Filter(newExprs.head, newChild)
+ })
+
+ case Sort(sortOrder, global, agg: Aggregate) if agg.resolved =>
+ // We should resolve the references normally based on child.output
first.
+ val maybeResolved =
sortOrder.map(_.child).map(resolveExpressionByPlanOutput(_, agg))
+ resolveOperatorWithAggregate(maybeResolved, agg, (newExprs, newChild)
=> {
+ val newSortOrder = sortOrder.zip(newExprs).map {
+ case (sortOrder, expr) => sortOrder.copy(child = expr)
}
- val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child,
"aggOrder")())
-
- val aggregateWithExtraOrdering = aggregate.copy(
- aggregateExpressions = aggregate.aggregateExpressions ++
aliasedOrdering)
-
- val resolvedAggregate: Aggregate =
-
executeSameContext(aggregateWithExtraOrdering).asInstanceOf[Aggregate]
-
- val (reResolvedAggExprs, resolvedAliasedOrdering) =
-
resolvedAggregate.aggregateExpressions.splitAt(aggregate.aggregateExpressions.length)
-
- // If we pass the analysis check, then the ordering expressions
should only reference to
- // aggregate expressions or grouping expressions, and it's safe to
push them down to
- // Aggregate.
- checkAnalysis(resolvedAggregate)
-
- val originalAggExprs =
aggregate.aggregateExpressions.map(trimNonTopLevelAliases)
-
- // If the ordering expression is same with original aggregate
expression, we don't need
- // to push down this ordering expression and can reference the
original aggregate
- // expression instead.
- val needsPushDown = ArrayBuffer.empty[NamedExpression]
- val orderToAlias = unresolvedSortOrders.zip(aliasedOrdering)
- val evaluatedOrderings =
-
resolvedAliasedOrdering.asInstanceOf[Seq[Alias]].zip(orderToAlias).map {
- case (evaluated, (order, aliasOrder)) =>
- val index = reResolvedAggExprs.indexWhere {
- case Alias(child, _) => child semanticEquals evaluated.child
- case other => other semanticEquals evaluated.child
- }
+ Sort(newSortOrder, global, newChild)
+ })
+ }
- if (index == -1) {
- if (hasCharVarchar(evaluated)) {
- needsPushDown += aliasOrder
- order.copy(child = aliasOrder)
- } else {
- needsPushDown += evaluated
- order.copy(child = evaluated.toAttribute)
- }
- } else {
- order.copy(child = originalAggExprs(index).toAttribute)
- }
+ /**
+ * Resolves the given expressions as if they are in the given Aggregate
operator, which means
+ * the column can be resolved using `agg.child` and aggregate
functions/grouping columns are
+ * allowed. It returns a list of named expressions that need to be
appended to
+ * `agg.aggregateExpressions`, and the list of resolved expressions.
+ */
+ def resolveExprsWithAggregate(
+ exprs: Seq[Expression],
+ agg: Aggregate): (Seq[NamedExpression], Seq[Expression]) = {
+ val extraAggExprs = ArrayBuffer.empty[NamedExpression]
+ val transformed = exprs.map { e =>
+ // Try resolving the expression as though it is in the aggregate
clause.
+ def resolveCol(input: Expression): Expression = {
+ input.transform {
+ case u: UnresolvedAttribute =>
+ try {
+ agg.child.resolve(u.nameParts, resolver)
+ .map(TempResolvedColumn(_, u.nameParts)).getOrElse(u)
+ } catch {
+ case _: AnalysisException => u
+ }
}
-
- val sortOrdersMap = unresolvedSortOrders
- .map(new TreeNodeRef(_))
- .zip(evaluatedOrderings)
- .toMap
- val finalSortOrders = sortOrder.map(s => sortOrdersMap.getOrElse(new
TreeNodeRef(s), s))
-
- // Since we don't rely on sort.resolved as the stop condition for
this rule,
- // we need to check this and prevent applying this rule multiple
times
- if (sortOrder == finalSortOrders) {
- sort
+ }
+ def resolveSubQuery(input: Expression): Expression = {
+ if (SubqueryExpression.hasSubquery(input)) {
+ val fake = Project(Alias(input, "fake")() :: Nil, agg.child)
+
ResolveSubquery(fake).asInstanceOf[Project].projectList.head.asInstanceOf[Alias].child
} else {
- Project(aggregate.output,
- Sort(finalSortOrders, global,
- aggregate.copy(aggregateExpressions = originalAggExprs ++
needsPushDown)))
+ input
}
- } catch {
- // Attempting to resolve in the aggregate can result in ambiguity.
When this happens,
- // just return the original plan.
- case ae: AnalysisException => sort
}
- }
- def hasCharVarchar(expr: Alias): Boolean = {
- expr.find {
- case ne: NamedExpression =>
CharVarcharUtils.getRawType(ne.metadata).nonEmpty
- case _ => false
- }.nonEmpty
+ val maybeResolved = resolveSubQuery(resolveCol(e))
+ if (!maybeResolved.resolved) {
+ maybeResolved
+ } else {
+ buildAggExprList(maybeResolved, agg, extraAggExprs)
+ }
+ }
+ (extraAggExprs.toSeq, transformed)
}
- def containsAggregate(condition: Expression): Boolean = {
- condition.find(_.isInstanceOf[AggregateExpression]).isDefined
+ private def trimTempResolvedField(input: Expression): Expression =
input.transform {
+ case t: TempResolvedColumn => t.child
}
- def resolveFilterCondInAggregate(
- filterCond: Expression, agg: Aggregate): Option[(Seq[NamedExpression],
Expression)] = {
- try {
- val aggregatedCondition =
- Aggregate(
- agg.groupingExpressions,
- Alias(filterCond, "havingCondition")() :: Nil,
- agg.child)
- val resolvedOperator = executeSameContext(aggregatedCondition)
- def resolvedAggregateFilter =
- resolvedOperator
- .asInstanceOf[Aggregate]
- .aggregateExpressions.head
-
- // If resolution was successful and we see the filter has an aggregate
in it, add it to
- // the original aggregate operator.
- if (resolvedOperator.resolved) {
- // Try to replace all aggregate expressions in the filter by an
alias.
- val aggregateExpressions = ArrayBuffer.empty[NamedExpression]
- val transformedAggregateFilter = resolvedAggregateFilter.transform {
- case ae: AggregateExpression =>
- val alias = Alias(ae, ae.toString)()
- aggregateExpressions += alias
+ private def buildAggExprList(
+ expr: Expression,
+ agg: Aggregate,
+ aggExprList: ArrayBuffer[NamedExpression]): Expression = {
+ // Avoid adding an extra aggregate expression if it's already present in
+ // `agg.aggregateExpressions`.
+ val index = agg.aggregateExpressions.indexWhere {
Review comment:
Should we also take into account the aggregate expressions already added
in `aggExprList`? For example
```sql
select c1 from t1 group by c1 having sum(c2) > 0 and sum(c2) < 10
```
Here `sum(c2)` will be added to the exprList and we don't need to create
another alias for the second `sum(c2)`. Maybe use a aggExpr map instead of list?
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
##########
@@ -325,8 +325,6 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan]
with PredicateHelper
*/
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsPattern(PLAN_EXPRESSION)) {
- case f @ Filter(_, a: Aggregate) =>
Review comment:
This is not related to this PR but I just thought about the following
case:
```sql
-- t1: [c1, c2]: (0,1),(1,2) , t: [a, b]: (0,1),(1,2)
select sum(c2) as c1 from t1 group by c1 having (select sum(b) from t where
c1 = a) > 0
```
Because ResolveSubquery comes before ResolveAggregateFunction, this `c1` in
`c1 = a` will be resolved as `sum(c2)` instead of `t1.c1`, because
`UnresolvedHaving` is also a UnaryNode which will be resolved under the
UnaryNode case below. So the result here can be incorrect.
Here we need to exclude `UnresolvedHaving` from this rule.
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
##########
@@ -839,9 +839,7 @@ trait CheckAnalysis extends PredicateHelper with
LookupCatalog {
def checkMixedReferencesInsideAggregateExpr(expr: Expression): Unit = {
expr.foreach {
case a: AggregateExpression if containsOuter(a) =>
- val outer = a.collect { case OuterReference(e) => e.toAttribute }
- val local = a.references -- outer
Review comment:
Good catch!
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
##########
@@ -325,8 +325,6 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan]
with PredicateHelper
*/
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsPattern(PLAN_EXPRESSION)) {
- case f @ Filter(_, a: Aggregate) =>
Review comment:
This is not related to this PR but I just thought about the following
case:
```sql
-- t1: [c1, c2]: (0,1),(1,2) , t: [a, b]: (0,1),(1,2)
select sum(c2) as c1 from t1 group by c1 having (select sum(b) from t where
c1 = a) > 0
```
Because ResolveSubquery comes before ResolveAggregateFunction, this `c1` in
`c1 = a` will be resolved as `sum(c2)` instead of `t1.c1`, because
`UnresolvedHaving` is also a UnaryNode which will be resolved under the
UnaryNode case below. So the result here can be incorrect.
Here we need to exclude `UnresolvedHaving` from this rule.
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
##########
@@ -2378,10 +2375,6 @@ class Analyzer(override val catalogManager:
CatalogManager)
*/
def apply(plan: LogicalPlan): LogicalPlan =
plan.resolveOperatorsUpWithPruning(
_.containsPattern(PLAN_EXPRESSION), ruleId) {
- // In case of HAVING (a filter after an aggregate) we use both the
aggregate and
- // its child for resolution.
- case f @ Filter(_, a: Aggregate) if f.childrenResolved =>
Review comment:
This is not related to this PR but I just thought about the following
case:
```sql
-- t1: [c1, c2]: (0,1),(1,2) , t: [a, b]: (0,1),(1,2)
select sum(c2) as c1 from t1 group by c1 having (select sum(b) from t where
c1 = a) > 0
```
Because ResolveSubquery comes before ResolveAggregateFunction, this `c1` in
`c1 = a` will be resolved as `sum(c2)` instead of `t1.c1`, because
UnresolvedHaving is also a UnaryNode which will be resolved under the UnaryNode
case below. So the result here can be incorrect.
Here we need to exclude UnresolvedHaving from this rule.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]