Repository: spark
Updated Branches:
  refs/heads/master f561a76b2 -> d099f414d


[SPARK-20674][SQL] Support registering UserDefinedFunction as named UDF

## What changes were proposed in this pull request?
For some reason we don't have an API to register UserDefinedFunction as named 
UDF. It is a no brainer to add one, in addition to the existing register 
functions we have.

## How was this patch tested?
Added a test case in UDFSuite for the new API.

Author: Reynold Xin <[email protected]>

Closes #17915 from rxin/SPARK-20674.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d099f414
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d099f414
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d099f414

Branch: refs/heads/master
Commit: d099f414d2cb53f5a61f6e77317c736be6f953a0
Parents: f561a76
Author: Reynold Xin <[email protected]>
Authored: Tue May 9 09:24:28 2017 -0700
Committer: Xiao Li <[email protected]>
Committed: Tue May 9 09:24:28 2017 -0700

----------------------------------------------------------------------
 .../org/apache/spark/sql/UDFRegistration.scala  | 22 +++++++++++++++++---
 .../scala/org/apache/spark/sql/UDFSuite.scala   |  7 +++++++
 2 files changed, 26 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d099f414/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index a576733..6accf1f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -70,15 +70,31 @@ class UDFRegistration private[sql] (functionRegistry: 
FunctionRegistry) extends
    * @param name the name of the UDAF.
    * @param udaf the UDAF needs to be registered.
    * @return the registered UDAF.
+   *
+   * @since 1.5.0
    */
-  def register(
-      name: String,
-      udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = {
+  def register(name: String, udaf: UserDefinedAggregateFunction): 
UserDefinedAggregateFunction = {
     def builder(children: Seq[Expression]) = ScalaUDAF(children, udaf)
     functionRegistry.registerFunction(name, builder)
     udaf
   }
 
+  /**
+   * Register a user-defined function (UDF), for a UDF that's already defined 
using the DataFrame
+   * API (i.e. of type UserDefinedFunction).
+   *
+   * @param name the name of the UDF.
+   * @param udf the UDF needs to be registered.
+   * @return the registered UDF.
+   *
+   * @since 2.2.0
+   */
+  def register(name: String, udf: UserDefinedFunction): UserDefinedFunction = {
+    def builder(children: Seq[Expression]) = 
udf.apply(children.map(Column.apply) : _*).expr
+    functionRegistry.registerFunction(name, builder)
+    udf
+  }
+
   // scalastyle:off line.size.limit
 
   /* register 0-22 were generated by this script

http://git-wip-us.apache.org/repos/asf/spark/blob/d099f414/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index ae6b2bc..6f8723a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -93,6 +93,13 @@ class UDFSuite extends QueryTest with SharedSQLContext {
     assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4)
   }
 
+  test("UDF defined using UserDefinedFunction") {
+    import functions.udf
+    val foo = udf((x: Int) => x + 1)
+    spark.udf.register("foo", foo)
+    assert(sql("select foo(5)").head().getInt(0) == 6)
+  }
+
   test("ZeroArgument UDF") {
     spark.udf.register("random0", () => { Math.random()})
     assert(sql("SELECT random0()").head().getDouble(0) >= 0.0)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to