Github user BryanCutler commented on a diff in the pull request:
https://github.com/apache/spark/pull/18732#discussion_r142517310
--- Diff: python/pyspark/sql/group.py ---
@@ -194,6 +194,65 @@ def pivot(self, pivot_col, values=None):
jgd = self._jgd.pivot(pivot_col, values)
return GroupedData(jgd, self.sql_ctx)
+ def apply(self, udf):
+ """
+ Maps each group of the current :class:`DataFrame` using a pandas
udf and returns the result
+ as a :class:`DataFrame`.
+
+ The user-function should take a `pandas.DataFrame` and return
another `pandas.DataFrame`.
+ Each group is passed as a `pandas.DataFrame` to the user-function
and the returned
+ `pandas.DataFrame` are combined as a :class:`DataFrame`. The
returned `pandas.DataFrame`
+ can be arbitrary length and its schema should match the returnType
of the pandas udf.
+
+ :param udf: A wrapped function returned by `pandas_udf`
+
+ >>> df = spark.createDataFrame(
+ ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
+ ... ("id", "v"))
+ >>> @pandas_udf(returnType=df.schema)
+ ... def normalize(pdf):
+ ... v = pdf.v
+ ... return pdf.assign(v=(v - v.mean()) / v.std())
+ >>> df.groupby('id').apply(normalize).show() # doctest: + SKIP
+ +---+-------------------+
+ | id| v|
+ +---+-------------------+
+ | 1|-0.7071067811865475|
+ | 1| 0.7071067811865475|
+ | 2|-0.8320502943378437|
+ | 2|-0.2773500981126146|
+ | 2| 1.1094003924504583|
+ +---+-------------------+
+
+ .. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
+
+ """
+ from pyspark.sql.functions import pandas_udf
+
+ # Columns are special because hasattr always return True
+ if isinstance(udf, Column) or not hasattr(udf, 'func') or not
udf.vectorized:
+ raise ValueError("The argument to apply must be a pandas_udf")
+ if not isinstance(udf.returnType, StructType):
+ raise ValueError("The returnType of the pandas_udf must be a
StructType")
+
+ df = DataFrame(self._jgd.df(), self.sql_ctx)
+ func = udf.func
+ returnType = udf.returnType
+
+ # The python executors expects the function to take a list of
pd.Series as input
+ # So we to create a wrapper function that turns that to a
pd.DataFrame before passing
+ # down to the user function
+ columns = df.columns
+
+ def wrapped(*cols):
+ import pandas as pd
+ return func(pd.concat(cols, axis=1, keys=columns))
--- End diff --
What do you all think about making the `ArrowPandasSerializer` also able to
serialize pandas.DataFrames? Then it wouldn't require this extra wrapping and
I think it could be useful for other things in the future as well.
@HyukjinKwon @ueshin @viirya ?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]