This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 8e589f38826 [SPARK-42541][CONNECT] Support Pivot with provided pivot 
column values
8e589f38826 is described below

commit 8e589f3882609059a9726be6a72c343a112eb12f
Author: Rui Wang <rui.w...@databricks.com>
AuthorDate: Fri Feb 24 22:30:20 2023 -0400

    [SPARK-42541][CONNECT] Support Pivot with provided pivot column values
    
    ### What changes were proposed in this pull request?
    
    Support Pivot with provided pivot column values. Not supporting Pivot 
without providing column values because that requires to do max value check 
which depends on the implementation of Spark configuration in Spark Connect.
    
    ### Why are the changes needed?
    
    API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    
    NO
    
    ### How was this patch tested?
    
    UT
    
    Closes #40145 from amaliujia/rw-pivot.
    
    Authored-by: Rui Wang <rui.w...@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
    (cherry picked from commit 34a2d95dadfca2ee643eb937d50f12e3b8b148eb)
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../spark/sql/RelationalGroupedDataset.scala       | 138 ++++++++++++++++++++-
 .../scala/org/apache/spark/sql/DatasetSuite.scala  |   7 ++
 .../apache/spark/sql/PlanGenerationTestSuite.scala |   4 +
 .../query-tests/explain-results/pivot.explain      |   4 +
 .../test/resources/query-tests/queries/pivot.json  |  45 +++++++
 .../resources/query-tests/queries/pivot.proto.bin  | Bin 0 -> 97 bytes
 6 files changed, 196 insertions(+), 2 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 76db231db9e..c918061ac46 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -38,7 +38,8 @@ import org.apache.spark.connect.proto
 class RelationalGroupedDataset protected[sql] (
     private[sql] val df: DataFrame,
     private[sql] val groupingExprs: Seq[proto.Expression],
-    groupType: proto.Aggregate.GroupType) {
+    groupType: proto.Aggregate.GroupType,
+    pivot: Option[proto.Aggregate.Pivot] = None) {
 
   private[this] def toDF(aggExprs: Seq[Column]): DataFrame = {
     df.session.newDataset { builder =>
@@ -47,7 +48,6 @@ class RelationalGroupedDataset protected[sql] (
         .addAllGroupingExpressions(groupingExprs.asJava)
         .addAllAggregateExpressions(aggExprs.map(e => e.expr).asJava)
 
-      // TODO: support Pivot.
       groupType match {
         case proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP =>
           
builder.getAggregateBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP)
@@ -55,6 +55,11 @@ class RelationalGroupedDataset protected[sql] (
           
builder.getAggregateBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_CUBE)
         case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY =>
           
builder.getAggregateBuilder.setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY)
+        case proto.Aggregate.GroupType.GROUP_TYPE_PIVOT =>
+          assert(pivot.isDefined)
+          builder.getAggregateBuilder
+            .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_PIVOT)
+            .setPivot(pivot.get)
         case g => throw new UnsupportedOperationException(g.toString)
       }
     }
@@ -234,4 +239,133 @@ class RelationalGroupedDataset protected[sql] (
   def sum(colNames: String*): DataFrame = {
     toDF(colNames.map(colName => functions.sum(colName)))
   }
+
+  /**
+   * Pivots a column of the current `DataFrame` and performs the specified 
aggregation. There are
+   * two versions of pivot function: one that requires the caller to specify 
the list of distinct
+   * values to pivot on, and one that does not. The latter is more concise but 
less efficient,
+   * because Spark needs to first compute the list of distinct values 
internally.
+   *
+   * {{{
+   *   // Compute the sum of earnings for each year by course with each course 
as a separate column
+   *   df.groupBy("year").pivot("course", Seq("dotNET", 
"Java")).sum("earnings")
+   *
+   *   // Or without specifying column values (less efficient)
+   *   df.groupBy("year").pivot("course").sum("earnings")
+   * }}}
+   *
+   * From Spark 3.0.0, values can be literal columns, for instance, struct. 
For pivoting by
+   * multiple columns, use the `struct` function to combine the columns and 
values:
+   *
+   * {{{
+   *   df.groupBy("year")
+   *     .pivot("trainingCourse", Seq(struct(lit("java"), lit("Experts"))))
+   *     .agg(sum($"earnings"))
+   * }}}
+   *
+   * @see
+   *   `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, 
except for the
+   *   aggregation.
+   *
+   * @param pivotColumn
+   *   Name of the column to pivot.
+   * @param values
+   *   List of values that will be translated to columns in the output 
DataFrame.
+   * @since 3.4.0
+   */
+  def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = 
{
+    pivot(Column(pivotColumn), values)
+  }
+
+  /**
+   * (Java-specific) Pivots a column of the current `DataFrame` and performs 
the specified
+   * aggregation.
+   *
+   * There are two versions of pivot function: one that requires the caller to 
specify the list of
+   * distinct values to pivot on, and one that does not. The latter is more 
concise but less
+   * efficient, because Spark needs to first compute the list of distinct 
values internally.
+   *
+   * {{{
+   *   // Compute the sum of earnings for each year by course with each course 
as a separate column
+   *   df.groupBy("year").pivot("course", Arrays.<Object>asList("dotNET", 
"Java")).sum("earnings");
+   *
+   *   // Or without specifying column values (less efficient)
+   *   df.groupBy("year").pivot("course").sum("earnings");
+   * }}}
+   *
+   * @see
+   *   `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, 
except for the
+   *   aggregation.
+   *
+   * @param pivotColumn
+   *   Name of the column to pivot.
+   * @param values
+   *   List of values that will be translated to columns in the output 
DataFrame.
+   * @since 3.4.0
+   */
+  def pivot(pivotColumn: String, values: java.util.List[Any]): 
RelationalGroupedDataset = {
+    pivot(Column(pivotColumn), values)
+  }
+
+  /**
+   * Pivots a column of the current `DataFrame` and performs the specified 
aggregation. This is an
+   * overloaded version of the `pivot` method with `pivotColumn` of the 
`String` type.
+   *
+   * {{{
+   *   // Compute the sum of earnings for each year by course with each course 
as a separate column
+   *   df.groupBy($"year").pivot($"course", Seq("dotNET", 
"Java")).sum($"earnings")
+   * }}}
+   *
+   * @see
+   *   `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, 
except for the
+   *   aggregation.
+   *
+   * @param pivotColumn
+   *   the column to pivot.
+   * @param values
+   *   List of values that will be translated to columns in the output 
DataFrame.
+   * @since 3.4.0
+   */
+  def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = 
{
+    groupType match {
+      case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY =>
+        val valueExprs = values.map(_ match {
+          case c: Column if c.expr.hasLiteral => c.expr.getLiteral
+          case c: Column if !c.expr.hasLiteral =>
+            throw new IllegalArgumentException("values only accept literal 
Column")
+          case v => functions.lit(v).expr.getLiteral
+        })
+        new RelationalGroupedDataset(
+          df,
+          groupingExprs,
+          proto.Aggregate.GroupType.GROUP_TYPE_PIVOT,
+          Some(
+            proto.Aggregate.Pivot
+              .newBuilder()
+              .setCol(pivotColumn.expr)
+              .addAllValues(valueExprs.asJava)
+              .build()))
+      case _ =>
+        throw new UnsupportedOperationException()
+    }
+  }
+
+  /**
+   * (Java-specific) Pivots a column of the current `DataFrame` and performs 
the specified
+   * aggregation. This is an overloaded version of the `pivot` method with 
`pivotColumn` of the
+   * `String` type.
+   *
+   * @see
+   *   `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, 
except for the
+   *   aggregation.
+   *
+   * @param pivotColumn
+   *   the column to pivot.
+   * @param values
+   *   List of values that will be translated to columns in the output 
DataFrame.
+   * @since 3.4.0
+   */
+  def pivot(pivotColumn: Column, values: java.util.List[Any]): 
RelationalGroupedDataset = {
+    pivot(pivotColumn, values.asScala.toSeq)
+  }
 }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index c3c80a08379..be69959beac 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -129,4 +129,11 @@ class DatasetSuite
     val actualPlan = service.getAndClearLatestInputPlan()
     assert(actualPlan.equals(expectedPlan))
   }
+
+  test("Pivot") {
+    val df = ss.newDataset(_ => ())
+    intercept[IllegalArgumentException] {
+      df.groupBy().pivot(Column("c"), Seq(Column("col")))
+    }
+  }
 }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index f7589d957ca..465d2091ca2 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -1949,6 +1949,10 @@ class PlanGenerationTestSuite
     simple.cube("a", "b").count()
   }
 
+  test("pivot") {
+    simple.groupBy(Column("id")).pivot("a", Seq(1, 2, 
3)).agg(functions.count(Column("b")))
+  }
+
   test("function lit") {
     simple.select(
       fn.lit(fn.col("id")),
diff --git 
a/connector/connect/common/src/test/resources/query-tests/explain-results/pivot.explain
 
b/connector/connect/common/src/test/resources/query-tests/explain-results/pivot.explain
new file mode 100644
index 00000000000..b8cd8441237
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/explain-results/pivot.explain
@@ -0,0 +1,4 @@
+Project [id#0L, __pivot_count(b) AS `count(b)`#0[0] AS 1#0L, __pivot_count(b) 
AS `count(b)`#0[1] AS 2#0L, __pivot_count(b) AS `count(b)`#0[2] AS 3#0L]
++- Aggregate [id#0L], [id#0L, pivotfirst(a#0, count(b)#0L, 1, 2, 3, 0, 0) AS 
__pivot_count(b) AS `count(b)`#0]
+   +- Aggregate [id#0L, a#0], [id#0L, a#0, count(b#0) AS count(b)#0L]
+      +- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/pivot.json 
b/connector/connect/common/src/test/resources/query-tests/queries/pivot.json
new file mode 100644
index 00000000000..30bff04c531
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/queries/pivot.json
@@ -0,0 +1,45 @@
+{
+  "common": {
+    "planId": "1"
+  },
+  "aggregate": {
+    "input": {
+      "common": {
+        "planId": "0"
+      },
+      "localRelation": {
+        "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
+      }
+    },
+    "groupType": "GROUP_TYPE_PIVOT",
+    "groupingExpressions": [{
+      "unresolvedAttribute": {
+        "unparsedIdentifier": "id"
+      }
+    }],
+    "aggregateExpressions": [{
+      "unresolvedFunction": {
+        "functionName": "count",
+        "arguments": [{
+          "unresolvedAttribute": {
+            "unparsedIdentifier": "b"
+          }
+        }]
+      }
+    }],
+    "pivot": {
+      "col": {
+        "unresolvedAttribute": {
+          "unparsedIdentifier": "a"
+        }
+      },
+      "values": [{
+        "integer": 1
+      }, {
+        "integer": 2
+      }, {
+        "integer": 3
+      }]
+    }
+  }
+}
\ No newline at end of file
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/pivot.proto.bin
 
b/connector/connect/common/src/test/resources/query-tests/queries/pivot.proto.bin
new file mode 100644
index 00000000000..67063209a18
Binary files /dev/null and 
b/connector/connect/common/src/test/resources/query-tests/queries/pivot.proto.bin
 differ


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to