This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new c5d27603f29 [SPARK-41064][CONNECT][PYTHON] Implement `DataFrame.crosstab` and `DataFrame.stat.crosstab` c5d27603f29 is described below commit c5d27603f29437f1686cac70727594c19410a273 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Thu Nov 10 18:15:54 2022 +0800 [SPARK-41064][CONNECT][PYTHON] Implement `DataFrame.crosstab` and `DataFrame.stat.crosstab` ### What changes were proposed in this pull request? Implement `DataFrame.crosstab` and `DataFrame.stat.crosstab` ### Why are the changes needed? for api coverage ### Does this PR introduce _any_ user-facing change? yes, new api ### How was this patch tested? added ut Closes #38578 from zhengruifeng/connect_df_crosstab. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../main/protobuf/spark/connect/relations.proto | 19 +++ .../org/apache/spark/sql/connect/dsl/package.scala | 17 +++ .../sql/connect/planner/SparkConnectPlanner.scala | 10 ++ .../connect/planner/SparkConnectProtoSuite.scala | 6 + python/pyspark/sql/connect/dataframe.py | 62 ++++++++++ python/pyspark/sql/connect/plan.py | 32 +++++ python/pyspark/sql/connect/proto/relations_pb2.py | 134 +++++++++++---------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 50 ++++++++ .../sql/tests/connect/test_connect_plan_only.py | 10 ++ 9 files changed, 274 insertions(+), 66 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto index b3613fc908d..639d1bafce5 100644 --- a/connector/connect/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto @@ -52,6 +52,7 @@ message Relation { // stat functions StatSummary summary = 100; + StatCrosstab crosstab = 101; Unknown unknown = 999; } @@ -284,6 +285,24 @@ message StatSummary { repeated string statistics = 2; } +// Computes a pair-wise frequency table of the given columns. Also known as a contingency table. +// It will invoke 'Dataset.stat.crosstab' (same as 'StatFunctions.crossTabulate') +// to compute the results. +message StatCrosstab { + // (Required) The input relation. + Relation input = 1; + + // (Required) The name of the first column. + // + // Distinct items will make the first item of each row. + string col1 = 2; + + // (Required) The name of the second column. + // + // Distinct items will make the column names of the DataFrame. + string col2 = 3; +} + // Rename columns on the input relation by the same length of names. message RenameColumnsBySameLengthNames { // Required. The input relation. diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 381cbf7a9a8..5e7a94da347 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -227,6 +227,21 @@ package object dsl { } } + implicit class DslStatFunctions(val logicalPlan: Relation) { + def crosstab(col1: String, col2: String): Relation = { + Relation + .newBuilder() + .setCrosstab( + proto.StatCrosstab + .newBuilder() + .setInput(logicalPlan) + .setCol1(col1) + .setCol2(col2) + .build()) + .build() + } + } + implicit class DslLogicalPlan(val logicalPlan: Relation) { def select(exprs: Expression*): Relation = { Relation @@ -463,6 +478,8 @@ package object dsl { Repartition.newBuilder().setInput(logicalPlan).setNumPartitions(num).setShuffle(true)) .build() + def stat: DslStatFunctions = new DslStatFunctions(logicalPlan) + def summary(statistics: String*): Relation = { Relation .newBuilder() diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 148f5569683..04ce880a925 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -67,6 +67,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { transformSubqueryAlias(rel.getSubqueryAlias) case proto.Relation.RelTypeCase.REPARTITION => transformRepartition(rel.getRepartition) case proto.Relation.RelTypeCase.SUMMARY => transformStatSummary(rel.getSummary) + case proto.Relation.RelTypeCase.CROSSTAB => + transformStatCrosstab(rel.getCrosstab) case proto.Relation.RelTypeCase.RENAME_COLUMNS_BY_SAME_LENGTH_NAMES => transformRenameColumnsBySamelenghtNames(rel.getRenameColumnsBySameLengthNames) case proto.Relation.RelTypeCase.RENAME_COLUMNS_BY_NAME_TO_NAME_MAP => @@ -129,6 +131,14 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { .logicalPlan } + private def transformStatCrosstab(rel: proto.StatCrosstab): LogicalPlan = { + Dataset + .ofRows(session, transformRelation(rel.getInput)) + .stat + .crosstab(rel.getCol1, rel.getCol2) + .logicalPlan + } + private def transformRenameColumnsBySamelenghtNames( rel: proto.RenameColumnsBySameLengthNames): LogicalPlan = { Dataset diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 3612c5e0d0a..5052b451047 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -273,6 +273,12 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { sparkTestRelation.summary("count", "mean", "stddev")) } + test("Test crosstab") { + comparePlans( + connectTestRelation.stat.crosstab("id", "name"), + sparkTestRelation.stat.crosstab("id", "name")) + } + test("Test toDF") { comparePlans(connectTestRelation.toDF("col1", "col2"), sparkTestRelation.toDF("col1", "col2")) } diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 6bf3ce0dcc9..e3116ea1250 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -501,6 +501,18 @@ class DataFrame(object): def where(self, condition: Expression) -> "DataFrame": return self.filter(condition) + @property + def stat(self) -> "DataFrameStatFunctions": + """Returns a :class:`DataFrameStatFunctions` for statistic functions. + + .. versionadded:: 3.4.0 + + Returns + ------- + :class:`DataFrameStatFunctions` + """ + return DataFrameStatFunctions(self) + def summary(self, *statistics: str) -> "DataFrame": _statistics: List[str] = list(statistics) for s in _statistics: @@ -511,6 +523,41 @@ class DataFrame(object): session=self._session, ) + def crosstab(self, col1: str, col2: str) -> "DataFrame": + """ + Computes a pair-wise frequency table of the given columns. Also known as a contingency + table. The number of distinct values for each column should be less than 1e4. At most 1e6 + non-zero pair frequencies will be returned. + The first column of each row will be the distinct values of `col1` and the column names + will be the distinct values of `col2`. The name of the first column will be `$col1_$col2`. + Pairs that have no occurrences will have zero as their counts. + :func:`DataFrame.crosstab` and :func:`DataFrameStatFunctions.crosstab` are aliases. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + col1 : str + The name of the first column. Distinct items will make the first item of + each row. + col2 : str + The name of the second column. Distinct items will make the column names + of the :class:`DataFrame`. + + Returns + ------- + :class:`DataFrame` + Frequency matrix of two columns. + """ + if not isinstance(col1, str): + raise TypeError(f"'col1' must be str, but got {type(col1).__name__}") + if not isinstance(col2, str): + raise TypeError(f"'col2' must be str, but got {type(col2).__name__}") + return DataFrame.withPlan( + plan.StatCrosstab(child=self._plan, col1=col1, col2=col2), + session=self._session, + ) + def _get_alias(self) -> Optional[str]: p = self._plan while p is not None: @@ -579,3 +626,18 @@ class DataFrame(object): return self._session.explain_string(query) else: return "" + + +class DataFrameStatFunctions: + """Functionality for statistic functions with :class:`DataFrame`. + + .. versionadded:: 3.4.0 + """ + + def __init__(self, df: DataFrame): + self.df = df + + def crosstab(self, col1: str, col2: str) -> DataFrame: + return self.df.crosstab(col1, col2) + + crosstab.__doc__ = DataFrame.crosstab.__doc__ diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 047c9f2ce0f..926119c5457 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -830,3 +830,35 @@ class StatSummary(LogicalPlan): </li> </ul> """ + + +class StatCrosstab(LogicalPlan): + def __init__(self, child: Optional["LogicalPlan"], col1: str, col2: str) -> None: + super().__init__(child) + self.col1 = col1 + self.col2 = col2 + + def plan(self, session: "RemoteSparkSession") -> proto.Relation: + assert self._child is not None + + plan = proto.Relation() + plan.crosstab.input.CopyFrom(self._child.plan(session)) + plan.crosstab.col1 = self.col1 + plan.crosstab.col2 = self.col2 + return plan + + def print(self, indent: int = 0) -> str: + i = " " * indent + return f"""{i}<Crosstab col1='{self.col1}' col2='{self.col2}'>""" + + def _repr_html_(self) -> str: + return f""" + <ul> + <li> + <b>Crosstab</b><br /> + Col1: {self.col1} <br /> + Col2: {self.col2} <br /> + {self._child_repr_()} + </li> + </ul> + """ diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index d8c85596727..323eb8e7690 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -32,7 +32,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xfb\t\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\ [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xb6\n\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\ [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -46,69 +46,71 @@ if _descriptor._USE_C_DESCRIPTORS == False: _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._options = None _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_options = b"8\001" _RELATION._serialized_start = 82 - _RELATION._serialized_end = 1357 - _UNKNOWN._serialized_start = 1359 - _UNKNOWN._serialized_end = 1368 - _RELATIONCOMMON._serialized_start = 1370 - _RELATIONCOMMON._serialized_end = 1419 - _SQL._serialized_start = 1421 - _SQL._serialized_end = 1448 - _READ._serialized_start = 1451 - _READ._serialized_end = 1861 - _READ_NAMEDTABLE._serialized_start = 1593 - _READ_NAMEDTABLE._serialized_end = 1654 - _READ_DATASOURCE._serialized_start = 1657 - _READ_DATASOURCE._serialized_end = 1848 - _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1790 - _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 1848 - _PROJECT._serialized_start = 1863 - _PROJECT._serialized_end = 1980 - _FILTER._serialized_start = 1982 - _FILTER._serialized_end = 2094 - _JOIN._serialized_start = 2097 - _JOIN._serialized_end = 2547 - _JOIN_JOINTYPE._serialized_start = 2360 - _JOIN_JOINTYPE._serialized_end = 2547 - _SETOPERATION._serialized_start = 2550 - _SETOPERATION._serialized_end = 2913 - _SETOPERATION_SETOPTYPE._serialized_start = 2799 - _SETOPERATION_SETOPTYPE._serialized_end = 2913 - _LIMIT._serialized_start = 2915 - _LIMIT._serialized_end = 2991 - _OFFSET._serialized_start = 2993 - _OFFSET._serialized_end = 3072 - _AGGREGATE._serialized_start = 3075 - _AGGREGATE._serialized_end = 3285 - _SORT._serialized_start = 3288 - _SORT._serialized_end = 3819 - _SORT_SORTFIELD._serialized_start = 3437 - _SORT_SORTFIELD._serialized_end = 3625 - _SORT_SORTDIRECTION._serialized_start = 3627 - _SORT_SORTDIRECTION._serialized_end = 3735 - _SORT_SORTNULLS._serialized_start = 3737 - _SORT_SORTNULLS._serialized_end = 3819 - _DEDUPLICATE._serialized_start = 3822 - _DEDUPLICATE._serialized_end = 3964 - _LOCALRELATION._serialized_start = 3966 - _LOCALRELATION._serialized_end = 4059 - _SAMPLE._serialized_start = 4062 - _SAMPLE._serialized_end = 4302 - _SAMPLE_SEED._serialized_start = 4276 - _SAMPLE_SEED._serialized_end = 4302 - _RANGE._serialized_start = 4305 - _RANGE._serialized_end = 4503 - _RANGE_NUMPARTITIONS._serialized_start = 4449 - _RANGE_NUMPARTITIONS._serialized_end = 4503 - _SUBQUERYALIAS._serialized_start = 4505 - _SUBQUERYALIAS._serialized_end = 4619 - _REPARTITION._serialized_start = 4621 - _REPARTITION._serialized_end = 4746 - _STATSUMMARY._serialized_start = 4748 - _STATSUMMARY._serialized_end = 4840 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 4842 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 4956 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 4959 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 5218 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 5151 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 5218 + _RELATION._serialized_end = 1416 + _UNKNOWN._serialized_start = 1418 + _UNKNOWN._serialized_end = 1427 + _RELATIONCOMMON._serialized_start = 1429 + _RELATIONCOMMON._serialized_end = 1478 + _SQL._serialized_start = 1480 + _SQL._serialized_end = 1507 + _READ._serialized_start = 1510 + _READ._serialized_end = 1920 + _READ_NAMEDTABLE._serialized_start = 1652 + _READ_NAMEDTABLE._serialized_end = 1713 + _READ_DATASOURCE._serialized_start = 1716 + _READ_DATASOURCE._serialized_end = 1907 + _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1849 + _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 1907 + _PROJECT._serialized_start = 1922 + _PROJECT._serialized_end = 2039 + _FILTER._serialized_start = 2041 + _FILTER._serialized_end = 2153 + _JOIN._serialized_start = 2156 + _JOIN._serialized_end = 2606 + _JOIN_JOINTYPE._serialized_start = 2419 + _JOIN_JOINTYPE._serialized_end = 2606 + _SETOPERATION._serialized_start = 2609 + _SETOPERATION._serialized_end = 2972 + _SETOPERATION_SETOPTYPE._serialized_start = 2858 + _SETOPERATION_SETOPTYPE._serialized_end = 2972 + _LIMIT._serialized_start = 2974 + _LIMIT._serialized_end = 3050 + _OFFSET._serialized_start = 3052 + _OFFSET._serialized_end = 3131 + _AGGREGATE._serialized_start = 3134 + _AGGREGATE._serialized_end = 3344 + _SORT._serialized_start = 3347 + _SORT._serialized_end = 3878 + _SORT_SORTFIELD._serialized_start = 3496 + _SORT_SORTFIELD._serialized_end = 3684 + _SORT_SORTDIRECTION._serialized_start = 3686 + _SORT_SORTDIRECTION._serialized_end = 3794 + _SORT_SORTNULLS._serialized_start = 3796 + _SORT_SORTNULLS._serialized_end = 3878 + _DEDUPLICATE._serialized_start = 3881 + _DEDUPLICATE._serialized_end = 4023 + _LOCALRELATION._serialized_start = 4025 + _LOCALRELATION._serialized_end = 4118 + _SAMPLE._serialized_start = 4121 + _SAMPLE._serialized_end = 4361 + _SAMPLE_SEED._serialized_start = 4335 + _SAMPLE_SEED._serialized_end = 4361 + _RANGE._serialized_start = 4364 + _RANGE._serialized_end = 4562 + _RANGE_NUMPARTITIONS._serialized_start = 4508 + _RANGE_NUMPARTITIONS._serialized_end = 4562 + _SUBQUERYALIAS._serialized_start = 4564 + _SUBQUERYALIAS._serialized_end = 4678 + _REPARTITION._serialized_start = 4680 + _REPARTITION._serialized_end = 4805 + _STATSUMMARY._serialized_start = 4807 + _STATSUMMARY._serialized_end = 4899 + _STATCROSSTAB._serialized_start = 4901 + _STATCROSSTAB._serialized_end = 5002 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 5004 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 5118 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 5121 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 5380 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 5313 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 5380 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 5569e4db4ef..53f75b7520f 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -79,6 +79,7 @@ class Relation(google.protobuf.message.Message): RENAME_COLUMNS_BY_SAME_LENGTH_NAMES_FIELD_NUMBER: builtins.int RENAME_COLUMNS_BY_NAME_TO_NAME_MAP_FIELD_NUMBER: builtins.int SUMMARY_FIELD_NUMBER: builtins.int + CROSSTAB_FIELD_NUMBER: builtins.int UNKNOWN_FIELD_NUMBER: builtins.int @property def common(self) -> global___RelationCommon: ... @@ -122,6 +123,8 @@ class Relation(google.protobuf.message.Message): def summary(self) -> global___StatSummary: """stat functions""" @property + def crosstab(self) -> global___StatCrosstab: ... + @property def unknown(self) -> global___Unknown: ... def __init__( self, @@ -146,6 +149,7 @@ class Relation(google.protobuf.message.Message): rename_columns_by_same_length_names: global___RenameColumnsBySameLengthNames | None = ..., rename_columns_by_name_to_name_map: global___RenameColumnsByNameToNameMap | None = ..., summary: global___StatSummary | None = ..., + crosstab: global___StatCrosstab | None = ..., unknown: global___Unknown | None = ..., ) -> None: ... def HasField( @@ -155,6 +159,8 @@ class Relation(google.protobuf.message.Message): b"aggregate", "common", b"common", + "crosstab", + b"crosstab", "deduplicate", b"deduplicate", "filter", @@ -204,6 +210,8 @@ class Relation(google.protobuf.message.Message): b"aggregate", "common", b"common", + "crosstab", + b"crosstab", "deduplicate", b"deduplicate", "filter", @@ -268,6 +276,7 @@ class Relation(google.protobuf.message.Message): "rename_columns_by_same_length_names", "rename_columns_by_name_to_name_map", "summary", + "crosstab", "unknown", ] | None: ... @@ -1141,6 +1150,47 @@ class StatSummary(google.protobuf.message.Message): global___StatSummary = StatSummary +class StatCrosstab(google.protobuf.message.Message): + """Computes a pair-wise frequency table of the given columns. Also known as a contingency table. + It will invoke 'Dataset.stat.crosstab' (same as 'StatFunctions.crossTabulate') + to compute the results. + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + INPUT_FIELD_NUMBER: builtins.int + COL1_FIELD_NUMBER: builtins.int + COL2_FIELD_NUMBER: builtins.int + @property + def input(self) -> global___Relation: + """(Required) The input relation.""" + col1: builtins.str + """(Required) The name of the first column. + + Distinct items will make the first item of each row. + """ + col2: builtins.str + """(Required) The name of the second column. + + Distinct items will make the column names of the DataFrame. + """ + def __init__( + self, + *, + input: global___Relation | None = ..., + col1: builtins.str = ..., + col2: builtins.str = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["input", b"input"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal["col1", b"col1", "col2", b"col2", "input", b"input"], + ) -> None: ... + +global___StatCrosstab = StatCrosstab + class RenameColumnsBySameLengthNames(google.protobuf.message.Message): """Rename columns on the input relation by the same length of names.""" diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py index 0164fec11ff..c46d4d10624 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py @@ -85,6 +85,16 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture): ["count", "mean", "stddev", "min", "25%"], ) + def test_crosstab(self): + df = self.connect.readTable(table_name=self.tbl_name) + plan = df.filter(df.col_name > 3).crosstab("col_a", "col_b")._plan.to_proto(self.connect) + self.assertEqual(plan.root.crosstab.col1, "col_a") + self.assertEqual(plan.root.crosstab.col2, "col_b") + + plan = df.stat.crosstab("col_a", "col_b")._plan.to_proto(self.connect) + self.assertEqual(plan.root.crosstab.col1, "col_a") + self.assertEqual(plan.root.crosstab.col2, "col_b") + def test_limit(self): df = self.connect.readTable(table_name=self.tbl_name) limit_plan = df.limit(10)._plan.to_proto(self.connect) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org