This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 3d819389819 [SPARK-41065][CONNECT][PYTHON] Implement
`DataFrame.freqItems ` and `DataFrame.stat.freqItems `
3d819389819 is described below
commit 3d819389819557523f373c192f88a594b665734d
Author: Jiaan Geng <[email protected]>
AuthorDate: Sun Jan 1 08:55:33 2023 +0800
[SPARK-41065][CONNECT][PYTHON] Implement `DataFrame.freqItems ` and
`DataFrame.stat.freqItems `
### What changes were proposed in this pull request?
Implement `DataFrame.freqItems ` and `DataFrame.stat.freqItems` with a
proto message
~~Implement `DataFrame.freqItems ` and `DataFrame.stat.freqItems` for scala
API~~
Implement `DataFrame.freqItems ` and `DataFrame.stat.freqItems` for python
API
### Why are the changes needed?
for Connect API coverage
### Does this PR introduce _any_ user-facing change?
'No'. New API
### How was this patch tested?
New test cases.
Closes #39325 from beliefer/SPARK-41065.
Authored-by: Jiaan Geng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../main/protobuf/spark/connect/relations.proto | 16 ++
.../org/apache/spark/sql/connect/dsl/package.scala | 20 ++
.../sql/connect/planner/SparkConnectPlanner.scala | 11 ++
.../connect/planner/SparkConnectProtoSuite.scala | 10 +
python/pyspark/sql/connect/dataframe.py | 23 +++
python/pyspark/sql/connect/plan.py | 21 +++
python/pyspark/sql/connect/proto/relations_pb2.py | 202 +++++++++++----------
python/pyspark/sql/connect/proto/relations_pb2.pyi | 57 ++++++
.../sql/tests/connect/test_connect_basic.py | 19 ++
.../sql/tests/connect/test_connect_plan_only.py | 18 ++
10 files changed, 303 insertions(+), 94 deletions(-)
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 7aa098a53b4..db3565eda61 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -73,6 +73,7 @@ message Relation {
StatCov cov = 103;
StatCorr corr = 104;
StatApproxQuantile approx_quantile = 105;
+ StatFreqItems freq_items = 106;
// Catalog API (experimental / unstable)
Catalog catalog = 200;
@@ -530,6 +531,21 @@ message StatApproxQuantile {
double relative_error = 4;
}
+// Finding frequent items for columns, possibly with false positives.
+// It will invoke 'Dataset.stat.freqItems' (same as 'StatFunctions.freqItems')
+// to compute the results.
+message StatFreqItems {
+ // (Required) The input relation.
+ Relation input = 1;
+
+ // (Required) The names of the columns to search frequent items in.
+ repeated string cols = 2;
+
+ // (Optional) The minimum frequency for an item to be considered `frequent`.
+ // Should be greater than 1e-4.
+ optional double support = 3;
+}
+
// Replaces null values.
// It will invoke 'Dataset.na.fill' (same as 'DataFrameNaFunctions.fill') to
compute the results.
// Following 3 parameter combinations are supported:
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 84d46817b08..0b54a9c9d92 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -432,6 +432,26 @@ package object dsl {
.build())
.build()
}
+
+ def freqItems(cols: Array[String], support: Double): Relation = {
+ Relation
+ .newBuilder()
+ .setFreqItems(
+ proto.StatFreqItems
+ .newBuilder()
+ .setInput(logicalPlan)
+ .addAllCols(cols.toSeq.asJava)
+ .setSupport(support)
+ .build())
+ .build()
+ }
+
+ def freqItems(cols: Array[String]): Relation = freqItems(cols, 0.01)
+
+ def freqItems(cols: Seq[String], support: Double): Relation =
+ freqItems(cols.toArray, support)
+
+ def freqItems(cols: Seq[String]): Relation = freqItems(cols, 0.01)
}
def select(exprs: Expression*): Relation = {
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 98c27f1ea93..dcfdc3f8b52 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -97,6 +97,7 @@ class SparkConnectPlanner(session: SparkSession) {
transformStatApproxQuantile(rel.getApproxQuantile)
case proto.Relation.RelTypeCase.CROSSTAB =>
transformStatCrosstab(rel.getCrosstab)
+ case proto.Relation.RelTypeCase.FREQ_ITEMS =>
transformStatFreqItems(rel.getFreqItems)
case proto.Relation.RelTypeCase.TO_SCHEMA =>
transformToSchema(rel.getToSchema)
case proto.Relation.RelTypeCase.RENAME_COLUMNS_BY_SAME_LENGTH_NAMES =>
transformRenameColumnsBySamelenghtNames(rel.getRenameColumnsBySameLengthNames)
@@ -408,6 +409,16 @@ class SparkConnectPlanner(session: SparkSession) {
.logicalPlan
}
+ private def transformStatFreqItems(rel: proto.StatFreqItems): LogicalPlan = {
+ val cols = rel.getColsList.asScala.toSeq
+ val df = Dataset.ofRows(session, transformRelation(rel.getInput))
+ if (rel.hasSupport) {
+ df.stat.freqItems(cols, rel.getSupport).logicalPlan
+ } else {
+ df.stat.freqItems(cols).logicalPlan
+ }
+ }
+
private def transformToSchema(rel: proto.ToSchema): LogicalPlan = {
val schema = DataTypeProtoConverter.toCatalystType(rel.getSchema)
assert(schema.isInstanceOf[StructType])
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 86e7f978e5d..4c4a070bb4f 100644
---
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -488,6 +488,16 @@ class SparkConnectProtoSuite extends PlanTest with
SparkConnectPlanTest {
sparkTestRelation.stat.crosstab("id", "name"))
}
+ test("Test freqItems") {
+ comparePlans(
+ connectTestRelation.stat.freqItems(Seq("id", "name"), 1),
+ sparkTestRelation.stat.freqItems(Seq("id", "name"), 1))
+
+ comparePlans(
+ connectTestRelation.stat.freqItems(Seq("id", "name")),
+ sparkTestRelation.stat.freqItems(Seq("id", "name")))
+ }
+
test("Test to") {
val dataTypes: Seq[DataType] = Seq(
StringType,
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index a309998d245..c5ab22b34bd 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -936,6 +936,22 @@ class DataFrame:
crosstab.__doc__ = PySparkDataFrame.crosstab.__doc__
+ def freqItems(
+ self, cols: Union[List[str], Tuple[str]], support: Optional[float] =
None
+ ) -> "DataFrame":
+ if isinstance(cols, tuple):
+ cols = list(cols)
+ if not isinstance(cols, list):
+ raise TypeError("cols must be a list or tuple of column names as
strings.")
+ if not support:
+ support = 0.01
+ return DataFrame.withPlan(
+ plan.StatFreqItems(child=self._plan, cols=cols, support=support),
+ session=self._session,
+ )
+
+ freqItems.__doc__ = PySparkDataFrame.freqItems.__doc__
+
def _get_alias(self) -> Optional[str]:
p = self._plan
while p is not None:
@@ -1321,5 +1337,12 @@ class DataFrameStatFunctions:
crosstab.__doc__ = DataFrame.crosstab.__doc__
+ def freqItems(
+ self, cols: Union[List[str], Tuple[str]], support: Optional[float] =
None
+ ) -> DataFrame:
+ return self.df.freqItems(cols, support)
+
+ freqItems.__doc__ = DataFrame.freqItems.__doc__
+
DataFrameStatFunctions.__doc__ = PySparkDataFrameStatFunctions.__doc__
diff --git a/python/pyspark/sql/connect/plan.py
b/python/pyspark/sql/connect/plan.py
index 07b266bb46c..f567d88137a 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1119,6 +1119,27 @@ class StatCrosstab(LogicalPlan):
return plan
+class StatFreqItems(LogicalPlan):
+ def __init__(
+ self,
+ child: Optional["LogicalPlan"],
+ cols: List[str],
+ support: float,
+ ) -> None:
+ super().__init__(child)
+ self._cols = cols
+ self._support = support
+
+ def plan(self, session: "SparkConnectClient") -> proto.Relation:
+ assert self._child is not None
+
+ plan = proto.Relation()
+ plan.freq_items.input.CopyFrom(self._child.plan(session))
+ plan.freq_items.cols.extend(self._cols)
+ plan.freq_items.support = self._support
+ return plan
+
+
class StatCorr(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], col1: str, col2: str,
method: str) -> None:
super().__init__(child)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py
b/python/pyspark/sql/connect/proto/relations_pb2.py
index 0d89c76287e..6e2904b0294 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as
spark_dot_connect_dot_catal
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf2\x11\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xb1\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
)
@@ -72,6 +72,7 @@ _STATCROSSTAB =
DESCRIPTOR.message_types_by_name["StatCrosstab"]
_STATCOV = DESCRIPTOR.message_types_by_name["StatCov"]
_STATCORR = DESCRIPTOR.message_types_by_name["StatCorr"]
_STATAPPROXQUANTILE = DESCRIPTOR.message_types_by_name["StatApproxQuantile"]
+_STATFREQITEMS = DESCRIPTOR.message_types_by_name["StatFreqItems"]
_NAFILL = DESCRIPTOR.message_types_by_name["NAFill"]
_NADROP = DESCRIPTOR.message_types_by_name["NADrop"]
_NAREPLACE = DESCRIPTOR.message_types_by_name["NAReplace"]
@@ -437,6 +438,17 @@ StatApproxQuantile =
_reflection.GeneratedProtocolMessageType(
)
_sym_db.RegisterMessage(StatApproxQuantile)
+StatFreqItems = _reflection.GeneratedProtocolMessageType(
+ "StatFreqItems",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _STATFREQITEMS,
+ "__module__": "spark.connect.relations_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.StatFreqItems)
+ },
+)
+_sym_db.RegisterMessage(StatFreqItems)
+
NAFill = _reflection.GeneratedProtocolMessageType(
"NAFill",
(_message.Message,),
@@ -576,97 +588,99 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._options = None
_RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_options =
b"8\001"
_RELATION._serialized_start = 165
- _RELATION._serialized_end = 2455
- _UNKNOWN._serialized_start = 2457
- _UNKNOWN._serialized_end = 2466
- _RELATIONCOMMON._serialized_start = 2468
- _RELATIONCOMMON._serialized_end = 2517
- _SQL._serialized_start = 2519
- _SQL._serialized_end = 2546
- _READ._serialized_start = 2549
- _READ._serialized_end = 2975
- _READ_NAMEDTABLE._serialized_start = 2691
- _READ_NAMEDTABLE._serialized_end = 2752
- _READ_DATASOURCE._serialized_start = 2755
- _READ_DATASOURCE._serialized_end = 2962
- _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2893
- _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2951
- _PROJECT._serialized_start = 2977
- _PROJECT._serialized_end = 3094
- _FILTER._serialized_start = 3096
- _FILTER._serialized_end = 3208
- _JOIN._serialized_start = 3211
- _JOIN._serialized_end = 3682
- _JOIN_JOINTYPE._serialized_start = 3474
- _JOIN_JOINTYPE._serialized_end = 3682
- _SETOPERATION._serialized_start = 3685
- _SETOPERATION._serialized_end = 4081
- _SETOPERATION_SETOPTYPE._serialized_start = 3944
- _SETOPERATION_SETOPTYPE._serialized_end = 4058
- _LIMIT._serialized_start = 4083
- _LIMIT._serialized_end = 4159
- _OFFSET._serialized_start = 4161
- _OFFSET._serialized_end = 4240
- _TAIL._serialized_start = 4242
- _TAIL._serialized_end = 4317
- _AGGREGATE._serialized_start = 4320
- _AGGREGATE._serialized_end = 4902
- _AGGREGATE_PIVOT._serialized_start = 4659
- _AGGREGATE_PIVOT._serialized_end = 4770
- _AGGREGATE_GROUPTYPE._serialized_start = 4773
- _AGGREGATE_GROUPTYPE._serialized_end = 4902
- _SORT._serialized_start = 4905
- _SORT._serialized_end = 5065
- _DROP._serialized_start = 5067
- _DROP._serialized_end = 5167
- _DEDUPLICATE._serialized_start = 5170
- _DEDUPLICATE._serialized_end = 5341
- _LOCALRELATION._serialized_start = 5344
- _LOCALRELATION._serialized_end = 5481
- _SAMPLE._serialized_start = 5484
- _SAMPLE._serialized_end = 5757
- _RANGE._serialized_start = 5760
- _RANGE._serialized_end = 5905
- _SUBQUERYALIAS._serialized_start = 5907
- _SUBQUERYALIAS._serialized_end = 6021
- _REPARTITION._serialized_start = 6024
- _REPARTITION._serialized_end = 6166
- _SHOWSTRING._serialized_start = 6169
- _SHOWSTRING._serialized_end = 6311
- _STATSUMMARY._serialized_start = 6313
- _STATSUMMARY._serialized_end = 6405
- _STATDESCRIBE._serialized_start = 6407
- _STATDESCRIBE._serialized_end = 6488
- _STATCROSSTAB._serialized_start = 6490
- _STATCROSSTAB._serialized_end = 6591
- _STATCOV._serialized_start = 6593
- _STATCOV._serialized_end = 6689
- _STATCORR._serialized_start = 6692
- _STATCORR._serialized_end = 6829
- _STATAPPROXQUANTILE._serialized_start = 6832
- _STATAPPROXQUANTILE._serialized_end = 6996
- _NAFILL._serialized_start = 6999
- _NAFILL._serialized_end = 7133
- _NADROP._serialized_start = 7136
- _NADROP._serialized_end = 7270
- _NAREPLACE._serialized_start = 7273
- _NAREPLACE._serialized_end = 7569
- _NAREPLACE_REPLACEMENT._serialized_start = 7428
- _NAREPLACE_REPLACEMENT._serialized_end = 7569
- _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 7571
- _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 7685
- _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 7688
- _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 7947
- _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start =
7880
- _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 7947
- _WITHCOLUMNS._serialized_start = 7950
- _WITHCOLUMNS._serialized_end = 8081
- _HINT._serialized_start = 8084
- _HINT._serialized_end = 8224
- _UNPIVOT._serialized_start = 8227
- _UNPIVOT._serialized_end = 8473
- _TOSCHEMA._serialized_start = 8475
- _TOSCHEMA._serialized_end = 8581
- _REPARTITIONBYEXPRESSION._serialized_start = 8584
- _REPARTITIONBYEXPRESSION._serialized_end = 8787
+ _RELATION._serialized_end = 2518
+ _UNKNOWN._serialized_start = 2520
+ _UNKNOWN._serialized_end = 2529
+ _RELATIONCOMMON._serialized_start = 2531
+ _RELATIONCOMMON._serialized_end = 2580
+ _SQL._serialized_start = 2582
+ _SQL._serialized_end = 2609
+ _READ._serialized_start = 2612
+ _READ._serialized_end = 3038
+ _READ_NAMEDTABLE._serialized_start = 2754
+ _READ_NAMEDTABLE._serialized_end = 2815
+ _READ_DATASOURCE._serialized_start = 2818
+ _READ_DATASOURCE._serialized_end = 3025
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2956
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3014
+ _PROJECT._serialized_start = 3040
+ _PROJECT._serialized_end = 3157
+ _FILTER._serialized_start = 3159
+ _FILTER._serialized_end = 3271
+ _JOIN._serialized_start = 3274
+ _JOIN._serialized_end = 3745
+ _JOIN_JOINTYPE._serialized_start = 3537
+ _JOIN_JOINTYPE._serialized_end = 3745
+ _SETOPERATION._serialized_start = 3748
+ _SETOPERATION._serialized_end = 4144
+ _SETOPERATION_SETOPTYPE._serialized_start = 4007
+ _SETOPERATION_SETOPTYPE._serialized_end = 4121
+ _LIMIT._serialized_start = 4146
+ _LIMIT._serialized_end = 4222
+ _OFFSET._serialized_start = 4224
+ _OFFSET._serialized_end = 4303
+ _TAIL._serialized_start = 4305
+ _TAIL._serialized_end = 4380
+ _AGGREGATE._serialized_start = 4383
+ _AGGREGATE._serialized_end = 4965
+ _AGGREGATE_PIVOT._serialized_start = 4722
+ _AGGREGATE_PIVOT._serialized_end = 4833
+ _AGGREGATE_GROUPTYPE._serialized_start = 4836
+ _AGGREGATE_GROUPTYPE._serialized_end = 4965
+ _SORT._serialized_start = 4968
+ _SORT._serialized_end = 5128
+ _DROP._serialized_start = 5130
+ _DROP._serialized_end = 5230
+ _DEDUPLICATE._serialized_start = 5233
+ _DEDUPLICATE._serialized_end = 5404
+ _LOCALRELATION._serialized_start = 5407
+ _LOCALRELATION._serialized_end = 5544
+ _SAMPLE._serialized_start = 5547
+ _SAMPLE._serialized_end = 5820
+ _RANGE._serialized_start = 5823
+ _RANGE._serialized_end = 5968
+ _SUBQUERYALIAS._serialized_start = 5970
+ _SUBQUERYALIAS._serialized_end = 6084
+ _REPARTITION._serialized_start = 6087
+ _REPARTITION._serialized_end = 6229
+ _SHOWSTRING._serialized_start = 6232
+ _SHOWSTRING._serialized_end = 6374
+ _STATSUMMARY._serialized_start = 6376
+ _STATSUMMARY._serialized_end = 6468
+ _STATDESCRIBE._serialized_start = 6470
+ _STATDESCRIBE._serialized_end = 6551
+ _STATCROSSTAB._serialized_start = 6553
+ _STATCROSSTAB._serialized_end = 6654
+ _STATCOV._serialized_start = 6656
+ _STATCOV._serialized_end = 6752
+ _STATCORR._serialized_start = 6755
+ _STATCORR._serialized_end = 6892
+ _STATAPPROXQUANTILE._serialized_start = 6895
+ _STATAPPROXQUANTILE._serialized_end = 7059
+ _STATFREQITEMS._serialized_start = 7061
+ _STATFREQITEMS._serialized_end = 7186
+ _NAFILL._serialized_start = 7189
+ _NAFILL._serialized_end = 7323
+ _NADROP._serialized_start = 7326
+ _NADROP._serialized_end = 7460
+ _NAREPLACE._serialized_start = 7463
+ _NAREPLACE._serialized_end = 7759
+ _NAREPLACE_REPLACEMENT._serialized_start = 7618
+ _NAREPLACE_REPLACEMENT._serialized_end = 7759
+ _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 7761
+ _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 7875
+ _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 7878
+ _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 8137
+ _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start =
8070
+ _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 8137
+ _WITHCOLUMNS._serialized_start = 8140
+ _WITHCOLUMNS._serialized_end = 8271
+ _HINT._serialized_start = 8274
+ _HINT._serialized_end = 8414
+ _UNPIVOT._serialized_start = 8417
+ _UNPIVOT._serialized_end = 8663
+ _TOSCHEMA._serialized_start = 8665
+ _TOSCHEMA._serialized_end = 8771
+ _REPARTITIONBYEXPRESSION._serialized_start = 8774
+ _REPARTITIONBYEXPRESSION._serialized_end = 8977
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 1ed9e62edcc..96915de60dc 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -98,6 +98,7 @@ class Relation(google.protobuf.message.Message):
COV_FIELD_NUMBER: builtins.int
CORR_FIELD_NUMBER: builtins.int
APPROX_QUANTILE_FIELD_NUMBER: builtins.int
+ FREQ_ITEMS_FIELD_NUMBER: builtins.int
CATALOG_FIELD_NUMBER: builtins.int
EXTENSION_FIELD_NUMBER: builtins.int
UNKNOWN_FIELD_NUMBER: builtins.int
@@ -176,6 +177,8 @@ class Relation(google.protobuf.message.Message):
@property
def approx_quantile(self) -> global___StatApproxQuantile: ...
@property
+ def freq_items(self) -> global___StatFreqItems: ...
+ @property
def catalog(self) -> pyspark.sql.connect.proto.catalog_pb2.Catalog:
"""Catalog API (experimental / unstable)"""
@property
@@ -224,6 +227,7 @@ class Relation(google.protobuf.message.Message):
cov: global___StatCov | None = ...,
corr: global___StatCorr | None = ...,
approx_quantile: global___StatApproxQuantile | None = ...,
+ freq_items: global___StatFreqItems | None = ...,
catalog: pyspark.sql.connect.proto.catalog_pb2.Catalog | None = ...,
extension: google.protobuf.any_pb2.Any | None = ...,
unknown: global___Unknown | None = ...,
@@ -259,6 +263,8 @@ class Relation(google.protobuf.message.Message):
b"fill_na",
"filter",
b"filter",
+ "freq_items",
+ b"freq_items",
"hint",
b"hint",
"join",
@@ -344,6 +350,8 @@ class Relation(google.protobuf.message.Message):
b"fill_na",
"filter",
b"filter",
+ "freq_items",
+ b"freq_items",
"hint",
b"hint",
"join",
@@ -436,6 +444,7 @@ class Relation(google.protobuf.message.Message):
"cov",
"corr",
"approx_quantile",
+ "freq_items",
"catalog",
"extension",
"unknown",
@@ -1853,6 +1862,54 @@ class
StatApproxQuantile(google.protobuf.message.Message):
global___StatApproxQuantile = StatApproxQuantile
+class StatFreqItems(google.protobuf.message.Message):
+ """Finding frequent items for columns, possibly with false positives.
+ It will invoke 'Dataset.stat.freqItems' (same as 'StatFunctions.freqItems')
+ to compute the results.
+ """
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ INPUT_FIELD_NUMBER: builtins.int
+ COLS_FIELD_NUMBER: builtins.int
+ SUPPORT_FIELD_NUMBER: builtins.int
+ @property
+ def input(self) -> global___Relation:
+ """(Required) The input relation."""
+ @property
+ def cols(
+ self,
+ ) ->
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+ """(Required) The names of the columns to search frequent items in."""
+ support: builtins.float
+ """(Optional) The minimum frequency for an item to be considered
`frequent`.
+ Should be greater than 1e-4.
+ """
+ def __init__(
+ self,
+ *,
+ input: global___Relation | None = ...,
+ cols: collections.abc.Iterable[builtins.str] | None = ...,
+ support: builtins.float | None = ...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_support", b"_support", "input", b"input", "support", b"support"
+ ],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_support", b"_support", "cols", b"cols", "input", b"input",
"support", b"support"
+ ],
+ ) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_support", b"_support"]
+ ) -> typing_extensions.Literal["support"] | None: ...
+
+global___StatFreqItems = StatFreqItems
+
class NAFill(google.protobuf.message.Message):
"""Replaces null values.
It will invoke 'Dataset.na.fill' (same as 'DataFrameNaFunctions.fill') to
compute the results.
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 21f29a7eb4d..6cdef25d5bc 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -1175,6 +1175,25 @@ class SparkConnectTests(SparkConnectSQLTestCase):
["col1", "col3"], [0.1, 0.5, 0.9], -0.1
)
+ def test_stat_freq_items(self):
+ # SPARK-41065: Test the stat.freqItems method
+ self.assert_eq(
+ self.connect.read.table(self.tbl_name2).stat.freqItems(["col1",
"col3"]).toPandas(),
+ self.spark.read.table(self.tbl_name2).stat.freqItems(["col1",
"col3"]).toPandas(),
+ )
+
+ self.assert_eq(
+ self.connect.read.table(self.tbl_name2)
+ .stat.freqItems(["col1", "col3"], 0.4)
+ .toPandas(),
+ self.spark.read.table(self.tbl_name2).stat.freqItems(["col1",
"col3"], 0.4).toPandas(),
+ )
+
+ with self.assertRaisesRegex(
+ TypeError, "cols must be a list or tuple of column names as
strings"
+ ):
+ self.connect.read.table(self.tbl_name2).stat.freqItems("col1")
+
def test_repr(self):
# SPARK-41213: Test the __repr__ method
query = """SELECT * FROM VALUES (1L, NULL), (3L, "Z") AS tab(a, b)"""
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
index 529e3ca3eda..5e3c6661e52 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
@@ -309,6 +309,24 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
self.assertEqual(plan.root.crosstab.col1, "col_a")
self.assertEqual(plan.root.crosstab.col2, "col_b")
+ def test_freqItems(self):
+ df = self.connect.readTable(table_name=self.tbl_name)
+ plan = (
+ df.filter(df.col_name > 3).freqItems(["col_a", "col_b"],
1)._plan.to_proto(self.connect)
+ )
+ self.assertEqual(plan.root.freq_items.cols, ["col_a", "col_b"])
+ self.assertEqual(plan.root.freq_items.support, 1)
+ plan = df.filter(df.col_name > 3).freqItems(["col_a",
"col_b"])._plan.to_proto(self.connect)
+ self.assertEqual(plan.root.freq_items.cols, ["col_a", "col_b"])
+ self.assertEqual(plan.root.freq_items.support, 0.01)
+
+ plan = df.stat.freqItems(["col_a", "col_b"],
1)._plan.to_proto(self.connect)
+ self.assertEqual(plan.root.freq_items.cols, ["col_a", "col_b"])
+ self.assertEqual(plan.root.freq_items.support, 1)
+ plan = df.stat.freqItems(["col_a",
"col_b"])._plan.to_proto(self.connect)
+ self.assertEqual(plan.root.freq_items.cols, ["col_a", "col_b"])
+ self.assertEqual(plan.root.freq_items.support, 0.01)
+
def test_limit(self):
df = self.connect.readTable(table_name=self.tbl_name)
limit_plan = df.limit(10)._plan.to_proto(self.connect)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]