Github user HyukjinKwon commented on a diff in the pull request:
https://github.com/apache/spark/pull/20288#discussion_r162035758
--- Diff: python/pyspark/sql/udf.py ---
@@ -181,3 +183,179 @@ def asNondeterministic(self):
"""
self.deterministic = False
return self
+
+
+class UDFRegistration(object):
+ """
+ Wrapper for user-defined function registration.
+
+ .. versionadded:: 1.3.1
+ """
+
+ def __init__(self, sparkSession):
+ self.sparkSession = sparkSession
+
+ @ignore_unicode_prefix
+ @since(1.3)
+ def register(self, name, f, returnType=None):
+ """Registers a Python function (including lambda function) or a
user-defined function
+ in SQL statements.
+
+ :param name: name of the user-defined function in SQL statements.
+ :param f: a Python function, or a user-defined function. The
user-defined function can
+ be either row-at-a-time or vectorized. See
:meth:`pyspark.sql.functions.udf` and
+ :meth:`pyspark.sql.functions.pandas_udf`.
+ :param returnType: the return type of the registered user-defined
function.
+ :return: a user-defined function.
+
+ `returnType` can be optionally specified when `f` is a Python
function but not
+ when `f` is a user-defined function. Please see below.
+
+ 1. When `f` is a Python function:
+
+ `returnType` defaults to string type and can be optionally
specified. The produced
+ object must match the specified type. In this case, this API
works as if
+ `register(name, f, returnType=StringType())`.
+
+ >>> strlen = spark.udf.register("stringLengthString", lambda
x: len(x))
+ >>> spark.sql("SELECT stringLengthString('test')").collect()
+ [Row(stringLengthString(test)=u'4')]
+
+ >>> spark.sql("SELECT 'foo' AS
text").select(strlen("text")).collect()
+ [Row(stringLengthString(text)=u'3')]
+
+ >>> from pyspark.sql.types import IntegerType
+ >>> _ = spark.udf.register("stringLengthInt", lambda x:
len(x), IntegerType())
+ >>> spark.sql("SELECT stringLengthInt('test')").collect()
+ [Row(stringLengthInt(test)=4)]
+
+ >>> from pyspark.sql.types import IntegerType
+ >>> _ = spark.udf.register("stringLengthInt", lambda x:
len(x), IntegerType())
+ >>> spark.sql("SELECT stringLengthInt('test')").collect()
+ [Row(stringLengthInt(test)=4)]
+
+ 2. When `f` is a user-defined function:
+
+ Spark uses the return type of the given user-defined function
as the return type of
+ the registered user-defined function. `returnType` should not
be specified.
+ In this case, this API works as if `register(name, f)`.
+
+ >>> from pyspark.sql.types import IntegerType
+ >>> from pyspark.sql.functions import udf
+ >>> slen = udf(lambda s: len(s), IntegerType())
+ >>> _ = spark.udf.register("slen", slen)
+ >>> spark.sql("SELECT slen('test')").collect()
+ [Row(slen(test)=4)]
+
+ >>> import random
+ >>> from pyspark.sql.functions import udf
+ >>> from pyspark.sql.types import IntegerType
+ >>> random_udf = udf(lambda: random.randint(0, 100),
IntegerType()).asNondeterministic()
+ >>> new_random_udf = spark.udf.register("random_udf",
random_udf)
+ >>> spark.sql("SELECT random_udf()").collect() # doctest:
+SKIP
+ [Row(random_udf()=82)]
+
+ >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
+ >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest:
+SKIP
+ ... def add_one(x):
+ ... return x + 1
+ ...
+ >>> _ = spark.udf.register("add_one", add_one) # doctest:
+SKIP
+ >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() #
doctest: +SKIP
+ [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
+
+ .. note:: Registration for a user-defined function (case 2.)
was added from
+ Spark 2.3.0.
+ """
--- End diff --
<img width="715" alt="2018-01-17 9 23 21"
src="https://user-images.githubusercontent.com/6477701/35042729-1acaa234-fbcd-11e7-9d3f-4e94dc200e2c.png">
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]