hvanhovell commented on code in PR #40145:
URL: https://github.com/apache/spark/pull/40145#discussion_r1116470992


##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala:
##########
@@ -234,4 +239,132 @@ class RelationalGroupedDataset protected[sql] (
   def sum(colNames: String*): DataFrame = {
     toDF(colNames.map(colName => functions.sum(colName)))
   }
+
+  /**
+   * Pivots a column of the current `DataFrame` and performs the specified 
aggregation. There are
+   * two versions of pivot function: one that requires the caller to specify 
the list of distinct
+   * values to pivot on, and one that does not. The latter is more concise but 
less efficient,
+   * because Spark needs to first compute the list of distinct values 
internally.
+   *
+   * {{{
+   *   // Compute the sum of earnings for each year by course with each course 
as a separate column
+   *   df.groupBy("year").pivot("course", Seq("dotNET", 
"Java")).sum("earnings")
+   *
+   *   // Or without specifying column values (less efficient)
+   *   df.groupBy("year").pivot("course").sum("earnings")
+   * }}}
+   *
+   * From Spark 3.0.0, values can be literal columns, for instance, struct. 
For pivoting by
+   * multiple columns, use the `struct` function to combine the columns and 
values:
+   *
+   * {{{
+   *   df.groupBy("year")
+   *     .pivot("trainingCourse", Seq(struct(lit("java"), lit("Experts"))))
+   *     .agg(sum($"earnings"))
+   * }}}
+   *
+   * @see
+   *   `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, 
except for the
+   *   aggregation.
+   *
+   * @param pivotColumn
+   *   Name of the column to pivot.
+   * @param values
+   *   List of values that will be translated to columns in the output 
DataFrame.
+   * @since 3.4.0
+   */
+  def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = 
{
+    pivot(Column(pivotColumn), values)
+  }
+
+  /**
+   * (Java-specific) Pivots a column of the current `DataFrame` and performs 
the specified
+   * aggregation.
+   *
+   * There are two versions of pivot function: one that requires the caller to 
specify the list of
+   * distinct values to pivot on, and one that does not. The latter is more 
concise but less
+   * efficient, because Spark needs to first compute the list of distinct 
values internally.
+   *
+   * {{{
+   *   // Compute the sum of earnings for each year by course with each course 
as a separate column
+   *   df.groupBy("year").pivot("course", Arrays.<Object>asList("dotNET", 
"Java")).sum("earnings");
+   *
+   *   // Or without specifying column values (less efficient)
+   *   df.groupBy("year").pivot("course").sum("earnings");
+   * }}}
+   *
+   * @see
+   *   `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, 
except for the
+   *   aggregation.
+   *
+   * @param pivotColumn
+   *   Name of the column to pivot.
+   * @param values
+   *   List of values that will be translated to columns in the output 
DataFrame.
+   * @since 3.4.0
+   */
+  def pivot(pivotColumn: String, values: java.util.List[Any]): 
RelationalGroupedDataset = {
+    pivot(Column(pivotColumn), values)
+  }
+
+  /**
+   * Pivots a column of the current `DataFrame` and performs the specified 
aggregation. This is an
+   * overloaded version of the `pivot` method with `pivotColumn` of the 
`String` type.
+   *
+   * {{{
+   *   // Compute the sum of earnings for each year by course with each course 
as a separate column
+   *   df.groupBy($"year").pivot($"course", Seq("dotNET", 
"Java")).sum($"earnings")
+   * }}}
+   *
+   * @see
+   *   `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, 
except for the
+   *   aggregation.
+   *
+   * @param pivotColumn
+   *   the column to pivot.
+   * @param values
+   *   List of values that will be translated to columns in the output 
DataFrame.
+   * @since 3.4.0
+   */
+  def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = 
{
+    groupType match {
+      case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY =>
+        val valueExprs = values.map(_ match {
+          case c: Column if c.expr.hasLiteral => c.expr.getLiteral
+          case v => functions.lit(v).expr.getLiteral

Review Comment:
   If `v` is a `Column` (which can happen because of the guard in the previous 
case statement) then this will produce an empty literal. The same applies to a 
`scala.Symbol`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to