This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new a3c837ae2ea [SPARK-41068][CONNECT][PYTHON] Implement `DataFrame.stat.corr` a3c837ae2ea is described below commit a3c837ae2eaf2c7ba08563b7afa0f96df8a4e80b Author: Jiaan Geng <belie...@163.com> AuthorDate: Fri Dec 30 13:09:55 2022 +0800 [SPARK-41068][CONNECT][PYTHON] Implement `DataFrame.stat.corr` ### What changes were proposed in this pull request? Implement `DataFrame.stat.corr` with a proto message Implement `DataFrame.stat.corr` for scala API Implement `DataFrame.stat.corr` for python API ### Why are the changes needed? for Connect API coverage ### Does this PR introduce _any_ user-facing change? 'No'. New API ### How was this patch tested? New test cases. Closes #39236 from beliefer/SPARK-41068. Authored-by: Jiaan Geng <belie...@163.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../main/protobuf/spark/connect/relations.proto | 20 +++ .../org/apache/spark/sql/connect/dsl/package.scala | 16 ++ .../sql/connect/planner/SparkConnectPlanner.scala | 14 ++ python/pyspark/sql/connect/dataframe.py | 27 +++ python/pyspark/sql/connect/plan.py | 18 ++ python/pyspark/sql/connect/proto/relations_pb2.py | 194 +++++++++++---------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 68 ++++++++ .../sql/tests/connect/test_connect_basic.py | 24 +++ 8 files changed, 291 insertions(+), 90 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto index 2d0837b4924..8a604f0702c 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -70,6 +70,7 @@ message Relation { StatCrosstab crosstab = 101; StatDescribe describe = 102; StatCov cov = 103; + StatCorr corr = 104; // Catalog API (experimental / unstable) Catalog catalog = 200; @@ -481,6 +482,25 @@ message StatCov { string col2 = 3; } +// Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson +// Correlation Coefficient. It will invoke 'Dataset.stat.corr' (same as +// 'StatFunctions.pearsonCorrelation') to compute the results. +message StatCorr { + // (Required) The input relation. + Relation input = 1; + + // (Required) The name of the first column. + string col1 = 2; + + // (Required) The name of the second column. + string col2 = 3; + + // (Optional) Default value is 'pearson'. + // + // Currently only supports the Pearson Correlation Coefficient. + optional string method = 4; +} + // Replaces null values. // It will invoke 'Dataset.na.fill' (same as 'DataFrameNaFunctions.fill') to compute the results. // Following 3 parameter combinations are supported: diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 9e3346d9364..3bd713a9710 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -387,6 +387,22 @@ package object dsl { .build() } + def corr(col1: String, col2: String, method: String): Relation = { + Relation + .newBuilder() + .setCorr( + proto.StatCorr + .newBuilder() + .setInput(logicalPlan) + .setCol1(col1) + .setCol2(col2) + .setMethod(method) + .build()) + .build() + } + + def corr(col1: String, col2: String): Relation = corr(col1, col2, "pearson") + def crosstab(col1: String, col2: String): Relation = { Relation .newBuilder() diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index d06787e6b14..bb582e92755 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -89,6 +89,7 @@ class SparkConnectPlanner(session: SparkSession) { case proto.Relation.RelTypeCase.SUMMARY => transformStatSummary(rel.getSummary) case proto.Relation.RelTypeCase.DESCRIBE => transformStatDescribe(rel.getDescribe) case proto.Relation.RelTypeCase.COV => transformStatCov(rel.getCov) + case proto.Relation.RelTypeCase.CORR => transformStatCorr(rel.getCorr) case proto.Relation.RelTypeCase.CROSSTAB => transformStatCrosstab(rel.getCrosstab) case proto.Relation.RelTypeCase.TO_SCHEMA => transformToSchema(rel.getToSchema) @@ -352,6 +353,19 @@ class SparkConnectPlanner(session: SparkSession) { data = Tuple1.apply(cov) :: Nil) } + private def transformStatCorr(rel: proto.StatCorr): LogicalPlan = { + val df = Dataset.ofRows(session, transformRelation(rel.getInput)) + val corr = if (rel.hasMethod) { + df.stat.corr(rel.getCol1, rel.getCol2, rel.getMethod) + } else { + df.stat.corr(rel.getCol1, rel.getCol2) + } + + LocalRelation.fromProduct( + output = AttributeReference("corr", DoubleType, false)() :: Nil, + data = Tuple1.apply(corr) :: Nil) + } + private def transformStatCrosstab(rel: proto.StatCrosstab): LogicalPlan = { Dataset .ofRows(session, transformRelation(rel.getInput)) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 08db6b61871..5b5a6c3f4b5 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -856,6 +856,28 @@ class DataFrame: cov.__doc__ = PySparkDataFrame.cov.__doc__ + def corr(self, col1: str, col2: str, method: Optional[str] = None) -> float: + if not isinstance(col1, str): + raise TypeError("col1 should be a string.") + if not isinstance(col2, str): + raise TypeError("col2 should be a string.") + if not method: + method = "pearson" + if not method == "pearson": + raise ValueError( + "Currently only the calculation of the Pearson Correlation " + + "coefficient is supported." + ) + pdf = DataFrame.withPlan( + plan.StatCorr(child=self._plan, col1=col1, col2=col2, method=method), + session=self._session, + ).toPandas() + + assert pdf is not None + return pdf["corr"][0] + + corr.__doc__ = PySparkDataFrame.corr.__doc__ + def crosstab(self, col1: str, col2: str) -> "DataFrame": if not isinstance(col1, str): raise TypeError(f"'col1' must be str, but got {type(col1).__name__}") @@ -1216,6 +1238,11 @@ class DataFrameStatFunctions: cov.__doc__ = DataFrame.cov.__doc__ + def corr(self, col1: str, col2: str, method: Optional[str] = None) -> float: + return self.df.corr(col1, col2, method) + + corr.__doc__ = DataFrame.corr.__doc__ + def crosstab(self, col1: str, col2: str) -> DataFrame: return self.df.crosstab(col1, col2) diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 616e4ce283b..e1b9fa0d0e4 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -1095,6 +1095,24 @@ class StatCrosstab(LogicalPlan): return plan +class StatCorr(LogicalPlan): + def __init__(self, child: Optional["LogicalPlan"], col1: str, col2: str, method: str) -> None: + super().__init__(child) + self._col1 = col1 + self._col2 = col2 + self._method = method + + def plan(self, session: "SparkConnectClient") -> proto.Relation: + assert self._child is not None + + plan = proto.Relation() + plan.corr.input.CopyFrom(self._child.plan(session)) + plan.corr.col1 = self._col1 + plan.corr.col2 = self._col2 + plan.corr.method = self._method + return plan + + class RenameColumns(LogicalPlan): def __init__(self, child: Optional["LogicalPlan"], cols: Sequence[str]) -> None: super().__init__(child) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 1fbb284ec37..7c938831882 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -35,7 +35,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catal DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xbe\x10\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spa [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xed\x10\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spa [...] ) @@ -69,6 +69,7 @@ _STATSUMMARY = DESCRIPTOR.message_types_by_name["StatSummary"] _STATDESCRIBE = DESCRIPTOR.message_types_by_name["StatDescribe"] _STATCROSSTAB = DESCRIPTOR.message_types_by_name["StatCrosstab"] _STATCOV = DESCRIPTOR.message_types_by_name["StatCov"] +_STATCORR = DESCRIPTOR.message_types_by_name["StatCorr"] _NAFILL = DESCRIPTOR.message_types_by_name["NAFill"] _NADROP = DESCRIPTOR.message_types_by_name["NADrop"] _NAREPLACE = DESCRIPTOR.message_types_by_name["NAReplace"] @@ -412,6 +413,17 @@ StatCov = _reflection.GeneratedProtocolMessageType( ) _sym_db.RegisterMessage(StatCov) +StatCorr = _reflection.GeneratedProtocolMessageType( + "StatCorr", + (_message.Message,), + { + "DESCRIPTOR": _STATCORR, + "__module__": "spark.connect.relations_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.StatCorr) + }, +) +_sym_db.RegisterMessage(StatCorr) + NAFill = _reflection.GeneratedProtocolMessageType( "NAFill", (_message.Message,), @@ -551,93 +563,95 @@ if _descriptor._USE_C_DESCRIPTORS == False: _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._options = None _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_options = b"8\001" _RELATION._serialized_start = 138 - _RELATION._serialized_end = 2248 - _UNKNOWN._serialized_start = 2250 - _UNKNOWN._serialized_end = 2259 - _RELATIONCOMMON._serialized_start = 2261 - _RELATIONCOMMON._serialized_end = 2310 - _SQL._serialized_start = 2312 - _SQL._serialized_end = 2339 - _READ._serialized_start = 2342 - _READ._serialized_end = 2768 - _READ_NAMEDTABLE._serialized_start = 2484 - _READ_NAMEDTABLE._serialized_end = 2545 - _READ_DATASOURCE._serialized_start = 2548 - _READ_DATASOURCE._serialized_end = 2755 - _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2686 - _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2744 - _PROJECT._serialized_start = 2770 - _PROJECT._serialized_end = 2887 - _FILTER._serialized_start = 2889 - _FILTER._serialized_end = 3001 - _JOIN._serialized_start = 3004 - _JOIN._serialized_end = 3475 - _JOIN_JOINTYPE._serialized_start = 3267 - _JOIN_JOINTYPE._serialized_end = 3475 - _SETOPERATION._serialized_start = 3478 - _SETOPERATION._serialized_end = 3874 - _SETOPERATION_SETOPTYPE._serialized_start = 3737 - _SETOPERATION_SETOPTYPE._serialized_end = 3851 - _LIMIT._serialized_start = 3876 - _LIMIT._serialized_end = 3952 - _OFFSET._serialized_start = 3954 - _OFFSET._serialized_end = 4033 - _TAIL._serialized_start = 4035 - _TAIL._serialized_end = 4110 - _AGGREGATE._serialized_start = 4113 - _AGGREGATE._serialized_end = 4695 - _AGGREGATE_PIVOT._serialized_start = 4452 - _AGGREGATE_PIVOT._serialized_end = 4563 - _AGGREGATE_GROUPTYPE._serialized_start = 4566 - _AGGREGATE_GROUPTYPE._serialized_end = 4695 - _SORT._serialized_start = 4698 - _SORT._serialized_end = 4858 - _DROP._serialized_start = 4860 - _DROP._serialized_end = 4960 - _DEDUPLICATE._serialized_start = 4963 - _DEDUPLICATE._serialized_end = 5134 - _LOCALRELATION._serialized_start = 5137 - _LOCALRELATION._serialized_end = 5274 - _SAMPLE._serialized_start = 5277 - _SAMPLE._serialized_end = 5572 - _RANGE._serialized_start = 5575 - _RANGE._serialized_end = 5720 - _SUBQUERYALIAS._serialized_start = 5722 - _SUBQUERYALIAS._serialized_end = 5836 - _REPARTITION._serialized_start = 5839 - _REPARTITION._serialized_end = 5981 - _SHOWSTRING._serialized_start = 5984 - _SHOWSTRING._serialized_end = 6126 - _STATSUMMARY._serialized_start = 6128 - _STATSUMMARY._serialized_end = 6220 - _STATDESCRIBE._serialized_start = 6222 - _STATDESCRIBE._serialized_end = 6303 - _STATCROSSTAB._serialized_start = 6305 - _STATCROSSTAB._serialized_end = 6406 - _STATCOV._serialized_start = 6408 - _STATCOV._serialized_end = 6504 - _NAFILL._serialized_start = 6507 - _NAFILL._serialized_end = 6641 - _NADROP._serialized_start = 6644 - _NADROP._serialized_end = 6778 - _NAREPLACE._serialized_start = 6781 - _NAREPLACE._serialized_end = 7077 - _NAREPLACE_REPLACEMENT._serialized_start = 6936 - _NAREPLACE_REPLACEMENT._serialized_end = 7077 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 7079 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 7193 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 7196 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 7455 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 7388 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 7455 - _WITHCOLUMNS._serialized_start = 7458 - _WITHCOLUMNS._serialized_end = 7589 - _HINT._serialized_start = 7592 - _HINT._serialized_end = 7732 - _UNPIVOT._serialized_start = 7735 - _UNPIVOT._serialized_end = 7981 - _TOSCHEMA._serialized_start = 7983 - _TOSCHEMA._serialized_end = 8089 - _REPARTITIONBYEXPRESSION._serialized_start = 8092 - _REPARTITIONBYEXPRESSION._serialized_end = 8295 + _RELATION._serialized_end = 2295 + _UNKNOWN._serialized_start = 2297 + _UNKNOWN._serialized_end = 2306 + _RELATIONCOMMON._serialized_start = 2308 + _RELATIONCOMMON._serialized_end = 2357 + _SQL._serialized_start = 2359 + _SQL._serialized_end = 2386 + _READ._serialized_start = 2389 + _READ._serialized_end = 2815 + _READ_NAMEDTABLE._serialized_start = 2531 + _READ_NAMEDTABLE._serialized_end = 2592 + _READ_DATASOURCE._serialized_start = 2595 + _READ_DATASOURCE._serialized_end = 2802 + _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2733 + _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2791 + _PROJECT._serialized_start = 2817 + _PROJECT._serialized_end = 2934 + _FILTER._serialized_start = 2936 + _FILTER._serialized_end = 3048 + _JOIN._serialized_start = 3051 + _JOIN._serialized_end = 3522 + _JOIN_JOINTYPE._serialized_start = 3314 + _JOIN_JOINTYPE._serialized_end = 3522 + _SETOPERATION._serialized_start = 3525 + _SETOPERATION._serialized_end = 3921 + _SETOPERATION_SETOPTYPE._serialized_start = 3784 + _SETOPERATION_SETOPTYPE._serialized_end = 3898 + _LIMIT._serialized_start = 3923 + _LIMIT._serialized_end = 3999 + _OFFSET._serialized_start = 4001 + _OFFSET._serialized_end = 4080 + _TAIL._serialized_start = 4082 + _TAIL._serialized_end = 4157 + _AGGREGATE._serialized_start = 4160 + _AGGREGATE._serialized_end = 4742 + _AGGREGATE_PIVOT._serialized_start = 4499 + _AGGREGATE_PIVOT._serialized_end = 4610 + _AGGREGATE_GROUPTYPE._serialized_start = 4613 + _AGGREGATE_GROUPTYPE._serialized_end = 4742 + _SORT._serialized_start = 4745 + _SORT._serialized_end = 4905 + _DROP._serialized_start = 4907 + _DROP._serialized_end = 5007 + _DEDUPLICATE._serialized_start = 5010 + _DEDUPLICATE._serialized_end = 5181 + _LOCALRELATION._serialized_start = 5184 + _LOCALRELATION._serialized_end = 5321 + _SAMPLE._serialized_start = 5324 + _SAMPLE._serialized_end = 5619 + _RANGE._serialized_start = 5622 + _RANGE._serialized_end = 5767 + _SUBQUERYALIAS._serialized_start = 5769 + _SUBQUERYALIAS._serialized_end = 5883 + _REPARTITION._serialized_start = 5886 + _REPARTITION._serialized_end = 6028 + _SHOWSTRING._serialized_start = 6031 + _SHOWSTRING._serialized_end = 6173 + _STATSUMMARY._serialized_start = 6175 + _STATSUMMARY._serialized_end = 6267 + _STATDESCRIBE._serialized_start = 6269 + _STATDESCRIBE._serialized_end = 6350 + _STATCROSSTAB._serialized_start = 6352 + _STATCROSSTAB._serialized_end = 6453 + _STATCOV._serialized_start = 6455 + _STATCOV._serialized_end = 6551 + _STATCORR._serialized_start = 6554 + _STATCORR._serialized_end = 6691 + _NAFILL._serialized_start = 6694 + _NAFILL._serialized_end = 6828 + _NADROP._serialized_start = 6831 + _NADROP._serialized_end = 6965 + _NAREPLACE._serialized_start = 6968 + _NAREPLACE._serialized_end = 7264 + _NAREPLACE_REPLACEMENT._serialized_start = 7123 + _NAREPLACE_REPLACEMENT._serialized_end = 7264 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 7266 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 7380 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 7383 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 7642 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 7575 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 7642 + _WITHCOLUMNS._serialized_start = 7645 + _WITHCOLUMNS._serialized_end = 7776 + _HINT._serialized_start = 7779 + _HINT._serialized_end = 7919 + _UNPIVOT._serialized_start = 7922 + _UNPIVOT._serialized_end = 8168 + _TOSCHEMA._serialized_start = 8170 + _TOSCHEMA._serialized_end = 8276 + _REPARTITIONBYEXPRESSION._serialized_start = 8279 + _REPARTITIONBYEXPRESSION._serialized_end = 8482 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index e1a37abbd6c..63ccfa18559 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -95,6 +95,7 @@ class Relation(google.protobuf.message.Message): CROSSTAB_FIELD_NUMBER: builtins.int DESCRIBE_FIELD_NUMBER: builtins.int COV_FIELD_NUMBER: builtins.int + CORR_FIELD_NUMBER: builtins.int CATALOG_FIELD_NUMBER: builtins.int UNKNOWN_FIELD_NUMBER: builtins.int @property @@ -168,6 +169,8 @@ class Relation(google.protobuf.message.Message): @property def cov(self) -> global___StatCov: ... @property + def corr(self) -> global___StatCorr: ... + @property def catalog(self) -> pyspark.sql.connect.proto.catalog_pb2.Catalog: """Catalog API (experimental / unstable)""" @property @@ -209,6 +212,7 @@ class Relation(google.protobuf.message.Message): crosstab: global___StatCrosstab | None = ..., describe: global___StatDescribe | None = ..., cov: global___StatCov | None = ..., + corr: global___StatCorr | None = ..., catalog: pyspark.sql.connect.proto.catalog_pb2.Catalog | None = ..., unknown: global___Unknown | None = ..., ) -> None: ... @@ -221,6 +225,8 @@ class Relation(google.protobuf.message.Message): b"catalog", "common", b"common", + "corr", + b"corr", "cov", b"cov", "crosstab", @@ -300,6 +306,8 @@ class Relation(google.protobuf.message.Message): b"catalog", "common", b"common", + "corr", + b"corr", "cov", b"cov", "crosstab", @@ -406,6 +414,7 @@ class Relation(google.protobuf.message.Message): "crosstab", "describe", "cov", + "corr", "catalog", "unknown", ] | None: ... @@ -1710,6 +1719,65 @@ class StatCov(google.protobuf.message.Message): global___StatCov = StatCov +class StatCorr(google.protobuf.message.Message): + """Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson + Correlation Coefficient. It will invoke 'Dataset.stat.corr' (same as + 'StatFunctions.pearsonCorrelation') to compute the results. + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + INPUT_FIELD_NUMBER: builtins.int + COL1_FIELD_NUMBER: builtins.int + COL2_FIELD_NUMBER: builtins.int + METHOD_FIELD_NUMBER: builtins.int + @property + def input(self) -> global___Relation: + """(Required) The input relation.""" + col1: builtins.str + """(Required) The name of the first column.""" + col2: builtins.str + """(Required) The name of the second column.""" + method: builtins.str + """(Optional) Default value is 'pearson'. + + Currently only supports the Pearson Correlation Coefficient. + """ + def __init__( + self, + *, + input: global___Relation | None = ..., + col1: builtins.str = ..., + col2: builtins.str = ..., + method: builtins.str | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "_method", b"_method", "input", b"input", "method", b"method" + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "_method", + b"_method", + "col1", + b"col1", + "col2", + b"col2", + "input", + b"input", + "method", + b"method", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_method", b"_method"] + ) -> typing_extensions.Literal["method"] | None: ... + +global___StatCorr = StatCorr + class NAFill(google.protobuf.message.Message): """Replaces null values. It will invoke 'Dataset.na.fill' (same as 'DataFrameNaFunctions.fill') to compute the results. diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 84c1baea80d..99ee54a87fa 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -1013,6 +1013,30 @@ class SparkConnectTests(SparkConnectSQLTestCase): self.spark.read.table(self.tbl_name2).stat.cov("col1", "col3"), ) + def test_stat_corr(self): + # SPARK-41068: Test the stat.corr method + self.assertEqual( + self.connect.read.table(self.tbl_name2).stat.corr("col1", "col3"), + self.spark.read.table(self.tbl_name2).stat.corr("col1", "col3"), + ) + + self.assertEqual( + self.connect.read.table(self.tbl_name2).stat.corr("col1", "col3", "pearson"), + self.spark.read.table(self.tbl_name2).stat.corr("col1", "col3", "pearson"), + ) + + with self.assertRaisesRegex(TypeError, "col1 should be a string."): + self.connect.read.table(self.tbl_name2).stat.corr(1, "col3", "pearson") + with self.assertRaisesRegex(TypeError, "col2 should be a string."): + self.connect.read.table(self.tbl_name).stat.corr("col1", 1, "pearson") + with self.assertRaises(ValueError) as context: + self.connect.read.table(self.tbl_name2).stat.corr("col1", "col3", "spearman"), + self.assertTrue( + "Currently only the calculation of the Pearson Correlation " + + "coefficient is supported." + in str(context.exception) + ) + def test_repr(self): # SPARK-41213: Test the __repr__ method query = """SELECT * FROM VALUES (1L, NULL), (3L, "Z") AS tab(a, b)""" --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org