This is an automated email from the ASF dual-hosted git repository.

yumwang pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new b89b927af4f [SPARK-43050][SQL] Fix construct aggregate expressions by 
replacing grouping functions
b89b927af4f is described below

commit b89b927af4fe2ae35c2818d1f2f0efd879a39565
Author: Yuming Wang <yumw...@ebay.com>
AuthorDate: Sat Apr 15 09:27:22 2023 +0800

    [SPARK-43050][SQL] Fix construct aggregate expressions by replacing 
grouping functions
    
    ### What changes were proposed in this pull request?
    
    This PR fixes construct aggregate expressions by replacing grouping 
functions if a expression is part of aggregation.
    In the following example, the second `b` should also be replaced:
    <img width="545" alt="image" 
src="https://user-images.githubusercontent.com/5399861/230415618-84cd6334-690e-4b0b-867b-ccc4056226a8.png";>
    
    ### Why are the changes needed?
    
    Fix bug:
    ```
    spark-sql (default)> SELECT CASE WHEN a IS NULL THEN count(b) WHEN b IS 
NULL THEN count(c) END
                       > FROM grouping
                       > GROUP BY GROUPING SETS (a, b, c);
    [MISSING_AGGREGATION] The non-aggregating expression "b" is based on 
columns which are not participating in the GROUP BY clause.
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Unit test.
    
    Closes #40685 from wangyum/SPARK-43050.
    
    Authored-by: Yuming Wang <yumw...@ebay.com>
    Signed-off-by: Yuming Wang <yumw...@ebay.com>
    (cherry picked from commit 45b84cd37add1b9ce274273ad5e519e6bc1d8013)
    Signed-off-by: Yuming Wang <yumw...@ebay.com>
    
    # Conflicts:
    #       
sql/core/src/test/resources/sql-tests/analyzer-results/grouping_set.sql.out
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     | 27 ++++++++--------------
 .../resources/sql-tests/inputs/grouping_set.sql    |  4 ++++
 .../sql-tests/results/grouping_set.sql.out         | 18 +++++++++++++++
 3 files changed, 31 insertions(+), 18 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index d9c2c0ef63b..dc8904d7f74 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -608,31 +608,22 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
         aggregations: Seq[NamedExpression],
         groupByAliases: Seq[Alias],
         groupingAttrs: Seq[Expression],
-        gid: Attribute): Seq[NamedExpression] = aggregations.map { agg =>
-      // collect all the found AggregateExpression, so we can check an 
expression is part of
-      // any AggregateExpression or not.
-      val aggsBuffer = ArrayBuffer[Expression]()
-      // Returns whether the expression belongs to any expressions in 
`aggsBuffer` or not.
-      def isPartOfAggregation(e: Expression): Boolean = {
-        aggsBuffer.exists(a => a.exists(_ eq e))
-      }
-      replaceGroupingFunc(agg, groupByExprs, gid).transformDown {
-        // AggregateExpression should be computed on the unmodified value of 
its argument
-        // expressions, so we should not replace any references to grouping 
expression
-        // inside it.
-        case e if AggregateExpression.isAggregate(e) =>
-          aggsBuffer += e
-          e
-        case e if isPartOfAggregation(e) => e
+        gid: Attribute): Seq[NamedExpression] = {
+      def replaceExprs(e: Expression): Expression = e match {
+        case e if AggregateExpression.isAggregate(e) => e
         case e =>
           // Replace expression by expand output attribute.
           val index = groupByAliases.indexWhere(_.child.semanticEquals(e))
           if (index == -1) {
-            e
+            e.mapChildren(replaceExprs)
           } else {
             groupingAttrs(index)
           }
-      }.asInstanceOf[NamedExpression]
+      }
+      aggregations
+        .map(replaceGroupingFunc(_, groupByExprs, gid))
+        .map(replaceExprs)
+        .map(_.asInstanceOf[NamedExpression])
     }
 
     /*
diff --git a/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql 
b/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql
index 4d516bdda7b..909c36c926c 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql
@@ -60,3 +60,7 @@ SELECT grouping(k1), k1, k2, avg(v) FROM (VALUES 
(1,1,1),(2,2,2)) AS t(k1,k2,v)
 
 -- grouping_id function
 SELECT grouping_id(k1, k2), avg(v) from (VALUES (1,1,1),(2,2,2)) AS t(k1,k2,v) 
GROUP BY k1, k2 GROUPING SETS ((k2, k1), k1);
+
+SELECT CASE WHEN a IS NULL THEN count(b) WHEN b IS NULL THEN count(c) END
+FROM grouping
+GROUP BY GROUPING SETS (a, b, c);
diff --git a/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out 
b/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out
index f85cd8fad3d..61d9523da6d 100644
--- a/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out
@@ -224,3 +224,21 @@ struct<grouping_id(k1, k2):bigint,avg(v):double>
 0      2.0
 1      1.0
 1      2.0
+
+
+-- !query
+SELECT CASE WHEN a IS NULL THEN count(b) WHEN b IS NULL THEN count(c) END
+FROM grouping
+GROUP BY GROUPING SETS (a, b, c)
+-- !query schema
+struct<CASE WHEN (a IS NULL) THEN count(b) WHEN (b IS NULL) THEN count(c) 
END:bigint>
+-- !query output
+1
+1
+1
+1
+1
+1
+1
+1
+1


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to