Github user gatorsmile commented on a diff in the pull request:
https://github.com/apache/spark/pull/20295#discussion_r214795846
--- Diff: python/pyspark/sql/tests.py ---
@@ -4588,6 +4613,80 @@ def test_timestamp_dst(self):
result = df.groupby('time').apply(foo_udf).sort('time')
self.assertPandasEqual(df.toPandas(), result.toPandas())
+ def test_udf_with_key(self):
+ from pyspark.sql.functions import pandas_udf, col, PandasUDFType
+ df = self.data
+ pdf = df.toPandas()
+
+ def foo1(key, pdf):
+ import numpy as np
+ assert type(key) == tuple
+ assert type(key[0]) == np.int64
+
+ return pdf.assign(v1=key[0],
+ v2=pdf.v * key[0],
+ v3=pdf.v * pdf.id,
+ v4=pdf.v * pdf.id.mean())
+
+ def foo2(key, pdf):
+ import numpy as np
+ assert type(key) == tuple
+ assert type(key[0]) == np.int64
+ assert type(key[1]) == np.int32
+
+ return pdf.assign(v1=key[0],
+ v2=key[1],
+ v3=pdf.v * key[0],
+ v4=pdf.v + key[1])
+
+ def foo3(key, pdf):
+ assert type(key) == tuple
+ assert len(key) == 0
+ return pdf.assign(v1=pdf.v * pdf.id)
+
+ # v2 is int because numpy.int64 * pd.Series<int32> results in
pd.Series<int32>
+ # v3 is long because pd.Series<int64> * pd.Series<int32> results
in pd.Series<int64>
+ udf1 = pandas_udf(
+ foo1,
+ 'id long, v int, v1 long, v2 int, v3 long, v4 double',
+ PandasUDFType.GROUPED_MAP)
+
+ udf2 = pandas_udf(
+ foo2,
+ 'id long, v int, v1 long, v2 int, v3 int, v4 int',
+ PandasUDFType.GROUPED_MAP)
+
+ udf3 = pandas_udf(
+ foo3,
+ 'id long, v int, v1 long',
+ PandasUDFType.GROUPED_MAP)
+
+ # Test groupby column
+ result1 = df.groupby('id').apply(udf1).sort('id', 'v').toPandas()
+ expected1 = pdf.groupby('id')\
+ .apply(lambda x: udf1.func((x.id.iloc[0],), x))\
+ .sort_values(['id', 'v']).reset_index(drop=True)
+ self.assertPandasEqual(expected1, result1)
+
+ # Test groupby expression
+ result2 = df.groupby(df.id % 2).apply(udf1).sort('id',
'v').toPandas()
+ expected2 = pdf.groupby(pdf.id % 2)\
+ .apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\
+ .sort_values(['id', 'v']).reset_index(drop=True)
+ self.assertPandasEqual(expected2, result2)
+
+ # Test complex groupby
+ result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id',
'v').toPandas()
--- End diff --
Any negative test case when the number of columns specified in groupby is
different from the definition of udf (foo2)?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]