This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5211f6b140a [SPARK-46085][CONNECT] Dataset.groupingSets in Scala Spark
Connect client
5211f6b140a is described below
commit 5211f6b140a74bd28f7e05934508bdafdbe7f237
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Fri Nov 24 17:52:23 2023 -0800
[SPARK-46085][CONNECT] Dataset.groupingSets in Scala Spark Connect client
### What changes were proposed in this pull request?
This PR proposes to add `Dataset.groupingsets` API added from
https://github.com/apache/spark/pull/43813 to Scala Spark Connect cleint.
### Why are the changes needed?
For feature parity.
### Does this PR introduce _any_ user-facing change?
Yes, it adds a new API to Scala Spark Connect client.
### How was this patch tested?
Unittest was added.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #43995 from HyukjinKwon/SPARK-46085.
Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../main/scala/org/apache/spark/sql/Dataset.scala | 35 +++++++++++++++
.../spark/sql/RelationalGroupedDataset.scala | 8 +++-
.../apache/spark/sql/PlanGenerationTestSuite.scala | 6 +++
.../explain-results/groupingSets.explain | 4 ++
.../query-tests/queries/groupingSets.json | 50 +++++++++++++++++++++
.../query-tests/queries/groupingSets.proto.bin | Bin 0 -> 106 bytes
6 files changed, 102 insertions(+), 1 deletion(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index a1e57226e53..d760c9d9769 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1532,6 +1532,41 @@ class Dataset[T] private[sql] (
proto.Aggregate.GroupType.GROUP_TYPE_CUBE)
}
+ /**
+ * Create multi-dimensional aggregation for the current Dataset using the
specified grouping
+ * sets, so we can run aggregation on them. See [[RelationalGroupedDataset]]
for all the
+ * available aggregate functions.
+ *
+ * {{{
+ * // Compute the average for all numeric columns group by specific
grouping sets.
+ * ds.groupingSets(Seq(Seq($"department", $"group"), Seq()),
$"department", $"group").avg()
+ *
+ * // Compute the max age and average salary, group by specific grouping
sets.
+ * ds.groupingSets(Seq($"department", $"gender"), Seq()), $"department",
$"group").agg(Map(
+ * "salary" -> "avg",
+ * "age" -> "max"
+ * ))
+ * }}}
+ *
+ * @group untypedrel
+ * @since 4.0.0
+ */
+ @scala.annotation.varargs
+ def groupingSets(groupingSets: Seq[Seq[Column]], cols: Column*):
RelationalGroupedDataset = {
+ val groupingSetMsgs = groupingSets.map { groupingSet =>
+ val groupingSetMsg = proto.Aggregate.GroupingSets.newBuilder()
+ for (groupCol <- groupingSet) {
+ groupingSetMsg.addGroupingSet(groupCol.expr)
+ }
+ groupingSetMsg.build()
+ }
+ new RelationalGroupedDataset(
+ toDF(),
+ cols,
+ proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS,
+ groupingSets = Some(groupingSetMsgs))
+ }
+
/**
* (Scala-specific) Aggregates on the entire Dataset without groups.
* {{{
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 5ed97e45c77..776a6231eae 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -39,7 +39,8 @@ class RelationalGroupedDataset private[sql] (
private[sql] val df: DataFrame,
private[sql] val groupingExprs: Seq[Column],
groupType: proto.Aggregate.GroupType,
- pivot: Option[proto.Aggregate.Pivot] = None) {
+ pivot: Option[proto.Aggregate.Pivot] = None,
+ groupingSets: Option[Seq[proto.Aggregate.GroupingSets]] = None) {
private[this] def toDF(aggExprs: Seq[Column]): DataFrame = {
df.sparkSession.newDataFrame { builder =>
@@ -60,6 +61,11 @@ class RelationalGroupedDataset private[sql] (
builder.getAggregateBuilder
.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_PIVOT)
.setPivot(pivot.get)
+ case proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS =>
+ assert(groupingSets.isDefined)
+ val aggBuilder = builder.getAggregateBuilder
+ .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS)
+ groupingSets.get.foreach(aggBuilder.addGroupingSets)
case g => throw new UnsupportedOperationException(g.toString)
}
}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index 5cc63bc45a0..c5c917ebfa9 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -3017,6 +3017,12 @@ class PlanGenerationTestSuite
simple.groupBy(Column("id")).pivot("a").agg(functions.count(Column("b")))
}
+ test("groupingSets") {
+ simple
+ .groupingSets(Seq(Seq(fn.col("a")), Seq.empty[Column]), fn.col("a"))
+ .agg("a" -> "max", "a" -> "count")
+ }
+
test("width_bucket") {
simple.select(fn.width_bucket(fn.col("b"), fn.col("b"), fn.col("b"),
fn.col("a")))
}
diff --git
a/connector/connect/common/src/test/resources/query-tests/explain-results/groupingSets.explain
b/connector/connect/common/src/test/resources/query-tests/explain-results/groupingSets.explain
new file mode 100644
index 00000000000..1e3fe1a987e
--- /dev/null
+++
b/connector/connect/common/src/test/resources/query-tests/explain-results/groupingSets.explain
@@ -0,0 +1,4 @@
+Aggregate [a#0, spark_grouping_id#0L], [a#0, max(a#0) AS max(a)#0, count(a#0)
AS count(a)#0L]
++- Expand [[id#0L, a#0, b#0, a#0, 0], [id#0L, a#0, b#0, null, 1]], [id#0L,
a#0, b#0, a#0, spark_grouping_id#0L]
+ +- Project [id#0L, a#0, b#0, a#0 AS a#0]
+ +- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git
a/connector/connect/common/src/test/resources/query-tests/queries/groupingSets.json
b/connector/connect/common/src/test/resources/query-tests/queries/groupingSets.json
new file mode 100644
index 00000000000..6e84824ec7a
--- /dev/null
+++
b/connector/connect/common/src/test/resources/query-tests/queries/groupingSets.json
@@ -0,0 +1,50 @@
+{
+ "common": {
+ "planId": "1"
+ },
+ "aggregate": {
+ "input": {
+ "common": {
+ "planId": "0"
+ },
+ "localRelation": {
+ "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
+ }
+ },
+ "groupType": "GROUP_TYPE_GROUPING_SETS",
+ "groupingExpressions": [{
+ "unresolvedAttribute": {
+ "unparsedIdentifier": "a"
+ }
+ }],
+ "aggregateExpressions": [{
+ "unresolvedFunction": {
+ "functionName": "max",
+ "arguments": [{
+ "unresolvedAttribute": {
+ "unparsedIdentifier": "a",
+ "planId": "0"
+ }
+ }]
+ }
+ }, {
+ "unresolvedFunction": {
+ "functionName": "count",
+ "arguments": [{
+ "unresolvedAttribute": {
+ "unparsedIdentifier": "a",
+ "planId": "0"
+ }
+ }]
+ }
+ }],
+ "groupingSets": [{
+ "groupingSet": [{
+ "unresolvedAttribute": {
+ "unparsedIdentifier": "a"
+ }
+ }]
+ }, {
+ }]
+ }
+}
\ No newline at end of file
diff --git
a/connector/connect/common/src/test/resources/query-tests/queries/groupingSets.proto.bin
b/connector/connect/common/src/test/resources/query-tests/queries/groupingSets.proto.bin
new file mode 100644
index 00000000000..ce029409670
Binary files /dev/null and
b/connector/connect/common/src/test/resources/query-tests/queries/groupingSets.proto.bin
differ
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]