This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.x by this push:
     new 0f88ecaf03ed [SPARK-57043][SQL] Collapse grouped SUM(COUNT(1)) rollups
0f88ecaf03ed is described below

commit 0f88ecaf03ed905e180656176a0c31b3103d82fb
Author: Chao Sun <[email protected]>
AuthorDate: Fri May 29 13:30:35 2026 -0700

    [SPARK-57043][SQL] Collapse grouped SUM(COUNT(1)) rollups
    
    ### Why are the changes needed?
    
    Queries that first compute fine-grained grouped counts and then roll those 
counts up with
    `SUM` can retain an unnecessary lower aggregation. For example:
    
    ```sql
    SELECT a, SUM(cnt) AS total_count
    FROM (
      SELECT a, b, COUNT(1) AS cnt
      FROM t
      GROUP BY a, b
    )
    GROUP BY a
    ```
    
    The result only needs to count input rows by `a`; it does not need the 
intermediate groups by
    `(a, b)`. When `b` is an expensive derived expression or nested input 
field, retaining that
    lower aggregation can also force the scan to read and compute data that the 
final result does
    not need.
    
    ### What changes were proposed in this pull request?
    
    This PR introduces a dedicated `CollapseGroupedSumOfCount` optimization 
that simplifies:
    
    ```text
    Aggregate(a, SUM(cnt))
    +- Aggregate(a, b, COUNT(1) AS cnt)
       +- input
    ```
    
    to:
    
    ```text
    Aggregate(a, SUM(1))
    +- Project(a)
       +- input
    ```
    
    The rule is applied during early scan planning so schema pruning can remove 
inputs used only by
    the discarded lower grouping expressions. For data source V2, it runs after 
the source has had
    the opportunity to push down the original aggregate and before required 
scan columns are
    finalized. It also honors `spark.sql.optimizer.excludedRules` when used 
from that embedded V2
    path.
    
    The rewrite accepts eligible grouped `SUM(COUNT(1))` rollups and preserves 
compatible `MIN` or
    `MAX` expressions on retained grouping outputs. It excludes `TRY_SUM`, 
nullable or filtered
    counts, distinct counts, global lower aggregates, and plans whose retained 
expressions require
    discarded lower outputs.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. Eligible grouped count rollups receive a simpler optimized plan and 
can read fewer input
    columns when the removed grouping expressions previously kept those columns 
alive.
    
    ### How was this patch tested?
    
    - Added Catalyst optimizer unit coverage for eligible rewrites and rejected 
cases, including
      ANSI/legacy `SUM`, `TRY_SUM`, `MIN`/`MAX`, distinct/filtered counts, and 
nondeterministic
      grouping expressions.
    - Added file-source aggregate-pushdown coverage for preserving source 
pushdown when available,
      applying the fallback when pushdown is rejected, and honoring optimizer 
rule exclusion.
    - Added schema-pruning coverage for removing nested fields required only by 
discarded grouping
      expressions.
    - Validated the pending update with `git diff --check`.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: OpenAI Codex (GPT-5)
    
    Closes #56129 from sunchao/dev/chao/codex/collapse-grouped-count-rollup.
    
    Authored-by: Chao Sun <[email protected]>
    Signed-off-by: Chao Sun <[email protected]>
    (cherry picked from commit 9bebc0d71f2b30bcfedb7dccc3b0da710cdcda0f)
    Signed-off-by: Chao Sun <[email protected]>
---
 .../optimizer/CollapseGroupedSumOfCount.scala      | 124 ++++++++++++++++++
 .../sql/catalyst/rules/RuleIdCollection.scala      |   1 +
 .../optimizer/CollapseGroupedSumOfCountSuite.scala | 143 +++++++++++++++++++++
 .../spark/sql/execution/SparkOptimizer.scala       |   4 +
 .../datasources/v2/V2ScanRelationPushDown.scala    |  16 ++-
 .../FileSourceAggregatePushDownSuite.scala         |  63 +++++++++
 .../execution/datasources/SchemaPruningSuite.scala |  16 +++
 7 files changed, 366 insertions(+), 1 deletion(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CollapseGroupedSumOfCount.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CollapseGroupedSumOfCount.scala
new file mode 100644
index 000000000000..d6a57405d467
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CollapseGroupedSumOfCount.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
AttributeSet, EvalMode, Literal,
+  NamedExpression}
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Count, Max, Min, Sum}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, 
Project}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE
+
+/**
+ * Collapses a grouped `SUM(COUNT(1))` rollup to a single `SUM(1)` aggregation.
+ *
+ * The V2 scan-planning path invokes this after first attempting to push the 
original `COUNT(1)`
+ * and before final column pruning. The V1 path runs it in 
`SparkOptimizer.earlyScanPushDownRules`,
+ * followed by another `SchemaPruning` pass, when the collapse is needed. 
Unlike the standalone V1
+ * invocation, the nested V2 invocation must explicitly honor 
`spark.sql.optimizer.excludedRules`.
+ */
+object CollapseGroupedSumOfCount extends Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
+    _.containsPattern(AGGREGATE), ruleId) {
+    case upper @ Aggregate(_, _, lower: Aggregate, _) =>
+      collapse(upper, lower).getOrElse(upper)
+  }
+
+  /**
+   * `SUM(COUNT(1))` over a finer-grained grouped aggregation can be evaluated 
as `SUM(1)` over
+   * the input when the lower aggregation is grouped. Unfiltered non-distinct 
`MAX` and `MIN`
+   * over retained lower grouping outputs can remain alongside the rewritten 
`SUM`, because
+   * repeating their input values does not change their result. This avoids 
evaluating lower
+   * grouping expressions that only partition rows before their counts are 
added back together.
+   *
+   * Retaining `SUM` preserves its empty-input and ANSI overflow semantics. 
`TRY` is excluded
+   * because the lower `COUNT` can overflow differently from a direct 
`TRY_SUM(1)`. A global
+   * lower aggregate is excluded because it produces one row even when its 
input is empty.
+   */
+  private def collapse(upper: Aggregate, lower: Aggregate): Option[Aggregate] 
= {
+    if (lower.groupingExpressions.isEmpty || 
!lower.groupingExpressions.forall(_.deterministic)) {
+      return None
+    }
+
+    val countOutputs = AttributeSet(lower.aggregateExpressions.collect {
+      case a @ Alias(
+          AggregateExpression(Count(Seq(l: Literal)), _, false, None, _), _)
+          if l.value != null => a.toAttribute
+    })
+    if (countOutputs.isEmpty) {
+      return None
+    }
+
+    val upperAggs = upper.aggregateExpressions.flatMap(_.collect {
+      case ae: AggregateExpression => ae
+    })
+    def isCollapsibleSum(ae: AggregateExpression): Boolean = ae match {
+      case AggregateExpression(Sum(a: Attribute, context), _, false, None, _) 
=>
+        context.evalMode != EvalMode.TRY && countOutputs.contains(a)
+      case _ => false
+    }
+    if (!upperAggs.exists(isCollapsibleSum) || !upperAggs.forall {
+      case ae if isCollapsibleSum(ae) => true
+      case AggregateExpression(_: Max, _, false, None, _) => true
+      case AggregateExpression(_: Min, _, false, None, _) => true
+      case _ => false
+    }) {
+      return None
+    }
+
+    val rewrittenExpressions = upper.aggregateExpressions.map { expr =>
+      expr.transform {
+        case ae @ AggregateExpression(Sum(a: Attribute, context), _, false, 
None, _)
+            if countOutputs.contains(a) =>
+          ae.copy(aggregateFunction = Sum(Literal(1L), context))
+      }.asInstanceOf[NamedExpression]
+    }
+    val rewritten = upper.copy(aggregateExpressions = rewrittenExpressions)
+    val lowerNonAggOutputs = AttributeSet(lower.aggregateExpressions
+      .filter(_.deterministic)
+      .filterNot(AggregateExpression.containsAggregate)
+      .map(_.toAttribute))
+    // MIN and MAX above are safe only when their inputs are retained lower 
grouping outputs.
+    if (!rewritten.references.subsetOf(lowerNonAggOutputs)) {
+      return None
+    }
+
+    val projectList = 
lower.aggregateExpressions.filter(rewritten.references.contains(_))
+    val discardedGroupingRefs =
+      AttributeSet(lower.groupingExpressions.flatMap(_.references)) -- 
rewritten.references
+    if (hasNondeterministicProducer(lower.child, discardedGroupingRefs)) {
+      return None
+    }
+    Some(rewritten.copy(child = Project(projectList, lower.child)))
+  }
+
+  /**
+   * The analyzer pulls a nondeterministic grouping expression, such as 
`rand()`, into a `Project`
+   * below the aggregate and replaces the grouping key with its output 
attribute. Do not remove
+   * that projected evaluation when the corresponding grouping key is 
discarded by this rule.
+   */
+  private def hasNondeterministicProducer(
+      plan: LogicalPlan,
+      attributes: AttributeSet): Boolean = plan match {
+    case Project(projectList, child) =>
+      val producers = projectList.filter(expr => 
attributes.contains(expr.toAttribute))
+      producers.exists(!_.deterministic) ||
+        hasNondeterministicProducer(child, 
AttributeSet(producers.flatMap(_.references)))
+    case _ => false
+  }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
index a890d43f0672..fb6254d82056 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
@@ -119,6 +119,7 @@ object RuleIdCollection {
       
"org.apache.spark.sql.catalyst.expressions.ValidateAndStripPipeExpressions" ::
       // Catalyst Optimizer rules
       "org.apache.spark.sql.catalyst.optimizer.BooleanSimplification" ::
+      "org.apache.spark.sql.catalyst.optimizer.CollapseGroupedSumOfCount" ::
       "org.apache.spark.sql.catalyst.optimizer.CollapseProject" ::
       "org.apache.spark.sql.catalyst.optimizer.CollapseRepartition" ::
       "org.apache.spark.sql.catalyst.optimizer.CollapseWindow" ::
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseGroupedSumOfCountSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseGroupedSumOfCountSuite.scala
new file mode 100644
index 000000000000..242bed45a693
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseGroupedSumOfCountSuite.scala
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{EvalMode, Literal, 
NumericEvalContext}
+import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.internal.SQLConf
+
+class CollapseGroupedSumOfCountSuite extends PlanTest {
+
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches = Batch("CollapseGroupedSumOfCount", Once,
+      CollapseGroupedSumOfCount,
+      RemoveNoopOperators) :: Nil
+  }
+
+  private val relation = LocalRelation($"a".int, $"b".int)
+
+  test("SPARK-57043: collapse grouped sum of count in legacy and ANSI modes") {
+    Seq("false", "true").foreach { ansiEnabled =>
+      withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled) {
+        val query = relation
+          .groupBy($"a", $"b")($"a", $"b", count(Literal(1)).as("cnt"))
+          .groupBy($"a")($"a", sum($"cnt").as("total"))
+          .analyze
+        val expected = relation
+          .select($"a")
+          .groupBy($"a")($"a", sum(Literal(1L)).as("total"))
+          .analyze
+        comparePlans(Optimize.execute(query), expected)
+      }
+    }
+  }
+
+  test("SPARK-57043: collapse global sum of grouped count while preserving 
empty input semantics") {
+    val query = relation
+      .groupBy($"a", $"b")($"a", $"b", count(Literal(1)).as("cnt"))
+      .groupBy()(sum($"cnt").as("total"))
+      .analyze
+    val expected = relation
+      .select()
+      .groupBy()(sum(Literal(1L)).as("total"))
+      .analyze
+    comparePlans(Optimize.execute(query), expected)
+  }
+
+  test("SPARK-57043: keep grouped try_sum of count because overflow behavior 
differs") {
+    val query = relation
+      .groupBy($"a", $"b")($"a", $"b", count(Literal(1)).as("cnt"))
+      .groupBy($"a")(
+        $"a",
+        Sum($"cnt", 
NumericEvalContext(EvalMode.TRY)).toAggregateExpression().as("total"))
+      .analyze
+    comparePlans(Optimize.execute(query), query)
+  }
+
+  test("SPARK-57043: keep grouped sum of nullable count") {
+    val query = relation
+      .groupBy($"a", $"b")($"a", $"b", count($"b").as("cnt"))
+      .groupBy($"a")($"a", sum($"cnt").as("total"))
+      .analyze
+    comparePlans(Optimize.execute(query), query)
+  }
+
+  test("SPARK-57043: keep sum of global count to preserve empty input 
semantics") {
+    val query = relation
+      .groupBy()(count(Literal(1)).as("cnt"))
+      .groupBy()(sum($"cnt").as("total"))
+      .analyze
+    comparePlans(Optimize.execute(query), query)
+  }
+
+  test("SPARK-57043: collapse grouped sum of count with max and min of 
grouping output") {
+    val query = relation
+      .groupBy($"a", $"b")($"a", $"b", count(Literal(1)).as("cnt"))
+      .groupBy($"a")(
+        $"a", sum($"cnt").as("total"), max($"b").as("max_b"), 
min($"b").as("min_b"))
+      .analyze
+    val expected = relation.groupBy($"a")(
+      $"a", sum(Literal(1L)).as("total"), max($"b").as("max_b"), 
min($"b").as("min_b"))
+      .analyze
+    comparePlans(Optimize.execute(query), expected)
+  }
+
+  test("SPARK-57043: keep grouped sum of count with max of count output") {
+    val query = relation
+      .groupBy($"a", $"b")($"a", $"b", count(Literal(1)).as("cnt"))
+      .groupBy($"a")($"a", sum($"cnt").as("total"), max($"cnt").as("max_cnt"))
+      .analyze
+    comparePlans(Optimize.execute(query), query)
+  }
+
+  test("SPARK-57043: keep grouped sum of filtered and distinct counts") {
+    val filtered = relation
+      .groupBy($"a", $"b")(
+        $"a", $"b", count(Literal(1), Some($"b" > Literal(0))).as("cnt"))
+      .groupBy($"a")($"a", sum($"cnt").as("total"))
+      .analyze
+    comparePlans(Optimize.execute(filtered), filtered)
+
+    val distinct = relation
+      .groupBy($"a", $"b")($"a", $"b", countDistinct(Literal(1)).as("cnt"))
+      .groupBy($"a")($"a", sum($"cnt").as("total"))
+      .analyze
+    comparePlans(Optimize.execute(distinct), distinct)
+  }
+
+  test("SPARK-57043: keep grouped sum when outer grouping depends on count 
output") {
+    val query = relation
+      .groupBy($"a", $"b")($"a", $"b", count(Literal(1)).as("cnt"))
+      .groupBy($"a", $"cnt")($"a", $"cnt", sum($"cnt").as("total"))
+      .analyze
+    comparePlans(Optimize.execute(query), query)
+  }
+
+  test("SPARK-57043: keep grouped sum of count with non-deterministic lower 
grouping") {
+    val query = relation
+      .groupBy($"a", rand(0))($"a", count(Literal(1)).as("cnt"))
+      .groupBy($"a")($"a", sum($"cnt").as("total"))
+      .analyze
+    comparePlans(Optimize.execute(query), query)
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index 7f3b8383f0f8..e27faf7b4d9e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -41,6 +41,10 @@ class SparkOptimizer(
       GroupBasedRowLevelOperationScanPlanning,
       V1Writes,
       V2ScanRelationPushDown,
+      // V2 applies this fallback before building its scan. For V1, apply it 
here and rerun
+      // nested pruning because the original grouping expressions may have 
kept extra fields.
+      CollapseGroupedSumOfCount,
+      SchemaPruning,
       V2ScanPartitioningAndOrdering,
       V2Writes,
       PruneFileSourcePartitions,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index a1c69847c509..a17fe787b81e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -26,7 +26,7 @@ import 
org.apache.spark.internal.LogKeys.{AGGREGATE_FUNCTIONS, COLUMN_NAMES, GRO
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, 
Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, Expression, 
ExpressionSet, ExprId, IntegerLiteral, Literal, NamedExpression, 
PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
-import org.apache.spark.sql.catalyst.optimizer.CollapseProject
+import org.apache.spark.sql.catalyst.optimizer.{CollapseGroupedSumOfCount, 
CollapseProject}
 import org.apache.spark.sql.catalyst.planning.{PhysicalOperation, 
ScanOperation}
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, 
LeafNode, Limit, LimitAndOffset, LocalLimit, LogicalPlan, Offset, 
OffsetAndLimit, Project, Sample, SampleMethod, Sort}
 import org.apache.spark.sql.catalyst.rules.Rule
@@ -36,11 +36,13 @@ import 
org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, C
 import org.apache.spark.sql.connector.expressions.filter.Predicate
 import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, 
SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownJoin, 
SupportsPushDownVariantExtractions, V1Scan, VariantExtraction}
 import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, 
VariantInRelation}
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.internal.connector.VariantExtractionImpl
 import org.apache.spark.sql.sources
 import org.apache.spark.sql.types.{DataType, DecimalType, IntegerType, 
StructField, StructType}
 import org.apache.spark.sql.util.SchemaUtils._
 import org.apache.spark.util.ArrayImplicits._
+import org.apache.spark.util.Utils
 
 object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
   import DataSourceV2Implicits._
@@ -52,6 +54,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with 
PredicateHelper {
       pushDownFilters,
       pushDownJoin,
       pushDownAggregates,
+      // Apply the fallback after the source has tried the lower aggregation, 
but before its
+      // required columns are finalized by pruneColumns.
+      collapseGroupedSumOfCount,
       pushDownVariants,
       pushDownLimitAndOffset,
       buildScanWithPushedAggregate,
@@ -64,6 +69,15 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with 
PredicateHelper {
     }
   }
 
+  private def collapseGroupedSumOfCount(plan: LogicalPlan): LogicalPlan = {
+    val excludedRules = 
SQLConf.get.optimizerExcludedRules.toSeq.flatMap(Utils.stringToSeq)
+    if (excludedRules.contains(CollapseGroupedSumOfCount.ruleName)) {
+      plan
+    } else {
+      CollapseGroupedSumOfCount(plan)
+    }
+  }
+
   private def createScanBuilder(plan: LogicalPlan) = plan.transform {
     case r: DataSourceV2Relation =>
       ScanBuilderHolder(r.output, r, 
r.table.asReadable.newScanBuilder(r.options))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala
index 1a0beedfcad1..706657712763 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala
@@ -21,6 +21,8 @@ import java.sql.{Date, Timestamp}
 
 import org.apache.spark.SparkConf
 import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, Row}
+import org.apache.spark.sql.catalyst.optimizer.CollapseGroupedSumOfCount
+import org.apache.spark.sql.catalyst.plans.logical.Aggregate
 import org.apache.spark.sql.execution.datasources.orc.OrcTest
 import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
 import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
@@ -262,6 +264,67 @@ trait FileSourceAggregatePushDownSuite
     }
   }
 
+  test("SPARK-57043: preserve pushed count before collapsing its grouped 
rollup") {
+    withTempPath { dir =>
+      Seq((10, 1, 2), (2, 1, 2), (3, 2, 1), (4, 2, 1), (5, 2, 2))
+        .toDF("value", "p1", "p2")
+        .write
+        .partitionBy("p1", "p2")
+        .format(format)
+        .save(dir.getCanonicalPath)
+
+      withTempView("tmp") {
+        
spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp")
+        val query =
+          """
+            |SELECT p1, SUM(cnt)
+            |FROM (SELECT p1, p2, COUNT(*) AS cnt FROM tmp GROUP BY p1, p2)
+            |GROUP BY p1
+            |""".stripMargin
+        var expected = Array.empty[Row]
+        withSQLConf(aggPushDownEnabledKey -> "false") {
+          expected = sql(query).collect()
+        }
+        withSQLConf(aggPushDownEnabledKey -> "true") {
+          val df = sql(query)
+          checkPushedInfo(
+            df,
+            "PushedAggregation: [COUNT(*)], PushedFilters: [], PushedGroupBy: 
[p1, p2]")
+          checkAnswer(df, expected)
+        }
+      }
+    }
+  }
+
+  test("SPARK-57043: collapse grouped count rollup after rejected scan push 
down") {
+    val data = Seq((1, 1), (1, 1), (1, 2), (2, 1), (2, 2))
+    withDataSourceTable(data, "t") {
+      withSQLConf(aggPushDownEnabledKey -> "true") {
+        val df = sql(
+          """
+            |SELECT _1, SUM(cnt)
+            |FROM (SELECT _1, _2, COUNT(*) AS cnt FROM t GROUP BY _1, _2)
+            |GROUP BY _1
+            |""".stripMargin)
+        checkPushedInfo(df, "PushedAggregation: []")
+        assert(df.queryExecution.optimizedPlan.collect { case _: Aggregate => 
true }.size == 1)
+        checkAnswer(df, Seq(Row(1, 3L), Row(2, 2L)))
+      }
+      withSQLConf(
+          aggPushDownEnabledKey -> "true",
+          SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> 
CollapseGroupedSumOfCount.ruleName) {
+        val df = sql(
+          """
+            |SELECT _1, SUM(cnt)
+            |FROM (SELECT _1, _2, COUNT(*) AS cnt FROM t GROUP BY _1, _2)
+            |GROUP BY _1
+            |""".stripMargin)
+        assert(df.queryExecution.optimizedPlan.collect { case _: Aggregate => 
true }.size == 2)
+        checkAnswer(df, Seq(Row(1, 3L), Row(2, 2L)))
+      }
+    }
+  }
+
   test("push down only if all the aggregates can be pushed down") {
     val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
       (9, "mno", 7), (2, null, 7))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
index 5213c0c5f4e2..7f4b83ee342e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
@@ -636,6 +636,22 @@ abstract class SchemaPruningSuite
     checkAnswer(query4, Row(2, null) :: Row(2, 4) :: Nil)
   }
 
+  testSchemaPruning("SPARK-57043: prune nested grouping field after collapsing 
count rollup") {
+    val query = sql(
+      """
+        |SELECT id, SUM(cnt)
+        |FROM (
+        |  SELECT id, name.last, COUNT(*) AS cnt
+        |  FROM contacts
+        |  WHERE name.first IS NOT NULL
+        |  GROUP BY id, name.last
+        |)
+        |GROUP BY id
+        |""".stripMargin)
+    checkScan(query, "struct<id:int,name:struct<first:string>>")
+    checkAnswer(query.orderBy("id"), Row(0, 1L) :: Row(1, 1L) :: Row(2, 1L) :: 
Row(3, 1L) :: Nil)
+  }
+
   testSchemaPruning("select nested field in window function") {
     val windowSql =
       """


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to