This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 21767d29b36 [SPARK-42529][CONNECT] Support Cube and Rollup in Scala 
client
21767d29b36 is described below

commit 21767d29b36c3c8d812bb3ea8946a21a8ef6e65c
Author: Rui Wang <rui.w...@databricks.com>
AuthorDate: Wed Feb 22 23:56:38 2023 -0400

    [SPARK-42529][CONNECT] Support Cube and Rollup in Scala client
    
    ### What changes were proposed in this pull request?
    
    Support Cube and Rollup in Scala client.
    
    ### Why are the changes needed?
    
    API Coverage
    
    ### Does this PR introduce _any_ user-facing change?
    
    NO
    
    ### How was this patch tested?
    
    UT
    
    Closes #40129 from amaliujia/support_cube_rollup_pivot.
    
    Authored-by: Rui Wang <rui.w...@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 120 ++++++++++++++++++++-
 .../spark/sql/RelationalGroupedDataset.scala       |  16 ++-
 .../apache/spark/sql/PlanGenerationTestSuite.scala |  16 +++
 .../explain-results/cube_column.explain            |   4 +
 .../explain-results/cube_string.explain            |   4 +
 .../explain-results/rollup_column.explain          |   4 +
 .../explain-results/rollup_string.explain          |   4 +
 .../resources/query-tests/queries/cube_column.json |  34 ++++++
 .../query-tests/queries/cube_column.proto.bin      |   7 ++
 .../resources/query-tests/queries/cube_string.json |  34 ++++++
 .../query-tests/queries/cube_string.proto.bin      |   7 ++
 .../query-tests/queries/rollup_column.json         |  34 ++++++
 .../query-tests/queries/rollup_column.proto.bin    |   7 ++
 .../query-tests/queries/rollup_string.json         |  34 ++++++
 .../query-tests/queries/rollup_string.proto.bin    |   7 ++
 15 files changed, 328 insertions(+), 4 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index c7ded04a963..560276d154e 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1055,7 +1055,125 @@ class Dataset[T] private[sql] (val session: 
SparkSession, private[sql] val plan:
    */
   @scala.annotation.varargs
   def groupBy(cols: Column*): RelationalGroupedDataset = {
-    new RelationalGroupedDataset(toDF(), cols.map(_.expr))
+    new RelationalGroupedDataset(
+      toDF(),
+      cols.map(_.expr),
+      proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY)
+  }
+
+  /**
+   * Create a multi-dimensional rollup for the current Dataset using the 
specified columns, so we
+   * can run aggregation on them. See [[RelationalGroupedDataset]] for all the 
available aggregate
+   * functions.
+   *
+   * {{{
+   *   // Compute the average for all numeric columns rolled up by department 
and group.
+   *   ds.rollup($"department", $"group").avg()
+   *
+   *   // Compute the max age and average salary, rolled up by department and 
gender.
+   *   ds.rollup($"department", $"gender").agg(Map(
+   *     "salary" -> "avg",
+   *     "age" -> "max"
+   *   ))
+   * }}}
+   *
+   * @group untypedrel
+   * @since 3.4.0
+   */
+  @scala.annotation.varargs
+  def rollup(cols: Column*): RelationalGroupedDataset = {
+    new RelationalGroupedDataset(
+      toDF(),
+      cols.map(_.expr),
+      proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP)
+  }
+
+  /**
+   * Create a multi-dimensional rollup for the current Dataset using the 
specified columns, so we
+   * can run aggregation on them. See [[RelationalGroupedDataset]] for all the 
available aggregate
+   * functions.
+   *
+   * This is a variant of rollup that can only group by existing columns using 
column names (i.e.
+   * cannot construct expressions).
+   *
+   * {{{
+   *   // Compute the average for all numeric columns rolled up by department 
and group.
+   *   ds.rollup("department", "group").avg()
+   *
+   *   // Compute the max age and average salary, rolled up by department and 
gender.
+   *   ds.rollup($"department", $"gender").agg(Map(
+   *     "salary" -> "avg",
+   *     "age" -> "max"
+   *   ))
+   * }}}
+   *
+   * @group untypedrel
+   * @since 3.4.0
+   */
+  @scala.annotation.varargs
+  def rollup(col1: String, cols: String*): RelationalGroupedDataset = {
+    val colNames: Seq[String] = col1 +: cols
+    new RelationalGroupedDataset(
+      toDF(),
+      colNames.map(colName => Column(colName).expr),
+      proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP)
+  }
+
+  /**
+   * Create a multi-dimensional cube for the current Dataset using the 
specified columns, so we
+   * can run aggregation on them. See [[RelationalGroupedDataset]] for all the 
available aggregate
+   * functions.
+   *
+   * {{{
+   *   // Compute the average for all numeric columns cubed by department and 
group.
+   *   ds.cube($"department", $"group").avg()
+   *
+   *   // Compute the max age and average salary, cubed by department and 
gender.
+   *   ds.cube($"department", $"gender").agg(Map(
+   *     "salary" -> "avg",
+   *     "age" -> "max"
+   *   ))
+   * }}}
+   *
+   * @group untypedrel
+   * @since 3.4.0
+   */
+  @scala.annotation.varargs
+  def cube(cols: Column*): RelationalGroupedDataset = {
+    new RelationalGroupedDataset(
+      toDF(),
+      cols.map(_.expr),
+      proto.Aggregate.GroupType.GROUP_TYPE_CUBE)
+  }
+
+  /**
+   * Create a multi-dimensional cube for the current Dataset using the 
specified columns, so we
+   * can run aggregation on them. See [[RelationalGroupedDataset]] for all the 
available aggregate
+   * functions.
+   *
+   * This is a variant of cube that can only group by existing columns using 
column names (i.e.
+   * cannot construct expressions).
+   *
+   * {{{
+   *   // Compute the average for all numeric columns cubed by department and 
group.
+   *   ds.cube("department", "group").avg()
+   *
+   *   // Compute the max age and average salary, cubed by department and 
gender.
+   *   ds.cube($"department", $"gender").agg(Map(
+   *     "salary" -> "avg",
+   *     "age" -> "max"
+   *   ))
+   * }}}
+   * @group untypedrel
+   * @since 3.4.0
+   */
+  @scala.annotation.varargs
+  def cube(col1: String, cols: String*): RelationalGroupedDataset = {
+    val colNames: Seq[String] = col1 +: cols
+    new RelationalGroupedDataset(
+      toDF(),
+      colNames.map(colName => Column(colName).expr),
+      proto.Aggregate.GroupType.GROUP_TYPE_CUBE)
   }
 
   /**
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index a6d3dc2e468..76db231db9e 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -37,16 +37,26 @@ import org.apache.spark.connect.proto
  */
 class RelationalGroupedDataset protected[sql] (
     private[sql] val df: DataFrame,
-    private[sql] val groupingExprs: Seq[proto.Expression]) {
+    private[sql] val groupingExprs: Seq[proto.Expression],
+    groupType: proto.Aggregate.GroupType) {
 
   private[this] def toDF(aggExprs: Seq[Column]): DataFrame = {
-    // TODO: support other GroupByType such as Rollup, Cube, Pivot.
     df.session.newDataset { builder =>
       builder.getAggregateBuilder
-        .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY)
         .setInput(df.plan.getRoot)
         .addAllGroupingExpressions(groupingExprs.asJava)
         .addAllAggregateExpressions(aggExprs.map(e => e.expr).asJava)
+
+      // TODO: support Pivot.
+      groupType match {
+        case proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP =>
+          
builder.getAggregateBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP)
+        case proto.Aggregate.GroupType.GROUP_TYPE_CUBE =>
+          
builder.getAggregateBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_CUBE)
+        case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY =>
+          
builder.getAggregateBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY)
+        case g => throw new UnsupportedOperationException(g.toString)
+      }
     }
   }
 
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index 42572f8427e..9ca91942567 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -1663,6 +1663,22 @@ class PlanGenerationTestSuite extends ConnectFunSuite 
with BeforeAndAfterAll wit
       .count()
   }
 
+  test("rollup column") {
+    simple.rollup(Column("a"), Column("b")).count()
+  }
+
+  test("cube column") {
+    simple.cube(Column("a"), Column("b")).count()
+  }
+
+  test("rollup string") {
+    simple.rollup("a", "b").count()
+  }
+
+  test("cube string") {
+    simple.cube("a", "b").count()
+  }
+
   test("function lit") {
     simple.select(
       fn.lit(fn.col("id")),
diff --git 
a/connector/connect/common/src/test/resources/query-tests/explain-results/cube_column.explain
 
b/connector/connect/common/src/test/resources/query-tests/explain-results/cube_column.explain
new file mode 100644
index 00000000000..1721162f478
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/explain-results/cube_column.explain
@@ -0,0 +1,4 @@
+Aggregate [a#0, b#0, spark_grouping_id#0L], [a#0, b#0, count(1) AS count#0L]
++- Expand [[id#0L, a#0, b#0, a#0, b#0, 0], [id#0L, a#0, b#0, a#0, null, 1], 
[id#0L, a#0, b#0, null, b#0, 2], [id#0L, a#0, b#0, null, null, 3]], [id#0L, 
a#0, b#0, a#0, b#0, spark_grouping_id#0L]
+   +- Project [id#0L, a#0, b#0, a#0 AS a#0, b#0 AS b#0]
+      +- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git 
a/connector/connect/common/src/test/resources/query-tests/explain-results/cube_string.explain
 
b/connector/connect/common/src/test/resources/query-tests/explain-results/cube_string.explain
new file mode 100644
index 00000000000..1721162f478
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/explain-results/cube_string.explain
@@ -0,0 +1,4 @@
+Aggregate [a#0, b#0, spark_grouping_id#0L], [a#0, b#0, count(1) AS count#0L]
++- Expand [[id#0L, a#0, b#0, a#0, b#0, 0], [id#0L, a#0, b#0, a#0, null, 1], 
[id#0L, a#0, b#0, null, b#0, 2], [id#0L, a#0, b#0, null, null, 3]], [id#0L, 
a#0, b#0, a#0, b#0, spark_grouping_id#0L]
+   +- Project [id#0L, a#0, b#0, a#0 AS a#0, b#0 AS b#0]
+      +- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git 
a/connector/connect/common/src/test/resources/query-tests/explain-results/rollup_column.explain
 
b/connector/connect/common/src/test/resources/query-tests/explain-results/rollup_column.explain
new file mode 100644
index 00000000000..c8f0f1e2aeb
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/explain-results/rollup_column.explain
@@ -0,0 +1,4 @@
+Aggregate [a#0, b#0, spark_grouping_id#0L], [a#0, b#0, count(1) AS count#0L]
++- Expand [[id#0L, a#0, b#0, a#0, b#0, 0], [id#0L, a#0, b#0, a#0, null, 1], 
[id#0L, a#0, b#0, null, null, 3]], [id#0L, a#0, b#0, a#0, b#0, 
spark_grouping_id#0L]
+   +- Project [id#0L, a#0, b#0, a#0 AS a#0, b#0 AS b#0]
+      +- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git 
a/connector/connect/common/src/test/resources/query-tests/explain-results/rollup_string.explain
 
b/connector/connect/common/src/test/resources/query-tests/explain-results/rollup_string.explain
new file mode 100644
index 00000000000..c8f0f1e2aeb
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/explain-results/rollup_string.explain
@@ -0,0 +1,4 @@
+Aggregate [a#0, b#0, spark_grouping_id#0L], [a#0, b#0, count(1) AS count#0L]
++- Expand [[id#0L, a#0, b#0, a#0, b#0, 0], [id#0L, a#0, b#0, a#0, null, 1], 
[id#0L, a#0, b#0, null, null, 3]], [id#0L, a#0, b#0, a#0, b#0, 
spark_grouping_id#0L]
+   +- Project [id#0L, a#0, b#0, a#0 AS a#0, b#0 AS b#0]
+      +- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/cube_column.json
 
b/connector/connect/common/src/test/resources/query-tests/queries/cube_column.json
new file mode 100644
index 00000000000..49016593a34
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/queries/cube_column.json
@@ -0,0 +1,34 @@
+{
+  "aggregate": {
+    "input": {
+      "localRelation": {
+        "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
+      }
+    },
+    "groupType": "GROUP_TYPE_CUBE",
+    "groupingExpressions": [{
+      "unresolvedAttribute": {
+        "unparsedIdentifier": "a"
+      }
+    }, {
+      "unresolvedAttribute": {
+        "unparsedIdentifier": "b"
+      }
+    }],
+    "aggregateExpressions": [{
+      "alias": {
+        "expr": {
+          "unresolvedFunction": {
+            "functionName": "count",
+            "arguments": [{
+              "literal": {
+                "integer": 1
+              }
+            }]
+          }
+        },
+        "name": ["count"]
+      }
+    }]
+  }
+}
\ No newline at end of file
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/cube_column.proto.bin
 
b/connector/connect/common/src/test/resources/query-tests/queries/cube_column.proto.bin
new file mode 100644
index 00000000000..c706144de59
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/queries/cube_column.proto.bin
@@ -0,0 +1,7 @@
+JR
+$Z" struct<id:bigint,a:int,b:double>
+a
+b"2
+
+count
+0count
\ No newline at end of file
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/cube_string.json
 
b/connector/connect/common/src/test/resources/query-tests/queries/cube_string.json
new file mode 100644
index 00000000000..49016593a34
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/queries/cube_string.json
@@ -0,0 +1,34 @@
+{
+  "aggregate": {
+    "input": {
+      "localRelation": {
+        "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
+      }
+    },
+    "groupType": "GROUP_TYPE_CUBE",
+    "groupingExpressions": [{
+      "unresolvedAttribute": {
+        "unparsedIdentifier": "a"
+      }
+    }, {
+      "unresolvedAttribute": {
+        "unparsedIdentifier": "b"
+      }
+    }],
+    "aggregateExpressions": [{
+      "alias": {
+        "expr": {
+          "unresolvedFunction": {
+            "functionName": "count",
+            "arguments": [{
+              "literal": {
+                "integer": 1
+              }
+            }]
+          }
+        },
+        "name": ["count"]
+      }
+    }]
+  }
+}
\ No newline at end of file
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/cube_string.proto.bin
 
b/connector/connect/common/src/test/resources/query-tests/queries/cube_string.proto.bin
new file mode 100644
index 00000000000..c706144de59
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/queries/cube_string.proto.bin
@@ -0,0 +1,7 @@
+JR
+$Z" struct<id:bigint,a:int,b:double>
+a
+b"2
+
+count
+0count
\ No newline at end of file
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/rollup_column.json
 
b/connector/connect/common/src/test/resources/query-tests/queries/rollup_column.json
new file mode 100644
index 00000000000..f976e4ea10f
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/queries/rollup_column.json
@@ -0,0 +1,34 @@
+{
+  "aggregate": {
+    "input": {
+      "localRelation": {
+        "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
+      }
+    },
+    "groupType": "GROUP_TYPE_ROLLUP",
+    "groupingExpressions": [{
+      "unresolvedAttribute": {
+        "unparsedIdentifier": "a"
+      }
+    }, {
+      "unresolvedAttribute": {
+        "unparsedIdentifier": "b"
+      }
+    }],
+    "aggregateExpressions": [{
+      "alias": {
+        "expr": {
+          "unresolvedFunction": {
+            "functionName": "count",
+            "arguments": [{
+              "literal": {
+                "integer": 1
+              }
+            }]
+          }
+        },
+        "name": ["count"]
+      }
+    }]
+  }
+}
\ No newline at end of file
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/rollup_column.proto.bin
 
b/connector/connect/common/src/test/resources/query-tests/queries/rollup_column.proto.bin
new file mode 100644
index 00000000000..89ef8ff947b
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/queries/rollup_column.proto.bin
@@ -0,0 +1,7 @@
+JR
+$Z" struct<id:bigint,a:int,b:double>
+a
+b"2
+
+count
+0count
\ No newline at end of file
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/rollup_string.json
 
b/connector/connect/common/src/test/resources/query-tests/queries/rollup_string.json
new file mode 100644
index 00000000000..f976e4ea10f
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/queries/rollup_string.json
@@ -0,0 +1,34 @@
+{
+  "aggregate": {
+    "input": {
+      "localRelation": {
+        "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
+      }
+    },
+    "groupType": "GROUP_TYPE_ROLLUP",
+    "groupingExpressions": [{
+      "unresolvedAttribute": {
+        "unparsedIdentifier": "a"
+      }
+    }, {
+      "unresolvedAttribute": {
+        "unparsedIdentifier": "b"
+      }
+    }],
+    "aggregateExpressions": [{
+      "alias": {
+        "expr": {
+          "unresolvedFunction": {
+            "functionName": "count",
+            "arguments": [{
+              "literal": {
+                "integer": 1
+              }
+            }]
+          }
+        },
+        "name": ["count"]
+      }
+    }]
+  }
+}
\ No newline at end of file
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/rollup_string.proto.bin
 
b/connector/connect/common/src/test/resources/query-tests/queries/rollup_string.proto.bin
new file mode 100644
index 00000000000..89ef8ff947b
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/queries/rollup_string.proto.bin
@@ -0,0 +1,7 @@
+JR
+$Z" struct<id:bigint,a:int,b:double>
+a
+b"2
+
+count
+0count
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to