This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d14410c6777 [SPARK-46048][PYTHON][CONNECT] Support
DataFrame.groupingSets in Python Spark Connect
d14410c6777 is described below
commit d14410c6777e7de7f61e1957fab749da2793f4b8
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Thu Nov 23 16:38:52 2023 +0900
[SPARK-46048][PYTHON][CONNECT] Support DataFrame.groupingSets in Python
Spark Connect
### What changes were proposed in this pull request?
This PR adds `DataFrame.groupingSets` in Python Spark Connect.
### Why are the changes needed?
For feature parity with non-Spark Connect.
### Does this PR introduce _any_ user-facing change?
Yes, it adds the new API `DataFframe.groupingSets` in Python Spark Connect.
### How was this patch tested?
Unittests were added.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #43967 from HyukjinKwon/SPARK-46048.
Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../main/protobuf/spark/connect/relations.proto | 9 +
.../org/apache/spark/sql/connect/dsl/package.scala | 21 +++
.../sql/connect/planner/SparkConnectPlanner.scala | 11 ++
.../connect/planner/SparkConnectProtoSuite.scala | 12 ++
python/pyspark/sql/connect/dataframe.py | 39 +++++
python/pyspark/sql/connect/group.py | 16 +-
python/pyspark/sql/connect/plan.py | 23 ++-
python/pyspark/sql/connect/proto/relations_pb2.py | 194 +++++++++++----------
python/pyspark/sql/connect/proto/relations_pb2.pyi | 36 ++++
python/pyspark/sql/dataframe.py | 1 -
10 files changed, 262 insertions(+), 100 deletions(-)
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index deb33978386..43f692671df 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -327,12 +327,16 @@ message Aggregate {
// (Optional) Pivots a column of the current `DataFrame` and performs the
specified aggregation.
Pivot pivot = 5;
+ // (Optional) List of values that will be translated to columns in the
output DataFrame.
+ repeated GroupingSets grouping_sets = 6;
+
enum GroupType {
GROUP_TYPE_UNSPECIFIED = 0;
GROUP_TYPE_GROUPBY = 1;
GROUP_TYPE_ROLLUP = 2;
GROUP_TYPE_CUBE = 3;
GROUP_TYPE_PIVOT = 4;
+ GROUP_TYPE_GROUPING_SETS = 5;
}
message Pivot {
@@ -345,6 +349,11 @@ message Aggregate {
// the distinct values of the column.
repeated Expression.Literal values = 2;
}
+
+ message GroupingSets {
+ // (Required) Individual grouping set
+ repeated Expression grouping_set = 1;
+ }
}
// Relation of type [[Sort]].
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 5fd1a035385..18c71ae4ace 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -800,6 +800,27 @@ package object dsl {
Relation.newBuilder().setAggregate(agg.build()).build()
}
+ def groupingSets(groupingSets: Seq[Seq[Expression]], groupingExprs:
Expression*)(
+ aggregateExprs: Expression*): Relation = {
+ val agg = Aggregate.newBuilder()
+ agg.setInput(logicalPlan)
+ agg.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS)
+ for (groupingSet <- groupingSets) {
+ val groupingSetMsg = Aggregate.GroupingSets.newBuilder()
+ for (groupCol <- groupingSet) {
+ groupingSetMsg.addGroupingSet(groupCol)
+ }
+ agg.addGroupingSets(groupingSetMsg)
+ }
+ for (groupingExpr <- groupingExprs) {
+ agg.addGroupingExpressions(groupingExpr)
+ }
+ for (aggregateExpr <- aggregateExprs) {
+ agg.addAggregateExpressions(aggregateExpr)
+ }
+ Relation.newBuilder().setAggregate(agg.build()).build()
+ }
+
def except(otherPlan: Relation, isAll: Boolean): Relation = {
Relation
.newBuilder()
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 4a0aa7e5589..95c5acc803d 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -2445,6 +2445,17 @@ class SparkConnectPlanner(
aggregates = aggExprs,
child = input)
+ case proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS =>
+ val groupingSetsExprs = rel.getGroupingSetsList.asScala.toSeq.map {
getGroupingSets =>
+
getGroupingSets.getGroupingSetList.asScala.toSeq.map(transformExpression)
+ }
+ logical.Aggregate(
+ groupingExpressions = Seq(
+ GroupingSets(
+ groupingSets = groupingSetsExprs,
+ userGivenGroupByExprs = groupingExprs)),
+ aggregateExpressions = aliasedAgg,
+ child = input)
case other => throw InvalidPlanInput(s"Unknown Group Type $other")
}
}
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index c54aa496c66..0b27ccdbef8 100644
---
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -307,6 +307,18 @@ class SparkConnectProtoSuite extends PlanTest with
SparkConnectPlanTest {
comparePlans(connectPlan2, sparkPlan2)
}
+ test("GroupingSets expressions") {
+ val connectPlan1 =
+ connectTestRelation.groupingSets(Seq(Seq("id".protoAttr), Seq.empty),
"id".protoAttr)(
+
proto_min(proto.Expression.newBuilder().setLiteral(toLiteralProto(1)).build())
+ .as("agg1"))
+ val sparkPlan1 =
+ sparkTestRelation
+ .groupingSets(Seq(Seq(Column("id")), Seq.empty), Column("id"))
+ .agg(min(lit(1)).as("agg1"))
+ comparePlans(connectPlan1, sparkPlan1)
+ }
+
test("Test as(alias: String)") {
val connectPlan = connectTestRelation.as("target_table")
val sparkPlan = sparkTestRelation.as("target_table")
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index c7b51205363..b3bec44428b 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -550,6 +550,45 @@ class DataFrame:
cube.__doc__ = PySparkDataFrame.cube.__doc__
+ def groupingSets(
+ self, groupingSets: Sequence[Sequence["ColumnOrName"]], *cols:
"ColumnOrName"
+ ) -> "GroupedData":
+ gsets: List[List[Column]] = []
+ for grouping_set in groupingSets:
+ gset: List[Column] = []
+ for c in grouping_set:
+ if isinstance(c, Column):
+ gset.append(c)
+ elif isinstance(c, str):
+ gset.append(self[c])
+ else:
+ raise PySparkTypeError(
+ error_class="NOT_COLUMN_OR_STR",
+ message_parameters={
+ "arg_name": "groupingSets",
+ "arg_type": type(c).__name__,
+ },
+ )
+ gsets.append(gset)
+
+ gcols: List[Column] = []
+ for c in cols:
+ if isinstance(c, Column):
+ gcols.append(c)
+ elif isinstance(c, str):
+ gcols.append(self[c])
+ else:
+ raise PySparkTypeError(
+ error_class="NOT_COLUMN_OR_STR",
+ message_parameters={"arg_name": "cols", "arg_type":
type(c).__name__},
+ )
+
+ return GroupedData(
+ df=self, group_type="grouping_sets", grouping_cols=gcols,
grouping_sets=gsets
+ )
+
+ groupingSets.__doc__ = PySparkDataFrame.groupingSets.__doc__
+
@overload
def head(self) -> Optional[Row]:
...
diff --git a/python/pyspark/sql/connect/group.py
b/python/pyspark/sql/connect/group.py
index 7b71a43c112..481b7981a15 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -63,13 +63,20 @@ class GroupedData:
grouping_cols: Sequence["Column"],
pivot_col: Optional["Column"] = None,
pivot_values: Optional[Sequence["LiteralType"]] = None,
+ grouping_sets: Optional[Sequence[Sequence["Column"]]] = None,
) -> None:
from pyspark.sql.connect.dataframe import DataFrame
assert isinstance(df, DataFrame)
self._df = df
- assert isinstance(group_type, str) and group_type in ["groupby",
"rollup", "cube", "pivot"]
+ assert isinstance(group_type, str) and group_type in [
+ "groupby",
+ "rollup",
+ "cube",
+ "pivot",
+ "grouping_sets",
+ ]
self._group_type = group_type
assert isinstance(grouping_cols, list) and all(isinstance(g, Column)
for g in grouping_cols)
@@ -83,6 +90,11 @@ class GroupedData:
self._pivot_col = pivot_col
self._pivot_values = pivot_values
+ self._grouping_sets: Optional[Sequence[Sequence["Column"]]] = None
+ if group_type == "grouping_sets":
+ assert grouping_sets is None or isinstance(grouping_sets, list)
+ self._grouping_sets = grouping_sets
+
def __repr__(self) -> str:
# the expressions are not resolved here,
# so the string representation can be different from vanilla PySpark.
@@ -130,6 +142,7 @@ class GroupedData:
aggregate_cols=aggregate_cols,
pivot_col=self._pivot_col,
pivot_values=self._pivot_values,
+ grouping_sets=self._grouping_sets,
),
session=self._df._session,
)
@@ -171,6 +184,7 @@ class GroupedData:
aggregate_cols=[_invoke_function(function, col(c)) for c in
agg_cols],
pivot_col=self._pivot_col,
pivot_values=self._pivot_values,
+ grouping_sets=self._grouping_sets,
),
session=self._df._session,
)
diff --git a/python/pyspark/sql/connect/plan.py
b/python/pyspark/sql/connect/plan.py
index 607d1429a9e..7d63f8714a9 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -778,10 +778,17 @@ class Aggregate(LogicalPlan):
aggregate_cols: Sequence[Column],
pivot_col: Optional[Column],
pivot_values: Optional[Sequence[Any]],
+ grouping_sets: Optional[Sequence[Sequence[Column]]],
) -> None:
super().__init__(child)
- assert isinstance(group_type, str) and group_type in ["groupby",
"rollup", "cube", "pivot"]
+ assert isinstance(group_type, str) and group_type in [
+ "groupby",
+ "rollup",
+ "cube",
+ "pivot",
+ "grouping_sets",
+ ]
self._group_type = group_type
assert isinstance(grouping_cols, list) and all(isinstance(c, Column)
for c in grouping_cols)
@@ -795,12 +802,16 @@ class Aggregate(LogicalPlan):
if group_type == "pivot":
assert pivot_col is not None and isinstance(pivot_col, Column)
assert pivot_values is None or isinstance(pivot_values, list)
+ elif group_type == "grouping_sets":
+ assert grouping_sets is None or isinstance(grouping_sets, list)
else:
assert pivot_col is None
assert pivot_values is None
+ assert grouping_sets is None
self._pivot_col = pivot_col
self._pivot_values = pivot_values
+ self._grouping_sets = grouping_sets
def plan(self, session: "SparkConnectClient") -> proto.Relation:
from pyspark.sql.connect.functions import lit
@@ -829,7 +840,15 @@ class Aggregate(LogicalPlan):
plan.aggregate.pivot.values.extend(
[lit(v).to_plan(session).literal for v in
self._pivot_values]
)
-
+ elif self._group_type == "grouping_sets":
+ plan.aggregate.group_type =
proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS
+ assert self._grouping_sets is not None
+ for grouping_set in self._grouping_sets:
+ plan.aggregate.grouping_sets.append(
+ proto.Aggregate.GroupingSets(
+ grouping_set=[c.to_plan(session) for c in grouping_set]
+ )
+ )
return plan
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py
b/python/pyspark/sql/connect/proto/relations_pb2.py
index fc70cdea402..f79ee786afb 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -35,7 +35,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as
spark_dot_connect_dot_catal
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\x9a\x19\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\x9a\x19\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -104,101 +104,103 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_TAIL._serialized_start = 6182
_TAIL._serialized_end = 6257
_AGGREGATE._serialized_start = 6260
- _AGGREGATE._serialized_end = 6842
- _AGGREGATE_PIVOT._serialized_start = 6599
- _AGGREGATE_PIVOT._serialized_end = 6710
- _AGGREGATE_GROUPTYPE._serialized_start = 6713
- _AGGREGATE_GROUPTYPE._serialized_end = 6842
- _SORT._serialized_start = 6845
- _SORT._serialized_end = 7005
- _DROP._serialized_start = 7008
- _DROP._serialized_end = 7149
- _DEDUPLICATE._serialized_start = 7152
- _DEDUPLICATE._serialized_end = 7392
- _LOCALRELATION._serialized_start = 7394
- _LOCALRELATION._serialized_end = 7483
- _CACHEDLOCALRELATION._serialized_start = 7485
- _CACHEDLOCALRELATION._serialized_end = 7557
- _CACHEDREMOTERELATION._serialized_start = 7559
- _CACHEDREMOTERELATION._serialized_end = 7614
- _SAMPLE._serialized_start = 7617
- _SAMPLE._serialized_end = 7890
- _RANGE._serialized_start = 7893
- _RANGE._serialized_end = 8038
- _SUBQUERYALIAS._serialized_start = 8040
- _SUBQUERYALIAS._serialized_end = 8154
- _REPARTITION._serialized_start = 8157
- _REPARTITION._serialized_end = 8299
- _SHOWSTRING._serialized_start = 8302
- _SHOWSTRING._serialized_end = 8444
- _HTMLSTRING._serialized_start = 8446
- _HTMLSTRING._serialized_end = 8560
- _STATSUMMARY._serialized_start = 8562
- _STATSUMMARY._serialized_end = 8654
- _STATDESCRIBE._serialized_start = 8656
- _STATDESCRIBE._serialized_end = 8737
- _STATCROSSTAB._serialized_start = 8739
- _STATCROSSTAB._serialized_end = 8840
- _STATCOV._serialized_start = 8842
- _STATCOV._serialized_end = 8938
- _STATCORR._serialized_start = 8941
- _STATCORR._serialized_end = 9078
- _STATAPPROXQUANTILE._serialized_start = 9081
- _STATAPPROXQUANTILE._serialized_end = 9245
- _STATFREQITEMS._serialized_start = 9247
- _STATFREQITEMS._serialized_end = 9372
- _STATSAMPLEBY._serialized_start = 9375
- _STATSAMPLEBY._serialized_end = 9684
- _STATSAMPLEBY_FRACTION._serialized_start = 9576
- _STATSAMPLEBY_FRACTION._serialized_end = 9675
- _NAFILL._serialized_start = 9687
- _NAFILL._serialized_end = 9821
- _NADROP._serialized_start = 9824
- _NADROP._serialized_end = 9958
- _NAREPLACE._serialized_start = 9961
- _NAREPLACE._serialized_end = 10257
- _NAREPLACE_REPLACEMENT._serialized_start = 10116
- _NAREPLACE_REPLACEMENT._serialized_end = 10257
- _TODF._serialized_start = 10259
- _TODF._serialized_end = 10347
- _WITHCOLUMNSRENAMED._serialized_start = 10350
- _WITHCOLUMNSRENAMED._serialized_end = 10589
- _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 10522
- _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 10589
- _WITHCOLUMNS._serialized_start = 10591
- _WITHCOLUMNS._serialized_end = 10710
- _WITHWATERMARK._serialized_start = 10713
- _WITHWATERMARK._serialized_end = 10847
- _HINT._serialized_start = 10850
- _HINT._serialized_end = 10982
- _UNPIVOT._serialized_start = 10985
- _UNPIVOT._serialized_end = 11312
- _UNPIVOT_VALUES._serialized_start = 11242
- _UNPIVOT_VALUES._serialized_end = 11301
- _TOSCHEMA._serialized_start = 11314
- _TOSCHEMA._serialized_end = 11420
- _REPARTITIONBYEXPRESSION._serialized_start = 11423
- _REPARTITIONBYEXPRESSION._serialized_end = 11626
- _MAPPARTITIONS._serialized_start = 11629
- _MAPPARTITIONS._serialized_end = 11810
- _GROUPMAP._serialized_start = 11813
- _GROUPMAP._serialized_end = 12448
- _COGROUPMAP._serialized_start = 12451
- _COGROUPMAP._serialized_end = 12977
- _APPLYINPANDASWITHSTATE._serialized_start = 12980
- _APPLYINPANDASWITHSTATE._serialized_end = 13337
- _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_start = 13340
- _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_end = 13584
- _PYTHONUDTF._serialized_start = 13587
- _PYTHONUDTF._serialized_end = 13764
- _COLLECTMETRICS._serialized_start = 13767
- _COLLECTMETRICS._serialized_end = 13903
- _PARSE._serialized_start = 13906
- _PARSE._serialized_end = 14294
+ _AGGREGATE._serialized_end = 7026
+ _AGGREGATE_PIVOT._serialized_start = 6675
+ _AGGREGATE_PIVOT._serialized_end = 6786
+ _AGGREGATE_GROUPINGSETS._serialized_start = 6788
+ _AGGREGATE_GROUPINGSETS._serialized_end = 6864
+ _AGGREGATE_GROUPTYPE._serialized_start = 6867
+ _AGGREGATE_GROUPTYPE._serialized_end = 7026
+ _SORT._serialized_start = 7029
+ _SORT._serialized_end = 7189
+ _DROP._serialized_start = 7192
+ _DROP._serialized_end = 7333
+ _DEDUPLICATE._serialized_start = 7336
+ _DEDUPLICATE._serialized_end = 7576
+ _LOCALRELATION._serialized_start = 7578
+ _LOCALRELATION._serialized_end = 7667
+ _CACHEDLOCALRELATION._serialized_start = 7669
+ _CACHEDLOCALRELATION._serialized_end = 7741
+ _CACHEDREMOTERELATION._serialized_start = 7743
+ _CACHEDREMOTERELATION._serialized_end = 7798
+ _SAMPLE._serialized_start = 7801
+ _SAMPLE._serialized_end = 8074
+ _RANGE._serialized_start = 8077
+ _RANGE._serialized_end = 8222
+ _SUBQUERYALIAS._serialized_start = 8224
+ _SUBQUERYALIAS._serialized_end = 8338
+ _REPARTITION._serialized_start = 8341
+ _REPARTITION._serialized_end = 8483
+ _SHOWSTRING._serialized_start = 8486
+ _SHOWSTRING._serialized_end = 8628
+ _HTMLSTRING._serialized_start = 8630
+ _HTMLSTRING._serialized_end = 8744
+ _STATSUMMARY._serialized_start = 8746
+ _STATSUMMARY._serialized_end = 8838
+ _STATDESCRIBE._serialized_start = 8840
+ _STATDESCRIBE._serialized_end = 8921
+ _STATCROSSTAB._serialized_start = 8923
+ _STATCROSSTAB._serialized_end = 9024
+ _STATCOV._serialized_start = 9026
+ _STATCOV._serialized_end = 9122
+ _STATCORR._serialized_start = 9125
+ _STATCORR._serialized_end = 9262
+ _STATAPPROXQUANTILE._serialized_start = 9265
+ _STATAPPROXQUANTILE._serialized_end = 9429
+ _STATFREQITEMS._serialized_start = 9431
+ _STATFREQITEMS._serialized_end = 9556
+ _STATSAMPLEBY._serialized_start = 9559
+ _STATSAMPLEBY._serialized_end = 9868
+ _STATSAMPLEBY_FRACTION._serialized_start = 9760
+ _STATSAMPLEBY_FRACTION._serialized_end = 9859
+ _NAFILL._serialized_start = 9871
+ _NAFILL._serialized_end = 10005
+ _NADROP._serialized_start = 10008
+ _NADROP._serialized_end = 10142
+ _NAREPLACE._serialized_start = 10145
+ _NAREPLACE._serialized_end = 10441
+ _NAREPLACE_REPLACEMENT._serialized_start = 10300
+ _NAREPLACE_REPLACEMENT._serialized_end = 10441
+ _TODF._serialized_start = 10443
+ _TODF._serialized_end = 10531
+ _WITHCOLUMNSRENAMED._serialized_start = 10534
+ _WITHCOLUMNSRENAMED._serialized_end = 10773
+ _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 10706
+ _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 10773
+ _WITHCOLUMNS._serialized_start = 10775
+ _WITHCOLUMNS._serialized_end = 10894
+ _WITHWATERMARK._serialized_start = 10897
+ _WITHWATERMARK._serialized_end = 11031
+ _HINT._serialized_start = 11034
+ _HINT._serialized_end = 11166
+ _UNPIVOT._serialized_start = 11169
+ _UNPIVOT._serialized_end = 11496
+ _UNPIVOT_VALUES._serialized_start = 11426
+ _UNPIVOT_VALUES._serialized_end = 11485
+ _TOSCHEMA._serialized_start = 11498
+ _TOSCHEMA._serialized_end = 11604
+ _REPARTITIONBYEXPRESSION._serialized_start = 11607
+ _REPARTITIONBYEXPRESSION._serialized_end = 11810
+ _MAPPARTITIONS._serialized_start = 11813
+ _MAPPARTITIONS._serialized_end = 11994
+ _GROUPMAP._serialized_start = 11997
+ _GROUPMAP._serialized_end = 12632
+ _COGROUPMAP._serialized_start = 12635
+ _COGROUPMAP._serialized_end = 13161
+ _APPLYINPANDASWITHSTATE._serialized_start = 13164
+ _APPLYINPANDASWITHSTATE._serialized_end = 13521
+ _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_start = 13524
+ _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_end = 13768
+ _PYTHONUDTF._serialized_start = 13771
+ _PYTHONUDTF._serialized_end = 13948
+ _COLLECTMETRICS._serialized_start = 13951
+ _COLLECTMETRICS._serialized_end = 14087
+ _PARSE._serialized_start = 14090
+ _PARSE._serialized_end = 14478
_PARSE_OPTIONSENTRY._serialized_start = 4291
_PARSE_OPTIONSENTRY._serialized_end = 4349
- _PARSE_PARSEFORMAT._serialized_start = 14195
- _PARSE_PARSEFORMAT._serialized_end = 14283
- _ASOFJOIN._serialized_start = 14297
- _ASOFJOIN._serialized_end = 14772
+ _PARSE_PARSEFORMAT._serialized_start = 14379
+ _PARSE_PARSEFORMAT._serialized_end = 14467
+ _ASOFJOIN._serialized_start = 14481
+ _ASOFJOIN._serialized_end = 14956
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 5bca4f21b2e..f8b7a2ad1cd 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -1380,6 +1380,7 @@ class Aggregate(google.protobuf.message.Message):
GROUP_TYPE_ROLLUP: Aggregate._GroupType.ValueType # 2
GROUP_TYPE_CUBE: Aggregate._GroupType.ValueType # 3
GROUP_TYPE_PIVOT: Aggregate._GroupType.ValueType # 4
+ GROUP_TYPE_GROUPING_SETS: Aggregate._GroupType.ValueType # 5
class GroupType(_GroupType, metaclass=_GroupTypeEnumTypeWrapper): ...
GROUP_TYPE_UNSPECIFIED: Aggregate.GroupType.ValueType # 0
@@ -1387,6 +1388,7 @@ class Aggregate(google.protobuf.message.Message):
GROUP_TYPE_ROLLUP: Aggregate.GroupType.ValueType # 2
GROUP_TYPE_CUBE: Aggregate.GroupType.ValueType # 3
GROUP_TYPE_PIVOT: Aggregate.GroupType.ValueType # 4
+ GROUP_TYPE_GROUPING_SETS: Aggregate.GroupType.ValueType # 5
class Pivot(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -1423,11 +1425,35 @@ class Aggregate(google.protobuf.message.Message):
self, field_name: typing_extensions.Literal["col", b"col",
"values", b"values"]
) -> None: ...
+ class GroupingSets(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ GROUPING_SET_FIELD_NUMBER: builtins.int
+ @property
+ def grouping_set(
+ self,
+ ) ->
google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+ pyspark.sql.connect.proto.expressions_pb2.Expression
+ ]:
+ """(Required) Individual grouping set"""
+ def __init__(
+ self,
+ *,
+ grouping_set: collections.abc.Iterable[
+ pyspark.sql.connect.proto.expressions_pb2.Expression
+ ]
+ | None = ...,
+ ) -> None: ...
+ def ClearField(
+ self, field_name: typing_extensions.Literal["grouping_set",
b"grouping_set"]
+ ) -> None: ...
+
INPUT_FIELD_NUMBER: builtins.int
GROUP_TYPE_FIELD_NUMBER: builtins.int
GROUPING_EXPRESSIONS_FIELD_NUMBER: builtins.int
AGGREGATE_EXPRESSIONS_FIELD_NUMBER: builtins.int
PIVOT_FIELD_NUMBER: builtins.int
+ GROUPING_SETS_FIELD_NUMBER: builtins.int
@property
def input(self) -> global___Relation:
"""(Required) Input relation for a RelationalGroupedDataset."""
@@ -1450,6 +1476,13 @@ class Aggregate(google.protobuf.message.Message):
@property
def pivot(self) -> global___Aggregate.Pivot:
"""(Optional) Pivots a column of the current `DataFrame` and performs
the specified aggregation."""
+ @property
+ def grouping_sets(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+ global___Aggregate.GroupingSets
+ ]:
+ """(Optional) List of values that will be translated to columns in the
output DataFrame."""
def __init__(
self,
*,
@@ -1464,6 +1497,7 @@ class Aggregate(google.protobuf.message.Message):
]
| None = ...,
pivot: global___Aggregate.Pivot | None = ...,
+ grouping_sets:
collections.abc.Iterable[global___Aggregate.GroupingSets] | None = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["input", b"input",
"pivot", b"pivot"]
@@ -1477,6 +1511,8 @@ class Aggregate(google.protobuf.message.Message):
b"group_type",
"grouping_expressions",
b"grouping_expressions",
+ "grouping_sets",
+ b"grouping_sets",
"input",
b"input",
"pivot",
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 383a5566ded..82087adc82f 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -4204,7 +4204,6 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
return GroupedData(jgd, self)
- # TODO(SPARK-46048): Add it to Python Spark Connect client.
def groupingSets(
self, groupingSets: Sequence[Sequence["ColumnOrName"]], *cols:
"ColumnOrName"
) -> "GroupedData":
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]