This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 43a9b88991b2 [SPARK-50693][CONNECT] The inputs for TypedScalaUdf 
should be analyzed
43a9b88991b2 is described below

commit 43a9b88991b22757c1b5ae40d3fc7efcdb893d82
Author: Takuya Ueshin <[email protected]>
AuthorDate: Mon Dec 30 11:50:28 2024 -0800

    [SPARK-50693][CONNECT] The inputs for TypedScalaUdf should be analyzed
    
    ### What changes were proposed in this pull request?
    
    Fixes `SparkConnectPlanner` to analyze the inputs for `TypedScalaUdf`.
    
    ### Why are the changes needed?
    
    The inputs for `TypedScalaUdf` should be analyzed.
    
    For example:
    
    ```scala
    val df = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDF("c1", 
"c2")
    df.select("*").filter(r => r.getInt(1) > 5)
    ```
    
    fails with:
    
    ```
    org.apache.spark.SparkException: [INTERNAL_ERROR] Invalid call to 
toAttribute on unresolved object SQLSTATE: XX000
      at 
org.apache.spark.sql.catalyst.analysis.Star.toAttribute(unresolved.scala:438)
      at 
org.apache.spark.sql.catalyst.plans.logical.Project.$anonfun$output$1(basicLogicalOperators.scala:74)
      at scala.collection.immutable.List.map(List.scala:247)
      at scala.collection.immutable.List.map(List.scala:79)
      at 
org.apache.spark.sql.catalyst.plans.logical.Project.output(basicLogicalOperators.scala:74)
      at 
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformTypedFilter(SparkConnectPlanner.scala:1460)
      at 
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformFilter(SparkConnectPlanner.scala:1437)
    ...
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    The failure will not appear.
    
    ### How was this patch tested?
    
    Added the related tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #49327 from ueshin/issues/SPARK-50693/typed_scala_udf.
    
    Authored-by: Takuya Ueshin <[email protected]>
    Signed-off-by: Takuya Ueshin <[email protected]>
---
 .../sql/KeyValueGroupedDatasetE2ETestSuite.scala   | 68 ++++++++++++++++++++++
 .../sql/UserDefinedFunctionE2ETestSuite.scala      |  8 +++
 .../sql/connect/planner/SparkConnectPlanner.scala  | 23 +++++---
 3 files changed, 90 insertions(+), 9 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
index 988774d5eec9..6fd664d90540 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
@@ -479,6 +479,25 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest 
with RemoteSparkSessi
       "(c,1,1)")
   }
 
+  test("SPARK-50693: groupby on unresolved plan") {
+    val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 
1, 1))
+      .toDF("key", "seq", "value")
+    val grouped = ds.select("*").groupBy($"key").as[String, (String, Int, Int)]
+    val aggregated = grouped
+      .flatMapSortedGroups($"seq", expr("length(key)"), $"value") { (g, iter) 
=>
+        Iterator(g, iter.mkString(", "))
+      }
+
+    checkDatasetUnorderly(
+      aggregated,
+      "a",
+      "(a,1,10), (a,2,20)",
+      "b",
+      "(b,1,2), (b,2,1)",
+      "c",
+      "(c,1,1)")
+  }
+
   test("groupby - keyAs, keys") {
     val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 
1, 1))
       .toDF("key", "seq", "value")
@@ -597,6 +616,16 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest 
with RemoteSparkSessi
       ("c", 1L))
   }
 
+  test("SPARK-50693: RowEncoder in udf on unresolved plan") {
+    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 
1)).toDF("c1", "c2")
+
+    checkDatasetUnorderly(
+      ds.select("*").groupByKey(k => 
k.getAs[String](0)).agg(sum("c2").as[Long]),
+      ("a", 30L),
+      ("b", 3L),
+      ("c", 1L))
+  }
+
   test("mapGroups with row encoder") {
     val df = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 
1)).toDF("c1", "c2")
 
@@ -611,6 +640,21 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest 
with RemoteSparkSessi
       1)
   }
 
+  test("SPARK-50693: mapGroups with row encoder on unresolved plan") {
+    val df = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 
1)).toDF("c1", "c2")
+
+    checkDataset(
+      df.select("*")
+        .groupByKey(r => r.getAs[String]("c1"))
+        .mapGroups((_, it) =>
+          it.map(r => {
+            r.getAs[Int]("c2")
+          }).sum),
+      30,
+      3,
+      1)
+  }
+
   test("coGroup with row encoder") {
     val df1 = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 
1)).toDF("c1", "c2")
     val df2 = Seq(("x", 10), ("x", 20), ("y", 1), ("y", 2), ("a", 
1)).toDF("c1", "c2")
@@ -632,6 +676,30 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest 
with RemoteSparkSessi
       3)
   }
 
+  test("SPARK-50693: coGroup with row encoder on unresolved plan") {
+    val df1 = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 
1)).toDF("c1", "c2")
+    val df2 = Seq(("x", 10), ("x", 20), ("y", 1), ("y", 2), ("a", 
1)).toDF("c1", "c2")
+
+    Seq((df1.select("*"), df2), (df1, df2.select("*")), (df1.select("*"), 
df2.select("*")))
+      .foreach { case (df1, df2) =>
+        val ds1: KeyValueGroupedDataset[String, Row] =
+          df1.groupByKey(r => r.getAs[String]("c1"))
+        val ds2: KeyValueGroupedDataset[String, Row] =
+          df2.groupByKey(r => r.getAs[String]("c1"))
+        checkDataset(
+          ds1.cogroup(ds2)((_, it, it2) => {
+            val sum1 = it.map(r => r.getAs[Int]("c2")).sum
+            val sum2 = it2.map(r => r.getAs[Int]("c2")).sum
+            Iterator(sum1 + sum2)
+          }),
+          31,
+          3,
+          1,
+          30,
+          3)
+      }
+  }
+
   test("serialize as null") {
     val kvgds = session.range(10).groupByKey(_ % 2)
     val bytes = SparkSerDeUtils.serialize(kvgds)
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
index ca754c7b542f..8415444c10aa 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
@@ -301,6 +301,14 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest 
with RemoteSparkSession
     checkDataset(df.filter(r => r.getInt(1) > 5), Row("a", 10), Row("a", 20))
   }
 
+  test("SPARK-50693: Filter with row input encoder on unresolved plan") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val df = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 
1)).toDF("c1", "c2")
+
+    checkDataset(df.select("*").filter(r => r.getInt(1) > 5), Row("a", 10), 
Row("a", 20))
+  }
+
   test("mapPartitions with row input encoder") {
     val session: SparkSession = spark
     import session.implicits._
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index d6ade1ac9126..8bb5e54c36cc 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -875,18 +875,20 @@ class SparkConnectPlanner(
         logicalPlan: LogicalPlan,
         groupingExprs: java.util.List[proto.Expression],
         sortOrder: Seq[SortOrder]): UntypedKeyValueGroupedDataset = {
+      val analyzed = session.sessionState.executePlan(logicalPlan).analyzed
+
       assertPlan(groupingExprs.size() >= 1)
       val dummyFunc = TypedScalaUdf(groupingExprs.get(0), None)
       val groupExprs = groupingExprs.asScala.toSeq.drop(1).map(expr => 
transformExpression(expr))
 
       val (qe, aliasedGroupings) =
-        RelationalGroupedDataset.handleGroupingExpression(logicalPlan, 
session, groupExprs)
+        RelationalGroupedDataset.handleGroupingExpression(analyzed, session, 
groupExprs)
 
       UntypedKeyValueGroupedDataset(
         dummyFunc.outEnc,
         dummyFunc.inEnc,
         qe.analyzed,
-        logicalPlan.output,
+        analyzed.output,
         aliasedGroupings,
         sortOrder)
     }
@@ -895,20 +897,22 @@ class SparkConnectPlanner(
         logicalPlan: LogicalPlan,
         groupingExprs: java.util.List[proto.Expression],
         sortOrder: Seq[SortOrder]): UntypedKeyValueGroupedDataset = {
+      val analyzed = session.sessionState.executePlan(logicalPlan).analyzed
+
       assertPlan(groupingExprs.size() == 1)
-      val groupFunc = TypedScalaUdf(groupingExprs.get(0), 
Some(logicalPlan.output))
+      val groupFunc = TypedScalaUdf(groupingExprs.get(0), 
Some(analyzed.output))
       val vEnc = groupFunc.inEnc
       val kEnc = groupFunc.outEnc
 
-      val withGroupingKey = AppendColumns(groupFunc.function, vEnc, kEnc, 
logicalPlan)
+      val withGroupingKey = AppendColumns(groupFunc.function, vEnc, kEnc, 
analyzed)
       // The input logical plan of KeyValueGroupedDataset need to be executed 
and analyzed
-      val analyzed = session.sessionState.executePlan(withGroupingKey).analyzed
+      val withGroupingKeyAnalyzed = 
session.sessionState.executePlan(withGroupingKey).analyzed
 
       UntypedKeyValueGroupedDataset(
         kEnc,
         vEnc,
-        analyzed,
-        logicalPlan.output,
+        withGroupingKeyAnalyzed,
+        analyzed.output,
         withGroupingKey.newColumns,
         sortOrder)
     }
@@ -1457,8 +1461,9 @@ class SparkConnectPlanner(
   private def transformTypedFilter(
       fun: proto.CommonInlineUserDefinedFunction,
       child: LogicalPlan): TypedFilter = {
-    val udf = TypedScalaUdf(fun, Some(child.output))
-    TypedFilter(udf.function, child)(udf.inEnc)
+    val analyzed = session.sessionState.executePlan(child).analyzed
+    val udf = TypedScalaUdf(fun, Some(analyzed.output))
+    TypedFilter(udf.function, analyzed)(udf.inEnc)
   }
 
   private def transformProject(rel: proto.Project): LogicalPlan = {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to