This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 10fd918754a [SPARK-42953][CONNECT] Typed filter, map, flatMap,
mapPartitions
10fd918754a is described below
commit 10fd918754a26077bf8a34b7bccbfc5fec49424c
Author: Zhen Li <[email protected]>
AuthorDate: Thu Apr 6 15:25:50 2023 -0400
[SPARK-42953][CONNECT] Typed filter, map, flatMap, mapPartitions
### What changes were proposed in this pull request?
Implemented new missing methods in the client Dataset API: filter, map,
flatMap, mapPartitions.
### Why are the changes needed?
Missing APIs
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Integration tests.
The UDF test does not work with maven.
Closes #40581 from zhenlineo/filter-typed.
Authored-by: Zhen Li <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
---
.../main/scala/org/apache/spark/sql/Dataset.scala | 110 ++++++++++++++++-
.../apache/spark/sql/connect/client/UdfUtils.scala | 46 ++++++++
.../sql/expressions/UserDefinedFunction.scala | 18 ++-
.../sql/UserDefinedFunctionE2ETestSuite.scala | 131 +++++++++++++++++++++
.../CheckConnectJvmClientCompatibility.scala | 4 -
.../connect/client/util/IntegrationTestUtils.scala | 22 +++-
.../connect/client/util/RemoteSparkSession.scala | 2 +-
.../sql/connect/planner/SparkConnectPlanner.scala | 105 +++++++++++++----
8 files changed, 408 insertions(+), 30 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 246e9f6e0a6..913dad00952 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -23,12 +23,14 @@ import scala.collection.mutable
import scala.util.control.NonFatal
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.api.java.function.{FilterFunction, FlatMapFunction,
MapFunction, MapPartitionsFunction}
import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
-import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveLongEncoder,
ProductEncoder, StringEncoder, UnboundRowEncoder}
+import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveBooleanEncoder,
PrimitiveLongEncoder, ProductEncoder, StringEncoder, UnboundRowEncoder}
import org.apache.spark.sql.catalyst.expressions.RowOrdering
-import org.apache.spark.sql.connect.client.SparkResult
+import org.apache.spark.sql.connect.client.{SparkResult, UdfUtils}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter,
StorageLevelProtoConverter}
+import org.apache.spark.sql.expressions.ScalarUserDefinedFunction
import org.apache.spark.sql.functions.{struct, to_json}
import org.apache.spark.sql.types.{Metadata, StructType}
import org.apache.spark.storage.StorageLevel
@@ -2468,6 +2470,110 @@ class Dataset[T] private[sql] (
*/
def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this)
+ /**
+ * (Scala-specific) Returns a new Dataset that only contains elements where
`func` returns
+ * `true`.
+ *
+ * @group typedrel
+ * @since 3.5.0
+ */
+ def filter(func: T => Boolean): Dataset[T] = {
+ val udf = ScalarUserDefinedFunction(
+ function = func,
+ inputEncoders = encoder :: Nil,
+ outputEncoder = PrimitiveBooleanEncoder)
+ sparkSession.newDataset[T](encoder) { builder =>
+ builder.getFilterBuilder
+ .setInput(plan.getRoot)
+ .setCondition(udf.apply(col("*")).expr)
+ }
+ }
+
+ /**
+ * (Java-specific) Returns a new Dataset that only contains elements where
`func` returns
+ * `true`.
+ *
+ * @group typedrel
+ * @since 3.5.0
+ */
+ def filter(f: FilterFunction[T]): Dataset[T] = {
+ filter(UdfUtils.filterFuncToScalaFunc(f))
+ }
+
+ /**
+ * (Scala-specific) Returns a new Dataset that contains the result of
applying `func` to each
+ * element.
+ *
+ * @group typedrel
+ * @since 3.5.0
+ */
+ def map[U: Encoder](f: T => U): Dataset[U] = {
+ mapPartitions(UdfUtils.mapFuncToMapPartitionsAdaptor(f))
+ }
+
+ /**
+ * (Java-specific) Returns a new Dataset that contains the result of
applying `func` to each
+ * element.
+ *
+ * @group typedrel
+ * @since 3.5.0
+ */
+ def map[U](f: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
+ map(UdfUtils.mapFunctionToScalaFunc(f))(encoder)
+ }
+
+ /**
+ * (Scala-specific) Returns a new Dataset that contains the result of
applying `func` to each
+ * partition.
+ *
+ * @group typedrel
+ * @since 3.5.0
+ */
+ def mapPartitions[U: Encoder](func: Iterator[T] => Iterator[U]): Dataset[U]
= {
+ val outputEncoder = encoderFor[U]
+ val udf = ScalarUserDefinedFunction(
+ function = func,
+ inputEncoders = encoder :: Nil,
+ outputEncoder = outputEncoder)
+ sparkSession.newDataset(outputEncoder) { builder =>
+ builder.getMapPartitionsBuilder
+ .setInput(plan.getRoot)
+ .setFunc(udf.apply().expr.getCommonInlineUserDefinedFunction)
+ }
+ }
+
+ /**
+ * (Java-specific) Returns a new Dataset that contains the result of
applying `f` to each
+ * partition.
+ *
+ * @group typedrel
+ * @since 3.5.0
+ */
+ def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]):
Dataset[U] = {
+ mapPartitions(UdfUtils.mapPartitionsFuncToScalaFunc(f))(encoder)
+ }
+
+ /**
+ * (Scala-specific) Returns a new Dataset by first applying a function to
all elements of this
+ * Dataset, and then flattening the results.
+ *
+ * @group typedrel
+ * @since 3.5.0
+ */
+ def flatMap[U: Encoder](func: T => TraversableOnce[U]): Dataset[U] =
+ mapPartitions(UdfUtils.flatMapFuncToMapPartitionsAdaptor(func))
+
+ /**
+ * (Java-specific) Returns a new Dataset by first applying a function to all
elements of this
+ * Dataset, and then flattening the results.
+ *
+ * @group typedrel
+ * @since 3.5.0
+ */
+ def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
+ flatMap(UdfUtils.flatMapFuncToScalaFunc(f))(encoder)
+ }
+
/**
* Returns the first `n` rows in the Dataset.
*
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/UdfUtils.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/UdfUtils.scala
new file mode 100644
index 00000000000..9e6eabeb4db
--- /dev/null
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/UdfUtils.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.client
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.api.java.function._
+
+/**
+ * Util functions to help convert input functions between typed filter, map,
flatMap,
+ * mapPartitions etc. These functions cannot be defined inside the client
Dataset class as it will
+ * cause Dataset sync conflicts when used together with UDFs. Thus we define
them outside, in the
+ * client package.
+ */
+private[sql] object UdfUtils {
+
+ def mapFuncToMapPartitionsAdaptor[T, U](f: T => U): Iterator[T] =>
Iterator[U] = _.map(f(_))
+
+ def flatMapFuncToMapPartitionsAdaptor[T, U](
+ f: T => TraversableOnce[U]): Iterator[T] => Iterator[U] = _.flatMap(f)
+
+ def filterFuncToScalaFunc[T](f: FilterFunction[T]): T => Boolean = f.call
+
+ def mapFunctionToScalaFunc[T, U](f: MapFunction[T, U]): T => U = f.call
+
+ def flatMapFuncToScalaFunc[T, U](f: FlatMapFunction[T, U]): T =>
TraversableOnce[U] = x =>
+ f.call(x).asScala
+
+ def mapPartitionsFuncToScalaFunc[T, U](
+ f: MapPartitionsFunction[T, U]): Iterator[T] => Iterator[U] = x =>
f.call(x.asJava).asScala
+
+}
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
index 0fe47092e4e..ad1aae73876 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -25,7 +25,7 @@ import org.apache.spark.connect.proto
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
-import org.apache.spark.sql.connect.common.UdfPacket
+import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, UdfPacket}
import org.apache.spark.util.Utils
/**
@@ -106,6 +106,10 @@ case class ScalarUserDefinedFunction(
val scalaUdfBuilder = proto.ScalarScalaUDF
.newBuilder()
.setPayload(ByteString.copyFrom(udfPacketBytes))
+ // Send the real inputs and return types to obtain the types without
deser the udf bytes.
+ .addAllInputTypes(
+
inputEncoders.map(_.dataType).map(DataTypeProtoConverter.toConnectProtoType).asJava)
+
.setOutputType(DataTypeProtoConverter.toConnectProtoType(outputEncoder.dataType))
.setNullable(nullable)
scalaUdfBuilder.build()
@@ -138,7 +142,17 @@ object ScalarUserDefinedFunction {
ScalarUserDefinedFunction(
function = function,
inputEncoders = parameterTypes.map(tag =>
ScalaReflection.encoderFor(tag)),
- outputEncoder = ScalaReflection.encoderFor(returnType),
+ outputEncoder = ScalaReflection.encoderFor(returnType))
+ }
+
+ private[sql] def apply(
+ function: AnyRef,
+ inputEncoders: Seq[AgnosticEncoder[_]],
+ outputEncoder: AgnosticEncoder[_]): ScalarUserDefinedFunction = {
+ ScalarUserDefinedFunction(
+ function = function,
+ inputEncoders = inputEncoders,
+ outputEncoder = outputEncoder,
name = None,
nullable = true,
deterministic = true)
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
new file mode 100644
index 00000000000..f352e106b8b
--- /dev/null
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import java.lang.{Long => JLong}
+import java.util.{Iterator => JIterator}
+import java.util.Arrays
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.api.java.function._
+import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveIntEncoder,
PrimitiveLongEncoder}
+import org.apache.spark.sql.connect.client.util.RemoteSparkSession
+import org.apache.spark.sql.functions.{col, udf}
+
+/**
+ * All tests in this class requires client UDF artifacts synced with the
server. TODO: It means
+ * these tests only works with SBT for now.
+ */
+class UserDefinedFunctionE2ETestSuite extends RemoteSparkSession {
+ test("Dataset typed filter") {
+ val rows = spark.range(10).filter(n => n % 2 == 0).collectAsList()
+ assert(rows == Arrays.asList[Long](0, 2, 4, 6, 8))
+ }
+
+ test("Dataset typed filter - java") {
+ val rows = spark
+ .range(10)
+ .filter(new FilterFunction[JLong] {
+ override def call(value: JLong): Boolean = value % 2 == 0
+ })
+ .collectAsList()
+ assert(rows == Arrays.asList[Long](0, 2, 4, 6, 8))
+ }
+
+ test("Dataset typed map") {
+ val rows = spark.range(10).map(n => n /
2)(PrimitiveLongEncoder).collectAsList()
+ assert(rows == Arrays.asList[Long](0, 0, 1, 1, 2, 2, 3, 3, 4, 4))
+ }
+
+ test("filter with condition") {
+ // This should go via `def filter(condition: Column)` rather than
+ // `def filter(func: T => Boolean)`
+ def func(i: Long): Boolean = i < 5
+ val under5 = udf(func _)
+ val longs = spark.range(10).filter(under5(col("id") * 2)).collectAsList()
+ assert(longs == Arrays.asList[Long](0, 1, 2))
+ }
+
+ test("filter with col(*)") {
+ // This should go via `def filter(condition: Column)` but it is executed as
+ // `def filter(func: T => Boolean)`. This is fine as the result is the
same.
+ def func(i: Long): Boolean = i < 5
+ val under5 = udf(func _)
+ val longs = spark.range(10).filter(under5(col("*"))).collectAsList()
+ assert(longs == Arrays.asList[Long](0, 1, 2, 3, 4))
+ }
+
+ test("Dataset typed map - java") {
+ val rows = spark
+ .range(10)
+ .map(
+ new MapFunction[JLong, Long] {
+ def call(value: JLong): Long = value / 2
+ },
+ PrimitiveLongEncoder)
+ .collectAsList()
+ assert(rows == Arrays.asList[Long](0, 0, 1, 1, 2, 2, 3, 3, 4, 4))
+ }
+
+ test("Dataset typed flat map") {
+ val session: SparkSession = spark
+ import session.implicits._
+ val rows = spark
+ .range(5)
+ .flatMap(n => Iterator(42, 42))
+ .collectAsList()
+ assert(rows.size() == 10)
+ rows.forEach(x => assert(x == 42))
+ }
+
+ test("Dataset typed flat map - java") {
+ val rows = spark
+ .range(5)
+ .flatMap(
+ new FlatMapFunction[JLong, Int] {
+ def call(value: JLong): JIterator[Int] = Arrays.asList(42,
42).iterator()
+ },
+ PrimitiveIntEncoder)
+ .collectAsList()
+ assert(rows.size() == 10)
+ rows.forEach(x => assert(x == 42))
+ }
+
+ test("Dataset typed map partition") {
+ val session: SparkSession = spark
+ import session.implicits._
+ val df = spark.range(0, 100, 1, 50).repartition(4)
+ val result =
+ df.mapPartitions(iter => Iterator.single(iter.length)).collect()
+ assert(result.sorted.toSeq === Seq(23, 25, 25, 27))
+ }
+
+ test("Dataset typed map partition - java") {
+ val df = spark.range(0, 100, 1, 50).repartition(4)
+ val result = df
+ .mapPartitions(
+ new MapPartitionsFunction[JLong, Int] {
+ override def call(input: JIterator[JLong]): JIterator[Int] = {
+ Arrays.asList(input.asScala.length).iterator()
+ }
+ },
+ PrimitiveIntEncoder)
+ .collect()
+ assert(result.sorted.toSeq === Seq(23, 25, 25, 27))
+ }
+}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 3a257cdd2df..301bfb26f9c 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -161,10 +161,6 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.reduce"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.groupByKey"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.explode"),
// deprecated
- ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.filter"),
- ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.map"),
-
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.mapPartitions"),
- ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.flatMap"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.foreach"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.foreachPartition"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.rdd"),
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
index 408caa58534..9196db175d2 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
@@ -27,6 +27,8 @@ object IntegrationTestUtils {
// System properties used for testing and debugging
private val DEBUG_SC_JVM_CLIENT = "spark.debug.sc.jvm.client"
+ // Enable this flag to print all client debug log + server logs to the
console
+ private[connect] val isDebug = System.getProperty(DEBUG_SC_JVM_CLIENT,
"false").toBoolean
private[sql] lazy val scalaVersion = {
versionNumberString.split('.') match {
@@ -43,7 +45,25 @@ object IntegrationTestUtils {
}
sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME"))
}
- private[connect] val isDebug = System.getProperty(DEBUG_SC_JVM_CLIENT,
"false").toBoolean
+
+ private[connect] lazy val debugConfig: Seq[String] = {
+ val log4j2 =
s"$sparkHome/connector/connect/client/jvm/src/test/resources/log4j2.properties"
+ if (isDebug) {
+ Seq(
+ // Enable to see the server plan change log
+ // "--conf",
+ // "spark.sql.planChangeLog.level=WARN",
+
+ // Enable to see the server grpc received
+ // "--conf",
+ // "spark.connect.grpc.interceptor.classes=" +
+ // "org.apache.spark.sql.connect.service.LoggingInterceptor",
+
+ // Redirect server log into console
+ "--conf",
+ s"spark.driver.extraJavaOptions=-Dlog4j.configuration=$log4j2")
+ } else Seq.empty
+ }
// Log server start stop debug info into console
// scalastyle:off println
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
index 43bf722020c..47d89961d52 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
@@ -86,7 +86,7 @@ object SparkConnectServerUtils {
"--conf",
"spark.sql.catalog.testcat=org.apache.spark.sql.connector.catalog.InMemoryTableCatalog",
"--conf",
- s"spark.sql.catalogImplementation=$catalogImplementation",
+ s"spark.sql.catalogImplementation=$catalogImplementation") ++
debugConfig ++ Seq(
"--class",
"org.apache.spark.sql.connect.SimpleSparkConnectService",
connectJar),
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 0700d8bea04..e1ea48d0da7 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -40,13 +40,13 @@ import org.apache.spark.ml.{functions => MLFunctions}
import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession}
import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier,
FunctionIdentifier}
-import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView,
MultiAlias, ParameterizedQuery, UnresolvedAlias, UnresolvedAttribute,
UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex,
UnresolvedRelation, UnresolvedStar}
+import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView,
MultiAlias, ParameterizedQuery, UnresolvedAlias, UnresolvedAttribute,
UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction,
UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser,
ParseException, ParserUtils}
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType,
LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics,
CommandResult, Deduplicate, Except, Intersect, LocalRelation, LogicalPlan,
Project, Sample, Sort, SubqueryAlias, Union, Unpivot, UnresolvedHint}
+import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics,
CommandResult, Deduplicate, DeserializeToObject, Except, Intersect,
LocalRelation, LogicalPlan, MapPartitions, Project, Sample,
SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot,
UnresolvedHint}
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap,
CharVarcharUtils}
import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter,
InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter,
UdfPacket}
@@ -493,27 +493,64 @@ class SparkConnectPlanner(val session: SparkSession) {
}
private def transformMapPartitions(rel: proto.MapPartitions): LogicalPlan = {
+ val baseRel = transformRelation(rel.getInput)
val commonUdf = rel.getFunc
- val pythonUdf = transformPythonUDF(commonUdf)
- val isBarrier = if (rel.hasIsBarrier) rel.getIsBarrier else false
- pythonUdf.evalType match {
- case PythonEvalType.SQL_MAP_PANDAS_ITER_UDF =>
- logical.MapInPandas(
- pythonUdf,
- pythonUdf.dataType.asInstanceOf[StructType].toAttributes,
- transformRelation(rel.getInput),
- isBarrier)
- case PythonEvalType.SQL_MAP_ARROW_ITER_UDF =>
- logical.PythonMapInArrow(
- pythonUdf,
- pythonUdf.dataType.asInstanceOf[StructType].toAttributes,
- transformRelation(rel.getInput),
- isBarrier)
+ commonUdf.getFunctionCase match {
+ case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF
=>
+ transformTypedMapPartitions(commonUdf, baseRel)
+ case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
+ val pythonUdf = transformPythonUDF(commonUdf)
+ val isBarrier = if (rel.hasIsBarrier) rel.getIsBarrier else false
+ pythonUdf.evalType match {
+ case PythonEvalType.SQL_MAP_PANDAS_ITER_UDF =>
+ logical.MapInPandas(
+ pythonUdf,
+ pythonUdf.dataType.asInstanceOf[StructType].toAttributes,
+ baseRel,
+ isBarrier)
+ case PythonEvalType.SQL_MAP_ARROW_ITER_UDF =>
+ logical.PythonMapInArrow(
+ pythonUdf,
+ pythonUdf.dataType.asInstanceOf[StructType].toAttributes,
+ baseRel,
+ isBarrier)
+ case _ =>
+ throw InvalidPlanInput(
+ s"Function with EvalType: ${pythonUdf.evalType} is not
supported")
+ }
case _ =>
- throw InvalidPlanInput(s"Function with EvalType: ${pythonUdf.evalType}
is not supported")
+ throw InvalidPlanInput(
+ s"Function with ID: ${commonUdf.getFunctionCase.getNumber} is not
supported")
}
}
+ private def generateObjAttr[T](enc: ExpressionEncoder[T]): Attribute = {
+ val dataType = enc.deserializer.dataType
+ val nullable = !enc.clsTag.runtimeClass.isPrimitive
+ AttributeReference("obj", dataType, nullable)()
+ }
+
+ private def transformTypedMapPartitions(
+ fun: proto.CommonInlineUserDefinedFunction,
+ child: LogicalPlan): LogicalPlan = {
+ val udf = fun.getScalarScalaUdf
+ val udfPacket =
+ Utils.deserialize[UdfPacket](
+ udf.getPayload.toByteArray,
+ SparkConnectArtifactManager.classLoaderWithArtifacts)
+ assert(udfPacket.inputEncoders.size == 1)
+ val iEnc = ExpressionEncoder(udfPacket.inputEncoders.head)
+ val rEnc = ExpressionEncoder(udfPacket.outputEncoder)
+
+ val deserializer = UnresolvedDeserializer(iEnc.deserializer)
+ val deserialized = DeserializeToObject(deserializer,
generateObjAttr(iEnc), child)
+ val mapped = MapPartitions(
+ udfPacket.function.asInstanceOf[Iterator[Any] => Iterator[Any]],
+ generateObjAttr(rEnc),
+ deserialized)
+ SerializeFromObject(rEnc.namedExpressions, mapped)
+ }
+
private def transformGroupMap(rel: proto.GroupMap): LogicalPlan = {
val pythonUdf = transformPythonUDF(rel.getFunc)
val cols =
@@ -887,7 +924,35 @@ class SparkConnectPlanner(val session: SparkSession) {
private def transformFilter(rel: proto.Filter): LogicalPlan = {
assert(rel.hasInput)
val baseRel = transformRelation(rel.getInput)
- logical.Filter(condition = transformExpression(rel.getCondition), child =
baseRel)
+ val cond = rel.getCondition
+ cond.getExprTypeCase match {
+ case proto.Expression.ExprTypeCase.COMMON_INLINE_USER_DEFINED_FUNCTION
+ if isTypedFilter(cond.getCommonInlineUserDefinedFunction) =>
+ transformTypedFilter(cond.getCommonInlineUserDefinedFunction, baseRel)
+ case _ =>
+ logical.Filter(condition = transformExpression(cond), child = baseRel)
+ }
+ }
+
+ private def isTypedFilter(udf: proto.CommonInlineUserDefinedFunction):
Boolean = {
+ // It is a scala udf && the udf argument is an unresolved start.
+ // This means the udf is a typed filter to filter on all inputs
+ udf.getFunctionCase ==
proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF &&
+ udf.getArgumentsCount == 1 &&
+ udf.getArguments(0).getExprTypeCase ==
proto.Expression.ExprTypeCase.UNRESOLVED_STAR
+ }
+
+ private def transformTypedFilter(
+ fun: proto.CommonInlineUserDefinedFunction,
+ child: LogicalPlan): TypedFilter = {
+ val udf = fun.getScalarScalaUdf
+ val udfPacket =
+ Utils.deserialize[UdfPacket](
+ udf.getPayload.toByteArray,
+ SparkConnectArtifactManager.classLoaderWithArtifacts)
+ assert(udfPacket.inputEncoders.size == 1)
+ val iEnc = ExpressionEncoder(udfPacket.inputEncoders.head)
+ TypedFilter(udfPacket.function, child)(iEnc)
}
private def transformProject(rel: proto.Project): LogicalPlan = {
@@ -1065,7 +1130,7 @@ class SparkConnectPlanner(val session: SparkSession) {
SparkConnectArtifactManager.classLoaderWithArtifacts)
ScalaUDF(
function = udfPacket.function,
- dataType = udfPacket.outputEncoder.dataType,
+ dataType = transformDataType(udf.getOutputType),
children = fun.getArgumentsList.asScala.map(transformExpression).toSeq,
inputEncoders = udfPacket.inputEncoders.map(e =>
Option(ExpressionEncoder(e))),
outputEncoder = Option(ExpressionEncoder(udfPacket.outputEncoder)),
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]