Github user HyukjinKwon commented on a diff in the pull request:
https://github.com/apache/spark/pull/20288#discussion_r162031680
--- Diff: python/pyspark/sql/udf.py ---
@@ -181,3 +183,180 @@ def asNondeterministic(self):
"""
self.deterministic = False
return self
+
+
+class UDFRegistration(object):
+ """
+ Wrapper for user-defined function registration.
+
+ .. versionadded:: 1.3.1
+ """
+
+ def __init__(self, sparkSession):
+ self.sparkSession = sparkSession
+
+ @ignore_unicode_prefix
+ @since(1.3)
+ def register(self, name, f, returnType=None):
+ """Registers a Python function (including lambda function) or a
user-defined function
+ in SQL statements.
+
+ :param name: name of the user-defined function in SQL statements.
+ :param f: a Python function, or a user-defined function. The
user-defined function can
+ be either row-at-a-time or vectorized. See
:meth:`pyspark.sql.functions.udf` and
+ :meth:`pyspark.sql.functions.pandas_udf`.
+ :param returnType: the return type of the registered user-defined
function.
+ :return: a user-defined function.
+
+ `returnType` can be optionally specified when `f` is a Python
function but not
+ when `f` is a user-defined function. See below:
+
+ 1. When `f` is a Python function:
+
+ `returnType` defaults to string type and can be optionally
specified. The produced
+ object must match the specified type. In this case, this API
works as if
+ `register(name, f, returnType=StringType())`.
+
+ >>> strlen = spark.udf.register("stringLengthString", lambda
x: len(x))
+ >>> spark.sql("SELECT stringLengthString('test')").collect()
+ [Row(stringLengthString(test)=u'4')]
+
+ >>> spark.sql("SELECT 'foo' AS
text").select(strlen("text")).collect()
+ [Row(stringLengthString(text)=u'3')]
+
+ >>> from pyspark.sql.types import IntegerType
+ >>> _ = spark.udf.register("stringLengthInt", lambda x:
len(x), IntegerType())
+ >>> spark.sql("SELECT stringLengthInt('test')").collect()
+ [Row(stringLengthInt(test)=4)]
+
+ >>> from pyspark.sql.types import IntegerType
+ >>> _ = spark.udf.register("stringLengthInt", lambda x:
len(x), IntegerType())
+ >>> spark.sql("SELECT stringLengthInt('test')").collect()
+ [Row(stringLengthInt(test)=4)]
+
+
+ 2. When `f` is a user-defined function:
+
+ Spark uses the return type of the given user-defined function
as the return type of
+ the registered user-defined function. `returnType` should not
be specified.
+ In this case, this API works as if `register(name, f)`.
+
+ >>> from pyspark.sql.types import IntegerType
+ >>> from pyspark.sql.functions import udf
+ >>> slen = udf(lambda s: len(s), IntegerType())
+ >>> _ = spark.udf.register("slen", slen)
+ >>> spark.sql("SELECT slen('test')").collect()
+ [Row(slen(test)=4)]
+
+ >>> import random
+ >>> from pyspark.sql.functions import udf
+ >>> from pyspark.sql.types import IntegerType
+ >>> random_udf = udf(lambda: random.randint(0, 100),
IntegerType()).asNondeterministic()
+ >>> new_random_udf = spark.udf.register("random_udf",
random_udf)
+ >>> spark.sql("SELECT random_udf()").collect() # doctest:
+SKIP
+ [Row(random_udf()=82)]
+
+ >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
+ >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest:
+SKIP
+ ... def add_one(x):
+ ... return x + 1
+ ...
+ >>> _ = spark.udf.register("add_one", add_one) # doctest:
+SKIP
+ >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() #
doctest: +SKIP
+ [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
+
+ .. note:: Registration for a user-defined function (case 2.)
was added from
+ Spark 2.3.0.
+ """
+
+ # This is to check whether the input function is from a
user-defined function or
+ # Python function.
+ if hasattr(f, 'asNondeterministic'):
+ if returnType is not None:
+ raise TypeError(
+ "Invalid returnType: data type can not be specified
when f is"
+ "a user-defined function, but got %s." % returnType)
+ if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF,
+ PythonEvalType.SQL_PANDAS_SCALAR_UDF]:
+ raise ValueError(
+ "Invalid f: f must be either SQL_BATCHED_UDF or
SQL_PANDAS_SCALAR_UDF")
+ register_udf = UserDefinedFunction(f.func,
returnType=f.returnType, name=name,
+ evalType=f.evalType,
+
deterministic=f.deterministic)
+ return_udf = f
+ else:
+ if returnType is None:
+ returnType = StringType()
+ register_udf = UserDefinedFunction(f, returnType=returnType,
name=name,
+
evalType=PythonEvalType.SQL_BATCHED_UDF)
+ return_udf = register_udf._wrapped()
+ self.sparkSession._jsparkSession.udf().registerPython(name,
register_udf._judf)
+ return return_udf
+
+ @ignore_unicode_prefix
+ @since(2.3)
+ def registerJavaFunction(self, name, javaClassName, returnType=None):
--- End diff --
`registerJavaFunction` and `registerJavaUDAF` look introduced from 2.3.0 -
https://issues.apache.org/jira/browse/SPARK-19439
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]