This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new f496cd1ee2a [SPARK-44736][CONNECT] Add Dataset.explode to Spark
Connect Scala Client
f496cd1ee2a is described below
commit f496cd1ee2a7e59af08e1bd3ab0579f93cc46da9
Author: Herman van Hovell <[email protected]>
AuthorDate: Sun Aug 13 20:27:08 2023 +0200
[SPARK-44736][CONNECT] Add Dataset.explode to Spark Connect Scala Client
### What changes were proposed in this pull request?
This PR adds Dataset.explode to the Spark Connect Scala Client.
### Why are the changes needed?
To increase compatibility with the existing Dataset API in sql/core.
### Does this PR introduce _any_ user-facing change?
Yes, it adds a new method to the scala client.
### How was this patch tested?
I added a test to `UserDefinedFunctionE2ETestSuite`.
Closes #42418 from hvanhovell/SPARK-44736.
Lead-authored-by: Herman van Hovell <[email protected]>
Co-authored-by: itholic <[email protected]>
Co-authored-by: Juliusz Sompolski <[email protected]>
Co-authored-by: Martin Grund <[email protected]>
Co-authored-by: Hyukjin Kwon <[email protected]>
Co-authored-by: Kent Yao <[email protected]>
Co-authored-by: Wenchen Fan <[email protected]>
Co-authored-by: Wei Liu <[email protected]>
Co-authored-by: Ruifeng Zheng <[email protected]>
Co-authored-by: Gengliang Wang <[email protected]>
Co-authored-by: Yuming Wang <[email protected]>
Co-authored-by: Herman van Hovell <[email protected]>
Co-authored-by: 余良 <[email protected]>
Co-authored-by: Dongjoon Hyun <[email protected]>
Co-authored-by: Jack Chen <[email protected]>
Co-authored-by: srielau <[email protected]>
Co-authored-by: zhyhimont <[email protected]>
Co-authored-by: Daniel Tenedorio <[email protected]>
Co-authored-by: Dongjoon Hyun <[email protected]>
Co-authored-by: Zhyhimont Dmitry <[email protected]>
Co-authored-by: Sandip Agarwala
<[email protected]>
Co-authored-by: yangjie01 <[email protected]>
Co-authored-by: Yihong He <[email protected]>
Co-authored-by: Rameshkrishnan Muthusamy
<[email protected]>
Co-authored-by: Jia Fan <[email protected]>
Co-authored-by: allisonwang-db <[email protected]>
Co-authored-by: Utkarsh <[email protected]>
Co-authored-by: Cheng Pan <[email protected]>
Co-authored-by: Jason Li <[email protected]>
Co-authored-by: Shu Wang <[email protected]>
Co-authored-by: Nicolas Fraison <[email protected]>
Co-authored-by: Max Gekk <[email protected]>
Co-authored-by: panbingkun <[email protected]>
Co-authored-by: Ziqi Liu <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
---
.../main/scala/org/apache/spark/sql/Dataset.scala | 70 ++++++++++++++++++++++
.../sql/UserDefinedFunctionE2ETestSuite.scala | 60 +++++++++++++++++++
.../CheckConnectJvmClientCompatibility.scala | 1 -
.../apache/spark/sql/connect/common/UdfUtils.scala | 4 ++
.../sql/connect/planner/SparkConnectPlanner.scala | 3 +-
5 files changed, 136 insertions(+), 2 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 2d72ea6bda8..28b04fb850e 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -21,12 +21,14 @@ import java.util.{Collections, Locale}
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
import org.apache.spark.SparkException
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.function._
import org.apache.spark.connect.proto
+import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.catalyst.expressions.OrderUtils
@@ -2728,6 +2730,74 @@ class Dataset[T] private[sql] (
flatMap(UdfUtils.flatMapFuncToScalaFunc(f))(encoder)
}
+ /**
+ * (Scala-specific) Returns a new Dataset where each row has been expanded
to zero or more rows
+ * by the provided function. This is similar to a `LATERAL VIEW` in HiveQL.
The columns of the
+ * input row are implicitly joined with each row that is output by the
function.
+ *
+ * Given that this is deprecated, as an alternative, you can explode columns
either using
+ * `functions.explode()` or `flatMap()`. The following example uses these
alternatives to count
+ * the number of books that contain a given word:
+ *
+ * {{{
+ * case class Book(title: String, words: String)
+ * val ds: Dataset[Book]
+ *
+ * val allWords = ds.select($"title", explode(split($"words", "
")).as("word"))
+ *
+ * val bookCountPerWord =
allWords.groupBy("word").agg(count_distinct("title"))
+ * }}}
+ *
+ * Using `flatMap()` this can similarly be exploded as:
+ *
+ * {{{
+ * ds.flatMap(_.words.split(" "))
+ * }}}
+ *
+ * @group untypedrel
+ * @since 3.5.0
+ */
+ @deprecated("use flatMap() or select() with functions.explode() instead",
"3.5.0")
+ def explode[A <: Product: TypeTag](input: Column*)(f: Row =>
TraversableOnce[A]): DataFrame = {
+ val generator = ScalarUserDefinedFunction(
+ UdfUtils.traversableOnceToSeq(f),
+ UnboundRowEncoder :: Nil,
+ ScalaReflection.encoderFor[Seq[A]])
+ select(col("*"), functions.inline(generator(struct(input: _*))))
+ }
+
+ /**
+ * (Scala-specific) Returns a new Dataset where a single column has been
expanded to zero or
+ * more rows by the provided function. This is similar to a `LATERAL VIEW`
in HiveQL. All
+ * columns of the input row are implicitly joined with each value that is
output by the
+ * function.
+ *
+ * Given that this is deprecated, as an alternative, you can explode columns
either using
+ * `functions.explode()`:
+ *
+ * {{{
+ * ds.select(explode(split($"words", " ")).as("word"))
+ * }}}
+ *
+ * or `flatMap()`:
+ *
+ * {{{
+ * ds.flatMap(_.words.split(" "))
+ * }}}
+ *
+ * @group untypedrel
+ * @since 3.5.0
+ */
+ @deprecated("use flatMap() or select() with functions.explode() instead",
"3.5.0")
+ def explode[A, B: TypeTag](inputColumn: String, outputColumn: String)(
+ f: A => TraversableOnce[B]): DataFrame = {
+ val generator = ScalarUserDefinedFunction(
+ UdfUtils.traversableOnceToSeq(f),
+ Nil,
+ ScalaReflection.encoderFor[Seq[B]])
+ select(col("*"),
functions.explode(generator(col(inputColumn))).as((outputColumn)))
+ }
+
/**
* Applies a function `f` to all rows.
*
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
index 3a931c9a6ba..d00659ac2d8 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
@@ -95,6 +95,66 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest {
rows.forEach(x => assert(x == 42))
}
+ test("(deprecated) Dataset explode") {
+ val session: SparkSession = spark
+ import session.implicits._
+ val result1 = spark
+ .range(3)
+ .filter(col("id") =!= 1L)
+ .explode(col("id") + 41, col("id") + 10) { case Row(x: Long, y: Long) =>
+ Iterator((x, x - 1), (y, y + 1))
+ }
+ .as[(Long, Long, Long)]
+ .collect()
+ .toSeq
+ assert(result1 === Seq((0L, 41L, 40L), (0L, 10L, 11L), (2L, 43L, 42L),
(2L, 12L, 13L)))
+
+ val result2 = Seq((1, "a b c"), (2, "a b"), (3, "a"))
+ .toDF("number", "letters")
+ .explode('letters) { case Row(letters: String) =>
+ letters.split(' ').map(Tuple1.apply).toSeq
+ }
+ .as[(Int, String, String)]
+ .collect()
+ .toSeq
+ assert(
+ result2 === Seq(
+ (1, "a b c", "a"),
+ (1, "a b c", "b"),
+ (1, "a b c", "c"),
+ (2, "a b", "a"),
+ (2, "a b", "b"),
+ (3, "a", "a")))
+
+ val result3 = Seq("a b c", "d e")
+ .toDF("words")
+ .explode("words", "word") { word: String =>
+ word.split(' ').toSeq
+ }
+ .select(col("word"))
+ .as[String]
+ .collect()
+ .toSeq
+ assert(result3 === Seq("a", "b", "c", "d", "e"))
+
+ val result4 = Seq("a b c", "d e")
+ .toDF("words")
+ .explode("words", "word") { word: String =>
+ word.split(' ').map(s => s -> s.head.toInt).toSeq
+ }
+ .select(col("word"), col("words"))
+ .as[((String, Int), String)]
+ .collect()
+ .toSeq
+ assert(
+ result4 === Seq(
+ (("a", 97), "a b c"),
+ (("b", 98), "a b c"),
+ (("c", 99), "a b c"),
+ (("d", 100), "d e"),
+ (("e", 101), "d e")))
+ }
+
test("Dataset typed flat map - java") {
val rows = spark
.range(5)
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 04b162eceec..7356d4daa79 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -184,7 +184,6 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.sqlContext"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.metadataColumn"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.selectUntyped"),
// protected
- ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.explode"),
// deprecated
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.rdd"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.toJavaRDD"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.javaRDD"),
diff --git
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
index 16d5823f4a4..433614a4afc 100644
---
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
+++
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
@@ -131,6 +131,10 @@ private[sql] object UdfUtils extends Serializable {
def noOp[V, K](): V => K = _ => null.asInstanceOf[K]
+ def traversableOnceToSeq[A, B](f: A => TraversableOnce[B]): A => Seq[B] = {
value =>
+ f(value).toSeq
+ }
+
// (1 to 22).foreach { i =>
// val extTypeArgs = (0 to i).map(_ => "_").mkString(", ")
// val anyTypeArgs = (0 to i).map(_ => "Any").mkString(", ")
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 45f962f7920..e6305cd9d1a 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -508,7 +508,8 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder)
extends Logging {
val commonUdf = rel.getFunc
commonUdf.getFunctionCase match {
case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF
=>
- transformTypedMapPartitions(commonUdf, baseRel)
+ val analyzed = session.sessionState.executePlan(baseRel).analyzed
+ transformTypedMapPartitions(commonUdf, analyzed)
case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
val pythonUdf = transformPythonUDF(commonUdf)
val isBarrier = if (rel.hasIsBarrier) rel.getIsBarrier else false
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]