This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 5b2bd1c9c0c [SPARK-40447][PS] Implement `kendall` correlation in `DataFrame.corr` 5b2bd1c9c0c is described below commit 5b2bd1c9c0cb109f8a801dfcfb6ba1305bf864c6 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Sat Sep 17 07:30:31 2022 +0800 [SPARK-40447][PS] Implement `kendall` correlation in `DataFrame.corr` ### What changes were proposed in this pull request? Implement `kendall` correlation in `DataFrame.corr` ### Why are the changes needed? for API coverage ### Does this PR introduce _any_ user-facing change? yes, new correlation option: ``` In [1]: import pyspark.pandas as ps In [2]: df = ps.DataFrame([(.2, .3), (.0, .6), (.6, .0), (.2, .1)], columns=['dogs', 'cats']) In [3]: df.corr('kendall') dogs cats dogs 1.000000 -0.912871 cats -0.912871 1.000000 In [4]: df.to_pandas().corr('kendall') /Users/ruifeng.zheng/Dev/spark/python/pyspark/pandas/utils.py:975: PandasAPIOnSparkAdviceWarning: `to_pandas` loads all data into the driver's memory. It should only be used if the resulting pandas DataFrame is expected to be small. warnings.warn(message, PandasAPIOnSparkAdviceWarning) Out[4]: dogs cats dogs 1.000000 -0.912871 cats -0.912871 1.000000 ``` ### How was this patch tested? added UT Closes #37913 from zhengruifeng/ps_df_kendall. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/pandas/frame.py | 260 +++++++++++++++++++++--------- python/pyspark/pandas/tests/test_stats.py | 32 ++-- 2 files changed, 204 insertions(+), 88 deletions(-) diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 4149868dde9..d7b26cacda3 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -1424,9 +1424,10 @@ class DataFrame(Frame, Generic[T]): Parameters ---------- - method : {'pearson', 'spearman'} + method : {'pearson', 'spearman', 'kendall'} * pearson : standard correlation coefficient * spearman : Spearman rank correlation + * kendall : Kendall Tau correlation coefficient min_periods : int, optional Minimum number of observations required per pair of columns to have a valid result. @@ -1435,12 +1436,21 @@ class DataFrame(Frame, Generic[T]): Returns ------- - y : DataFrame + DataFrame See Also -------- + DataFrame.corrwith Series.corr + Notes + ----- + 1. Pearson, Kendall and Spearman correlation are currently computed using pairwise + complete observations. + + 2. The complexity of Spearman correlation is O(#row * #row), if the dataset is too + large, sampling ahead of correlation computation is recommended. + Examples -------- >>> df = ps.DataFrame([(.2, .3), (.0, .6), (.6, .0), (.2, .1)], @@ -1455,16 +1465,13 @@ class DataFrame(Frame, Generic[T]): dogs 1.000000 -0.948683 cats -0.948683 1.000000 - Notes - ----- - There are behavior differences between pandas-on-Spark and pandas. - - * the `method` argument only accepts 'pearson', 'spearman' + >>> df.corr('kendall') + dogs cats + dogs 1.000000 -0.912871 + cats -0.912871 1.000000 """ if method not in ["pearson", "spearman", "kendall"]: raise ValueError(f"Invalid method {method}") - if method == "kendall": - raise NotImplementedError("method doesn't support kendall for now") if min_periods is not None and not isinstance(min_periods, int): raise TypeError(f"Invalid min_periods type {type(min_periods).__name__}") @@ -1537,87 +1544,196 @@ class DataFrame(Frame, Generic[T]): .otherwise(F.col(f"{tmp_tuple_col_name}.{tmp_value_2_col_name}")) .alias(tmp_value_2_col_name), ) + not_null_cond = ( + F.col(tmp_value_1_col_name).isNotNull() & F.col(tmp_value_2_col_name).isNotNull() + ) - # convert values to avg ranks for spearman correlation - if method == "spearman": - tmp_row_number_col_name = verify_temp_column_name(sdf, "__tmp_row_number_col__") - tmp_dense_rank_col_name = verify_temp_column_name(sdf, "__tmp_dense_rank_col__") - window = Window.partitionBy(tmp_index_1_col_name, tmp_index_2_col_name) - - # tmp_value_1_col_name: value -> avg rank - # for example: - # values: 3, 4, 5, 7, 7, 7, 9, 9, 10 - # avg ranks: 1.0, 2.0, 3.0, 5.0, 5.0, 5.0, 7.5, 7.5, 9.0 - sdf = ( - sdf.withColumn( - tmp_row_number_col_name, - F.row_number().over(window.orderBy(F.asc_nulls_last(tmp_value_1_col_name))), + tmp_count_col_name = verify_temp_column_name(sdf, "__tmp_count_col__") + tmp_corr_col_name = verify_temp_column_name(sdf, "__tmp_corr_col__") + if method in ["pearson", "spearman"]: + # convert values to avg ranks for spearman correlation + if method == "spearman": + tmp_row_number_col_name = verify_temp_column_name(sdf, "__tmp_row_number_col__") + tmp_dense_rank_col_name = verify_temp_column_name(sdf, "__tmp_dense_rank_col__") + window = Window.partitionBy(tmp_index_1_col_name, tmp_index_2_col_name) + + # tmp_value_1_col_name: value -> avg rank + # for example: + # values: 3, 4, 5, 7, 7, 7, 9, 9, 10 + # avg ranks: 1.0, 2.0, 3.0, 5.0, 5.0, 5.0, 7.5, 7.5, 9.0 + sdf = ( + sdf.withColumn( + tmp_row_number_col_name, + F.row_number().over(window.orderBy(F.asc_nulls_last(tmp_value_1_col_name))), + ) + .withColumn( + tmp_dense_rank_col_name, + F.dense_rank().over(window.orderBy(F.asc_nulls_last(tmp_value_1_col_name))), + ) + .withColumn( + tmp_value_1_col_name, + F.when(F.isnull(F.col(tmp_value_1_col_name)), F.lit(None)).otherwise( + F.avg(tmp_row_number_col_name).over( + window.orderBy(F.asc(tmp_dense_rank_col_name)).rangeBetween(0, 0) + ) + ), + ) ) - .withColumn( - tmp_dense_rank_col_name, - F.dense_rank().over(window.orderBy(F.asc_nulls_last(tmp_value_1_col_name))), + + # tmp_value_2_col_name: value -> avg rank + sdf = ( + sdf.withColumn( + tmp_row_number_col_name, + F.row_number().over(window.orderBy(F.asc_nulls_last(tmp_value_2_col_name))), + ) + .withColumn( + tmp_dense_rank_col_name, + F.dense_rank().over(window.orderBy(F.asc_nulls_last(tmp_value_2_col_name))), + ) + .withColumn( + tmp_value_2_col_name, + F.when(F.isnull(F.col(tmp_value_2_col_name)), F.lit(None)).otherwise( + F.avg(tmp_row_number_col_name).over( + window.orderBy(F.asc(tmp_dense_rank_col_name)).rangeBetween(0, 0) + ) + ), + ) ) - .withColumn( + + sdf = sdf.select( + tmp_index_1_col_name, + tmp_index_2_col_name, tmp_value_1_col_name, - F.when(F.isnull(F.col(tmp_value_1_col_name)), F.lit(None)).otherwise( - F.avg(tmp_row_number_col_name).over( - window.orderBy(F.asc(tmp_dense_rank_col_name)).rangeBetween(0, 0) - ) - ), + tmp_value_2_col_name, ) + + # +-------------------+-------------------+----------------+-----------------+ + # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_corr_col__|__tmp_count_col__| + # +-------------------+-------------------+----------------+-----------------+ + # | 2| 2| null| 1| + # | 1| 2| null| 1| + # | 1| 1| 1.0| 2| + # | 0| 0| 1.0| 2| + # | 0| 1| -1.0| 2| + # | 0| 2| null| 1| + # +-------------------+-------------------+----------------+-----------------+ + sdf = sdf.groupby(tmp_index_1_col_name, tmp_index_2_col_name).agg( + F.corr(tmp_value_1_col_name, tmp_value_2_col_name).alias(tmp_corr_col_name), + F.count( + F.when( + not_null_cond, + 1, + ) + ).alias(tmp_count_col_name), + ) + + else: + # kendall correlation + tmp_row_number_12_col_name = verify_temp_column_name(sdf, "__tmp_row_number_12_col__") + sdf = sdf.withColumn( + tmp_row_number_12_col_name, + F.row_number().over( + Window.partitionBy(tmp_index_1_col_name, tmp_index_2_col_name).orderBy( + F.asc_nulls_last(tmp_value_1_col_name), + F.asc_nulls_last(tmp_value_2_col_name), + ) + ), + ) + + # drop nulls but make sure each partition contains at least one row + sdf = sdf.where(not_null_cond | (F.col(tmp_row_number_12_col_name) == 1)) + + tmp_value_x_col_name = verify_temp_column_name(sdf, "__tmp_value_x_col__") + tmp_value_y_col_name = verify_temp_column_name(sdf, "__tmp_value_y_col__") + tmp_row_number_xy_col_name = verify_temp_column_name(sdf, "__tmp_row_number_xy_col__") + sdf2 = sdf.select( + F.col(tmp_index_1_col_name), + F.col(tmp_index_2_col_name), + F.col(tmp_value_1_col_name).alias(tmp_value_x_col_name), + F.col(tmp_value_2_col_name).alias(tmp_value_y_col_name), + F.col(tmp_row_number_12_col_name).alias(tmp_row_number_xy_col_name), + ) + + sdf = sdf.join(sdf2, [tmp_index_1_col_name, tmp_index_2_col_name], "inner").where( + F.col(tmp_row_number_12_col_name) <= F.col(tmp_row_number_xy_col_name) + ) + + # compute P, Q, T, U in tau_b = (P - Q) / sqrt((P + Q + T) * (P + Q + U)) + # see https://github.com/scipy/scipy/blob/v1.9.1/scipy/stats/_stats_py.py#L5015-L5222 + tmp_tau_b_p_col_name = verify_temp_column_name(sdf, "__tmp_tau_b_p_col__") + tmp_tau_b_q_col_name = verify_temp_column_name(sdf, "__tmp_tau_b_q_col__") + tmp_tau_b_t_col_name = verify_temp_column_name(sdf, "__tmp_tau_b_t_col__") + tmp_tau_b_u_col_name = verify_temp_column_name(sdf, "__tmp_tau_b_u_col__") + + pair_cond = not_null_cond & ( + F.col(tmp_row_number_12_col_name) < F.col(tmp_row_number_xy_col_name) + ) + + p_cond = ( + (F.col(tmp_value_1_col_name) < F.col(tmp_value_x_col_name)) + & (F.col(tmp_value_2_col_name) < F.col(tmp_value_y_col_name)) + ) | ( + (F.col(tmp_value_1_col_name) > F.col(tmp_value_x_col_name)) + & (F.col(tmp_value_2_col_name) > F.col(tmp_value_y_col_name)) + ) + q_cond = ( + (F.col(tmp_value_1_col_name) < F.col(tmp_value_x_col_name)) + & (F.col(tmp_value_2_col_name) > F.col(tmp_value_y_col_name)) + ) | ( + (F.col(tmp_value_1_col_name) > F.col(tmp_value_x_col_name)) + & (F.col(tmp_value_2_col_name) < F.col(tmp_value_y_col_name)) + ) + t_cond = (F.col(tmp_value_1_col_name) == F.col(tmp_value_x_col_name)) & ( + F.col(tmp_value_2_col_name) != F.col(tmp_value_y_col_name) + ) + u_cond = (F.col(tmp_value_1_col_name) != F.col(tmp_value_x_col_name)) & ( + F.col(tmp_value_2_col_name) == F.col(tmp_value_y_col_name) ) - # tmp_value_2_col_name: value -> avg rank sdf = ( - sdf.withColumn( - tmp_row_number_col_name, - F.row_number().over(window.orderBy(F.asc_nulls_last(tmp_value_2_col_name))), - ) - .withColumn( - tmp_dense_rank_col_name, - F.dense_rank().over(window.orderBy(F.asc_nulls_last(tmp_value_2_col_name))), + sdf.groupby(tmp_index_1_col_name, tmp_index_2_col_name) + .agg( + F.count(F.when(pair_cond & p_cond, 1)).alias(tmp_tau_b_p_col_name), + F.count(F.when(pair_cond & q_cond, 1)).alias(tmp_tau_b_q_col_name), + F.count(F.when(pair_cond & t_cond, 1)).alias(tmp_tau_b_t_col_name), + F.count(F.when(pair_cond & u_cond, 1)).alias(tmp_tau_b_u_col_name), + F.max( + F.when(not_null_cond, F.col(tmp_row_number_xy_col_name)).otherwise(F.lit(0)) + ).alias(tmp_count_col_name), ) .withColumn( - tmp_value_2_col_name, - F.when(F.isnull(F.col(tmp_value_2_col_name)), F.lit(None)).otherwise( - F.avg(tmp_row_number_col_name).over( - window.orderBy(F.asc(tmp_dense_rank_col_name)).rangeBetween(0, 0) + tmp_corr_col_name, + F.when( + F.col(tmp_index_1_col_name) == F.col(tmp_index_2_col_name), F.lit(1.0) + ).otherwise( + (F.col(tmp_tau_b_p_col_name) - F.col(tmp_tau_b_q_col_name)) + / F.sqrt( + ( + ( + F.col(tmp_tau_b_p_col_name) + + F.col(tmp_tau_b_q_col_name) + + (F.col(tmp_tau_b_t_col_name)) + ) + ) + * ( + ( + F.col(tmp_tau_b_p_col_name) + + F.col(tmp_tau_b_q_col_name) + + (F.col(tmp_tau_b_u_col_name)) + ) + ) ) ), ) ) sdf = sdf.select( - tmp_index_1_col_name, - tmp_index_2_col_name, - tmp_value_1_col_name, - tmp_value_2_col_name, + F.col(tmp_index_1_col_name), + F.col(tmp_index_2_col_name), + F.col(tmp_corr_col_name), + F.col(tmp_count_col_name), ) - # +-------------------+-------------------+----------------+-----------------+ - # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_corr_col__|__tmp_count_col__| - # +-------------------+-------------------+----------------+-----------------+ - # | 2| 2| null| 1| - # | 1| 2| null| 1| - # | 1| 1| 1.0| 2| - # | 0| 0| 1.0| 2| - # | 0| 1| -1.0| 2| - # | 0| 2| null| 1| - # +-------------------+-------------------+----------------+-----------------+ - tmp_corr_col_name = verify_temp_column_name(sdf, "__tmp_corr_col__") - tmp_count_col_name = verify_temp_column_name(sdf, "__tmp_count_col__") - - sdf = sdf.groupby(tmp_index_1_col_name, tmp_index_2_col_name).agg( - F.corr(tmp_value_1_col_name, tmp_value_2_col_name).alias(tmp_corr_col_name), - F.count( - F.when( - F.col(tmp_value_1_col_name).isNotNull() - & F.col(tmp_value_2_col_name).isNotNull(), - 1, - ) - ).alias(tmp_count_col_name), - ) - # +-------------------+-------------------+----------------+ # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_corr_col__| # +-------------------+-------------------+----------------+ diff --git a/python/pyspark/pandas/tests/test_stats.py b/python/pyspark/pandas/tests/test_stats.py index fbe16146ff2..c6ca71696b6 100644 --- a/python/pyspark/pandas/tests/test_stats.py +++ b/python/pyspark/pandas/tests/test_stats.py @@ -265,12 +265,10 @@ class StatsTest(PandasOnSparkTestCase, SQLTestUtils): with self.assertRaisesRegex(ValueError, "Invalid method"): psdf.corr("std") - with self.assertRaisesRegex(NotImplementedError, "kendall for now"): - psdf.corr("kendall") with self.assertRaisesRegex(TypeError, "Invalid min_periods type"): psdf.corr(min_periods="3") - for method in ["pearson", "spearman"]: + for method in ["pearson", "spearman", "kendall"]: self.assert_eq(psdf.corr(method=method), pdf.corr(method=method), check_exact=False) self.assert_eq( psdf.corr(method=method, min_periods=1), @@ -293,7 +291,7 @@ class StatsTest(PandasOnSparkTestCase, SQLTestUtils): pdf.columns = columns psdf.columns = columns - for method in ["pearson", "spearman"]: + for method in ["pearson", "spearman", "kendall"]: self.assert_eq(psdf.corr(method=method), pdf.corr(method=method), check_exact=False) self.assert_eq( psdf.corr(method=method, min_periods=1), @@ -311,7 +309,7 @@ class StatsTest(PandasOnSparkTestCase, SQLTestUtils): check_exact=False, ) - # test spearman with identical values + # test with identical values pdf = pd.DataFrame( { "a": [0, 1, 1, 1, 0], @@ -321,17 +319,19 @@ class StatsTest(PandasOnSparkTestCase, SQLTestUtils): } ) psdf = ps.from_pandas(pdf) - self.assert_eq(psdf.corr(method="spearman"), pdf.corr(method="spearman"), check_exact=False) - self.assert_eq( - psdf.corr(method="spearman", min_periods=1), - pdf.corr(method="spearman", min_periods=1), - check_exact=False, - ) - self.assert_eq( - psdf.corr(method="spearman", min_periods=3), - pdf.corr(method="spearman", min_periods=3), - check_exact=False, - ) + + for method in ["pearson", "spearman", "kendall"]: + self.assert_eq(psdf.corr(method=method), pdf.corr(method=method), check_exact=False) + self.assert_eq( + psdf.corr(method=method, min_periods=1), + pdf.corr(method=method, min_periods=1), + check_exact=False, + ) + self.assert_eq( + psdf.corr(method=method, min_periods=3), + pdf.corr(method=method, min_periods=3), + check_exact=False, + ) def test_corr(self): # Disable arrow execution since corr() is using UDT internally which is not supported. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org