This is an automated email from the ASF dual-hosted git repository.
sunchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 9bebc0d71f2b [SPARK-57043][SQL] Collapse grouped SUM(COUNT(1)) rollups
9bebc0d71f2b is described below
commit 9bebc0d71f2b30bcfedb7dccc3b0da710cdcda0f
Author: Chao Sun <[email protected]>
AuthorDate: Fri May 29 13:30:35 2026 -0700
[SPARK-57043][SQL] Collapse grouped SUM(COUNT(1)) rollups
### Why are the changes needed?
Queries that first compute fine-grained grouped counts and then roll those
counts up with
`SUM` can retain an unnecessary lower aggregation. For example:
```sql
SELECT a, SUM(cnt) AS total_count
FROM (
SELECT a, b, COUNT(1) AS cnt
FROM t
GROUP BY a, b
)
GROUP BY a
```
The result only needs to count input rows by `a`; it does not need the
intermediate groups by
`(a, b)`. When `b` is an expensive derived expression or nested input
field, retaining that
lower aggregation can also force the scan to read and compute data that the
final result does
not need.
### What changes were proposed in this pull request?
This PR introduces a dedicated `CollapseGroupedSumOfCount` optimization
that simplifies:
```text
Aggregate(a, SUM(cnt))
+- Aggregate(a, b, COUNT(1) AS cnt)
+- input
```
to:
```text
Aggregate(a, SUM(1))
+- Project(a)
+- input
```
The rule is applied during early scan planning so schema pruning can remove
inputs used only by
the discarded lower grouping expressions. For data source V2, it runs after
the source has had
the opportunity to push down the original aggregate and before required
scan columns are
finalized. It also honors `spark.sql.optimizer.excludedRules` when used
from that embedded V2
path.
The rewrite accepts eligible grouped `SUM(COUNT(1))` rollups and preserves
compatible `MIN` or
`MAX` expressions on retained grouping outputs. It excludes `TRY_SUM`,
nullable or filtered
counts, distinct counts, global lower aggregates, and plans whose retained
expressions require
discarded lower outputs.
### Does this PR introduce _any_ user-facing change?
Yes. Eligible grouped count rollups receive a simpler optimized plan and
can read fewer input
columns when the removed grouping expressions previously kept those columns
alive.
### How was this patch tested?
- Added Catalyst optimizer unit coverage for eligible rewrites and rejected
cases, including
ANSI/legacy `SUM`, `TRY_SUM`, `MIN`/`MAX`, distinct/filtered counts, and
nondeterministic
grouping expressions.
- Added file-source aggregate-pushdown coverage for preserving source
pushdown when available,
applying the fallback when pushdown is rejected, and honoring optimizer
rule exclusion.
- Added schema-pruning coverage for removing nested fields required only by
discarded grouping
expressions.
- Validated the pending update with `git diff --check`.
### Was this patch authored or co-authored using generative AI tooling?
Generated-by: OpenAI Codex (GPT-5)
Closes #56129 from sunchao/dev/chao/codex/collapse-grouped-count-rollup.
Authored-by: Chao Sun <[email protected]>
Signed-off-by: Chao Sun <[email protected]>
---
.../optimizer/CollapseGroupedSumOfCount.scala | 124 ++++++++++++++++++
.../sql/catalyst/rules/RuleIdCollection.scala | 1 +
.../optimizer/CollapseGroupedSumOfCountSuite.scala | 143 +++++++++++++++++++++
.../spark/sql/execution/SparkOptimizer.scala | 4 +
.../datasources/v2/V2ScanRelationPushDown.scala | 16 ++-
.../FileSourceAggregatePushDownSuite.scala | 63 +++++++++
.../execution/datasources/SchemaPruningSuite.scala | 16 +++
7 files changed, 366 insertions(+), 1 deletion(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CollapseGroupedSumOfCount.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CollapseGroupedSumOfCount.scala
new file mode 100644
index 000000000000..d6a57405d467
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CollapseGroupedSumOfCount.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute,
AttributeSet, EvalMode, Literal,
+ NamedExpression}
+import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Count, Max, Min, Sum}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan,
Project}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE
+
+/**
+ * Collapses a grouped `SUM(COUNT(1))` rollup to a single `SUM(1)` aggregation.
+ *
+ * The V2 scan-planning path invokes this after first attempting to push the
original `COUNT(1)`
+ * and before final column pruning. The V1 path runs it in
`SparkOptimizer.earlyScanPushDownRules`,
+ * followed by another `SchemaPruning` pass, when the collapse is needed.
Unlike the standalone V1
+ * invocation, the nested V2 invocation must explicitly honor
`spark.sql.optimizer.excludedRules`.
+ */
+object CollapseGroupedSumOfCount extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
+ _.containsPattern(AGGREGATE), ruleId) {
+ case upper @ Aggregate(_, _, lower: Aggregate, _) =>
+ collapse(upper, lower).getOrElse(upper)
+ }
+
+ /**
+ * `SUM(COUNT(1))` over a finer-grained grouped aggregation can be evaluated
as `SUM(1)` over
+ * the input when the lower aggregation is grouped. Unfiltered non-distinct
`MAX` and `MIN`
+ * over retained lower grouping outputs can remain alongside the rewritten
`SUM`, because
+ * repeating their input values does not change their result. This avoids
evaluating lower
+ * grouping expressions that only partition rows before their counts are
added back together.
+ *
+ * Retaining `SUM` preserves its empty-input and ANSI overflow semantics.
`TRY` is excluded
+ * because the lower `COUNT` can overflow differently from a direct
`TRY_SUM(1)`. A global
+ * lower aggregate is excluded because it produces one row even when its
input is empty.
+ */
+ private def collapse(upper: Aggregate, lower: Aggregate): Option[Aggregate]
= {
+ if (lower.groupingExpressions.isEmpty ||
!lower.groupingExpressions.forall(_.deterministic)) {
+ return None
+ }
+
+ val countOutputs = AttributeSet(lower.aggregateExpressions.collect {
+ case a @ Alias(
+ AggregateExpression(Count(Seq(l: Literal)), _, false, None, _), _)
+ if l.value != null => a.toAttribute
+ })
+ if (countOutputs.isEmpty) {
+ return None
+ }
+
+ val upperAggs = upper.aggregateExpressions.flatMap(_.collect {
+ case ae: AggregateExpression => ae
+ })
+ def isCollapsibleSum(ae: AggregateExpression): Boolean = ae match {
+ case AggregateExpression(Sum(a: Attribute, context), _, false, None, _)
=>
+ context.evalMode != EvalMode.TRY && countOutputs.contains(a)
+ case _ => false
+ }
+ if (!upperAggs.exists(isCollapsibleSum) || !upperAggs.forall {
+ case ae if isCollapsibleSum(ae) => true
+ case AggregateExpression(_: Max, _, false, None, _) => true
+ case AggregateExpression(_: Min, _, false, None, _) => true
+ case _ => false
+ }) {
+ return None
+ }
+
+ val rewrittenExpressions = upper.aggregateExpressions.map { expr =>
+ expr.transform {
+ case ae @ AggregateExpression(Sum(a: Attribute, context), _, false,
None, _)
+ if countOutputs.contains(a) =>
+ ae.copy(aggregateFunction = Sum(Literal(1L), context))
+ }.asInstanceOf[NamedExpression]
+ }
+ val rewritten = upper.copy(aggregateExpressions = rewrittenExpressions)
+ val lowerNonAggOutputs = AttributeSet(lower.aggregateExpressions
+ .filter(_.deterministic)
+ .filterNot(AggregateExpression.containsAggregate)
+ .map(_.toAttribute))
+ // MIN and MAX above are safe only when their inputs are retained lower
grouping outputs.
+ if (!rewritten.references.subsetOf(lowerNonAggOutputs)) {
+ return None
+ }
+
+ val projectList =
lower.aggregateExpressions.filter(rewritten.references.contains(_))
+ val discardedGroupingRefs =
+ AttributeSet(lower.groupingExpressions.flatMap(_.references)) --
rewritten.references
+ if (hasNondeterministicProducer(lower.child, discardedGroupingRefs)) {
+ return None
+ }
+ Some(rewritten.copy(child = Project(projectList, lower.child)))
+ }
+
+ /**
+ * The analyzer pulls a nondeterministic grouping expression, such as
`rand()`, into a `Project`
+ * below the aggregate and replaces the grouping key with its output
attribute. Do not remove
+ * that projected evaluation when the corresponding grouping key is
discarded by this rule.
+ */
+ private def hasNondeterministicProducer(
+ plan: LogicalPlan,
+ attributes: AttributeSet): Boolean = plan match {
+ case Project(projectList, child) =>
+ val producers = projectList.filter(expr =>
attributes.contains(expr.toAttribute))
+ producers.exists(!_.deterministic) ||
+ hasNondeterministicProducer(child,
AttributeSet(producers.flatMap(_.references)))
+ case _ => false
+ }
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
index a890d43f0672..fb6254d82056 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
@@ -119,6 +119,7 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.expressions.ValidateAndStripPipeExpressions" ::
// Catalyst Optimizer rules
"org.apache.spark.sql.catalyst.optimizer.BooleanSimplification" ::
+ "org.apache.spark.sql.catalyst.optimizer.CollapseGroupedSumOfCount" ::
"org.apache.spark.sql.catalyst.optimizer.CollapseProject" ::
"org.apache.spark.sql.catalyst.optimizer.CollapseRepartition" ::
"org.apache.spark.sql.catalyst.optimizer.CollapseWindow" ::
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseGroupedSumOfCountSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseGroupedSumOfCountSuite.scala
new file mode 100644
index 000000000000..242bed45a693
--- /dev/null
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseGroupedSumOfCountSuite.scala
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{EvalMode, Literal,
NumericEvalContext}
+import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.internal.SQLConf
+
+class CollapseGroupedSumOfCountSuite extends PlanTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches = Batch("CollapseGroupedSumOfCount", Once,
+ CollapseGroupedSumOfCount,
+ RemoveNoopOperators) :: Nil
+ }
+
+ private val relation = LocalRelation($"a".int, $"b".int)
+
+ test("SPARK-57043: collapse grouped sum of count in legacy and ANSI modes") {
+ Seq("false", "true").foreach { ansiEnabled =>
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled) {
+ val query = relation
+ .groupBy($"a", $"b")($"a", $"b", count(Literal(1)).as("cnt"))
+ .groupBy($"a")($"a", sum($"cnt").as("total"))
+ .analyze
+ val expected = relation
+ .select($"a")
+ .groupBy($"a")($"a", sum(Literal(1L)).as("total"))
+ .analyze
+ comparePlans(Optimize.execute(query), expected)
+ }
+ }
+ }
+
+ test("SPARK-57043: collapse global sum of grouped count while preserving
empty input semantics") {
+ val query = relation
+ .groupBy($"a", $"b")($"a", $"b", count(Literal(1)).as("cnt"))
+ .groupBy()(sum($"cnt").as("total"))
+ .analyze
+ val expected = relation
+ .select()
+ .groupBy()(sum(Literal(1L)).as("total"))
+ .analyze
+ comparePlans(Optimize.execute(query), expected)
+ }
+
+ test("SPARK-57043: keep grouped try_sum of count because overflow behavior
differs") {
+ val query = relation
+ .groupBy($"a", $"b")($"a", $"b", count(Literal(1)).as("cnt"))
+ .groupBy($"a")(
+ $"a",
+ Sum($"cnt",
NumericEvalContext(EvalMode.TRY)).toAggregateExpression().as("total"))
+ .analyze
+ comparePlans(Optimize.execute(query), query)
+ }
+
+ test("SPARK-57043: keep grouped sum of nullable count") {
+ val query = relation
+ .groupBy($"a", $"b")($"a", $"b", count($"b").as("cnt"))
+ .groupBy($"a")($"a", sum($"cnt").as("total"))
+ .analyze
+ comparePlans(Optimize.execute(query), query)
+ }
+
+ test("SPARK-57043: keep sum of global count to preserve empty input
semantics") {
+ val query = relation
+ .groupBy()(count(Literal(1)).as("cnt"))
+ .groupBy()(sum($"cnt").as("total"))
+ .analyze
+ comparePlans(Optimize.execute(query), query)
+ }
+
+ test("SPARK-57043: collapse grouped sum of count with max and min of
grouping output") {
+ val query = relation
+ .groupBy($"a", $"b")($"a", $"b", count(Literal(1)).as("cnt"))
+ .groupBy($"a")(
+ $"a", sum($"cnt").as("total"), max($"b").as("max_b"),
min($"b").as("min_b"))
+ .analyze
+ val expected = relation.groupBy($"a")(
+ $"a", sum(Literal(1L)).as("total"), max($"b").as("max_b"),
min($"b").as("min_b"))
+ .analyze
+ comparePlans(Optimize.execute(query), expected)
+ }
+
+ test("SPARK-57043: keep grouped sum of count with max of count output") {
+ val query = relation
+ .groupBy($"a", $"b")($"a", $"b", count(Literal(1)).as("cnt"))
+ .groupBy($"a")($"a", sum($"cnt").as("total"), max($"cnt").as("max_cnt"))
+ .analyze
+ comparePlans(Optimize.execute(query), query)
+ }
+
+ test("SPARK-57043: keep grouped sum of filtered and distinct counts") {
+ val filtered = relation
+ .groupBy($"a", $"b")(
+ $"a", $"b", count(Literal(1), Some($"b" > Literal(0))).as("cnt"))
+ .groupBy($"a")($"a", sum($"cnt").as("total"))
+ .analyze
+ comparePlans(Optimize.execute(filtered), filtered)
+
+ val distinct = relation
+ .groupBy($"a", $"b")($"a", $"b", countDistinct(Literal(1)).as("cnt"))
+ .groupBy($"a")($"a", sum($"cnt").as("total"))
+ .analyze
+ comparePlans(Optimize.execute(distinct), distinct)
+ }
+
+ test("SPARK-57043: keep grouped sum when outer grouping depends on count
output") {
+ val query = relation
+ .groupBy($"a", $"b")($"a", $"b", count(Literal(1)).as("cnt"))
+ .groupBy($"a", $"cnt")($"a", $"cnt", sum($"cnt").as("total"))
+ .analyze
+ comparePlans(Optimize.execute(query), query)
+ }
+
+ test("SPARK-57043: keep grouped sum of count with non-deterministic lower
grouping") {
+ val query = relation
+ .groupBy($"a", rand(0))($"a", count(Literal(1)).as("cnt"))
+ .groupBy($"a")($"a", sum($"cnt").as("total"))
+ .analyze
+ comparePlans(Optimize.execute(query), query)
+ }
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index 7f3b8383f0f8..e27faf7b4d9e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -41,6 +41,10 @@ class SparkOptimizer(
GroupBasedRowLevelOperationScanPlanning,
V1Writes,
V2ScanRelationPushDown,
+ // V2 applies this fallback before building its scan. For V1, apply it
here and rerun
+ // nested pruning because the original grouping expressions may have
kept extra fields.
+ CollapseGroupedSumOfCount,
+ SchemaPruning,
V2ScanPartitioningAndOrdering,
V2Writes,
PruneFileSourcePartitions,
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index a1c69847c509..a17fe787b81e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -26,7 +26,7 @@ import
org.apache.spark.internal.LogKeys.{AGGREGATE_FUNCTIONS, COLUMN_NAMES, GRO
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And,
Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, Expression,
ExpressionSet, ExprId, IntegerLiteral, Literal, NamedExpression,
PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
-import org.apache.spark.sql.catalyst.optimizer.CollapseProject
+import org.apache.spark.sql.catalyst.optimizer.{CollapseGroupedSumOfCount,
CollapseProject}
import org.apache.spark.sql.catalyst.planning.{PhysicalOperation,
ScanOperation}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join,
LeafNode, Limit, LimitAndOffset, LocalLimit, LogicalPlan, Offset,
OffsetAndLimit, Project, Sample, SampleMethod, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
@@ -36,11 +36,13 @@ import
org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, C
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder,
SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownJoin,
SupportsPushDownVariantExtractions, V1Scan, VariantExtraction}
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy,
VariantInRelation}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.connector.VariantExtractionImpl
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.{DataType, DecimalType, IntegerType,
StructField, StructType}
import org.apache.spark.sql.util.SchemaUtils._
import org.apache.spark.util.ArrayImplicits._
+import org.apache.spark.util.Utils
object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
import DataSourceV2Implicits._
@@ -52,6 +54,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with
PredicateHelper {
pushDownFilters,
pushDownJoin,
pushDownAggregates,
+ // Apply the fallback after the source has tried the lower aggregation,
but before its
+ // required columns are finalized by pruneColumns.
+ collapseGroupedSumOfCount,
pushDownVariants,
pushDownLimitAndOffset,
buildScanWithPushedAggregate,
@@ -64,6 +69,15 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with
PredicateHelper {
}
}
+ private def collapseGroupedSumOfCount(plan: LogicalPlan): LogicalPlan = {
+ val excludedRules =
SQLConf.get.optimizerExcludedRules.toSeq.flatMap(Utils.stringToSeq)
+ if (excludedRules.contains(CollapseGroupedSumOfCount.ruleName)) {
+ plan
+ } else {
+ CollapseGroupedSumOfCount(plan)
+ }
+ }
+
private def createScanBuilder(plan: LogicalPlan) = plan.transform {
case r: DataSourceV2Relation =>
ScanBuilderHolder(r.output, r,
r.table.asReadable.newScanBuilder(r.options))
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala
index 1a0beedfcad1..706657712763 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala
@@ -21,6 +21,8 @@ import java.sql.{Date, Timestamp}
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, Row}
+import org.apache.spark.sql.catalyst.optimizer.CollapseGroupedSumOfCount
+import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.execution.datasources.orc.OrcTest
import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
@@ -262,6 +264,67 @@ trait FileSourceAggregatePushDownSuite
}
}
+ test("SPARK-57043: preserve pushed count before collapsing its grouped
rollup") {
+ withTempPath { dir =>
+ Seq((10, 1, 2), (2, 1, 2), (3, 2, 1), (4, 2, 1), (5, 2, 2))
+ .toDF("value", "p1", "p2")
+ .write
+ .partitionBy("p1", "p2")
+ .format(format)
+ .save(dir.getCanonicalPath)
+
+ withTempView("tmp") {
+
spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp")
+ val query =
+ """
+ |SELECT p1, SUM(cnt)
+ |FROM (SELECT p1, p2, COUNT(*) AS cnt FROM tmp GROUP BY p1, p2)
+ |GROUP BY p1
+ |""".stripMargin
+ var expected = Array.empty[Row]
+ withSQLConf(aggPushDownEnabledKey -> "false") {
+ expected = sql(query).collect()
+ }
+ withSQLConf(aggPushDownEnabledKey -> "true") {
+ val df = sql(query)
+ checkPushedInfo(
+ df,
+ "PushedAggregation: [COUNT(*)], PushedFilters: [], PushedGroupBy:
[p1, p2]")
+ checkAnswer(df, expected)
+ }
+ }
+ }
+ }
+
+ test("SPARK-57043: collapse grouped count rollup after rejected scan push
down") {
+ val data = Seq((1, 1), (1, 1), (1, 2), (2, 1), (2, 2))
+ withDataSourceTable(data, "t") {
+ withSQLConf(aggPushDownEnabledKey -> "true") {
+ val df = sql(
+ """
+ |SELECT _1, SUM(cnt)
+ |FROM (SELECT _1, _2, COUNT(*) AS cnt FROM t GROUP BY _1, _2)
+ |GROUP BY _1
+ |""".stripMargin)
+ checkPushedInfo(df, "PushedAggregation: []")
+ assert(df.queryExecution.optimizedPlan.collect { case _: Aggregate =>
true }.size == 1)
+ checkAnswer(df, Seq(Row(1, 3L), Row(2, 2L)))
+ }
+ withSQLConf(
+ aggPushDownEnabledKey -> "true",
+ SQLConf.OPTIMIZER_EXCLUDED_RULES.key ->
CollapseGroupedSumOfCount.ruleName) {
+ val df = sql(
+ """
+ |SELECT _1, SUM(cnt)
+ |FROM (SELECT _1, _2, COUNT(*) AS cnt FROM t GROUP BY _1, _2)
+ |GROUP BY _1
+ |""".stripMargin)
+ assert(df.queryExecution.optimizedPlan.collect { case _: Aggregate =>
true }.size == 2)
+ checkAnswer(df, Seq(Row(1, 3L), Row(2, 2L)))
+ }
+ }
+ }
+
test("push down only if all the aggregates can be pushed down") {
val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
(9, "mno", 7), (2, null, 7))
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
index 5213c0c5f4e2..7f4b83ee342e 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
@@ -636,6 +636,22 @@ abstract class SchemaPruningSuite
checkAnswer(query4, Row(2, null) :: Row(2, 4) :: Nil)
}
+ testSchemaPruning("SPARK-57043: prune nested grouping field after collapsing
count rollup") {
+ val query = sql(
+ """
+ |SELECT id, SUM(cnt)
+ |FROM (
+ | SELECT id, name.last, COUNT(*) AS cnt
+ | FROM contacts
+ | WHERE name.first IS NOT NULL
+ | GROUP BY id, name.last
+ |)
+ |GROUP BY id
+ |""".stripMargin)
+ checkScan(query, "struct<id:int,name:struct<first:string>>")
+ checkAnswer(query.orderBy("id"), Row(0, 1L) :: Row(1, 1L) :: Row(2, 1L) ::
Row(3, 1L) :: Nil)
+ }
+
testSchemaPruning("select nested field in window function") {
val windowSql =
"""
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]