This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new 5567dca35f9f [SPARK-49961][SQL] Correct transform type signature for
both Scala and Java
5567dca35f9f is described below
commit 5567dca35f9fe0c426e2f44f6c5843f58cc06744
Author: Chris Twiner <[email protected]>
AuthorDate: Tue Jan 28 21:24:06 2025 -0400
[SPARK-49961][SQL] Correct transform type signature for both Scala and Java
### What changes were proposed in this pull request?
Rollback of transform function signature from SPARK-49568 to SPARK-49029.
SPARK-49961 notes an API bug introduced during SPARK-49568 with changes
applied to SPARK-49029. The bug fix reintroduces the aaf61e69 change and
applies the same to sql.Dataset. Without this change compilation of code using
transform will break due to incorrect types. Scala code compiles with just the
base api.Dataset change but Java will not, hence the sql.Dataset change.
### Why are the changes needed?
```scala
import sparkSession.implicits._
val ds = Seq(1, 2).toDS()
val f: Dataset[Int] => Dataset[Int] = d => d.selectExpr("(value + 1)
value").as[Int]
val transformed = ds.transform(f)
assert(transformed.collect().sorted === Array(2, 3))
```
fails to compile. With this (re)patch it succeeds.
### Does this PR introduce _any_ user-facing change?
No - as per previous versions (yes from 4.0.0-preview2)
### How was this patch tested?
Simple tests introduced for Scala and Java interfaces.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #48479 from chris-twiner/temp/transform.
Authored-by: Chris Twiner <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
(cherry picked from commit af9f4ef1e7412ebd10051efe5056d32330cf660f)
Signed-off-by: Herman van Hovell <[email protected]>
---
.../sql/connect/ClientDataFrameStatSuite.scala | 9 ++++++++
.../main/scala/org/apache/spark/sql/Dataset.scala | 3 ++-
.../org/apache/spark/sql/JavaDataFrameSuite.java | 27 ++++++++++++++++++++++
.../scala/org/apache/spark/sql/DatasetSuite.scala | 24 +++++++++++++++++++
4 files changed, 62 insertions(+), 1 deletion(-)
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientDataFrameStatSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientDataFrameStatSuite.scala
index a7e2e61a106f..d812f5e3deb7 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientDataFrameStatSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientDataFrameStatSuite.scala
@@ -264,4 +264,13 @@ class ClientDataFrameStatSuite extends ConnectFunSuite
with RemoteSparkSession {
}
assert(error3.getCondition === "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE")
}
+
+ test("SPARK-49961: transform type should be consistent") {
+ val session = spark
+ import session.implicits._
+ val ds = Seq(1, 2).toDS()
+ val f: Dataset[Int] => Dataset[Int] = d => d.selectExpr("(value + 1)
value").as[Int]
+ val transformed = ds.transform(f)
+ assert(transformed.collect().sorted === Array(2, 3))
+ }
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala
index c49a5d5a5088..11291079f84f 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2699,7 +2699,8 @@ abstract class Dataset[T] extends Serializable {
* @group typedrel
* @since 1.6.0
*/
- def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] =
t(this.asInstanceOf[Dataset[T]])
+ def transform[U, DSO[_] <: Dataset[_]](t: this.type => DSO[U]): DSO[U] =
+ t(this)
/**
* (Scala-specific) Returns a new Dataset that contains the result of
applying `func` to each
diff --git
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 26a19cbed1b9..27137e53934d 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -35,6 +35,7 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.expressions.UserDefinedFunction;
@@ -540,4 +541,30 @@ public class JavaDataFrameSuite {
.map(row -> row.get(0).toString() +
row.getString(1)).toArray(String[]::new);
Assertions.assertArrayEquals(expected, result);
}
+
+ @Test
+ public void testTransformBase() {
+ // SPARK-49961 - transform must have the correct type
+ Dataset<Integer> ds = spark.createDataset(Arrays.asList(1,2),
Encoders.INT());
+ Dataset<Integer> transformed = ds.transform((Dataset<Integer> d) ->
+ ds.selectExpr("(value + 1) value").as(Encoders.INT()));
+ Integer[] expected = {2, 3};
+ Integer[] got = transformed.collectAsList().toArray(new Integer[0]);
+ Arrays.sort(got);
+ Assertions.assertArrayEquals(expected, got);
+ }
+
+ @Test
+ public void testTransformAsClassic() {
+ // SPARK-49961 - transform must have the correct type
+ org.apache.spark.sql.classic.Dataset<Integer> ds =
+ spark.createDataset(Arrays.asList(1,2), Encoders.INT());
+ org.apache.spark.sql.classic.Dataset<Integer> transformed =
+ ds.transform((Dataset<Integer> d) ->
+ ds.selectExpr("(value + 1) value").as(Encoders.INT()));
+ Integer[] expected = {2, 3};
+ Integer[] got = transformed.collectAsList().toArray(new Integer[0]);
+ Arrays.sort(got);
+ Assertions.assertArrayEquals(expected, got);
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 403d2b697a9e..8963c9de4ee4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -2802,6 +2802,30 @@ class DatasetSuite extends QueryTest
}
}
}
+
+ test("SPARK-49961: transform type should be consistent (classic)") {
+ val ds = Seq(1, 2).toDS()
+ val f: classic.Dataset[Int] => classic.Dataset[Int] =
+ d => d.selectExpr("(value + 1) value").as[Int]
+ val transformed = ds.transform(f)
+ assert(transformed.collect().sorted === Array(2, 3))
+ }
+
+ test("SPARK-49961: transform type should be consistent (base to classic)") {
+ val ds = Seq(1, 2).toDS()
+ val f: Dataset[Int] => classic.Dataset[Int] =
+ d => d.selectExpr("(value + 1) value").as[Int]
+ val transformed = ds.transform(f)
+ assert(transformed.collect().sorted === Array(2, 3))
+ }
+
+ test("SPARK-49961: transform type should be consistent (as base)") {
+ val ds = Seq(1, 2).toDS().asInstanceOf[Dataset[Int]]
+ val f: Dataset[Int] => Dataset[Int] =
+ d => d.selectExpr("(value + 1) value").as[Int]
+ val transformed = ds.transform(f)
+ assert(transformed.collect().sorted === Array(2, 3))
+ }
}
class DatasetLargeResultCollectingSuite extends QueryTest
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]