Github user HyukjinKwon commented on a diff in the pull request:
https://github.com/apache/spark/pull/18732#discussion_r142947514
--- Diff: python/pyspark/sql/group.py ---
@@ -192,7 +193,67 @@ def pivot(self, pivot_col, values=None):
jgd = self._jgd.pivot(pivot_col)
else:
jgd = self._jgd.pivot(pivot_col, values)
- return GroupedData(jgd, self.sql_ctx)
+ return GroupedData(jgd, self._df)
+
+ def apply(self, udf):
+ """
+ Maps each group of the current :class:`DataFrame` using a pandas
udf and returns the result
+ as a :class:`DataFrame`.
+
+ The user-defined function should take a `pandas.DataFrame` and
return another
+ `pandas.DataFrame`. Each group is passed as a `pandas.DataFrame`
to the user-function and
+ the returned`pandas.DataFrame` are combined as a
:class:`DataFrame`. The returned
+ `pandas.DataFrame` can be arbitrary length and its schema should
match the returnType of
+ the pandas udf.
+
+ :param udf: A wrapped function returned by `pandas_udf`
+
+ >>> df = spark.createDataFrame(
+ ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
+ ... ("id", "v"))
+ >>> @pandas_udf(returnType=df.schema)
+ ... def normalize(pdf):
+ ... v = pdf.v
+ ... return pdf.assign(v=(v - v.mean()) / v.std())
+ >>> df.groupby('id').apply(normalize).show() # doctest: +SKIP
--- End diff --
Also, it looks this file does not define `spark` as a global that is used
in doctests. I think we should add something like ...
```diff
sc = spark.sparkContext
globs['sc'] = sc
+ globs['spark'] = spark
globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \
```
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]