This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new cfbf3c704f0f [SPARK-46976][PS] Implement `DataFrameGroupBy.corr` cfbf3c704f0f is described below commit cfbf3c704f0fd593ce383eaddada4d3fc3500659 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Tue Feb 6 16:55:51 2024 +0800 [SPARK-46976][PS] Implement `DataFrameGroupBy.corr` ### What changes were proposed in this pull request? Implement `DataFrameGroupBy.corr` ### Why are the changes needed? for pandas parity https://pandas.pydata.org/docs/reference/api/pandas.core.groupby.DataFrameGroupBy.corr.html ### Does this PR introduce _any_ user-facing change? yes ``` In [5]: pdf = pd.DataFrame({'A': [0, 0, 0, 1, 1, 2], 'B': [-1, 2, 3, 5, 6, 0], 'C': [4, 6, 5, 1, 3, 0]}, columns=['A', 'B', 'C']) In [6]: pdf.groupby("A").corr() Out[6]: B C A 0 B 1.000000 0.720577 C 0.720577 1.000000 1 B 1.000000 1.000000 C 1.000000 1.000000 2 B NaN NaN C NaN NaN In [7]: psdf = ps.from_pandas(pdf) In [8]: psdf.groupby("A").corr() B C A 0 B 1.000000 0.720577 C 0.720577 1.000000 1 B 1.000000 1.000000 C 1.000000 1.000000 2 B NaN NaN C NaN NaN ``` ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #45028 from zhengruifeng/ps_df_groupby_corr. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- dev/sparktestsupport/modules.py | 2 + python/pyspark/pandas/groupby.py | 216 +++++++++++++++++++++ python/pyspark/pandas/missing/groupby.py | 1 - .../tests/connect/groupby/test_parity_corr.py | 41 ++++ python/pyspark/pandas/tests/groupby/test_corr.py | 84 ++++++++ 5 files changed, 343 insertions(+), 1 deletion(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 2ed2144fa64b..ff3b23ff573a 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -889,6 +889,7 @@ pyspark_pandas_slow = Module( "pyspark.pandas.tests.indexes.test_reset_index", "pyspark.pandas.tests.groupby.test_aggregate", "pyspark.pandas.tests.groupby.test_apply_func", + "pyspark.pandas.tests.groupby.test_corr", "pyspark.pandas.tests.groupby.test_cumulative", "pyspark.pandas.tests.groupby.test_describe", "pyspark.pandas.tests.groupby.test_groupby", @@ -1174,6 +1175,7 @@ pyspark_pandas_connect_part1 = Module( "pyspark.pandas.tests.connect.frame.test_parity_truncate", "pyspark.pandas.tests.connect.groupby.test_parity_aggregate", "pyspark.pandas.tests.connect.groupby.test_parity_apply_func", + "pyspark.pandas.tests.connect.groupby.test_parity_corr", "pyspark.pandas.tests.connect.groupby.test_parity_cumulative", "pyspark.pandas.tests.connect.groupby.test_parity_missing_data", "pyspark.pandas.tests.connect.groupby.test_parity_split_apply", diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py index 4cce147b2606..ec47ab75c43c 100644 --- a/python/pyspark/pandas/groupby.py +++ b/python/pyspark/pandas/groupby.py @@ -76,6 +76,13 @@ from pyspark.pandas.missing.groupby import ( from pyspark.pandas.series import Series, first_series from pyspark.pandas.spark import functions as SF from pyspark.pandas.config import get_option +from pyspark.pandas.correlation import ( + compute, + CORRELATION_VALUE_1_COLUMN, + CORRELATION_VALUE_2_COLUMN, + CORRELATION_CORR_OUTPUT_COLUMN, + CORRELATION_COUNT_OUTPUT_COLUMN, +) from pyspark.pandas.utils import ( align_diff_frames, is_name_like_tuple, @@ -3928,6 +3935,215 @@ class DataFrameGroupBy(GroupBy[DataFrame]): # Cast columns to ``"float64"`` to match `pandas.DataFrame.groupby`. return DataFrame(internal).astype("float64") + def corr( + self, + method: str = "pearson", + min_periods: int = 1, + numeric_only: bool = False, + ) -> "DataFrame": + """ + Compute pairwise correlation of columns, excluding NA/null values. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + method : {'pearson', 'spearman', 'kendall'} + * pearson : standard correlation coefficient + * spearman : Spearman rank correlation + * kendall : Kendall Tau correlation coefficient + + min_periods : int, default 1 + Minimum number of observations in window required to have a value + (otherwise result is NA). + + numeric_only : bool, default False + Include only `float`, `int` or `boolean` data. + + Returns + ------- + DataFrame + + See Also + -------- + DataFrame.corrwith + Series.corr + + Notes + ----- + 1. Pearson, Kendall and Spearman correlation are currently computed using pairwise + complete observations. + + 2. The complexity of Kendall correlation is O(#row * #row), if the dataset is too + large, sampling ahead of correlation computation is recommended. + + Examples + -------- + >>> df = ps.DataFrame( + ... {"A": [0, 0, 0, 1, 1, 2], "B": [-1, 2, 3, 5, 6, 0], "C": [4, 6, 5, 1, 3, 0]}, + ... columns=["A", "B", "C"]) + >>> df.groupby("A").corr() + B C + A + 0 B 1.000000 0.720577 + C 0.720577 1.000000 + 1 B 1.000000 1.000000 + C 1.000000 1.000000 + 2 B NaN NaN + C NaN NaN + + >>> df.groupby("A").corr(min_periods=2) + B C + A + 0 B 1.000000 0.720577 + C 0.720577 1.000000 + 1 B 1.000000 1.000000 + C 1.000000 1.000000 + 2 B NaN NaN + C NaN NaN + + >>> df.groupby("A").corr("spearman") + B C + A + 0 B 1.0 0.5 + C 0.5 1.0 + 1 B 1.0 1.0 + C 1.0 1.0 + 2 B NaN NaN + C NaN NaN + + >>> df.groupby("A").corr('kendall') + B C + A + 0 B 1.000000 0.333333 + C 0.333333 1.000000 + 1 B 1.000000 1.000000 + C 1.000000 1.000000 + 2 B 1.000000 NaN + C NaN 1.000000 + """ + if method not in ["pearson", "spearman", "kendall"]: + raise ValueError(f"Invalid method {method}") + + groupkey_names: List[str] = [str(key.name) for key in self._groupkeys] + internal, agg_columns, sdf = self._prepare_reduce( + groupkey_names=groupkey_names, + accepted_spark_types=(NumericType, BooleanType) if numeric_only else None, + bool_to_numeric=False, + ) + + numeric_labels = [ + label + for label in internal.column_labels + if isinstance(internal.spark_type_for(label), (NumericType, BooleanType)) + ] + numeric_scols: List[Column] = [ + internal.spark_column_for(label).cast("double") for label in numeric_labels + ] + numeric_col_names: List[str] = [name_like_string(label) for label in numeric_labels] + num_scols = len(numeric_scols) + + sdf = internal.spark_frame + index_1_col_name = verify_temp_column_name(sdf, "__groupby_corr_index_1_temp_column__") + index_2_col_name = verify_temp_column_name(sdf, "__groupby_corr_index_2_temp_column__") + + pair_scols: List[Column] = [] + for i in range(0, num_scols): + for j in range(i, num_scols): + pair_scols.append( + F.struct( + F.lit(i).alias(index_1_col_name), + F.lit(j).alias(index_2_col_name), + numeric_scols[i].alias(CORRELATION_VALUE_1_COLUMN), + numeric_scols[j].alias(CORRELATION_VALUE_2_COLUMN), + ) + ) + + sdf = sdf.select(*[F.col(key) for key in groupkey_names], *[F.inline(F.array(*pair_scols))]) + + sdf = compute( + sdf=sdf, groupKeys=groupkey_names + [index_1_col_name, index_2_col_name], method=method + ) + if method == "kendall": + sdf = sdf.withColumn( + CORRELATION_CORR_OUTPUT_COLUMN, + F.when(F.col(index_1_col_name) == F.col(index_2_col_name), F.lit(1.0)).otherwise( + F.col(CORRELATION_CORR_OUTPUT_COLUMN) + ), + ) + + sdf = sdf.withColumn( + CORRELATION_CORR_OUTPUT_COLUMN, + F.when(F.col(CORRELATION_COUNT_OUTPUT_COLUMN) < min_periods, F.lit(None)).otherwise( + F.col(CORRELATION_CORR_OUTPUT_COLUMN) + ), + ) + + auxiliary_col_name = verify_temp_column_name(sdf, "__groupby_corr_auxiliary_temp_column__") + sdf = sdf.withColumn( + auxiliary_col_name, + F.explode( + F.when( + F.col(index_1_col_name) == F.col(index_2_col_name), + F.lit([0]), + ).otherwise(F.lit([0, 1])) + ), + ).select( + *[F.col(key) for key in groupkey_names], + *[ + F.when(F.col(auxiliary_col_name) == 0, F.col(index_1_col_name)) + .otherwise(F.col(index_2_col_name)) + .alias(index_1_col_name), + F.when(F.col(auxiliary_col_name) == 0, F.col(index_2_col_name)) + .otherwise(F.col(index_1_col_name)) + .alias(index_2_col_name), + F.col(CORRELATION_CORR_OUTPUT_COLUMN), + ], + ) + + array_col_name = verify_temp_column_name(sdf, "__groupby_corr_array_temp_column__") + sdf = sdf.groupby(groupkey_names + [index_1_col_name]).agg( + F.array_sort( + F.collect_list( + F.struct( + F.col(index_2_col_name), + F.col(CORRELATION_CORR_OUTPUT_COLUMN), + ) + ) + ).alias(array_col_name) + ) + + for i in range(0, num_scols): + sdf = sdf.withColumn(auxiliary_col_name, F.get(F.col(array_col_name), i)).withColumn( + numeric_col_names[i], + F.col(f"{auxiliary_col_name}.{CORRELATION_CORR_OUTPUT_COLUMN}"), + ) + + sdf = sdf.orderBy(groupkey_names + [index_1_col_name]) # type: ignore[arg-type] + + sdf = sdf.select( + *[F.col(col) for col in groupkey_names + numeric_col_names], + *[ + F.get(F.lit(numeric_col_names), F.col(index_1_col_name)).alias(auxiliary_col_name), + F.monotonically_increasing_id().alias(NATURAL_ORDER_COLUMN_NAME), + ], + ) + + return DataFrame( + InternalFrame( + spark_frame=sdf, + index_spark_columns=[ + scol_for(sdf, key) for key in groupkey_names + [auxiliary_col_name] + ], + index_names=( + [psser._column_label for psser in self._groupkeys] + + self._psdf._internal.index_names + ), + column_labels=numeric_labels, + column_label_names=internal.column_label_names, + ) + ) + class SeriesGroupBy(GroupBy[Series]): @staticmethod diff --git a/python/pyspark/pandas/missing/groupby.py b/python/pyspark/pandas/missing/groupby.py index 55a4a1d59674..a6b672df916c 100644 --- a/python/pyspark/pandas/missing/groupby.py +++ b/python/pyspark/pandas/missing/groupby.py @@ -41,7 +41,6 @@ class MissingPandasLikeDataFrameGroupBy: # Documentation path: `python/docs/source/reference/pyspark.pandas/`. # Properties - corr = _unsupported_property("corr") corrwith = _unsupported_property("corrwith") cov = _unsupported_property("cov") dtypes = _unsupported_property("dtypes") diff --git a/python/pyspark/pandas/tests/connect/groupby/test_parity_corr.py b/python/pyspark/pandas/tests/connect/groupby/test_parity_corr.py new file mode 100644 index 000000000000..53d4d53a7a35 --- /dev/null +++ b/python/pyspark/pandas/tests/connect/groupby/test_parity_corr.py @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from pyspark.pandas.tests.groupby.test_corr import CorrMixin +from pyspark.testing.connectutils import ReusedConnectTestCase +from pyspark.testing.pandasutils import PandasOnSparkTestUtils + + +class CorrParityTests( + CorrMixin, + PandasOnSparkTestUtils, + ReusedConnectTestCase, +): + pass + + +if __name__ == "__main__": + from pyspark.pandas.tests.connect.groupby.test_parity_corr import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pandas/tests/groupby/test_corr.py b/python/pyspark/pandas/tests/groupby/test_corr.py new file mode 100644 index 000000000000..39d6d91de4b0 --- /dev/null +++ b/python/pyspark/pandas/tests/groupby/test_corr.py @@ -0,0 +1,84 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +import pandas as pd + +from pyspark import pandas as ps +from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.testing.sqlutils import SQLTestUtils + + +class CorrMixin: + @property + def pdf(self): + return pd.DataFrame( + { + "A": [0, 0, 0, 1, 1, 2], + "B": [-1, 2, 3, 5, 6, 0], + "C": [4, 6, 5, 1, 3, 0], + }, + columns=["A", "B", "C"], + ) + + @property + def psdf(self): + return ps.from_pandas(self.pdf) + + def test_corr(self): + for c in ["A", "B", "C"]: + self.assert_eq( + self.pdf.groupby(c).corr().sort_index(), + self.psdf.groupby(c).corr().sort_index(), + almost=True, + ) + + def test_method(self): + for m in ["pearson", "spearman", "kendall"]: + self.assert_eq( + self.pdf.groupby("A").corr(method=m).sort_index(), + self.psdf.groupby("A").corr(method=m).sort_index(), + almost=True, + ) + + def test_min_periods(self): + for m in [1, 2, 3]: + self.assert_eq( + self.pdf.groupby("A").corr(min_periods=m).sort_index(), + self.psdf.groupby("A").corr(min_periods=m).sort_index(), + almost=True, + ) + + +class CorrTests( + CorrMixin, + PandasOnSparkTestCase, + SQLTestUtils, +): + pass + + +if __name__ == "__main__": + from pyspark.pandas.tests.groupby.test_corr import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org