This is an automated email from the ASF dual-hosted git repository.
ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 3569e768e657 [SPARK-50789][CONNECT] The inputs for typed aggregations
should be analyzed
3569e768e657 is described below
commit 3569e768e657d4e28ee7520808ec910cdff2b099
Author: Takuya Ueshin <[email protected]>
AuthorDate: Mon Jan 13 11:17:05 2025 -0800
[SPARK-50789][CONNECT] The inputs for typed aggregations should be analyzed
### What changes were proposed in this pull request?
Fixes `SparkConnectPlanner` to analyze the inputs for typed aggregations.
### Why are the changes needed?
The inputs for typed aggregations should be analyzed.
For example:
```scala
val ds = Seq("abc", "xyz", "hello").toDS().select("*").as[String]
ds.groupByKey(_.length).reduceGroups(_ + _).show()
```
fails with:
```
org.apache.spark.SparkException: [INTERNAL_ERROR] Invalid call to
toAttribute on unresolved object SQLSTATE: XX000
org.apache.spark.sql.catalyst.analysis.Star.toAttribute(unresolved.scala:439)
org.apache.spark.sql.catalyst.plans.logical.Project.$anonfun$output$1(basicLogicalOperators.scala:74)
scala.collection.immutable.List.map(List.scala:247)
scala.collection.immutable.List.map(List.scala:79)
org.apache.spark.sql.catalyst.plans.logical.Project.output(basicLogicalOperators.scala:74)
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformExpressionWithTypedReduceExpression(SparkConnectPlanner.scala:2340)
org.apache.spark.sql.connect.planner.SparkConnectPlanner.$anonfun$transformKeyValueGroupedAggregate$1(SparkConnectPlanner.scala:2244)
scala.collection.immutable.List.map(List.scala:247)
scala.collection.immutable.List.map(List.scala:79)
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformKeyValueGroupedAggregate(SparkConnectPlanner.scala:2244)
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformAggregate(SparkConnectPlanner.scala:2232)
...
```
### Does this PR introduce _any_ user-facing change?
The failure will not appear.
### How was this patch tested?
Added the related tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #49449 from ueshin/issues/SPARK-50789/typed_agg.
Authored-by: Takuya Ueshin <[email protected]>
Signed-off-by: Takuya Ueshin <[email protected]>
---
.../sql/KeyValueGroupedDatasetE2ETestSuite.scala | 8 ++++
.../sql/UserDefinedFunctionE2ETestSuite.scala | 22 ++++++++++-
.../sql/connect/planner/SparkConnectPlanner.scala | 43 +++++++++++++++-------
3 files changed, 59 insertions(+), 14 deletions(-)
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
index 6fd664d90540..021b4fea26e2 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
@@ -460,6 +460,14 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest
with RemoteSparkSessi
(5, "hello"))
}
+ test("SPARK-50789: reduceGroups on unresolved plan") {
+ val ds = Seq("abc", "xyz", "hello").toDS().select("*").as[String]
+ checkDatasetUnorderly(
+ ds.groupByKey(_.length).reduceGroups(_ + _),
+ (3, "abcxyz"),
+ (5, "hello"))
+ }
+
test("groupby") {
val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c",
1, 1))
.toDF("key", "seq", "value")
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
index 8415444c10aa..19275326d642 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
@@ -401,6 +401,13 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest
with RemoteSparkSession
assert(ds.select(aggCol).head() == 135) // 45 + 90
}
+ test("SPARK-50789: UDAF custom Aggregator - toColumn on unresolved plan") {
+ val encoder = Encoders.product[UdafTestInput]
+ val aggCol = new CompleteUdafTestInputAggregator().toColumn
+ val ds = spark.range(10).withColumn("extra", col("id") *
2).select("*").as(encoder)
+ assert(ds.select(aggCol).head() == 135) // 45 + 90
+ }
+
test("UDAF custom Aggregator - multiple extends - toColumn") {
val encoder = Encoders.product[UdafTestInput]
val aggCol = new CompleteGrandChildUdafTestInputAggregator().toColumn
@@ -408,11 +415,24 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest
with RemoteSparkSession
assert(ds.select(aggCol).head() == 540) // (45 + 90) * 4
}
- test("UDAF custom aggregator - with rows - toColumn") {
+ test("SPARK-50789: UDAF custom Aggregator - multiple extends - toColumn on
unresolved plan") {
+ val encoder = Encoders.product[UdafTestInput]
+ val aggCol = new CompleteGrandChildUdafTestInputAggregator().toColumn
+ val ds = spark.range(10).withColumn("extra", col("id") *
2).select("*").as(encoder)
+ assert(ds.select(aggCol).head() == 540) // (45 + 90) * 4
+ }
+
+ test("UDAF custom Aggregator - with rows - toColumn") {
val ds = spark.range(10).withColumn("extra", col("id") * 2)
assert(ds.select(RowAggregator.toColumn).head() == 405)
assert(ds.agg(RowAggregator.toColumn).head().getLong(0) == 405)
}
+
+ test("SPARK-50789: UDAF custom Aggregator - with rows - toColumn on
unresolved plan") {
+ val ds = spark.range(10).withColumn("extra", col("id") * 2).select("*")
+ assert(ds.select(RowAggregator.toColumn).head() == 405)
+ assert(ds.agg(RowAggregator.toColumn).head().getLong(0) == 405)
+ }
}
case class UdafTestInput(id: Long, extra: Long)
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index c0b4384af8b6..6ab69aea12e5 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -845,9 +845,10 @@ class SparkConnectPlanner(
kEncoder: ExpressionEncoder[_],
vEncoder: ExpressionEncoder[_],
analyzed: LogicalPlan,
- dataAttributes: Seq[Attribute],
+ analyzedData: LogicalPlan,
groupingAttributes: Seq[Attribute],
sortOrder: Seq[SortOrder]) {
+ val dataAttributes: Seq[Attribute] = analyzedData.output
val valueDeserializer: Expression =
UnresolvedDeserializer(vEncoder.deserializer, dataAttributes)
}
@@ -900,7 +901,7 @@ class SparkConnectPlanner(
dummyFunc.outEnc,
dummyFunc.inEnc,
qe.analyzed,
- analyzed.output,
+ analyzed,
aliasedGroupings,
sortOrder)
}
@@ -924,7 +925,7 @@ class SparkConnectPlanner(
kEnc,
vEnc,
withGroupingKeyAnalyzed,
- analyzed.output,
+ analyzed,
withGroupingKey.newColumns,
sortOrder)
}
@@ -1489,11 +1490,19 @@ class SparkConnectPlanner(
logical.OneRowRelation()
}
+ val logicalPlan =
+ if (rel.getExpressionsList.asScala.toSeq.exists(
+ _.getExprTypeCase ==
proto.Expression.ExprTypeCase.TYPED_AGGREGATE_EXPRESSION)) {
+ session.sessionState.executePlan(baseRel).analyzed
+ } else {
+ baseRel
+ }
+
val projection = rel.getExpressionsList.asScala.toSeq
- .map(transformExpression(_, Some(baseRel)))
+ .map(transformExpression(_, Some(logicalPlan)))
.map(toNamedExpression)
- logical.Project(projectList = projection, child = baseRel)
+ logical.Project(projectList = projection, child = logicalPlan)
}
/**
@@ -2241,7 +2250,7 @@ class SparkConnectPlanner(
val keyColumn = TypedAggUtils.aggKeyColumn(ds.kEncoder,
ds.groupingAttributes)
val namedColumns = rel.getAggregateExpressionsList.asScala.toSeq
- .map(expr => transformExpressionWithTypedReduceExpression(expr, input))
+ .map(expr => transformExpressionWithTypedReduceExpression(expr,
ds.analyzedData))
.map(toNamedExpression)
logical.Aggregate(ds.groupingAttributes, keyColumn +: namedColumns,
ds.analyzed)
}
@@ -2252,9 +2261,17 @@ class SparkConnectPlanner(
}
val input = transformRelation(rel.getInput)
+ val logicalPlan =
+ if (rel.getAggregateExpressionsList.asScala.toSeq.exists(
+ _.getExprTypeCase ==
proto.Expression.ExprTypeCase.TYPED_AGGREGATE_EXPRESSION)) {
+ session.sessionState.executePlan(input).analyzed
+ } else {
+ input
+ }
+
val groupingExprs =
rel.getGroupingExpressionsList.asScala.toSeq.map(transformExpression)
val aggExprs = rel.getAggregateExpressionsList.asScala.toSeq
- .map(expr => transformExpressionWithTypedReduceExpression(expr, input))
+ .map(expr => transformExpressionWithTypedReduceExpression(expr,
logicalPlan))
val aliasedAgg = (groupingExprs ++ aggExprs).map(toNamedExpression)
rel.getGroupType match {
@@ -2262,19 +2279,19 @@ class SparkConnectPlanner(
logical.Aggregate(
groupingExpressions = groupingExprs,
aggregateExpressions = aliasedAgg,
- child = input)
+ child = logicalPlan)
case proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP =>
logical.Aggregate(
groupingExpressions = Seq(Rollup(groupingExprs.map(Seq(_)))),
aggregateExpressions = aliasedAgg,
- child = input)
+ child = logicalPlan)
case proto.Aggregate.GroupType.GROUP_TYPE_CUBE =>
logical.Aggregate(
groupingExpressions = Seq(Cube(groupingExprs.map(Seq(_)))),
aggregateExpressions = aliasedAgg,
- child = input)
+ child = logicalPlan)
case proto.Aggregate.GroupType.GROUP_TYPE_PIVOT =>
if (!rel.hasPivot) {
@@ -2286,7 +2303,7 @@ class SparkConnectPlanner(
rel.getPivot.getValuesList.asScala.toSeq.map(transformLiteral)
} else {
RelationalGroupedDataset
- .collectPivotValues(Dataset.ofRows(session, input),
Column(pivotExpr))
+ .collectPivotValues(Dataset.ofRows(session, logicalPlan),
Column(pivotExpr))
.map(expressions.Literal.apply)
}
logical.Pivot(
@@ -2294,7 +2311,7 @@ class SparkConnectPlanner(
pivotColumn = pivotExpr,
pivotValues = valueExprs,
aggregates = aggExprs,
- child = input)
+ child = logicalPlan)
case proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS =>
val groupingSetsExprs = rel.getGroupingSetsList.asScala.toSeq.map {
getGroupingSets =>
@@ -2306,7 +2323,7 @@ class SparkConnectPlanner(
groupingSets = groupingSetsExprs,
userGivenGroupByExprs = groupingExprs)),
aggregateExpressions = aliasedAgg,
- child = input)
+ child = logicalPlan)
case other => throw InvalidPlanInput(s"Unknown Group Type $other")
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]