Github user HyukjinKwon commented on a diff in the pull request:
https://github.com/apache/spark/pull/19872#discussion_r162046471
--- Diff: python/pyspark/sql/functions.py ---
@@ -2214,6 +2216,37 @@ def pandas_udf(f=None, returnType=None,
functionType=None):
.. seealso:: :meth:`pyspark.sql.GroupedData.apply`
+ 3. GROUP_AGG
+
+ A group aggregate UDF defines a transformation: One or more
`pandas.Series` -> A scalar
+ The returnType should be a primitive data type, e.g, `DoubleType()`.
+ The returned scalar can be either a python primitive type, e.g.,
`int` or `float`
+ or a numpy data type, e.g., `numpy.int64` or `numpy.float64`.
+
+ StructType and ArrayType are currently not supported.
+
+ Group aggregate UDFs are used with
:meth:`pyspark.sql.GroupedData.agg`
+
+ >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
+ >>> df = spark.createDataFrame(
+ ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
+ ... ("id", "v"))
+ >>> @pandas_udf("double", PandasUDFType.GROUP_AGG)
+ ... def mean_udf(v):
+ ... return v.mean()
+ >>> df.groupby("id").agg(mean_udf(df['v'])).show() # doctest: +SKIP
+ +---+-----------+
+ | id|mean_udf(v)|
+ +---+-----------+
+ | 1| 1.5|
+ | 2| 6.0|
+ +---+-----------+
+
+ .. note:: There is no partial aggregation with group aggregate
UDFs, i.e.,
+ a full shuffle is required.
+
+ .. seealso:: :meth:`pyspark.sql.GroupedData.agg`
+
--- End diff --
Ah, it's going to make a conflict with
https://github.com/apache/spark/pull/20288. I am find with merging this one
first. I will resolve the conflict. Will take a final look now.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]