This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new ba05a6bcd972 [SPARK-49261][SQL] Don't replace literals in aggregate
expressions with group-by expressions
ba05a6bcd972 is described below
commit ba05a6bcd972ed4d5d1ee7a31f1c770ed7bfaed7
Author: Bruce Robbins <[email protected]>
AuthorDate: Thu Sep 12 08:11:03 2024 -0700
[SPARK-49261][SQL] Don't replace literals in aggregate expressions with
group-by expressions
### What changes were proposed in this pull request?
Before this PR, `RewriteDistinctAggregates` could potentially replace
literals in the aggregate expressions with output attributes from the `Expand`
operator. This can occur when a group-by expression is a literal that happens
by chance to match a literal used in an aggregate expression. E.g.:
```
create or replace temp view v1(a, b, c) as values
(1, 1.001d, 2), (2, 3.001d, 4), (2, 3.001, 4);
cache table v1;
select
round(sum(b), 6) as sum1,
count(distinct a) as count1,
count(distinct c) as count2
from (
select
6 as gb,
*
from v1
)
group by a, gb;
```
In the optimized plan, you can see that the literal 6 in the `round`
function invocation has been patched with an output attribute (6#163) from the
`Expand` operator:
```
== Optimized Logical Plan ==
'Aggregate [a#123, 6#163],
[round(first(sum(__auto_generated_subquery_name.b)#167, true) FILTER (WHERE
(gid#162 = 0)), 6#163) AS sum1#114, count(__auto_generated_subquery_name.a#164)
FILTER (WHERE (gid#162 = 1)) AS count1#115L,
count(__auto_generated_subquery_name.c#165) FILTER (WHERE (gid#162 = 2)) AS
count2#116L]
+- Aggregate [a#123, 6#163, __auto_generated_subquery_name.a#164,
__auto_generated_subquery_name.c#165, gid#162], [a#123, 6#163,
__auto_generated_subquery_name.a#164, __auto_generated_subquery_name.c#165,
gid#162, sum(__auto_generated_subquery_name.b#166) AS
sum(__auto_generated_subquery_name.b)#167]
+- Expand [[a#123, 6, null, null, 0, b#124], [a#123, 6, a#123, null, 1,
null], [a#123, 6, null, c#125, 2, null]], [a#123, 6#163,
__auto_generated_subquery_name.a#164, __auto_generated_subquery_name.c#165,
gid#162, __auto_generated_subquery_name.b#166]
+- InMemoryRelation [a#123, b#124, c#125], StorageLevel(disk, memory,
deserialized, 1 replicas)
+- LocalTableScan [a#6, b#7, c#8]
```
This is because the literal 6 was used in the group-by expressions
(referred to as gb in the query, and renamed 6#163 in the `Expand` operator's
output attributes).
After this PR, foldable expressions in the aggregate expressions are kept
as-is.
### Why are the changes needed?
Some expressions require a foldable argument. In the above example, the
`round` function requires a foldable expression as the scale argument. Because
the scale argument is patched with an attribute,
`RoundBase#checkInputDataTypes` returns an error, which leaves the `Aggregate`
operator unresolved:
```
[INTERNAL_ERROR] Invalid call to dataType on unresolved object SQLSTATE:
XX000
org.apache.spark.sql.catalyst.analysis.UnresolvedException:
[INTERNAL_ERROR] Invalid call to dataType on unresolved object SQLSTATE: XX000
at
org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute.dataType(unresolved.scala:255)
at
org.apache.spark.sql.catalyst.types.DataTypeUtils$.$anonfun$fromAttributes$1(DataTypeUtils.scala:241)
at scala.collection.immutable.List.map(List.scala:247)
at scala.collection.immutable.List.map(List.scala:79)
at
org.apache.spark.sql.catalyst.types.DataTypeUtils$.fromAttributes(DataTypeUtils.scala:241)
at
org.apache.spark.sql.catalyst.plans.QueryPlan.schema$lzycompute(QueryPlan.scala:428)
at
org.apache.spark.sql.catalyst.plans.QueryPlan.schema(QueryPlan.scala:428)
at
org.apache.spark.sql.execution.SparkPlan.executeCollectPublic(SparkPlan.scala:474)
...
```
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47876 from bersprockets/group_by_lit_issue.
Authored-by: Bruce Robbins <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
(cherry picked from commit 1a0791d006e25898b67cc17e1420f053a39091b9)
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../optimizer/RewriteDistinctAggregates.scala | 3 ++-
.../optimizer/RewriteDistinctAggregatesSuite.scala | 18 +++++++++++++++++-
.../apache/spark/sql/DataFrameAggregateSuite.scala | 21 +++++++++++++++++++++
3 files changed, 40 insertions(+), 2 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
index 801bd2693af4..5aef82b64ed3 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
@@ -400,13 +400,14 @@ object RewriteDistinctAggregates extends
Rule[LogicalPlan] {
(distinctAggOperatorMap.flatMap(_._2) ++
regularAggOperatorMap.map(e => (e._1, e._3))).toMap
+ val groupByMapNonFoldable = groupByMap.filter(!_._1.foldable)
val patchedAggExpressions = a.aggregateExpressions.map { e =>
e.transformDown {
case e: Expression =>
// The same GROUP BY clauses can have different forms (different
names for instance) in
// the groupBy and aggregate expressions of an aggregate. This
makes a map lookup
// tricky. So we do a linear search for a semantically equal group
by expression.
- groupByMap
+ groupByMapNonFoldable
.find(ge => e.semanticEquals(ge._1))
.map(_._2)
.getOrElse(transformations.getOrElse(e, e))
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
index ac136dfb898e..4d31999ded65 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.catalyst.expressions.{Literal, Round}
import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand,
LocalRelation, LogicalPlan}
@@ -109,4 +109,20 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
case _ => fail(s"Plan is not rewritten:\n$rewrite")
}
}
+
+ test("SPARK-49261: Literals in grouping expressions shouldn't result in
unresolved aggregation") {
+ val relation = testRelation2
+ .select(Literal(6).as("gb"), $"a", $"b", $"c", $"d")
+ val input = relation
+ .groupBy($"a", $"gb")(
+ countDistinct($"b").as("agg1"),
+ countDistinct($"d").as("agg2"),
+ Round(sum($"c").as("sum1"), 6)).analyze
+ val rewriteFold = FoldablePropagation(input)
+ // without the fix, the below produces an unresolved plan
+ val rewrite = RewriteDistinctAggregates(rewriteFold)
+ if (!rewrite.resolved) {
+ fail(s"Plan is not as expected:\n$rewrite")
+ }
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 0967a562980b..9ce542e45b74 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -1673,6 +1673,27 @@ class DataFrameAggregateSuite extends QueryTest
})
}
}
+
+ test("SPARK-49261: Literals in grouping expressions shouldn't result in
unresolved aggregation") {
+ val data = Seq((1, 1.001d, 2), (2, 3.001d, 4), (2, 3.001, 4)).toDF("a",
"b", "c")
+ withTempView("v1") {
+ data.createOrReplaceTempView("v1")
+ val df =
+ sql("""SELECT
+ | ROUND(SUM(b), 6) AS sum1,
+ | COUNT(DISTINCT a) AS count1,
+ | COUNT(DISTINCT c) AS count2
+ |FROM (
+ | SELECT
+ | 6 AS gb,
+ | *
+ | FROM v1
+ |)
+ |GROUP BY a, gb
+ |""".stripMargin)
+ checkAnswer(df, Row(1.001d, 1, 1) :: Row(6.002d, 1, 1) :: Nil)
+ }
+ }
}
case class B(c: Option[Double])
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]