Github user viirya commented on a diff in the pull request: https://github.com/apache/spark/pull/20288#discussion_r161966469 --- Diff: python/pyspark/sql/session.py --- @@ -778,6 +778,146 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.stop() +class UDFRegistration(object): + """Wrapper for user-defined function registration.""" + + def __init__(self, sparkSession): + self.sparkSession = sparkSession + + @ignore_unicode_prefix + def register(self, name, f, returnType=None): + """Registers a Python function (including lambda function) or a user-defined function + in SQL statements. + + :param name: name of the user-defined function in SQL statements. + :param f: a Python function, or a user-defined function. The user-defined function can + be either row-at-a-time or vectorized. See :meth:`pyspark.sql.functions.udf` and + :meth:`pyspark.sql.functions.pandas_udf`. + :param returnType: the return type of the registered user-defined function. + :return: a user-defined function. + + `returnType` can be optionally specified when `f` is a Python function but not + when `f` is a user-defined function. See below: + + 1. When `f` is a Python function, `returnType` defaults to string type and can be + optionally specified. The produced object must match the specified type. In this case, + this API works as if `register(name, f, returnType=StringType())`. + + >>> strlen = spark.udf.register("stringLengthString", lambda x: len(x)) + >>> spark.sql("SELECT stringLengthString('test')").collect() + [Row(stringLengthString(test)=u'4')] + + >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect() + [Row(stringLengthString(text)=u'3')] + + >>> from pyspark.sql.types import IntegerType + >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> spark.sql("SELECT stringLengthInt('test')").collect() + [Row(stringLengthInt(test)=4)] + + >>> from pyspark.sql.types import IntegerType + >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> spark.sql("SELECT stringLengthInt('test')").collect() + [Row(stringLengthInt(test)=4)] + + 2. When `f` is a user-defined function, Spark uses the return type of the given a + user-defined function as the return type of the registered a user-defined function. --- End diff -- the registered a user-defined function -> the registered user-defined function
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org