This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5b2bd1c9c0c [SPARK-40447][PS] Implement `kendall` correlation in 
`DataFrame.corr`
5b2bd1c9c0c is described below

commit 5b2bd1c9c0cb109f8a801dfcfb6ba1305bf864c6
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Sat Sep 17 07:30:31 2022 +0800

    [SPARK-40447][PS] Implement `kendall` correlation in `DataFrame.corr`
    
    ### What changes were proposed in this pull request?
    Implement `kendall` correlation in `DataFrame.corr`
    
    ### Why are the changes needed?
    for API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new correlation option:
    ```
    In [1]: import pyspark.pandas as ps
    
    In [2]: df = ps.DataFrame([(.2, .3), (.0, .6), (.6, .0), (.2, .1)], 
columns=['dogs', 'cats'])
    
    In [3]: df.corr('kendall')
    
              dogs      cats
    dogs  1.000000 -0.912871
    cats -0.912871  1.000000
    
    In [4]: df.to_pandas().corr('kendall')
    /Users/ruifeng.zheng/Dev/spark/python/pyspark/pandas/utils.py:975: 
PandasAPIOnSparkAdviceWarning: `to_pandas` loads all data into the driver's 
memory. It should only be used if the resulting pandas DataFrame is expected to 
be small.
      warnings.warn(message, PandasAPIOnSparkAdviceWarning)
    Out[4]:
              dogs      cats
    dogs  1.000000 -0.912871
    cats -0.912871  1.000000
    ```
    
    ### How was this patch tested?
    added UT
    
    Closes #37913 from zhengruifeng/ps_df_kendall.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/pandas/frame.py            | 260 +++++++++++++++++++++---------
 python/pyspark/pandas/tests/test_stats.py |  32 ++--
 2 files changed, 204 insertions(+), 88 deletions(-)

diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index 4149868dde9..d7b26cacda3 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -1424,9 +1424,10 @@ class DataFrame(Frame, Generic[T]):
 
         Parameters
         ----------
-        method : {'pearson', 'spearman'}
+        method : {'pearson', 'spearman', 'kendall'}
             * pearson : standard correlation coefficient
             * spearman : Spearman rank correlation
+            * kendall : Kendall Tau correlation coefficient
         min_periods : int, optional
             Minimum number of observations required per pair of columns
             to have a valid result.
@@ -1435,12 +1436,21 @@ class DataFrame(Frame, Generic[T]):
 
         Returns
         -------
-        y : DataFrame
+        DataFrame
 
         See Also
         --------
+        DataFrame.corrwith
         Series.corr
 
+        Notes
+        -----
+        1. Pearson, Kendall and Spearman correlation are currently computed 
using pairwise
+           complete observations.
+
+        2. The complexity of Spearman correlation is O(#row * #row), if the 
dataset is too
+           large, sampling ahead of correlation computation is recommended.
+
         Examples
         --------
         >>> df = ps.DataFrame([(.2, .3), (.0, .6), (.6, .0), (.2, .1)],
@@ -1455,16 +1465,13 @@ class DataFrame(Frame, Generic[T]):
         dogs  1.000000 -0.948683
         cats -0.948683  1.000000
 
-        Notes
-        -----
-        There are behavior differences between pandas-on-Spark and pandas.
-
-        * the `method` argument only accepts 'pearson', 'spearman'
+        >>> df.corr('kendall')
+                  dogs      cats
+        dogs  1.000000 -0.912871
+        cats -0.912871  1.000000
         """
         if method not in ["pearson", "spearman", "kendall"]:
             raise ValueError(f"Invalid method {method}")
-        if method == "kendall":
-            raise NotImplementedError("method doesn't support kendall for now")
         if min_periods is not None and not isinstance(min_periods, int):
             raise TypeError(f"Invalid min_periods type 
{type(min_periods).__name__}")
 
@@ -1537,87 +1544,196 @@ class DataFrame(Frame, Generic[T]):
             .otherwise(F.col(f"{tmp_tuple_col_name}.{tmp_value_2_col_name}"))
             .alias(tmp_value_2_col_name),
         )
+        not_null_cond = (
+            F.col(tmp_value_1_col_name).isNotNull() & 
F.col(tmp_value_2_col_name).isNotNull()
+        )
 
-        # convert values to avg ranks for spearman correlation
-        if method == "spearman":
-            tmp_row_number_col_name = verify_temp_column_name(sdf, 
"__tmp_row_number_col__")
-            tmp_dense_rank_col_name = verify_temp_column_name(sdf, 
"__tmp_dense_rank_col__")
-            window = Window.partitionBy(tmp_index_1_col_name, 
tmp_index_2_col_name)
-
-            # tmp_value_1_col_name: value -> avg rank
-            # for example:
-            # values:       3, 4, 5, 7, 7, 7, 9, 9, 10
-            # avg ranks:    1.0, 2.0, 3.0, 5.0, 5.0, 5.0, 7.5, 7.5, 9.0
-            sdf = (
-                sdf.withColumn(
-                    tmp_row_number_col_name,
-                    
F.row_number().over(window.orderBy(F.asc_nulls_last(tmp_value_1_col_name))),
+        tmp_count_col_name = verify_temp_column_name(sdf, "__tmp_count_col__")
+        tmp_corr_col_name = verify_temp_column_name(sdf, "__tmp_corr_col__")
+        if method in ["pearson", "spearman"]:
+            # convert values to avg ranks for spearman correlation
+            if method == "spearman":
+                tmp_row_number_col_name = verify_temp_column_name(sdf, 
"__tmp_row_number_col__")
+                tmp_dense_rank_col_name = verify_temp_column_name(sdf, 
"__tmp_dense_rank_col__")
+                window = Window.partitionBy(tmp_index_1_col_name, 
tmp_index_2_col_name)
+
+                # tmp_value_1_col_name: value -> avg rank
+                # for example:
+                # values:       3, 4, 5, 7, 7, 7, 9, 9, 10
+                # avg ranks:    1.0, 2.0, 3.0, 5.0, 5.0, 5.0, 7.5, 7.5, 9.0
+                sdf = (
+                    sdf.withColumn(
+                        tmp_row_number_col_name,
+                        
F.row_number().over(window.orderBy(F.asc_nulls_last(tmp_value_1_col_name))),
+                    )
+                    .withColumn(
+                        tmp_dense_rank_col_name,
+                        
F.dense_rank().over(window.orderBy(F.asc_nulls_last(tmp_value_1_col_name))),
+                    )
+                    .withColumn(
+                        tmp_value_1_col_name,
+                        F.when(F.isnull(F.col(tmp_value_1_col_name)), 
F.lit(None)).otherwise(
+                            F.avg(tmp_row_number_col_name).over(
+                                
window.orderBy(F.asc(tmp_dense_rank_col_name)).rangeBetween(0, 0)
+                            )
+                        ),
+                    )
                 )
-                .withColumn(
-                    tmp_dense_rank_col_name,
-                    
F.dense_rank().over(window.orderBy(F.asc_nulls_last(tmp_value_1_col_name))),
+
+                # tmp_value_2_col_name: value -> avg rank
+                sdf = (
+                    sdf.withColumn(
+                        tmp_row_number_col_name,
+                        
F.row_number().over(window.orderBy(F.asc_nulls_last(tmp_value_2_col_name))),
+                    )
+                    .withColumn(
+                        tmp_dense_rank_col_name,
+                        
F.dense_rank().over(window.orderBy(F.asc_nulls_last(tmp_value_2_col_name))),
+                    )
+                    .withColumn(
+                        tmp_value_2_col_name,
+                        F.when(F.isnull(F.col(tmp_value_2_col_name)), 
F.lit(None)).otherwise(
+                            F.avg(tmp_row_number_col_name).over(
+                                
window.orderBy(F.asc(tmp_dense_rank_col_name)).rangeBetween(0, 0)
+                            )
+                        ),
+                    )
                 )
-                .withColumn(
+
+                sdf = sdf.select(
+                    tmp_index_1_col_name,
+                    tmp_index_2_col_name,
                     tmp_value_1_col_name,
-                    F.when(F.isnull(F.col(tmp_value_1_col_name)), 
F.lit(None)).otherwise(
-                        F.avg(tmp_row_number_col_name).over(
-                            
window.orderBy(F.asc(tmp_dense_rank_col_name)).rangeBetween(0, 0)
-                        )
-                    ),
+                    tmp_value_2_col_name,
                 )
+
+            # 
+-------------------+-------------------+----------------+-----------------+
+            # 
|__tmp_index_1_col__|__tmp_index_2_col__|__tmp_corr_col__|__tmp_count_col__|
+            # 
+-------------------+-------------------+----------------+-----------------+
+            # |                  2|                  2|            null|       
         1|
+            # |                  1|                  2|            null|       
         1|
+            # |                  1|                  1|             1.0|       
         2|
+            # |                  0|                  0|             1.0|       
         2|
+            # |                  0|                  1|            -1.0|       
         2|
+            # |                  0|                  2|            null|       
         1|
+            # 
+-------------------+-------------------+----------------+-----------------+
+            sdf = sdf.groupby(tmp_index_1_col_name, tmp_index_2_col_name).agg(
+                F.corr(tmp_value_1_col_name, 
tmp_value_2_col_name).alias(tmp_corr_col_name),
+                F.count(
+                    F.when(
+                        not_null_cond,
+                        1,
+                    )
+                ).alias(tmp_count_col_name),
+            )
+
+        else:
+            # kendall correlation
+            tmp_row_number_12_col_name = verify_temp_column_name(sdf, 
"__tmp_row_number_12_col__")
+            sdf = sdf.withColumn(
+                tmp_row_number_12_col_name,
+                F.row_number().over(
+                    Window.partitionBy(tmp_index_1_col_name, 
tmp_index_2_col_name).orderBy(
+                        F.asc_nulls_last(tmp_value_1_col_name),
+                        F.asc_nulls_last(tmp_value_2_col_name),
+                    )
+                ),
+            )
+
+            # drop nulls but make sure each partition contains at least one row
+            sdf = sdf.where(not_null_cond | (F.col(tmp_row_number_12_col_name) 
== 1))
+
+            tmp_value_x_col_name = verify_temp_column_name(sdf, 
"__tmp_value_x_col__")
+            tmp_value_y_col_name = verify_temp_column_name(sdf, 
"__tmp_value_y_col__")
+            tmp_row_number_xy_col_name = verify_temp_column_name(sdf, 
"__tmp_row_number_xy_col__")
+            sdf2 = sdf.select(
+                F.col(tmp_index_1_col_name),
+                F.col(tmp_index_2_col_name),
+                F.col(tmp_value_1_col_name).alias(tmp_value_x_col_name),
+                F.col(tmp_value_2_col_name).alias(tmp_value_y_col_name),
+                
F.col(tmp_row_number_12_col_name).alias(tmp_row_number_xy_col_name),
+            )
+
+            sdf = sdf.join(sdf2, [tmp_index_1_col_name, tmp_index_2_col_name], 
"inner").where(
+                F.col(tmp_row_number_12_col_name) <= 
F.col(tmp_row_number_xy_col_name)
+            )
+
+            # compute P, Q, T, U in tau_b = (P - Q) / sqrt((P + Q + T) * (P + 
Q + U))
+            # see 
https://github.com/scipy/scipy/blob/v1.9.1/scipy/stats/_stats_py.py#L5015-L5222
+            tmp_tau_b_p_col_name = verify_temp_column_name(sdf, 
"__tmp_tau_b_p_col__")
+            tmp_tau_b_q_col_name = verify_temp_column_name(sdf, 
"__tmp_tau_b_q_col__")
+            tmp_tau_b_t_col_name = verify_temp_column_name(sdf, 
"__tmp_tau_b_t_col__")
+            tmp_tau_b_u_col_name = verify_temp_column_name(sdf, 
"__tmp_tau_b_u_col__")
+
+            pair_cond = not_null_cond & (
+                F.col(tmp_row_number_12_col_name) < 
F.col(tmp_row_number_xy_col_name)
+            )
+
+            p_cond = (
+                (F.col(tmp_value_1_col_name) < F.col(tmp_value_x_col_name))
+                & (F.col(tmp_value_2_col_name) < F.col(tmp_value_y_col_name))
+            ) | (
+                (F.col(tmp_value_1_col_name) > F.col(tmp_value_x_col_name))
+                & (F.col(tmp_value_2_col_name) > F.col(tmp_value_y_col_name))
+            )
+            q_cond = (
+                (F.col(tmp_value_1_col_name) < F.col(tmp_value_x_col_name))
+                & (F.col(tmp_value_2_col_name) > F.col(tmp_value_y_col_name))
+            ) | (
+                (F.col(tmp_value_1_col_name) > F.col(tmp_value_x_col_name))
+                & (F.col(tmp_value_2_col_name) < F.col(tmp_value_y_col_name))
+            )
+            t_cond = (F.col(tmp_value_1_col_name) == 
F.col(tmp_value_x_col_name)) & (
+                F.col(tmp_value_2_col_name) != F.col(tmp_value_y_col_name)
+            )
+            u_cond = (F.col(tmp_value_1_col_name) != 
F.col(tmp_value_x_col_name)) & (
+                F.col(tmp_value_2_col_name) == F.col(tmp_value_y_col_name)
             )
 
-            # tmp_value_2_col_name: value -> avg rank
             sdf = (
-                sdf.withColumn(
-                    tmp_row_number_col_name,
-                    
F.row_number().over(window.orderBy(F.asc_nulls_last(tmp_value_2_col_name))),
-                )
-                .withColumn(
-                    tmp_dense_rank_col_name,
-                    
F.dense_rank().over(window.orderBy(F.asc_nulls_last(tmp_value_2_col_name))),
+                sdf.groupby(tmp_index_1_col_name, tmp_index_2_col_name)
+                .agg(
+                    F.count(F.when(pair_cond & p_cond, 
1)).alias(tmp_tau_b_p_col_name),
+                    F.count(F.when(pair_cond & q_cond, 
1)).alias(tmp_tau_b_q_col_name),
+                    F.count(F.when(pair_cond & t_cond, 
1)).alias(tmp_tau_b_t_col_name),
+                    F.count(F.when(pair_cond & u_cond, 
1)).alias(tmp_tau_b_u_col_name),
+                    F.max(
+                        F.when(not_null_cond, 
F.col(tmp_row_number_xy_col_name)).otherwise(F.lit(0))
+                    ).alias(tmp_count_col_name),
                 )
                 .withColumn(
-                    tmp_value_2_col_name,
-                    F.when(F.isnull(F.col(tmp_value_2_col_name)), 
F.lit(None)).otherwise(
-                        F.avg(tmp_row_number_col_name).over(
-                            
window.orderBy(F.asc(tmp_dense_rank_col_name)).rangeBetween(0, 0)
+                    tmp_corr_col_name,
+                    F.when(
+                        F.col(tmp_index_1_col_name) == 
F.col(tmp_index_2_col_name), F.lit(1.0)
+                    ).otherwise(
+                        (F.col(tmp_tau_b_p_col_name) - 
F.col(tmp_tau_b_q_col_name))
+                        / F.sqrt(
+                            (
+                                (
+                                    F.col(tmp_tau_b_p_col_name)
+                                    + F.col(tmp_tau_b_q_col_name)
+                                    + (F.col(tmp_tau_b_t_col_name))
+                                )
+                            )
+                            * (
+                                (
+                                    F.col(tmp_tau_b_p_col_name)
+                                    + F.col(tmp_tau_b_q_col_name)
+                                    + (F.col(tmp_tau_b_u_col_name))
+                                )
+                            )
                         )
                     ),
                 )
             )
 
             sdf = sdf.select(
-                tmp_index_1_col_name,
-                tmp_index_2_col_name,
-                tmp_value_1_col_name,
-                tmp_value_2_col_name,
+                F.col(tmp_index_1_col_name),
+                F.col(tmp_index_2_col_name),
+                F.col(tmp_corr_col_name),
+                F.col(tmp_count_col_name),
             )
 
-        # 
+-------------------+-------------------+----------------+-----------------+
-        # 
|__tmp_index_1_col__|__tmp_index_2_col__|__tmp_corr_col__|__tmp_count_col__|
-        # 
+-------------------+-------------------+----------------+-----------------+
-        # |                  2|                  2|            null|           
     1|
-        # |                  1|                  2|            null|           
     1|
-        # |                  1|                  1|             1.0|           
     2|
-        # |                  0|                  0|             1.0|           
     2|
-        # |                  0|                  1|            -1.0|           
     2|
-        # |                  0|                  2|            null|           
     1|
-        # 
+-------------------+-------------------+----------------+-----------------+
-        tmp_corr_col_name = verify_temp_column_name(sdf, "__tmp_corr_col__")
-        tmp_count_col_name = verify_temp_column_name(sdf, "__tmp_count_col__")
-
-        sdf = sdf.groupby(tmp_index_1_col_name, tmp_index_2_col_name).agg(
-            F.corr(tmp_value_1_col_name, 
tmp_value_2_col_name).alias(tmp_corr_col_name),
-            F.count(
-                F.when(
-                    F.col(tmp_value_1_col_name).isNotNull()
-                    & F.col(tmp_value_2_col_name).isNotNull(),
-                    1,
-                )
-            ).alias(tmp_count_col_name),
-        )
-
         # +-------------------+-------------------+----------------+
         # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_corr_col__|
         # +-------------------+-------------------+----------------+
diff --git a/python/pyspark/pandas/tests/test_stats.py 
b/python/pyspark/pandas/tests/test_stats.py
index fbe16146ff2..c6ca71696b6 100644
--- a/python/pyspark/pandas/tests/test_stats.py
+++ b/python/pyspark/pandas/tests/test_stats.py
@@ -265,12 +265,10 @@ class StatsTest(PandasOnSparkTestCase, SQLTestUtils):
 
         with self.assertRaisesRegex(ValueError, "Invalid method"):
             psdf.corr("std")
-        with self.assertRaisesRegex(NotImplementedError, "kendall for now"):
-            psdf.corr("kendall")
         with self.assertRaisesRegex(TypeError, "Invalid min_periods type"):
             psdf.corr(min_periods="3")
 
-        for method in ["pearson", "spearman"]:
+        for method in ["pearson", "spearman", "kendall"]:
             self.assert_eq(psdf.corr(method=method), pdf.corr(method=method), 
check_exact=False)
             self.assert_eq(
                 psdf.corr(method=method, min_periods=1),
@@ -293,7 +291,7 @@ class StatsTest(PandasOnSparkTestCase, SQLTestUtils):
         pdf.columns = columns
         psdf.columns = columns
 
-        for method in ["pearson", "spearman"]:
+        for method in ["pearson", "spearman", "kendall"]:
             self.assert_eq(psdf.corr(method=method), pdf.corr(method=method), 
check_exact=False)
             self.assert_eq(
                 psdf.corr(method=method, min_periods=1),
@@ -311,7 +309,7 @@ class StatsTest(PandasOnSparkTestCase, SQLTestUtils):
                 check_exact=False,
             )
 
-        # test spearman with identical values
+        # test with identical values
         pdf = pd.DataFrame(
             {
                 "a": [0, 1, 1, 1, 0],
@@ -321,17 +319,19 @@ class StatsTest(PandasOnSparkTestCase, SQLTestUtils):
             }
         )
         psdf = ps.from_pandas(pdf)
-        self.assert_eq(psdf.corr(method="spearman"), 
pdf.corr(method="spearman"), check_exact=False)
-        self.assert_eq(
-            psdf.corr(method="spearman", min_periods=1),
-            pdf.corr(method="spearman", min_periods=1),
-            check_exact=False,
-        )
-        self.assert_eq(
-            psdf.corr(method="spearman", min_periods=3),
-            pdf.corr(method="spearman", min_periods=3),
-            check_exact=False,
-        )
+
+        for method in ["pearson", "spearman", "kendall"]:
+            self.assert_eq(psdf.corr(method=method), pdf.corr(method=method), 
check_exact=False)
+            self.assert_eq(
+                psdf.corr(method=method, min_periods=1),
+                pdf.corr(method=method, min_periods=1),
+                check_exact=False,
+            )
+            self.assert_eq(
+                psdf.corr(method=method, min_periods=3),
+                pdf.corr(method=method, min_periods=3),
+                check_exact=False,
+            )
 
     def test_corr(self):
         # Disable arrow execution since corr() is using UDT internally which 
is not supported.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to