This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a3c837ae2ea [SPARK-41068][CONNECT][PYTHON] Implement 
`DataFrame.stat.corr`
a3c837ae2ea is described below

commit a3c837ae2eaf2c7ba08563b7afa0f96df8a4e80b
Author: Jiaan Geng <belie...@163.com>
AuthorDate: Fri Dec 30 13:09:55 2022 +0800

    [SPARK-41068][CONNECT][PYTHON] Implement `DataFrame.stat.corr`
    
    ### What changes were proposed in this pull request?
    Implement `DataFrame.stat.corr` with a proto message
    
    Implement `DataFrame.stat.corr` for scala API
    Implement `DataFrame.stat.corr` for python API
    
    ### Why are the changes needed?
    for Connect API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    'No'. New API
    
    ### How was this patch tested?
    New test cases.
    
    Closes #39236 from beliefer/SPARK-41068.
    
    Authored-by: Jiaan Geng <belie...@163.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../main/protobuf/spark/connect/relations.proto    |  20 +++
 .../org/apache/spark/sql/connect/dsl/package.scala |  16 ++
 .../sql/connect/planner/SparkConnectPlanner.scala  |  14 ++
 python/pyspark/sql/connect/dataframe.py            |  27 +++
 python/pyspark/sql/connect/plan.py                 |  18 ++
 python/pyspark/sql/connect/proto/relations_pb2.py  | 194 +++++++++++----------
 python/pyspark/sql/connect/proto/relations_pb2.pyi |  68 ++++++++
 .../sql/tests/connect/test_connect_basic.py        |  24 +++
 8 files changed, 291 insertions(+), 90 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 2d0837b4924..8a604f0702c 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -70,6 +70,7 @@ message Relation {
     StatCrosstab crosstab = 101;
     StatDescribe describe = 102;
     StatCov cov = 103;
+    StatCorr corr = 104;
 
     // Catalog API (experimental / unstable)
     Catalog catalog = 200;
@@ -481,6 +482,25 @@ message StatCov {
   string col2 = 3;
 }
 
+// Calculates the correlation of two columns of a DataFrame. Currently only 
supports the Pearson
+// Correlation Coefficient. It will invoke 'Dataset.stat.corr' (same as
+// 'StatFunctions.pearsonCorrelation') to compute the results.
+message StatCorr {
+  // (Required) The input relation.
+  Relation input = 1;
+
+  // (Required) The name of the first column.
+  string col1 = 2;
+
+  // (Required) The name of the second column.
+  string col2 = 3;
+
+  // (Optional) Default value is 'pearson'.
+  //
+  // Currently only supports the Pearson Correlation Coefficient.
+  optional string method = 4;
+}
+
 // Replaces null values.
 // It will invoke 'Dataset.na.fill' (same as 'DataFrameNaFunctions.fill') to 
compute the results.
 // Following 3 parameter combinations are supported:
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 9e3346d9364..3bd713a9710 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -387,6 +387,22 @@ package object dsl {
           .build()
       }
 
+      def corr(col1: String, col2: String, method: String): Relation = {
+        Relation
+          .newBuilder()
+          .setCorr(
+            proto.StatCorr
+              .newBuilder()
+              .setInput(logicalPlan)
+              .setCol1(col1)
+              .setCol2(col2)
+              .setMethod(method)
+              .build())
+          .build()
+      }
+
+      def corr(col1: String, col2: String): Relation = corr(col1, col2, 
"pearson")
+
       def crosstab(col1: String, col2: String): Relation = {
         Relation
           .newBuilder()
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index d06787e6b14..bb582e92755 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -89,6 +89,7 @@ class SparkConnectPlanner(session: SparkSession) {
       case proto.Relation.RelTypeCase.SUMMARY => 
transformStatSummary(rel.getSummary)
       case proto.Relation.RelTypeCase.DESCRIBE => 
transformStatDescribe(rel.getDescribe)
       case proto.Relation.RelTypeCase.COV => transformStatCov(rel.getCov)
+      case proto.Relation.RelTypeCase.CORR => transformStatCorr(rel.getCorr)
       case proto.Relation.RelTypeCase.CROSSTAB =>
         transformStatCrosstab(rel.getCrosstab)
       case proto.Relation.RelTypeCase.TO_SCHEMA => 
transformToSchema(rel.getToSchema)
@@ -352,6 +353,19 @@ class SparkConnectPlanner(session: SparkSession) {
       data = Tuple1.apply(cov) :: Nil)
   }
 
+  private def transformStatCorr(rel: proto.StatCorr): LogicalPlan = {
+    val df = Dataset.ofRows(session, transformRelation(rel.getInput))
+    val corr = if (rel.hasMethod) {
+      df.stat.corr(rel.getCol1, rel.getCol2, rel.getMethod)
+    } else {
+      df.stat.corr(rel.getCol1, rel.getCol2)
+    }
+
+    LocalRelation.fromProduct(
+      output = AttributeReference("corr", DoubleType, false)() :: Nil,
+      data = Tuple1.apply(corr) :: Nil)
+  }
+
   private def transformStatCrosstab(rel: proto.StatCrosstab): LogicalPlan = {
     Dataset
       .ofRows(session, transformRelation(rel.getInput))
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 08db6b61871..5b5a6c3f4b5 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -856,6 +856,28 @@ class DataFrame:
 
     cov.__doc__ = PySparkDataFrame.cov.__doc__
 
+    def corr(self, col1: str, col2: str, method: Optional[str] = None) -> 
float:
+        if not isinstance(col1, str):
+            raise TypeError("col1 should be a string.")
+        if not isinstance(col2, str):
+            raise TypeError("col2 should be a string.")
+        if not method:
+            method = "pearson"
+        if not method == "pearson":
+            raise ValueError(
+                "Currently only the calculation of the Pearson Correlation "
+                + "coefficient is supported."
+            )
+        pdf = DataFrame.withPlan(
+            plan.StatCorr(child=self._plan, col1=col1, col2=col2, 
method=method),
+            session=self._session,
+        ).toPandas()
+
+        assert pdf is not None
+        return pdf["corr"][0]
+
+    corr.__doc__ = PySparkDataFrame.corr.__doc__
+
     def crosstab(self, col1: str, col2: str) -> "DataFrame":
         if not isinstance(col1, str):
             raise TypeError(f"'col1' must be str, but got 
{type(col1).__name__}")
@@ -1216,6 +1238,11 @@ class DataFrameStatFunctions:
 
     cov.__doc__ = DataFrame.cov.__doc__
 
+    def corr(self, col1: str, col2: str, method: Optional[str] = None) -> 
float:
+        return self.df.corr(col1, col2, method)
+
+    corr.__doc__ = DataFrame.corr.__doc__
+
     def crosstab(self, col1: str, col2: str) -> DataFrame:
         return self.df.crosstab(col1, col2)
 
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 616e4ce283b..e1b9fa0d0e4 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1095,6 +1095,24 @@ class StatCrosstab(LogicalPlan):
         return plan
 
 
+class StatCorr(LogicalPlan):
+    def __init__(self, child: Optional["LogicalPlan"], col1: str, col2: str, 
method: str) -> None:
+        super().__init__(child)
+        self._col1 = col1
+        self._col2 = col2
+        self._method = method
+
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
+        assert self._child is not None
+
+        plan = proto.Relation()
+        plan.corr.input.CopyFrom(self._child.plan(session))
+        plan.corr.col1 = self._col1
+        plan.corr.col2 = self._col2
+        plan.corr.method = self._method
+        return plan
+
+
 class RenameColumns(LogicalPlan):
     def __init__(self, child: Optional["LogicalPlan"], cols: Sequence[str]) -> 
None:
         super().__init__(child)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py 
b/python/pyspark/sql/connect/proto/relations_pb2.py
index 1fbb284ec37..7c938831882 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -35,7 +35,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as 
spark_dot_connect_dot_catal
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xbe\x10\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04
 \x01(\x0b\x32\x15.spa [...]
+    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xed\x10\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04
 \x01(\x0b\x32\x15.spa [...]
 )
 
 
@@ -69,6 +69,7 @@ _STATSUMMARY = DESCRIPTOR.message_types_by_name["StatSummary"]
 _STATDESCRIBE = DESCRIPTOR.message_types_by_name["StatDescribe"]
 _STATCROSSTAB = DESCRIPTOR.message_types_by_name["StatCrosstab"]
 _STATCOV = DESCRIPTOR.message_types_by_name["StatCov"]
+_STATCORR = DESCRIPTOR.message_types_by_name["StatCorr"]
 _NAFILL = DESCRIPTOR.message_types_by_name["NAFill"]
 _NADROP = DESCRIPTOR.message_types_by_name["NADrop"]
 _NAREPLACE = DESCRIPTOR.message_types_by_name["NAReplace"]
@@ -412,6 +413,17 @@ StatCov = _reflection.GeneratedProtocolMessageType(
 )
 _sym_db.RegisterMessage(StatCov)
 
+StatCorr = _reflection.GeneratedProtocolMessageType(
+    "StatCorr",
+    (_message.Message,),
+    {
+        "DESCRIPTOR": _STATCORR,
+        "__module__": "spark.connect.relations_pb2"
+        # @@protoc_insertion_point(class_scope:spark.connect.StatCorr)
+    },
+)
+_sym_db.RegisterMessage(StatCorr)
+
 NAFill = _reflection.GeneratedProtocolMessageType(
     "NAFill",
     (_message.Message,),
@@ -551,93 +563,95 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._options = None
     _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_options = 
b"8\001"
     _RELATION._serialized_start = 138
-    _RELATION._serialized_end = 2248
-    _UNKNOWN._serialized_start = 2250
-    _UNKNOWN._serialized_end = 2259
-    _RELATIONCOMMON._serialized_start = 2261
-    _RELATIONCOMMON._serialized_end = 2310
-    _SQL._serialized_start = 2312
-    _SQL._serialized_end = 2339
-    _READ._serialized_start = 2342
-    _READ._serialized_end = 2768
-    _READ_NAMEDTABLE._serialized_start = 2484
-    _READ_NAMEDTABLE._serialized_end = 2545
-    _READ_DATASOURCE._serialized_start = 2548
-    _READ_DATASOURCE._serialized_end = 2755
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2686
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2744
-    _PROJECT._serialized_start = 2770
-    _PROJECT._serialized_end = 2887
-    _FILTER._serialized_start = 2889
-    _FILTER._serialized_end = 3001
-    _JOIN._serialized_start = 3004
-    _JOIN._serialized_end = 3475
-    _JOIN_JOINTYPE._serialized_start = 3267
-    _JOIN_JOINTYPE._serialized_end = 3475
-    _SETOPERATION._serialized_start = 3478
-    _SETOPERATION._serialized_end = 3874
-    _SETOPERATION_SETOPTYPE._serialized_start = 3737
-    _SETOPERATION_SETOPTYPE._serialized_end = 3851
-    _LIMIT._serialized_start = 3876
-    _LIMIT._serialized_end = 3952
-    _OFFSET._serialized_start = 3954
-    _OFFSET._serialized_end = 4033
-    _TAIL._serialized_start = 4035
-    _TAIL._serialized_end = 4110
-    _AGGREGATE._serialized_start = 4113
-    _AGGREGATE._serialized_end = 4695
-    _AGGREGATE_PIVOT._serialized_start = 4452
-    _AGGREGATE_PIVOT._serialized_end = 4563
-    _AGGREGATE_GROUPTYPE._serialized_start = 4566
-    _AGGREGATE_GROUPTYPE._serialized_end = 4695
-    _SORT._serialized_start = 4698
-    _SORT._serialized_end = 4858
-    _DROP._serialized_start = 4860
-    _DROP._serialized_end = 4960
-    _DEDUPLICATE._serialized_start = 4963
-    _DEDUPLICATE._serialized_end = 5134
-    _LOCALRELATION._serialized_start = 5137
-    _LOCALRELATION._serialized_end = 5274
-    _SAMPLE._serialized_start = 5277
-    _SAMPLE._serialized_end = 5572
-    _RANGE._serialized_start = 5575
-    _RANGE._serialized_end = 5720
-    _SUBQUERYALIAS._serialized_start = 5722
-    _SUBQUERYALIAS._serialized_end = 5836
-    _REPARTITION._serialized_start = 5839
-    _REPARTITION._serialized_end = 5981
-    _SHOWSTRING._serialized_start = 5984
-    _SHOWSTRING._serialized_end = 6126
-    _STATSUMMARY._serialized_start = 6128
-    _STATSUMMARY._serialized_end = 6220
-    _STATDESCRIBE._serialized_start = 6222
-    _STATDESCRIBE._serialized_end = 6303
-    _STATCROSSTAB._serialized_start = 6305
-    _STATCROSSTAB._serialized_end = 6406
-    _STATCOV._serialized_start = 6408
-    _STATCOV._serialized_end = 6504
-    _NAFILL._serialized_start = 6507
-    _NAFILL._serialized_end = 6641
-    _NADROP._serialized_start = 6644
-    _NADROP._serialized_end = 6778
-    _NAREPLACE._serialized_start = 6781
-    _NAREPLACE._serialized_end = 7077
-    _NAREPLACE_REPLACEMENT._serialized_start = 6936
-    _NAREPLACE_REPLACEMENT._serialized_end = 7077
-    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 7079
-    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 7193
-    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 7196
-    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 7455
-    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 
7388
-    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 7455
-    _WITHCOLUMNS._serialized_start = 7458
-    _WITHCOLUMNS._serialized_end = 7589
-    _HINT._serialized_start = 7592
-    _HINT._serialized_end = 7732
-    _UNPIVOT._serialized_start = 7735
-    _UNPIVOT._serialized_end = 7981
-    _TOSCHEMA._serialized_start = 7983
-    _TOSCHEMA._serialized_end = 8089
-    _REPARTITIONBYEXPRESSION._serialized_start = 8092
-    _REPARTITIONBYEXPRESSION._serialized_end = 8295
+    _RELATION._serialized_end = 2295
+    _UNKNOWN._serialized_start = 2297
+    _UNKNOWN._serialized_end = 2306
+    _RELATIONCOMMON._serialized_start = 2308
+    _RELATIONCOMMON._serialized_end = 2357
+    _SQL._serialized_start = 2359
+    _SQL._serialized_end = 2386
+    _READ._serialized_start = 2389
+    _READ._serialized_end = 2815
+    _READ_NAMEDTABLE._serialized_start = 2531
+    _READ_NAMEDTABLE._serialized_end = 2592
+    _READ_DATASOURCE._serialized_start = 2595
+    _READ_DATASOURCE._serialized_end = 2802
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2733
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2791
+    _PROJECT._serialized_start = 2817
+    _PROJECT._serialized_end = 2934
+    _FILTER._serialized_start = 2936
+    _FILTER._serialized_end = 3048
+    _JOIN._serialized_start = 3051
+    _JOIN._serialized_end = 3522
+    _JOIN_JOINTYPE._serialized_start = 3314
+    _JOIN_JOINTYPE._serialized_end = 3522
+    _SETOPERATION._serialized_start = 3525
+    _SETOPERATION._serialized_end = 3921
+    _SETOPERATION_SETOPTYPE._serialized_start = 3784
+    _SETOPERATION_SETOPTYPE._serialized_end = 3898
+    _LIMIT._serialized_start = 3923
+    _LIMIT._serialized_end = 3999
+    _OFFSET._serialized_start = 4001
+    _OFFSET._serialized_end = 4080
+    _TAIL._serialized_start = 4082
+    _TAIL._serialized_end = 4157
+    _AGGREGATE._serialized_start = 4160
+    _AGGREGATE._serialized_end = 4742
+    _AGGREGATE_PIVOT._serialized_start = 4499
+    _AGGREGATE_PIVOT._serialized_end = 4610
+    _AGGREGATE_GROUPTYPE._serialized_start = 4613
+    _AGGREGATE_GROUPTYPE._serialized_end = 4742
+    _SORT._serialized_start = 4745
+    _SORT._serialized_end = 4905
+    _DROP._serialized_start = 4907
+    _DROP._serialized_end = 5007
+    _DEDUPLICATE._serialized_start = 5010
+    _DEDUPLICATE._serialized_end = 5181
+    _LOCALRELATION._serialized_start = 5184
+    _LOCALRELATION._serialized_end = 5321
+    _SAMPLE._serialized_start = 5324
+    _SAMPLE._serialized_end = 5619
+    _RANGE._serialized_start = 5622
+    _RANGE._serialized_end = 5767
+    _SUBQUERYALIAS._serialized_start = 5769
+    _SUBQUERYALIAS._serialized_end = 5883
+    _REPARTITION._serialized_start = 5886
+    _REPARTITION._serialized_end = 6028
+    _SHOWSTRING._serialized_start = 6031
+    _SHOWSTRING._serialized_end = 6173
+    _STATSUMMARY._serialized_start = 6175
+    _STATSUMMARY._serialized_end = 6267
+    _STATDESCRIBE._serialized_start = 6269
+    _STATDESCRIBE._serialized_end = 6350
+    _STATCROSSTAB._serialized_start = 6352
+    _STATCROSSTAB._serialized_end = 6453
+    _STATCOV._serialized_start = 6455
+    _STATCOV._serialized_end = 6551
+    _STATCORR._serialized_start = 6554
+    _STATCORR._serialized_end = 6691
+    _NAFILL._serialized_start = 6694
+    _NAFILL._serialized_end = 6828
+    _NADROP._serialized_start = 6831
+    _NADROP._serialized_end = 6965
+    _NAREPLACE._serialized_start = 6968
+    _NAREPLACE._serialized_end = 7264
+    _NAREPLACE_REPLACEMENT._serialized_start = 7123
+    _NAREPLACE_REPLACEMENT._serialized_end = 7264
+    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 7266
+    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 7380
+    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 7383
+    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 7642
+    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 
7575
+    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 7642
+    _WITHCOLUMNS._serialized_start = 7645
+    _WITHCOLUMNS._serialized_end = 7776
+    _HINT._serialized_start = 7779
+    _HINT._serialized_end = 7919
+    _UNPIVOT._serialized_start = 7922
+    _UNPIVOT._serialized_end = 8168
+    _TOSCHEMA._serialized_start = 8170
+    _TOSCHEMA._serialized_end = 8276
+    _REPARTITIONBYEXPRESSION._serialized_start = 8279
+    _REPARTITIONBYEXPRESSION._serialized_end = 8482
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi 
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index e1a37abbd6c..63ccfa18559 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -95,6 +95,7 @@ class Relation(google.protobuf.message.Message):
     CROSSTAB_FIELD_NUMBER: builtins.int
     DESCRIBE_FIELD_NUMBER: builtins.int
     COV_FIELD_NUMBER: builtins.int
+    CORR_FIELD_NUMBER: builtins.int
     CATALOG_FIELD_NUMBER: builtins.int
     UNKNOWN_FIELD_NUMBER: builtins.int
     @property
@@ -168,6 +169,8 @@ class Relation(google.protobuf.message.Message):
     @property
     def cov(self) -> global___StatCov: ...
     @property
+    def corr(self) -> global___StatCorr: ...
+    @property
     def catalog(self) -> pyspark.sql.connect.proto.catalog_pb2.Catalog:
         """Catalog API (experimental / unstable)"""
     @property
@@ -209,6 +212,7 @@ class Relation(google.protobuf.message.Message):
         crosstab: global___StatCrosstab | None = ...,
         describe: global___StatDescribe | None = ...,
         cov: global___StatCov | None = ...,
+        corr: global___StatCorr | None = ...,
         catalog: pyspark.sql.connect.proto.catalog_pb2.Catalog | None = ...,
         unknown: global___Unknown | None = ...,
     ) -> None: ...
@@ -221,6 +225,8 @@ class Relation(google.protobuf.message.Message):
             b"catalog",
             "common",
             b"common",
+            "corr",
+            b"corr",
             "cov",
             b"cov",
             "crosstab",
@@ -300,6 +306,8 @@ class Relation(google.protobuf.message.Message):
             b"catalog",
             "common",
             b"common",
+            "corr",
+            b"corr",
             "cov",
             b"cov",
             "crosstab",
@@ -406,6 +414,7 @@ class Relation(google.protobuf.message.Message):
         "crosstab",
         "describe",
         "cov",
+        "corr",
         "catalog",
         "unknown",
     ] | None: ...
@@ -1710,6 +1719,65 @@ class StatCov(google.protobuf.message.Message):
 
 global___StatCov = StatCov
 
+class StatCorr(google.protobuf.message.Message):
+    """Calculates the correlation of two columns of a DataFrame. Currently 
only supports the Pearson
+    Correlation Coefficient. It will invoke 'Dataset.stat.corr' (same as
+    'StatFunctions.pearsonCorrelation') to compute the results.
+    """
+
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    INPUT_FIELD_NUMBER: builtins.int
+    COL1_FIELD_NUMBER: builtins.int
+    COL2_FIELD_NUMBER: builtins.int
+    METHOD_FIELD_NUMBER: builtins.int
+    @property
+    def input(self) -> global___Relation:
+        """(Required) The input relation."""
+    col1: builtins.str
+    """(Required) The name of the first column."""
+    col2: builtins.str
+    """(Required) The name of the second column."""
+    method: builtins.str
+    """(Optional) Default value is 'pearson'.
+
+    Currently only supports the Pearson Correlation Coefficient.
+    """
+    def __init__(
+        self,
+        *,
+        input: global___Relation | None = ...,
+        col1: builtins.str = ...,
+        col2: builtins.str = ...,
+        method: builtins.str | None = ...,
+    ) -> None: ...
+    def HasField(
+        self,
+        field_name: typing_extensions.Literal[
+            "_method", b"_method", "input", b"input", "method", b"method"
+        ],
+    ) -> builtins.bool: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "_method",
+            b"_method",
+            "col1",
+            b"col1",
+            "col2",
+            b"col2",
+            "input",
+            b"input",
+            "method",
+            b"method",
+        ],
+    ) -> None: ...
+    def WhichOneof(
+        self, oneof_group: typing_extensions.Literal["_method", b"_method"]
+    ) -> typing_extensions.Literal["method"] | None: ...
+
+global___StatCorr = StatCorr
+
 class NAFill(google.protobuf.message.Message):
     """Replaces null values.
     It will invoke 'Dataset.na.fill' (same as 'DataFrameNaFunctions.fill') to 
compute the results.
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 84c1baea80d..99ee54a87fa 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -1013,6 +1013,30 @@ class SparkConnectTests(SparkConnectSQLTestCase):
             self.spark.read.table(self.tbl_name2).stat.cov("col1", "col3"),
         )
 
+    def test_stat_corr(self):
+        # SPARK-41068: Test the stat.corr method
+        self.assertEqual(
+            self.connect.read.table(self.tbl_name2).stat.corr("col1", "col3"),
+            self.spark.read.table(self.tbl_name2).stat.corr("col1", "col3"),
+        )
+
+        self.assertEqual(
+            self.connect.read.table(self.tbl_name2).stat.corr("col1", "col3", 
"pearson"),
+            self.spark.read.table(self.tbl_name2).stat.corr("col1", "col3", 
"pearson"),
+        )
+
+        with self.assertRaisesRegex(TypeError, "col1 should be a string."):
+            self.connect.read.table(self.tbl_name2).stat.corr(1, "col3", 
"pearson")
+        with self.assertRaisesRegex(TypeError, "col2 should be a string."):
+            self.connect.read.table(self.tbl_name).stat.corr("col1", 1, 
"pearson")
+        with self.assertRaises(ValueError) as context:
+            self.connect.read.table(self.tbl_name2).stat.corr("col1", "col3", 
"spearman"),
+            self.assertTrue(
+                "Currently only the calculation of the Pearson Correlation "
+                + "coefficient is supported."
+                in str(context.exception)
+            )
+
     def test_repr(self):
         # SPARK-41213: Test the __repr__ method
         query = """SELECT * FROM VALUES (1L, NULL), (3L, "Z") AS tab(a, b)"""


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to