This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 8269bcbe152 [SPARK-38907][PYTHON] Implement DataFrame.corrwith
8269bcbe152 is described below
commit 8269bcbe152da178c4c9bd8b3f745754c0a510d7
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Apr 27 10:13:16 2022 +0900
[SPARK-38907][PYTHON] Implement DataFrame.corrwith
### What changes were proposed in this pull request?
implement DataFrame.corrwith
- parameters `axis`,`numeric_only` are not supported,
- only the default `pearson` correlation is supported.
### Why are the changes needed?
Increase pandas API coverage in PySpark
### Does this PR introduce _any_ user-facing change?
yes, new function added:
```
In [4]: ps.set_option("compute.ops_on_diff_frames", True)
In [5]: df1 = ps.DataFrame({"A":[1, 5, 7, 8], "X":[5, 8, 4, 3], "C":[10,
4, 9, 3]})
In [6]: df2 = ps.DataFrame({"A":[5, 3, 6, 4], "B":[11, 2, 4, 3], "C":[4,
3, 8, 5]})
In [7]: df1.corrwith(df2)
Out[7]:
A -0.041703
C 0.395437
X NaN
B NaN
dtype: float64
In [8]: df1.corrwith(df2.B)
Out[8]:
A -0.844007
X -0.151186
C 0.767234
dtype: float64
```
### How was this patch tested?
added UT
Closes #36205 from zhengruifeng/impl_corr_with.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../docs/source/reference/pyspark.pandas/frame.rst | 1 +
.../pandas_on_spark/supported_pandas_api.rst | 2 +-
python/pyspark/pandas/frame.py | 151 +++++++++++++++++++++
python/pyspark/pandas/missing/frame.py | 1 -
python/pyspark/pandas/tests/test_dataframe.py | 27 ++++
.../pandas/tests/test_ops_on_diff_frames.py | 29 ++++
6 files changed, 209 insertions(+), 2 deletions(-)
diff --git a/python/docs/source/reference/pyspark.pandas/frame.rst
b/python/docs/source/reference/pyspark.pandas/frame.rst
index 9635115b0d7..05c215110c6 100644
--- a/python/docs/source/reference/pyspark.pandas/frame.rst
+++ b/python/docs/source/reference/pyspark.pandas/frame.rst
@@ -147,6 +147,7 @@ Computations / Descriptive Stats
DataFrame.any
DataFrame.clip
DataFrame.corr
+ DataFrame.corrwith
DataFrame.count
DataFrame.cov
DataFrame.describe
diff --git
a/python/docs/source/user_guide/pandas_on_spark/supported_pandas_api.rst
b/python/docs/source/user_guide/pandas_on_spark/supported_pandas_api.rst
index efcc55c7178..450742a20f7 100644
--- a/python/docs/source/user_guide/pandas_on_spark/supported_pandas_api.rst
+++ b/python/docs/source/user_guide/pandas_on_spark/supported_pandas_api.rst
@@ -121,7 +121,7 @@ Supported DataFrame APIs
+--------------------------------------------+-------------+--------------------------------------+
| :func:`corr` | P | ``min_periods``
|
+--------------------------------------------+-------------+--------------------------------------+
-| corrwith | N |
|
+| corrwith | P | ``axis``
|
+--------------------------------------------+-------------+--------------------------------------+
| :func:`count` | P | ``level``
|
+--------------------------------------------+-------------+--------------------------------------+
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index 16f8e786b0f..9880e2a18d8 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -1310,6 +1310,157 @@ class DataFrame(Frame, Generic[T]):
"""
return cast(DataFrame, ps.from_pandas(corr(self, method)))
+ # TODO: add axis parameter and support more methods
+ def corrwith(
+ self, other: DataFrameOrSeries, drop: bool = False, method: str =
"pearson"
+ ) -> "Series":
+ """
+ Compute pairwise correlation.
+
+ Pairwise correlation is computed between rows or columns of
+ DataFrame with rows or columns of Series or DataFrame. DataFrames
+ are first aligned along both axes before computing the
+ correlations.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ other : DataFrame, Series
+ Object with which to compute correlations.
+
+ drop : bool, default False
+ Drop missing indices from result.
+
+ method : str, default 'pearson'
+ Method of correlation, one of:
+
+ * pearson : standard correlation coefficient
+
+ Returns
+ -------
+ Series
+ Pairwise correlations.
+
+ See Also
+ --------
+ DataFrame.corr : Compute pairwise correlation of columns.
+
+ Examples
+ --------
+ >>> df1 = ps.DataFrame({
+ ... "A":[1, 5, 7, 8],
+ ... "X":[5, 8, 4, 3],
+ ... "C":[10, 4, 9, 3]})
+ >>> df1.corrwith(df1[["X", "C"]])
+ X 1.0
+ C 1.0
+ A NaN
+ dtype: float64
+
+ >>> df2 = ps.DataFrame({
+ ... "A":[5, 3, 6, 4],
+ ... "B":[11, 2, 4, 3],
+ ... "C":[4, 3, 8, 5]})
+
+ >>> with ps.option_context("compute.ops_on_diff_frames", True):
+ ... df1.corrwith(df2)
+ A -0.041703
+ C 0.395437
+ X NaN
+ B NaN
+ dtype: float64
+
+ >>> with ps.option_context("compute.ops_on_diff_frames", True):
+ ... df2.corrwith(df1.X)
+ A -0.597614
+ B -0.151186
+ C -0.642857
+ dtype: float64
+ """
+ from pyspark.pandas.series import Series, first_series
+
+ if (method is not None) and (method not in ["pearson"]):
+ raise NotImplementedError("corrwith currently works only for
method='pearson'")
+ if not isinstance(other, (DataFrame, Series)):
+ raise TypeError("unsupported type:
{}".format(type(other).__name__))
+
+ right_is_series = isinstance(other, Series)
+
+ if same_anchor(self, other):
+ combined = self
+ this = self
+ that = other
+ else:
+ combined = combine_frames(self, other, how="inner")
+ this = combined["this"]
+ that = combined["that"]
+
+ this_numeric_column_labels: List[Label] = []
+ for column_label in this._internal.column_labels:
+ if isinstance(this._internal.spark_type_for(column_label),
(NumericType, BooleanType)):
+ this_numeric_column_labels.append(column_label)
+
+ that_numeric_column_labels: List[Label] = []
+ for column_label in that._internal.column_labels:
+ if isinstance(that._internal.spark_type_for(column_label),
(NumericType, BooleanType)):
+ that_numeric_column_labels.append(column_label)
+
+ intersect_numeric_column_labels: List[Label] = []
+ diff_numeric_column_labels: List[Label] = []
+ corr_scols = []
+ if right_is_series:
+ intersect_numeric_column_labels = this_numeric_column_labels
+ that_scol =
that._internal.spark_column_for(that_numeric_column_labels[0])
+ for numeric_column_label in intersect_numeric_column_labels:
+ this_scol =
this._internal.spark_column_for(numeric_column_label)
+ corr_scols.append(
+ F.corr(this_scol.cast("double"),
that_scol.cast("double")).alias(
+ name_like_string(numeric_column_label)
+ )
+ )
+ else:
+ for numeric_column_label in this_numeric_column_labels:
+ if numeric_column_label in that_numeric_column_labels:
+
intersect_numeric_column_labels.append(numeric_column_label)
+ else:
+ diff_numeric_column_labels.append(numeric_column_label)
+ for numeric_column_label in that_numeric_column_labels:
+ if numeric_column_label not in this_numeric_column_labels:
+ diff_numeric_column_labels.append(numeric_column_label)
+ for numeric_column_label in intersect_numeric_column_labels:
+ this_scol =
this._internal.spark_column_for(numeric_column_label)
+ that_scol =
that._internal.spark_column_for(numeric_column_label)
+ corr_scols.append(
+ F.corr(this_scol.cast("double"),
that_scol.cast("double")).alias(
+ name_like_string(numeric_column_label)
+ )
+ )
+
+ corr_labels: List[Label] = intersect_numeric_column_labels
+ if not drop:
+ for numeric_column_label in diff_numeric_column_labels:
+ corr_scols.append(
+
SF.lit(None).cast("double").alias(name_like_string(numeric_column_label))
+ )
+ corr_labels.append(numeric_column_label)
+
+ sdf = combined._internal.spark_frame.select(
+
*[SF.lit(None).cast(StringType()).alias(SPARK_DEFAULT_INDEX_NAME)], *corr_scols
+ ).limit(
+ 1
+ ) # limit(1) to avoid returning more than 1 row when intersection is
empty
+
+ # The data is expected to be small so it's fine to transpose/use
default index.
+ with ps.option_context("compute.max_rows", 1):
+ internal = InternalFrame(
+ spark_frame=sdf,
+ index_spark_columns=[scol_for(sdf, SPARK_DEFAULT_INDEX_NAME)],
+ column_labels=corr_labels,
+ column_label_names=self._internal.column_label_names,
+ )
+ return first_series(DataFrame(internal).transpose())
+
def iteritems(self) -> Iterator[Tuple[Name, "Series"]]:
"""
Iterator over (column name, Series) pairs.
diff --git a/python/pyspark/pandas/missing/frame.py
b/python/pyspark/pandas/missing/frame.py
index 23fb06f03ce..ba2d01c5225 100644
--- a/python/pyspark/pandas/missing/frame.py
+++ b/python/pyspark/pandas/missing/frame.py
@@ -40,7 +40,6 @@ class _MissingPandasLikeDataFrame:
combine = _unsupported_function("combine")
compare = _unsupported_function("compare")
convert_dtypes = _unsupported_function("convert_dtypes")
- corrwith = _unsupported_function("corrwith")
infer_objects = _unsupported_function("infer_objects")
mode = _unsupported_function("mode")
reorder_levels = _unsupported_function("reorder_levels")
diff --git a/python/pyspark/pandas/tests/test_dataframe.py
b/python/pyspark/pandas/tests/test_dataframe.py
index 008da92c9a9..8915ec1ca64 100644
--- a/python/pyspark/pandas/tests/test_dataframe.py
+++ b/python/pyspark/pandas/tests/test_dataframe.py
@@ -5547,6 +5547,33 @@ class DataFrameTest(ComparisonTestBase, SQLTestUtils):
self.assert_eq(abs(psdf), abs(pdf))
self.assert_eq(np.abs(psdf), np.abs(pdf))
+ def test_corrwith(self):
+ df1 = ps.DataFrame({"A": [1, np.nan, 7, 8], "X": [5, 8, np.nan, 3],
"C": [10, 4, 9, 3]})
+ df2 = df1[["A", "C"]]
+ self._test_corrwith(df1, df2)
+ self._test_corrwith((df1 + 1), df2.A)
+ self._test_corrwith((df1 + 1), (df2.C + 2))
+
+ with self.assertRaisesRegex(
+ NotImplementedError, "corrwith currently works only for
method='pearson'"
+ ):
+ df1.corrwith(df2, method="kendall")
+
+ with self.assertRaisesRegex(TypeError, "unsupported type"):
+ df1.corrwith(123)
+
+ df_bool = ps.DataFrame({"A": [True, True, False, False], "B": [True,
False, False, True]})
+ self._test_corrwith(df_bool, df_bool.A)
+ self._test_corrwith(df_bool, df_bool.B)
+
+ def _test_corrwith(self, psdf, psobj):
+ pdf = psdf.to_pandas()
+ pobj = psobj.to_pandas()
+ for drop in [True, False]:
+ p_corr = pdf.corrwith(pobj, drop=drop)
+ ps_corr = psdf.corrwith(psobj, drop=drop)
+ self.assert_eq(p_corr.sort_index(), ps_corr.sort_index(),
almost=True)
+
def test_iteritems(self):
pdf = pd.DataFrame(
{"species": ["bear", "bear", "marsupial"], "population": [1864,
22000, 80000]},
diff --git a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
index 96473769475..3ef3c676ad8 100644
--- a/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
+++ b/python/pyspark/pandas/tests/test_ops_on_diff_frames.py
@@ -1842,6 +1842,35 @@ class OpsOnDiffFramesEnabledTest(PandasOnSparkTestCase,
SQLTestUtils):
pscov = psser1.cov(psser2, min_periods=3)
self.assert_eq(pcov, pscov, almost=True)
+ def test_corrwith(self):
+ df1 = ps.DataFrame({"A": [1, np.nan, 7, 8], "X": [5, 8, np.nan, 3],
"C": [10, 4, 9, 3]})
+ df2 = ps.DataFrame({"A": [5, 3, 6, 4], "B": [11, 2, 4, 3], "C": [4, 3,
8, np.nan]})
+ self._test_corrwith(df1, df2)
+ self._test_corrwith((df1 + 1), df2.B)
+ self._test_corrwith((df1 + 1), (df2.B + 2))
+
+ df_bool = ps.DataFrame({"A": [True, True, False, False], "B": [True,
False, False, True]})
+ ser_bool = ps.Series([True, True, False, True])
+ self._test_corrwith(df_bool, ser_bool)
+
+ self._test_corrwith(self.psdf1, self.psdf1)
+ self._test_corrwith(self.psdf1, self.psdf2)
+ self._test_corrwith(self.psdf2, self.psdf3)
+ self._test_corrwith(self.psdf3, self.psdf4)
+
+ self._test_corrwith(self.psdf1, self.psdf1.a)
+ self._test_corrwith(self.psdf1, self.psdf2.b)
+ self._test_corrwith(self.psdf2, self.psdf3.c)
+ self._test_corrwith(self.psdf3, self.psdf4.f)
+
+ def _test_corrwith(self, psdf, psobj):
+ pdf = psdf.to_pandas()
+ pobj = psobj.to_pandas()
+ for drop in [True, False]:
+ p_corr = pdf.corrwith(pobj, drop=drop)
+ ps_corr = psdf.corrwith(psobj, drop=drop)
+ self.assert_eq(p_corr.sort_index(), ps_corr.sort_index(),
almost=True)
+
def test_series_eq(self):
pser = pd.Series([1, 2, 3, 4, 5, 6], name="x")
psser = ps.from_pandas(pser)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]