This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f496cd1ee2a [SPARK-44736][CONNECT] Add Dataset.explode to Spark Connect Scala Client f496cd1ee2a is described below commit f496cd1ee2a7e59af08e1bd3ab0579f93cc46da9 Author: Herman van Hovell <her...@databricks.com> AuthorDate: Sun Aug 13 20:27:08 2023 +0200 [SPARK-44736][CONNECT] Add Dataset.explode to Spark Connect Scala Client ### What changes were proposed in this pull request? This PR adds Dataset.explode to the Spark Connect Scala Client. ### Why are the changes needed? To increase compatibility with the existing Dataset API in sql/core. ### Does this PR introduce _any_ user-facing change? Yes, it adds a new method to the scala client. ### How was this patch tested? I added a test to `UserDefinedFunctionE2ETestSuite`. Closes #42418 from hvanhovell/SPARK-44736. Lead-authored-by: Herman van Hovell <her...@databricks.com> Co-authored-by: itholic <haejoon....@databricks.com> Co-authored-by: Juliusz Sompolski <ju...@databricks.com> Co-authored-by: Martin Grund <martin.gr...@databricks.com> Co-authored-by: Hyukjin Kwon <gurwls...@apache.org> Co-authored-by: Kent Yao <y...@apache.org> Co-authored-by: Wenchen Fan <wenc...@databricks.com> Co-authored-by: Wei Liu <wei....@databricks.com> Co-authored-by: Ruifeng Zheng <ruife...@apache.org> Co-authored-by: Gengliang Wang <gengli...@apache.org> Co-authored-by: Yuming Wang <yumw...@ebay.com> Co-authored-by: Herman van Hovell <hvanhov...@databricks.com> Co-authored-by: 余良 <yul...@chinaunicom.cn> Co-authored-by: Dongjoon Hyun <dh...@apple.com> Co-authored-by: Jack Chen <jack.c...@databricks.com> Co-authored-by: srielau <se...@rielau.com> Co-authored-by: zhyhimont <zhyhim...@gmail.com> Co-authored-by: Daniel Tenedorio <daniel.tenedo...@databricks.com> Co-authored-by: Dongjoon Hyun <dongj...@apache.org> Co-authored-by: Zhyhimont Dmitry <zhyhimon...@profitero.com> Co-authored-by: Sandip Agarwala <131817656+sandip...@users.noreply.github.com> Co-authored-by: yangjie01 <yangji...@baidu.com> Co-authored-by: Yihong He <yihong...@databricks.com> Co-authored-by: Rameshkrishnan Muthusamy <rameshkrishnan_muthus...@apple.com> Co-authored-by: Jia Fan <fanjiaemi...@qq.com> Co-authored-by: allisonwang-db <allison.w...@databricks.com> Co-authored-by: Utkarsh <utkarsh.agar...@databricks.com> Co-authored-by: Cheng Pan <cheng...@apache.org> Co-authored-by: Jason Li <jason...@databricks.com> Co-authored-by: Shu Wang <swa...@linkedin.com> Co-authored-by: Nicolas Fraison <nicolas.frai...@datadoghq.com> Co-authored-by: Max Gekk <max.g...@gmail.com> Co-authored-by: panbingkun <pbk1...@gmail.com> Co-authored-by: Ziqi Liu <ziqi....@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../main/scala/org/apache/spark/sql/Dataset.scala | 70 ++++++++++++++++++++++ .../sql/UserDefinedFunctionE2ETestSuite.scala | 60 +++++++++++++++++++ .../CheckConnectJvmClientCompatibility.scala | 1 - .../apache/spark/sql/connect/common/UdfUtils.scala | 4 ++ .../sql/connect/planner/SparkConnectPlanner.scala | 3 +- 5 files changed, 136 insertions(+), 2 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 2d72ea6bda8..28b04fb850e 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -21,12 +21,14 @@ import java.util.{Collections, Locale} import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.function._ import org.apache.spark.connect.proto +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.expressions.OrderUtils @@ -2728,6 +2730,74 @@ class Dataset[T] private[sql] ( flatMap(UdfUtils.flatMapFuncToScalaFunc(f))(encoder) } + /** + * (Scala-specific) Returns a new Dataset where each row has been expanded to zero or more rows + * by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of the + * input row are implicitly joined with each row that is output by the function. + * + * Given that this is deprecated, as an alternative, you can explode columns either using + * `functions.explode()` or `flatMap()`. The following example uses these alternatives to count + * the number of books that contain a given word: + * + * {{{ + * case class Book(title: String, words: String) + * val ds: Dataset[Book] + * + * val allWords = ds.select($"title", explode(split($"words", " ")).as("word")) + * + * val bookCountPerWord = allWords.groupBy("word").agg(count_distinct("title")) + * }}} + * + * Using `flatMap()` this can similarly be exploded as: + * + * {{{ + * ds.flatMap(_.words.split(" ")) + * }}} + * + * @group untypedrel + * @since 3.5.0 + */ + @deprecated("use flatMap() or select() with functions.explode() instead", "3.5.0") + def explode[A <: Product: TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = { + val generator = ScalarUserDefinedFunction( + UdfUtils.traversableOnceToSeq(f), + UnboundRowEncoder :: Nil, + ScalaReflection.encoderFor[Seq[A]]) + select(col("*"), functions.inline(generator(struct(input: _*)))) + } + + /** + * (Scala-specific) Returns a new Dataset where a single column has been expanded to zero or + * more rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. All + * columns of the input row are implicitly joined with each value that is output by the + * function. + * + * Given that this is deprecated, as an alternative, you can explode columns either using + * `functions.explode()`: + * + * {{{ + * ds.select(explode(split($"words", " ")).as("word")) + * }}} + * + * or `flatMap()`: + * + * {{{ + * ds.flatMap(_.words.split(" ")) + * }}} + * + * @group untypedrel + * @since 3.5.0 + */ + @deprecated("use flatMap() or select() with functions.explode() instead", "3.5.0") + def explode[A, B: TypeTag](inputColumn: String, outputColumn: String)( + f: A => TraversableOnce[B]): DataFrame = { + val generator = ScalarUserDefinedFunction( + UdfUtils.traversableOnceToSeq(f), + Nil, + ScalaReflection.encoderFor[Seq[B]]) + select(col("*"), functions.explode(generator(col(inputColumn))).as((outputColumn))) + } + /** * Applies a function `f` to all rows. * diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala index 3a931c9a6ba..d00659ac2d8 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala @@ -95,6 +95,66 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest { rows.forEach(x => assert(x == 42)) } + test("(deprecated) Dataset explode") { + val session: SparkSession = spark + import session.implicits._ + val result1 = spark + .range(3) + .filter(col("id") =!= 1L) + .explode(col("id") + 41, col("id") + 10) { case Row(x: Long, y: Long) => + Iterator((x, x - 1), (y, y + 1)) + } + .as[(Long, Long, Long)] + .collect() + .toSeq + assert(result1 === Seq((0L, 41L, 40L), (0L, 10L, 11L), (2L, 43L, 42L), (2L, 12L, 13L))) + + val result2 = Seq((1, "a b c"), (2, "a b"), (3, "a")) + .toDF("number", "letters") + .explode('letters) { case Row(letters: String) => + letters.split(' ').map(Tuple1.apply).toSeq + } + .as[(Int, String, String)] + .collect() + .toSeq + assert( + result2 === Seq( + (1, "a b c", "a"), + (1, "a b c", "b"), + (1, "a b c", "c"), + (2, "a b", "a"), + (2, "a b", "b"), + (3, "a", "a"))) + + val result3 = Seq("a b c", "d e") + .toDF("words") + .explode("words", "word") { word: String => + word.split(' ').toSeq + } + .select(col("word")) + .as[String] + .collect() + .toSeq + assert(result3 === Seq("a", "b", "c", "d", "e")) + + val result4 = Seq("a b c", "d e") + .toDF("words") + .explode("words", "word") { word: String => + word.split(' ').map(s => s -> s.head.toInt).toSeq + } + .select(col("word"), col("words")) + .as[((String, Int), String)] + .collect() + .toSeq + assert( + result4 === Seq( + (("a", 97), "a b c"), + (("b", 98), "a b c"), + (("c", 99), "a b c"), + (("d", 100), "d e"), + (("e", 101), "d e"))) + } + test("Dataset typed flat map - java") { val rows = spark .range(5) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 04b162eceec..7356d4daa79 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -184,7 +184,6 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.sqlContext"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.metadataColumn"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.selectUntyped"), // protected - ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.explode"), // deprecated ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.rdd"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.toJavaRDD"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.javaRDD"), diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala index 16d5823f4a4..433614a4afc 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala @@ -131,6 +131,10 @@ private[sql] object UdfUtils extends Serializable { def noOp[V, K](): V => K = _ => null.asInstanceOf[K] + def traversableOnceToSeq[A, B](f: A => TraversableOnce[B]): A => Seq[B] = { value => + f(value).toSeq + } + // (1 to 22).foreach { i => // val extTypeArgs = (0 to i).map(_ => "_").mkString(", ") // val anyTypeArgs = (0 to i).map(_ => "Any").mkString(", ") diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 45f962f7920..e6305cd9d1a 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -508,7 +508,8 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { val commonUdf = rel.getFunc commonUdf.getFunctionCase match { case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF => - transformTypedMapPartitions(commonUdf, baseRel) + val analyzed = session.sessionState.executePlan(baseRel).analyzed + transformTypedMapPartitions(commonUdf, analyzed) case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF => val pythonUdf = transformPythonUDF(commonUdf) val isBarrier = if (rel.hasIsBarrier) rel.getIsBarrier else false --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org