This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a3c837ae2ea [SPARK-41068][CONNECT][PYTHON] Implement
`DataFrame.stat.corr`
a3c837ae2ea is described below
commit a3c837ae2eaf2c7ba08563b7afa0f96df8a4e80b
Author: Jiaan Geng <[email protected]>
AuthorDate: Fri Dec 30 13:09:55 2022 +0800
[SPARK-41068][CONNECT][PYTHON] Implement `DataFrame.stat.corr`
### What changes were proposed in this pull request?
Implement `DataFrame.stat.corr` with a proto message
Implement `DataFrame.stat.corr` for scala API
Implement `DataFrame.stat.corr` for python API
### Why are the changes needed?
for Connect API coverage
### Does this PR introduce _any_ user-facing change?
'No'. New API
### How was this patch tested?
New test cases.
Closes #39236 from beliefer/SPARK-41068.
Authored-by: Jiaan Geng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../main/protobuf/spark/connect/relations.proto | 20 +++
.../org/apache/spark/sql/connect/dsl/package.scala | 16 ++
.../sql/connect/planner/SparkConnectPlanner.scala | 14 ++
python/pyspark/sql/connect/dataframe.py | 27 +++
python/pyspark/sql/connect/plan.py | 18 ++
python/pyspark/sql/connect/proto/relations_pb2.py | 194 +++++++++++----------
python/pyspark/sql/connect/proto/relations_pb2.pyi | 68 ++++++++
.../sql/tests/connect/test_connect_basic.py | 24 +++
8 files changed, 291 insertions(+), 90 deletions(-)
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 2d0837b4924..8a604f0702c 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -70,6 +70,7 @@ message Relation {
StatCrosstab crosstab = 101;
StatDescribe describe = 102;
StatCov cov = 103;
+ StatCorr corr = 104;
// Catalog API (experimental / unstable)
Catalog catalog = 200;
@@ -481,6 +482,25 @@ message StatCov {
string col2 = 3;
}
+// Calculates the correlation of two columns of a DataFrame. Currently only
supports the Pearson
+// Correlation Coefficient. It will invoke 'Dataset.stat.corr' (same as
+// 'StatFunctions.pearsonCorrelation') to compute the results.
+message StatCorr {
+ // (Required) The input relation.
+ Relation input = 1;
+
+ // (Required) The name of the first column.
+ string col1 = 2;
+
+ // (Required) The name of the second column.
+ string col2 = 3;
+
+ // (Optional) Default value is 'pearson'.
+ //
+ // Currently only supports the Pearson Correlation Coefficient.
+ optional string method = 4;
+}
+
// Replaces null values.
// It will invoke 'Dataset.na.fill' (same as 'DataFrameNaFunctions.fill') to
compute the results.
// Following 3 parameter combinations are supported:
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 9e3346d9364..3bd713a9710 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -387,6 +387,22 @@ package object dsl {
.build()
}
+ def corr(col1: String, col2: String, method: String): Relation = {
+ Relation
+ .newBuilder()
+ .setCorr(
+ proto.StatCorr
+ .newBuilder()
+ .setInput(logicalPlan)
+ .setCol1(col1)
+ .setCol2(col2)
+ .setMethod(method)
+ .build())
+ .build()
+ }
+
+ def corr(col1: String, col2: String): Relation = corr(col1, col2,
"pearson")
+
def crosstab(col1: String, col2: String): Relation = {
Relation
.newBuilder()
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index d06787e6b14..bb582e92755 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -89,6 +89,7 @@ class SparkConnectPlanner(session: SparkSession) {
case proto.Relation.RelTypeCase.SUMMARY =>
transformStatSummary(rel.getSummary)
case proto.Relation.RelTypeCase.DESCRIBE =>
transformStatDescribe(rel.getDescribe)
case proto.Relation.RelTypeCase.COV => transformStatCov(rel.getCov)
+ case proto.Relation.RelTypeCase.CORR => transformStatCorr(rel.getCorr)
case proto.Relation.RelTypeCase.CROSSTAB =>
transformStatCrosstab(rel.getCrosstab)
case proto.Relation.RelTypeCase.TO_SCHEMA =>
transformToSchema(rel.getToSchema)
@@ -352,6 +353,19 @@ class SparkConnectPlanner(session: SparkSession) {
data = Tuple1.apply(cov) :: Nil)
}
+ private def transformStatCorr(rel: proto.StatCorr): LogicalPlan = {
+ val df = Dataset.ofRows(session, transformRelation(rel.getInput))
+ val corr = if (rel.hasMethod) {
+ df.stat.corr(rel.getCol1, rel.getCol2, rel.getMethod)
+ } else {
+ df.stat.corr(rel.getCol1, rel.getCol2)
+ }
+
+ LocalRelation.fromProduct(
+ output = AttributeReference("corr", DoubleType, false)() :: Nil,
+ data = Tuple1.apply(corr) :: Nil)
+ }
+
private def transformStatCrosstab(rel: proto.StatCrosstab): LogicalPlan = {
Dataset
.ofRows(session, transformRelation(rel.getInput))
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index 08db6b61871..5b5a6c3f4b5 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -856,6 +856,28 @@ class DataFrame:
cov.__doc__ = PySparkDataFrame.cov.__doc__
+ def corr(self, col1: str, col2: str, method: Optional[str] = None) ->
float:
+ if not isinstance(col1, str):
+ raise TypeError("col1 should be a string.")
+ if not isinstance(col2, str):
+ raise TypeError("col2 should be a string.")
+ if not method:
+ method = "pearson"
+ if not method == "pearson":
+ raise ValueError(
+ "Currently only the calculation of the Pearson Correlation "
+ + "coefficient is supported."
+ )
+ pdf = DataFrame.withPlan(
+ plan.StatCorr(child=self._plan, col1=col1, col2=col2,
method=method),
+ session=self._session,
+ ).toPandas()
+
+ assert pdf is not None
+ return pdf["corr"][0]
+
+ corr.__doc__ = PySparkDataFrame.corr.__doc__
+
def crosstab(self, col1: str, col2: str) -> "DataFrame":
if not isinstance(col1, str):
raise TypeError(f"'col1' must be str, but got
{type(col1).__name__}")
@@ -1216,6 +1238,11 @@ class DataFrameStatFunctions:
cov.__doc__ = DataFrame.cov.__doc__
+ def corr(self, col1: str, col2: str, method: Optional[str] = None) ->
float:
+ return self.df.corr(col1, col2, method)
+
+ corr.__doc__ = DataFrame.corr.__doc__
+
def crosstab(self, col1: str, col2: str) -> DataFrame:
return self.df.crosstab(col1, col2)
diff --git a/python/pyspark/sql/connect/plan.py
b/python/pyspark/sql/connect/plan.py
index 616e4ce283b..e1b9fa0d0e4 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1095,6 +1095,24 @@ class StatCrosstab(LogicalPlan):
return plan
+class StatCorr(LogicalPlan):
+ def __init__(self, child: Optional["LogicalPlan"], col1: str, col2: str,
method: str) -> None:
+ super().__init__(child)
+ self._col1 = col1
+ self._col2 = col2
+ self._method = method
+
+ def plan(self, session: "SparkConnectClient") -> proto.Relation:
+ assert self._child is not None
+
+ plan = proto.Relation()
+ plan.corr.input.CopyFrom(self._child.plan(session))
+ plan.corr.col1 = self._col1
+ plan.corr.col2 = self._col2
+ plan.corr.method = self._method
+ return plan
+
+
class RenameColumns(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], cols: Sequence[str]) ->
None:
super().__init__(child)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py
b/python/pyspark/sql/connect/proto/relations_pb2.py
index 1fbb284ec37..7c938831882 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -35,7 +35,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as
spark_dot_connect_dot_catal
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xbe\x10\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04
\x01(\x0b\x32\x15.spa [...]
+
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xed\x10\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04
\x01(\x0b\x32\x15.spa [...]
)
@@ -69,6 +69,7 @@ _STATSUMMARY = DESCRIPTOR.message_types_by_name["StatSummary"]
_STATDESCRIBE = DESCRIPTOR.message_types_by_name["StatDescribe"]
_STATCROSSTAB = DESCRIPTOR.message_types_by_name["StatCrosstab"]
_STATCOV = DESCRIPTOR.message_types_by_name["StatCov"]
+_STATCORR = DESCRIPTOR.message_types_by_name["StatCorr"]
_NAFILL = DESCRIPTOR.message_types_by_name["NAFill"]
_NADROP = DESCRIPTOR.message_types_by_name["NADrop"]
_NAREPLACE = DESCRIPTOR.message_types_by_name["NAReplace"]
@@ -412,6 +413,17 @@ StatCov = _reflection.GeneratedProtocolMessageType(
)
_sym_db.RegisterMessage(StatCov)
+StatCorr = _reflection.GeneratedProtocolMessageType(
+ "StatCorr",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _STATCORR,
+ "__module__": "spark.connect.relations_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.StatCorr)
+ },
+)
+_sym_db.RegisterMessage(StatCorr)
+
NAFill = _reflection.GeneratedProtocolMessageType(
"NAFill",
(_message.Message,),
@@ -551,93 +563,95 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._options = None
_RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_options =
b"8\001"
_RELATION._serialized_start = 138
- _RELATION._serialized_end = 2248
- _UNKNOWN._serialized_start = 2250
- _UNKNOWN._serialized_end = 2259
- _RELATIONCOMMON._serialized_start = 2261
- _RELATIONCOMMON._serialized_end = 2310
- _SQL._serialized_start = 2312
- _SQL._serialized_end = 2339
- _READ._serialized_start = 2342
- _READ._serialized_end = 2768
- _READ_NAMEDTABLE._serialized_start = 2484
- _READ_NAMEDTABLE._serialized_end = 2545
- _READ_DATASOURCE._serialized_start = 2548
- _READ_DATASOURCE._serialized_end = 2755
- _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2686
- _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2744
- _PROJECT._serialized_start = 2770
- _PROJECT._serialized_end = 2887
- _FILTER._serialized_start = 2889
- _FILTER._serialized_end = 3001
- _JOIN._serialized_start = 3004
- _JOIN._serialized_end = 3475
- _JOIN_JOINTYPE._serialized_start = 3267
- _JOIN_JOINTYPE._serialized_end = 3475
- _SETOPERATION._serialized_start = 3478
- _SETOPERATION._serialized_end = 3874
- _SETOPERATION_SETOPTYPE._serialized_start = 3737
- _SETOPERATION_SETOPTYPE._serialized_end = 3851
- _LIMIT._serialized_start = 3876
- _LIMIT._serialized_end = 3952
- _OFFSET._serialized_start = 3954
- _OFFSET._serialized_end = 4033
- _TAIL._serialized_start = 4035
- _TAIL._serialized_end = 4110
- _AGGREGATE._serialized_start = 4113
- _AGGREGATE._serialized_end = 4695
- _AGGREGATE_PIVOT._serialized_start = 4452
- _AGGREGATE_PIVOT._serialized_end = 4563
- _AGGREGATE_GROUPTYPE._serialized_start = 4566
- _AGGREGATE_GROUPTYPE._serialized_end = 4695
- _SORT._serialized_start = 4698
- _SORT._serialized_end = 4858
- _DROP._serialized_start = 4860
- _DROP._serialized_end = 4960
- _DEDUPLICATE._serialized_start = 4963
- _DEDUPLICATE._serialized_end = 5134
- _LOCALRELATION._serialized_start = 5137
- _LOCALRELATION._serialized_end = 5274
- _SAMPLE._serialized_start = 5277
- _SAMPLE._serialized_end = 5572
- _RANGE._serialized_start = 5575
- _RANGE._serialized_end = 5720
- _SUBQUERYALIAS._serialized_start = 5722
- _SUBQUERYALIAS._serialized_end = 5836
- _REPARTITION._serialized_start = 5839
- _REPARTITION._serialized_end = 5981
- _SHOWSTRING._serialized_start = 5984
- _SHOWSTRING._serialized_end = 6126
- _STATSUMMARY._serialized_start = 6128
- _STATSUMMARY._serialized_end = 6220
- _STATDESCRIBE._serialized_start = 6222
- _STATDESCRIBE._serialized_end = 6303
- _STATCROSSTAB._serialized_start = 6305
- _STATCROSSTAB._serialized_end = 6406
- _STATCOV._serialized_start = 6408
- _STATCOV._serialized_end = 6504
- _NAFILL._serialized_start = 6507
- _NAFILL._serialized_end = 6641
- _NADROP._serialized_start = 6644
- _NADROP._serialized_end = 6778
- _NAREPLACE._serialized_start = 6781
- _NAREPLACE._serialized_end = 7077
- _NAREPLACE_REPLACEMENT._serialized_start = 6936
- _NAREPLACE_REPLACEMENT._serialized_end = 7077
- _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 7079
- _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 7193
- _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 7196
- _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 7455
- _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start =
7388
- _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 7455
- _WITHCOLUMNS._serialized_start = 7458
- _WITHCOLUMNS._serialized_end = 7589
- _HINT._serialized_start = 7592
- _HINT._serialized_end = 7732
- _UNPIVOT._serialized_start = 7735
- _UNPIVOT._serialized_end = 7981
- _TOSCHEMA._serialized_start = 7983
- _TOSCHEMA._serialized_end = 8089
- _REPARTITIONBYEXPRESSION._serialized_start = 8092
- _REPARTITIONBYEXPRESSION._serialized_end = 8295
+ _RELATION._serialized_end = 2295
+ _UNKNOWN._serialized_start = 2297
+ _UNKNOWN._serialized_end = 2306
+ _RELATIONCOMMON._serialized_start = 2308
+ _RELATIONCOMMON._serialized_end = 2357
+ _SQL._serialized_start = 2359
+ _SQL._serialized_end = 2386
+ _READ._serialized_start = 2389
+ _READ._serialized_end = 2815
+ _READ_NAMEDTABLE._serialized_start = 2531
+ _READ_NAMEDTABLE._serialized_end = 2592
+ _READ_DATASOURCE._serialized_start = 2595
+ _READ_DATASOURCE._serialized_end = 2802
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2733
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2791
+ _PROJECT._serialized_start = 2817
+ _PROJECT._serialized_end = 2934
+ _FILTER._serialized_start = 2936
+ _FILTER._serialized_end = 3048
+ _JOIN._serialized_start = 3051
+ _JOIN._serialized_end = 3522
+ _JOIN_JOINTYPE._serialized_start = 3314
+ _JOIN_JOINTYPE._serialized_end = 3522
+ _SETOPERATION._serialized_start = 3525
+ _SETOPERATION._serialized_end = 3921
+ _SETOPERATION_SETOPTYPE._serialized_start = 3784
+ _SETOPERATION_SETOPTYPE._serialized_end = 3898
+ _LIMIT._serialized_start = 3923
+ _LIMIT._serialized_end = 3999
+ _OFFSET._serialized_start = 4001
+ _OFFSET._serialized_end = 4080
+ _TAIL._serialized_start = 4082
+ _TAIL._serialized_end = 4157
+ _AGGREGATE._serialized_start = 4160
+ _AGGREGATE._serialized_end = 4742
+ _AGGREGATE_PIVOT._serialized_start = 4499
+ _AGGREGATE_PIVOT._serialized_end = 4610
+ _AGGREGATE_GROUPTYPE._serialized_start = 4613
+ _AGGREGATE_GROUPTYPE._serialized_end = 4742
+ _SORT._serialized_start = 4745
+ _SORT._serialized_end = 4905
+ _DROP._serialized_start = 4907
+ _DROP._serialized_end = 5007
+ _DEDUPLICATE._serialized_start = 5010
+ _DEDUPLICATE._serialized_end = 5181
+ _LOCALRELATION._serialized_start = 5184
+ _LOCALRELATION._serialized_end = 5321
+ _SAMPLE._serialized_start = 5324
+ _SAMPLE._serialized_end = 5619
+ _RANGE._serialized_start = 5622
+ _RANGE._serialized_end = 5767
+ _SUBQUERYALIAS._serialized_start = 5769
+ _SUBQUERYALIAS._serialized_end = 5883
+ _REPARTITION._serialized_start = 5886
+ _REPARTITION._serialized_end = 6028
+ _SHOWSTRING._serialized_start = 6031
+ _SHOWSTRING._serialized_end = 6173
+ _STATSUMMARY._serialized_start = 6175
+ _STATSUMMARY._serialized_end = 6267
+ _STATDESCRIBE._serialized_start = 6269
+ _STATDESCRIBE._serialized_end = 6350
+ _STATCROSSTAB._serialized_start = 6352
+ _STATCROSSTAB._serialized_end = 6453
+ _STATCOV._serialized_start = 6455
+ _STATCOV._serialized_end = 6551
+ _STATCORR._serialized_start = 6554
+ _STATCORR._serialized_end = 6691
+ _NAFILL._serialized_start = 6694
+ _NAFILL._serialized_end = 6828
+ _NADROP._serialized_start = 6831
+ _NADROP._serialized_end = 6965
+ _NAREPLACE._serialized_start = 6968
+ _NAREPLACE._serialized_end = 7264
+ _NAREPLACE_REPLACEMENT._serialized_start = 7123
+ _NAREPLACE_REPLACEMENT._serialized_end = 7264
+ _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 7266
+ _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 7380
+ _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 7383
+ _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 7642
+ _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start =
7575
+ _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 7642
+ _WITHCOLUMNS._serialized_start = 7645
+ _WITHCOLUMNS._serialized_end = 7776
+ _HINT._serialized_start = 7779
+ _HINT._serialized_end = 7919
+ _UNPIVOT._serialized_start = 7922
+ _UNPIVOT._serialized_end = 8168
+ _TOSCHEMA._serialized_start = 8170
+ _TOSCHEMA._serialized_end = 8276
+ _REPARTITIONBYEXPRESSION._serialized_start = 8279
+ _REPARTITIONBYEXPRESSION._serialized_end = 8482
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index e1a37abbd6c..63ccfa18559 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -95,6 +95,7 @@ class Relation(google.protobuf.message.Message):
CROSSTAB_FIELD_NUMBER: builtins.int
DESCRIBE_FIELD_NUMBER: builtins.int
COV_FIELD_NUMBER: builtins.int
+ CORR_FIELD_NUMBER: builtins.int
CATALOG_FIELD_NUMBER: builtins.int
UNKNOWN_FIELD_NUMBER: builtins.int
@property
@@ -168,6 +169,8 @@ class Relation(google.protobuf.message.Message):
@property
def cov(self) -> global___StatCov: ...
@property
+ def corr(self) -> global___StatCorr: ...
+ @property
def catalog(self) -> pyspark.sql.connect.proto.catalog_pb2.Catalog:
"""Catalog API (experimental / unstable)"""
@property
@@ -209,6 +212,7 @@ class Relation(google.protobuf.message.Message):
crosstab: global___StatCrosstab | None = ...,
describe: global___StatDescribe | None = ...,
cov: global___StatCov | None = ...,
+ corr: global___StatCorr | None = ...,
catalog: pyspark.sql.connect.proto.catalog_pb2.Catalog | None = ...,
unknown: global___Unknown | None = ...,
) -> None: ...
@@ -221,6 +225,8 @@ class Relation(google.protobuf.message.Message):
b"catalog",
"common",
b"common",
+ "corr",
+ b"corr",
"cov",
b"cov",
"crosstab",
@@ -300,6 +306,8 @@ class Relation(google.protobuf.message.Message):
b"catalog",
"common",
b"common",
+ "corr",
+ b"corr",
"cov",
b"cov",
"crosstab",
@@ -406,6 +414,7 @@ class Relation(google.protobuf.message.Message):
"crosstab",
"describe",
"cov",
+ "corr",
"catalog",
"unknown",
] | None: ...
@@ -1710,6 +1719,65 @@ class StatCov(google.protobuf.message.Message):
global___StatCov = StatCov
+class StatCorr(google.protobuf.message.Message):
+ """Calculates the correlation of two columns of a DataFrame. Currently
only supports the Pearson
+ Correlation Coefficient. It will invoke 'Dataset.stat.corr' (same as
+ 'StatFunctions.pearsonCorrelation') to compute the results.
+ """
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ INPUT_FIELD_NUMBER: builtins.int
+ COL1_FIELD_NUMBER: builtins.int
+ COL2_FIELD_NUMBER: builtins.int
+ METHOD_FIELD_NUMBER: builtins.int
+ @property
+ def input(self) -> global___Relation:
+ """(Required) The input relation."""
+ col1: builtins.str
+ """(Required) The name of the first column."""
+ col2: builtins.str
+ """(Required) The name of the second column."""
+ method: builtins.str
+ """(Optional) Default value is 'pearson'.
+
+ Currently only supports the Pearson Correlation Coefficient.
+ """
+ def __init__(
+ self,
+ *,
+ input: global___Relation | None = ...,
+ col1: builtins.str = ...,
+ col2: builtins.str = ...,
+ method: builtins.str | None = ...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_method", b"_method", "input", b"input", "method", b"method"
+ ],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_method",
+ b"_method",
+ "col1",
+ b"col1",
+ "col2",
+ b"col2",
+ "input",
+ b"input",
+ "method",
+ b"method",
+ ],
+ ) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_method", b"_method"]
+ ) -> typing_extensions.Literal["method"] | None: ...
+
+global___StatCorr = StatCorr
+
class NAFill(google.protobuf.message.Message):
"""Replaces null values.
It will invoke 'Dataset.na.fill' (same as 'DataFrameNaFunctions.fill') to
compute the results.
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 84c1baea80d..99ee54a87fa 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -1013,6 +1013,30 @@ class SparkConnectTests(SparkConnectSQLTestCase):
self.spark.read.table(self.tbl_name2).stat.cov("col1", "col3"),
)
+ def test_stat_corr(self):
+ # SPARK-41068: Test the stat.corr method
+ self.assertEqual(
+ self.connect.read.table(self.tbl_name2).stat.corr("col1", "col3"),
+ self.spark.read.table(self.tbl_name2).stat.corr("col1", "col3"),
+ )
+
+ self.assertEqual(
+ self.connect.read.table(self.tbl_name2).stat.corr("col1", "col3",
"pearson"),
+ self.spark.read.table(self.tbl_name2).stat.corr("col1", "col3",
"pearson"),
+ )
+
+ with self.assertRaisesRegex(TypeError, "col1 should be a string."):
+ self.connect.read.table(self.tbl_name2).stat.corr(1, "col3",
"pearson")
+ with self.assertRaisesRegex(TypeError, "col2 should be a string."):
+ self.connect.read.table(self.tbl_name).stat.corr("col1", 1,
"pearson")
+ with self.assertRaises(ValueError) as context:
+ self.connect.read.table(self.tbl_name2).stat.corr("col1", "col3",
"spearman"),
+ self.assertTrue(
+ "Currently only the calculation of the Pearson Correlation "
+ + "coefficient is supported."
+ in str(context.exception)
+ )
+
def test_repr(self):
# SPARK-41213: Test the __repr__ method
query = """SELECT * FROM VALUES (1L, NULL), (3L, "Z") AS tab(a, b)"""
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]