This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new fb9d7067845 [SPARK-42283][CONNECT][SCALA] Simple Scalar Scala UDFs fb9d7067845 is described below commit fb9d706784557ef0fe9e17d59b7096374658954e Author: vicennial <venkata.gud...@databricks.com> AuthorDate: Wed Feb 1 14:12:38 2023 -0400 [SPARK-42283][CONNECT][SCALA] Simple Scalar Scala UDFs ### What changes were proposed in this pull request? This PR adds support for "simple" scalar Scala UDFs for the Spark Connect Scala/JVM Client. “Simple” here refers to UDFs that utilize no client-specific class files (e.g REPL-generated) and JARs. Essentially, a “simple” UDF may only reference in-built libraries and classes defined within the scope of the UDF. A user would then be able to do the following (example): ``` def myFunc(x: Int): Int = x + 5 val myUdf = udf(myFunc _) df = df.select(myUdf(Column("id"))) ``` #### Implementation Details: A shared JVM object `UdfPacket` is introduced in the common package to encapsulate the Scala UDF and its encoders (via Agnostic Encoders) such that it could be serialized/deserialized on the client and server respectively. Further, a new protobuf message `ScalarScalaUDF` is introduced to transmit Scala/JVM specific information to the server (such as the above serialized JVM object). ### Why are the changes needed? UDFs are crucial for the completeness of the Spark Connect Scala/JVM client. We introduce this component incrementally. ### Does this PR introduce _any_ user-facing change? Yes, users are now able to run "simple" scalar Scala UDFs through the Scala/JVM client. ### How was this patch tested? Unit test + Integration test Closes #39850 from vicennial/SPARK-42283. Authored-by: vicennial <venkata.gud...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../sql/expressions/UserDefinedFunction.scala | 146 ++++++++++++ .../scala/org/apache/spark/sql/functions.scala | 255 +++++++++++++++++++++ .../org/apache/spark/sql/ClientE2ETestSuite.scala | 14 ++ .../spark/sql/UserDefinedFunctionSuite.scala | 53 +++++ connector/connect/common/pom.xml | 12 + .../main/protobuf/spark/connect/expressions.proto | 12 + .../spark/sql/connect/common/UdfPacket.scala | 70 ++++++ .../sql/connect/planner/SparkConnectPlanner.scala | 27 +++ .../pyspark/sql/connect/proto/expressions_pb2.py | 22 +- .../pyspark/sql/connect/proto/expressions_pb2.pyi | 66 +++++- 10 files changed, 671 insertions(+), 6 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala new file mode 100644 index 00000000000..0fe47092e4e --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.expressions + +import scala.collection.JavaConverters._ +import scala.reflect.runtime.universe.TypeTag + +import com.google.protobuf.ByteString + +import org.apache.spark.connect.proto +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.connect.common.UdfPacket +import org.apache.spark.util.Utils + +/** + * A user-defined function. To create one, use the `udf` functions in `functions`. + * + * As an example: + * {{{ + * // Define a UDF that returns true or false based on some numeric score. + * val predict = udf((score: Double) => score > 0.5) + * + * // Projects a column that adds a prediction column based on the score column. + * df.select( predict(df("score")) ) + * }}} + * + * @since 3.4.0 + */ +sealed abstract class UserDefinedFunction { + + /** + * Returns true when the UDF can return a nullable value. + * + * @since 3.4.0 + */ + def nullable: Boolean + + /** + * Returns true iff the UDF is deterministic, i.e. the UDF produces the same output given the + * same input. + * + * @since 3.4.0 + */ + def deterministic: Boolean + + /** + * Returns an expression that invokes the UDF, using the given arguments. + * + * @since 3.4.0 + */ + @scala.annotation.varargs + def apply(exprs: Column*): Column + + /** + * Updates UserDefinedFunction with a given name. + * + * @since 3.4.0 + */ + def withName(name: String): UserDefinedFunction + + /** + * Updates UserDefinedFunction to non-nullable. + * + * @since 3.4.0 + */ + def asNonNullable(): UserDefinedFunction + + /** + * Updates UserDefinedFunction to nondeterministic. + * + * @since 3.4.0 + */ + def asNondeterministic(): UserDefinedFunction +} + +/** + * Holder class for a scalar user-defined function and it's input/output encoder(s). + */ +case class ScalarUserDefinedFunction( + function: AnyRef, + inputEncoders: Seq[AgnosticEncoder[_]], + outputEncoder: AgnosticEncoder[_], + name: Option[String], + override val nullable: Boolean, + override val deterministic: Boolean) + extends UserDefinedFunction { + + private[this] lazy val udf = { + val udfPacketBytes = Utils.serialize(UdfPacket(function, inputEncoders, outputEncoder)) + val scalaUdfBuilder = proto.ScalarScalaUDF + .newBuilder() + .setPayload(ByteString.copyFrom(udfPacketBytes)) + .setNullable(nullable) + + scalaUdfBuilder.build() + } + + @scala.annotation.varargs + override def apply(exprs: Column*): Column = Column { builder => + val udfBuilder = builder.getCommonInlineUserDefinedFunctionBuilder + udfBuilder + .setDeterministic(deterministic) + .setScalarScalaUdf(udf) + .addAllArguments(exprs.map(_.expr).asJava) + + name.foreach(udfBuilder.setFunctionName) + } + + override def withName(name: String): ScalarUserDefinedFunction = copy(name = Option(name)) + + override def asNonNullable(): ScalarUserDefinedFunction = copy(nullable = false) + + override def asNondeterministic(): ScalarUserDefinedFunction = copy(deterministic = false) +} + +object ScalarUserDefinedFunction { + private[sql] def apply( + function: AnyRef, + returnType: TypeTag[_], + parameterTypes: TypeTag[_]*): ScalarUserDefinedFunction = { + + ScalarUserDefinedFunction( + function = function, + inputEncoders = parameterTypes.map(tag => ScalaReflection.encoderFor(tag)), + outputEncoder = ScalaReflection.encoderFor(returnType), + name = None, + nullable = true, + deterministic = true) + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index bae394785be..61174f1921e 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -19,10 +19,13 @@ package org.apache.spark.sql import java.math.{BigDecimal => JBigDecimal} import java.time.LocalDate +import scala.reflect.runtime.universe.{typeTag, TypeTag} + import com.google.protobuf.ByteString import org.apache.spark.connect.proto import org.apache.spark.sql.connect.client.unsupported +import org.apache.spark.sql.expressions.{ScalarUserDefinedFunction, UserDefinedFunction} /** * Commonly used functions available for DataFrame operations. @@ -80,4 +83,256 @@ object functions { case _ => unsupported(s"literal $literal not supported (yet).") } } + + // scalastyle:off line.size.limit + + /** + * Defines a Scala closure of 0 arguments as user-defined function (UDF). The data types are + * automatically inferred based on the Scala closure's signature. By default the returned UDF is + * deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 3.4.0 + */ + def udf[RT: TypeTag](f: () => RT): UserDefinedFunction = { + ScalarUserDefinedFunction(f, typeTag[RT]) + } + + /** + * Defines a Scala closure of 1 arguments as user-defined function (UDF). The data types are + * automatically inferred based on the Scala closure's signature. By default the returned UDF is + * deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 3.4.0 + */ + def udf[RT: TypeTag, A1: TypeTag](f: A1 => RT): UserDefinedFunction = { + ScalarUserDefinedFunction(f, typeTag[RT], typeTag[A1]) + } + + /** + * Defines a Scala closure of 2 arguments as user-defined function (UDF). The data types are + * automatically inferred based on the Scala closure's signature. By default the returned UDF is + * deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 3.4.0 + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: (A1, A2) => RT): UserDefinedFunction = { + ScalarUserDefinedFunction(f, typeTag[RT], typeTag[A1], typeTag[A2]) + } + + /** + * Defines a Scala closure of 3 arguments as user-defined function (UDF). The data types are + * automatically inferred based on the Scala closure's signature. By default the returned UDF is + * deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 3.4.0 + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag]( + f: (A1, A2, A3) => RT): UserDefinedFunction = { + ScalarUserDefinedFunction(f, typeTag[RT], typeTag[A1], typeTag[A2], typeTag[A3]) + } + + /** + * Defines a Scala closure of 4 arguments as user-defined function (UDF). The data types are + * automatically inferred based on the Scala closure's signature. By default the returned UDF is + * deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 3.4.0 + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag]( + f: (A1, A2, A3, A4) => RT): UserDefinedFunction = { + ScalarUserDefinedFunction(f, typeTag[RT], typeTag[A1], typeTag[A2], typeTag[A3], typeTag[A4]) + } + + /** + * Defines a Scala closure of 5 arguments as user-defined function (UDF). The data types are + * automatically inferred based on the Scala closure's signature. By default the returned UDF is + * deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 3.4.0 + */ + def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag]( + f: (A1, A2, A3, A4, A5) => RT): UserDefinedFunction = { + ScalarUserDefinedFunction( + f, + typeTag[RT], + typeTag[A1], + typeTag[A2], + typeTag[A3], + typeTag[A4], + typeTag[A5]) + } + + /** + * Defines a Scala closure of 6 arguments as user-defined function (UDF). The data types are + * automatically inferred based on the Scala closure's signature. By default the returned UDF is + * deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 3.4.0 + */ + def udf[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag](f: (A1, A2, A3, A4, A5, A6) => RT): UserDefinedFunction = { + ScalarUserDefinedFunction( + f, + typeTag[RT], + typeTag[A1], + typeTag[A2], + typeTag[A3], + typeTag[A4], + typeTag[A5], + typeTag[A6]) + } + + /** + * Defines a Scala closure of 7 arguments as user-defined function (UDF). The data types are + * automatically inferred based on the Scala closure's signature. By default the returned UDF is + * deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 3.4.0 + */ + def udf[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag](f: (A1, A2, A3, A4, A5, A6, A7) => RT): UserDefinedFunction = { + ScalarUserDefinedFunction( + f, + typeTag[RT], + typeTag[A1], + typeTag[A2], + typeTag[A3], + typeTag[A4], + typeTag[A5], + typeTag[A6], + typeTag[A7]) + } + + /** + * Defines a Scala closure of 8 arguments as user-defined function (UDF). The data types are + * automatically inferred based on the Scala closure's signature. By default the returned UDF is + * deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 3.4.0 + */ + def udf[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag](f: (A1, A2, A3, A4, A5, A6, A7, A8) => RT): UserDefinedFunction = { + ScalarUserDefinedFunction( + f, + typeTag[RT], + typeTag[A1], + typeTag[A2], + typeTag[A3], + typeTag[A4], + typeTag[A5], + typeTag[A6], + typeTag[A7], + typeTag[A8]) + } + + /** + * Defines a Scala closure of 9 arguments as user-defined function (UDF). The data types are + * automatically inferred based on the Scala closure's signature. By default the returned UDF is + * deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 3.4.0 + */ + def udf[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag](f: (A1, A2, A3, A4, A5, A6, A7, A8, A9) => RT): UserDefinedFunction = { + ScalarUserDefinedFunction( + f, + typeTag[RT], + typeTag[A1], + typeTag[A2], + typeTag[A3], + typeTag[A4], + typeTag[A5], + typeTag[A6], + typeTag[A7], + typeTag[A8], + typeTag[A9]) + } + + /** + * Defines a Scala closure of 10 arguments as user-defined function (UDF). The data types are + * automatically inferred based on the Scala closure's signature. By default the returned UDF is + * deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 3.4.0 + */ + def udf[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag](f: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10) => RT): UserDefinedFunction = { + ScalarUserDefinedFunction( + f, + typeTag[RT], + typeTag[A1], + typeTag[A2], + typeTag[A3], + typeTag[A4], + typeTag[A5], + typeTag[A6], + typeTag[A7], + typeTag[A8], + typeTag[A9], + typeTag[A10]) + } + // scalastyle:off line.size.limit + } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index e31f121ca10..db2b8b26987 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql import org.apache.spark.sql.connect.client.util.RemoteSparkSession +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{StringType, StructField, StructType} class ClientE2ETestSuite extends RemoteSparkSession { @@ -48,6 +49,19 @@ class ClientE2ETestSuite extends RemoteSparkSession { assert(array(2).getLong(0) == 2) } + test("simple udf test") { + + def dummyUdf(x: Int): Int = x + 5 + val myUdf = udf(dummyUdf _) + val df = spark.range(5).select(myUdf(Column("id"))) + + val result = df.collectResult() + assert(result.length == 5) + result.toArray.zipWithIndex.foreach { case (v, idx) => + assert(v.getInt(0) == idx + 5) + } + } + // TODO test large result when we can create table or view // test("test spark large result") } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala new file mode 100644 index 00000000000..b0d92a223c6 --- /dev/null +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import scala.reflect.runtime.universe.typeTag + +import org.scalatest.BeforeAndAfterEach +import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite + +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.connect.common.UdfPacket +import org.apache.spark.sql.functions.udf +import org.apache.spark.util.Utils + +class UserDefinedFunctionSuite + extends AnyFunSuite // scalastyle:ignore funsuite + with BeforeAndAfterEach { + + test("udf and encoder serialization") { + def func(x: Int): Int = x + 1 + + val myUdf = udf(func _) + val colWithUdf = myUdf(Column("dummy")) + + val udfExpr = colWithUdf.expr.getCommonInlineUserDefinedFunction + assert(udfExpr.getDeterministic) + assert(udfExpr.getArgumentsCount == 1) + assert(udfExpr.getArguments(0) == Column("dummy").expr) + val udfObj = udfExpr.getScalarScalaUdf + + assert(udfObj.getNullable) + + val deSer = Utils.deserialize[UdfPacket](udfObj.getPayload.toByteArray) + + assert(deSer.function.asInstanceOf[Int => Int](5) == func(5)) + assert(deSer.outputEncoder == ScalaReflection.encoderFor(typeTag[Int])) + assert(deSer.inputEncoders == Seq(ScalaReflection.encoderFor(typeTag[Int]))) + } +} diff --git a/connector/connect/common/pom.xml b/connector/connect/common/pom.xml index a37f87dda1e..eb1e4cae34d 100644 --- a/connector/connect/common/pom.xml +++ b/connector/connect/common/pom.xml @@ -38,6 +38,18 @@ <tomcat.annotations.api.version>6.0.53</tomcat.annotations.api.version> </properties> <dependencies> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-catalyst_${scala.binary.version}</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + <exclusions> + <exclusion> + <groupId>com.google.guava</groupId> + <artifactId>guava</artifactId> + </exclusion> + </exclusions> + </dependency> <dependency> <groupId>org.scala-lang</groupId> <artifactId>scala-library</artifactId> diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto index 5b27d4593db..66361883321 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -307,6 +307,7 @@ message CommonInlineUserDefinedFunction { // (Required) Indicate the function type of the user-defined function. oneof function { PythonUDF python_udf = 4; + ScalarScalaUDF scalar_scala_udf = 5; } } @@ -319,3 +320,14 @@ message PythonUDF { bytes command = 3; } +message ScalarScalaUDF { + // (Required) Serialized JVM object containing UDF definition, input encoders and output encoder + bytes payload = 1; + // (Optional) Input type(s) of the UDF + repeated DataType inputTypes = 2; + // (Required) Output type of the UDF + DataType outputType = 3; + // (Required) True if the UDF can return null value + bool nullable = 4; +} + diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfPacket.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfPacket.scala new file mode 100644 index 00000000000..6829b8d1b21 --- /dev/null +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfPacket.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.common + +import com.google.protobuf.ByteString +import java.io.{InputStream, ObjectInputStream, ObjectOutputStream, OutputStream} + +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder + +/** + * A wrapper class around the UDF and it's Input/Output [[AgnosticEncoder]](s). + * + * This class is shared between the client and the server to allow for serialization and + * deserialization of the JVM object. + * + * @param function + * The UDF + * @param inputEncoders + * A list of [[AgnosticEncoder]](s) for all input arguments of the UDF + * @param outputEncoder + * An [[AgnosticEncoder]] for the output of the UDF + */ +@SerialVersionUID(8866761834651399125L) +case class UdfPacket( + function: AnyRef, + inputEncoders: Seq[AgnosticEncoder[_]], + outputEncoder: AgnosticEncoder[_]) + extends Serializable { + + def writeTo(out: OutputStream): Unit = { + val oos = new ObjectOutputStream(out) + oos.writeObject(this) + oos.flush() + } + + def toByteString: ByteString = { + val out = ByteString.newOutput() + writeTo(out) + out.toByteString + } +} + +object UdfPacket { + def apply(in: InputStream): UdfPacket = { + val ois = new ObjectInputStream(in) + ois.readObject().asInstanceOf[UdfPacket] + } + + def apply(bytes: ByteString): UdfPacket = { + val in = bytes.newInput() + try UdfPacket(in) + finally { + in.close() + } + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 9b5c4b93f62..51d115ef1ca 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -29,6 +29,7 @@ import org.apache.spark.connect.proto import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier} import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils} @@ -36,6 +37,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, L import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, Sample, Sort, SubqueryAlias, Union, Unpivot, UnresolvedHint} import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} +import org.apache.spark.sql.connect.common.UdfPacket import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.{toCatalystExpression, toCatalystValue} import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry import org.apache.spark.sql.errors.QueryCompilationErrors @@ -831,12 +833,37 @@ class SparkConnectPlanner(val session: SparkSession) { fun.getFunctionCase match { case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF => transformPythonUDF(fun) + case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF => + transformScalarScalaUDF(fun) case _ => throw InvalidPlanInput( s"Function with ID: ${fun.getFunctionCase.getNumber} is not supported") } } + /** + * Translates a Scalar Scala user-defined function from proto to the Catalyst expression. + * + * @param fun + * Proto representation of the Scalar Scalar user-defined function. + * @return + * ScalaUDF. + */ + private def transformScalarScalaUDF(fun: proto.CommonInlineUserDefinedFunction): ScalaUDF = { + val udf = fun.getScalarScalaUdf + val udfPacket = + Utils.deserialize[UdfPacket](udf.getPayload.toByteArray, Utils.getContextOrSparkClassLoader) + ScalaUDF( + function = udfPacket.function, + dataType = udfPacket.outputEncoder.dataType, + children = fun.getArgumentsList.asScala.map(transformExpression).toSeq, + inputEncoders = udfPacket.inputEncoders.map(e => Option(ExpressionEncoder(e))), + outputEncoder = Option(ExpressionEncoder(udfPacket.outputEncoder)), + udfName = Option(fun.getFunctionName), + nullable = udf.getNullable, + udfDeterministic = fun.getDeterministic) + } + /** * Translates a Python user-defined function from proto to the Catalyst expression. * diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index f320eee54e0..3a06e80c21e 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x92%\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...] + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x92%\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...] ) @@ -65,6 +65,7 @@ _COMMONINLINEUSERDEFINEDFUNCTION = DESCRIPTOR.message_types_by_name[ "CommonInlineUserDefinedFunction" ] _PYTHONUDF = DESCRIPTOR.message_types_by_name["PythonUDF"] +_SCALARSCALAUDF = DESCRIPTOR.message_types_by_name["ScalarScalaUDF"] _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE = _EXPRESSION_WINDOW_WINDOWFRAME.enum_types_by_name[ "FrameType" ] @@ -283,6 +284,17 @@ PythonUDF = _reflection.GeneratedProtocolMessageType( ) _sym_db.RegisterMessage(PythonUDF) +ScalarScalaUDF = _reflection.GeneratedProtocolMessageType( + "ScalarScalaUDF", + (_message.Message,), + { + "DESCRIPTOR": _SCALARSCALAUDF, + "__module__": "spark.connect.expressions_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.ScalarScalaUDF) + }, +) +_sym_db.RegisterMessage(ScalarScalaUDF) + if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None @@ -332,7 +344,9 @@ if _descriptor._USE_C_DESCRIPTORS == False: _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4784 _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4846 _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 4862 - _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5098 - _PYTHONUDF._serialized_start = 5100 - _PYTHONUDF._serialized_end = 5199 + _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5173 + _PYTHONUDF._serialized_start = 5175 + _PYTHONUDF._serialized_end = 5274 + _SCALARSCALAUDF._serialized_start = 5277 + _SCALARSCALAUDF._serialized_end = 5461 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index d8b0485017c..604672a9ad7 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -1100,6 +1100,7 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message): DETERMINISTIC_FIELD_NUMBER: builtins.int ARGUMENTS_FIELD_NUMBER: builtins.int PYTHON_UDF_FIELD_NUMBER: builtins.int + SCALAR_SCALA_UDF_FIELD_NUMBER: builtins.int function_name: builtins.str """(Required) Name of the user-defined function.""" deterministic: builtins.bool @@ -1111,6 +1112,8 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message): """(Optional) Function arguments. Empty arguments are allowed.""" @property def python_udf(self) -> global___PythonUDF: ... + @property + def scalar_scala_udf(self) -> global___ScalarScalaUDF: ... def __init__( self, *, @@ -1118,10 +1121,18 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message): deterministic: builtins.bool = ..., arguments: collections.abc.Iterable[global___Expression] | None = ..., python_udf: global___PythonUDF | None = ..., + scalar_scala_udf: global___ScalarScalaUDF | None = ..., ) -> None: ... def HasField( self, - field_name: typing_extensions.Literal["function", b"function", "python_udf", b"python_udf"], + field_name: typing_extensions.Literal[ + "function", + b"function", + "python_udf", + b"python_udf", + "scalar_scala_udf", + b"scalar_scala_udf", + ], ) -> builtins.bool: ... def ClearField( self, @@ -1136,11 +1147,13 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message): b"function_name", "python_udf", b"python_udf", + "scalar_scala_udf", + b"scalar_scala_udf", ], ) -> None: ... def WhichOneof( self, oneof_group: typing_extensions.Literal["function", b"function"] - ) -> typing_extensions.Literal["python_udf"] | None: ... + ) -> typing_extensions.Literal["python_udf", "scalar_scala_udf"] | None: ... global___CommonInlineUserDefinedFunction = CommonInlineUserDefinedFunction @@ -1171,3 +1184,52 @@ class PythonUDF(google.protobuf.message.Message): ) -> None: ... global___PythonUDF = PythonUDF + +class ScalarScalaUDF(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + PAYLOAD_FIELD_NUMBER: builtins.int + INPUTTYPES_FIELD_NUMBER: builtins.int + OUTPUTTYPE_FIELD_NUMBER: builtins.int + NULLABLE_FIELD_NUMBER: builtins.int + payload: builtins.bytes + """(Required) Serialized JVM object containing UDF definition, input encoders and output encoder""" + @property + def inputTypes( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + pyspark.sql.connect.proto.types_pb2.DataType + ]: + """(Optional) Input type(s) of the UDF""" + @property + def outputType(self) -> pyspark.sql.connect.proto.types_pb2.DataType: + """(Required) Output type of the UDF""" + nullable: builtins.bool + """(Required) True if the UDF can return null value""" + def __init__( + self, + *, + payload: builtins.bytes = ..., + inputTypes: collections.abc.Iterable[pyspark.sql.connect.proto.types_pb2.DataType] + | None = ..., + outputType: pyspark.sql.connect.proto.types_pb2.DataType | None = ..., + nullable: builtins.bool = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["outputType", b"outputType"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "inputTypes", + b"inputTypes", + "nullable", + b"nullable", + "outputType", + b"outputType", + "payload", + b"payload", + ], + ) -> None: ... + +global___ScalarScalaUDF = ScalarScalaUDF --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org