This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 0c27cb702c7 [SPARK-44715][CONNECT] Bring back callUdf and udf function 0c27cb702c7 is described below commit 0c27cb702c7b41b2518517d16d6d4108c6841271 Author: Herman van Hovell <her...@databricks.com> AuthorDate: Tue Aug 8 15:41:36 2023 +0200 [SPARK-44715][CONNECT] Bring back callUdf and udf function ### What changes were proposed in this pull request? This PR adds the `udf` (with a return type), and `callUDF` functions to `functions.scala` for the Spark Connect Scala Client. ### Why are the changes needed? We want the Spark Connect Scala Client to be as compatible as possible with the existing sql/core APIs. ### Does this PR introduce _any_ user-facing change? Yes. It adds more exposed functions. ### How was this patch tested? Added tests to `UserDefinedFunctionE2ETestSuite` and `FunctionTestSuite`. I have also updated the compatibility checks. Closes #42387 from hvanhovell/SPARK-44715. Authored-by: Herman van Hovell <her...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> (cherry picked from commit 8c444f497137d5abb3a94b576ec0fea55dc18bbc) Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../scala/org/apache/spark/sql/functions.scala | 40 ++++++++++++++++++++++ .../org/apache/spark/sql/FunctionTestSuite.scala | 2 ++ .../sql/UserDefinedFunctionE2ETestSuite.scala | 20 +++++++++++ .../CheckConnectJvmClientCompatibility.scala | 7 ---- 4 files changed, 62 insertions(+), 7 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index 89bfc998179..fa8c5782e06 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -8056,6 +8056,46 @@ object functions { } // scalastyle:off line.size.limit + /** + * Defines a deterministic user-defined function (UDF) using a Scala closure. For this variant, + * the caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the API + * `UserDefinedFunction.asNondeterministic()`. + * + * Note that, although the Scala closure can have primitive-type function argument, it doesn't + * work well with null values. Because the Scala closure is passed in as Any type, there is no + * type information for the function arguments. Without the type information, Spark may blindly + * pass null to the Scala closure with primitive-type argument, and the closure will see the + * default value of the Java type for the null argument, e.g. `udf((x: Int) => x, IntegerType)`, + * the result is 0 for null input. + * + * @param f + * A closure in Scala + * @param dataType + * The output data type of the UDF + * + * @group udf_funcs + * @since 3.5.0 + */ + @deprecated( + "Scala `udf` method with return type parameter is deprecated. " + + "Please use Scala `udf` method without return type parameter.", + "3.0.0") + def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = { + ScalarUserDefinedFunction(f, dataType) + } + + /** + * Call an user-defined function. + * + * @group udf_funcs + * @since 3.5.0 + */ + @scala.annotation.varargs + @deprecated("Use call_udf") + def callUDF(udfName: String, cols: Column*): Column = + call_function(udfName, cols: _*) + /** * Call an user-defined function. Example: * {{{ diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala index 32004b6bcc1..4a8e108357f 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala @@ -249,6 +249,8 @@ class FunctionTestSuite extends ConnectFunSuite { pbFn.to_protobuf(a, "FakeMessage", "fakeBytes".getBytes(), Map.empty[String, String].asJava), pbFn.to_protobuf(a, "FakeMessage", "fakeBytes".getBytes())) + testEquals("call_udf", callUDF("bob", lit(1)), call_udf("bob", lit(1))) + test("assert_true no message") { val e = assert_true(a).expr assert(e.hasUnresolvedFunction) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala index 258fa1e7c74..3a931c9a6ba 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala @@ -24,9 +24,11 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.JavaConverters._ import org.apache.spark.api.java.function._ +import org.apache.spark.sql.api.java.UDF2 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveIntEncoder, PrimitiveLongEncoder} import org.apache.spark.sql.connect.client.util.QueryTest import org.apache.spark.sql.functions.{col, struct, udf} +import org.apache.spark.sql.types.IntegerType /** * All tests in this class requires client UDF defined in this test class synced with the server. @@ -250,4 +252,22 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest { "b", "c") } + + test("(deprecated) scala UDF with dataType") { + val session: SparkSession = spark + import session.implicits._ + val fn = udf(((i: Long) => (i + 1).toInt), IntegerType) + checkDataset(session.range(2).select(fn($"id")).as[Int], 1, 2) + } + + test("java UDF") { + val session: SparkSession = spark + import session.implicits._ + val fn = udf( + new UDF2[Long, Long, Int] { + override def call(t1: Long, t2: Long): Int = (t1 + t2 + 1).toInt + }, + IntegerType) + checkDataset(session.range(2).select(fn($"id", $"id" + 2)).as[Int], 3, 5) + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 2bf9c41fb2c..d380a1bbb65 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -191,8 +191,6 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.javaRDD"), // functions - ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udf"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.callUDF"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.unwrap_udt"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udaf"), @@ -214,14 +212,11 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.listenerManager"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.experimental"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.udtf"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.streams"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.createDataFrame"), ProblemFilters.exclude[Problem]( "org.apache.spark.sql.SparkSession.baseRelationToDataFrame"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.createDataset"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.executeCommand"), - // TODO(SPARK-44068): Support positional parameters in Scala connect client - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sql"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.this"), // SparkSession#implicits @@ -266,8 +261,6 @@ object CheckConnectJvmClientCompatibility { "org.apache.spark.sql.streaming.StreamingQueryException.time"), // Classes missing from streaming API - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ForeachWriter"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.GroupState"), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.streaming.TestGroupState"), ProblemFilters.exclude[MissingClassProblem]( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org