Github user HyukjinKwon commented on a diff in the pull request:
https://github.com/apache/spark/pull/21082#discussion_r192140997
--- Diff: python/pyspark/sql/tests.py ---
@@ -5181,6 +5190,235 @@ def test_invalid_args(self):
'mixture.*aggregate function.*group aggregate pandas
UDF'):
df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
+
[email protected](
+ not _have_pandas or not _have_pyarrow,
+ _pandas_requirement_message or _pyarrow_requirement_message)
+class WindowPandasUDFTests(ReusedSQLTestCase):
+ @property
+ def data(self):
+ from pyspark.sql.functions import array, explode, col, lit
+ return self.spark.range(10).toDF('id') \
+ .withColumn("vs", array([lit(i * 1.0) + col('id') for i in
range(20, 30)])) \
+ .withColumn("v", explode(col('vs'))) \
+ .drop('vs') \
+ .withColumn('w', lit(1.0))
+
+ @property
+ def python_plus_one(self):
+ from pyspark.sql.functions import udf
+ return udf(lambda v: v + 1, 'double')
+
+ @property
+ def pandas_scalar_time_two(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+ return pandas_udf(lambda v: v * 2, 'double')
+
+ @property
+ def pandas_agg_mean_udf(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ @pandas_udf('double', PandasUDFType.GROUPED_AGG)
+ def avg(v):
+ return v.mean()
+ return avg
+
+ @property
+ def pandas_agg_max_udf(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ @pandas_udf('double', PandasUDFType.GROUPED_AGG)
+ def max(v):
+ return v.max()
+ return max
+
+ @property
+ def pandas_agg_min_udf(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ @pandas_udf('double', PandasUDFType.GROUPED_AGG)
+ def min(v):
+ return v.min()
+ return min
+
+ @property
+ def unbounded_window(self):
+ return Window.partitionBy('id') \
+ .rowsBetween(Window.unboundedPreceding,
Window.unboundedFollowing)
+
+ @property
+ def ordered_window(self):
+ return Window.partitionBy('id').orderBy('v')
+
+ @property
+ def unpartitioned_window(self):
+ return Window.partitionBy()
+
+ def test_simple(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType,
percent_rank, mean, max
+
+ df = self.data
+ w = self.unbounded_window
+
+ mean_udf = self.pandas_agg_mean_udf
+
+ result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w))
+ expected1 = df.withColumn('mean_v', mean(df['v']).over(w))
+
+ result2 = df.select(mean_udf(df['v']).over(w))
+ expected2 = df.select(mean(df['v']).over(w))
+
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+ self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
+
+ def test_multiple_udfs(self):
+ from pyspark.sql.functions import max, min, mean
+
+ df = self.data
+ w = self.unbounded_window
+
+ result1 = df.withColumn('mean_v',
self.pandas_agg_mean_udf(df['v']).over(w)) \
+ .withColumn('max_v',
self.pandas_agg_max_udf(df['v']).over(w)) \
+ .withColumn('min_w',
self.pandas_agg_min_udf(df['w']).over(w)) \
--- End diff --
Trailing `\`.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]