This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3569e768e657 [SPARK-50789][CONNECT] The inputs for typed aggregations 
should be analyzed
3569e768e657 is described below

commit 3569e768e657d4e28ee7520808ec910cdff2b099
Author: Takuya Ueshin <[email protected]>
AuthorDate: Mon Jan 13 11:17:05 2025 -0800

    [SPARK-50789][CONNECT] The inputs for typed aggregations should be analyzed
    
    ### What changes were proposed in this pull request?
    
    Fixes `SparkConnectPlanner` to analyze the inputs for typed aggregations.
    
    ### Why are the changes needed?
    
    The inputs for typed aggregations should be analyzed.
    
    For example:
    
    ```scala
    val ds = Seq("abc", "xyz", "hello").toDS().select("*").as[String]
    ds.groupByKey(_.length).reduceGroups(_ + _).show()
    ```
    
    fails with:
    
    ```
    org.apache.spark.SparkException: [INTERNAL_ERROR] Invalid call to 
toAttribute on unresolved object SQLSTATE: XX000
      
org.apache.spark.sql.catalyst.analysis.Star.toAttribute(unresolved.scala:439)
      
org.apache.spark.sql.catalyst.plans.logical.Project.$anonfun$output$1(basicLogicalOperators.scala:74)
      scala.collection.immutable.List.map(List.scala:247)
      scala.collection.immutable.List.map(List.scala:79)
      
org.apache.spark.sql.catalyst.plans.logical.Project.output(basicLogicalOperators.scala:74)
      
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformExpressionWithTypedReduceExpression(SparkConnectPlanner.scala:2340)
      
org.apache.spark.sql.connect.planner.SparkConnectPlanner.$anonfun$transformKeyValueGroupedAggregate$1(SparkConnectPlanner.scala:2244)
      scala.collection.immutable.List.map(List.scala:247)
      scala.collection.immutable.List.map(List.scala:79)
      
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformKeyValueGroupedAggregate(SparkConnectPlanner.scala:2244)
      
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformAggregate(SparkConnectPlanner.scala:2232)
    ...
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    The failure will not appear.
    
    ### How was this patch tested?
    
    Added the related tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #49449 from ueshin/issues/SPARK-50789/typed_agg.
    
    Authored-by: Takuya Ueshin <[email protected]>
    Signed-off-by: Takuya Ueshin <[email protected]>
---
 .../sql/KeyValueGroupedDatasetE2ETestSuite.scala   |  8 ++++
 .../sql/UserDefinedFunctionE2ETestSuite.scala      | 22 ++++++++++-
 .../sql/connect/planner/SparkConnectPlanner.scala  | 43 +++++++++++++++-------
 3 files changed, 59 insertions(+), 14 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
index 6fd664d90540..021b4fea26e2 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
@@ -460,6 +460,14 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest 
with RemoteSparkSessi
       (5, "hello"))
   }
 
+  test("SPARK-50789: reduceGroups on unresolved plan") {
+    val ds = Seq("abc", "xyz", "hello").toDS().select("*").as[String]
+    checkDatasetUnorderly(
+      ds.groupByKey(_.length).reduceGroups(_ + _),
+      (3, "abcxyz"),
+      (5, "hello"))
+  }
+
   test("groupby") {
     val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 
1, 1))
       .toDF("key", "seq", "value")
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
index 8415444c10aa..19275326d642 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
@@ -401,6 +401,13 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest 
with RemoteSparkSession
     assert(ds.select(aggCol).head() == 135) // 45 + 90
   }
 
+  test("SPARK-50789: UDAF custom Aggregator - toColumn on unresolved plan") {
+    val encoder = Encoders.product[UdafTestInput]
+    val aggCol = new CompleteUdafTestInputAggregator().toColumn
+    val ds = spark.range(10).withColumn("extra", col("id") * 
2).select("*").as(encoder)
+    assert(ds.select(aggCol).head() == 135) // 45 + 90
+  }
+
   test("UDAF custom Aggregator - multiple extends - toColumn") {
     val encoder = Encoders.product[UdafTestInput]
     val aggCol = new CompleteGrandChildUdafTestInputAggregator().toColumn
@@ -408,11 +415,24 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest 
with RemoteSparkSession
     assert(ds.select(aggCol).head() == 540) // (45 + 90) * 4
   }
 
-  test("UDAF custom aggregator - with rows - toColumn") {
+  test("SPARK-50789: UDAF custom Aggregator - multiple extends - toColumn on 
unresolved plan") {
+    val encoder = Encoders.product[UdafTestInput]
+    val aggCol = new CompleteGrandChildUdafTestInputAggregator().toColumn
+    val ds = spark.range(10).withColumn("extra", col("id") * 
2).select("*").as(encoder)
+    assert(ds.select(aggCol).head() == 540) // (45 + 90) * 4
+  }
+
+  test("UDAF custom Aggregator - with rows - toColumn") {
     val ds = spark.range(10).withColumn("extra", col("id") * 2)
     assert(ds.select(RowAggregator.toColumn).head() == 405)
     assert(ds.agg(RowAggregator.toColumn).head().getLong(0) == 405)
   }
+
+  test("SPARK-50789: UDAF custom Aggregator - with rows - toColumn on 
unresolved plan") {
+    val ds = spark.range(10).withColumn("extra", col("id") * 2).select("*")
+    assert(ds.select(RowAggregator.toColumn).head() == 405)
+    assert(ds.agg(RowAggregator.toColumn).head().getLong(0) == 405)
+  }
 }
 
 case class UdafTestInput(id: Long, extra: Long)
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index c0b4384af8b6..6ab69aea12e5 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -845,9 +845,10 @@ class SparkConnectPlanner(
       kEncoder: ExpressionEncoder[_],
       vEncoder: ExpressionEncoder[_],
       analyzed: LogicalPlan,
-      dataAttributes: Seq[Attribute],
+      analyzedData: LogicalPlan,
       groupingAttributes: Seq[Attribute],
       sortOrder: Seq[SortOrder]) {
+    val dataAttributes: Seq[Attribute] = analyzedData.output
     val valueDeserializer: Expression =
       UnresolvedDeserializer(vEncoder.deserializer, dataAttributes)
   }
@@ -900,7 +901,7 @@ class SparkConnectPlanner(
         dummyFunc.outEnc,
         dummyFunc.inEnc,
         qe.analyzed,
-        analyzed.output,
+        analyzed,
         aliasedGroupings,
         sortOrder)
     }
@@ -924,7 +925,7 @@ class SparkConnectPlanner(
         kEnc,
         vEnc,
         withGroupingKeyAnalyzed,
-        analyzed.output,
+        analyzed,
         withGroupingKey.newColumns,
         sortOrder)
     }
@@ -1489,11 +1490,19 @@ class SparkConnectPlanner(
       logical.OneRowRelation()
     }
 
+    val logicalPlan =
+      if (rel.getExpressionsList.asScala.toSeq.exists(
+          _.getExprTypeCase == 
proto.Expression.ExprTypeCase.TYPED_AGGREGATE_EXPRESSION)) {
+        session.sessionState.executePlan(baseRel).analyzed
+      } else {
+        baseRel
+      }
+
     val projection = rel.getExpressionsList.asScala.toSeq
-      .map(transformExpression(_, Some(baseRel)))
+      .map(transformExpression(_, Some(logicalPlan)))
       .map(toNamedExpression)
 
-    logical.Project(projectList = projection, child = baseRel)
+    logical.Project(projectList = projection, child = logicalPlan)
   }
 
   /**
@@ -2241,7 +2250,7 @@ class SparkConnectPlanner(
 
     val keyColumn = TypedAggUtils.aggKeyColumn(ds.kEncoder, 
ds.groupingAttributes)
     val namedColumns = rel.getAggregateExpressionsList.asScala.toSeq
-      .map(expr => transformExpressionWithTypedReduceExpression(expr, input))
+      .map(expr => transformExpressionWithTypedReduceExpression(expr, 
ds.analyzedData))
       .map(toNamedExpression)
     logical.Aggregate(ds.groupingAttributes, keyColumn +: namedColumns, 
ds.analyzed)
   }
@@ -2252,9 +2261,17 @@ class SparkConnectPlanner(
     }
     val input = transformRelation(rel.getInput)
 
+    val logicalPlan =
+      if (rel.getAggregateExpressionsList.asScala.toSeq.exists(
+          _.getExprTypeCase == 
proto.Expression.ExprTypeCase.TYPED_AGGREGATE_EXPRESSION)) {
+        session.sessionState.executePlan(input).analyzed
+      } else {
+        input
+      }
+
     val groupingExprs = 
rel.getGroupingExpressionsList.asScala.toSeq.map(transformExpression)
     val aggExprs = rel.getAggregateExpressionsList.asScala.toSeq
-      .map(expr => transformExpressionWithTypedReduceExpression(expr, input))
+      .map(expr => transformExpressionWithTypedReduceExpression(expr, 
logicalPlan))
     val aliasedAgg = (groupingExprs ++ aggExprs).map(toNamedExpression)
 
     rel.getGroupType match {
@@ -2262,19 +2279,19 @@ class SparkConnectPlanner(
         logical.Aggregate(
           groupingExpressions = groupingExprs,
           aggregateExpressions = aliasedAgg,
-          child = input)
+          child = logicalPlan)
 
       case proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP =>
         logical.Aggregate(
           groupingExpressions = Seq(Rollup(groupingExprs.map(Seq(_)))),
           aggregateExpressions = aliasedAgg,
-          child = input)
+          child = logicalPlan)
 
       case proto.Aggregate.GroupType.GROUP_TYPE_CUBE =>
         logical.Aggregate(
           groupingExpressions = Seq(Cube(groupingExprs.map(Seq(_)))),
           aggregateExpressions = aliasedAgg,
-          child = input)
+          child = logicalPlan)
 
       case proto.Aggregate.GroupType.GROUP_TYPE_PIVOT =>
         if (!rel.hasPivot) {
@@ -2286,7 +2303,7 @@ class SparkConnectPlanner(
           rel.getPivot.getValuesList.asScala.toSeq.map(transformLiteral)
         } else {
           RelationalGroupedDataset
-            .collectPivotValues(Dataset.ofRows(session, input), 
Column(pivotExpr))
+            .collectPivotValues(Dataset.ofRows(session, logicalPlan), 
Column(pivotExpr))
             .map(expressions.Literal.apply)
         }
         logical.Pivot(
@@ -2294,7 +2311,7 @@ class SparkConnectPlanner(
           pivotColumn = pivotExpr,
           pivotValues = valueExprs,
           aggregates = aggExprs,
-          child = input)
+          child = logicalPlan)
 
       case proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS =>
         val groupingSetsExprs = rel.getGroupingSetsList.asScala.toSeq.map { 
getGroupingSets =>
@@ -2306,7 +2323,7 @@ class SparkConnectPlanner(
               groupingSets = groupingSetsExprs,
               userGivenGroupByExprs = groupingExprs)),
           aggregateExpressions = aliasedAgg,
-          child = input)
+          child = logicalPlan)
 
       case other => throw InvalidPlanInput(s"Unknown Group Type $other")
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to