maropu commented on a change in pull request #30145:
URL: https://github.com/apache/spark/pull/30145#discussion_r605606388
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
##########
@@ -658,14 +658,15 @@ class Analyzer(override val catalogManager:
CatalogManager)
// CUBE/ROLLUP/GROUPING SETS. This also replace grouping()/grouping_id()
in resolved
// Filter/Sort.
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown {
- case h @ UnresolvedHaving(_, agg @ Aggregate(Seq(gs: GroupingSet),
aggregateExpressions, _))
+ case h @ UnresolvedHaving(
+ _, agg @ Aggregate(Seq(gs: GroupingAnalytic), aggregateExpressions, _))
Review comment:
nit: to avoid the line break, how about renaming `aggregateExpressions`
-> `aggExprs`?
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
##########
@@ -1950,16 +1951,39 @@ class Analyzer(override val catalogManager:
CatalogManager)
// Replace the index with the corresponding expression in
aggregateExpressions. The index is
// a 1-base position of aggregateExpressions, which is output columns
(select expression)
case Aggregate(groups, aggs, child) if aggs.forall(_.resolved) &&
- groups.exists(_.isInstanceOf[UnresolvedOrdinal]) =>
- val newGroups = groups.map {
- case u @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size
=>
- aggs(index - 1)
- case ordinal @ UnresolvedOrdinal(index) =>
- throw QueryCompilationErrors.groupByPositionRangeError(index,
aggs.size, ordinal)
- case o => o
- }
+ groups.exists(containUnresolvedOrdinal) =>
+ val newGroups = groups.map((resolveGroupByExpressionOrdinal(_, aggs)))
Aggregate(newGroups, aggs, child)
}
+
+ private def containUnresolvedOrdinal(e: Expression): Boolean = e match {
+ case _: UnresolvedOrdinal => true
+ case Cube(_, groupByExprs) =>
groupByExprs.exists(containUnresolvedOrdinal)
+ case Rollup(_, groupByExprs) =>
groupByExprs.exists(containUnresolvedOrdinal)
+ case GroupingSets(_, flatGroupingSets, groupByExprs) =>
+ flatGroupingSets.exists(containUnresolvedOrdinal) ||
+ groupByExprs.exists(containUnresolvedOrdinal)
+ case _ => false
+ }
+
+ private def resolveGroupByExpressionOrdinal(
+ expr: Expression,
+ aggs: Seq[Expression]): Expression = expr match {
+ case u @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size =>
Review comment:
`u` not used.
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
##########
@@ -1950,16 +1951,39 @@ class Analyzer(override val catalogManager:
CatalogManager)
// Replace the index with the corresponding expression in
aggregateExpressions. The index is
// a 1-base position of aggregateExpressions, which is output columns
(select expression)
case Aggregate(groups, aggs, child) if aggs.forall(_.resolved) &&
- groups.exists(_.isInstanceOf[UnresolvedOrdinal]) =>
- val newGroups = groups.map {
- case u @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size
=>
- aggs(index - 1)
- case ordinal @ UnresolvedOrdinal(index) =>
- throw QueryCompilationErrors.groupByPositionRangeError(index,
aggs.size, ordinal)
- case o => o
- }
+ groups.exists(containUnresolvedOrdinal) =>
+ val newGroups = groups.map((resolveGroupByExpressionOrdinal(_, aggs)))
Aggregate(newGroups, aggs, child)
}
+
+ private def containUnresolvedOrdinal(e: Expression): Boolean = e match {
+ case _: UnresolvedOrdinal => true
+ case Cube(_, groupByExprs) =>
groupByExprs.exists(containUnresolvedOrdinal)
+ case Rollup(_, groupByExprs) =>
groupByExprs.exists(containUnresolvedOrdinal)
+ case GroupingSets(_, flatGroupingSets, groupByExprs) =>
+ flatGroupingSets.exists(containUnresolvedOrdinal) ||
+ groupByExprs.exists(containUnresolvedOrdinal)
+ case _ => false
+ }
+
+ private def resolveGroupByExpressionOrdinal(
+ expr: Expression,
+ aggs: Seq[Expression]): Expression = expr match {
+ case u @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size =>
+ aggs(index - 1)
+ case ordinal @ UnresolvedOrdinal(index) =>
+ throw QueryCompilationErrors.groupByPositionRangeError(index,
aggs.size, ordinal)
+ case cube @ Cube(_, groupByExprs) =>
+ cube.copy(children =
groupByExprs.map(resolveGroupByExpressionOrdinal(_, aggs)))
+ case rollup @ Rollup(_, groupByExprs) =>
+ rollup.copy(children =
groupByExprs.map(resolveGroupByExpressionOrdinal(_, aggs)))
+ case groupingSets @ GroupingSets(_, flatGroupingSets, groupByExprs) =>
+ groupingSets.copy(
+ flatGroupingSets =
flatGroupingSets.map(resolveGroupByExpressionOrdinal(_, aggs)),
+ groupByExprs = groupByExprs.map(resolveGroupByExpressionOrdinal(_,
aggs))
+ )
Review comment:
We can merge them into a single entry.
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
##########
@@ -1950,16 +1951,39 @@ class Analyzer(override val catalogManager:
CatalogManager)
// Replace the index with the corresponding expression in
aggregateExpressions. The index is
// a 1-base position of aggregateExpressions, which is output columns
(select expression)
case Aggregate(groups, aggs, child) if aggs.forall(_.resolved) &&
- groups.exists(_.isInstanceOf[UnresolvedOrdinal]) =>
- val newGroups = groups.map {
- case u @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size
=>
- aggs(index - 1)
- case ordinal @ UnresolvedOrdinal(index) =>
- throw QueryCompilationErrors.groupByPositionRangeError(index,
aggs.size, ordinal)
- case o => o
- }
+ groups.exists(containUnresolvedOrdinal) =>
+ val newGroups = groups.map((resolveGroupByExpressionOrdinal(_, aggs)))
Aggregate(newGroups, aggs, child)
}
+
+ private def containUnresolvedOrdinal(e: Expression): Boolean = e match {
+ case _: UnresolvedOrdinal => true
+ case Cube(_, groupByExprs) =>
groupByExprs.exists(containUnresolvedOrdinal)
+ case Rollup(_, groupByExprs) =>
groupByExprs.exists(containUnresolvedOrdinal)
+ case GroupingSets(_, flatGroupingSets, groupByExprs) =>
+ flatGroupingSets.exists(containUnresolvedOrdinal) ||
+ groupByExprs.exists(containUnresolvedOrdinal)
+ case _ => false
+ }
+
+ private def resolveGroupByExpressionOrdinal(
+ expr: Expression,
+ aggs: Seq[Expression]): Expression = expr match {
+ case u @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size =>
Review comment:
nit: how about this?
```
case ordinal @ UnresolvedOrdinal(index) =>
if (index > 0 && index <= aggs.size) {
aggs(index - 1)
} else {
throw QueryCompilationErrors.groupByPositionRangeError(index,
aggs.size, ordinal)
}
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]