Github user icexelloss commented on a diff in the pull request:
https://github.com/apache/spark/pull/20295#discussion_r172249785
--- Diff: python/pyspark/sql/functions.py ---
@@ -2253,6 +2253,30 @@ def pandas_udf(f=None, returnType=None,
functionType=None):
| 2| 1.1094003924504583|
+---+-------------------+
+ Alternatively, the user can define a function that takes two
arguments.
+ In this case, the grouping key will be passed as the first argument
and the data will
+ be passed as the second argument. The grouping key will be passed
as a tuple of numpy
+ data types, e.g., `numpy.int32` and `numpy.float64`. The data will
still be passed in
+ as a `pandas.DataFrame` containing all columns from the original
Spark DataFrame.
+ This is useful when the user doesn't want to hardcode grouping key
in the function.
+
+ >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
+ >>> df = spark.createDataFrame(
+ ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
+ ... ("id", "v")) # doctest: +SKIP
+ >>> @pandas_udf("id long, v double", PandasUDFType.GROUP_MAP) #
doctest: +SKIP
+ ... def mean_udf(key, pdf):
+ ... # key is a tuple of one numpy.int64, which is the value
+ ... # of 'id' for the current group
+ ... return pd.DataFrame([key + (pdf.v.mean(),)])
--- End diff --
Added
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]