This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 847199fb6d9 [SPARK-45929][SQL] Support groupingSets operation in
dataframe api
847199fb6d9 is described below
commit 847199fb6d95910ef624815cfad0be2f8ab8d9d7
Author: JacobZheng0927 <[email protected]>
AuthorDate: Tue Nov 21 10:41:17 2023 +0900
[SPARK-45929][SQL] Support groupingSets operation in dataframe api
### What changes were proposed in this pull request?
Add groupingSets method in dataset api.
`select col1, col2, col3, sum(col4) FROM t GROUP col1, col2, col3 BY
GROUPING SETS ((col1, col2), ())`
This SQL can be equivalently replaced with the following code:
`df.groupingSets(Seq(Seq("col1", "col2"), Seq()), "col1", "col2",
"col3").sum("col4")`
### Why are the changes needed?
Currently grouping sets can only be used in spark sql. This feature is not
available when developing with the dataset api.
### Does this PR introduce _any_ user-facing change?
Yes. This PR introduces the use of groupingSets in the dataset api.
### How was this patch tested?
Tests added in `DataFrameAggregateSuite.scala`.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #43813 from JacobZheng0927/SPARK-45929.
Authored-by: JacobZheng0927 <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../main/scala/org/apache/spark/sql/Dataset.scala | 27 ++++++++++++++++++++++
.../spark/sql/RelationalGroupedDataset.scala | 10 ++++++++
.../apache/spark/sql/DataFrameAggregateSuite.scala | 15 ++++++++++++
3 files changed, 52 insertions(+)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 5a372f9a0f9..062c4c6bcad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1825,6 +1825,33 @@ class Dataset[T] private[sql](
RelationalGroupedDataset(toDF(), cols.map(_.expr),
RelationalGroupedDataset.CubeType)
}
+ /**
+ * Create multi-dimensional aggregation for the current Dataset using the
specified grouping sets,
+ * so we can run aggregation on them.
+ * See [[RelationalGroupedDataset]] for all the available aggregate
functions.
+ *
+ * {{{
+ * // Compute the average for all numeric columns group by specific
grouping sets.
+ * ds.groupingSets(Seq(Seq($"department", $"group"),Seq()),$"department",
$"group").avg()
+ *
+ * // Compute the max age and average salary, group by specific grouping
sets.
+ * ds.groupingSets(Seq($"department", $"gender"), Seq()),$"department",
$"group").agg(Map(
+ * "salary" -> "avg",
+ * "age" -> "max"
+ * ))
+ * }}}
+ *
+ * @group untypedrel
+ * @since 4.0.0
+ */
+ @scala.annotation.varargs
+ def groupingSets(groupingSets: Seq[Seq[Column]], cols: Column*):
RelationalGroupedDataset = {
+ RelationalGroupedDataset(
+ toDF(),
+ cols.map(_.expr),
+
RelationalGroupedDataset.GroupingSetsType(groupingSets.map(_.map(_.expr))))
+ }
+
/**
* Groups the Dataset using the specified columns, so that we can run
aggregation on them.
* See [[RelationalGroupedDataset]] for all the available aggregate
functions.
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 7e15c0baf52..bf1b2814270 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -80,6 +80,11 @@ class RelationalGroupedDataset protected[sql](
Dataset.ofRows(
df.sparkSession, Aggregate(Seq(Cube(groupingExprs.map(Seq(_)))),
aliasedAgg, df.logicalPlan))
+ case RelationalGroupedDataset.GroupingSetsType(groupingSets) =>
+ Dataset.ofRows(
+ df.sparkSession,
+ Aggregate(Seq(GroupingSets(groupingSets, groupingExprs)),
+ aliasedAgg, df.logicalPlan))
case RelationalGroupedDataset.PivotType(pivotCol, values) =>
val aliasedGrps = groupingExprs.map(alias)
Dataset.ofRows(
@@ -732,6 +737,11 @@ private[sql] object RelationalGroupedDataset {
*/
private[sql] object RollupType extends GroupType
+ /**
+ * To indicate it's the GroupingSets
+ */
+ private[sql] case class GroupingSetsType(groupingSets: Seq[Seq[Expression]])
extends GroupType
+
/**
* To indicate it's the PIVOT
*/
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index c8eea985c10..3691d76d251 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -161,6 +161,21 @@ class DataFrameAggregateSuite extends QueryTest
assert(cube0.where("date IS NULL").count() > 0)
}
+ test("SPARK-45929 support grouping set operation in dataframe api") {
+ checkAnswer(
+ courseSales
+ .groupingSets(
+ Seq(Seq(Column("course"), Column("year")), Seq()),
+ Column("course"),
+ Column("year"))
+ .agg(sum(Column("earnings")), grouping_id()),
+ Row("Java", 2012, 20000.0, 0) ::
+ Row("Java", 2013, 30000.0, 0) ::
+ Row("dotNET", 2012, 15000.0, 0) ::
+ Row("dotNET", 2013, 48000.0, 0) ::
+ Row(null, null, 113000.0, 3) :: Nil)
+ }
+
test("grouping and grouping_id") {
checkAnswer(
courseSales.cube("course", "year")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]