This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 7ecdad5c59c [SPARK-43995][SPARK-43996][CONNECT] Add support for UDFRegistration to the Connect Scala Client 7ecdad5c59c is described below commit 7ecdad5c59ce2eecd4686effeb10819a6d784844 Author: vicennial <venkata.gud...@databricks.com> AuthorDate: Fri Jul 14 10:52:12 2023 +0900 [SPARK-43995][SPARK-43996][CONNECT] Add support for UDFRegistration to the Connect Scala Client ### What changes were proposed in this pull request? This PR adds support to register a scala UDF from the scala/jvm client. The following APIs are implemented in `UDFRegistration`: - `def register(name: String, udf: UserDefinedFunction): UserDefinedFunction` - `def register[RT: TypeTag, A1: TypeTag ...](name: String, func: (A1, ...) => RT): UserDefinedFunction` for 0 to 22 arguments. The following API is implemented in `functions`: - `def call_udf(udfName: String, cols: Column*): Column` Note: This PR is stacked on https://github.com/apache/spark/pull/41959. ### Why are the changes needed? To reach parity with classic Spark. ### Does this PR introduce _any_ user-facing change? Yes. spark.udf.register() is added as shown below: ```scala class A(x: Int) { def get = x * 100 } val myUdf = udf((x: Int) => new A(x).get) spark.udf.register("dummyUdf", myUdf) spark.sql("select dummyUdf(id) from range(5)").as[Long].collect() ``` The output: ```scala Array[Long] = Array(0L, 100L, 200L, 300L, 400L) ```` ### How was this patch tested? New tests in `ReplE2ESuite`. Closes #41953 from vicennial/SPARK-43995. Authored-by: vicennial <venkata.gud...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../scala/org/apache/spark/sql/SparkSession.scala | 31 + .../org/apache/spark/sql/UDFRegistration.scala | 1028 ++++++++++++++++++++ .../sql/expressions/UserDefinedFunction.scala | 10 + .../scala/org/apache/spark/sql/functions.scala | 17 + .../spark/sql/application/ReplE2ESuite.scala | 31 + .../CheckConnectJvmClientCompatibility.scala | 1 - .../sql/connect/planner/SparkConnectPlanner.scala | 23 +- 7 files changed, 1139 insertions(+), 2 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index c27f0f32e0d..fb9959c9942 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -417,6 +417,30 @@ class SparkSession private[sql] ( range(start, end, step, Option(numPartitions)) } + /** + * A collection of methods for registering user-defined functions (UDF). + * + * The following example registers a Scala closure as UDF: + * {{{ + * sparkSession.udf.register("myUDF", (arg1: Int, arg2: String) => arg2 + arg1) + * }}} + * + * The following example registers a UDF in Java: + * {{{ + * sparkSession.udf().register("myUDF", + * (Integer arg1, String arg2) -> arg2 + arg1, + * DataTypes.StringType); + * }}} + * + * @note + * The user-defined functions must be deterministic. Due to optimization, duplicate + * invocations may be eliminated or the function may even be invoked more times than it is + * present in the query. + * + * @since 3.5.0 + */ + lazy val udf: UDFRegistration = new UDFRegistration(this) + // scalastyle:off // Disable style checker so "implicits" object can start with lowercase i /** @@ -525,6 +549,13 @@ class SparkSession private[sql] ( client.execute(plan).asScala.toSeq } + private[sql] def registerUdf(udf: proto.CommonInlineUserDefinedFunction): Unit = { + val command = proto.Command.newBuilder().setRegisterFunction(udf).build() + val plan = proto.Plan.newBuilder().setCommand(command).build() + + client.execute(plan) + } + @DeveloperApi def execute(extension: com.google.protobuf.Any): Unit = { val command = proto.Command.newBuilder().setExtension(extension).build() diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/UDFRegistration.scala new file mode 100644 index 00000000000..426709b8f18 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -0,0 +1,1028 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.reflect.runtime.universe.{typeTag, TypeTag} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.expressions.{ScalarUserDefinedFunction, UserDefinedFunction} + +/** + * Functions for registering user-defined functions. Use `SparkSession.udf` to access this: + * + * {{{ + * spark.udf + * }}} + * + * @since 3.5.0 + */ +class UDFRegistration(session: SparkSession) extends Logging { + + /** + * Registers a user-defined function (UDF), for a UDF that's already defined using the Dataset + * API (i.e. of type UserDefinedFunction). To change a UDF to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. To change a UDF to nonNullable, call the API + * `UserDefinedFunction.asNonNullable()`. + * + * Example: + * {{{ + * val foo = udf(() => Math.random()) + * spark.udf.register("random", foo.asNondeterministic()) + * + * val bar = udf(() => "bar") + * spark.udf.register("stringLit", bar.asNonNullable()) + * }}} + * + * @param name + * the name of the UDF. + * @param udf + * the UDF needs to be registered. + * @return + * the registered UDF. + * + * @since 3.5.0 + */ + def register(name: String, udf: UserDefinedFunction): UserDefinedFunction = { + udf.withName(name) match { + case scalarUdf: ScalarUserDefinedFunction => + session.registerUdf(scalarUdf.toProto) + scalarUdf + case other => + throw new UnsupportedOperationException( + s"Registering a UDF of type " + + s"${other.getClass.getSimpleName} is currently unsupported.") + } + } + + // scalastyle:off line.size.limit + + /* register 0-22 were generated by this script: + (0 to 22).foreach { x => + val params = (1 to x).map(num => s"A$num").mkString(", ") + val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) + println(s""" + |/** + | * Registers a deterministic Scala closure of $x arguments as user-defined function (UDF). + | * @tparam RT return type of UDF. + | * @since 3.5.0 + | */ + |def register[$typeTags](name: String, func: ($params) => RT): UserDefinedFunction = { + | register(name, functions.udf(func)) + |}""".stripMargin) + } + */ + + /** + * Registers a deterministic Scala closure of 0 arguments as user-defined function (UDF). + * @tparam RT + * return type of UDF. + * @since 3.5.0 + */ + def register[RT: TypeTag](name: String, func: () => RT): UserDefinedFunction = { + val udf = ScalarUserDefinedFunction(func, typeTag[RT]) + register(name, udf) + } + + /** + * Registers a deterministic Scala closure of 1 arguments as user-defined function (UDF). + * @tparam RT + * return type of UDF. + * @since 3.5.0 + */ + def register[RT: TypeTag, A1: TypeTag](name: String, func: (A1) => RT): UserDefinedFunction = { + val udf = ScalarUserDefinedFunction(func, typeTag[RT], typeTag[A1]) + register(name, udf) + } + + /** + * Registers a deterministic Scala closure of 2 arguments as user-defined function (UDF). + * @tparam RT + * return type of UDF. + * @since 3.5.0 + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag]( + name: String, + func: (A1, A2) => RT): UserDefinedFunction = { + val udf = ScalarUserDefinedFunction(func, typeTag[RT], typeTag[A1], typeTag[A2]) + register(name, udf) + } + + /** + * Registers a deterministic Scala closure of 3 arguments as user-defined function (UDF). + * @tparam RT + * return type of UDF. + * @since 3.5.0 + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag]( + name: String, + func: (A1, A2, A3) => RT): UserDefinedFunction = { + val udf = ScalarUserDefinedFunction(func, typeTag[RT], typeTag[A1], typeTag[A2], typeTag[A3]) + register(name, udf) + } + + /** + * Registers a deterministic Scala closure of 4 arguments as user-defined function (UDF). + * @tparam RT + * return type of UDF. + * @since 3.5.0 + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag]( + name: String, + func: (A1, A2, A3, A4) => RT): UserDefinedFunction = { + val udf = ScalarUserDefinedFunction( + func, + typeTag[RT], + typeTag[A1], + typeTag[A2], + typeTag[A3], + typeTag[A4]) + register(name, udf) + } + + /** + * Registers a deterministic Scala closure of 5 arguments as user-defined function (UDF). + * @tparam RT + * return type of UDF. + * @since 3.5.0 + */ + def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag]( + name: String, + func: (A1, A2, A3, A4, A5) => RT): UserDefinedFunction = { + val udf = ScalarUserDefinedFunction( + func, + typeTag[RT], + typeTag[A1], + typeTag[A2], + typeTag[A3], + typeTag[A4], + typeTag[A5]) + register(name, udf) + } + + /** + * Registers a deterministic Scala closure of 6 arguments as user-defined function (UDF). + * @tparam RT + * return type of UDF. + * @since 3.5.0 + */ + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag](name: String, func: (A1, A2, A3, A4, A5, A6) => RT): UserDefinedFunction = { + val udf = ScalarUserDefinedFunction( + func, + typeTag[RT], + typeTag[A1], + typeTag[A2], + typeTag[A3], + typeTag[A4], + typeTag[A5], + typeTag[A6]) + register(name, udf) + } + + /** + * Registers a deterministic Scala closure of 7 arguments as user-defined function (UDF). + * @tparam RT + * return type of UDF. + * @since 3.5.0 + */ + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag]( + name: String, + func: (A1, A2, A3, A4, A5, A6, A7) => RT): UserDefinedFunction = { + val udf = ScalarUserDefinedFunction( + func, + typeTag[RT], + typeTag[A1], + typeTag[A2], + typeTag[A3], + typeTag[A4], + typeTag[A5], + typeTag[A6], + typeTag[A7]) + register(name, udf) + } + + /** + * Registers a deterministic Scala closure of 8 arguments as user-defined function (UDF). + * @tparam RT + * return type of UDF. + * @since 3.5.0 + */ + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag]( + name: String, + func: (A1, A2, A3, A4, A5, A6, A7, A8) => RT): UserDefinedFunction = { + val udf = ScalarUserDefinedFunction( + func, + typeTag[RT], + typeTag[A1], + typeTag[A2], + typeTag[A3], + typeTag[A4], + typeTag[A5], + typeTag[A6], + typeTag[A7], + typeTag[A8]) + register(name, udf) + } + + /** + * Registers a deterministic Scala closure of 9 arguments as user-defined function (UDF). + * @tparam RT + * return type of UDF. + * @since 3.5.0 + */ + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag]( + name: String, + func: (A1, A2, A3, A4, A5, A6, A7, A8, A9) => RT): UserDefinedFunction = { + val udf = ScalarUserDefinedFunction( + func, + typeTag[RT], + typeTag[A1], + typeTag[A2], + typeTag[A3], + typeTag[A4], + typeTag[A5], + typeTag[A6], + typeTag[A7], + typeTag[A8], + typeTag[A9]) + register(name, udf) + } + + /** + * Registers a deterministic Scala closure of 10 arguments as user-defined function (UDF). + * @tparam RT + * return type of UDF. + * @since 3.5.0 + */ + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag]( + name: String, + func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10) => RT): UserDefinedFunction = { + val udf = ScalarUserDefinedFunction( + func, + typeTag[RT], + typeTag[A1], + typeTag[A2], + typeTag[A3], + typeTag[A4], + typeTag[A5], + typeTag[A6], + typeTag[A7], + typeTag[A8], + typeTag[A9], + typeTag[A10]) + register(name, udf) + } + + /** + * Registers a deterministic Scala closure of 11 arguments as user-defined function (UDF). + * @tparam RT + * return type of UDF. + * @since 3.5.0 + */ + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag, + A11: TypeTag]( + name: String, + func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11) => RT): UserDefinedFunction = { + val udf = ScalarUserDefinedFunction( + func, + typeTag[RT], + typeTag[A1], + typeTag[A2], + typeTag[A3], + typeTag[A4], + typeTag[A5], + typeTag[A6], + typeTag[A7], + typeTag[A8], + typeTag[A9], + typeTag[A10], + typeTag[A11]) + register(name, udf) + } + + /** + * Registers a deterministic Scala closure of 12 arguments as user-defined function (UDF). + * @tparam RT + * return type of UDF. + * @since 3.5.0 + */ + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag, + A11: TypeTag, + A12: TypeTag]( + name: String, + func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12) => RT): UserDefinedFunction = { + val udf = ScalarUserDefinedFunction( + func, + typeTag[RT], + typeTag[A1], + typeTag[A2], + typeTag[A3], + typeTag[A4], + typeTag[A5], + typeTag[A6], + typeTag[A7], + typeTag[A8], + typeTag[A9], + typeTag[A10], + typeTag[A11], + typeTag[A12]) + register(name, udf) + } + + /** + * Registers a deterministic Scala closure of 13 arguments as user-defined function (UDF). + * @tparam RT + * return type of UDF. + * @since 3.5.0 + */ + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag, + A11: TypeTag, + A12: TypeTag, + A13: TypeTag]( + name: String, + func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13) => RT) + : UserDefinedFunction = { + val udf = ScalarUserDefinedFunction( + func, + typeTag[RT], + typeTag[A1], + typeTag[A2], + typeTag[A3], + typeTag[A4], + typeTag[A5], + typeTag[A6], + typeTag[A7], + typeTag[A8], + typeTag[A9], + typeTag[A10], + typeTag[A11], + typeTag[A12], + typeTag[A13]) + register(name, udf) + } + + /** + * Registers a deterministic Scala closure of 14 arguments as user-defined function (UDF). + * @tparam RT + * return type of UDF. + * @since 3.5.0 + */ + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag, + A11: TypeTag, + A12: TypeTag, + A13: TypeTag, + A14: TypeTag]( + name: String, + func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14) => RT) + : UserDefinedFunction = { + val udf = ScalarUserDefinedFunction( + func, + typeTag[RT], + typeTag[A1], + typeTag[A2], + typeTag[A3], + typeTag[A4], + typeTag[A5], + typeTag[A6], + typeTag[A7], + typeTag[A8], + typeTag[A9], + typeTag[A10], + typeTag[A11], + typeTag[A12], + typeTag[A13], + typeTag[A14]) + register(name, udf) + } + + /** + * Registers a deterministic Scala closure of 15 arguments as user-defined function (UDF). + * @tparam RT + * return type of UDF. + * @since 3.5.0 + */ + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag, + A11: TypeTag, + A12: TypeTag, + A13: TypeTag, + A14: TypeTag, + A15: TypeTag]( + name: String, + func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15) => RT) + : UserDefinedFunction = { + val udf = ScalarUserDefinedFunction( + func, + typeTag[RT], + typeTag[A1], + typeTag[A2], + typeTag[A3], + typeTag[A4], + typeTag[A5], + typeTag[A6], + typeTag[A7], + typeTag[A8], + typeTag[A9], + typeTag[A10], + typeTag[A11], + typeTag[A12], + typeTag[A13], + typeTag[A14], + typeTag[A15]) + register(name, udf) + } + + /** + * Registers a deterministic Scala closure of 16 arguments as user-defined function (UDF). + * @tparam RT + * return type of UDF. + * @since 3.5.0 + */ + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag, + A11: TypeTag, + A12: TypeTag, + A13: TypeTag, + A14: TypeTag, + A15: TypeTag, + A16: TypeTag]( + name: String, + func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16) => RT) + : UserDefinedFunction = { + val udf = ScalarUserDefinedFunction( + func, + typeTag[RT], + typeTag[A1], + typeTag[A2], + typeTag[A3], + typeTag[A4], + typeTag[A5], + typeTag[A6], + typeTag[A7], + typeTag[A8], + typeTag[A9], + typeTag[A10], + typeTag[A11], + typeTag[A12], + typeTag[A13], + typeTag[A14], + typeTag[A15], + typeTag[A16]) + register(name, udf) + } + + /** + * Registers a deterministic Scala closure of 17 arguments as user-defined function (UDF). + * @tparam RT + * return type of UDF. + * @since 3.5.0 + */ + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag, + A11: TypeTag, + A12: TypeTag, + A13: TypeTag, + A14: TypeTag, + A15: TypeTag, + A16: TypeTag, + A17: TypeTag]( + name: String, + func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17) => RT) + : UserDefinedFunction = { + val udf = ScalarUserDefinedFunction( + func, + typeTag[RT], + typeTag[A1], + typeTag[A2], + typeTag[A3], + typeTag[A4], + typeTag[A5], + typeTag[A6], + typeTag[A7], + typeTag[A8], + typeTag[A9], + typeTag[A10], + typeTag[A11], + typeTag[A12], + typeTag[A13], + typeTag[A14], + typeTag[A15], + typeTag[A16], + typeTag[A17]) + register(name, udf) + } + + /** + * Registers a deterministic Scala closure of 18 arguments as user-defined function (UDF). + * @tparam RT + * return type of UDF. + * @since 3.5.0 + */ + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag, + A11: TypeTag, + A12: TypeTag, + A13: TypeTag, + A14: TypeTag, + A15: TypeTag, + A16: TypeTag, + A17: TypeTag, + A18: TypeTag]( + name: String, + func: ( + A1, + A2, + A3, + A4, + A5, + A6, + A7, + A8, + A9, + A10, + A11, + A12, + A13, + A14, + A15, + A16, + A17, + A18) => RT): UserDefinedFunction = { + val udf = ScalarUserDefinedFunction( + func, + typeTag[RT], + typeTag[A1], + typeTag[A2], + typeTag[A3], + typeTag[A4], + typeTag[A5], + typeTag[A6], + typeTag[A7], + typeTag[A8], + typeTag[A9], + typeTag[A10], + typeTag[A11], + typeTag[A12], + typeTag[A13], + typeTag[A14], + typeTag[A15], + typeTag[A16], + typeTag[A17], + typeTag[A18]) + register(name, udf) + } + + /** + * Registers a deterministic Scala closure of 19 arguments as user-defined function (UDF). + * @tparam RT + * return type of UDF. + * @since 3.5.0 + */ + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag, + A11: TypeTag, + A12: TypeTag, + A13: TypeTag, + A14: TypeTag, + A15: TypeTag, + A16: TypeTag, + A17: TypeTag, + A18: TypeTag, + A19: TypeTag]( + name: String, + func: ( + A1, + A2, + A3, + A4, + A5, + A6, + A7, + A8, + A9, + A10, + A11, + A12, + A13, + A14, + A15, + A16, + A17, + A18, + A19) => RT): UserDefinedFunction = { + val udf = ScalarUserDefinedFunction( + func, + typeTag[RT], + typeTag[A1], + typeTag[A2], + typeTag[A3], + typeTag[A4], + typeTag[A5], + typeTag[A6], + typeTag[A7], + typeTag[A8], + typeTag[A9], + typeTag[A10], + typeTag[A11], + typeTag[A12], + typeTag[A13], + typeTag[A14], + typeTag[A15], + typeTag[A16], + typeTag[A17], + typeTag[A18], + typeTag[A19]) + register(name, udf) + } + + /** + * Registers a deterministic Scala closure of 20 arguments as user-defined function (UDF). + * @tparam RT + * return type of UDF. + * @since 3.5.0 + */ + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag, + A11: TypeTag, + A12: TypeTag, + A13: TypeTag, + A14: TypeTag, + A15: TypeTag, + A16: TypeTag, + A17: TypeTag, + A18: TypeTag, + A19: TypeTag, + A20: TypeTag]( + name: String, + func: ( + A1, + A2, + A3, + A4, + A5, + A6, + A7, + A8, + A9, + A10, + A11, + A12, + A13, + A14, + A15, + A16, + A17, + A18, + A19, + A20) => RT): UserDefinedFunction = { + val udf = ScalarUserDefinedFunction( + func, + typeTag[RT], + typeTag[A1], + typeTag[A2], + typeTag[A3], + typeTag[A4], + typeTag[A5], + typeTag[A6], + typeTag[A7], + typeTag[A8], + typeTag[A9], + typeTag[A10], + typeTag[A11], + typeTag[A12], + typeTag[A13], + typeTag[A14], + typeTag[A15], + typeTag[A16], + typeTag[A17], + typeTag[A18], + typeTag[A19], + typeTag[A20]) + register(name, udf) + } + + /** + * Registers a deterministic Scala closure of 21 arguments as user-defined function (UDF). + * @tparam RT + * return type of UDF. + * @since 3.5.0 + */ + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag, + A11: TypeTag, + A12: TypeTag, + A13: TypeTag, + A14: TypeTag, + A15: TypeTag, + A16: TypeTag, + A17: TypeTag, + A18: TypeTag, + A19: TypeTag, + A20: TypeTag, + A21: TypeTag]( + name: String, + func: ( + A1, + A2, + A3, + A4, + A5, + A6, + A7, + A8, + A9, + A10, + A11, + A12, + A13, + A14, + A15, + A16, + A17, + A18, + A19, + A20, + A21) => RT): UserDefinedFunction = { + val udf = ScalarUserDefinedFunction( + func, + typeTag[RT], + typeTag[A1], + typeTag[A2], + typeTag[A3], + typeTag[A4], + typeTag[A5], + typeTag[A6], + typeTag[A7], + typeTag[A8], + typeTag[A9], + typeTag[A10], + typeTag[A11], + typeTag[A12], + typeTag[A13], + typeTag[A14], + typeTag[A15], + typeTag[A16], + typeTag[A17], + typeTag[A18], + typeTag[A19], + typeTag[A20], + typeTag[A21]) + register(name, udf) + } + + /** + * Registers a deterministic Scala closure of 22 arguments as user-defined function (UDF). + * @tparam RT + * return type of UDF. + * @since 3.5.0 + */ + def register[ + RT: TypeTag, + A1: TypeTag, + A2: TypeTag, + A3: TypeTag, + A4: TypeTag, + A5: TypeTag, + A6: TypeTag, + A7: TypeTag, + A8: TypeTag, + A9: TypeTag, + A10: TypeTag, + A11: TypeTag, + A12: TypeTag, + A13: TypeTag, + A14: TypeTag, + A15: TypeTag, + A16: TypeTag, + A17: TypeTag, + A18: TypeTag, + A19: TypeTag, + A20: TypeTag, + A21: TypeTag, + A22: TypeTag]( + name: String, + func: ( + A1, + A2, + A3, + A4, + A5, + A6, + A7, + A8, + A9, + A10, + A11, + A12, + A13, + A14, + A15, + A16, + A17, + A18, + A19, + A20, + A21, + A22) => RT): UserDefinedFunction = { + val udf = ScalarUserDefinedFunction( + func, + typeTag[RT], + typeTag[A1], + typeTag[A2], + typeTag[A3], + typeTag[A4], + typeTag[A5], + typeTag[A6], + typeTag[A7], + typeTag[A8], + typeTag[A9], + typeTag[A10], + typeTag[A11], + typeTag[A12], + typeTag[A13], + typeTag[A14], + typeTag[A15], + typeTag[A16], + typeTag[A17], + typeTag[A18], + typeTag[A19], + typeTag[A20], + typeTag[A21], + typeTag[A22]) + register(name, udf) + } + // scalastyle:on line.size.limit +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 7bce4b5b31a..18aef8a2e4c 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -130,6 +130,16 @@ case class ScalarUserDefinedFunction private ( override def asNonNullable(): ScalarUserDefinedFunction = copy(nullable = false) override def asNondeterministic(): ScalarUserDefinedFunction = copy(deterministic = false) + + def toProto: proto.CommonInlineUserDefinedFunction = { + val builder = proto.CommonInlineUserDefinedFunction.newBuilder() + builder + .setDeterministic(deterministic) + .setScalarScalaUdf(udf) + + name.foreach(builder.setFunctionName) + builder.build() + } } object ScalarUserDefinedFunction { diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index b0ae4c9752a..17d1cdca350 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -7905,6 +7905,23 @@ object functions { } // scalastyle:off line.size.limit + /** + * Call an user-defined function. Example: + * {{{ + * import org.apache.spark.sql._ + * + * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") + * val spark = df.sparkSession + * spark.udf.register("simpleUDF", (v: Int) => v * v) + * df.select($"id", call_udf("simpleUDF", $"value")) + * }}} + * + * @group udf_funcs + * @since 3.5.0 + */ + @scala.annotation.varargs + def call_udf(udfName: String, cols: Column*): Column = call_function(udfName, cols: _*) + /** * Call a builtin or temp function. * diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala index 40841aa3b39..58758a13840 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala @@ -206,4 +206,35 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { assertContains("Array[Int] = Array(2, 2, 2, 2, 2)", output) // scalastyle:on classforname line.size.limit } + + test("UDF Registration") { + val input = """ + |class A(x: Int) { def get = x * 100 } + |val myUdf = udf((x: Int) => new A(x).get) + |spark.udf.register("dummyUdf", myUdf) + |spark.sql("select dummyUdf(id) from range(5)").as[Long].collect() + """.stripMargin + val output = runCommandsInShell(input) + assertContains("Array[Long] = Array(0L, 100L, 200L, 300L, 400L)", output) + } + + test("UDF closure registration") { + val input = """ + |class A(x: Int) { def get = x * 15 } + |spark.udf.register("directUdf", (x: Int) => new A(x).get) + |spark.sql("select directUdf(id) from range(5)").as[Long].collect() + """.stripMargin + val output = runCommandsInShell(input) + assertContains("Array[Long] = Array(0L, 15L, 30L, 45L, 60L)", output) + } + + test("call_udf") { + val input = """ + |val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") + |spark.udf.register("simpleUDF", (v: Int) => v * v) + |df.select($"id", call_udf("simpleUDF", $"value")).collect() + """.stripMargin + val output = runCommandsInShell(input) + assertContains("Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], [id3,25])", output) + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 921381caf53..130d22842b3 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -154,7 +154,6 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.SparkSessionExtensionsProvider"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UDTFRegistration"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UDFRegistration"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UDFRegistration$"), // DataFrame Reader & Writer diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 48d5e7509c3..e0bee824195 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -70,7 +70,7 @@ import org.apache.spark.sql.execution.python.{PythonForeachWriter, UserDefinedPy import org.apache.spark.sql.execution.stat.StatFunctions import org.apache.spark.sql.execution.streaming.GroupStateImpl.groupStateTimeoutFromString import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper -import org.apache.spark.sql.expressions.ReduceAggregator +import org.apache.spark.sql.expressions.{ReduceAggregator, SparkUserDefinedFunction} import org.apache.spark.sql.internal.{CatalogImpl, TypedAggUtils} import org.apache.spark.sql.protobuf.{CatalystDataToProtobuf, ProtobufDataToCatalyst} import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StreamingQuery, StreamingQueryListener, StreamingQueryProgress, Trigger} @@ -1487,6 +1487,20 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { udfDeterministic = fun.getDeterministic) } + private def transformScalarScalaFunction( + fun: proto.CommonInlineUserDefinedFunction): SparkUserDefinedFunction = { + val udf = fun.getScalarScalaUdf + val udfPacket = unpackUdf(fun) + SparkUserDefinedFunction( + f = udfPacket.function, + dataType = transformDataType(udf.getOutputType), + inputEncoders = udfPacket.inputEncoders.map(e => Try(ExpressionEncoder(e)).toOption), + outputEncoder = Option(ExpressionEncoder(udfPacket.outputEncoder)), + name = Option(fun.getFunctionName), + nullable = udf.getNullable, + deterministic = fun.getDeterministic) + } + /** * Translates a Python user-defined function from proto to the Catalyst expression. * @@ -2415,6 +2429,8 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { handleRegisterPythonUDF(fun) case proto.CommonInlineUserDefinedFunction.FunctionCase.JAVA_UDF => handleRegisterJavaUDF(fun) + case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF => + handleRegisterScalarScalaUDF(fun) case _ => throw InvalidPlanInput( s"Function with ID: ${fun.getFunctionCase.getNumber} is not supported") @@ -2448,6 +2464,11 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { } } + private def handleRegisterScalarScalaUDF(fun: proto.CommonInlineUserDefinedFunction): Unit = { + val udf = transformScalarScalaFunction(fun) + session.udf.register(fun.getFunctionName, udf) + } + private def handleCommandPlugin(extension: ProtoAny): Unit = { SparkConnectPluginRegistry.commandRegistry // Lazily traverse the collection. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org