This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e20db137d2de [SPARK-48510] 2/2] Support UDAF `toColumn` API in Spark
Connect
e20db137d2de is described below
commit e20db137d2de26594b38c7e257a3d863de882022
Author: Paddy Xu <[email protected]>
AuthorDate: Fri Jul 12 20:21:03 2024 +0900
[SPARK-48510] 2/2] Support UDAF `toColumn` API in Spark Connect
### What changes were proposed in this pull request?
This PR follows https://github.com/apache/spark/pull/46245 to add support
`udaf.toColumn` API in Spark Connect.
Here we introduce a new Protobuf message, `proto.TypedAggregateExpression`,
that includes a serialized UDF packet. On the server, we unpack it into an
`Aggregator` object and generate a real `TypedAggregateExpression` instance
with the encoder information passed along with the UDF.
### Why are the changes needed?
Because the `toColumn` API is not supported in the previous PR.
### Does this PR introduce _any_ user-facing change?
Yes, from now on users could create typed UDAF using `udaf.toColumn` API/.
### How was this patch tested?
New tests.
### Was this patch authored or co-authored using generative AI tooling?
Nope.
Closes #46849 from xupefei/connect-udaf-tocolumn.
Authored-by: Paddy Xu <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../main/protobuf/spark/connect/expressions.proto | 6 +
.../sql/connect/planner/SparkConnectPlanner.scala | 63 ++++++++-
.../apache/spark/sql/expressions/Aggregator.scala | 53 +++++++-
.../sql/expressions/UserDefinedFunction.scala | 2 +-
.../sql/UserDefinedFunctionE2ETestSuite.scala | 72 ++++++++--
.../pyspark/sql/connect/proto/expressions_pb2.py | 146 +++++++++++----------
.../pyspark/sql/connect/proto/expressions_pb2.pyi | 30 +++++
7 files changed, 278 insertions(+), 94 deletions(-)
diff --git a/connect/common/src/main/protobuf/spark/connect/expressions.proto
b/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 860e92357616..3a91371fd3b2 100644
--- a/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -51,6 +51,7 @@ message Expression {
CallFunction call_function = 16;
NamedArgumentExpression named_argument_expression = 17;
MergeAction merge_action = 19;
+ TypedAggregateExpression typed_aggregate_expression = 20;
// This field is used to mark extensions to the protocol. When plugins
generate arbitrary
// relations they can add them here. During the planning the correct
resolution is done.
@@ -402,6 +403,11 @@ message JavaUDF {
bool aggregate = 3;
}
+message TypedAggregateExpression {
+ // (Required) The aggregate function object packed into bytes.
+ ScalarScalaUDF scalar_scala_udf = 1;
+}
+
message CallFunction {
// (Required) Unparsed name of the SQL function.
string function_name = 1;
diff --git
a/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 449e923beae3..4702f09a14c2 100644
---
a/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -51,7 +51,7 @@ import
org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, Mu
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder,
ExpressionEncoder, RowEncoder}
import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
+import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
BloomFilterAggregate}
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType,
LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
import org.apache.spark.sql.catalyst.plans.logical
@@ -67,6 +67,7 @@ import org.apache.spark.sql.connect.service.{ExecuteHolder,
SessionHolder, Spark
import org.apache.spark.sql.connect.utils.MetricGenerator
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.execution.command.CreateViewCommand
import org.apache.spark.sql.execution.datasources.LogicalRelation
@@ -1455,7 +1456,7 @@ class SparkConnectPlanner(
}
val projection = rel.getExpressionsList.asScala.toSeq
- .map(transformExpression)
+ .map(transformExpression(_, Some(baseRel)))
.map(toNamedExpression)
logical.Project(projectList = projection, child = baseRel)
@@ -1472,21 +1473,40 @@ class SparkConnectPlanner(
* Catalyst expression
*/
@DeveloperApi
- def transformExpression(exp: proto.Expression): Expression = if
(exp.hasCommon) {
+ def transformExpression(exp: proto.Expression): Expression =
transformExpression(exp, None)
+
+ /**
+ * Transforms an input protobuf expression into the Catalyst expression.
This is usually not
+ * called directly. Typically the planner will traverse the expressions
automatically, only
+ * plugins are expected to manually perform expression transformations.
+ *
+ * @param exp
+ * the input expression
+ * @param baseRelationOpt
+ * inputs of the base relation that contains this expression
+ * @return
+ * Catalyst expression
+ */
+ @DeveloperApi
+ def transformExpression(
+ exp: proto.Expression,
+ baseRelationOpt: Option[LogicalPlan]): Expression = if (exp.hasCommon) {
try {
val origin = exp.getCommon.getOrigin
PySparkCurrentOrigin.set(
origin.getPythonOrigin.getFragment,
origin.getPythonOrigin.getCallSite)
- withOrigin { doTransformExpression(exp) }
+ withOrigin { doTransformExpression(exp, baseRelationOpt) }
} finally {
PySparkCurrentOrigin.clear()
}
} else {
- doTransformExpression(exp)
+ doTransformExpression(exp, baseRelationOpt)
}
- private def doTransformExpression(exp: proto.Expression): Expression = {
+ private def doTransformExpression(
+ exp: proto.Expression,
+ baseRelationOpt: Option[LogicalPlan]): Expression = {
exp.getExprTypeCase match {
case proto.Expression.ExprTypeCase.LITERAL =>
transformLiteral(exp.getLiteral)
case proto.Expression.ExprTypeCase.UNRESOLVED_ATTRIBUTE =>
@@ -1523,6 +1543,8 @@ class SparkConnectPlanner(
transformNamedArgumentExpression(exp.getNamedArgumentExpression)
case proto.Expression.ExprTypeCase.MERGE_ACTION =>
transformMergeAction(exp.getMergeAction)
+ case proto.Expression.ExprTypeCase.TYPED_AGGREGATE_EXPRESSION =>
+ transformTypedAggregateExpression(exp.getTypedAggregateExpression,
baseRelationOpt)
case _ =>
throw InvalidPlanInput(
s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not
supported")
@@ -2584,8 +2606,35 @@ class SparkConnectPlanner(
if expr.getUnresolvedFunction.getFunctionName == "reduce" =>
// The reduce func needs the input data attribute, thus handle it
specially here
transformTypedReduceExpression(expr.getUnresolvedFunction, plan.output)
- case _ => transformExpression(expr)
+ case _ => transformExpression(expr, Some(plan))
+ }
+ }
+
+ private def transformTypedAggregateExpression(
+ expr: proto.TypedAggregateExpression,
+ baseRelationOpt: Option[LogicalPlan]): AggregateExpression = {
+ val udf = expr.getScalarScalaUdf
+ assert(udf.getAggregate)
+
+ val udfPacket = unpackScalaUDF[UdfPacket](udf)
+ assert(udfPacket.inputEncoders.size == 1, "UDAF should have exactly one
input encoder")
+
+ val aggregator = udfPacket.function.asInstanceOf[Aggregator[Any, Any, Any]]
+ val tae =
+ TypedAggregateExpression(aggregator)(aggregator.bufferEncoder,
aggregator.outputEncoder)
+ val taeWithInput = baseRelationOpt match {
+ case Some(baseRelation) =>
+ val inputEncoder = TypedScalaUdf.encoderFor(
+ udfPacket.inputEncoders.head,
+ "input",
+ Some(baseRelation.output))
+ TypedAggUtils
+ .withInputType(tae, inputEncoder, baseRelation.output)
+ .asInstanceOf[TypedAggregateExpression]
+ case _ =>
+ tae
}
+ taeWithInput.toAggregateExpression()
}
private def transformMergeAction(action: proto.MergeAction): MergeAction = {
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 91c8fb57c31b..3dabcdef1567 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -17,7 +17,11 @@
package org.apache.spark.sql.expressions
-import org.apache.spark.sql.{Encoder, TypedColumn}
+import scala.reflect.runtime.universe._
+
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.{encoderFor, Encoder, TypedColumn}
+import org.apache.spark.sql.catalyst.ScalaReflection
/**
* A base class for user-defined aggregations, which can be used in `Dataset`
operations to take
@@ -92,9 +96,52 @@ abstract class Aggregator[-IN, BUF, OUT] extends
Serializable {
def outputEncoder: Encoder[OUT]
/**
- * Returns this `Aggregator` as a `TypedColumn` that can be used in
`Dataset`. operations.
+ * Returns this `Aggregator` as a `TypedColumn` that can be used in
`Dataset` operations.
+ * @since 4.0.0
*/
def toColumn: TypedColumn[IN, OUT] = {
- throw new UnsupportedOperationException("toColumn is not implemented.")
+ val ttpe = getInputTypeTag[IN]
+ val inputEncoder = ScalaReflection.encoderFor(ttpe)
+ val udaf =
+ ScalaUserDefinedFunction(
+ this,
+ Seq(inputEncoder),
+ encoderFor(outputEncoder),
+ aggregate = true)
+
+ val builder = proto.TypedAggregateExpression.newBuilder()
+ builder.setScalarScalaUdf(udaf.udf)
+ val expr =
proto.Expression.newBuilder().setTypedAggregateExpression(builder).build()
+
+ new TypedColumn(expr, encoderFor(outputEncoder))
+ }
+
+ private final def getInputTypeTag[T]: TypeTag[T] = {
+ val mirror = runtimeMirror(this.getClass.getClassLoader)
+ val tpe = mirror.classSymbol(this.getClass).toType
+ // Find the most generic (last in the tree) Aggregator class
+ val baseAgg =
+ tpe.baseClasses
+ .findLast(_.asClass.toType <:< typeOf[Aggregator[_, _, _]])
+ .getOrElse(throw new IllegalStateException("Could not find the
Aggregator base class."))
+ val typeArgs = tpe.baseType(baseAgg).typeArgs
+ assert(
+ typeArgs.length == 3,
+ s"Aggregator should have 3 type arguments, " +
+ s"but found ${typeArgs.length}: ${typeArgs.mkString}.")
+ val inType = typeArgs.head
+
+ import scala.reflect.api._
+ TypeTag(
+ mirror,
+ new TypeCreator {
+ def apply[U <: Universe with Singleton](m: Mirror[U]): U#Type =
+ if (m eq mirror) {
+ inType.asInstanceOf[U#Type]
+ } else {
+ throw new IllegalArgumentException(
+ s"Type tag defined in $mirror cannot be migrated to other
mirrors.")
+ }
+ })
}
}
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
index f4499858306a..dcf7f67551d3 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -107,7 +107,7 @@ case class ScalaUserDefinedFunction private[sql] (
aggregate: Boolean)
extends UserDefinedFunction {
- private[this] lazy val udf = {
+ private[expressions] lazy val udf = {
val scalaUdfBuilder = proto.ScalarScalaUDF
.newBuilder()
.setPayload(ByteString.copyFrom(serializedUdfPacket))
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
index 4032a9499c44..4aec0e6348c0 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
@@ -367,17 +367,7 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest
with RemoteSparkSession
test("UDAF custom Aggregator - case class as input types") {
val session: SparkSession = spark
import session.implicits._
- val agg = new Aggregator[UdafTestInput, (Long, Long), Long] {
- override def zero: (Long, Long) = (0L, 0L)
- override def reduce(b: (Long, Long), a: UdafTestInput): (Long, Long) =
- (b._1 + a.id, b._2 + a.extra)
- override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) =
- (b1._1 + b2._1, b1._2 + b2._2)
- override def finish(reduction: (Long, Long)): Long = reduction._1 +
reduction._2
- override def bufferEncoder: Encoder[(Long, Long)] =
- Encoders.tuple(Encoders.scalaLong, Encoders.scalaLong)
- override def outputEncoder: Encoder[Long] = Encoders.scalaLong
- }
+ val agg = new CompleteUdafTestInputAggregator()
spark.udf.register("agg", udaf(agg))
val result = spark
.range(10)
@@ -388,6 +378,66 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest
with RemoteSparkSession
.head()
assert(result == 135) // 45 + 90
}
+
+ test("UDAF custom Aggregator - toColumn") {
+ val session: SparkSession = spark
+ import session.implicits._
+ val aggCol = new CompleteUdafTestInputAggregator().toColumn
+ val ds = spark.range(10).withColumn("extra", col("id") *
2).as[UdafTestInput]
+
+ assert(ds.select(aggCol).head() == 135) // 45 + 90
+ assert(ds.agg(aggCol).head().getLong(0) == 135) // 45 + 90
+ }
+
+ test("UDAF custom Aggregator - multiple extends - toColumn") {
+ val session: SparkSession = spark
+ import session.implicits._
+ val aggCol = new CompleteGrandChildUdafTestInputAggregator().toColumn
+ val ds = spark.range(10).withColumn("extra", col("id") *
2).as[UdafTestInput]
+
+ assert(ds.select(aggCol).head() == 540) // (45 + 90) * 4
+ assert(ds.agg(aggCol).head().getLong(0) == 540) // (45 + 90) * 4
+ }
}
case class UdafTestInput(id: Long, extra: Long)
+
+// An Aggregator that takes [[UdafTestInput]] as input.
+final class CompleteUdafTestInputAggregator
+ extends Aggregator[UdafTestInput, (Long, Long), Long] {
+ override def zero: (Long, Long) = (0L, 0L)
+ override def reduce(b: (Long, Long), a: UdafTestInput): (Long, Long) =
+ (b._1 + a.id, b._2 + a.extra)
+ override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) =
+ (b1._1 + b2._1, b1._2 + b2._2)
+ override def finish(reduction: (Long, Long)): Long = reduction._1 +
reduction._2
+ override def bufferEncoder: Encoder[(Long, Long)] =
+ Encoders.tuple(Encoders.scalaLong, Encoders.scalaLong)
+ override def outputEncoder: Encoder[Long] = Encoders.scalaLong
+}
+
+// Same as [[CompleteUdafTestInputAggregator]] but the input type is not
defined.
+abstract class IncompleteUdafTestInputAggregator[T] extends Aggregator[T,
(Long, Long), Long] {
+ override def zero: (Long, Long) = (0L, 0L)
+ override def reduce(b: (Long, Long), a: T): (Long, Long) // Incomplete!
+ override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) =
+ (b1._1 + b2._1, b1._2 + b2._2)
+ override def finish(reduction: (Long, Long)): Long = reduction._1 +
reduction._2
+ override def bufferEncoder: Encoder[(Long, Long)] =
+ Encoders.tuple(Encoders.scalaLong, Encoders.scalaLong)
+ override def outputEncoder: Encoder[Long] = Encoders.scalaLong
+}
+
+// A layer over [[IncompleteUdafTestInputAggregator]] but the input type is
still not defined.
+abstract class IncompleteChildUdafTestInputAggregator[T]
+ extends IncompleteUdafTestInputAggregator[T] {
+ override def finish(reduction: (Long, Long)): Long = (reduction._1 +
reduction._2) * 2
+}
+
+// Another layer that finally defines the input type.
+final class CompleteGrandChildUdafTestInputAggregator
+ extends IncompleteChildUdafTestInputAggregator[UdafTestInput] {
+ override def reduce(b: (Long, Long), a: UdafTestInput): (Long, Long) =
+ (b._1 + a.id, b._2 + a.extra)
+ override def finish(reduction: (Long, Long)): Long = (reduction._1 +
reduction._2) * 4
+}
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index b4c4b48de268..1c1ad2b6ecec 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import common_pb2 as
spark_dot_connect_dot_common
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto\x1a\x1aspark/connect/common.proto"\xd8/\n\nExpression\x12\x37\n\x06\x63ommon\x18\x12
\x01(\x0b\x32\x1f.spark.connect.ExpressionCommonR\x06\x63ommon\x12=\n\x07literal\x18\x01
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAtt
[...]
+
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto\x1a\x1aspark/connect/common.proto"\xc1\x30\n\nExpression\x12\x37\n\x06\x63ommon\x18\x12
\x01(\x0b\x32\x1f.spark.connect.ExpressionCommonR\x06\x63ommon\x12=\n\x07literal\x18\x01
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolved
[...]
)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -47,75 +47,77 @@ if _descriptor._USE_C_DESCRIPTORS == False:
b"\n\036org.apache.spark.connect.protoP\001Z\022internal/generated"
)
_EXPRESSION._serialized_start = 133
- _EXPRESSION._serialized_end = 6237
- _EXPRESSION_WINDOW._serialized_start = 1795
- _EXPRESSION_WINDOW._serialized_end = 2578
- _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 2085
- _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 2578
- _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 2352
- _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 2497
- _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 2499
- _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 2578
- _EXPRESSION_SORTORDER._serialized_start = 2581
- _EXPRESSION_SORTORDER._serialized_end = 3006
- _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2811
- _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2919
- _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2921
- _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 3006
- _EXPRESSION_CAST._serialized_start = 3009
- _EXPRESSION_CAST._serialized_end = 3324
- _EXPRESSION_CAST_EVALMODE._serialized_start = 3210
- _EXPRESSION_CAST_EVALMODE._serialized_end = 3308
- _EXPRESSION_LITERAL._serialized_start = 3327
- _EXPRESSION_LITERAL._serialized_end = 4890
- _EXPRESSION_LITERAL_DECIMAL._serialized_start = 4162
- _EXPRESSION_LITERAL_DECIMAL._serialized_end = 4279
- _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 4281
- _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 4379
- _EXPRESSION_LITERAL_ARRAY._serialized_start = 4382
- _EXPRESSION_LITERAL_ARRAY._serialized_end = 4512
- _EXPRESSION_LITERAL_MAP._serialized_start = 4515
- _EXPRESSION_LITERAL_MAP._serialized_end = 4742
- _EXPRESSION_LITERAL_STRUCT._serialized_start = 4745
- _EXPRESSION_LITERAL_STRUCT._serialized_end = 4874
- _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4893
- _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 5079
- _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 5082
- _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 5286
- _EXPRESSION_EXPRESSIONSTRING._serialized_start = 5288
- _EXPRESSION_EXPRESSIONSTRING._serialized_end = 5338
- _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 5340
- _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 5464
- _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 5466
- _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5552
- _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5555
- _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5687
- _EXPRESSION_UPDATEFIELDS._serialized_start = 5690
- _EXPRESSION_UPDATEFIELDS._serialized_end = 5877
- _EXPRESSION_ALIAS._serialized_start = 5879
- _EXPRESSION_ALIAS._serialized_end = 5999
- _EXPRESSION_LAMBDAFUNCTION._serialized_start = 6002
- _EXPRESSION_LAMBDAFUNCTION._serialized_end = 6160
- _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 6162
- _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 6224
- _EXPRESSIONCOMMON._serialized_start = 6239
- _EXPRESSIONCOMMON._serialized_end = 6304
- _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 6307
- _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6671
- _PYTHONUDF._serialized_start = 6674
- _PYTHONUDF._serialized_end = 6878
- _SCALARSCALAUDF._serialized_start = 6881
- _SCALARSCALAUDF._serialized_end = 7095
- _JAVAUDF._serialized_start = 7098
- _JAVAUDF._serialized_end = 7247
- _CALLFUNCTION._serialized_start = 7249
- _CALLFUNCTION._serialized_end = 7357
- _NAMEDARGUMENTEXPRESSION._serialized_start = 7359
- _NAMEDARGUMENTEXPRESSION._serialized_end = 7451
- _MERGEACTION._serialized_start = 7454
- _MERGEACTION._serialized_end = 7966
- _MERGEACTION_ASSIGNMENT._serialized_start = 7676
- _MERGEACTION_ASSIGNMENT._serialized_end = 7782
- _MERGEACTION_ACTIONTYPE._serialized_start = 7785
- _MERGEACTION_ACTIONTYPE._serialized_end = 7952
+ _EXPRESSION._serialized_end = 6342
+ _EXPRESSION_WINDOW._serialized_start = 1900
+ _EXPRESSION_WINDOW._serialized_end = 2683
+ _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 2190
+ _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 2683
+ _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 2457
+ _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 2602
+ _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 2604
+ _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 2683
+ _EXPRESSION_SORTORDER._serialized_start = 2686
+ _EXPRESSION_SORTORDER._serialized_end = 3111
+ _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2916
+ _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 3024
+ _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 3026
+ _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 3111
+ _EXPRESSION_CAST._serialized_start = 3114
+ _EXPRESSION_CAST._serialized_end = 3429
+ _EXPRESSION_CAST_EVALMODE._serialized_start = 3315
+ _EXPRESSION_CAST_EVALMODE._serialized_end = 3413
+ _EXPRESSION_LITERAL._serialized_start = 3432
+ _EXPRESSION_LITERAL._serialized_end = 4995
+ _EXPRESSION_LITERAL_DECIMAL._serialized_start = 4267
+ _EXPRESSION_LITERAL_DECIMAL._serialized_end = 4384
+ _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 4386
+ _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 4484
+ _EXPRESSION_LITERAL_ARRAY._serialized_start = 4487
+ _EXPRESSION_LITERAL_ARRAY._serialized_end = 4617
+ _EXPRESSION_LITERAL_MAP._serialized_start = 4620
+ _EXPRESSION_LITERAL_MAP._serialized_end = 4847
+ _EXPRESSION_LITERAL_STRUCT._serialized_start = 4850
+ _EXPRESSION_LITERAL_STRUCT._serialized_end = 4979
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4998
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 5184
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 5187
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 5391
+ _EXPRESSION_EXPRESSIONSTRING._serialized_start = 5393
+ _EXPRESSION_EXPRESSIONSTRING._serialized_end = 5443
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 5445
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 5569
+ _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 5571
+ _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5657
+ _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5660
+ _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5792
+ _EXPRESSION_UPDATEFIELDS._serialized_start = 5795
+ _EXPRESSION_UPDATEFIELDS._serialized_end = 5982
+ _EXPRESSION_ALIAS._serialized_start = 5984
+ _EXPRESSION_ALIAS._serialized_end = 6104
+ _EXPRESSION_LAMBDAFUNCTION._serialized_start = 6107
+ _EXPRESSION_LAMBDAFUNCTION._serialized_end = 6265
+ _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 6267
+ _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 6329
+ _EXPRESSIONCOMMON._serialized_start = 6344
+ _EXPRESSIONCOMMON._serialized_end = 6409
+ _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 6412
+ _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6776
+ _PYTHONUDF._serialized_start = 6779
+ _PYTHONUDF._serialized_end = 6983
+ _SCALARSCALAUDF._serialized_start = 6986
+ _SCALARSCALAUDF._serialized_end = 7200
+ _JAVAUDF._serialized_start = 7203
+ _JAVAUDF._serialized_end = 7352
+ _TYPEDAGGREGATEEXPRESSION._serialized_start = 7354
+ _TYPEDAGGREGATEEXPRESSION._serialized_end = 7453
+ _CALLFUNCTION._serialized_start = 7455
+ _CALLFUNCTION._serialized_end = 7563
+ _NAMEDARGUMENTEXPRESSION._serialized_start = 7565
+ _NAMEDARGUMENTEXPRESSION._serialized_end = 7657
+ _MERGEACTION._serialized_start = 7660
+ _MERGEACTION._serialized_end = 8172
+ _MERGEACTION_ASSIGNMENT._serialized_start = 7882
+ _MERGEACTION_ASSIGNMENT._serialized_end = 7988
+ _MERGEACTION_ACTIONTYPE._serialized_start = 7991
+ _MERGEACTION_ACTIONTYPE._serialized_end = 8158
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 2c80be6c8fb5..1566eb1b1e9e 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -1183,6 +1183,7 @@ class Expression(google.protobuf.message.Message):
CALL_FUNCTION_FIELD_NUMBER: builtins.int
NAMED_ARGUMENT_EXPRESSION_FIELD_NUMBER: builtins.int
MERGE_ACTION_FIELD_NUMBER: builtins.int
+ TYPED_AGGREGATE_EXPRESSION_FIELD_NUMBER: builtins.int
EXTENSION_FIELD_NUMBER: builtins.int
@property
def common(self) -> global___ExpressionCommon: ...
@@ -1225,6 +1226,8 @@ class Expression(google.protobuf.message.Message):
@property
def merge_action(self) -> global___MergeAction: ...
@property
+ def typed_aggregate_expression(self) -> global___TypedAggregateExpression:
...
+ @property
def extension(self) -> google.protobuf.any_pb2.Any:
"""This field is used to mark extensions to the protocol. When plugins
generate arbitrary
relations they can add them here. During the planning the correct
resolution is done.
@@ -1252,6 +1255,7 @@ class Expression(google.protobuf.message.Message):
call_function: global___CallFunction | None = ...,
named_argument_expression: global___NamedArgumentExpression | None =
...,
merge_action: global___MergeAction | None = ...,
+ typed_aggregate_expression: global___TypedAggregateExpression | None =
...,
extension: google.protobuf.any_pb2.Any | None = ...,
) -> None: ...
def HasField(
@@ -1283,6 +1287,8 @@ class Expression(google.protobuf.message.Message):
b"named_argument_expression",
"sort_order",
b"sort_order",
+ "typed_aggregate_expression",
+ b"typed_aggregate_expression",
"unresolved_attribute",
b"unresolved_attribute",
"unresolved_extract_value",
@@ -1330,6 +1336,8 @@ class Expression(google.protobuf.message.Message):
b"named_argument_expression",
"sort_order",
b"sort_order",
+ "typed_aggregate_expression",
+ b"typed_aggregate_expression",
"unresolved_attribute",
b"unresolved_attribute",
"unresolved_extract_value",
@@ -1370,6 +1378,7 @@ class Expression(google.protobuf.message.Message):
"call_function",
"named_argument_expression",
"merge_action",
+ "typed_aggregate_expression",
"extension",
]
| None
@@ -1620,6 +1629,27 @@ class JavaUDF(google.protobuf.message.Message):
global___JavaUDF = JavaUDF
+class TypedAggregateExpression(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ SCALAR_SCALA_UDF_FIELD_NUMBER: builtins.int
+ @property
+ def scalar_scala_udf(self) -> global___ScalarScalaUDF:
+ """(Required) The aggregate function object packed into bytes."""
+ def __init__(
+ self,
+ *,
+ scalar_scala_udf: global___ScalarScalaUDF | None = ...,
+ ) -> None: ...
+ def HasField(
+ self, field_name: typing_extensions.Literal["scalar_scala_udf",
b"scalar_scala_udf"]
+ ) -> builtins.bool: ...
+ def ClearField(
+ self, field_name: typing_extensions.Literal["scalar_scala_udf",
b"scalar_scala_udf"]
+ ) -> None: ...
+
+global___TypedAggregateExpression = TypedAggregateExpression
+
class CallFunction(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]