Github user ueshin commented on a diff in the pull request:
https://github.com/apache/spark/pull/18732#discussion_r142581225
--- Diff: python/pyspark/sql/group.py ---
@@ -192,7 +193,66 @@ def pivot(self, pivot_col, values=None):
jgd = self._jgd.pivot(pivot_col)
else:
jgd = self._jgd.pivot(pivot_col, values)
- return GroupedData(jgd, self.sql_ctx)
+ return GroupedData(jgd, self)
+
+ def apply(self, udf):
+ """
+ Maps each group of the current :class:`DataFrame` using a pandas
udf and returns the result
+ as a :class:`DataFrame`.
+
+ The user-function should take a `pandas.DataFrame` and return
another `pandas.DataFrame`.
+ Each group is passed as a `pandas.DataFrame` to the user-function
and the returned
+ `pandas.DataFrame` are combined as a :class:`DataFrame`. The
returned `pandas.DataFrame`
+ can be arbitrary length and its schema should match the returnType
of the pandas udf.
+
+ :param udf: A wrapped function returned by `pandas_udf`
+
+ >>> df = spark.createDataFrame(
+ ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
+ ... ("id", "v"))
+ >>> @pandas_udf(returnType=df.schema)
+ ... def normalize(pdf):
+ ... v = pdf.v
+ ... return pdf.assign(v=(v - v.mean()) / v.std())
+ >>> df.groupby('id').apply(normalize).show() # doctest: + SKIP
--- End diff --
`# doctest: +SKIP`?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]