Github user ueshin commented on a diff in the pull request:
https://github.com/apache/spark/pull/20171#discussion_r161711905
--- Diff: python/pyspark/sql/tests.py ---
@@ -3975,33 +4003,50 @@ def
test_vectorized_udf_timestamps_respect_session_timezone(self):
finally:
self.spark.conf.set("spark.sql.session.timeZone", orig_tz)
- def test_nondeterministic_udf(self):
+ def test_nondeterministic_vectorized_udf(self):
# Test that nondeterministic UDFs are evaluated only once in
chained UDF evaluations
from pyspark.sql.functions import udf, pandas_udf, col
@pandas_udf('double')
def plus_ten(v):
return v + 10
- random_udf = self.random_udf
+ random_udf = self.nondeterministic_vectorized_udf
df = self.spark.range(10).withColumn('rand', random_udf(col('id')))
result1 = df.withColumn('plus_ten(rand)',
plus_ten(df['rand'])).toPandas()
self.assertEqual(random_udf.deterministic, False)
self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] +
10))
- def test_nondeterministic_udf_in_aggregate(self):
+ def test_nondeterministic_vectorized_udf_in_aggregate(self):
from pyspark.sql.functions import pandas_udf, sum
df = self.spark.range(10)
- random_udf = self.random_udf
+ random_udf = self.nondeterministic_vectorized_udf
with QuietTest(self.sc):
with self.assertRaisesRegexp(AnalysisException,
'nondeterministic'):
df.groupby(df.id).agg(sum(random_udf(df.id))).collect()
with self.assertRaisesRegexp(AnalysisException,
'nondeterministic'):
df.agg(sum(random_udf(df.id))).collect()
+ def test_register_vectorized_udf_basic(self):
+ from pyspark.rdd import PythonEvalType
+ from pyspark.sql.functions import pandas_udf, col, expr
+ df = self.spark.range(10).select(
+ col('id').cast('int').alias('a'),
+ col('id').cast('int').alias('b'))
+ original_add = pandas_udf(lambda x, y: x + y, IntegerType())
+ self.assertEqual(original_add.deterministic, True)
+ self.assertEqual(original_add.evalType,
PythonEvalType.SQL_PANDAS_SCALAR_UDF)
+ new_add = self.spark.catalog.registerFunction("add1", original_add)
--- End diff --
`spark.udf.register` instead of `spark.catalog.registerFunction`?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]