This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new cfbf3c704f0f [SPARK-46976][PS] Implement `DataFrameGroupBy.corr`
cfbf3c704f0f is described below

commit cfbf3c704f0fd593ce383eaddada4d3fc3500659
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Tue Feb 6 16:55:51 2024 +0800

    [SPARK-46976][PS] Implement `DataFrameGroupBy.corr`
    
    ### What changes were proposed in this pull request?
    Implement `DataFrameGroupBy.corr`
    
    ### Why are the changes needed?
    for pandas parity
     
https://pandas.pydata.org/docs/reference/api/pandas.core.groupby.DataFrameGroupBy.corr.html
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    ```
    In [5]: pdf = pd.DataFrame({'A': [0, 0, 0, 1, 1, 2], 'B': [-1, 2, 3, 5, 6, 
0], 'C': [4, 6, 5, 1, 3, 0]}, columns=['A', 'B', 'C'])
    
    In [6]: pdf.groupby("A").corr()
    Out[6]:
                B         C
    A
    0 B  1.000000  0.720577
      C  0.720577  1.000000
    1 B  1.000000  1.000000
      C  1.000000  1.000000
    2 B       NaN       NaN
      C       NaN       NaN
    
    In [7]: psdf = ps.from_pandas(pdf)
    
    In [8]: psdf.groupby("A").corr()
    
                B         C
    A
    0 B  1.000000  0.720577
      C  0.720577  1.000000
    1 B  1.000000  1.000000
      C  1.000000  1.000000
    2 B       NaN       NaN
      C       NaN       NaN
    ```
    
    ### How was this patch tested?
    added tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #45028 from zhengruifeng/ps_df_groupby_corr.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 dev/sparktestsupport/modules.py                    |   2 +
 python/pyspark/pandas/groupby.py                   | 216 +++++++++++++++++++++
 python/pyspark/pandas/missing/groupby.py           |   1 -
 .../tests/connect/groupby/test_parity_corr.py      |  41 ++++
 python/pyspark/pandas/tests/groupby/test_corr.py   |  84 ++++++++
 5 files changed, 343 insertions(+), 1 deletion(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 2ed2144fa64b..ff3b23ff573a 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -889,6 +889,7 @@ pyspark_pandas_slow = Module(
         "pyspark.pandas.tests.indexes.test_reset_index",
         "pyspark.pandas.tests.groupby.test_aggregate",
         "pyspark.pandas.tests.groupby.test_apply_func",
+        "pyspark.pandas.tests.groupby.test_corr",
         "pyspark.pandas.tests.groupby.test_cumulative",
         "pyspark.pandas.tests.groupby.test_describe",
         "pyspark.pandas.tests.groupby.test_groupby",
@@ -1174,6 +1175,7 @@ pyspark_pandas_connect_part1 = Module(
         "pyspark.pandas.tests.connect.frame.test_parity_truncate",
         "pyspark.pandas.tests.connect.groupby.test_parity_aggregate",
         "pyspark.pandas.tests.connect.groupby.test_parity_apply_func",
+        "pyspark.pandas.tests.connect.groupby.test_parity_corr",
         "pyspark.pandas.tests.connect.groupby.test_parity_cumulative",
         "pyspark.pandas.tests.connect.groupby.test_parity_missing_data",
         "pyspark.pandas.tests.connect.groupby.test_parity_split_apply",
diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py
index 4cce147b2606..ec47ab75c43c 100644
--- a/python/pyspark/pandas/groupby.py
+++ b/python/pyspark/pandas/groupby.py
@@ -76,6 +76,13 @@ from pyspark.pandas.missing.groupby import (
 from pyspark.pandas.series import Series, first_series
 from pyspark.pandas.spark import functions as SF
 from pyspark.pandas.config import get_option
+from pyspark.pandas.correlation import (
+    compute,
+    CORRELATION_VALUE_1_COLUMN,
+    CORRELATION_VALUE_2_COLUMN,
+    CORRELATION_CORR_OUTPUT_COLUMN,
+    CORRELATION_COUNT_OUTPUT_COLUMN,
+)
 from pyspark.pandas.utils import (
     align_diff_frames,
     is_name_like_tuple,
@@ -3928,6 +3935,215 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
         # Cast columns to ``"float64"`` to match `pandas.DataFrame.groupby`.
         return DataFrame(internal).astype("float64")
 
+    def corr(
+        self,
+        method: str = "pearson",
+        min_periods: int = 1,
+        numeric_only: bool = False,
+    ) -> "DataFrame":
+        """
+        Compute pairwise correlation of columns, excluding NA/null values.
+
+        .. versionadded:: 4.0.0
+
+        Parameters
+        ----------
+        method : {'pearson', 'spearman', 'kendall'}
+            * pearson : standard correlation coefficient
+            * spearman : Spearman rank correlation
+            * kendall : Kendall Tau correlation coefficient
+
+        min_periods : int, default 1
+            Minimum number of observations in window required to have a value
+            (otherwise result is NA).
+
+        numeric_only : bool, default False
+            Include only `float`, `int` or `boolean` data.
+
+        Returns
+        -------
+        DataFrame
+
+        See Also
+        --------
+        DataFrame.corrwith
+        Series.corr
+
+        Notes
+        -----
+        1. Pearson, Kendall and Spearman correlation are currently computed 
using pairwise
+           complete observations.
+
+        2. The complexity of Kendall correlation is O(#row * #row), if the 
dataset is too
+           large, sampling ahead of correlation computation is recommended.
+
+        Examples
+        --------
+        >>> df = ps.DataFrame(
+        ...     {"A": [0, 0, 0, 1, 1, 2], "B": [-1, 2, 3, 5, 6, 0], "C": [4, 
6, 5, 1, 3, 0]},
+        ...     columns=["A", "B", "C"])
+        >>> df.groupby("A").corr()
+                    B         C
+        A
+        0 B  1.000000  0.720577
+          C  0.720577  1.000000
+        1 B  1.000000  1.000000
+          C  1.000000  1.000000
+        2 B       NaN       NaN
+          C       NaN       NaN
+
+        >>> df.groupby("A").corr(min_periods=2)
+                    B         C
+        A
+        0 B  1.000000  0.720577
+          C  0.720577  1.000000
+        1 B  1.000000  1.000000
+          C  1.000000  1.000000
+        2 B       NaN       NaN
+          C       NaN       NaN
+
+        >>> df.groupby("A").corr("spearman")
+               B    C
+        A
+        0 B  1.0  0.5
+          C  0.5  1.0
+        1 B  1.0  1.0
+          C  1.0  1.0
+        2 B  NaN  NaN
+          C  NaN  NaN
+
+        >>> df.groupby("A").corr('kendall')
+                    B         C
+        A
+        0 B  1.000000  0.333333
+          C  0.333333  1.000000
+        1 B  1.000000  1.000000
+          C  1.000000  1.000000
+        2 B  1.000000       NaN
+          C       NaN  1.000000
+        """
+        if method not in ["pearson", "spearman", "kendall"]:
+            raise ValueError(f"Invalid method {method}")
+
+        groupkey_names: List[str] = [str(key.name) for key in self._groupkeys]
+        internal, agg_columns, sdf = self._prepare_reduce(
+            groupkey_names=groupkey_names,
+            accepted_spark_types=(NumericType, BooleanType) if numeric_only 
else None,
+            bool_to_numeric=False,
+        )
+
+        numeric_labels = [
+            label
+            for label in internal.column_labels
+            if isinstance(internal.spark_type_for(label), (NumericType, 
BooleanType))
+        ]
+        numeric_scols: List[Column] = [
+            internal.spark_column_for(label).cast("double") for label in 
numeric_labels
+        ]
+        numeric_col_names: List[str] = [name_like_string(label) for label in 
numeric_labels]
+        num_scols = len(numeric_scols)
+
+        sdf = internal.spark_frame
+        index_1_col_name = verify_temp_column_name(sdf, 
"__groupby_corr_index_1_temp_column__")
+        index_2_col_name = verify_temp_column_name(sdf, 
"__groupby_corr_index_2_temp_column__")
+
+        pair_scols: List[Column] = []
+        for i in range(0, num_scols):
+            for j in range(i, num_scols):
+                pair_scols.append(
+                    F.struct(
+                        F.lit(i).alias(index_1_col_name),
+                        F.lit(j).alias(index_2_col_name),
+                        numeric_scols[i].alias(CORRELATION_VALUE_1_COLUMN),
+                        numeric_scols[j].alias(CORRELATION_VALUE_2_COLUMN),
+                    )
+                )
+
+        sdf = sdf.select(*[F.col(key) for key in groupkey_names], 
*[F.inline(F.array(*pair_scols))])
+
+        sdf = compute(
+            sdf=sdf, groupKeys=groupkey_names + [index_1_col_name, 
index_2_col_name], method=method
+        )
+        if method == "kendall":
+            sdf = sdf.withColumn(
+                CORRELATION_CORR_OUTPUT_COLUMN,
+                F.when(F.col(index_1_col_name) == F.col(index_2_col_name), 
F.lit(1.0)).otherwise(
+                    F.col(CORRELATION_CORR_OUTPUT_COLUMN)
+                ),
+            )
+
+        sdf = sdf.withColumn(
+            CORRELATION_CORR_OUTPUT_COLUMN,
+            F.when(F.col(CORRELATION_COUNT_OUTPUT_COLUMN) < min_periods, 
F.lit(None)).otherwise(
+                F.col(CORRELATION_CORR_OUTPUT_COLUMN)
+            ),
+        )
+
+        auxiliary_col_name = verify_temp_column_name(sdf, 
"__groupby_corr_auxiliary_temp_column__")
+        sdf = sdf.withColumn(
+            auxiliary_col_name,
+            F.explode(
+                F.when(
+                    F.col(index_1_col_name) == F.col(index_2_col_name),
+                    F.lit([0]),
+                ).otherwise(F.lit([0, 1]))
+            ),
+        ).select(
+            *[F.col(key) for key in groupkey_names],
+            *[
+                F.when(F.col(auxiliary_col_name) == 0, F.col(index_1_col_name))
+                .otherwise(F.col(index_2_col_name))
+                .alias(index_1_col_name),
+                F.when(F.col(auxiliary_col_name) == 0, F.col(index_2_col_name))
+                .otherwise(F.col(index_1_col_name))
+                .alias(index_2_col_name),
+                F.col(CORRELATION_CORR_OUTPUT_COLUMN),
+            ],
+        )
+
+        array_col_name = verify_temp_column_name(sdf, 
"__groupby_corr_array_temp_column__")
+        sdf = sdf.groupby(groupkey_names + [index_1_col_name]).agg(
+            F.array_sort(
+                F.collect_list(
+                    F.struct(
+                        F.col(index_2_col_name),
+                        F.col(CORRELATION_CORR_OUTPUT_COLUMN),
+                    )
+                )
+            ).alias(array_col_name)
+        )
+
+        for i in range(0, num_scols):
+            sdf = sdf.withColumn(auxiliary_col_name, 
F.get(F.col(array_col_name), i)).withColumn(
+                numeric_col_names[i],
+                
F.col(f"{auxiliary_col_name}.{CORRELATION_CORR_OUTPUT_COLUMN}"),
+            )
+
+        sdf = sdf.orderBy(groupkey_names + [index_1_col_name])  # type: 
ignore[arg-type]
+
+        sdf = sdf.select(
+            *[F.col(col) for col in groupkey_names + numeric_col_names],
+            *[
+                F.get(F.lit(numeric_col_names), 
F.col(index_1_col_name)).alias(auxiliary_col_name),
+                
F.monotonically_increasing_id().alias(NATURAL_ORDER_COLUMN_NAME),
+            ],
+        )
+
+        return DataFrame(
+            InternalFrame(
+                spark_frame=sdf,
+                index_spark_columns=[
+                    scol_for(sdf, key) for key in groupkey_names + 
[auxiliary_col_name]
+                ],
+                index_names=(
+                    [psser._column_label for psser in self._groupkeys]
+                    + self._psdf._internal.index_names
+                ),
+                column_labels=numeric_labels,
+                column_label_names=internal.column_label_names,
+            )
+        )
+
 
 class SeriesGroupBy(GroupBy[Series]):
     @staticmethod
diff --git a/python/pyspark/pandas/missing/groupby.py 
b/python/pyspark/pandas/missing/groupby.py
index 55a4a1d59674..a6b672df916c 100644
--- a/python/pyspark/pandas/missing/groupby.py
+++ b/python/pyspark/pandas/missing/groupby.py
@@ -41,7 +41,6 @@ class MissingPandasLikeDataFrameGroupBy:
     # Documentation path: `python/docs/source/reference/pyspark.pandas/`.
 
     # Properties
-    corr = _unsupported_property("corr")
     corrwith = _unsupported_property("corrwith")
     cov = _unsupported_property("cov")
     dtypes = _unsupported_property("dtypes")
diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_corr.py 
b/python/pyspark/pandas/tests/connect/groupby/test_parity_corr.py
new file mode 100644
index 000000000000..53d4d53a7a35
--- /dev/null
+++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_corr.py
@@ -0,0 +1,41 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from pyspark.pandas.tests.groupby.test_corr import CorrMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
+
+
+class CorrParityTests(
+    CorrMixin,
+    PandasOnSparkTestUtils,
+    ReusedConnectTestCase,
+):
+    pass
+
+
+if __name__ == "__main__":
+    from pyspark.pandas.tests.connect.groupby.test_parity_corr import *  # 
noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pandas/tests/groupby/test_corr.py 
b/python/pyspark/pandas/tests/groupby/test_corr.py
new file mode 100644
index 000000000000..39d6d91de4b0
--- /dev/null
+++ b/python/pyspark/pandas/tests/groupby/test_corr.py
@@ -0,0 +1,84 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+import pandas as pd
+
+from pyspark import pandas as ps
+from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.sqlutils import SQLTestUtils
+
+
+class CorrMixin:
+    @property
+    def pdf(self):
+        return pd.DataFrame(
+            {
+                "A": [0, 0, 0, 1, 1, 2],
+                "B": [-1, 2, 3, 5, 6, 0],
+                "C": [4, 6, 5, 1, 3, 0],
+            },
+            columns=["A", "B", "C"],
+        )
+
+    @property
+    def psdf(self):
+        return ps.from_pandas(self.pdf)
+
+    def test_corr(self):
+        for c in ["A", "B", "C"]:
+            self.assert_eq(
+                self.pdf.groupby(c).corr().sort_index(),
+                self.psdf.groupby(c).corr().sort_index(),
+                almost=True,
+            )
+
+    def test_method(self):
+        for m in ["pearson", "spearman", "kendall"]:
+            self.assert_eq(
+                self.pdf.groupby("A").corr(method=m).sort_index(),
+                self.psdf.groupby("A").corr(method=m).sort_index(),
+                almost=True,
+            )
+
+    def test_min_periods(self):
+        for m in [1, 2, 3]:
+            self.assert_eq(
+                self.pdf.groupby("A").corr(min_periods=m).sort_index(),
+                self.psdf.groupby("A").corr(min_periods=m).sort_index(),
+                almost=True,
+            )
+
+
+class CorrTests(
+    CorrMixin,
+    PandasOnSparkTestCase,
+    SQLTestUtils,
+):
+    pass
+
+
+if __name__ == "__main__":
+    from pyspark.pandas.tests.groupby.test_corr import *  # noqa: F401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to