This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new af9f4ef1e741 [SPARK-49961][SQL] Correct transform type signature for 
both Scala and Java
af9f4ef1e741 is described below

commit af9f4ef1e7412ebd10051efe5056d32330cf660f
Author: Chris Twiner <[email protected]>
AuthorDate: Tue Jan 28 21:24:06 2025 -0400

    [SPARK-49961][SQL] Correct transform type signature for both Scala and Java
    
    ### What changes were proposed in this pull request?
    
    Rollback of transform function signature from SPARK-49568 to SPARK-49029.
    
    SPARK-49961 notes an API bug introduced during SPARK-49568 with changes 
applied to SPARK-49029.  The bug fix reintroduces the aaf61e69 change and 
applies the same to sql.Dataset.  Without this change compilation of code using 
transform will break due to incorrect types.  Scala code compiles with just the 
base api.Dataset change but Java will not, hence the sql.Dataset change.
    
    ### Why are the changes needed?
    
    ```scala
    import sparkSession.implicits._
    val ds = Seq(1, 2).toDS()
    val f: Dataset[Int] => Dataset[Int] = d => d.selectExpr("(value + 1) 
value").as[Int]
    val transformed = ds.transform(f)
    assert(transformed.collect().sorted === Array(2, 3))
    ```
    
    fails to compile.  With this (re)patch it succeeds.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No - as per previous versions (yes from 4.0.0-preview2)
    
    ### How was this patch tested?
    
    Simple tests introduced for Scala and Java interfaces.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #48479 from chris-twiner/temp/transform.
    
    Authored-by: Chris Twiner <[email protected]>
    Signed-off-by: Herman van Hovell <[email protected]>
---
 .../sql/connect/ClientDataFrameStatSuite.scala     |  9 ++++++++
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  3 ++-
 .../org/apache/spark/sql/JavaDataFrameSuite.java   | 27 ++++++++++++++++++++++
 .../scala/org/apache/spark/sql/DatasetSuite.scala  | 24 +++++++++++++++++++
 4 files changed, 62 insertions(+), 1 deletion(-)

diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientDataFrameStatSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientDataFrameStatSuite.scala
index a7e2e61a106f..d812f5e3deb7 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientDataFrameStatSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientDataFrameStatSuite.scala
@@ -264,4 +264,13 @@ class ClientDataFrameStatSuite extends ConnectFunSuite 
with RemoteSparkSession {
     }
     assert(error3.getCondition === "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE")
   }
+
+  test("SPARK-49961: transform type should be consistent") {
+    val session = spark
+    import session.implicits._
+    val ds = Seq(1, 2).toDS()
+    val f: Dataset[Int] => Dataset[Int] = d => d.selectExpr("(value + 1) 
value").as[Int]
+    val transformed = ds.transform(f)
+    assert(transformed.collect().sorted === Array(2, 3))
+  }
 }
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala
index c49a5d5a5088..11291079f84f 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2699,7 +2699,8 @@ abstract class Dataset[T] extends Serializable {
    * @group typedrel
    * @since 1.6.0
    */
-  def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = 
t(this.asInstanceOf[Dataset[T]])
+  def transform[U, DSO[_] <: Dataset[_]](t: this.type => DSO[U]): DSO[U] =
+    t(this)
 
   /**
    * (Scala-specific) Returns a new Dataset that contains the result of 
applying `func` to each
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 26a19cbed1b9..27137e53934d 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -35,6 +35,7 @@ import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.sql.Column;
 import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoders;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
 import org.apache.spark.sql.expressions.UserDefinedFunction;
@@ -540,4 +541,30 @@ public class JavaDataFrameSuite {
       .map(row -> row.get(0).toString() + 
row.getString(1)).toArray(String[]::new);
     Assertions.assertArrayEquals(expected, result);
   }
+
+  @Test
+  public void testTransformBase() {
+    // SPARK-49961 - transform must have the correct type
+    Dataset<Integer> ds = spark.createDataset(Arrays.asList(1,2), 
Encoders.INT());
+    Dataset<Integer> transformed = ds.transform((Dataset<Integer> d) ->
+            ds.selectExpr("(value + 1) value").as(Encoders.INT()));
+    Integer[] expected = {2, 3};
+    Integer[] got = transformed.collectAsList().toArray(new Integer[0]);
+    Arrays.sort(got);
+    Assertions.assertArrayEquals(expected, got);
+  }
+
+  @Test
+  public void testTransformAsClassic() {
+    // SPARK-49961 - transform must have the correct type
+    org.apache.spark.sql.classic.Dataset<Integer> ds =
+            spark.createDataset(Arrays.asList(1,2), Encoders.INT());
+    org.apache.spark.sql.classic.Dataset<Integer> transformed =
+            ds.transform((Dataset<Integer> d) ->
+              ds.selectExpr("(value + 1) value").as(Encoders.INT()));
+    Integer[] expected = {2, 3};
+    Integer[] got = transformed.collectAsList().toArray(new Integer[0]);
+    Arrays.sort(got);
+    Assertions.assertArrayEquals(expected, got);
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 403d2b697a9e..8963c9de4ee4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -2802,6 +2802,30 @@ class DatasetSuite extends QueryTest
       }
     }
   }
+
+  test("SPARK-49961: transform type should be consistent (classic)") {
+    val ds = Seq(1, 2).toDS()
+    val f: classic.Dataset[Int] => classic.Dataset[Int] =
+      d => d.selectExpr("(value + 1) value").as[Int]
+    val transformed = ds.transform(f)
+    assert(transformed.collect().sorted === Array(2, 3))
+  }
+
+  test("SPARK-49961: transform type should be consistent (base to classic)") {
+    val ds = Seq(1, 2).toDS()
+    val f: Dataset[Int] => classic.Dataset[Int] =
+      d => d.selectExpr("(value + 1) value").as[Int]
+    val transformed = ds.transform(f)
+    assert(transformed.collect().sorted === Array(2, 3))
+  }
+
+  test("SPARK-49961: transform type should be consistent (as base)") {
+    val ds = Seq(1, 2).toDS().asInstanceOf[Dataset[Int]]
+    val f: Dataset[Int] => Dataset[Int] =
+      d => d.selectExpr("(value + 1) value").as[Int]
+    val transformed = ds.transform(f)
+    assert(transformed.collect().sorted === Array(2, 3))
+  }
 }
 
 class DatasetLargeResultCollectingSuite extends QueryTest


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to