This is an automated email from the ASF dual-hosted git repository. yumwang pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new b89b927af4f [SPARK-43050][SQL] Fix construct aggregate expressions by replacing grouping functions b89b927af4f is described below commit b89b927af4fe2ae35c2818d1f2f0efd879a39565 Author: Yuming Wang <yumw...@ebay.com> AuthorDate: Sat Apr 15 09:27:22 2023 +0800 [SPARK-43050][SQL] Fix construct aggregate expressions by replacing grouping functions ### What changes were proposed in this pull request? This PR fixes construct aggregate expressions by replacing grouping functions if a expression is part of aggregation. In the following example, the second `b` should also be replaced: <img width="545" alt="image" src="https://user-images.githubusercontent.com/5399861/230415618-84cd6334-690e-4b0b-867b-ccc4056226a8.png"> ### Why are the changes needed? Fix bug: ``` spark-sql (default)> SELECT CASE WHEN a IS NULL THEN count(b) WHEN b IS NULL THEN count(c) END > FROM grouping > GROUP BY GROUPING SETS (a, b, c); [MISSING_AGGREGATION] The non-aggregating expression "b" is based on columns which are not participating in the GROUP BY clause. ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test. Closes #40685 from wangyum/SPARK-43050. Authored-by: Yuming Wang <yumw...@ebay.com> Signed-off-by: Yuming Wang <yumw...@ebay.com> (cherry picked from commit 45b84cd37add1b9ce274273ad5e519e6bc1d8013) Signed-off-by: Yuming Wang <yumw...@ebay.com> # Conflicts: # sql/core/src/test/resources/sql-tests/analyzer-results/grouping_set.sql.out --- .../spark/sql/catalyst/analysis/Analyzer.scala | 27 ++++++++-------------- .../resources/sql-tests/inputs/grouping_set.sql | 4 ++++ .../sql-tests/results/grouping_set.sql.out | 18 +++++++++++++++ 3 files changed, 31 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d9c2c0ef63b..dc8904d7f74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -608,31 +608,22 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor aggregations: Seq[NamedExpression], groupByAliases: Seq[Alias], groupingAttrs: Seq[Expression], - gid: Attribute): Seq[NamedExpression] = aggregations.map { agg => - // collect all the found AggregateExpression, so we can check an expression is part of - // any AggregateExpression or not. - val aggsBuffer = ArrayBuffer[Expression]() - // Returns whether the expression belongs to any expressions in `aggsBuffer` or not. - def isPartOfAggregation(e: Expression): Boolean = { - aggsBuffer.exists(a => a.exists(_ eq e)) - } - replaceGroupingFunc(agg, groupByExprs, gid).transformDown { - // AggregateExpression should be computed on the unmodified value of its argument - // expressions, so we should not replace any references to grouping expression - // inside it. - case e if AggregateExpression.isAggregate(e) => - aggsBuffer += e - e - case e if isPartOfAggregation(e) => e + gid: Attribute): Seq[NamedExpression] = { + def replaceExprs(e: Expression): Expression = e match { + case e if AggregateExpression.isAggregate(e) => e case e => // Replace expression by expand output attribute. val index = groupByAliases.indexWhere(_.child.semanticEquals(e)) if (index == -1) { - e + e.mapChildren(replaceExprs) } else { groupingAttrs(index) } - }.asInstanceOf[NamedExpression] + } + aggregations + .map(replaceGroupingFunc(_, groupByExprs, gid)) + .map(replaceExprs) + .map(_.asInstanceOf[NamedExpression]) } /* diff --git a/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql b/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql index 4d516bdda7b..909c36c926c 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql @@ -60,3 +60,7 @@ SELECT grouping(k1), k1, k2, avg(v) FROM (VALUES (1,1,1),(2,2,2)) AS t(k1,k2,v) -- grouping_id function SELECT grouping_id(k1, k2), avg(v) from (VALUES (1,1,1),(2,2,2)) AS t(k1,k2,v) GROUP BY k1, k2 GROUPING SETS ((k2, k1), k1); + +SELECT CASE WHEN a IS NULL THEN count(b) WHEN b IS NULL THEN count(c) END +FROM grouping +GROUP BY GROUPING SETS (a, b, c); diff --git a/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out b/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out index f85cd8fad3d..61d9523da6d 100644 --- a/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out @@ -224,3 +224,21 @@ struct<grouping_id(k1, k2):bigint,avg(v):double> 0 2.0 1 1.0 1 2.0 + + +-- !query +SELECT CASE WHEN a IS NULL THEN count(b) WHEN b IS NULL THEN count(c) END +FROM grouping +GROUP BY GROUPING SETS (a, b, c) +-- !query schema +struct<CASE WHEN (a IS NULL) THEN count(b) WHEN (b IS NULL) THEN count(c) END:bigint> +-- !query output +1 +1 +1 +1 +1 +1 +1 +1 +1 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org