This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1a0791d006e2 [SPARK-49261][SQL] Don't replace literals in aggregate 
expressions with group-by expressions
1a0791d006e2 is described below

commit 1a0791d006e25898b67cc17e1420f053a39091b9
Author: Bruce Robbins <[email protected]>
AuthorDate: Thu Sep 12 08:11:03 2024 -0700

    [SPARK-49261][SQL] Don't replace literals in aggregate expressions with 
group-by expressions
    
    ### What changes were proposed in this pull request?
    
    Before this PR, `RewriteDistinctAggregates` could potentially replace 
literals in the aggregate expressions with output attributes from the `Expand` 
operator. This can occur when a group-by expression is a literal that happens 
by chance to match a literal used in an aggregate expression. E.g.:
    
    ```
    create or replace temp view v1(a, b, c) as values
    (1, 1.001d, 2), (2, 3.001d, 4), (2, 3.001, 4);
    
    cache table v1;
    
    select
      round(sum(b), 6) as sum1,
      count(distinct a) as count1,
      count(distinct c) as count2
    from (
      select
        6 as gb,
        *
      from v1
    )
    group by a, gb;
    ```
    In the optimized plan, you can see that the literal 6 in the `round` 
function invocation has been patched with an output attribute (6#163) from the 
`Expand` operator:
    ```
    == Optimized Logical Plan ==
    'Aggregate [a#123, 6#163], 
[round(first(sum(__auto_generated_subquery_name.b)#167, true) FILTER (WHERE 
(gid#162 = 0)), 6#163) AS sum1#114, count(__auto_generated_subquery_name.a#164) 
FILTER (WHERE (gid#162 = 1)) AS count1#115L, 
count(__auto_generated_subquery_name.c#165) FILTER (WHERE (gid#162 = 2)) AS 
count2#116L]
    +- Aggregate [a#123, 6#163, __auto_generated_subquery_name.a#164, 
__auto_generated_subquery_name.c#165, gid#162], [a#123, 6#163, 
__auto_generated_subquery_name.a#164, __auto_generated_subquery_name.c#165, 
gid#162, sum(__auto_generated_subquery_name.b#166) AS 
sum(__auto_generated_subquery_name.b)#167]
       +- Expand [[a#123, 6, null, null, 0, b#124], [a#123, 6, a#123, null, 1, 
null], [a#123, 6, null, c#125, 2, null]], [a#123, 6#163, 
__auto_generated_subquery_name.a#164, __auto_generated_subquery_name.c#165, 
gid#162, __auto_generated_subquery_name.b#166]
          +- InMemoryRelation [a#123, b#124, c#125], StorageLevel(disk, memory, 
deserialized, 1 replicas)
                +- LocalTableScan [a#6, b#7, c#8]
    ```
    This is because the literal 6 was used in the group-by expressions 
(referred to as gb in the query, and renamed 6#163 in the `Expand` operator's 
output attributes).
    
    After this PR, foldable expressions in the aggregate expressions are kept 
as-is.
    
    ### Why are the changes needed?
    
    Some expressions require a foldable argument. In the above example, the 
`round` function requires a foldable expression as the scale argument. Because 
the scale argument is patched with an attribute, 
`RoundBase#checkInputDataTypes` returns an error, which leaves the `Aggregate` 
operator unresolved:
    ```
    [INTERNAL_ERROR] Invalid call to dataType on unresolved object SQLSTATE: 
XX000
    org.apache.spark.sql.catalyst.analysis.UnresolvedException: 
[INTERNAL_ERROR] Invalid call to dataType on unresolved object SQLSTATE: XX000
            at 
org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute.dataType(unresolved.scala:255)
            at 
org.apache.spark.sql.catalyst.types.DataTypeUtils$.$anonfun$fromAttributes$1(DataTypeUtils.scala:241)
            at scala.collection.immutable.List.map(List.scala:247)
            at scala.collection.immutable.List.map(List.scala:79)
            at 
org.apache.spark.sql.catalyst.types.DataTypeUtils$.fromAttributes(DataTypeUtils.scala:241)
            at 
org.apache.spark.sql.catalyst.plans.QueryPlan.schema$lzycompute(QueryPlan.scala:428)
            at 
org.apache.spark.sql.catalyst.plans.QueryPlan.schema(QueryPlan.scala:428)
            at 
org.apache.spark.sql.execution.SparkPlan.executeCollectPublic(SparkPlan.scala:474)
            ...
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #47876 from bersprockets/group_by_lit_issue.
    
    Authored-by: Bruce Robbins <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 .../optimizer/RewriteDistinctAggregates.scala       |  3 ++-
 .../optimizer/RewriteDistinctAggregatesSuite.scala  | 18 +++++++++++++++++-
 .../apache/spark/sql/DataFrameAggregateSuite.scala  | 21 +++++++++++++++++++++
 3 files changed, 40 insertions(+), 2 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
index 801bd2693af4..5aef82b64ed3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
@@ -400,13 +400,14 @@ object RewriteDistinctAggregates extends 
Rule[LogicalPlan] {
         (distinctAggOperatorMap.flatMap(_._2) ++
           regularAggOperatorMap.map(e => (e._1, e._3))).toMap
 
+      val groupByMapNonFoldable = groupByMap.filter(!_._1.foldable)
       val patchedAggExpressions = a.aggregateExpressions.map { e =>
         e.transformDown {
           case e: Expression =>
             // The same GROUP BY clauses can have different forms (different 
names for instance) in
             // the groupBy and aggregate expressions of an aggregate. This 
makes a map lookup
             // tricky. So we do a linear search for a semantically equal group 
by expression.
-            groupByMap
+            groupByMapNonFoldable
               .find(ge => e.semanticEquals(ge._1))
               .map(_._2)
               .getOrElse(transformations.getOrElse(e, e))
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
index ac136dfb898e..4d31999ded65 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer
 
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.catalyst.expressions.{Literal, Round}
 import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, 
LocalRelation, LogicalPlan}
@@ -109,4 +109,20 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
       case _ => fail(s"Plan is not rewritten:\n$rewrite")
     }
   }
+
+  test("SPARK-49261: Literals in grouping expressions shouldn't result in 
unresolved aggregation") {
+    val relation = testRelation2
+      .select(Literal(6).as("gb"), $"a", $"b", $"c", $"d")
+    val input = relation
+      .groupBy($"a", $"gb")(
+        countDistinct($"b").as("agg1"),
+        countDistinct($"d").as("agg2"),
+        Round(sum($"c").as("sum1"), 6)).analyze
+    val rewriteFold = FoldablePropagation(input)
+    // without the fix, the below produces an unresolved plan
+    val rewrite = RewriteDistinctAggregates(rewriteFold)
+    if (!rewrite.resolved) {
+      fail(s"Plan is not as expected:\n$rewrite")
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 0e9d34c3bd96..e80c3b23a7db 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -2490,6 +2490,27 @@ class DataFrameAggregateSuite extends QueryTest
       })
     }
   }
+
+  test("SPARK-49261: Literals in grouping expressions shouldn't result in 
unresolved aggregation") {
+    val data = Seq((1, 1.001d, 2), (2, 3.001d, 4), (2, 3.001, 4)).toDF("a", 
"b", "c")
+    withTempView("v1") {
+      data.createOrReplaceTempView("v1")
+      val df =
+        sql("""SELECT
+              |  ROUND(SUM(b), 6) AS sum1,
+              |  COUNT(DISTINCT a) AS count1,
+              |  COUNT(DISTINCT c) AS count2
+              |FROM (
+              |  SELECT
+              |    6 AS gb,
+              |    *
+              |  FROM v1
+              |)
+              |GROUP BY a, gb
+              |""".stripMargin)
+      checkAnswer(df, Row(1.001d, 1, 1) :: Row(6.002d, 1, 1) :: Nil)
+    }
+  }
 }
 
 case class B(c: Option[Double])


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to