This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 9ada1e7e68db [SPARK-51820][FOLLOWUP][CONNECT] Replace literal in `SortOrder` only under `Sort` operator 9ada1e7e68db is described below commit 9ada1e7e68db2c1a79c528d1ff6482f58cd6ccb7 Author: Mihailo Timotic <mihailo.timo...@databricks.com> AuthorDate: Tue Sep 2 21:05:30 2025 +0800 [SPARK-51820][FOLLOWUP][CONNECT] Replace literal in `SortOrder` only under `Sort` operator ### What changes were proposed in this pull request? Replace literal in `SortOrder` only under `Sort` operator ### Why are the changes needed? SPARK-51820 introduced a bug where literal under all `SortOrder` expressions were treated as ordinals, breaking Windows in Spark Connect. This PR fixes that. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added a test case ### Was this patch authored or co-authored using generative AI tooling? No Closes #52189 from mihailotim-db/mihailotim-db/fix_window_ordinal. Authored-by: Mihailo Timotic <mihailo.timo...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/connect/planner/SparkConnectPlanner.scala | 23 +++++--- .../connect/planner/SparkConnectPlannerSuite.scala | 63 ++++++++++++++++++++++ 2 files changed, 80 insertions(+), 6 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index c0b1fd01616a..a394c5cb80fe 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1752,7 +1752,8 @@ class SparkConnectPlanner( transformUnresolvedExtractValue(exp.getUnresolvedExtractValue) case proto.Expression.ExprTypeCase.UPDATE_FIELDS => transformUpdateFields(exp.getUpdateFields) - case proto.Expression.ExprTypeCase.SORT_ORDER => transformSortOrder(exp.getSortOrder) + case proto.Expression.ExprTypeCase.SORT_ORDER => + transformSortOrder(order = exp.getSortOrder, shouldReplaceOrdinals = false) case proto.Expression.ExprTypeCase.LAMBDA_FUNCTION => transformLambdaFunction(exp.getLambdaFunction) case proto.Expression.ExprTypeCase.UNRESOLVED_NAMED_LAMBDA_VARIABLE => @@ -2230,7 +2231,8 @@ class SparkConnectPlanner( val windowSpec = WindowSpecDefinition( partitionSpec = window.getPartitionSpecList.asScala.toSeq.map(transformExpression), - orderSpec = window.getOrderSpecList.asScala.toSeq.map(transformSortOrder), + orderSpec = window.getOrderSpecList.asScala.toSeq.map(orderSpec => + transformSortOrder(order = orderSpec, shouldReplaceOrdinals = false)), frameSpecification = frameSpec) WindowExpression( @@ -2382,12 +2384,20 @@ class SparkConnectPlanner( logical.Sort( child = transformRelation(sort.getInput), global = sort.getIsGlobal, - order = sort.getOrderList.asScala.toSeq.map(transformSortOrder)) + order = sort.getOrderList.asScala.toSeq.map(order => + transformSortOrder(order = order, shouldReplaceOrdinals = true))) } - private def transformSortOrder(order: proto.Expression.SortOrder) = { + private def transformSortOrder( + order: proto.Expression.SortOrder, + shouldReplaceOrdinals: Boolean = false) = { + val childWithReplacedOrdinals = if (shouldReplaceOrdinals) { + transformSortOrderAndReplaceOrdinals(order.getChild) + } else { + transformExpression(order.getChild) + } expressions.SortOrder( - child = transformSortOrderAndReplaceOrdinals(order.getChild), + child = childWithReplacedOrdinals, direction = order.getDirection match { case proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_ASCENDING => expressions.Ascending @@ -4081,7 +4091,8 @@ class SparkConnectPlanner( .map(transformExpression) .toSeq, orderSpec = options.getOrderSpecList.asScala - .map(transformSortOrder) + .map(orderSpec => + transformSortOrder(order = orderSpec, shouldReplaceOrdinals = false)) .toSeq, withSinglePartition = options.hasWithSinglePartition && options.getWithSinglePartition) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 16cdd7da8279..66ff45b553c7 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -962,6 +962,69 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { !aggregateExpression.containsPattern(TreePattern.UNRESOLVED_ORDINAL))) } + test("SPARK-51820 Literals in SortOrder should only be replaced under Sort node") { + val schema = StructType(Seq(StructField("col1", IntegerType))) + val data = Seq(InternalRow(1)) + val inputRows = data.map { row => + val proj = UnsafeProjection.create(schema) + proj(row).copy() + } + val localRelation = createLocalRelationProto(schema, inputRows) + + val sumFunction = proto.Expression + .newBuilder() + .setUnresolvedFunction( + proto.Expression.UnresolvedFunction + .newBuilder() + .setFunctionName("sum") + .addArguments( + proto.Expression + .newBuilder() + .setUnresolvedAttribute(proto.Expression.UnresolvedAttribute + .newBuilder() + .setUnparsedIdentifier("col1")))) + .build() + + val windowExpression = proto.Expression + .newBuilder() + .setWindow( + proto.Expression.Window + .newBuilder() + .setWindowFunction(sumFunction) + .addOrderSpec( + proto.Expression.SortOrder + .newBuilder() + .setChild(proto.Expression + .newBuilder() + .setLiteral(proto.Expression.Literal.newBuilder().setInteger(4))) + .setDirection(proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_ASCENDING) + .setNullOrdering(proto.Expression.SortOrder.NullOrdering.SORT_NULLS_FIRST))) + .build() + + val aliasedWindowExpression = proto.Expression + .newBuilder() + .setAlias( + proto.Expression.Alias + .newBuilder() + .setExpr(windowExpression) + .addName("sum_over")) + .build() + + val project = proto.Project + .newBuilder() + .setInput(localRelation) + .addExpressions(aliasedWindowExpression) + .build() + + val result = transform(proto.Relation.newBuilder().setProject(project).build()) + val df = Dataset.ofRows(spark, result) + + val collected = df.collect() + assert(collected.length == 1) + assert(df.schema.fields.head.name == "sum_over") + assert(collected(0).getAs[Long]("sum_over") == 1L) + } + test("Time literal") { val project = proto.Project.newBuilder .addExpressions( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org