Github user icexelloss commented on a diff in the pull request:
https://github.com/apache/spark/pull/20211#discussion_r160826435
--- Diff: python/pyspark/sql/group.py ---
@@ -233,6 +233,27 @@ def apply(self, udf):
| 2| 1.1094003924504583|
+---+-------------------+
+ Notes on grouping column:
--- End diff --
@cloud-fan That's what I thought too initially. Let's consider this use
case,
```
import statsmodels.api as sm
# df has four columns: id, y, x1, x2
group_column = 'id'
y_column = 'y'
x_columns = ['x1', 'x2']
schema = df.select(group_column, *x_columns).schema
@pandas_udf(schema, PandasUDFType.GROUP_MAP)
# Input/output are both a pandas.DataFrame
def ols(pdf):
group_key = pdf[group_column].iloc[0]
y = pdf[y_column]
X = pdf[x_columns]
X = sm.add_constant(X)
model = sm.OLS(y, X).fit()
return pd.DataFrame([[group_key] + [model.params[i] for i in
x_columns]], columns=[group_column] + x_columns)
beta = df.groupby(group_column).apply(ols)
```
This is a simple pandas UDF that does a linear regression. The issue is,
although the UDF (linear regression) has nothing to do with the grouping
column, the user needs to deal with grouping column in the UDF. In other words,
the UDF is coupled with the grouping column.
If we make it such that grouping columns are prepend to UDF result, then
the user can write something like this:
```
import statsmodels.api as sm
# df has four columns: id, y, x1, x2
group_column = 'id'
y_column = 'y'
x_columns = ['x1', 'x2']
schema = df.select(*x_columns).schema
@pandas_udf(schema, PandasUDFType.GROUP_MAP)
# Input/output are both a pandas.DataFrame
def ols(pdf):
y = pdf[y_column]
X = pdf[x_columns]
X = sm.add_constant(X)
model = sm.OLS(y, X).fit()
return pd.DataFrame([[model.params[i] for i in x_columns]],
columns=x_columns)
beta = df.groupby(group_column).apply(ols)
```
Now the UDF is cleaner because it only deals with columns that are relevant
to the regression. It also make the UDF more reusable, as the user can now do
something like:
```
beta1 = df.groupby('a').apply(ols)
beta2 = df.groupby('a', 'b').apply(ols)
```
Because the UDF is now decoupled with the grouping column, the user can
reuse the same udf with different grouping, which is not possible with the
current API.
@cloud-fan @HyukjinKwon What do you think?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]