This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e20db137d2de [SPARK-48510] 2/2] Support UDAF `toColumn` API in Spark 
Connect
e20db137d2de is described below

commit e20db137d2de26594b38c7e257a3d863de882022
Author: Paddy Xu <[email protected]>
AuthorDate: Fri Jul 12 20:21:03 2024 +0900

    [SPARK-48510] 2/2] Support UDAF `toColumn` API in Spark Connect
    
    ### What changes were proposed in this pull request?
    
    This PR follows https://github.com/apache/spark/pull/46245 to add support 
`udaf.toColumn` API in Spark Connect.
    
    Here we introduce a new Protobuf message, `proto.TypedAggregateExpression`, 
that includes a serialized UDF packet. On the server, we unpack it into an 
`Aggregator` object and generate a real `TypedAggregateExpression` instance 
with the encoder information passed along with the UDF.
    
    ### Why are the changes needed?
    
    Because the `toColumn` API is not supported in the previous PR.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, from now on users could create typed UDAF using `udaf.toColumn` API/.
    
    ### How was this patch tested?
    
    New tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Nope.
    
    Closes #46849 from xupefei/connect-udaf-tocolumn.
    
    Authored-by: Paddy Xu <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../main/protobuf/spark/connect/expressions.proto  |   6 +
 .../sql/connect/planner/SparkConnectPlanner.scala  |  63 ++++++++-
 .../apache/spark/sql/expressions/Aggregator.scala  |  53 +++++++-
 .../sql/expressions/UserDefinedFunction.scala      |   2 +-
 .../sql/UserDefinedFunctionE2ETestSuite.scala      |  72 ++++++++--
 .../pyspark/sql/connect/proto/expressions_pb2.py   | 146 +++++++++++----------
 .../pyspark/sql/connect/proto/expressions_pb2.pyi  |  30 +++++
 7 files changed, 278 insertions(+), 94 deletions(-)

diff --git a/connect/common/src/main/protobuf/spark/connect/expressions.proto 
b/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 860e92357616..3a91371fd3b2 100644
--- a/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -51,6 +51,7 @@ message Expression {
     CallFunction call_function = 16;
     NamedArgumentExpression named_argument_expression = 17;
     MergeAction merge_action = 19;
+    TypedAggregateExpression typed_aggregate_expression = 20;
 
     // This field is used to mark extensions to the protocol. When plugins 
generate arbitrary
     // relations they can add them here. During the planning the correct 
resolution is done.
@@ -402,6 +403,11 @@ message JavaUDF {
   bool aggregate = 3;
 }
 
+message TypedAggregateExpression {
+  // (Required) The aggregate function object packed into bytes.
+  ScalarScalaUDF scalar_scala_udf = 1;
+}
+
 message CallFunction {
   // (Required) Unparsed name of the SQL function.
   string function_name = 1;
diff --git 
a/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 449e923beae3..4702f09a14c2 100644
--- 
a/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -51,7 +51,7 @@ import 
org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, Mu
 import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, 
ExpressionEncoder, RowEncoder}
 import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
BloomFilterAggregate}
 import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
 import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, 
LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
 import org.apache.spark.sql.catalyst.plans.logical
@@ -67,6 +67,7 @@ import org.apache.spark.sql.connect.service.{ExecuteHolder, 
SessionHolder, Spark
 import org.apache.spark.sql.connect.utils.MetricGenerator
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
 import org.apache.spark.sql.execution.arrow.ArrowConverters
 import org.apache.spark.sql.execution.command.CreateViewCommand
 import org.apache.spark.sql.execution.datasources.LogicalRelation
@@ -1455,7 +1456,7 @@ class SparkConnectPlanner(
     }
 
     val projection = rel.getExpressionsList.asScala.toSeq
-      .map(transformExpression)
+      .map(transformExpression(_, Some(baseRel)))
       .map(toNamedExpression)
 
     logical.Project(projectList = projection, child = baseRel)
@@ -1472,21 +1473,40 @@ class SparkConnectPlanner(
    *   Catalyst expression
    */
   @DeveloperApi
-  def transformExpression(exp: proto.Expression): Expression = if 
(exp.hasCommon) {
+  def transformExpression(exp: proto.Expression): Expression = 
transformExpression(exp, None)
+
+  /**
+   * Transforms an input protobuf expression into the Catalyst expression. 
This is usually not
+   * called directly. Typically the planner will traverse the expressions 
automatically, only
+   * plugins are expected to manually perform expression transformations.
+   *
+   * @param exp
+   *   the input expression
+   * @param baseRelationOpt
+   *   inputs of the base relation that contains this expression
+   * @return
+   *   Catalyst expression
+   */
+  @DeveloperApi
+  def transformExpression(
+      exp: proto.Expression,
+      baseRelationOpt: Option[LogicalPlan]): Expression = if (exp.hasCommon) {
     try {
       val origin = exp.getCommon.getOrigin
       PySparkCurrentOrigin.set(
         origin.getPythonOrigin.getFragment,
         origin.getPythonOrigin.getCallSite)
-      withOrigin { doTransformExpression(exp) }
+      withOrigin { doTransformExpression(exp, baseRelationOpt) }
     } finally {
       PySparkCurrentOrigin.clear()
     }
   } else {
-    doTransformExpression(exp)
+    doTransformExpression(exp, baseRelationOpt)
   }
 
-  private def doTransformExpression(exp: proto.Expression): Expression = {
+  private def doTransformExpression(
+      exp: proto.Expression,
+      baseRelationOpt: Option[LogicalPlan]): Expression = {
     exp.getExprTypeCase match {
       case proto.Expression.ExprTypeCase.LITERAL => 
transformLiteral(exp.getLiteral)
       case proto.Expression.ExprTypeCase.UNRESOLVED_ATTRIBUTE =>
@@ -1523,6 +1543,8 @@ class SparkConnectPlanner(
         transformNamedArgumentExpression(exp.getNamedArgumentExpression)
       case proto.Expression.ExprTypeCase.MERGE_ACTION =>
         transformMergeAction(exp.getMergeAction)
+      case proto.Expression.ExprTypeCase.TYPED_AGGREGATE_EXPRESSION =>
+        transformTypedAggregateExpression(exp.getTypedAggregateExpression, 
baseRelationOpt)
       case _ =>
         throw InvalidPlanInput(
           s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not 
supported")
@@ -2584,8 +2606,35 @@ class SparkConnectPlanner(
           if expr.getUnresolvedFunction.getFunctionName == "reduce" =>
         // The reduce func needs the input data attribute, thus handle it 
specially here
         transformTypedReduceExpression(expr.getUnresolvedFunction, plan.output)
-      case _ => transformExpression(expr)
+      case _ => transformExpression(expr, Some(plan))
+    }
+  }
+
+  private def transformTypedAggregateExpression(
+      expr: proto.TypedAggregateExpression,
+      baseRelationOpt: Option[LogicalPlan]): AggregateExpression = {
+    val udf = expr.getScalarScalaUdf
+    assert(udf.getAggregate)
+
+    val udfPacket = unpackScalaUDF[UdfPacket](udf)
+    assert(udfPacket.inputEncoders.size == 1, "UDAF should have exactly one 
input encoder")
+
+    val aggregator = udfPacket.function.asInstanceOf[Aggregator[Any, Any, Any]]
+    val tae =
+      TypedAggregateExpression(aggregator)(aggregator.bufferEncoder, 
aggregator.outputEncoder)
+    val taeWithInput = baseRelationOpt match {
+      case Some(baseRelation) =>
+        val inputEncoder = TypedScalaUdf.encoderFor(
+          udfPacket.inputEncoders.head,
+          "input",
+          Some(baseRelation.output))
+        TypedAggUtils
+          .withInputType(tae, inputEncoder, baseRelation.output)
+          .asInstanceOf[TypedAggregateExpression]
+      case _ =>
+        tae
     }
+    taeWithInput.toAggregateExpression()
   }
 
   private def transformMergeAction(action: proto.MergeAction): MergeAction = {
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 91c8fb57c31b..3dabcdef1567 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -17,7 +17,11 @@
 
 package org.apache.spark.sql.expressions
 
-import org.apache.spark.sql.{Encoder, TypedColumn}
+import scala.reflect.runtime.universe._
+
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.{encoderFor, Encoder, TypedColumn}
+import org.apache.spark.sql.catalyst.ScalaReflection
 
 /**
  * A base class for user-defined aggregations, which can be used in `Dataset` 
operations to take
@@ -92,9 +96,52 @@ abstract class Aggregator[-IN, BUF, OUT] extends 
Serializable {
   def outputEncoder: Encoder[OUT]
 
   /**
-   * Returns this `Aggregator` as a `TypedColumn` that can be used in 
`Dataset`. operations.
+   * Returns this `Aggregator` as a `TypedColumn` that can be used in 
`Dataset` operations.
+   * @since 4.0.0
    */
   def toColumn: TypedColumn[IN, OUT] = {
-    throw new UnsupportedOperationException("toColumn is not implemented.")
+    val ttpe = getInputTypeTag[IN]
+    val inputEncoder = ScalaReflection.encoderFor(ttpe)
+    val udaf =
+      ScalaUserDefinedFunction(
+        this,
+        Seq(inputEncoder),
+        encoderFor(outputEncoder),
+        aggregate = true)
+
+    val builder = proto.TypedAggregateExpression.newBuilder()
+    builder.setScalarScalaUdf(udaf.udf)
+    val expr = 
proto.Expression.newBuilder().setTypedAggregateExpression(builder).build()
+
+    new TypedColumn(expr, encoderFor(outputEncoder))
+  }
+
+  private final def getInputTypeTag[T]: TypeTag[T] = {
+    val mirror = runtimeMirror(this.getClass.getClassLoader)
+    val tpe = mirror.classSymbol(this.getClass).toType
+    // Find the most generic (last in the tree) Aggregator class
+    val baseAgg =
+      tpe.baseClasses
+        .findLast(_.asClass.toType <:< typeOf[Aggregator[_, _, _]])
+        .getOrElse(throw new IllegalStateException("Could not find the 
Aggregator base class."))
+    val typeArgs = tpe.baseType(baseAgg).typeArgs
+    assert(
+      typeArgs.length == 3,
+      s"Aggregator should have 3 type arguments, " +
+        s"but found ${typeArgs.length}: ${typeArgs.mkString}.")
+    val inType = typeArgs.head
+
+    import scala.reflect.api._
+    TypeTag(
+      mirror,
+      new TypeCreator {
+        def apply[U <: Universe with Singleton](m: Mirror[U]): U#Type =
+          if (m eq mirror) {
+            inType.asInstanceOf[U#Type]
+          } else {
+            throw new IllegalArgumentException(
+              s"Type tag defined in $mirror cannot be migrated to other 
mirrors.")
+          }
+      })
   }
 }
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
index f4499858306a..dcf7f67551d3 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -107,7 +107,7 @@ case class ScalaUserDefinedFunction private[sql] (
     aggregate: Boolean)
     extends UserDefinedFunction {
 
-  private[this] lazy val udf = {
+  private[expressions] lazy val udf = {
     val scalaUdfBuilder = proto.ScalarScalaUDF
       .newBuilder()
       .setPayload(ByteString.copyFrom(serializedUdfPacket))
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
index 4032a9499c44..4aec0e6348c0 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
@@ -367,17 +367,7 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest 
with RemoteSparkSession
   test("UDAF custom Aggregator - case class as input types") {
     val session: SparkSession = spark
     import session.implicits._
-    val agg = new Aggregator[UdafTestInput, (Long, Long), Long] {
-      override def zero: (Long, Long) = (0L, 0L)
-      override def reduce(b: (Long, Long), a: UdafTestInput): (Long, Long) =
-        (b._1 + a.id, b._2 + a.extra)
-      override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) =
-        (b1._1 + b2._1, b1._2 + b2._2)
-      override def finish(reduction: (Long, Long)): Long = reduction._1 + 
reduction._2
-      override def bufferEncoder: Encoder[(Long, Long)] =
-        Encoders.tuple(Encoders.scalaLong, Encoders.scalaLong)
-      override def outputEncoder: Encoder[Long] = Encoders.scalaLong
-    }
+    val agg = new CompleteUdafTestInputAggregator()
     spark.udf.register("agg", udaf(agg))
     val result = spark
       .range(10)
@@ -388,6 +378,66 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest 
with RemoteSparkSession
       .head()
     assert(result == 135) // 45 + 90
   }
+
+  test("UDAF custom Aggregator - toColumn") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val aggCol = new CompleteUdafTestInputAggregator().toColumn
+    val ds = spark.range(10).withColumn("extra", col("id") * 
2).as[UdafTestInput]
+
+    assert(ds.select(aggCol).head() == 135) // 45 + 90
+    assert(ds.agg(aggCol).head().getLong(0) == 135) // 45 + 90
+  }
+
+  test("UDAF custom Aggregator - multiple extends - toColumn") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val aggCol = new CompleteGrandChildUdafTestInputAggregator().toColumn
+    val ds = spark.range(10).withColumn("extra", col("id") * 
2).as[UdafTestInput]
+
+    assert(ds.select(aggCol).head() == 540) // (45 + 90) * 4
+    assert(ds.agg(aggCol).head().getLong(0) == 540) // (45 + 90) * 4
+  }
 }
 
 case class UdafTestInput(id: Long, extra: Long)
+
+// An Aggregator that takes [[UdafTestInput]] as input.
+final class CompleteUdafTestInputAggregator
+    extends Aggregator[UdafTestInput, (Long, Long), Long] {
+  override def zero: (Long, Long) = (0L, 0L)
+  override def reduce(b: (Long, Long), a: UdafTestInput): (Long, Long) =
+    (b._1 + a.id, b._2 + a.extra)
+  override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) =
+    (b1._1 + b2._1, b1._2 + b2._2)
+  override def finish(reduction: (Long, Long)): Long = reduction._1 + 
reduction._2
+  override def bufferEncoder: Encoder[(Long, Long)] =
+    Encoders.tuple(Encoders.scalaLong, Encoders.scalaLong)
+  override def outputEncoder: Encoder[Long] = Encoders.scalaLong
+}
+
+// Same as [[CompleteUdafTestInputAggregator]] but the input type is not 
defined.
+abstract class IncompleteUdafTestInputAggregator[T] extends Aggregator[T, 
(Long, Long), Long] {
+  override def zero: (Long, Long) = (0L, 0L)
+  override def reduce(b: (Long, Long), a: T): (Long, Long) // Incomplete!
+  override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) =
+    (b1._1 + b2._1, b1._2 + b2._2)
+  override def finish(reduction: (Long, Long)): Long = reduction._1 + 
reduction._2
+  override def bufferEncoder: Encoder[(Long, Long)] =
+    Encoders.tuple(Encoders.scalaLong, Encoders.scalaLong)
+  override def outputEncoder: Encoder[Long] = Encoders.scalaLong
+}
+
+// A layer over [[IncompleteUdafTestInputAggregator]] but the input type is 
still not defined.
+abstract class IncompleteChildUdafTestInputAggregator[T]
+    extends IncompleteUdafTestInputAggregator[T] {
+  override def finish(reduction: (Long, Long)): Long = (reduction._1 + 
reduction._2) * 2
+}
+
+// Another layer that finally defines the input type.
+final class CompleteGrandChildUdafTestInputAggregator
+    extends IncompleteChildUdafTestInputAggregator[UdafTestInput] {
+  override def reduce(b: (Long, Long), a: UdafTestInput): (Long, Long) =
+    (b._1 + a.id, b._2 + a.extra)
+  override def finish(reduction: (Long, Long)): Long = (reduction._1 + 
reduction._2) * 4
+}
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py 
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index b4c4b48de268..1c1ad2b6ecec 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import common_pb2 as 
spark_dot_connect_dot_common
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto\x1a\x1aspark/connect/common.proto"\xd8/\n\nExpression\x12\x37\n\x06\x63ommon\x18\x12
 
\x01(\x0b\x32\x1f.spark.connect.ExpressionCommonR\x06\x63ommon\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAtt
 [...]
+    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto\x1a\x1aspark/connect/common.proto"\xc1\x30\n\nExpression\x12\x37\n\x06\x63ommon\x18\x12
 
\x01(\x0b\x32\x1f.spark.connect.ExpressionCommonR\x06\x63ommon\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolved 
[...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -47,75 +47,77 @@ if _descriptor._USE_C_DESCRIPTORS == False:
         b"\n\036org.apache.spark.connect.protoP\001Z\022internal/generated"
     )
     _EXPRESSION._serialized_start = 133
-    _EXPRESSION._serialized_end = 6237
-    _EXPRESSION_WINDOW._serialized_start = 1795
-    _EXPRESSION_WINDOW._serialized_end = 2578
-    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 2085
-    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 2578
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 2352
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 2497
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 2499
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 2578
-    _EXPRESSION_SORTORDER._serialized_start = 2581
-    _EXPRESSION_SORTORDER._serialized_end = 3006
-    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2811
-    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2919
-    _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2921
-    _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 3006
-    _EXPRESSION_CAST._serialized_start = 3009
-    _EXPRESSION_CAST._serialized_end = 3324
-    _EXPRESSION_CAST_EVALMODE._serialized_start = 3210
-    _EXPRESSION_CAST_EVALMODE._serialized_end = 3308
-    _EXPRESSION_LITERAL._serialized_start = 3327
-    _EXPRESSION_LITERAL._serialized_end = 4890
-    _EXPRESSION_LITERAL_DECIMAL._serialized_start = 4162
-    _EXPRESSION_LITERAL_DECIMAL._serialized_end = 4279
-    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 4281
-    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 4379
-    _EXPRESSION_LITERAL_ARRAY._serialized_start = 4382
-    _EXPRESSION_LITERAL_ARRAY._serialized_end = 4512
-    _EXPRESSION_LITERAL_MAP._serialized_start = 4515
-    _EXPRESSION_LITERAL_MAP._serialized_end = 4742
-    _EXPRESSION_LITERAL_STRUCT._serialized_start = 4745
-    _EXPRESSION_LITERAL_STRUCT._serialized_end = 4874
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4893
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 5079
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 5082
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 5286
-    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 5288
-    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 5338
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 5340
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 5464
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 5466
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5552
-    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5555
-    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5687
-    _EXPRESSION_UPDATEFIELDS._serialized_start = 5690
-    _EXPRESSION_UPDATEFIELDS._serialized_end = 5877
-    _EXPRESSION_ALIAS._serialized_start = 5879
-    _EXPRESSION_ALIAS._serialized_end = 5999
-    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 6002
-    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 6160
-    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 6162
-    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 6224
-    _EXPRESSIONCOMMON._serialized_start = 6239
-    _EXPRESSIONCOMMON._serialized_end = 6304
-    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 6307
-    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6671
-    _PYTHONUDF._serialized_start = 6674
-    _PYTHONUDF._serialized_end = 6878
-    _SCALARSCALAUDF._serialized_start = 6881
-    _SCALARSCALAUDF._serialized_end = 7095
-    _JAVAUDF._serialized_start = 7098
-    _JAVAUDF._serialized_end = 7247
-    _CALLFUNCTION._serialized_start = 7249
-    _CALLFUNCTION._serialized_end = 7357
-    _NAMEDARGUMENTEXPRESSION._serialized_start = 7359
-    _NAMEDARGUMENTEXPRESSION._serialized_end = 7451
-    _MERGEACTION._serialized_start = 7454
-    _MERGEACTION._serialized_end = 7966
-    _MERGEACTION_ASSIGNMENT._serialized_start = 7676
-    _MERGEACTION_ASSIGNMENT._serialized_end = 7782
-    _MERGEACTION_ACTIONTYPE._serialized_start = 7785
-    _MERGEACTION_ACTIONTYPE._serialized_end = 7952
+    _EXPRESSION._serialized_end = 6342
+    _EXPRESSION_WINDOW._serialized_start = 1900
+    _EXPRESSION_WINDOW._serialized_end = 2683
+    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 2190
+    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 2683
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 2457
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 2602
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 2604
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 2683
+    _EXPRESSION_SORTORDER._serialized_start = 2686
+    _EXPRESSION_SORTORDER._serialized_end = 3111
+    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2916
+    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 3024
+    _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 3026
+    _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 3111
+    _EXPRESSION_CAST._serialized_start = 3114
+    _EXPRESSION_CAST._serialized_end = 3429
+    _EXPRESSION_CAST_EVALMODE._serialized_start = 3315
+    _EXPRESSION_CAST_EVALMODE._serialized_end = 3413
+    _EXPRESSION_LITERAL._serialized_start = 3432
+    _EXPRESSION_LITERAL._serialized_end = 4995
+    _EXPRESSION_LITERAL_DECIMAL._serialized_start = 4267
+    _EXPRESSION_LITERAL_DECIMAL._serialized_end = 4384
+    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 4386
+    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 4484
+    _EXPRESSION_LITERAL_ARRAY._serialized_start = 4487
+    _EXPRESSION_LITERAL_ARRAY._serialized_end = 4617
+    _EXPRESSION_LITERAL_MAP._serialized_start = 4620
+    _EXPRESSION_LITERAL_MAP._serialized_end = 4847
+    _EXPRESSION_LITERAL_STRUCT._serialized_start = 4850
+    _EXPRESSION_LITERAL_STRUCT._serialized_end = 4979
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4998
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 5184
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 5187
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 5391
+    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 5393
+    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 5443
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 5445
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 5569
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 5571
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5657
+    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5660
+    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5792
+    _EXPRESSION_UPDATEFIELDS._serialized_start = 5795
+    _EXPRESSION_UPDATEFIELDS._serialized_end = 5982
+    _EXPRESSION_ALIAS._serialized_start = 5984
+    _EXPRESSION_ALIAS._serialized_end = 6104
+    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 6107
+    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 6265
+    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 6267
+    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 6329
+    _EXPRESSIONCOMMON._serialized_start = 6344
+    _EXPRESSIONCOMMON._serialized_end = 6409
+    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 6412
+    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6776
+    _PYTHONUDF._serialized_start = 6779
+    _PYTHONUDF._serialized_end = 6983
+    _SCALARSCALAUDF._serialized_start = 6986
+    _SCALARSCALAUDF._serialized_end = 7200
+    _JAVAUDF._serialized_start = 7203
+    _JAVAUDF._serialized_end = 7352
+    _TYPEDAGGREGATEEXPRESSION._serialized_start = 7354
+    _TYPEDAGGREGATEEXPRESSION._serialized_end = 7453
+    _CALLFUNCTION._serialized_start = 7455
+    _CALLFUNCTION._serialized_end = 7563
+    _NAMEDARGUMENTEXPRESSION._serialized_start = 7565
+    _NAMEDARGUMENTEXPRESSION._serialized_end = 7657
+    _MERGEACTION._serialized_start = 7660
+    _MERGEACTION._serialized_end = 8172
+    _MERGEACTION_ASSIGNMENT._serialized_start = 7882
+    _MERGEACTION_ASSIGNMENT._serialized_end = 7988
+    _MERGEACTION_ACTIONTYPE._serialized_start = 7991
+    _MERGEACTION_ACTIONTYPE._serialized_end = 8158
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi 
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 2c80be6c8fb5..1566eb1b1e9e 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -1183,6 +1183,7 @@ class Expression(google.protobuf.message.Message):
     CALL_FUNCTION_FIELD_NUMBER: builtins.int
     NAMED_ARGUMENT_EXPRESSION_FIELD_NUMBER: builtins.int
     MERGE_ACTION_FIELD_NUMBER: builtins.int
+    TYPED_AGGREGATE_EXPRESSION_FIELD_NUMBER: builtins.int
     EXTENSION_FIELD_NUMBER: builtins.int
     @property
     def common(self) -> global___ExpressionCommon: ...
@@ -1225,6 +1226,8 @@ class Expression(google.protobuf.message.Message):
     @property
     def merge_action(self) -> global___MergeAction: ...
     @property
+    def typed_aggregate_expression(self) -> global___TypedAggregateExpression: 
...
+    @property
     def extension(self) -> google.protobuf.any_pb2.Any:
         """This field is used to mark extensions to the protocol. When plugins 
generate arbitrary
         relations they can add them here. During the planning the correct 
resolution is done.
@@ -1252,6 +1255,7 @@ class Expression(google.protobuf.message.Message):
         call_function: global___CallFunction | None = ...,
         named_argument_expression: global___NamedArgumentExpression | None = 
...,
         merge_action: global___MergeAction | None = ...,
+        typed_aggregate_expression: global___TypedAggregateExpression | None = 
...,
         extension: google.protobuf.any_pb2.Any | None = ...,
     ) -> None: ...
     def HasField(
@@ -1283,6 +1287,8 @@ class Expression(google.protobuf.message.Message):
             b"named_argument_expression",
             "sort_order",
             b"sort_order",
+            "typed_aggregate_expression",
+            b"typed_aggregate_expression",
             "unresolved_attribute",
             b"unresolved_attribute",
             "unresolved_extract_value",
@@ -1330,6 +1336,8 @@ class Expression(google.protobuf.message.Message):
             b"named_argument_expression",
             "sort_order",
             b"sort_order",
+            "typed_aggregate_expression",
+            b"typed_aggregate_expression",
             "unresolved_attribute",
             b"unresolved_attribute",
             "unresolved_extract_value",
@@ -1370,6 +1378,7 @@ class Expression(google.protobuf.message.Message):
             "call_function",
             "named_argument_expression",
             "merge_action",
+            "typed_aggregate_expression",
             "extension",
         ]
         | None
@@ -1620,6 +1629,27 @@ class JavaUDF(google.protobuf.message.Message):
 
 global___JavaUDF = JavaUDF
 
+class TypedAggregateExpression(google.protobuf.message.Message):
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    SCALAR_SCALA_UDF_FIELD_NUMBER: builtins.int
+    @property
+    def scalar_scala_udf(self) -> global___ScalarScalaUDF:
+        """(Required) The aggregate function object packed into bytes."""
+    def __init__(
+        self,
+        *,
+        scalar_scala_udf: global___ScalarScalaUDF | None = ...,
+    ) -> None: ...
+    def HasField(
+        self, field_name: typing_extensions.Literal["scalar_scala_udf", 
b"scalar_scala_udf"]
+    ) -> builtins.bool: ...
+    def ClearField(
+        self, field_name: typing_extensions.Literal["scalar_scala_udf", 
b"scalar_scala_udf"]
+    ) -> None: ...
+
+global___TypedAggregateExpression = TypedAggregateExpression
+
 class CallFunction(google.protobuf.message.Message):
     DESCRIPTOR: google.protobuf.descriptor.Descriptor
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to