Github user icexelloss commented on a diff in the pull request:
https://github.com/apache/spark/pull/19872#discussion_r165385572
--- Diff: python/pyspark/sql/tests.py ---
@@ -4353,6 +4347,446 @@ def test_unsupported_types(self):
df.groupby('id').apply(f).collect()
[email protected](not _have_pandas or not _have_arrow, "Pandas or Arrow not
installed")
+class GroupbyAggPandasUDFTests(ReusedSQLTestCase):
+
+ @property
+ def data(self):
+ from pyspark.sql.functions import array, explode, col, lit
+ return self.spark.range(10).toDF('id') \
+ .withColumn("vs", array([lit(i * 1.0) + col('id') for i in
range(20, 30)])) \
+ .withColumn("v", explode(col('vs'))) \
+ .drop('vs') \
+ .withColumn('w', lit(1.0))
+
+ @property
+ def python_plus_one(self):
+ from pyspark.sql.functions import udf
+
+ @udf('double')
+ def plus_one(v):
+ assert isinstance(v, (int, float))
+ return v + 1
+ return plus_one
+
+ @property
+ def pandas_scalar_plus_two(self):
+ import pandas as pd
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ @pandas_udf('double', PandasUDFType.SCALAR)
+ def plus_two(v):
+ assert isinstance(v, pd.Series)
+ return v + 2
+ return plus_two
+
+ @property
+ def pandas_agg_mean_udf(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ @pandas_udf('double', PandasUDFType.GROUP_AGG)
+ def avg(v):
+ return v.mean()
+ return avg
+
+ @property
+ def pandas_agg_sum_udf(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ @pandas_udf('double', PandasUDFType.GROUP_AGG)
+ def sum(v):
+ return v.sum()
+ return sum
+
+ @property
+ def pandas_agg_weighted_mean_udf(self):
+ import numpy as np
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ @pandas_udf('double', PandasUDFType.GROUP_AGG)
+ def weighted_mean(v, w):
+ return np.average(v, weights=w)
+ return weighted_mean
+
+ def test_manual(self):
+ df = self.data
+ sum_udf = self.pandas_agg_sum_udf
+ mean_udf = self.pandas_agg_mean_udf
+
+ result1 = df.groupby('id').agg(sum_udf(df.v),
mean_udf(df.v)).sort('id')
+ expected1 = self.spark.createDataFrame(
+ [[0, 245.0, 24.5],
+ [1, 255.0, 25.5],
+ [2, 265.0, 26.5],
+ [3, 275.0, 27.5],
+ [4, 285.0, 28.5],
+ [5, 295.0, 29.5],
+ [6, 305.0, 30.5],
+ [7, 315.0, 31.5],
+ [8, 325.0, 32.5],
+ [9, 335.0, 33.5]],
+ ['id', 'sum(v)', 'avg(v)'])
+
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+ def test_basic(self):
+ from pyspark.sql.functions import col, lit, sum, mean
+
+ df = self.data
+ weighted_mean_udf = self.pandas_agg_weighted_mean_udf
+
+ # Groupby one column and aggregate one UDF with literal
+ result1 = df.groupby('id').agg(weighted_mean_udf(df.v,
lit(1.0))).sort('id')
+ expected1 =
df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort('id')
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+ # Groupby one expression and aggregate one UDF with literal
+ result2 = df.groupby((col('id') + 1)).agg(weighted_mean_udf(df.v,
lit(1.0)))\
+ .sort(df.id + 1)
+ expected2 = df.groupby((col('id') + 1))\
+ .agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort(df.id + 1)
+ self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
+
+ # Groupby one column and aggregate one UDF without literal
+ result3 = df.groupby('id').agg(weighted_mean_udf(df.v,
df.w)).sort('id')
+ expected3 =
df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, w)')).sort('id')
+ self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
+
+ # Groupby one expression and aggregate one UDF without literal
+ result4 = df.groupby((col('id') + 1).alias('id'))\
+ .agg(weighted_mean_udf(df.v, df.w))\
+ .sort('id')
+ expected4 = df.groupby((col('id') + 1).alias('id'))\
+ .agg(mean(df.v).alias('weighted_mean(v, w)'))\
+ .sort('id')
+ self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
+
+ def test_unsupported_types(self):
+ from pyspark.sql.types import ArrayType, DoubleType, MapType
+ from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+ with QuietTest(self.sc):
+ with self.assertRaisesRegex(NotImplementedError, 'not
supported'):
--- End diff --
@ueshin Thanks for fixing this. (I am late to the party)
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]