Github user viirya commented on a diff in the pull request:
https://github.com/apache/spark/pull/20288#discussion_r161966507
--- Diff: python/pyspark/sql/session.py ---
@@ -778,6 +778,146 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
+class UDFRegistration(object):
+ """Wrapper for user-defined function registration."""
+
+ def __init__(self, sparkSession):
+ self.sparkSession = sparkSession
+
+ @ignore_unicode_prefix
+ def register(self, name, f, returnType=None):
+ """Registers a Python function (including lambda function) or a
user-defined function
+ in SQL statements.
+
+ :param name: name of the user-defined function in SQL statements.
+ :param f: a Python function, or a user-defined function. The
user-defined function can
+ be either row-at-a-time or vectorized. See
:meth:`pyspark.sql.functions.udf` and
+ :meth:`pyspark.sql.functions.pandas_udf`.
+ :param returnType: the return type of the registered user-defined
function.
+ :return: a user-defined function.
+
+ `returnType` can be optionally specified when `f` is a Python
function but not
+ when `f` is a user-defined function. See below:
+
+ 1. When `f` is a Python function, `returnType` defaults to string
type and can be
+ optionally specified. The produced object must match the specified
type. In this case,
+ this API works as if `register(name, f, returnType=StringType())`.
+
+ >>> strlen = spark.udf.register("stringLengthString", lambda
x: len(x))
+ >>> spark.sql("SELECT stringLengthString('test')").collect()
+ [Row(stringLengthString(test)=u'4')]
+
+ >>> spark.sql("SELECT 'foo' AS
text").select(strlen("text")).collect()
+ [Row(stringLengthString(text)=u'3')]
+
+ >>> from pyspark.sql.types import IntegerType
+ >>> _ = spark.udf.register("stringLengthInt", lambda x:
len(x), IntegerType())
+ >>> spark.sql("SELECT stringLengthInt('test')").collect()
+ [Row(stringLengthInt(test)=4)]
+
+ >>> from pyspark.sql.types import IntegerType
+ >>> _ = spark.udf.register("stringLengthInt", lambda x:
len(x), IntegerType())
+ >>> spark.sql("SELECT stringLengthInt('test')").collect()
+ [Row(stringLengthInt(test)=4)]
+
+ 2. When `f` is a user-defined function, Spark uses the return type
of the given a
--- End diff --
of the given a user-defined function -> of the given user-defined function
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]