This is an automated email from the ASF dual-hosted git repository.
ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 43a9b88991b2 [SPARK-50693][CONNECT] The inputs for TypedScalaUdf
should be analyzed
43a9b88991b2 is described below
commit 43a9b88991b22757c1b5ae40d3fc7efcdb893d82
Author: Takuya Ueshin <[email protected]>
AuthorDate: Mon Dec 30 11:50:28 2024 -0800
[SPARK-50693][CONNECT] The inputs for TypedScalaUdf should be analyzed
### What changes were proposed in this pull request?
Fixes `SparkConnectPlanner` to analyze the inputs for `TypedScalaUdf`.
### Why are the changes needed?
The inputs for `TypedScalaUdf` should be analyzed.
For example:
```scala
val df = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDF("c1",
"c2")
df.select("*").filter(r => r.getInt(1) > 5)
```
fails with:
```
org.apache.spark.SparkException: [INTERNAL_ERROR] Invalid call to
toAttribute on unresolved object SQLSTATE: XX000
at
org.apache.spark.sql.catalyst.analysis.Star.toAttribute(unresolved.scala:438)
at
org.apache.spark.sql.catalyst.plans.logical.Project.$anonfun$output$1(basicLogicalOperators.scala:74)
at scala.collection.immutable.List.map(List.scala:247)
at scala.collection.immutable.List.map(List.scala:79)
at
org.apache.spark.sql.catalyst.plans.logical.Project.output(basicLogicalOperators.scala:74)
at
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformTypedFilter(SparkConnectPlanner.scala:1460)
at
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformFilter(SparkConnectPlanner.scala:1437)
...
```
### Does this PR introduce _any_ user-facing change?
The failure will not appear.
### How was this patch tested?
Added the related tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #49327 from ueshin/issues/SPARK-50693/typed_scala_udf.
Authored-by: Takuya Ueshin <[email protected]>
Signed-off-by: Takuya Ueshin <[email protected]>
---
.../sql/KeyValueGroupedDatasetE2ETestSuite.scala | 68 ++++++++++++++++++++++
.../sql/UserDefinedFunctionE2ETestSuite.scala | 8 +++
.../sql/connect/planner/SparkConnectPlanner.scala | 23 +++++---
3 files changed, 90 insertions(+), 9 deletions(-)
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
index 988774d5eec9..6fd664d90540 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
@@ -479,6 +479,25 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest
with RemoteSparkSessi
"(c,1,1)")
}
+ test("SPARK-50693: groupby on unresolved plan") {
+ val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c",
1, 1))
+ .toDF("key", "seq", "value")
+ val grouped = ds.select("*").groupBy($"key").as[String, (String, Int, Int)]
+ val aggregated = grouped
+ .flatMapSortedGroups($"seq", expr("length(key)"), $"value") { (g, iter)
=>
+ Iterator(g, iter.mkString(", "))
+ }
+
+ checkDatasetUnorderly(
+ aggregated,
+ "a",
+ "(a,1,10), (a,2,20)",
+ "b",
+ "(b,1,2), (b,2,1)",
+ "c",
+ "(c,1,1)")
+ }
+
test("groupby - keyAs, keys") {
val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c",
1, 1))
.toDF("key", "seq", "value")
@@ -597,6 +616,16 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest
with RemoteSparkSessi
("c", 1L))
}
+ test("SPARK-50693: RowEncoder in udf on unresolved plan") {
+ val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c",
1)).toDF("c1", "c2")
+
+ checkDatasetUnorderly(
+ ds.select("*").groupByKey(k =>
k.getAs[String](0)).agg(sum("c2").as[Long]),
+ ("a", 30L),
+ ("b", 3L),
+ ("c", 1L))
+ }
+
test("mapGroups with row encoder") {
val df = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c",
1)).toDF("c1", "c2")
@@ -611,6 +640,21 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest
with RemoteSparkSessi
1)
}
+ test("SPARK-50693: mapGroups with row encoder on unresolved plan") {
+ val df = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c",
1)).toDF("c1", "c2")
+
+ checkDataset(
+ df.select("*")
+ .groupByKey(r => r.getAs[String]("c1"))
+ .mapGroups((_, it) =>
+ it.map(r => {
+ r.getAs[Int]("c2")
+ }).sum),
+ 30,
+ 3,
+ 1)
+ }
+
test("coGroup with row encoder") {
val df1 = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c",
1)).toDF("c1", "c2")
val df2 = Seq(("x", 10), ("x", 20), ("y", 1), ("y", 2), ("a",
1)).toDF("c1", "c2")
@@ -632,6 +676,30 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest
with RemoteSparkSessi
3)
}
+ test("SPARK-50693: coGroup with row encoder on unresolved plan") {
+ val df1 = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c",
1)).toDF("c1", "c2")
+ val df2 = Seq(("x", 10), ("x", 20), ("y", 1), ("y", 2), ("a",
1)).toDF("c1", "c2")
+
+ Seq((df1.select("*"), df2), (df1, df2.select("*")), (df1.select("*"),
df2.select("*")))
+ .foreach { case (df1, df2) =>
+ val ds1: KeyValueGroupedDataset[String, Row] =
+ df1.groupByKey(r => r.getAs[String]("c1"))
+ val ds2: KeyValueGroupedDataset[String, Row] =
+ df2.groupByKey(r => r.getAs[String]("c1"))
+ checkDataset(
+ ds1.cogroup(ds2)((_, it, it2) => {
+ val sum1 = it.map(r => r.getAs[Int]("c2")).sum
+ val sum2 = it2.map(r => r.getAs[Int]("c2")).sum
+ Iterator(sum1 + sum2)
+ }),
+ 31,
+ 3,
+ 1,
+ 30,
+ 3)
+ }
+ }
+
test("serialize as null") {
val kvgds = session.range(10).groupByKey(_ % 2)
val bytes = SparkSerDeUtils.serialize(kvgds)
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
index ca754c7b542f..8415444c10aa 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
@@ -301,6 +301,14 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest
with RemoteSparkSession
checkDataset(df.filter(r => r.getInt(1) > 5), Row("a", 10), Row("a", 20))
}
+ test("SPARK-50693: Filter with row input encoder on unresolved plan") {
+ val session: SparkSession = spark
+ import session.implicits._
+ val df = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c",
1)).toDF("c1", "c2")
+
+ checkDataset(df.select("*").filter(r => r.getInt(1) > 5), Row("a", 10),
Row("a", 20))
+ }
+
test("mapPartitions with row input encoder") {
val session: SparkSession = spark
import session.implicits._
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index d6ade1ac9126..8bb5e54c36cc 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -875,18 +875,20 @@ class SparkConnectPlanner(
logicalPlan: LogicalPlan,
groupingExprs: java.util.List[proto.Expression],
sortOrder: Seq[SortOrder]): UntypedKeyValueGroupedDataset = {
+ val analyzed = session.sessionState.executePlan(logicalPlan).analyzed
+
assertPlan(groupingExprs.size() >= 1)
val dummyFunc = TypedScalaUdf(groupingExprs.get(0), None)
val groupExprs = groupingExprs.asScala.toSeq.drop(1).map(expr =>
transformExpression(expr))
val (qe, aliasedGroupings) =
- RelationalGroupedDataset.handleGroupingExpression(logicalPlan,
session, groupExprs)
+ RelationalGroupedDataset.handleGroupingExpression(analyzed, session,
groupExprs)
UntypedKeyValueGroupedDataset(
dummyFunc.outEnc,
dummyFunc.inEnc,
qe.analyzed,
- logicalPlan.output,
+ analyzed.output,
aliasedGroupings,
sortOrder)
}
@@ -895,20 +897,22 @@ class SparkConnectPlanner(
logicalPlan: LogicalPlan,
groupingExprs: java.util.List[proto.Expression],
sortOrder: Seq[SortOrder]): UntypedKeyValueGroupedDataset = {
+ val analyzed = session.sessionState.executePlan(logicalPlan).analyzed
+
assertPlan(groupingExprs.size() == 1)
- val groupFunc = TypedScalaUdf(groupingExprs.get(0),
Some(logicalPlan.output))
+ val groupFunc = TypedScalaUdf(groupingExprs.get(0),
Some(analyzed.output))
val vEnc = groupFunc.inEnc
val kEnc = groupFunc.outEnc
- val withGroupingKey = AppendColumns(groupFunc.function, vEnc, kEnc,
logicalPlan)
+ val withGroupingKey = AppendColumns(groupFunc.function, vEnc, kEnc,
analyzed)
// The input logical plan of KeyValueGroupedDataset need to be executed
and analyzed
- val analyzed = session.sessionState.executePlan(withGroupingKey).analyzed
+ val withGroupingKeyAnalyzed =
session.sessionState.executePlan(withGroupingKey).analyzed
UntypedKeyValueGroupedDataset(
kEnc,
vEnc,
- analyzed,
- logicalPlan.output,
+ withGroupingKeyAnalyzed,
+ analyzed.output,
withGroupingKey.newColumns,
sortOrder)
}
@@ -1457,8 +1461,9 @@ class SparkConnectPlanner(
private def transformTypedFilter(
fun: proto.CommonInlineUserDefinedFunction,
child: LogicalPlan): TypedFilter = {
- val udf = TypedScalaUdf(fun, Some(child.output))
- TypedFilter(udf.function, child)(udf.inEnc)
+ val analyzed = session.sessionState.executePlan(child).analyzed
+ val udf = TypedScalaUdf(fun, Some(analyzed.output))
+ TypedFilter(udf.function, analyzed)(udf.inEnc)
}
private def transformProject(rel: proto.Project): LogicalPlan = {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]