Github user HyukjinKwon commented on a diff in the pull request:
https://github.com/apache/spark/pull/21737#discussion_r201764065
--- Diff: python/pyspark/sql/tests.py ---
@@ -5925,6 +5925,22 @@ def test_invalid_args(self):
'mixture.*aggregate function.*group aggregate pandas
UDF'):
df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
+ def test_self_join_with_pandas(self):
+ import pyspark.sql.functions as F
+
+ @F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP)
+ def dummy_pandas_udf(df):
+ return df[['key', 'col']]
+
+ df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1,
col='B'),
+ Row(key=2, col='C')])
+ dfWithPandas = df.groupBy('key').apply(dummy_pandas_udf)
--- End diff --
nit: `dfWithPandas` -> `df_with_pandas`
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]