This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new cfbf3c704f0f [SPARK-46976][PS] Implement `DataFrameGroupBy.corr`
cfbf3c704f0f is described below
commit cfbf3c704f0fd593ce383eaddada4d3fc3500659
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Feb 6 16:55:51 2024 +0800
[SPARK-46976][PS] Implement `DataFrameGroupBy.corr`
### What changes were proposed in this pull request?
Implement `DataFrameGroupBy.corr`
### Why are the changes needed?
for pandas parity
https://pandas.pydata.org/docs/reference/api/pandas.core.groupby.DataFrameGroupBy.corr.html
### Does this PR introduce _any_ user-facing change?
yes
```
In [5]: pdf = pd.DataFrame({'A': [0, 0, 0, 1, 1, 2], 'B': [-1, 2, 3, 5, 6,
0], 'C': [4, 6, 5, 1, 3, 0]}, columns=['A', 'B', 'C'])
In [6]: pdf.groupby("A").corr()
Out[6]:
B C
A
0 B 1.000000 0.720577
C 0.720577 1.000000
1 B 1.000000 1.000000
C 1.000000 1.000000
2 B NaN NaN
C NaN NaN
In [7]: psdf = ps.from_pandas(pdf)
In [8]: psdf.groupby("A").corr()
B C
A
0 B 1.000000 0.720577
C 0.720577 1.000000
1 B 1.000000 1.000000
C 1.000000 1.000000
2 B NaN NaN
C NaN NaN
```
### How was this patch tested?
added tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #45028 from zhengruifeng/ps_df_groupby_corr.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
dev/sparktestsupport/modules.py | 2 +
python/pyspark/pandas/groupby.py | 216 +++++++++++++++++++++
python/pyspark/pandas/missing/groupby.py | 1 -
.../tests/connect/groupby/test_parity_corr.py | 41 ++++
python/pyspark/pandas/tests/groupby/test_corr.py | 84 ++++++++
5 files changed, 343 insertions(+), 1 deletion(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 2ed2144fa64b..ff3b23ff573a 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -889,6 +889,7 @@ pyspark_pandas_slow = Module(
"pyspark.pandas.tests.indexes.test_reset_index",
"pyspark.pandas.tests.groupby.test_aggregate",
"pyspark.pandas.tests.groupby.test_apply_func",
+ "pyspark.pandas.tests.groupby.test_corr",
"pyspark.pandas.tests.groupby.test_cumulative",
"pyspark.pandas.tests.groupby.test_describe",
"pyspark.pandas.tests.groupby.test_groupby",
@@ -1174,6 +1175,7 @@ pyspark_pandas_connect_part1 = Module(
"pyspark.pandas.tests.connect.frame.test_parity_truncate",
"pyspark.pandas.tests.connect.groupby.test_parity_aggregate",
"pyspark.pandas.tests.connect.groupby.test_parity_apply_func",
+ "pyspark.pandas.tests.connect.groupby.test_parity_corr",
"pyspark.pandas.tests.connect.groupby.test_parity_cumulative",
"pyspark.pandas.tests.connect.groupby.test_parity_missing_data",
"pyspark.pandas.tests.connect.groupby.test_parity_split_apply",
diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py
index 4cce147b2606..ec47ab75c43c 100644
--- a/python/pyspark/pandas/groupby.py
+++ b/python/pyspark/pandas/groupby.py
@@ -76,6 +76,13 @@ from pyspark.pandas.missing.groupby import (
from pyspark.pandas.series import Series, first_series
from pyspark.pandas.spark import functions as SF
from pyspark.pandas.config import get_option
+from pyspark.pandas.correlation import (
+ compute,
+ CORRELATION_VALUE_1_COLUMN,
+ CORRELATION_VALUE_2_COLUMN,
+ CORRELATION_CORR_OUTPUT_COLUMN,
+ CORRELATION_COUNT_OUTPUT_COLUMN,
+)
from pyspark.pandas.utils import (
align_diff_frames,
is_name_like_tuple,
@@ -3928,6 +3935,215 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
# Cast columns to ``"float64"`` to match `pandas.DataFrame.groupby`.
return DataFrame(internal).astype("float64")
+ def corr(
+ self,
+ method: str = "pearson",
+ min_periods: int = 1,
+ numeric_only: bool = False,
+ ) -> "DataFrame":
+ """
+ Compute pairwise correlation of columns, excluding NA/null values.
+
+ .. versionadded:: 4.0.0
+
+ Parameters
+ ----------
+ method : {'pearson', 'spearman', 'kendall'}
+ * pearson : standard correlation coefficient
+ * spearman : Spearman rank correlation
+ * kendall : Kendall Tau correlation coefficient
+
+ min_periods : int, default 1
+ Minimum number of observations in window required to have a value
+ (otherwise result is NA).
+
+ numeric_only : bool, default False
+ Include only `float`, `int` or `boolean` data.
+
+ Returns
+ -------
+ DataFrame
+
+ See Also
+ --------
+ DataFrame.corrwith
+ Series.corr
+
+ Notes
+ -----
+ 1. Pearson, Kendall and Spearman correlation are currently computed
using pairwise
+ complete observations.
+
+ 2. The complexity of Kendall correlation is O(#row * #row), if the
dataset is too
+ large, sampling ahead of correlation computation is recommended.
+
+ Examples
+ --------
+ >>> df = ps.DataFrame(
+ ... {"A": [0, 0, 0, 1, 1, 2], "B": [-1, 2, 3, 5, 6, 0], "C": [4,
6, 5, 1, 3, 0]},
+ ... columns=["A", "B", "C"])
+ >>> df.groupby("A").corr()
+ B C
+ A
+ 0 B 1.000000 0.720577
+ C 0.720577 1.000000
+ 1 B 1.000000 1.000000
+ C 1.000000 1.000000
+ 2 B NaN NaN
+ C NaN NaN
+
+ >>> df.groupby("A").corr(min_periods=2)
+ B C
+ A
+ 0 B 1.000000 0.720577
+ C 0.720577 1.000000
+ 1 B 1.000000 1.000000
+ C 1.000000 1.000000
+ 2 B NaN NaN
+ C NaN NaN
+
+ >>> df.groupby("A").corr("spearman")
+ B C
+ A
+ 0 B 1.0 0.5
+ C 0.5 1.0
+ 1 B 1.0 1.0
+ C 1.0 1.0
+ 2 B NaN NaN
+ C NaN NaN
+
+ >>> df.groupby("A").corr('kendall')
+ B C
+ A
+ 0 B 1.000000 0.333333
+ C 0.333333 1.000000
+ 1 B 1.000000 1.000000
+ C 1.000000 1.000000
+ 2 B 1.000000 NaN
+ C NaN 1.000000
+ """
+ if method not in ["pearson", "spearman", "kendall"]:
+ raise ValueError(f"Invalid method {method}")
+
+ groupkey_names: List[str] = [str(key.name) for key in self._groupkeys]
+ internal, agg_columns, sdf = self._prepare_reduce(
+ groupkey_names=groupkey_names,
+ accepted_spark_types=(NumericType, BooleanType) if numeric_only
else None,
+ bool_to_numeric=False,
+ )
+
+ numeric_labels = [
+ label
+ for label in internal.column_labels
+ if isinstance(internal.spark_type_for(label), (NumericType,
BooleanType))
+ ]
+ numeric_scols: List[Column] = [
+ internal.spark_column_for(label).cast("double") for label in
numeric_labels
+ ]
+ numeric_col_names: List[str] = [name_like_string(label) for label in
numeric_labels]
+ num_scols = len(numeric_scols)
+
+ sdf = internal.spark_frame
+ index_1_col_name = verify_temp_column_name(sdf,
"__groupby_corr_index_1_temp_column__")
+ index_2_col_name = verify_temp_column_name(sdf,
"__groupby_corr_index_2_temp_column__")
+
+ pair_scols: List[Column] = []
+ for i in range(0, num_scols):
+ for j in range(i, num_scols):
+ pair_scols.append(
+ F.struct(
+ F.lit(i).alias(index_1_col_name),
+ F.lit(j).alias(index_2_col_name),
+ numeric_scols[i].alias(CORRELATION_VALUE_1_COLUMN),
+ numeric_scols[j].alias(CORRELATION_VALUE_2_COLUMN),
+ )
+ )
+
+ sdf = sdf.select(*[F.col(key) for key in groupkey_names],
*[F.inline(F.array(*pair_scols))])
+
+ sdf = compute(
+ sdf=sdf, groupKeys=groupkey_names + [index_1_col_name,
index_2_col_name], method=method
+ )
+ if method == "kendall":
+ sdf = sdf.withColumn(
+ CORRELATION_CORR_OUTPUT_COLUMN,
+ F.when(F.col(index_1_col_name) == F.col(index_2_col_name),
F.lit(1.0)).otherwise(
+ F.col(CORRELATION_CORR_OUTPUT_COLUMN)
+ ),
+ )
+
+ sdf = sdf.withColumn(
+ CORRELATION_CORR_OUTPUT_COLUMN,
+ F.when(F.col(CORRELATION_COUNT_OUTPUT_COLUMN) < min_periods,
F.lit(None)).otherwise(
+ F.col(CORRELATION_CORR_OUTPUT_COLUMN)
+ ),
+ )
+
+ auxiliary_col_name = verify_temp_column_name(sdf,
"__groupby_corr_auxiliary_temp_column__")
+ sdf = sdf.withColumn(
+ auxiliary_col_name,
+ F.explode(
+ F.when(
+ F.col(index_1_col_name) == F.col(index_2_col_name),
+ F.lit([0]),
+ ).otherwise(F.lit([0, 1]))
+ ),
+ ).select(
+ *[F.col(key) for key in groupkey_names],
+ *[
+ F.when(F.col(auxiliary_col_name) == 0, F.col(index_1_col_name))
+ .otherwise(F.col(index_2_col_name))
+ .alias(index_1_col_name),
+ F.when(F.col(auxiliary_col_name) == 0, F.col(index_2_col_name))
+ .otherwise(F.col(index_1_col_name))
+ .alias(index_2_col_name),
+ F.col(CORRELATION_CORR_OUTPUT_COLUMN),
+ ],
+ )
+
+ array_col_name = verify_temp_column_name(sdf,
"__groupby_corr_array_temp_column__")
+ sdf = sdf.groupby(groupkey_names + [index_1_col_name]).agg(
+ F.array_sort(
+ F.collect_list(
+ F.struct(
+ F.col(index_2_col_name),
+ F.col(CORRELATION_CORR_OUTPUT_COLUMN),
+ )
+ )
+ ).alias(array_col_name)
+ )
+
+ for i in range(0, num_scols):
+ sdf = sdf.withColumn(auxiliary_col_name,
F.get(F.col(array_col_name), i)).withColumn(
+ numeric_col_names[i],
+
F.col(f"{auxiliary_col_name}.{CORRELATION_CORR_OUTPUT_COLUMN}"),
+ )
+
+ sdf = sdf.orderBy(groupkey_names + [index_1_col_name]) # type:
ignore[arg-type]
+
+ sdf = sdf.select(
+ *[F.col(col) for col in groupkey_names + numeric_col_names],
+ *[
+ F.get(F.lit(numeric_col_names),
F.col(index_1_col_name)).alias(auxiliary_col_name),
+
F.monotonically_increasing_id().alias(NATURAL_ORDER_COLUMN_NAME),
+ ],
+ )
+
+ return DataFrame(
+ InternalFrame(
+ spark_frame=sdf,
+ index_spark_columns=[
+ scol_for(sdf, key) for key in groupkey_names +
[auxiliary_col_name]
+ ],
+ index_names=(
+ [psser._column_label for psser in self._groupkeys]
+ + self._psdf._internal.index_names
+ ),
+ column_labels=numeric_labels,
+ column_label_names=internal.column_label_names,
+ )
+ )
+
class SeriesGroupBy(GroupBy[Series]):
@staticmethod
diff --git a/python/pyspark/pandas/missing/groupby.py
b/python/pyspark/pandas/missing/groupby.py
index 55a4a1d59674..a6b672df916c 100644
--- a/python/pyspark/pandas/missing/groupby.py
+++ b/python/pyspark/pandas/missing/groupby.py
@@ -41,7 +41,6 @@ class MissingPandasLikeDataFrameGroupBy:
# Documentation path: `python/docs/source/reference/pyspark.pandas/`.
# Properties
- corr = _unsupported_property("corr")
corrwith = _unsupported_property("corrwith")
cov = _unsupported_property("cov")
dtypes = _unsupported_property("dtypes")
diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_corr.py
b/python/pyspark/pandas/tests/connect/groupby/test_parity_corr.py
new file mode 100644
index 000000000000..53d4d53a7a35
--- /dev/null
+++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_corr.py
@@ -0,0 +1,41 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from pyspark.pandas.tests.groupby.test_corr import CorrMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
+
+
+class CorrParityTests(
+ CorrMixin,
+ PandasOnSparkTestUtils,
+ ReusedConnectTestCase,
+):
+ pass
+
+
+if __name__ == "__main__":
+ from pyspark.pandas.tests.connect.groupby.test_parity_corr import * #
noqa: F401
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/groupby/test_corr.py
b/python/pyspark/pandas/tests/groupby/test_corr.py
new file mode 100644
index 000000000000..39d6d91de4b0
--- /dev/null
+++ b/python/pyspark/pandas/tests/groupby/test_corr.py
@@ -0,0 +1,84 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+import pandas as pd
+
+from pyspark import pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
+
+
+class CorrMixin:
+ @property
+ def pdf(self):
+ return pd.DataFrame(
+ {
+ "A": [0, 0, 0, 1, 1, 2],
+ "B": [-1, 2, 3, 5, 6, 0],
+ "C": [4, 6, 5, 1, 3, 0],
+ },
+ columns=["A", "B", "C"],
+ )
+
+ @property
+ def psdf(self):
+ return ps.from_pandas(self.pdf)
+
+ def test_corr(self):
+ for c in ["A", "B", "C"]:
+ self.assert_eq(
+ self.pdf.groupby(c).corr().sort_index(),
+ self.psdf.groupby(c).corr().sort_index(),
+ almost=True,
+ )
+
+ def test_method(self):
+ for m in ["pearson", "spearman", "kendall"]:
+ self.assert_eq(
+ self.pdf.groupby("A").corr(method=m).sort_index(),
+ self.psdf.groupby("A").corr(method=m).sort_index(),
+ almost=True,
+ )
+
+ def test_min_periods(self):
+ for m in [1, 2, 3]:
+ self.assert_eq(
+ self.pdf.groupby("A").corr(min_periods=m).sort_index(),
+ self.psdf.groupby("A").corr(min_periods=m).sort_index(),
+ almost=True,
+ )
+
+
+class CorrTests(
+ CorrMixin,
+ PandasOnSparkTestCase,
+ SQLTestUtils,
+):
+ pass
+
+
+if __name__ == "__main__":
+ from pyspark.pandas.tests.groupby.test_corr import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]