This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3d819389819 [SPARK-41065][CONNECT][PYTHON] Implement 
`DataFrame.freqItems ` and `DataFrame.stat.freqItems `
3d819389819 is described below

commit 3d819389819557523f373c192f88a594b665734d
Author: Jiaan Geng <[email protected]>
AuthorDate: Sun Jan 1 08:55:33 2023 +0800

    [SPARK-41065][CONNECT][PYTHON] Implement `DataFrame.freqItems ` and 
`DataFrame.stat.freqItems `
    
    ### What changes were proposed in this pull request?
    Implement `DataFrame.freqItems ` and `DataFrame.stat.freqItems` with a 
proto message
    
    ~~Implement `DataFrame.freqItems ` and `DataFrame.stat.freqItems` for scala 
API~~
    Implement `DataFrame.freqItems ` and `DataFrame.stat.freqItems` for python 
API
    
    ### Why are the changes needed?
    for Connect API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    'No'. New API
    
    ### How was this patch tested?
    New test cases.
    
    Closes #39325 from beliefer/SPARK-41065.
    
    Authored-by: Jiaan Geng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../main/protobuf/spark/connect/relations.proto    |  16 ++
 .../org/apache/spark/sql/connect/dsl/package.scala |  20 ++
 .../sql/connect/planner/SparkConnectPlanner.scala  |  11 ++
 .../connect/planner/SparkConnectProtoSuite.scala   |  10 +
 python/pyspark/sql/connect/dataframe.py            |  23 +++
 python/pyspark/sql/connect/plan.py                 |  21 +++
 python/pyspark/sql/connect/proto/relations_pb2.py  | 202 +++++++++++----------
 python/pyspark/sql/connect/proto/relations_pb2.pyi |  57 ++++++
 .../sql/tests/connect/test_connect_basic.py        |  19 ++
 .../sql/tests/connect/test_connect_plan_only.py    |  18 ++
 10 files changed, 303 insertions(+), 94 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 7aa098a53b4..db3565eda61 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -73,6 +73,7 @@ message Relation {
     StatCov cov = 103;
     StatCorr corr = 104;
     StatApproxQuantile approx_quantile = 105;
+    StatFreqItems freq_items = 106;
 
     // Catalog API (experimental / unstable)
     Catalog catalog = 200;
@@ -530,6 +531,21 @@ message StatApproxQuantile {
   double relative_error = 4;
 }
 
+// Finding frequent items for columns, possibly with false positives.
+// It will invoke 'Dataset.stat.freqItems' (same as 'StatFunctions.freqItems')
+// to compute the results.
+message StatFreqItems {
+  // (Required) The input relation.
+  Relation input = 1;
+
+  // (Required) The names of the columns to search frequent items in.
+  repeated string cols = 2;
+
+  // (Optional) The minimum frequency for an item to be considered `frequent`.
+  // Should be greater than 1e-4.
+  optional double support = 3;
+}
+
 // Replaces null values.
 // It will invoke 'Dataset.na.fill' (same as 'DataFrameNaFunctions.fill') to 
compute the results.
 // Following 3 parameter combinations are supported:
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 84d46817b08..0b54a9c9d92 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -432,6 +432,26 @@ package object dsl {
               .build())
           .build()
       }
+
+      def freqItems(cols: Array[String], support: Double): Relation = {
+        Relation
+          .newBuilder()
+          .setFreqItems(
+            proto.StatFreqItems
+              .newBuilder()
+              .setInput(logicalPlan)
+              .addAllCols(cols.toSeq.asJava)
+              .setSupport(support)
+              .build())
+          .build()
+      }
+
+      def freqItems(cols: Array[String]): Relation = freqItems(cols, 0.01)
+
+      def freqItems(cols: Seq[String], support: Double): Relation =
+        freqItems(cols.toArray, support)
+
+      def freqItems(cols: Seq[String]): Relation = freqItems(cols, 0.01)
     }
 
     def select(exprs: Expression*): Relation = {
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 98c27f1ea93..dcfdc3f8b52 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -97,6 +97,7 @@ class SparkConnectPlanner(session: SparkSession) {
         transformStatApproxQuantile(rel.getApproxQuantile)
       case proto.Relation.RelTypeCase.CROSSTAB =>
         transformStatCrosstab(rel.getCrosstab)
+      case proto.Relation.RelTypeCase.FREQ_ITEMS => 
transformStatFreqItems(rel.getFreqItems)
       case proto.Relation.RelTypeCase.TO_SCHEMA => 
transformToSchema(rel.getToSchema)
       case proto.Relation.RelTypeCase.RENAME_COLUMNS_BY_SAME_LENGTH_NAMES =>
         
transformRenameColumnsBySamelenghtNames(rel.getRenameColumnsBySameLengthNames)
@@ -408,6 +409,16 @@ class SparkConnectPlanner(session: SparkSession) {
       .logicalPlan
   }
 
+  private def transformStatFreqItems(rel: proto.StatFreqItems): LogicalPlan = {
+    val cols = rel.getColsList.asScala.toSeq
+    val df = Dataset.ofRows(session, transformRelation(rel.getInput))
+    if (rel.hasSupport) {
+      df.stat.freqItems(cols, rel.getSupport).logicalPlan
+    } else {
+      df.stat.freqItems(cols).logicalPlan
+    }
+  }
+
   private def transformToSchema(rel: proto.ToSchema): LogicalPlan = {
     val schema = DataTypeProtoConverter.toCatalystType(rel.getSchema)
     assert(schema.isInstanceOf[StructType])
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 86e7f978e5d..4c4a070bb4f 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -488,6 +488,16 @@ class SparkConnectProtoSuite extends PlanTest with 
SparkConnectPlanTest {
       sparkTestRelation.stat.crosstab("id", "name"))
   }
 
+  test("Test freqItems") {
+    comparePlans(
+      connectTestRelation.stat.freqItems(Seq("id", "name"), 1),
+      sparkTestRelation.stat.freqItems(Seq("id", "name"), 1))
+
+    comparePlans(
+      connectTestRelation.stat.freqItems(Seq("id", "name")),
+      sparkTestRelation.stat.freqItems(Seq("id", "name")))
+  }
+
   test("Test to") {
     val dataTypes: Seq[DataType] = Seq(
       StringType,
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index a309998d245..c5ab22b34bd 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -936,6 +936,22 @@ class DataFrame:
 
     crosstab.__doc__ = PySparkDataFrame.crosstab.__doc__
 
+    def freqItems(
+        self, cols: Union[List[str], Tuple[str]], support: Optional[float] = 
None
+    ) -> "DataFrame":
+        if isinstance(cols, tuple):
+            cols = list(cols)
+        if not isinstance(cols, list):
+            raise TypeError("cols must be a list or tuple of column names as 
strings.")
+        if not support:
+            support = 0.01
+        return DataFrame.withPlan(
+            plan.StatFreqItems(child=self._plan, cols=cols, support=support),
+            session=self._session,
+        )
+
+    freqItems.__doc__ = PySparkDataFrame.freqItems.__doc__
+
     def _get_alias(self) -> Optional[str]:
         p = self._plan
         while p is not None:
@@ -1321,5 +1337,12 @@ class DataFrameStatFunctions:
 
     crosstab.__doc__ = DataFrame.crosstab.__doc__
 
+    def freqItems(
+        self, cols: Union[List[str], Tuple[str]], support: Optional[float] = 
None
+    ) -> DataFrame:
+        return self.df.freqItems(cols, support)
+
+    freqItems.__doc__ = DataFrame.freqItems.__doc__
+
 
 DataFrameStatFunctions.__doc__ = PySparkDataFrameStatFunctions.__doc__
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 07b266bb46c..f567d88137a 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1119,6 +1119,27 @@ class StatCrosstab(LogicalPlan):
         return plan
 
 
+class StatFreqItems(LogicalPlan):
+    def __init__(
+        self,
+        child: Optional["LogicalPlan"],
+        cols: List[str],
+        support: float,
+    ) -> None:
+        super().__init__(child)
+        self._cols = cols
+        self._support = support
+
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
+        assert self._child is not None
+
+        plan = proto.Relation()
+        plan.freq_items.input.CopyFrom(self._child.plan(session))
+        plan.freq_items.cols.extend(self._cols)
+        plan.freq_items.support = self._support
+        return plan
+
+
 class StatCorr(LogicalPlan):
     def __init__(self, child: Optional["LogicalPlan"], col1: str, col2: str, 
method: str) -> None:
         super().__init__(child)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py 
b/python/pyspark/sql/connect/proto/relations_pb2.py
index 0d89c76287e..6e2904b0294 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as 
spark_dot_connect_dot_catal
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf2\x11\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xb1\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
 )
 
 
@@ -72,6 +72,7 @@ _STATCROSSTAB = 
DESCRIPTOR.message_types_by_name["StatCrosstab"]
 _STATCOV = DESCRIPTOR.message_types_by_name["StatCov"]
 _STATCORR = DESCRIPTOR.message_types_by_name["StatCorr"]
 _STATAPPROXQUANTILE = DESCRIPTOR.message_types_by_name["StatApproxQuantile"]
+_STATFREQITEMS = DESCRIPTOR.message_types_by_name["StatFreqItems"]
 _NAFILL = DESCRIPTOR.message_types_by_name["NAFill"]
 _NADROP = DESCRIPTOR.message_types_by_name["NADrop"]
 _NAREPLACE = DESCRIPTOR.message_types_by_name["NAReplace"]
@@ -437,6 +438,17 @@ StatApproxQuantile = 
_reflection.GeneratedProtocolMessageType(
 )
 _sym_db.RegisterMessage(StatApproxQuantile)
 
+StatFreqItems = _reflection.GeneratedProtocolMessageType(
+    "StatFreqItems",
+    (_message.Message,),
+    {
+        "DESCRIPTOR": _STATFREQITEMS,
+        "__module__": "spark.connect.relations_pb2"
+        # @@protoc_insertion_point(class_scope:spark.connect.StatFreqItems)
+    },
+)
+_sym_db.RegisterMessage(StatFreqItems)
+
 NAFill = _reflection.GeneratedProtocolMessageType(
     "NAFill",
     (_message.Message,),
@@ -576,97 +588,99 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._options = None
     _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_options = 
b"8\001"
     _RELATION._serialized_start = 165
-    _RELATION._serialized_end = 2455
-    _UNKNOWN._serialized_start = 2457
-    _UNKNOWN._serialized_end = 2466
-    _RELATIONCOMMON._serialized_start = 2468
-    _RELATIONCOMMON._serialized_end = 2517
-    _SQL._serialized_start = 2519
-    _SQL._serialized_end = 2546
-    _READ._serialized_start = 2549
-    _READ._serialized_end = 2975
-    _READ_NAMEDTABLE._serialized_start = 2691
-    _READ_NAMEDTABLE._serialized_end = 2752
-    _READ_DATASOURCE._serialized_start = 2755
-    _READ_DATASOURCE._serialized_end = 2962
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2893
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2951
-    _PROJECT._serialized_start = 2977
-    _PROJECT._serialized_end = 3094
-    _FILTER._serialized_start = 3096
-    _FILTER._serialized_end = 3208
-    _JOIN._serialized_start = 3211
-    _JOIN._serialized_end = 3682
-    _JOIN_JOINTYPE._serialized_start = 3474
-    _JOIN_JOINTYPE._serialized_end = 3682
-    _SETOPERATION._serialized_start = 3685
-    _SETOPERATION._serialized_end = 4081
-    _SETOPERATION_SETOPTYPE._serialized_start = 3944
-    _SETOPERATION_SETOPTYPE._serialized_end = 4058
-    _LIMIT._serialized_start = 4083
-    _LIMIT._serialized_end = 4159
-    _OFFSET._serialized_start = 4161
-    _OFFSET._serialized_end = 4240
-    _TAIL._serialized_start = 4242
-    _TAIL._serialized_end = 4317
-    _AGGREGATE._serialized_start = 4320
-    _AGGREGATE._serialized_end = 4902
-    _AGGREGATE_PIVOT._serialized_start = 4659
-    _AGGREGATE_PIVOT._serialized_end = 4770
-    _AGGREGATE_GROUPTYPE._serialized_start = 4773
-    _AGGREGATE_GROUPTYPE._serialized_end = 4902
-    _SORT._serialized_start = 4905
-    _SORT._serialized_end = 5065
-    _DROP._serialized_start = 5067
-    _DROP._serialized_end = 5167
-    _DEDUPLICATE._serialized_start = 5170
-    _DEDUPLICATE._serialized_end = 5341
-    _LOCALRELATION._serialized_start = 5344
-    _LOCALRELATION._serialized_end = 5481
-    _SAMPLE._serialized_start = 5484
-    _SAMPLE._serialized_end = 5757
-    _RANGE._serialized_start = 5760
-    _RANGE._serialized_end = 5905
-    _SUBQUERYALIAS._serialized_start = 5907
-    _SUBQUERYALIAS._serialized_end = 6021
-    _REPARTITION._serialized_start = 6024
-    _REPARTITION._serialized_end = 6166
-    _SHOWSTRING._serialized_start = 6169
-    _SHOWSTRING._serialized_end = 6311
-    _STATSUMMARY._serialized_start = 6313
-    _STATSUMMARY._serialized_end = 6405
-    _STATDESCRIBE._serialized_start = 6407
-    _STATDESCRIBE._serialized_end = 6488
-    _STATCROSSTAB._serialized_start = 6490
-    _STATCROSSTAB._serialized_end = 6591
-    _STATCOV._serialized_start = 6593
-    _STATCOV._serialized_end = 6689
-    _STATCORR._serialized_start = 6692
-    _STATCORR._serialized_end = 6829
-    _STATAPPROXQUANTILE._serialized_start = 6832
-    _STATAPPROXQUANTILE._serialized_end = 6996
-    _NAFILL._serialized_start = 6999
-    _NAFILL._serialized_end = 7133
-    _NADROP._serialized_start = 7136
-    _NADROP._serialized_end = 7270
-    _NAREPLACE._serialized_start = 7273
-    _NAREPLACE._serialized_end = 7569
-    _NAREPLACE_REPLACEMENT._serialized_start = 7428
-    _NAREPLACE_REPLACEMENT._serialized_end = 7569
-    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 7571
-    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 7685
-    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 7688
-    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 7947
-    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 
7880
-    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 7947
-    _WITHCOLUMNS._serialized_start = 7950
-    _WITHCOLUMNS._serialized_end = 8081
-    _HINT._serialized_start = 8084
-    _HINT._serialized_end = 8224
-    _UNPIVOT._serialized_start = 8227
-    _UNPIVOT._serialized_end = 8473
-    _TOSCHEMA._serialized_start = 8475
-    _TOSCHEMA._serialized_end = 8581
-    _REPARTITIONBYEXPRESSION._serialized_start = 8584
-    _REPARTITIONBYEXPRESSION._serialized_end = 8787
+    _RELATION._serialized_end = 2518
+    _UNKNOWN._serialized_start = 2520
+    _UNKNOWN._serialized_end = 2529
+    _RELATIONCOMMON._serialized_start = 2531
+    _RELATIONCOMMON._serialized_end = 2580
+    _SQL._serialized_start = 2582
+    _SQL._serialized_end = 2609
+    _READ._serialized_start = 2612
+    _READ._serialized_end = 3038
+    _READ_NAMEDTABLE._serialized_start = 2754
+    _READ_NAMEDTABLE._serialized_end = 2815
+    _READ_DATASOURCE._serialized_start = 2818
+    _READ_DATASOURCE._serialized_end = 3025
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2956
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3014
+    _PROJECT._serialized_start = 3040
+    _PROJECT._serialized_end = 3157
+    _FILTER._serialized_start = 3159
+    _FILTER._serialized_end = 3271
+    _JOIN._serialized_start = 3274
+    _JOIN._serialized_end = 3745
+    _JOIN_JOINTYPE._serialized_start = 3537
+    _JOIN_JOINTYPE._serialized_end = 3745
+    _SETOPERATION._serialized_start = 3748
+    _SETOPERATION._serialized_end = 4144
+    _SETOPERATION_SETOPTYPE._serialized_start = 4007
+    _SETOPERATION_SETOPTYPE._serialized_end = 4121
+    _LIMIT._serialized_start = 4146
+    _LIMIT._serialized_end = 4222
+    _OFFSET._serialized_start = 4224
+    _OFFSET._serialized_end = 4303
+    _TAIL._serialized_start = 4305
+    _TAIL._serialized_end = 4380
+    _AGGREGATE._serialized_start = 4383
+    _AGGREGATE._serialized_end = 4965
+    _AGGREGATE_PIVOT._serialized_start = 4722
+    _AGGREGATE_PIVOT._serialized_end = 4833
+    _AGGREGATE_GROUPTYPE._serialized_start = 4836
+    _AGGREGATE_GROUPTYPE._serialized_end = 4965
+    _SORT._serialized_start = 4968
+    _SORT._serialized_end = 5128
+    _DROP._serialized_start = 5130
+    _DROP._serialized_end = 5230
+    _DEDUPLICATE._serialized_start = 5233
+    _DEDUPLICATE._serialized_end = 5404
+    _LOCALRELATION._serialized_start = 5407
+    _LOCALRELATION._serialized_end = 5544
+    _SAMPLE._serialized_start = 5547
+    _SAMPLE._serialized_end = 5820
+    _RANGE._serialized_start = 5823
+    _RANGE._serialized_end = 5968
+    _SUBQUERYALIAS._serialized_start = 5970
+    _SUBQUERYALIAS._serialized_end = 6084
+    _REPARTITION._serialized_start = 6087
+    _REPARTITION._serialized_end = 6229
+    _SHOWSTRING._serialized_start = 6232
+    _SHOWSTRING._serialized_end = 6374
+    _STATSUMMARY._serialized_start = 6376
+    _STATSUMMARY._serialized_end = 6468
+    _STATDESCRIBE._serialized_start = 6470
+    _STATDESCRIBE._serialized_end = 6551
+    _STATCROSSTAB._serialized_start = 6553
+    _STATCROSSTAB._serialized_end = 6654
+    _STATCOV._serialized_start = 6656
+    _STATCOV._serialized_end = 6752
+    _STATCORR._serialized_start = 6755
+    _STATCORR._serialized_end = 6892
+    _STATAPPROXQUANTILE._serialized_start = 6895
+    _STATAPPROXQUANTILE._serialized_end = 7059
+    _STATFREQITEMS._serialized_start = 7061
+    _STATFREQITEMS._serialized_end = 7186
+    _NAFILL._serialized_start = 7189
+    _NAFILL._serialized_end = 7323
+    _NADROP._serialized_start = 7326
+    _NADROP._serialized_end = 7460
+    _NAREPLACE._serialized_start = 7463
+    _NAREPLACE._serialized_end = 7759
+    _NAREPLACE_REPLACEMENT._serialized_start = 7618
+    _NAREPLACE_REPLACEMENT._serialized_end = 7759
+    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 7761
+    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 7875
+    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 7878
+    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 8137
+    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 
8070
+    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 8137
+    _WITHCOLUMNS._serialized_start = 8140
+    _WITHCOLUMNS._serialized_end = 8271
+    _HINT._serialized_start = 8274
+    _HINT._serialized_end = 8414
+    _UNPIVOT._serialized_start = 8417
+    _UNPIVOT._serialized_end = 8663
+    _TOSCHEMA._serialized_start = 8665
+    _TOSCHEMA._serialized_end = 8771
+    _REPARTITIONBYEXPRESSION._serialized_start = 8774
+    _REPARTITIONBYEXPRESSION._serialized_end = 8977
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi 
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 1ed9e62edcc..96915de60dc 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -98,6 +98,7 @@ class Relation(google.protobuf.message.Message):
     COV_FIELD_NUMBER: builtins.int
     CORR_FIELD_NUMBER: builtins.int
     APPROX_QUANTILE_FIELD_NUMBER: builtins.int
+    FREQ_ITEMS_FIELD_NUMBER: builtins.int
     CATALOG_FIELD_NUMBER: builtins.int
     EXTENSION_FIELD_NUMBER: builtins.int
     UNKNOWN_FIELD_NUMBER: builtins.int
@@ -176,6 +177,8 @@ class Relation(google.protobuf.message.Message):
     @property
     def approx_quantile(self) -> global___StatApproxQuantile: ...
     @property
+    def freq_items(self) -> global___StatFreqItems: ...
+    @property
     def catalog(self) -> pyspark.sql.connect.proto.catalog_pb2.Catalog:
         """Catalog API (experimental / unstable)"""
     @property
@@ -224,6 +227,7 @@ class Relation(google.protobuf.message.Message):
         cov: global___StatCov | None = ...,
         corr: global___StatCorr | None = ...,
         approx_quantile: global___StatApproxQuantile | None = ...,
+        freq_items: global___StatFreqItems | None = ...,
         catalog: pyspark.sql.connect.proto.catalog_pb2.Catalog | None = ...,
         extension: google.protobuf.any_pb2.Any | None = ...,
         unknown: global___Unknown | None = ...,
@@ -259,6 +263,8 @@ class Relation(google.protobuf.message.Message):
             b"fill_na",
             "filter",
             b"filter",
+            "freq_items",
+            b"freq_items",
             "hint",
             b"hint",
             "join",
@@ -344,6 +350,8 @@ class Relation(google.protobuf.message.Message):
             b"fill_na",
             "filter",
             b"filter",
+            "freq_items",
+            b"freq_items",
             "hint",
             b"hint",
             "join",
@@ -436,6 +444,7 @@ class Relation(google.protobuf.message.Message):
         "cov",
         "corr",
         "approx_quantile",
+        "freq_items",
         "catalog",
         "extension",
         "unknown",
@@ -1853,6 +1862,54 @@ class 
StatApproxQuantile(google.protobuf.message.Message):
 
 global___StatApproxQuantile = StatApproxQuantile
 
+class StatFreqItems(google.protobuf.message.Message):
+    """Finding frequent items for columns, possibly with false positives.
+    It will invoke 'Dataset.stat.freqItems' (same as 'StatFunctions.freqItems')
+    to compute the results.
+    """
+
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    INPUT_FIELD_NUMBER: builtins.int
+    COLS_FIELD_NUMBER: builtins.int
+    SUPPORT_FIELD_NUMBER: builtins.int
+    @property
+    def input(self) -> global___Relation:
+        """(Required) The input relation."""
+    @property
+    def cols(
+        self,
+    ) -> 
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+        """(Required) The names of the columns to search frequent items in."""
+    support: builtins.float
+    """(Optional) The minimum frequency for an item to be considered 
`frequent`.
+    Should be greater than 1e-4.
+    """
+    def __init__(
+        self,
+        *,
+        input: global___Relation | None = ...,
+        cols: collections.abc.Iterable[builtins.str] | None = ...,
+        support: builtins.float | None = ...,
+    ) -> None: ...
+    def HasField(
+        self,
+        field_name: typing_extensions.Literal[
+            "_support", b"_support", "input", b"input", "support", b"support"
+        ],
+    ) -> builtins.bool: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "_support", b"_support", "cols", b"cols", "input", b"input", 
"support", b"support"
+        ],
+    ) -> None: ...
+    def WhichOneof(
+        self, oneof_group: typing_extensions.Literal["_support", b"_support"]
+    ) -> typing_extensions.Literal["support"] | None: ...
+
+global___StatFreqItems = StatFreqItems
+
 class NAFill(google.protobuf.message.Message):
     """Replaces null values.
     It will invoke 'Dataset.na.fill' (same as 'DataFrameNaFunctions.fill') to 
compute the results.
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 21f29a7eb4d..6cdef25d5bc 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -1175,6 +1175,25 @@ class SparkConnectTests(SparkConnectSQLTestCase):
                 ["col1", "col3"], [0.1, 0.5, 0.9], -0.1
             )
 
+    def test_stat_freq_items(self):
+        # SPARK-41065: Test the stat.freqItems method
+        self.assert_eq(
+            self.connect.read.table(self.tbl_name2).stat.freqItems(["col1", 
"col3"]).toPandas(),
+            self.spark.read.table(self.tbl_name2).stat.freqItems(["col1", 
"col3"]).toPandas(),
+        )
+
+        self.assert_eq(
+            self.connect.read.table(self.tbl_name2)
+            .stat.freqItems(["col1", "col3"], 0.4)
+            .toPandas(),
+            self.spark.read.table(self.tbl_name2).stat.freqItems(["col1", 
"col3"], 0.4).toPandas(),
+        )
+
+        with self.assertRaisesRegex(
+            TypeError, "cols must be a list or tuple of column names as 
strings"
+        ):
+            self.connect.read.table(self.tbl_name2).stat.freqItems("col1")
+
     def test_repr(self):
         # SPARK-41213: Test the __repr__ method
         query = """SELECT * FROM VALUES (1L, NULL), (3L, "Z") AS tab(a, b)"""
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py 
b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
index 529e3ca3eda..5e3c6661e52 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
@@ -309,6 +309,24 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
         self.assertEqual(plan.root.crosstab.col1, "col_a")
         self.assertEqual(plan.root.crosstab.col2, "col_b")
 
+    def test_freqItems(self):
+        df = self.connect.readTable(table_name=self.tbl_name)
+        plan = (
+            df.filter(df.col_name > 3).freqItems(["col_a", "col_b"], 
1)._plan.to_proto(self.connect)
+        )
+        self.assertEqual(plan.root.freq_items.cols, ["col_a", "col_b"])
+        self.assertEqual(plan.root.freq_items.support, 1)
+        plan = df.filter(df.col_name > 3).freqItems(["col_a", 
"col_b"])._plan.to_proto(self.connect)
+        self.assertEqual(plan.root.freq_items.cols, ["col_a", "col_b"])
+        self.assertEqual(plan.root.freq_items.support, 0.01)
+
+        plan = df.stat.freqItems(["col_a", "col_b"], 
1)._plan.to_proto(self.connect)
+        self.assertEqual(plan.root.freq_items.cols, ["col_a", "col_b"])
+        self.assertEqual(plan.root.freq_items.support, 1)
+        plan = df.stat.freqItems(["col_a", 
"col_b"])._plan.to_proto(self.connect)
+        self.assertEqual(plan.root.freq_items.cols, ["col_a", "col_b"])
+        self.assertEqual(plan.root.freq_items.support, 0.01)
+
     def test_limit(self):
         df = self.connect.readTable(table_name=self.tbl_name)
         limit_plan = df.limit(10)._plan.to_proto(self.connect)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to