This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ba8cae2031f [SPARK-43223][CONNECT] Typed agg, reduce functions, 
RelationalGroupedDataset#as
ba8cae2031f is described below

commit ba8cae2031f81dc326d386cbe7d19c1f0a8f239e
Author: Zhen Li <zhenli...@users.noreply.github.com>
AuthorDate: Mon May 15 11:05:33 2023 -0400

    [SPARK-43223][CONNECT] Typed agg, reduce functions, 
RelationalGroupedDataset#as
    
    ### What changes were proposed in this pull request?
    Added the agg, reduce support in `KeyValueGroupedDataset`.
    Added `Dataset#reduce`
    Added `RelationalGroupedDataset#as`.
    
    Summary:
    * `KVGDS#agg`: `KVGDS#agg` and the `RelationalGroupedDS#agg` shares the 
exact same proto. The only difference is that the KVGDS always passing a UDF as 
the first grouping expression. That's also how we tell them apart in this PR.
    * `KVGDS#reduce`: Reduce is a special aggregation. The client uses an 
UnresolvedFunc "reduce" to mark the agg operator is a `ReduceAggregator` and 
calls `KVGDS#agg` directly. The server would be able to pick this func up 
directly and reuse the agg code path by sending in a `ReduceAggregator`.
    * `Dataset#reduce`: This is free after `KVGDS#reduce`.
    * `RelationalGroupedDS#as`: The only difference between `KVGDS` created 
using `ds#groupByKey` and `ds#agg#as` is the grouping expressions. The former 
requires one grouping func as the grouping expression, the latter uses a dummy 
func (to pass encoders/types to the server) + grouping expressions. Thus the 
server can count how many grouping expressions received and decide if the 
`KVGDS` should be created as `ds#groupByKey` or `ds#agg#as`.
    
    Followups:
    * [SPARK-43415] Support mapValues in the Agg functions.
    * [SPARK-43416] The tupled ProductEncoder dose not pick up the fields names 
from the server.
    
    ### Why are the changes needed?
    Missing APIs in Scala Client
    
    ### Does this PR introduce _any_ user-facing change?
    Added `KeyValueGrouppedDataset#agg, reduce`, `Dataset#reduce`, 
`RelationalGroupedDataset#as` methods for the Scala client.
    
    ### How was this patch tested?
    E2E tests
    
    Closes #40796 from zhenlineo/typed-agg.
    
    Authored-by: Zhen Li <zhenli...@users.noreply.github.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  66 +++--
 .../apache/spark/sql/KeyValueGroupedDataset.scala  | 255 ++++++++++++++++--
 .../spark/sql/RelationalGroupedDataset.scala       |  14 +-
 .../sql/KeyValueGroupedDatasetE2ETestSuite.scala   | 290 ++++++++++++++++++---
 .../sql/UserDefinedFunctionE2ETestSuite.scala      |  18 ++
 .../CheckConnectJvmClientCompatibility.scala       |   8 -
 .../spark/sql/connect/client/util/QueryTest.scala  |  36 ++-
 .../apache/spark/sql/connect/common/UdfUtils.scala |   4 +
 .../sql/connect/planner/SparkConnectPlanner.scala  | 209 +++++++++++----
 .../spark/sql/catalyst/plans/logical/object.scala  |  16 ++
 .../main/scala/org/apache/spark/sql/Column.scala   |  13 +-
 .../apache/spark/sql/KeyValueGroupedDataset.scala  |  15 +-
 .../spark/sql/RelationalGroupedDataset.scala       |  53 ++--
 .../spark/sql/expressions/ReduceAggregator.scala   |   6 +
 .../apache/spark/sql/internal/TypedAggUtils.scala  |  62 +++++
 15 files changed, 883 insertions(+), 182 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 555f6c312c5..7a680bde7d3 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1242,10 +1242,7 @@ class Dataset[T] private[sql] (
    */
   @scala.annotation.varargs
   def groupBy(cols: Column*): RelationalGroupedDataset = {
-    new RelationalGroupedDataset(
-      toDF(),
-      cols.map(_.expr),
-      proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY)
+    new RelationalGroupedDataset(toDF(), cols, 
proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY)
   }
 
   /**
@@ -1273,10 +1270,45 @@ class Dataset[T] private[sql] (
     val colNames: Seq[String] = col1 +: cols
     new RelationalGroupedDataset(
       toDF(),
-      colNames.map(colName => Column(colName).expr),
+      colNames.map(colName => Column(colName)),
       proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY)
   }
 
+  /**
+   * (Scala-specific) Reduces the elements of this Dataset using the specified 
binary function.
+   * The given `func` must be commutative and associative or the result may be 
non-deterministic.
+   *
+   * @group action
+   * @since 3.5.0
+   */
+  def reduce(func: (T, T) => T): T = {
+    val udf = ScalarUserDefinedFunction(
+      function = func,
+      inputEncoders = encoder :: encoder :: Nil,
+      outputEncoder = encoder)
+    val reduceExpr = Column.fn("reduce", udf.apply(col("*"), col("*"))).expr
+
+    val result = sparkSession
+      .newDataset(encoder) { builder =>
+        builder.getAggregateBuilder
+          .setInput(plan.getRoot)
+          .addAggregateExpressions(reduceExpr)
+          .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY)
+      }
+      .collect()
+    assert(result.length == 1)
+    result(0)
+  }
+
+  /**
+   * (Java-specific) Reduces the elements of this Dataset using the specified 
binary function. The
+   * given `func` must be commutative and associative or the result may be 
non-deterministic.
+   *
+   * @group action
+   * @since 3.5.0
+   */
+  def reduce(func: ReduceFunction[T]): T = 
reduce(UdfUtils.mapReduceFuncToScalaFunc(func))
+
   /**
    * (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is 
grouped by the given
    * key `func`.
@@ -1285,15 +1317,7 @@ class Dataset[T] private[sql] (
    * @since 3.5.0
    */
   def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = {
-    val kEncoder = encoderFor[K]
-    new KeyValueGroupedDatasetImpl[K, T, K, T](
-      this,
-      sparkSession,
-      plan,
-      kEncoder,
-      kEncoder,
-      func,
-      UdfUtils.identical())
+    KeyValueGroupedDatasetImpl[K, T](this, encoderFor[K], func)
   }
 
   /**
@@ -1327,10 +1351,7 @@ class Dataset[T] private[sql] (
    */
   @scala.annotation.varargs
   def rollup(cols: Column*): RelationalGroupedDataset = {
-    new RelationalGroupedDataset(
-      toDF(),
-      cols.map(_.expr),
-      proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP)
+    new RelationalGroupedDataset(toDF(), cols, 
proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP)
   }
 
   /**
@@ -1360,7 +1381,7 @@ class Dataset[T] private[sql] (
     val colNames: Seq[String] = col1 +: cols
     new RelationalGroupedDataset(
       toDF(),
-      colNames.map(colName => Column(colName).expr),
+      colNames.map(colName => Column(colName)),
       proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP)
   }
 
@@ -1385,10 +1406,7 @@ class Dataset[T] private[sql] (
    */
   @scala.annotation.varargs
   def cube(cols: Column*): RelationalGroupedDataset = {
-    new RelationalGroupedDataset(
-      toDF(),
-      cols.map(_.expr),
-      proto.Aggregate.GroupType.GROUP_TYPE_CUBE)
+    new RelationalGroupedDataset(toDF(), cols, 
proto.Aggregate.GroupType.GROUP_TYPE_CUBE)
   }
 
   /**
@@ -1417,7 +1435,7 @@ class Dataset[T] private[sql] (
     val colNames: Seq[String] = col1 +: cols
     new RelationalGroupedDataset(
       toDF(),
-      colNames.map(colName => Column(colName).expr),
+      colNames.map(colName => Column(colName)),
       proto.Aggregate.GroupType.GROUP_TYPE_CUBE)
   }
 
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 2d712bc4c51..7b2fa3b52be 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -25,6 +25,7 @@ import scala.language.existentials
 import org.apache.spark.api.java.function._
 import org.apache.spark.connect.proto
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder
 import org.apache.spark.sql.connect.common.UdfUtils
 import org.apache.spark.sql.expressions.ScalarUserDefinedFunction
 import org.apache.spark.sql.functions.col
@@ -239,6 +240,153 @@ abstract class KeyValueGroupedDataset[K, V] private[sql] 
() extends Serializable
     mapGroups(UdfUtils.mapGroupsFuncToScalaFunc(f))(encoder)
   }
 
+  /**
+   * (Scala-specific) Reduces the elements of each group of data using the 
specified binary
+   * function. The given function must be commutative and associative or the 
result may be
+   * non-deterministic.
+   *
+   * @since 3.5.0
+   */
+  def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * (Java-specific) Reduces the elements of each group of data using the 
specified binary
+   * function. The given function must be commutative and associative or the 
result may be
+   * non-deterministic.
+   *
+   * @since 3.5.0
+   */
+  def reduceGroups(f: ReduceFunction[V]): Dataset[(K, V)] = {
+    reduceGroups(UdfUtils.mapReduceFuncToScalaFunc(f))
+  }
+
+  /**
+   * Internal helper function for building typed aggregations that return 
tuples. For simplicity
+   * and code reuse, we do this without the help of the type system and then 
use helper functions
+   * that cast appropriately for the user facing interface.
+   */
+  protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * Computes the given aggregation, returning a [[Dataset]] of tuples for 
each unique key and the
+   * result of computing this aggregation over all elements in the group.
+   *
+   * @since 3.5.0
+   */
+  def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] =
+    aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]]
+
+  /**
+   * Computes the given aggregations, returning a [[Dataset]] of tuples for 
each unique key and
+   * the result of computing these aggregations over all elements in the group.
+   *
+   * @since 3.5.0
+   */
+  def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): 
Dataset[(K, U1, U2)] =
+    aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]]
+
+  /**
+   * Computes the given aggregations, returning a [[Dataset]] of tuples for 
each unique key and
+   * the result of computing these aggregations over all elements in the group.
+   *
+   * @since 3.5.0
+   */
+  def agg[U1, U2, U3](
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2],
+      col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] =
+    aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]]
+
+  /**
+   * Computes the given aggregations, returning a [[Dataset]] of tuples for 
each unique key and
+   * the result of computing these aggregations over all elements in the group.
+   *
+   * @since 3.5.0
+   */
+  def agg[U1, U2, U3, U4](
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2],
+      col3: TypedColumn[V, U3],
+      col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] =
+    aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, 
U4)]]
+
+  /**
+   * Computes the given aggregations, returning a [[Dataset]] of tuples for 
each unique key and
+   * the result of computing these aggregations over all elements in the group.
+   *
+   * @since 3.5.0
+   */
+  def agg[U1, U2, U3, U4, U5](
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2],
+      col3: TypedColumn[V, U3],
+      col4: TypedColumn[V, U4],
+      col5: TypedColumn[V, U5]): Dataset[(K, U1, U2, U3, U4, U5)] =
+    aggUntyped(col1, col2, col3, col4, col5).asInstanceOf[Dataset[(K, U1, U2, 
U3, U4, U5)]]
+
+  /**
+   * Computes the given aggregations, returning a [[Dataset]] of tuples for 
each unique key and
+   * the result of computing these aggregations over all elements in the group.
+   *
+   * @since 3.5.0
+   */
+  def agg[U1, U2, U3, U4, U5, U6](
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2],
+      col3: TypedColumn[V, U3],
+      col4: TypedColumn[V, U4],
+      col5: TypedColumn[V, U5],
+      col6: TypedColumn[V, U6]): Dataset[(K, U1, U2, U3, U4, U5, U6)] =
+    aggUntyped(col1, col2, col3, col4, col5, col6)
+      .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6)]]
+
+  /**
+   * Computes the given aggregations, returning a [[Dataset]] of tuples for 
each unique key and
+   * the result of computing these aggregations over all elements in the group.
+   *
+   * @since 3.5.0
+   */
+  def agg[U1, U2, U3, U4, U5, U6, U7](
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2],
+      col3: TypedColumn[V, U3],
+      col4: TypedColumn[V, U4],
+      col5: TypedColumn[V, U5],
+      col6: TypedColumn[V, U6],
+      col7: TypedColumn[V, U7]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7)] =
+    aggUntyped(col1, col2, col3, col4, col5, col6, col7)
+      .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7)]]
+
+  /**
+   * Computes the given aggregations, returning a [[Dataset]] of tuples for 
each unique key and
+   * the result of computing these aggregations over all elements in the group.
+   *
+   * @since 3.5.0
+   */
+  def agg[U1, U2, U3, U4, U5, U6, U7, U8](
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2],
+      col3: TypedColumn[V, U3],
+      col4: TypedColumn[V, U4],
+      col5: TypedColumn[V, U5],
+      col6: TypedColumn[V, U6],
+      col7: TypedColumn[V, U7],
+      col8: TypedColumn[V, U8]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)] =
+    aggUntyped(col1, col2, col3, col4, col5, col6, col7, col8)
+      .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)]]
+
+  /**
+   * Returns a [[Dataset]] that contains a tuple with each key and the number 
of items present for
+   * that key.
+   *
+   * @since 3.5.0
+   */
+  def count(): Dataset[(K, Long)] = agg(functions.count("*"))
+
   /**
    * (Scala-specific) Applies the given function to each cogrouped data. For 
each unique group,
    * the function will be passed the grouping key and 2 iterators containing 
all elements in the
@@ -322,41 +470,45 @@ abstract class KeyValueGroupedDataset[K, V] private[sql] 
() extends Serializable
  * [[KeyValueGroupedDataset]] behaves on the server.
  */
 private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
-    private val ds: Dataset[IV],
     private val sparkSession: SparkSession,
     private val plan: proto.Plan,
     private val ikEncoder: AgnosticEncoder[IK],
     private val kEncoder: AgnosticEncoder[K],
-    private val groupingFunc: IV => IK,
-    private val valueMapFunc: IV => V)
+    private val ivEncoder: AgnosticEncoder[IV],
+    private val vEncoder: AgnosticEncoder[V],
+    private val groupingExprs: java.util.List[proto.Expression],
+    private val valueMapFunc: IV => V,
+    private val keysFunc: () => Dataset[IK])
     extends KeyValueGroupedDataset[K, V] {
 
-  private val ivEncoder = ds.encoder
-
   override def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] = {
     new KeyValueGroupedDatasetImpl[L, V, IK, IV](
-      ds,
       sparkSession,
       plan,
       ikEncoder,
       encoderFor[L],
-      groupingFunc,
-      valueMapFunc)
+      ivEncoder,
+      vEncoder,
+      groupingExprs,
+      valueMapFunc,
+      keysFunc)
   }
 
   override def mapValues[W: Encoder](valueFunc: V => W): 
KeyValueGroupedDataset[K, W] = {
     new KeyValueGroupedDatasetImpl[K, W, IK, IV](
-      ds,
       sparkSession,
       plan,
       ikEncoder,
       kEncoder,
-      groupingFunc,
-      valueMapFunc.andThen(valueFunc))
+      ivEncoder,
+      encoderFor[W],
+      groupingExprs,
+      valueMapFunc.andThen(valueFunc),
+      keysFunc)
   }
 
   override def keys: Dataset[K] = {
-    ds.map(groupingFunc)(ikEncoder)
+    keysFunc()
       .dropDuplicates()
       .as(kEncoder)
   }
@@ -371,7 +523,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
       builder.getGroupMapBuilder
         .setInput(plan.getRoot)
         .addAllSortingExpressions(sortExprs.map(e => e.expr).asJava)
-        .addAllGroupingExpressions(getGroupingExpressions)
+        .addAllGroupingExpressions(groupingExprs)
         .setFunc(getUdf(nf, outputEncoder)(ivEncoder))
     }
   }
@@ -387,21 +539,37 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
     sparkSession.newDataset[R](outputEncoder) { builder =>
       builder.getCoGroupMapBuilder
         .setInput(plan.getRoot)
-        .addAllInputGroupingExpressions(getGroupingExpressions)
+        .addAllInputGroupingExpressions(groupingExprs)
         .addAllInputSortingExpressions(thisSortExprs.map(e => e.expr).asJava)
         .setOther(otherImpl.plan.getRoot)
-        .addAllOtherGroupingExpressions(otherImpl.getGroupingExpressions)
+        .addAllOtherGroupingExpressions(otherImpl.groupingExprs)
         .addAllOtherSortingExpressions(otherSortExprs.map(e => e.expr).asJava)
         .setFunc(getUdf(nf, outputEncoder)(ivEncoder, otherImpl.ivEncoder))
     }
   }
 
-  private def getGroupingExpressions = {
-    val gf = ScalarUserDefinedFunction(
-      function = groupingFunc,
-      inputEncoders = ivEncoder :: Nil, // Using the original value and key 
encoders
-      outputEncoder = ikEncoder)
-    Arrays.asList(gf.apply(col("*")).expr)
+  override protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = 
{
+    // TODO(SPARK-43415): For each column, apply the valueMap func first
+    val rEnc = ProductEncoder.tuple(kEncoder +: columns.map(_.encoder)) // 
apply keyAs change
+    sparkSession.newDataset(rEnc) { builder =>
+      builder.getAggregateBuilder
+        .setInput(plan.getRoot)
+        .setGroupType(proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY)
+        .addAllGroupingExpressions(groupingExprs)
+        .addAllAggregateExpressions(columns.map(_.expr).asJava)
+    }
+  }
+
+  override def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = {
+    val inputEncoders = Seq(vEncoder, vEncoder)
+    val udf = ScalarUserDefinedFunction(
+      function = f,
+      inputEncoders = inputEncoders,
+      outputEncoder = vEncoder)
+    val input = udf.apply(inputEncoders.map(_ => col("*")): _*)
+    val expr = Column.fn("reduce", input).expr
+    val aggregator: TypedColumn[V, V] = new TypedColumn[V, V](expr, vEncoder)
+    agg(aggregator)
   }
 
   private def getUdf[U: Encoder](nf: AnyRef, outputEncoder: 
AgnosticEncoder[U])(
@@ -414,3 +582,48 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
     udf.apply(inputEncoders.map(_ => col("*")): 
_*).expr.getCommonInlineUserDefinedFunction
   }
 }
+
+private object KeyValueGroupedDatasetImpl {
+  def apply[K, V](
+      ds: Dataset[V],
+      kEncoder: AgnosticEncoder[K],
+      groupingFunc: V => K): KeyValueGroupedDatasetImpl[K, V, K, V] = {
+    val gf = ScalarUserDefinedFunction(
+      function = groupingFunc,
+      inputEncoders = ds.encoder :: Nil, // Using the original value and key 
encoders
+      outputEncoder = kEncoder)
+    new KeyValueGroupedDatasetImpl(
+      ds.sparkSession,
+      ds.plan,
+      kEncoder,
+      kEncoder,
+      ds.encoder,
+      ds.encoder,
+      Arrays.asList(gf.apply(col("*")).expr),
+      UdfUtils.identical(),
+      () => ds.map(groupingFunc)(kEncoder))
+  }
+
+  def apply[K, V](
+      df: DataFrame,
+      kEncoder: AgnosticEncoder[K],
+      vEncoder: AgnosticEncoder[V],
+      groupingExprs: Seq[Column]): KeyValueGroupedDatasetImpl[K, V, K, V] = {
+    // Use a dummy udf to pass the K V encoders
+    val dummyGroupingFunc = ScalarUserDefinedFunction(
+      function = UdfUtils.noOp[V, K](),
+      inputEncoders = vEncoder :: Nil,
+      outputEncoder = kEncoder).apply(col("*"))
+
+    new KeyValueGroupedDatasetImpl(
+      df.sparkSession,
+      df.plan,
+      kEncoder,
+      kEncoder,
+      vEncoder,
+      vEncoder,
+      (Seq(dummyGroupingFunc) ++ groupingExprs).map(_.expr).asJava,
+      UdfUtils.identical(),
+      () => df.select(groupingExprs: _*).as(kEncoder))
+  }
+}
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 5a10e1d52eb..c19314a0d5c 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -37,7 +37,7 @@ import org.apache.spark.connect.proto
  */
 class RelationalGroupedDataset private[sql] (
     private[sql] val df: DataFrame,
-    private[sql] val groupingExprs: Seq[proto.Expression],
+    private[sql] val groupingExprs: Seq[Column],
     groupType: proto.Aggregate.GroupType,
     pivot: Option[proto.Aggregate.Pivot] = None) {
 
@@ -45,7 +45,7 @@ class RelationalGroupedDataset private[sql] (
     df.sparkSession.newDataFrame { builder =>
       builder.getAggregateBuilder
         .setInput(df.plan.getRoot)
-        .addAllGroupingExpressions(groupingExprs.asJava)
+        .addAllGroupingExpressions(groupingExprs.map(_.expr).asJava)
         .addAllAggregateExpressions(aggExprs.map(e => e.expr).asJava)
 
       groupType match {
@@ -65,6 +65,16 @@ class RelationalGroupedDataset private[sql] (
     }
   }
 
+  /**
+   * Returns a `KeyValueGroupedDataset` where the data is grouped by the 
grouping expressions of
+   * current `RelationalGroupedDataset`.
+   *
+   * @since 3.5.0
+   */
+  def as[K: Encoder, T: Encoder]: KeyValueGroupedDataset[K, T] = {
+    KeyValueGroupedDatasetImpl[K, T](df, encoderFor[K], encoderFor[T], 
groupingExprs)
+  }
+
   /**
    * (Scala-specific) Compute aggregates by specifying the column names and 
aggregate methods. The
    * resulting `DataFrame` will also contain the grouping columns.
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
index 097efa01a42..e7a77eed70d 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
@@ -20,13 +20,17 @@ import java.util.Arrays
 
 import io.grpc.StatusRuntimeException
 
-import org.apache.spark.sql.connect.client.util.RemoteSparkSession
+import org.apache.spark.sql.connect.client.util.QueryTest
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
 
 /**
- * All tests in this class requires client UDF artifacts synced with the 
server. TODO: It means
- * these tests only works with SBT for now.
+ * All tests in this class requires client UDF artifacts synced with the 
server.
  */
-class KeyValueGroupedDatasetE2ETestSuite extends RemoteSparkSession {
+class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper {
+
+  lazy val session: SparkSession = spark
+  import session.implicits._
 
   test("mapGroups") {
     val session: SparkSession = spark
@@ -40,8 +44,6 @@ class KeyValueGroupedDatasetE2ETestSuite extends 
RemoteSparkSession {
   }
 
   test("flatGroupMap") {
-    val session: SparkSession = spark
-    import session.implicits._
     val values = spark
       .range(10)
       .groupByKey(v => v % 2)
@@ -51,8 +53,6 @@ class KeyValueGroupedDatasetE2ETestSuite extends 
RemoteSparkSession {
   }
 
   test("keys") {
-    val session: SparkSession = spark
-    import session.implicits._
     val values = spark
       .range(10)
       .groupByKey(v => v % 2)
@@ -63,8 +63,6 @@ class KeyValueGroupedDatasetE2ETestSuite extends 
RemoteSparkSession {
 
   test("keyAs - keys") {
     // It is okay to cast from Long to Double, but not Long to Int.
-    val session: SparkSession = spark
-    import session.implicits._
     val values = spark
       .range(10)
       .groupByKey(v => v % 2)
@@ -75,8 +73,6 @@ class KeyValueGroupedDatasetE2ETestSuite extends 
RemoteSparkSession {
   }
 
   test("keyAs - flatGroupMap") {
-    val session: SparkSession = spark
-    import session.implicits._
     val values = spark
       .range(10)
       .groupByKey(v => v % 2)
@@ -87,8 +83,6 @@ class KeyValueGroupedDatasetE2ETestSuite extends 
RemoteSparkSession {
   }
 
   test("keyAs mapValues - cogroup") {
-    val session: SparkSession = spark
-    import session.implicits._
     val grouped = spark
       .range(10)
       .groupByKey(v => v % 2)
@@ -120,8 +114,6 @@ class KeyValueGroupedDatasetE2ETestSuite extends 
RemoteSparkSession {
   }
 
   test("mapValues - flatGroupMap") {
-    val session: SparkSession = spark
-    import session.implicits._
     val values = spark
       .range(10)
       .groupByKey(v => v % 2)
@@ -132,8 +124,6 @@ class KeyValueGroupedDatasetE2ETestSuite extends 
RemoteSparkSession {
   }
 
   test("mapValues - keys") {
-    val session: SparkSession = spark
-    import session.implicits._
     val values = spark
       .range(10)
       .groupByKey(v => v % 2)
@@ -144,13 +134,11 @@ class KeyValueGroupedDatasetE2ETestSuite extends 
RemoteSparkSession {
   }
 
   test("flatMapSortedGroups") {
-    val session: SparkSession = spark
-    import session.implicits._
     val grouped = spark
       .range(10)
       .groupByKey(v => v % 2)
     val values = grouped
-      .flatMapSortedGroups(functions.desc("id")) { (g, iter) =>
+      .flatMapSortedGroups(desc("id")) { (g, iter) =>
         Iterator(String.valueOf(g), iter.mkString(","))
       }
       .collectAsList()
@@ -160,7 +148,7 @@ class KeyValueGroupedDatasetE2ETestSuite extends 
RemoteSparkSession {
     // Star is not allowed as group sort column
     val message = intercept[StatusRuntimeException] {
       grouped
-        .flatMapSortedGroups(functions.col("*")) { (g, iter) =>
+        .flatMapSortedGroups(col("*")) { (g, iter) =>
           Iterator(String.valueOf(g), iter.mkString(","))
         }
         .collectAsList()
@@ -169,8 +157,6 @@ class KeyValueGroupedDatasetE2ETestSuite extends 
RemoteSparkSession {
   }
 
   test("cogroup") {
-    val session: SparkSession = spark
-    import session.implicits._
     val grouped = spark
       .range(10)
       .groupByKey(v => v % 2)
@@ -187,8 +173,6 @@ class KeyValueGroupedDatasetE2ETestSuite extends 
RemoteSparkSession {
   }
 
   test("cogroupSorted") {
-    val session: SparkSession = spark
-    import session.implicits._
     val grouped = spark
       .range(10)
       .groupByKey(v => v % 2)
@@ -196,9 +180,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends 
RemoteSparkSession {
       .range(10)
       .groupByKey(v => v / 2)
     val values = grouped
-      .cogroupSorted(otherGrouped)(functions.desc("id"))(functions.desc("id")) 
{
-        (k, it, otherIt) =>
-          Iterator(String.valueOf(k), it.mkString(",") + ";" + 
otherIt.mkString(","))
+      .cogroupSorted(otherGrouped)(desc("id"))(desc("id")) { (k, it, otherIt) 
=>
+        Iterator(String.valueOf(k), it.mkString(",") + ";" + 
otherIt.mkString(","))
       }
       .collectAsList()
 
@@ -215,4 +198,253 @@ class KeyValueGroupedDatasetE2ETestSuite extends 
RemoteSparkSession {
         "4",
         ";9,8"))
   }
+
+  test("agg, keyAs") {
+    val ds = spark
+      .range(10)
+      .groupByKey(v => v % 2)
+      .keyAs[Double]
+      .agg(count("*"))
+
+    checkDatasetUnorderly(ds, (0.0, 5L), (1.0, 5L))
+  }
+
+  test("typed aggregation: expr") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+
+    checkDatasetUnorderly(
+      ds.groupByKey(_._1).agg(sum("_2").as[Long]),
+      ("a", 30L),
+      ("b", 3L),
+      ("c", 1L))
+  }
+
+  test("typed aggregation: expr, expr") {
+    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+
+    checkDatasetUnorderly(
+      ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]),
+      ("a", 30L, 32L),
+      ("b", 3L, 5L),
+      ("c", 1L, 2L))
+  }
+
+  test("typed aggregation: expr, expr, expr") {
+    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+
+    checkDatasetUnorderly(
+      ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], 
count("*")),
+      ("a", 30L, 32L, 2L),
+      ("b", 3L, 5L, 2L),
+      ("c", 1L, 2L, 1L))
+  }
+
+  test("typed aggregation: expr, expr, expr, expr") {
+    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+
+    checkDatasetUnorderly(
+      ds.groupByKey(_._1)
+        .agg(
+          sum("_2").as[Long],
+          sum($"_2" + 1).as[Long],
+          count("*").as[Long],
+          avg("_2").as[Double]),
+      ("a", 30L, 32L, 2L, 15.0),
+      ("b", 3L, 5L, 2L, 1.5),
+      ("c", 1L, 2L, 1L, 1.0))
+  }
+
+  test("typed aggregation: expr, expr, expr, expr, expr") {
+    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+
+    checkDatasetUnorderly(
+      ds.groupByKey(_._1)
+        .agg(
+          sum("_2").as[Long],
+          sum($"_2" + 1).as[Long],
+          count("*").as[Long],
+          avg("_2").as[Double],
+          countDistinct("*").as[Long]),
+      ("a", 30L, 32L, 2L, 15.0, 2L),
+      ("b", 3L, 5L, 2L, 1.5, 2L),
+      ("c", 1L, 2L, 1L, 1.0, 1L))
+  }
+
+  test("typed aggregation: expr, expr, expr, expr, expr, expr") {
+    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+
+    checkDatasetUnorderly(
+      ds.groupByKey(_._1)
+        .agg(
+          sum("_2").as[Long],
+          sum($"_2" + 1).as[Long],
+          count("*").as[Long],
+          avg("_2").as[Double],
+          countDistinct("*").as[Long],
+          max("_2").as[Long]),
+      ("a", 30L, 32L, 2L, 15.0, 2L, 20L),
+      ("b", 3L, 5L, 2L, 1.5, 2L, 2L),
+      ("c", 1L, 2L, 1L, 1.0, 1L, 1L))
+  }
+
+  test("typed aggregation: expr, expr, expr, expr, expr, expr, expr") {
+    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+
+    checkDatasetUnorderly(
+      ds.groupByKey(_._1)
+        .agg(
+          sum("_2").as[Long],
+          sum($"_2" + 1).as[Long],
+          count("*").as[Long],
+          avg("_2").as[Double],
+          countDistinct("*").as[Long],
+          max("_2").as[Long],
+          min("_2").as[Long]),
+      ("a", 30L, 32L, 2L, 15.0, 2L, 20L, 10L),
+      ("b", 3L, 5L, 2L, 1.5, 2L, 2L, 1L),
+      ("c", 1L, 2L, 1L, 1.0, 1L, 1L, 1L))
+  }
+
+  test("typed aggregation: expr, expr, expr, expr, expr, expr, expr, expr") {
+    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+
+    checkDatasetUnorderly(
+      ds.groupByKey(_._1)
+        .agg(
+          sum("_2").as[Long],
+          sum($"_2" + 1).as[Long],
+          count("*").as[Long],
+          avg("_2").as[Double],
+          countDistinct("*").as[Long],
+          max("_2").as[Long],
+          min("_2").as[Long],
+          mean("_2").as[Double]),
+      ("a", 30L, 32L, 2L, 15.0, 2L, 20L, 10L, 15.0),
+      ("b", 3L, 5L, 2L, 1.5, 2L, 2L, 1L, 1.5),
+      ("c", 1L, 2L, 1L, 1.0, 1L, 1L, 1L, 1.0))
+  }
+
+  test("SPARK-24762: Enable top-level Option of Product encoders") {
+    val data = Seq(Some((1, "a")), Some((2, "b")), None)
+    val ds = data.toDS()
+
+    checkDataset(ds, data: _*)
+
+    val schema = new StructType().add(
+      "value",
+      new StructType()
+        .add("_1", IntegerType, nullable = false)
+        .add("_2", StringType, nullable = true),
+      nullable = true)
+
+    assert(ds.schema == schema)
+
+    val nestedOptData = Seq(Some((Some((1, "a")), 2.0)), Some((Some((2, "b")), 
3.0)))
+    val nestedDs = nestedOptData.toDS()
+
+    checkDataset(nestedDs, nestedOptData: _*)
+
+    val nestedSchema = StructType(
+      Seq(StructField(
+        "value",
+        StructType(Seq(
+          StructField(
+            "_1",
+            StructType(Seq(
+              StructField("_1", IntegerType, nullable = false),
+              StructField("_2", StringType, nullable = true)))),
+          StructField("_2", DoubleType, nullable = false))),
+        nullable = true)))
+    assert(nestedDs.schema == nestedSchema)
+  }
+
+  test("SPARK-24762: Resolving Option[Product] field") {
+    val ds = Seq((1, ("a", 1.0)), (2, ("b", 2.0)), (3, null))
+      .toDS()
+      .as[(Int, Option[(String, Double)])]
+    checkDataset(ds, (1, Some(("a", 1.0))), (2, Some(("b", 2.0))), (3, None))
+  }
+
+  test("SPARK-24762: select Option[Product] field") {
+    val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
+    val ds1 = ds.select(expr("struct(_2, _2 + 1)").as[Option[(Int, Int)]])
+    checkDataset(ds1, Some((1, 2)), Some((2, 3)), Some((3, 4)))
+
+    val ds2 = ds.select(expr("if(_2 > 2, struct(_2, _2 + 1), 
null)").as[Option[(Int, Int)]])
+    checkDataset(ds2, None, None, Some((3, 4)))
+  }
+
+  test("SPARK-24762: typed agg on Option[Product] type") {
+    val ds = Seq(Some((1, 2)), Some((2, 3)), Some((1, 3))).toDS()
+    assert(ds.groupByKey(_.get._1).count().collect() === Seq((1, 2), (2, 1)))
+
+    assert(
+      ds.groupByKey(x => x).count().collect() ===
+        Seq((Some((1, 2)), 1), (Some((2, 3)), 1), (Some((1, 3)), 1)))
+  }
+
+  test("SPARK-25942: typed aggregation on primitive type") {
+    val ds = Seq(1, 2, 3).toDS()
+
+    val agg = ds
+      .groupByKey(_ >= 2)
+      .agg(sum("value").as[Long], sum($"value" + 1).as[Long])
+    checkDatasetUnorderly(agg, (false, 1L, 2L), (true, 5L, 7L))
+  }
+
+  test("SPARK-25942: typed aggregation on product type") {
+    val ds = Seq((1, 2), (2, 3), (3, 4)).toDS()
+    val agg = ds.groupByKey(x => x).agg(sum("_1").as[Long], sum($"_2" + 
1).as[Long])
+    checkDatasetUnorderly(agg, ((1, 2), 1L, 3L), ((2, 3), 2L, 4L), ((3, 4), 
3L, 5L))
+  }
+
+  test("SPARK-26085: fix key attribute name for atomic type for typed 
aggregation") {
+    // TODO(SPARK-43416): Recursively rename the position based tuple to the 
schema name from the
+    //  server.
+    val ds = Seq(1, 2, 3).toDS()
+    assert(ds.groupByKey(x => x).count().schema.head.name == "_1")
+
+    // Enable legacy flag to follow previous Spark behavior
+    withSQLConf("spark.sql.legacy.dataset.nameNonStructGroupingKeyAsValue" -> 
"true") {
+      assert(ds.groupByKey(x => x).count().schema.head.name == "_1")
+    }
+  }
+
+  test("reduceGroups") {
+    val ds = Seq("abc", "xyz", "hello").toDS()
+    checkDatasetUnorderly(
+      ds.groupByKey(_.length).reduceGroups(_ + _),
+      (3, "abcxyz"),
+      (5, "hello"))
+  }
+
+  test("groupby") {
+    val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 
1, 1))
+      .toDF("key", "seq", "value")
+    val grouped = ds.groupBy($"key").as[String, (String, Int, Int)]
+    val aggregated = grouped
+      .flatMapSortedGroups($"seq", expr("length(key)"), $"value") { (g, iter) 
=>
+        Iterator(g, iter.mkString(", "))
+      }
+
+    checkDatasetUnorderly(
+      aggregated,
+      "a",
+      "(a,1,10), (a,2,20)",
+      "b",
+      "(b,1,2), (b,2,1)",
+      "c",
+      "(c,1,1)")
+  }
+
+  test("groupby - keyAs, keys") {
+    val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 
1, 1))
+      .toDF("key", "seq", "value")
+    val grouped = ds.groupBy($"value").as[String, (String, Int, Int)]
+    val keys = grouped.keyAs[String].keys.sort($"value")
+
+    checkDataset(keys, "1", "2", "10", "20")
+  }
 }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
index b07d1459df5..b5bbee67803 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
@@ -197,4 +197,22 @@ class UserDefinedFunctionE2ETestSuite extends 
RemoteSparkSession {
     spark.range(10).repartition(1).foreachPartition(func)
     assert(sum.get() == 0) // The value is not 45
   }
+
+  test("Dataset reduce") {
+    val session: SparkSession = spark
+    import session.implicits._
+    assert(spark.range(10).map(_ + 1).reduce(_ + _) == 55)
+  }
+
+  test("Dataset reduce - java") {
+    val session: SparkSession = spark
+    import session.implicits._
+    assert(
+      spark
+        .range(10)
+        .map(_ + 1)
+        .reduce(new ReduceFunction[Long] {
+          override def call(v1: Long, v2: Long): Long = v1 + v2
+        }) == 55)
+  }
 }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 28a28994a76..32a44c350d9 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -165,7 +165,6 @@ object CheckConnectJvmClientCompatibility {
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.joinWith"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.metadataColumn"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.selectUntyped"), 
// protected
-      ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.reduce"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.explode"), 
// deprecated
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.rdd"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.toJavaRDD"),
@@ -190,17 +189,10 @@ object CheckConnectJvmClientCompatibility {
       ), // streaming
       ProblemFilters.exclude[Problem](
         "org.apache.spark.sql.KeyValueGroupedDataset.queryExecution"),
-      
ProblemFilters.exclude[Problem]("org.apache.spark.sql.KeyValueGroupedDataset.reduceGroups"),
-      
ProblemFilters.exclude[Problem]("org.apache.spark.sql.KeyValueGroupedDataset.agg"),
-      ProblemFilters.exclude[Problem](
-        "org.apache.spark.sql.KeyValueGroupedDataset.aggUntyped"
-      ), // protected internal
-      
ProblemFilters.exclude[Problem]("org.apache.spark.sql.KeyValueGroupedDataset.count"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.KeyValueGroupedDataset.this"),
 
       // RelationalGroupedDataset
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.apply"),
-      
ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.as"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.this"),
 
       // SparkSession
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/QueryTest.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/QueryTest.scala
index 1c3f49f897f..fdbb3edbf84 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/QueryTest.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/QueryTest.scala
@@ -21,7 +21,7 @@ import java.util.TimeZone
 
 import org.scalatest.Assertions
 
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.catalyst.util.sideBySide
 
 abstract class QueryTest extends RemoteSparkSession {
@@ -45,6 +45,40 @@ abstract class QueryTest extends RemoteSparkSession {
   protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit 
= {
     checkAnswer(df, expectedAnswer.collect())
   }
+
+  /**
+   * Evaluates a dataset to make sure that the result of calling collect 
matches the given
+   * expected answer.
+   */
+  protected def checkDataset[T](ds: => Dataset[T], expectedAnswer: T*): Unit = 
{
+    val result = ds.collect()
+
+    if (!QueryTest.compare(result.toSeq, expectedAnswer)) {
+      fail(s"""
+              |Decoded objects do not match expected objects:
+              |expected: $expectedAnswer
+              |actual:   ${result.toSeq}
+       """.stripMargin)
+    }
+  }
+
+  /**
+   * Evaluates a dataset to make sure that the result of calling collect 
matches the given
+   * expected answer, after sort.
+   */
+  protected def checkDatasetUnorderly[T: Ordering](
+      ds: => Dataset[T],
+      expectedAnswer: T*): Unit = {
+    val result = ds.collect()
+
+    if (!QueryTest.compare(result.toSeq.sorted, expectedAnswer.sorted)) {
+      fail(s"""
+              |Decoded objects do not match expected objects:
+              |expected: $expectedAnswer
+              |actual:   ${result.toSeq}
+       """.stripMargin)
+    }
+  }
 }
 
 object QueryTest extends Assertions {
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
index 7cd251b245f..06a6c74f268 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
@@ -95,5 +95,9 @@ private[sql] object UdfUtils extends Serializable {
       }
   }
 
+  def mapReduceFuncToScalaFunc[T](func: ReduceFunction[T]): (T, T) => T = 
func.call
+
   def identical[T](): T => T = t => t
+
+  def noOp[V, K](): V => K = _ => null.asInstanceOf[K]
 }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index b86ed866d6e..8562722a95b 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -40,7 +40,7 @@ import 
org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.Streami
 import org.apache.spark.connect.proto.WriteStreamOperationStart
 import org.apache.spark.connect.proto.WriteStreamOperationStart.TriggerCase
 import org.apache.spark.ml.{functions => MLFunctions}
-import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession}
+import org.apache.spark.sql.{Column, Dataset, Encoders, 
RelationalGroupedDataset, SparkSession}
 import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
 import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, 
FunctionIdentifier}
 import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, 
MultiAlias, ParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, 
UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, 
UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
@@ -67,7 +67,8 @@ import 
org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartiti
 import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
 import org.apache.spark.sql.execution.stat.StatFunctions
 import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper
-import org.apache.spark.sql.internal.CatalogImpl
+import org.apache.spark.sql.expressions.ReduceAggregator
+import org.apache.spark.sql.internal.{CatalogImpl, TypedAggUtils}
 import org.apache.spark.sql.streaming.Trigger
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -526,7 +527,7 @@ class SparkConnectPlanner(val session: SparkSession) {
   private def transformTypedMapPartitions(
       fun: proto.CommonInlineUserDefinedFunction,
       child: LogicalPlan): LogicalPlan = {
-    val udf = ScalaUdf(fun)
+    val udf = TypedScalaUdf(fun)
     val deserialized = DeserializeToObject(udf.inputDeserializer(), 
udf.inputObjAttr, child)
     val mapped = MapPartitions(
       udf.function.asInstanceOf[Iterator[Any] => Iterator[Any]],
@@ -562,7 +563,7 @@ class SparkConnectPlanner(val session: SparkSession) {
   private def transformTypedGroupMap(
       rel: proto.GroupMap,
       commonUdf: proto.CommonInlineUserDefinedFunction): LogicalPlan = {
-    val udf = ScalaUdf(commonUdf)
+    val udf = TypedScalaUdf(commonUdf)
     val ds = UntypedKeyValueGroupedDataset(
       rel.getInput,
       rel.getGroupingExpressionsList,
@@ -614,7 +615,7 @@ class SparkConnectPlanner(val session: SparkSession) {
   private def transformTypedCoGroupMap(
       rel: proto.CoGroupMap,
       commonUdf: proto.CommonInlineUserDefinedFunction): LogicalPlan = {
-    val udf = ScalaUdf(commonUdf)
+    val udf = TypedScalaUdf(commonUdf)
     val left = UntypedKeyValueGroupedDataset(
       rel.getInput,
       rel.getInputGroupingExpressionsList,
@@ -644,57 +645,89 @@ class SparkConnectPlanner(val session: SparkSession) {
   }
 
   /**
-   * This is the untyped version of [[KeyValueGroupedDataset]].
+   * This is the untyped version of 
[[org.apache.spark.sql.KeyValueGroupedDataset]].
    */
   private case class UntypedKeyValueGroupedDataset(
       kEncoder: ExpressionEncoder[_],
       vEncoder: ExpressionEncoder[_],
-      valueDeserializer: Expression,
       analyzed: LogicalPlan,
       dataAttributes: Seq[Attribute],
       groupingAttributes: Seq[Attribute],
-      sortOrder: Seq[SortOrder])
+      sortOrder: Seq[SortOrder]) {
+    val valueDeserializer: Expression =
+      UnresolvedDeserializer(vEncoder.deserializer, dataAttributes)
+  }
+
   private object UntypedKeyValueGroupedDataset {
     def apply(
         input: proto.Relation,
         groupingExprs: java.util.List[proto.Expression],
         sortingExprs: java.util.List[proto.Expression]): 
UntypedKeyValueGroupedDataset = {
-      val logicalPlan = transformRelation(input)
-      assert(groupingExprs.size() == 1)
-      val groupFunc = groupingExprs.asScala.toSeq
-        .map(expr => unpackUdf(expr.getCommonInlineUserDefinedFunction))
-        .head
-
-      assert(groupFunc.inputEncoders.size == 1)
-      val vEnc = ExpressionEncoder(groupFunc.inputEncoders.head)
-      val kEnc = ExpressionEncoder(groupFunc.outputEncoder)
-
-      val withGroupingKey = new AppendColumns(
-        groupFunc.function.asInstanceOf[Any => Any],
-        vEnc.clsTag.runtimeClass,
-        vEnc.schema,
-        UnresolvedDeserializer(vEnc.deserializer),
-        kEnc.namedExpressions,
-        logicalPlan)
-
-      // The input logical plan of KeyValueGroupedDataset need to be executed 
and analyzed
-      val analyzed = session.sessionState.executePlan(withGroupingKey).analyzed
-      val dataAttributes = logicalPlan.output
-      val groupingAttributes = withGroupingKey.newColumns
-      val valueDeserializer = UnresolvedDeserializer(vEnc.deserializer, 
dataAttributes)
 
       // Compute sort order
       val sortExprs =
         sortingExprs.asScala.toSeq.map(expr => transformExpression(expr))
       val sortOrder: Seq[SortOrder] = MapGroups.sortOrder(sortExprs)
 
+      apply(transformRelation(input), groupingExprs, sortOrder)
+    }
+
+    def apply(
+        logicalPlan: LogicalPlan,
+        groupingExprs: java.util.List[proto.Expression],
+        sortOrder: Seq[SortOrder]): UntypedKeyValueGroupedDataset = {
+      // If created via ds#groupByKey, then there should be only one 
groupingFunc.
+      // If created via relationalGroupedDS#as, then we are expecting a dummy 
groupingFuc
+      // (for types) + groupingExprs
+      if (groupingExprs.size() == 1) {
+        createFromGroupByKeyFunc(logicalPlan, groupingExprs, sortOrder)
+      } else if (groupingExprs.size() > 1) {
+        createFromRelationalDataset(logicalPlan, groupingExprs, sortOrder)
+      } else {
+        throw InvalidPlanInput(
+          "The grouping expression cannot be absent for 
KeyValueGroupedDataset")
+      }
+    }
+
+    private def createFromRelationalDataset(
+        logicalPlan: LogicalPlan,
+        groupingExprs: java.util.List[proto.Expression],
+        sortOrder: Seq[SortOrder]): UntypedKeyValueGroupedDataset = {
+      assert(groupingExprs.size() >= 1)
+      val dummyFunc = TypedScalaUdf(groupingExprs.get(0))
+      val groupExprs = groupingExprs.asScala.toSeq.drop(1).map(expr => 
transformExpression(expr))
+
+      val (qe, aliasedGroupings) =
+        RelationalGroupedDataset.handleGroupingExpression(logicalPlan, 
session, groupExprs)
+
+      UntypedKeyValueGroupedDataset(
+        dummyFunc.outEnc,
+        dummyFunc.inEnc,
+        qe.analyzed,
+        logicalPlan.output,
+        aliasedGroupings,
+        sortOrder)
+    }
+
+    private def createFromGroupByKeyFunc(
+        logicalPlan: LogicalPlan,
+        groupingExprs: java.util.List[proto.Expression],
+        sortOrder: Seq[SortOrder]): UntypedKeyValueGroupedDataset = {
+      assert(groupingExprs.size() == 1)
+      val groupFunc = TypedScalaUdf(groupingExprs.get(0))
+      val vEnc = groupFunc.inEnc
+      val kEnc = groupFunc.outEnc
+
+      val withGroupingKey = AppendColumns(groupFunc.function, vEnc, kEnc, 
logicalPlan)
+      // The input logical plan of KeyValueGroupedDataset need to be executed 
and analyzed
+      val analyzed = session.sessionState.executePlan(withGroupingKey).analyzed
+
       UntypedKeyValueGroupedDataset(
         kEnc,
         vEnc,
-        valueDeserializer,
         analyzed,
-        dataAttributes,
-        groupingAttributes,
+        logicalPlan.output,
+        withGroupingKey.newColumns,
         sortOrder)
     }
   }
@@ -702,7 +735,7 @@ class SparkConnectPlanner(val session: SparkSession) {
   /**
    * The UDF used in typed APIs, where the input column is absent.
    */
-  private case class ScalaUdf(
+  private case class TypedScalaUdf(
       function: AnyRef,
       outEnc: ExpressionEncoder[_],
       outputObjAttr: Attribute,
@@ -713,17 +746,26 @@ class SparkConnectPlanner(val session: SparkSession) {
       UnresolvedDeserializer(inEnc.deserializer, inputAttributes)
     }
   }
-  private object ScalaUdf {
-    def apply(commonUdf: proto.CommonInlineUserDefinedFunction): ScalaUdf = {
+  private object TypedScalaUdf {
+    def apply(expr: proto.Expression): TypedScalaUdf = {
+      if (expr.hasCommonInlineUserDefinedFunction
+        && expr.getCommonInlineUserDefinedFunction.hasScalarScalaUdf) {
+        apply(expr.getCommonInlineUserDefinedFunction)
+      } else {
+        throw InvalidPlanInput(s"Expecting a Scala UDF, but get 
${expr.getExprTypeCase}")
+      }
+    }
+
+    def apply(commonUdf: proto.CommonInlineUserDefinedFunction): TypedScalaUdf 
= {
       val udf = unpackUdf(commonUdf)
       val outEnc = ExpressionEncoder(udf.outputEncoder)
       // There might be more than one inputs, but we only interested in the 
first one.
       // Most typed API takes one UDF input.
       // For the few that takes more than one inputs, e.g. grouping function 
mapping UDFs,
-      // we only interested in the first input which is the key of the 
grouping function.
+      // the first input which is the key of the grouping function.
       assert(udf.inputEncoders.nonEmpty)
       val inEnc = ExpressionEncoder(udf.inputEncoders.head) // single input 
encoder or key encoder
-      ScalaUdf(udf.function, outEnc, generateObjAttr(outEnc), inEnc, 
generateObjAttr(inEnc))
+      TypedScalaUdf(udf.function, outEnc, generateObjAttr(outEnc), inEnc, 
generateObjAttr(inEnc))
     }
   }
 
@@ -1117,27 +1159,31 @@ class SparkConnectPlanner(val session: SparkSession) {
     assert(rel.hasInput)
     val baseRel = transformRelation(rel.getInput)
     val cond = rel.getCondition
-    cond.getExprTypeCase match {
-      case proto.Expression.ExprTypeCase.COMMON_INLINE_USER_DEFINED_FUNCTION
-          if isTypedFilter(cond.getCommonInlineUserDefinedFunction) =>
-        transformTypedFilter(cond.getCommonInlineUserDefinedFunction, baseRel)
-      case _ =>
-        logical.Filter(condition = transformExpression(cond), child = baseRel)
+    if (isTypedScalaUdfExpr(cond)) {
+      transformTypedFilter(cond.getCommonInlineUserDefinedFunction, baseRel)
+    } else {
+      logical.Filter(condition = transformExpression(cond), child = baseRel)
     }
   }
 
-  private def isTypedFilter(udf: proto.CommonInlineUserDefinedFunction): 
Boolean = {
-    // It is a scala udf && the udf argument is an unresolved start.
-    // This means the udf is a typed filter to filter on all inputs
-    udf.getFunctionCase == 
proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF &&
-    udf.getArgumentsCount == 1 &&
-    udf.getArguments(0).getExprTypeCase == 
proto.Expression.ExprTypeCase.UNRESOLVED_STAR
+  private def isTypedScalaUdfExpr(expr: proto.Expression): Boolean = {
+    expr.getExprTypeCase match {
+      case proto.Expression.ExprTypeCase.COMMON_INLINE_USER_DEFINED_FUNCTION =>
+        val udf = expr.getCommonInlineUserDefinedFunction
+        // A typed scala udf is a scala udf && the udf argument is an 
unresolved start.
+        udf.getFunctionCase ==
+          proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF 
&&
+          udf.getArgumentsCount == 1 &&
+          udf.getArguments(0).getExprTypeCase == 
proto.Expression.ExprTypeCase.UNRESOLVED_STAR
+      case _ =>
+        false
+    }
   }
 
   private def transformTypedFilter(
       fun: proto.CommonInlineUserDefinedFunction,
       child: LogicalPlan): TypedFilter = {
-    val udf = ScalaUdf(fun)
+    val udf = TypedScalaUdf(fun)
     TypedFilter(udf.function, child)(udf.inEnc)
   }
 
@@ -1853,13 +1899,39 @@ class SparkConnectPlanner(val session: SparkSession) {
   }
 
   private def transformAggregate(rel: proto.Aggregate): LogicalPlan = {
+    rel.getGroupType match {
+      case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY
+          // This relies on the assumption that a KVGDS always requires the 
head to be a Typed UDF.
+          // This is the case for datasets created via groupByKey,
+          // and also via RelationalGroupedDS#as, as the first is a dummy UDF 
currently.
+          if rel.getGroupingExpressionsList.size() >= 1 &&
+            isTypedScalaUdfExpr(rel.getGroupingExpressionsList.get(0)) =>
+        transformKeyValueGroupedAggregate(rel)
+      case _ =>
+        transformRelationalGroupedAggregate(rel)
+    }
+  }
+
+  private def transformKeyValueGroupedAggregate(rel: proto.Aggregate): 
LogicalPlan = {
+    val input = transformRelation(rel.getInput)
+    val ds = UntypedKeyValueGroupedDataset(input, 
rel.getGroupingExpressionsList, Seq.empty)
+
+    val keyColumn = TypedAggUtils.aggKeyColumn(ds.kEncoder, 
ds.groupingAttributes)
+    val namedColumns = rel.getAggregateExpressionsList.asScala.toSeq
+      .map(expr => transformExpressionWithTypedReduceExpression(expr, input))
+      .map(toNamedExpression)
+    logical.Aggregate(ds.groupingAttributes, keyColumn +: namedColumns, 
ds.analyzed)
+  }
+
+  private def transformRelationalGroupedAggregate(rel: proto.Aggregate): 
LogicalPlan = {
     if (!rel.hasInput) {
       throw InvalidPlanInput("Aggregate needs a plan input")
     }
     val input = transformRelation(rel.getInput)
 
     val groupingExprs = 
rel.getGroupingExpressionsList.asScala.toSeq.map(transformExpression)
-    val aggExprs = 
rel.getAggregateExpressionsList.asScala.toSeq.map(transformExpression)
+    val aggExprs = rel.getAggregateExpressionsList.asScala.toSeq
+      .map(expr => transformExpressionWithTypedReduceExpression(expr, input))
     val aliasedAgg = (groupingExprs ++ aggExprs).map(toNamedExpression)
 
     rel.getGroupType match {
@@ -1917,6 +1989,37 @@ class SparkConnectPlanner(val session: SparkSession) {
     }
   }
 
+  private def transformTypedReduceExpression(
+      fun: proto.Expression.UnresolvedFunction,
+      dataAttributes: Seq[Attribute]): Expression = {
+    assert(fun.getFunctionName == "reduce")
+    if (fun.getArgumentsCount != 1) {
+      throw InvalidPlanInput("reduce requires single child expression")
+    }
+    val udf = fun.getArgumentsList.asScala.toSeq.map(transformExpression) 
match {
+      case Seq(f: ScalaUDF) =>
+        f
+      case other =>
+        throw InvalidPlanInput(s"reduce should carry a scalar scala udf, but 
got $other")
+    }
+    assert(udf.outputEncoder.isDefined)
+    val tEncoder = udf.outputEncoder.get // (T, T) => T
+    val reduce = ReduceAggregator(udf.function)(tEncoder).toColumn.expr
+    TypedAggUtils.withInputType(reduce, tEncoder, dataAttributes)
+  }
+
+  private def transformExpressionWithTypedReduceExpression(
+      expr: proto.Expression,
+      plan: LogicalPlan): Expression = {
+    expr.getExprTypeCase match {
+      case proto.Expression.ExprTypeCase.UNRESOLVED_FUNCTION
+          if expr.getUnresolvedFunction.getFunctionName == "reduce" =>
+        // The reduce func needs the input data attribute, thus handle it 
specially here
+        transformTypedReduceExpression(expr.getUnresolvedFunction, plan.output)
+      case _ => transformExpression(expr)
+    }
+  }
+
   def process(
       command: proto.Command,
       userId: String,
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 35640c4aec8..54c0b84ff52 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -337,6 +337,22 @@ object AppendColumns {
       encoderFor[U].namedExpressions,
       child)
   }
+
+  private[sql] def apply(
+      func: AnyRef,
+      inEncoder: ExpressionEncoder[_],
+      outEncoder: ExpressionEncoder[_],
+      child: LogicalPlan,
+      inputAttributes: Seq[Attribute] = Nil): AppendColumns = {
+    new AppendColumns(
+      func.asInstanceOf[Any => Any],
+      inEncoder.clsTag.runtimeClass,
+      inEncoder.schema,
+      UnresolvedDeserializer(inEncoder.deserializer, inputAttributes),
+      outEncoder.namedExpressions,
+      child
+    )
+  }
 }
 
 /**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 3c9f3e58cec..8d4cb50dfa2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.util.{toPrettySQL, 
CharVarcharUtils}
 import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
 import org.apache.spark.sql.expressions.Window
 import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.internal.TypedAggUtils
 import org.apache.spark.sql.types._
 
 private[sql] object Column {
@@ -81,17 +82,7 @@ class TypedColumn[-T, U](
   private[sql] def withInputType(
       inputEncoder: ExpressionEncoder[_],
       inputAttributes: Seq[Attribute]): TypedColumn[T, U] = {
-    val unresolvedDeserializer = 
UnresolvedDeserializer(inputEncoder.deserializer, inputAttributes)
-
-    // This only inserts inputs into typed aggregate expressions. For untyped 
aggregate expressions,
-    // the resolving is handled in the analyzer directly.
-    val newExpr = expr transform {
-      case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty =>
-        ta.withInputInfo(
-          deser = unresolvedDeserializer,
-          cls = inputEncoder.clsTag.runtimeClass,
-          schema = inputEncoder.schema)
-    }
+    val newExpr = TypedAggUtils.withInputType(expr, inputEncoder, 
inputAttributes)
     new TypedColumn[T, U](newExpr, encoder)
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index a8e78043c28..4c2ccb27eab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -21,11 +21,11 @@ import scala.collection.JavaConverters._
 
 import org.apache.spark.api.java.function._
 import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
-import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, 
CreateStruct, Expression, SortOrder}
+import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, 
Expression, SortOrder}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.expressions.ReduceAggregator
-import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.TypedAggUtils
 import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, 
OutputMode}
 
 /**
@@ -673,16 +673,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
     val encoders = columns.map(_.encoder)
     val namedColumns =
       columns.map(_.withInputType(vExprEnc, dataAttributes).named)
-    val keyColumn = if (!kExprEnc.isSerializedAsStructForTopLevel) {
-      assert(groupingAttributes.length == 1)
-      if (SQLConf.get.nameNonStructGroupingKeyAsValue) {
-        groupingAttributes.head
-      } else {
-        Alias(groupingAttributes.head, "key")()
-      }
-    } else {
-      Alias(CreateStruct(groupingAttributes), "key")()
-    }
+    val keyColumn = TypedAggUtils.aggKeyColumn(kExprEnc, groupingAttributes)
     val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, 
logicalPlan)
     val execution = new QueryExecution(sparkSession, aggregate)
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 31c303921f3..29138b5bf58 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
 import org.apache.spark.sql.catalyst.util.toPrettySQL
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
+import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
 import org.apache.spark.sql.streaming.OutputMode
 import org.apache.spark.sql.types.{NumericType, StructType}
@@ -53,6 +54,7 @@ class RelationalGroupedDataset protected[sql](
     private[sql] val df: DataFrame,
     private[sql] val groupingExprs: Seq[Expression],
     groupType: RelationalGroupedDataset.GroupType) {
+  import RelationalGroupedDataset._
 
   private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = {
     val aggregates = if 
(df.sparkSession.sessionState.conf.dataFrameRetainGroupColumns) {
@@ -85,14 +87,6 @@ class RelationalGroupedDataset protected[sql](
     }
   }
 
-  private[this] def alias(expr: Expression): NamedExpression = expr match {
-    case expr: NamedExpression => expr
-    case a: AggregateExpression if 
a.aggregateFunction.isInstanceOf[TypedAggregateExpression] =>
-      UnresolvedAlias(a, Some(Column.generateAlias))
-    case u: UnresolvedFunction => UnresolvedAlias(expr, None)
-    case expr: Expression => Alias(expr, toPrettySQL(expr))()
-  }
-
   private[this] def aggregateNumericColumns(colNames: String*)(f: Expression 
=> AggregateFunction)
     : DataFrame = {
 
@@ -143,25 +137,15 @@ class RelationalGroupedDataset protected[sql](
     val keyEncoder = encoderFor[K]
     val valueEncoder = encoderFor[T]
 
-    // Resolves grouping expressions.
-    val dummyPlan = Project(groupingExprs.map(alias), 
LocalRelation(df.logicalPlan.output))
-    val analyzedPlan = df.sparkSession.sessionState.analyzer.execute(dummyPlan)
-      .asInstanceOf[Project]
-    df.sparkSession.sessionState.analyzer.checkAnalysis(analyzedPlan)
-    val aliasedGroupings = analyzedPlan.projectList
-
-    // Adds the grouping expressions that are not in base DataFrame into 
outputs.
-    val addedCols = aliasedGroupings.filter(g => 
!df.logicalPlan.outputSet.contains(g.toAttribute))
-    val qe = Dataset.ofRows(
-      df.sparkSession,
-      Project(df.logicalPlan.output ++ addedCols, 
df.logicalPlan)).queryExecution
+    val (qe, groupingAttributes) =
+      handleGroupingExpression(df.logicalPlan, df.sparkSession, groupingExprs)
 
     new KeyValueGroupedDataset(
       keyEncoder,
       valueEncoder,
       qe,
       df.logicalPlan.output,
-      aliasedGroupings.map(_.toAttribute))
+      groupingAttributes)
   }
 
   /**
@@ -700,6 +684,33 @@ private[sql] object RelationalGroupedDataset {
     new RelationalGroupedDataset(df, groupingExprs, groupType: GroupType)
   }
 
+  private[sql] def handleGroupingExpression(
+      logicalPlan: LogicalPlan,
+      sparkSession: SparkSession,
+      groupingExprs: Seq[Expression]): (QueryExecution, Seq[Attribute]) = {
+    // Resolves grouping expressions.
+    val dummyPlan = Project(groupingExprs.map(alias), 
LocalRelation(logicalPlan.output))
+    val analyzedPlan = sparkSession.sessionState.analyzer.execute(dummyPlan)
+      .asInstanceOf[Project]
+    sparkSession.sessionState.analyzer.checkAnalysis(analyzedPlan)
+    val aliasedGroupings = analyzedPlan.projectList
+
+    // Adds the grouping expressions that are not in base DataFrame into 
outputs.
+    val addedCols = aliasedGroupings.filter(g => 
!logicalPlan.outputSet.contains(g.toAttribute))
+    val newPlan = Project(logicalPlan.output ++ addedCols, logicalPlan)
+    val qe = sparkSession.sessionState.executePlan(newPlan)
+
+    (qe, aliasedGroupings.map(_.toAttribute))
+  }
+
+  private def alias(expr: Expression): NamedExpression = expr match {
+    case expr: NamedExpression => expr
+    case a: AggregateExpression if 
a.aggregateFunction.isInstanceOf[TypedAggregateExpression] =>
+      UnresolvedAlias(a, Some(Column.generateAlias))
+    case u: UnresolvedFunction => UnresolvedAlias(expr, None)
+    case expr: Expression => Alias(expr, toPrettySQL(expr))()
+  }
+
   /**
    * The Grouping Type
    */
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
index e266ae55cc4..41306cd0a99 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
@@ -66,3 +66,9 @@ private[sql] class ReduceAggregator[T: Encoder](func: (T, T) 
=> T)
     reduction._2
   }
 }
+
+private[sql] object ReduceAggregator {
+  def apply[T: Encoder](f: AnyRef): ReduceAggregator[T] = {
+    new ReduceAggregator(f.asInstanceOf[(T, T) => T])
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala
new file mode 100644
index 00000000000..68bda47cf8c
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/TypedAggUtils.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.internal
+
+import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
+
+private[sql] object TypedAggUtils {
+
+  def aggKeyColumn[A](
+      encoder: ExpressionEncoder[A],
+      groupingAttributes: Seq[Attribute]): NamedExpression = {
+    if (!encoder.isSerializedAsStructForTopLevel) {
+      assert(groupingAttributes.length == 1)
+      if (SQLConf.get.nameNonStructGroupingKeyAsValue) {
+        groupingAttributes.head
+      } else {
+        Alias(groupingAttributes.head, "key")()
+      }
+    } else {
+      Alias(CreateStruct(groupingAttributes), "key")()
+    }
+  }
+
+  /**
+   * Insert inputs into typed aggregate expressions. For untyped aggregate 
expressions,
+   * the resolving is handled in the analyzer directly.
+   */
+  private[sql] def withInputType(
+      expr: Expression,
+      inputEncoder: ExpressionEncoder[_],
+      inputAttributes: Seq[Attribute]): Expression = {
+    val unresolvedDeserializer = 
UnresolvedDeserializer(inputEncoder.deserializer, inputAttributes)
+
+    expr transform {
+      case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty =>
+        ta.withInputInfo(
+          deser = unresolvedDeserializer,
+          cls = inputEncoder.clsTag.runtimeClass,
+          schema = inputEncoder.schema
+        )
+    }
+  }
+}
+


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to