This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new af45902d33c4 [SPARK-49422][CONNECT][SQL] Add groupByKey to sql/api
af45902d33c4 is described below

commit af45902d33c4d8e38a6427ac1d0c46fe057bb45a
Author: Herman van Hovell <[email protected]>
AuthorDate: Wed Sep 18 20:11:21 2024 -0400

    [SPARK-49422][CONNECT][SQL] Add groupByKey to sql/api
    
    ### What changes were proposed in this pull request?
    This PR adds `Dataset.groupByKey(..)` to the shared interface. I forgot to 
add in the previous PR.
    
    ### Why are the changes needed?
    The shared interface needs to support all functionality.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #48147 from hvanhovell/SPARK-49422-follow-up.
    
    Authored-by: Herman van Hovell <[email protected]>
    Signed-off-by: Herman van Hovell <[email protected]>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 24 +++-----
 .../scala/org/apache/spark/sql/api/Dataset.scala   | 22 +++++++
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 68 ++++------------------
 3 files changed, 39 insertions(+), 75 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 161a0d9d265f..accfff9f2b07 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -524,27 +524,11 @@ class Dataset[T] private[sql] (
     result(0)
   }
 
-  /**
-   * (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is 
grouped by the given
-   * key `func`.
-   *
-   * @group typedrel
-   * @since 3.5.0
-   */
+  /** @inheritdoc */
   def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = {
     KeyValueGroupedDatasetImpl[K, T](this, encoderFor[K], func)
   }
 
-  /**
-   * (Java-specific) Returns a [[KeyValueGroupedDataset]] where the data is 
grouped by the given
-   * key `func`.
-   *
-   * @group typedrel
-   * @since 3.5.0
-   */
-  def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): 
KeyValueGroupedDataset[K, T] =
-    groupByKey(ToScalaUDF(func))(encoder)
-
   /** @inheritdoc */
   @scala.annotation.varargs
   def rollup(cols: Column*): RelationalGroupedDataset = {
@@ -1480,4 +1464,10 @@ class Dataset[T] private[sql] (
   /** @inheritdoc */
   @scala.annotation.varargs
   override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, 
exprs: _*)
+
+  /** @inheritdoc */
+  override def groupByKey[K](
+      func: MapFunction[T, K],
+      encoder: Encoder[K]): KeyValueGroupedDataset[K, T] =
+    super.groupByKey(func, encoder).asInstanceOf[KeyValueGroupedDataset[K, T]]
 }
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
index 284a69fe6ee3..7a3d6b0e0387 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
@@ -1422,6 +1422,28 @@ abstract class Dataset[T] extends Serializable {
    */
   def reduce(func: ReduceFunction[T]): T = reduce(ToScalaUDF(func))
 
+  /**
+   * (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is 
grouped by the given
+   * key `func`.
+   *
+   * @group typedrel
+   * @since 2.0.0
+   */
+  def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T, DS]
+
+  /**
+   * (Java-specific) Returns a [[KeyValueGroupedDataset]] where the data is 
grouped by the given
+   * key `func`.
+   *
+   * @group typedrel
+   * @since 2.0.0
+   */
+  def groupByKey[K](
+      func: MapFunction[T, K],
+      encoder: Encoder[K]): KeyValueGroupedDataset[K, T, DS] = {
+    groupByKey(ToScalaUDF(func))(encoder)
+  }
+
   /**
    * Unpivot a DataFrame from wide format to long format, optionally leaving 
identifier columns
    * set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except 
for the aggregation,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 61f9e6ff7c04..ef628ca612b4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -62,7 +62,7 @@ import 
org.apache.spark.sql.execution.datasources.LogicalRelation
 import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, 
DataSourceV2ScanRelation, FileTable}
 import org.apache.spark.sql.execution.python.EvaluatePython
 import org.apache.spark.sql.execution.stat.StatFunctions
-import org.apache.spark.sql.internal.{DataFrameWriterImpl, 
DataFrameWriterV2Impl, MergeIntoWriterImpl, SQLConf, ToScalaUDF}
+import org.apache.spark.sql.internal.{DataFrameWriterImpl, 
DataFrameWriterV2Impl, MergeIntoWriterImpl, SQLConf}
 import org.apache.spark.sql.internal.ExpressionUtils.column
 import org.apache.spark.sql.internal.TypedAggUtils.withInputType
 import org.apache.spark.sql.streaming.DataStreamWriter
@@ -865,24 +865,7 @@ class Dataset[T] private[sql](
     Filter(condition.expr, logicalPlan)
   }
 
-  /**
-   * Groups the Dataset using the specified columns, so we can run aggregation 
on them. See
-   * [[RelationalGroupedDataset]] for all the available aggregate functions.
-   *
-   * {{{
-   *   // Compute the average for all numeric columns grouped by department.
-   *   ds.groupBy($"department").avg()
-   *
-   *   // Compute the max age and average salary, grouped by department and 
gender.
-   *   ds.groupBy($"department", $"gender").agg(Map(
-   *     "salary" -> "avg",
-   *     "age" -> "max"
-   *   ))
-   * }}}
-   *
-   * @group untypedrel
-   * @since 2.0.0
-   */
+  /** @inheritdoc */
   @scala.annotation.varargs
   def groupBy(cols: Column*): RelationalGroupedDataset = {
     RelationalGroupedDataset(toDF(), cols.map(_.expr), 
RelationalGroupedDataset.GroupByType)
@@ -914,13 +897,7 @@ class Dataset[T] private[sql](
     rdd.reduce(func)
   }
 
-  /**
-   * (Scala-specific)
-   * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the 
given key `func`.
-   *
-   * @group typedrel
-   * @since 2.0.0
-   */
+  /** @inheritdoc */
   def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = {
     val withGroupingKey = AppendColumns(func, logicalPlan)
     val executed = sparkSession.sessionState.executePlan(withGroupingKey)
@@ -933,16 +910,6 @@ class Dataset[T] private[sql](
       withGroupingKey.newColumns)
   }
 
-  /**
-   * (Java-specific)
-   * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the 
given key `func`.
-   *
-   * @group typedrel
-   * @since 2.0.0
-   */
-  def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): 
KeyValueGroupedDataset[K, T] =
-    groupByKey(ToScalaUDF(func))(encoder)
-
   /** @inheritdoc */
   def unpivot(
       ids: Array[Column],
@@ -1640,28 +1607,7 @@ class Dataset[T] private[sql](
     new DataFrameWriterV2Impl[T](table, this)
   }
 
-  /**
-   * Merges a set of updates, insertions, and deletions based on a source 
table into
-   * a target table.
-   *
-   * Scala Examples:
-   * {{{
-   *   spark.table("source")
-   *     .mergeInto("target", $"source.id" === $"target.id")
-   *     .whenMatched($"salary" === 100)
-   *     .delete()
-   *     .whenNotMatched()
-   *     .insertAll()
-   *     .whenNotMatchedBySource($"salary" === 100)
-   *     .update(Map(
-   *       "salary" -> lit(200)
-   *     ))
-   *     .merge()
-   * }}}
-   *
-   * @group basic
-   * @since 4.0.0
-   */
+  /** @inheritdoc */
   def mergeInto(table: String, condition: Column): MergeIntoWriter[T] = {
     if (isStreaming) {
       logicalPlan.failAnalysis(
@@ -2024,6 +1970,12 @@ class Dataset[T] private[sql](
   @scala.annotation.varargs
   override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, 
exprs: _*)
 
+  /** @inheritdoc */
+  override def groupByKey[K](
+      func: MapFunction[T, K],
+      encoder: Encoder[K]): KeyValueGroupedDataset[K, T] =
+    super.groupByKey(func, encoder).asInstanceOf[KeyValueGroupedDataset[K, T]]
+
   ////////////////////////////////////////////////////////////////////////////
   // For Python API
   ////////////////////////////////////////////////////////////////////////////


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to