This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 10fd918754a [SPARK-42953][CONNECT] Typed filter, map, flatMap, 
mapPartitions
10fd918754a is described below

commit 10fd918754a26077bf8a34b7bccbfc5fec49424c
Author: Zhen Li <[email protected]>
AuthorDate: Thu Apr 6 15:25:50 2023 -0400

    [SPARK-42953][CONNECT] Typed filter, map, flatMap, mapPartitions
    
    ### What changes were proposed in this pull request?
    Implemented new missing methods in the client Dataset API: filter, map, 
flatMap, mapPartitions.
    
    ### Why are the changes needed?
    Missing APIs
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Integration tests.
    The UDF test does not work with maven.
    
    Closes #40581 from zhenlineo/filter-typed.
    
    Authored-by: Zhen Li <[email protected]>
    Signed-off-by: Herman van Hovell <[email protected]>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 110 ++++++++++++++++-
 .../apache/spark/sql/connect/client/UdfUtils.scala |  46 ++++++++
 .../sql/expressions/UserDefinedFunction.scala      |  18 ++-
 .../sql/UserDefinedFunctionE2ETestSuite.scala      | 131 +++++++++++++++++++++
 .../CheckConnectJvmClientCompatibility.scala       |   4 -
 .../connect/client/util/IntegrationTestUtils.scala |  22 +++-
 .../connect/client/util/RemoteSparkSession.scala   |   2 +-
 .../sql/connect/planner/SparkConnectPlanner.scala  | 105 +++++++++++++----
 8 files changed, 408 insertions(+), 30 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 246e9f6e0a6..913dad00952 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -23,12 +23,14 @@ import scala.collection.mutable
 import scala.util.control.NonFatal
 
 import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.api.java.function.{FilterFunction, FlatMapFunction, 
MapFunction, MapPartitionsFunction}
 import org.apache.spark.connect.proto
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
-import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveLongEncoder, 
ProductEncoder, StringEncoder, UnboundRowEncoder}
+import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveBooleanEncoder,
 PrimitiveLongEncoder, ProductEncoder, StringEncoder, UnboundRowEncoder}
 import org.apache.spark.sql.catalyst.expressions.RowOrdering
-import org.apache.spark.sql.connect.client.SparkResult
+import org.apache.spark.sql.connect.client.{SparkResult, UdfUtils}
 import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, 
StorageLevelProtoConverter}
+import org.apache.spark.sql.expressions.ScalarUserDefinedFunction
 import org.apache.spark.sql.functions.{struct, to_json}
 import org.apache.spark.sql.types.{Metadata, StructType}
 import org.apache.spark.storage.StorageLevel
@@ -2468,6 +2470,110 @@ class Dataset[T] private[sql] (
    */
   def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this)
 
+  /**
+   * (Scala-specific) Returns a new Dataset that only contains elements where 
`func` returns
+   * `true`.
+   *
+   * @group typedrel
+   * @since 3.5.0
+   */
+  def filter(func: T => Boolean): Dataset[T] = {
+    val udf = ScalarUserDefinedFunction(
+      function = func,
+      inputEncoders = encoder :: Nil,
+      outputEncoder = PrimitiveBooleanEncoder)
+    sparkSession.newDataset[T](encoder) { builder =>
+      builder.getFilterBuilder
+        .setInput(plan.getRoot)
+        .setCondition(udf.apply(col("*")).expr)
+    }
+  }
+
+  /**
+   * (Java-specific) Returns a new Dataset that only contains elements where 
`func` returns
+   * `true`.
+   *
+   * @group typedrel
+   * @since 3.5.0
+   */
+  def filter(f: FilterFunction[T]): Dataset[T] = {
+    filter(UdfUtils.filterFuncToScalaFunc(f))
+  }
+
+  /**
+   * (Scala-specific) Returns a new Dataset that contains the result of 
applying `func` to each
+   * element.
+   *
+   * @group typedrel
+   * @since 3.5.0
+   */
+  def map[U: Encoder](f: T => U): Dataset[U] = {
+    mapPartitions(UdfUtils.mapFuncToMapPartitionsAdaptor(f))
+  }
+
+  /**
+   * (Java-specific) Returns a new Dataset that contains the result of 
applying `func` to each
+   * element.
+   *
+   * @group typedrel
+   * @since 3.5.0
+   */
+  def map[U](f: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
+    map(UdfUtils.mapFunctionToScalaFunc(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific) Returns a new Dataset that contains the result of 
applying `func` to each
+   * partition.
+   *
+   * @group typedrel
+   * @since 3.5.0
+   */
+  def mapPartitions[U: Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] 
= {
+    val outputEncoder = encoderFor[U]
+    val udf = ScalarUserDefinedFunction(
+      function = func,
+      inputEncoders = encoder :: Nil,
+      outputEncoder = outputEncoder)
+    sparkSession.newDataset(outputEncoder) { builder =>
+      builder.getMapPartitionsBuilder
+        .setInput(plan.getRoot)
+        .setFunc(udf.apply().expr.getCommonInlineUserDefinedFunction)
+    }
+  }
+
+  /**
+   * (Java-specific) Returns a new Dataset that contains the result of 
applying `f` to each
+   * partition.
+   *
+   * @group typedrel
+   * @since 3.5.0
+   */
+  def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): 
Dataset[U] = {
+    mapPartitions(UdfUtils.mapPartitionsFuncToScalaFunc(f))(encoder)
+  }
+
+  /**
+   * (Scala-specific) Returns a new Dataset by first applying a function to 
all elements of this
+   * Dataset, and then flattening the results.
+   *
+   * @group typedrel
+   * @since 3.5.0
+   */
+  def flatMap[U: Encoder](func: T => TraversableOnce[U]): Dataset[U] =
+    mapPartitions(UdfUtils.flatMapFuncToMapPartitionsAdaptor(func))
+
+  /**
+   * (Java-specific) Returns a new Dataset by first applying a function to all 
elements of this
+   * Dataset, and then flattening the results.
+   *
+   * @group typedrel
+   * @since 3.5.0
+   */
+  def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
+    flatMap(UdfUtils.flatMapFuncToScalaFunc(f))(encoder)
+  }
+
   /**
    * Returns the first `n` rows in the Dataset.
    *
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/UdfUtils.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/UdfUtils.scala
new file mode 100644
index 00000000000..9e6eabeb4db
--- /dev/null
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/UdfUtils.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.client
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.api.java.function._
+
+/**
+ * Util functions to help convert input functions between typed filter, map, 
flatMap,
+ * mapPartitions etc. These functions cannot be defined inside the client 
Dataset class as it will
+ * cause Dataset sync conflicts when used together with UDFs. Thus we define 
them outside, in the
+ * client package.
+ */
+private[sql] object UdfUtils {
+
+  def mapFuncToMapPartitionsAdaptor[T, U](f: T => U): Iterator[T] => 
Iterator[U] = _.map(f(_))
+
+  def flatMapFuncToMapPartitionsAdaptor[T, U](
+      f: T => TraversableOnce[U]): Iterator[T] => Iterator[U] = _.flatMap(f)
+
+  def filterFuncToScalaFunc[T](f: FilterFunction[T]): T => Boolean = f.call
+
+  def mapFunctionToScalaFunc[T, U](f: MapFunction[T, U]): T => U = f.call
+
+  def flatMapFuncToScalaFunc[T, U](f: FlatMapFunction[T, U]): T => 
TraversableOnce[U] = x =>
+    f.call(x).asScala
+
+  def mapPartitionsFuncToScalaFunc[T, U](
+      f: MapPartitionsFunction[T, U]): Iterator[T] => Iterator[U] = x => 
f.call(x.asJava).asScala
+
+}
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
index 0fe47092e4e..ad1aae73876 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -25,7 +25,7 @@ import org.apache.spark.connect.proto
 import org.apache.spark.sql.Column
 import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
-import org.apache.spark.sql.connect.common.UdfPacket
+import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, UdfPacket}
 import org.apache.spark.util.Utils
 
 /**
@@ -106,6 +106,10 @@ case class ScalarUserDefinedFunction(
     val scalaUdfBuilder = proto.ScalarScalaUDF
       .newBuilder()
       .setPayload(ByteString.copyFrom(udfPacketBytes))
+      // Send the real inputs and return types to obtain the types without 
deser the udf bytes.
+      .addAllInputTypes(
+        
inputEncoders.map(_.dataType).map(DataTypeProtoConverter.toConnectProtoType).asJava)
+      
.setOutputType(DataTypeProtoConverter.toConnectProtoType(outputEncoder.dataType))
       .setNullable(nullable)
 
     scalaUdfBuilder.build()
@@ -138,7 +142,17 @@ object ScalarUserDefinedFunction {
     ScalarUserDefinedFunction(
       function = function,
       inputEncoders = parameterTypes.map(tag => 
ScalaReflection.encoderFor(tag)),
-      outputEncoder = ScalaReflection.encoderFor(returnType),
+      outputEncoder = ScalaReflection.encoderFor(returnType))
+  }
+
+  private[sql] def apply(
+      function: AnyRef,
+      inputEncoders: Seq[AgnosticEncoder[_]],
+      outputEncoder: AgnosticEncoder[_]): ScalarUserDefinedFunction = {
+    ScalarUserDefinedFunction(
+      function = function,
+      inputEncoders = inputEncoders,
+      outputEncoder = outputEncoder,
       name = None,
       nullable = true,
       deterministic = true)
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
new file mode 100644
index 00000000000..f352e106b8b
--- /dev/null
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import java.lang.{Long => JLong}
+import java.util.{Iterator => JIterator}
+import java.util.Arrays
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.api.java.function._
+import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveIntEncoder, 
PrimitiveLongEncoder}
+import org.apache.spark.sql.connect.client.util.RemoteSparkSession
+import org.apache.spark.sql.functions.{col, udf}
+
+/**
+ * All tests in this class requires client UDF artifacts synced with the 
server. TODO: It means
+ * these tests only works with SBT for now.
+ */
+class UserDefinedFunctionE2ETestSuite extends RemoteSparkSession {
+  test("Dataset typed filter") {
+    val rows = spark.range(10).filter(n => n % 2 == 0).collectAsList()
+    assert(rows == Arrays.asList[Long](0, 2, 4, 6, 8))
+  }
+
+  test("Dataset typed filter - java") {
+    val rows = spark
+      .range(10)
+      .filter(new FilterFunction[JLong] {
+        override def call(value: JLong): Boolean = value % 2 == 0
+      })
+      .collectAsList()
+    assert(rows == Arrays.asList[Long](0, 2, 4, 6, 8))
+  }
+
+  test("Dataset typed map") {
+    val rows = spark.range(10).map(n => n / 
2)(PrimitiveLongEncoder).collectAsList()
+    assert(rows == Arrays.asList[Long](0, 0, 1, 1, 2, 2, 3, 3, 4, 4))
+  }
+
+  test("filter with condition") {
+    // This should go via `def filter(condition: Column)` rather than
+    // `def filter(func: T => Boolean)`
+    def func(i: Long): Boolean = i < 5
+    val under5 = udf(func _)
+    val longs = spark.range(10).filter(under5(col("id") * 2)).collectAsList()
+    assert(longs == Arrays.asList[Long](0, 1, 2))
+  }
+
+  test("filter with col(*)") {
+    // This should go via `def filter(condition: Column)` but it is executed as
+    // `def filter(func: T => Boolean)`. This is fine as the result is the 
same.
+    def func(i: Long): Boolean = i < 5
+    val under5 = udf(func _)
+    val longs = spark.range(10).filter(under5(col("*"))).collectAsList()
+    assert(longs == Arrays.asList[Long](0, 1, 2, 3, 4))
+  }
+
+  test("Dataset typed map - java") {
+    val rows = spark
+      .range(10)
+      .map(
+        new MapFunction[JLong, Long] {
+          def call(value: JLong): Long = value / 2
+        },
+        PrimitiveLongEncoder)
+      .collectAsList()
+    assert(rows == Arrays.asList[Long](0, 0, 1, 1, 2, 2, 3, 3, 4, 4))
+  }
+
+  test("Dataset typed flat map") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val rows = spark
+      .range(5)
+      .flatMap(n => Iterator(42, 42))
+      .collectAsList()
+    assert(rows.size() == 10)
+    rows.forEach(x => assert(x == 42))
+  }
+
+  test("Dataset typed flat map - java") {
+    val rows = spark
+      .range(5)
+      .flatMap(
+        new FlatMapFunction[JLong, Int] {
+          def call(value: JLong): JIterator[Int] = Arrays.asList(42, 
42).iterator()
+        },
+        PrimitiveIntEncoder)
+      .collectAsList()
+    assert(rows.size() == 10)
+    rows.forEach(x => assert(x == 42))
+  }
+
+  test("Dataset typed map partition") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val df = spark.range(0, 100, 1, 50).repartition(4)
+    val result =
+      df.mapPartitions(iter => Iterator.single(iter.length)).collect()
+    assert(result.sorted.toSeq === Seq(23, 25, 25, 27))
+  }
+
+  test("Dataset typed map partition - java") {
+    val df = spark.range(0, 100, 1, 50).repartition(4)
+    val result = df
+      .mapPartitions(
+        new MapPartitionsFunction[JLong, Int] {
+          override def call(input: JIterator[JLong]): JIterator[Int] = {
+            Arrays.asList(input.asScala.length).iterator()
+          }
+        },
+        PrimitiveIntEncoder)
+      .collect()
+    assert(result.sorted.toSeq === Seq(23, 25, 25, 27))
+  }
+}
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 3a257cdd2df..301bfb26f9c 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -161,10 +161,6 @@ object CheckConnectJvmClientCompatibility {
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.reduce"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.groupByKey"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.explode"), 
// deprecated
-      ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.filter"),
-      ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.map"),
-      
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.mapPartitions"),
-      ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.flatMap"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.foreach"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.foreachPartition"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.rdd"),
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
index 408caa58534..9196db175d2 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
@@ -27,6 +27,8 @@ object IntegrationTestUtils {
 
   // System properties used for testing and debugging
   private val DEBUG_SC_JVM_CLIENT = "spark.debug.sc.jvm.client"
+  // Enable this flag to print all client debug log + server logs to the 
console
+  private[connect] val isDebug = System.getProperty(DEBUG_SC_JVM_CLIENT, 
"false").toBoolean
 
   private[sql] lazy val scalaVersion = {
     versionNumberString.split('.') match {
@@ -43,7 +45,25 @@ object IntegrationTestUtils {
     }
     sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME"))
   }
-  private[connect] val isDebug = System.getProperty(DEBUG_SC_JVM_CLIENT, 
"false").toBoolean
+
+  private[connect] lazy val debugConfig: Seq[String] = {
+    val log4j2 = 
s"$sparkHome/connector/connect/client/jvm/src/test/resources/log4j2.properties"
+    if (isDebug) {
+      Seq(
+        // Enable to see the server plan change log
+        // "--conf",
+        // "spark.sql.planChangeLog.level=WARN",
+
+        // Enable to see the server grpc received
+        // "--conf",
+        // "spark.connect.grpc.interceptor.classes=" +
+        //  "org.apache.spark.sql.connect.service.LoggingInterceptor",
+
+        // Redirect server log into console
+        "--conf",
+        s"spark.driver.extraJavaOptions=-Dlog4j.configuration=$log4j2")
+    } else Seq.empty
+  }
 
   // Log server start stop debug info into console
   // scalastyle:off println
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
index 43bf722020c..47d89961d52 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
@@ -86,7 +86,7 @@ object SparkConnectServerUtils {
         "--conf",
         
"spark.sql.catalog.testcat=org.apache.spark.sql.connector.catalog.InMemoryTableCatalog",
         "--conf",
-        s"spark.sql.catalogImplementation=$catalogImplementation",
+        s"spark.sql.catalogImplementation=$catalogImplementation") ++ 
debugConfig ++ Seq(
         "--class",
         "org.apache.spark.sql.connect.SimpleSparkConnectService",
         connectJar),
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 0700d8bea04..e1ea48d0da7 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -40,13 +40,13 @@ import org.apache.spark.ml.{functions => MLFunctions}
 import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession}
 import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
 import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, 
FunctionIdentifier}
-import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, 
MultiAlias, ParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, 
UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, 
UnresolvedRelation, UnresolvedStar}
+import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, 
MultiAlias, ParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, 
UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, 
UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, 
ParseException, ParserUtils}
 import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, 
LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
 import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, 
CommandResult, Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, 
Project, Sample, Sort, SubqueryAlias, Union, Unpivot, UnresolvedHint}
+import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, 
CommandResult, Deduplicate, DeserializeToObject, Except, Intersect, 
LocalRelation, LogicalPlan, MapPartitions, Project, Sample, 
SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, 
UnresolvedHint}
 import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, 
CharVarcharUtils}
 import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager
 import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, 
InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, 
UdfPacket}
@@ -493,27 +493,64 @@ class SparkConnectPlanner(val session: SparkSession) {
   }
 
   private def transformMapPartitions(rel: proto.MapPartitions): LogicalPlan = {
+    val baseRel = transformRelation(rel.getInput)
     val commonUdf = rel.getFunc
-    val pythonUdf = transformPythonUDF(commonUdf)
-    val isBarrier = if (rel.hasIsBarrier) rel.getIsBarrier else false
-    pythonUdf.evalType match {
-      case PythonEvalType.SQL_MAP_PANDAS_ITER_UDF =>
-        logical.MapInPandas(
-          pythonUdf,
-          pythonUdf.dataType.asInstanceOf[StructType].toAttributes,
-          transformRelation(rel.getInput),
-          isBarrier)
-      case PythonEvalType.SQL_MAP_ARROW_ITER_UDF =>
-        logical.PythonMapInArrow(
-          pythonUdf,
-          pythonUdf.dataType.asInstanceOf[StructType].toAttributes,
-          transformRelation(rel.getInput),
-          isBarrier)
+    commonUdf.getFunctionCase match {
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF 
=>
+        transformTypedMapPartitions(commonUdf, baseRel)
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
+        val pythonUdf = transformPythonUDF(commonUdf)
+        val isBarrier = if (rel.hasIsBarrier) rel.getIsBarrier else false
+        pythonUdf.evalType match {
+          case PythonEvalType.SQL_MAP_PANDAS_ITER_UDF =>
+            logical.MapInPandas(
+              pythonUdf,
+              pythonUdf.dataType.asInstanceOf[StructType].toAttributes,
+              baseRel,
+              isBarrier)
+          case PythonEvalType.SQL_MAP_ARROW_ITER_UDF =>
+            logical.PythonMapInArrow(
+              pythonUdf,
+              pythonUdf.dataType.asInstanceOf[StructType].toAttributes,
+              baseRel,
+              isBarrier)
+          case _ =>
+            throw InvalidPlanInput(
+              s"Function with EvalType: ${pythonUdf.evalType} is not 
supported")
+        }
       case _ =>
-        throw InvalidPlanInput(s"Function with EvalType: ${pythonUdf.evalType} 
is not supported")
+        throw InvalidPlanInput(
+          s"Function with ID: ${commonUdf.getFunctionCase.getNumber} is not 
supported")
     }
   }
 
+  private def generateObjAttr[T](enc: ExpressionEncoder[T]): Attribute = {
+    val dataType = enc.deserializer.dataType
+    val nullable = !enc.clsTag.runtimeClass.isPrimitive
+    AttributeReference("obj", dataType, nullable)()
+  }
+
+  private def transformTypedMapPartitions(
+      fun: proto.CommonInlineUserDefinedFunction,
+      child: LogicalPlan): LogicalPlan = {
+    val udf = fun.getScalarScalaUdf
+    val udfPacket =
+      Utils.deserialize[UdfPacket](
+        udf.getPayload.toByteArray,
+        SparkConnectArtifactManager.classLoaderWithArtifacts)
+    assert(udfPacket.inputEncoders.size == 1)
+    val iEnc = ExpressionEncoder(udfPacket.inputEncoders.head)
+    val rEnc = ExpressionEncoder(udfPacket.outputEncoder)
+
+    val deserializer = UnresolvedDeserializer(iEnc.deserializer)
+    val deserialized = DeserializeToObject(deserializer, 
generateObjAttr(iEnc), child)
+    val mapped = MapPartitions(
+      udfPacket.function.asInstanceOf[Iterator[Any] => Iterator[Any]],
+      generateObjAttr(rEnc),
+      deserialized)
+    SerializeFromObject(rEnc.namedExpressions, mapped)
+  }
+
   private def transformGroupMap(rel: proto.GroupMap): LogicalPlan = {
     val pythonUdf = transformPythonUDF(rel.getFunc)
     val cols =
@@ -887,7 +924,35 @@ class SparkConnectPlanner(val session: SparkSession) {
   private def transformFilter(rel: proto.Filter): LogicalPlan = {
     assert(rel.hasInput)
     val baseRel = transformRelation(rel.getInput)
-    logical.Filter(condition = transformExpression(rel.getCondition), child = 
baseRel)
+    val cond = rel.getCondition
+    cond.getExprTypeCase match {
+      case proto.Expression.ExprTypeCase.COMMON_INLINE_USER_DEFINED_FUNCTION
+          if isTypedFilter(cond.getCommonInlineUserDefinedFunction) =>
+        transformTypedFilter(cond.getCommonInlineUserDefinedFunction, baseRel)
+      case _ =>
+        logical.Filter(condition = transformExpression(cond), child = baseRel)
+    }
+  }
+
+  private def isTypedFilter(udf: proto.CommonInlineUserDefinedFunction): 
Boolean = {
+    // It is a scala udf && the udf argument is an unresolved start.
+    // This means the udf is a typed filter to filter on all inputs
+    udf.getFunctionCase == 
proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF &&
+    udf.getArgumentsCount == 1 &&
+    udf.getArguments(0).getExprTypeCase == 
proto.Expression.ExprTypeCase.UNRESOLVED_STAR
+  }
+
+  private def transformTypedFilter(
+      fun: proto.CommonInlineUserDefinedFunction,
+      child: LogicalPlan): TypedFilter = {
+    val udf = fun.getScalarScalaUdf
+    val udfPacket =
+      Utils.deserialize[UdfPacket](
+        udf.getPayload.toByteArray,
+        SparkConnectArtifactManager.classLoaderWithArtifacts)
+    assert(udfPacket.inputEncoders.size == 1)
+    val iEnc = ExpressionEncoder(udfPacket.inputEncoders.head)
+    TypedFilter(udfPacket.function, child)(iEnc)
   }
 
   private def transformProject(rel: proto.Project): LogicalPlan = {
@@ -1065,7 +1130,7 @@ class SparkConnectPlanner(val session: SparkSession) {
         SparkConnectArtifactManager.classLoaderWithArtifacts)
     ScalaUDF(
       function = udfPacket.function,
-      dataType = udfPacket.outputEncoder.dataType,
+      dataType = transformDataType(udf.getOutputType),
       children = fun.getArgumentsList.asScala.map(transformExpression).toSeq,
       inputEncoders = udfPacket.inputEncoders.map(e => 
Option(ExpressionEncoder(e))),
       outputEncoder = Option(ExpressionEncoder(udfPacket.outputEncoder)),


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to