Github user HyukjinKwon commented on a diff in the pull request:
https://github.com/apache/spark/pull/20171#discussion_r161414948
--- Diff: python/pyspark/sql/catalog.py ---
@@ -256,27 +258,58 @@ def registerFunction(self, name, f,
returnType=StringType()):
>>> spark.sql("SELECT stringLengthInt('test')").collect()
[Row(stringLengthInt(test)=4)]
+ >>> from pyspark.sql.types import IntegerType
+ >>> from pyspark.sql.functions import udf
+ >>> slen = udf(lambda s: len(s), IntegerType())
+ >>> _ = spark.udf.register("slen", slen)
+ >>> spark.sql("SELECT slen('test')").collect()
+ [Row(slen(test)=4)]
+
>>> import random
>>> from pyspark.sql.functions import udf
- >>> from pyspark.sql.types import IntegerType, StringType
+ >>> from pyspark.sql.types import IntegerType
>>> random_udf = udf(lambda: random.randint(0, 100),
IntegerType()).asNondeterministic()
- >>> newRandom_udf = spark.catalog.registerFunction("random_udf",
random_udf, StringType())
+ >>> newRandom_udf = spark.udf.register("random_udf", random_udf)
>>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP
- [Row(random_udf()=u'82')]
+ [Row(random_udf()=82)]
>>> spark.range(1).select(newRandom_udf()).collect() # doctest:
+SKIP
- [Row(random_udf()=u'62')]
+ [Row(<lambda>()=26)]
+
+ >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
+ >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP
+ ... def add_one(x):
+ ... return x + 1
+ ...
+ >>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP
+ >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() #
doctest: +SKIP
+ [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
"""
# This is to check whether the input function is a wrapped/native
UserDefinedFunction
if hasattr(f, 'asNondeterministic'):
- udf = UserDefinedFunction(f.func, returnType=returnType,
name=name,
-
evalType=PythonEvalType.SQL_BATCHED_UDF,
- deterministic=f.deterministic)
+ if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF,
+ PythonEvalType.SQL_PANDAS_SCALAR_UDF]:
+ raise ValueError(
+ "Invalid f: f must be either SQL_BATCHED_UDF or
SQL_PANDAS_SCALAR_UDF")
+ if returnType is not None and not isinstance(returnType,
DataType):
+ returnType = _parse_datatype_string(returnType)
+ if returnType is not None and returnType != f.returnType:
--- End diff --
I mean we can simply throw an exception always if `returnType` is given
(not `None`) but `f` is a udf. I thought we try to resemable an overloading
for`register(name, f)`.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]