Github user HyukjinKwon commented on a diff in the pull request:
https://github.com/apache/spark/pull/19872#discussion_r162047047
--- Diff: python/pyspark/sql/tests.py ---
@@ -4279,6 +4272,425 @@ def test_unsupported_types(self):
df.groupby('id').apply(f).collect()
[email protected](not _have_pandas or not _have_arrow, "Pandas or Arrow not
installed")
+class GroupbyAggPandasUDFTests(ReusedSQLTestCase):
+
+ @property
+ def data(self):
+ from pyspark.sql.functions import array, explode, col, lit
+ return self.spark.range(10).toDF('id') \
+ .withColumn("vs", array([lit(i * 1.0) + col('id') for i in
range(20, 30)])) \
+ .withColumn("v", explode(col('vs'))) \
+ .drop('vs') \
+ .withColumn('w', lit(1.0))
+
+ @property
+ def python_plus_one(self):
+ from pyspark.sql.functions import udf
+
+ @udf('double')
+ def plus_one(v):
+ assert isinstance(v, (int, float))
+ return v + 1
+ return plus_one
+
+ @property
+ def pandas_scalar_plus_two(self):
+ import pandas as pd
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ @pandas_udf('double', PandasUDFType.SCALAR)
+ def plus_two(v):
+ assert isinstance(v, pd.Series)
+ return v + 2
+ return plus_two
+
+ @property
+ def mean_udf(self):
--- End diff --
Shall we add a prefix here too?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]