This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9ada1e7e68db [SPARK-51820][FOLLOWUP][CONNECT] Replace literal in 
`SortOrder` only under `Sort` operator
9ada1e7e68db is described below

commit 9ada1e7e68db2c1a79c528d1ff6482f58cd6ccb7
Author: Mihailo Timotic <mihailo.timo...@databricks.com>
AuthorDate: Tue Sep 2 21:05:30 2025 +0800

    [SPARK-51820][FOLLOWUP][CONNECT] Replace literal in `SortOrder` only under 
`Sort` operator
    
    ### What changes were proposed in this pull request?
    Replace literal in `SortOrder` only under `Sort` operator
    
    ### Why are the changes needed?
    SPARK-51820 introduced a bug where literal under all `SortOrder` 
expressions were treated as ordinals, breaking Windows in Spark Connect. This 
PR fixes that.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Added a test case
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #52189 from mihailotim-db/mihailotim-db/fix_window_ordinal.
    
    Authored-by: Mihailo Timotic <mihailo.timo...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/connect/planner/SparkConnectPlanner.scala  | 23 +++++---
 .../connect/planner/SparkConnectPlannerSuite.scala | 63 ++++++++++++++++++++++
 2 files changed, 80 insertions(+), 6 deletions(-)

diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index c0b1fd01616a..a394c5cb80fe 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1752,7 +1752,8 @@ class SparkConnectPlanner(
         transformUnresolvedExtractValue(exp.getUnresolvedExtractValue)
       case proto.Expression.ExprTypeCase.UPDATE_FIELDS =>
         transformUpdateFields(exp.getUpdateFields)
-      case proto.Expression.ExprTypeCase.SORT_ORDER => 
transformSortOrder(exp.getSortOrder)
+      case proto.Expression.ExprTypeCase.SORT_ORDER =>
+        transformSortOrder(order = exp.getSortOrder, shouldReplaceOrdinals = 
false)
       case proto.Expression.ExprTypeCase.LAMBDA_FUNCTION =>
         transformLambdaFunction(exp.getLambdaFunction)
       case proto.Expression.ExprTypeCase.UNRESOLVED_NAMED_LAMBDA_VARIABLE =>
@@ -2230,7 +2231,8 @@ class SparkConnectPlanner(
 
     val windowSpec = WindowSpecDefinition(
       partitionSpec = 
window.getPartitionSpecList.asScala.toSeq.map(transformExpression),
-      orderSpec = 
window.getOrderSpecList.asScala.toSeq.map(transformSortOrder),
+      orderSpec = window.getOrderSpecList.asScala.toSeq.map(orderSpec =>
+        transformSortOrder(order = orderSpec, shouldReplaceOrdinals = false)),
       frameSpecification = frameSpec)
 
     WindowExpression(
@@ -2382,12 +2384,20 @@ class SparkConnectPlanner(
     logical.Sort(
       child = transformRelation(sort.getInput),
       global = sort.getIsGlobal,
-      order = sort.getOrderList.asScala.toSeq.map(transformSortOrder))
+      order = sort.getOrderList.asScala.toSeq.map(order =>
+        transformSortOrder(order = order, shouldReplaceOrdinals = true)))
   }
 
-  private def transformSortOrder(order: proto.Expression.SortOrder) = {
+  private def transformSortOrder(
+      order: proto.Expression.SortOrder,
+      shouldReplaceOrdinals: Boolean = false) = {
+    val childWithReplacedOrdinals = if (shouldReplaceOrdinals) {
+      transformSortOrderAndReplaceOrdinals(order.getChild)
+    } else {
+      transformExpression(order.getChild)
+    }
     expressions.SortOrder(
-      child = transformSortOrderAndReplaceOrdinals(order.getChild),
+      child = childWithReplacedOrdinals,
       direction = order.getDirection match {
         case proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_ASCENDING 
=>
           expressions.Ascending
@@ -4081,7 +4091,8 @@ class SparkConnectPlanner(
               .map(transformExpression)
               .toSeq,
             orderSpec = options.getOrderSpecList.asScala
-              .map(transformSortOrder)
+              .map(orderSpec =>
+                transformSortOrder(order = orderSpec, shouldReplaceOrdinals = 
false))
               .toSeq,
             withSinglePartition =
               options.hasWithSinglePartition && options.getWithSinglePartition)
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index 16cdd7da8279..66ff45b553c7 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -962,6 +962,69 @@ class SparkConnectPlannerSuite extends SparkFunSuite with 
SparkConnectPlanTest {
       !aggregateExpression.containsPattern(TreePattern.UNRESOLVED_ORDINAL)))
   }
 
+  test("SPARK-51820 Literals in SortOrder should only be replaced under Sort 
node") {
+    val schema = StructType(Seq(StructField("col1", IntegerType)))
+    val data = Seq(InternalRow(1))
+    val inputRows = data.map { row =>
+      val proj = UnsafeProjection.create(schema)
+      proj(row).copy()
+    }
+    val localRelation = createLocalRelationProto(schema, inputRows)
+
+    val sumFunction = proto.Expression
+      .newBuilder()
+      .setUnresolvedFunction(
+        proto.Expression.UnresolvedFunction
+          .newBuilder()
+          .setFunctionName("sum")
+          .addArguments(
+            proto.Expression
+              .newBuilder()
+              .setUnresolvedAttribute(proto.Expression.UnresolvedAttribute
+                .newBuilder()
+                .setUnparsedIdentifier("col1"))))
+      .build()
+
+    val windowExpression = proto.Expression
+      .newBuilder()
+      .setWindow(
+        proto.Expression.Window
+          .newBuilder()
+          .setWindowFunction(sumFunction)
+          .addOrderSpec(
+            proto.Expression.SortOrder
+              .newBuilder()
+              .setChild(proto.Expression
+                .newBuilder()
+                
.setLiteral(proto.Expression.Literal.newBuilder().setInteger(4)))
+              
.setDirection(proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_ASCENDING)
+              
.setNullOrdering(proto.Expression.SortOrder.NullOrdering.SORT_NULLS_FIRST)))
+      .build()
+
+    val aliasedWindowExpression = proto.Expression
+      .newBuilder()
+      .setAlias(
+        proto.Expression.Alias
+          .newBuilder()
+          .setExpr(windowExpression)
+          .addName("sum_over"))
+      .build()
+
+    val project = proto.Project
+      .newBuilder()
+      .setInput(localRelation)
+      .addExpressions(aliasedWindowExpression)
+      .build()
+
+    val result = 
transform(proto.Relation.newBuilder().setProject(project).build())
+    val df = Dataset.ofRows(spark, result)
+
+    val collected = df.collect()
+    assert(collected.length == 1)
+    assert(df.schema.fields.head.name == "sum_over")
+    assert(collected(0).getAs[Long]("sum_over") == 1L)
+  }
+
   test("Time literal") {
     val project = proto.Project.newBuilder
       .addExpressions(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to