This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8b209d188f6c [SPARK-50124][SQL][FOLLOWUP] InsertSortForLimitAndOffset 
should propagate missing ordering columns
8b209d188f6c is described below

commit 8b209d188f6c3cb8ca59d41625c49f9a10bb16f2
Author: Wenchen Fan <[email protected]>
AuthorDate: Fri Jan 10 16:23:01 2025 +0800

    [SPARK-50124][SQL][FOLLOWUP] InsertSortForLimitAndOffset should propagate 
missing ordering columns
    
    ### What changes were proposed in this pull request?
    
    This is a followup of https://github.com/apache/spark/pull/48661 to extend 
the fix to handle a common query pattern: doing project/filter after the sort.
    
    The implementation is a bit complicated as now we need to propagate the 
ordering columns pre-shuffle, so that we can perform the local sort 
post-shuffle.
    
    ### Why are the changes needed?
    
    extend the bug fix to cover more query patterns.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, the fix is not released
    
    ### How was this patch tested?
    
    new tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #49416 from cloud-fan/fix.
    
    Authored-by: Wenchen Fan <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../execution/InsertSortForLimitAndOffset.scala    | 81 ++++++++++++++--------
 .../InsertSortForLimitAndOffsetSuite.scala         | 47 ++++++++++---
 2 files changed, 90 insertions(+), 38 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/InsertSortForLimitAndOffset.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/InsertSortForLimitAndOffset.scala
index fa63e04d91b0..aa29128cda7e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/InsertSortForLimitAndOffset.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/InsertSortForLimitAndOffset.scala
@@ -17,13 +17,12 @@
 
 package org.apache.spark.sql.execution
 
-import scala.annotation.tailrec
-
 import org.apache.spark.sql.catalyst.expressions.SortOrder
+import org.apache.spark.sql.catalyst.plans.logical.{Project, Sort}
 import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, 
ShuffleQueryStageExec}
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.python.EvalPythonExec
 import org.apache.spark.sql.internal.SQLConf
 
 /**
@@ -43,33 +42,61 @@ object InsertSortForLimitAndOffset extends Rule[SparkPlan] {
     plan transform {
       case l @ GlobalLimitExec(
           _,
-          SinglePartitionShuffleWithGlobalOrdering(ordering),
-          _) =>
-        val newChild = SortExec(ordering, global = false, child = l.child)
-        l.withNewChildren(Seq(newChild))
-    }
-  }
-
-  object SinglePartitionShuffleWithGlobalOrdering {
-    @tailrec
-    def unapply(plan: SparkPlan): Option[Seq[SortOrder]] = plan match {
-      case ShuffleExchangeExec(SinglePartition, 
SparkPlanWithGlobalOrdering(ordering), _, _) =>
-        Some(ordering)
-      case p: AQEShuffleReadExec => unapply(p.child)
-      case p: ShuffleQueryStageExec => unapply(p.plan)
-      case _ => None
+          // Should not match AQE shuffle stage because we only target 
un-submitted stages which
+          // we can still rewrite the query plan.
+          s @ ShuffleExchangeExec(SinglePartition, child, _, _),
+          _) if child.logicalLink.isDefined =>
+        extractOrderingAndPropagateOrderingColumns(child) match {
+          case Some((ordering, newChild)) =>
+            val newShuffle = s.withNewChildren(Seq(newChild))
+            val sorted = SortExec(ordering, global = false, child = newShuffle)
+            // We must set the logical plan link to avoid losing the added 
SortExec and ProjectExec
+            // during AQE re-optimization, where we turn physical plan back to 
logical plan.
+            val logicalSort = Sort(ordering, global = false, child = 
s.child.logicalLink.get)
+            sorted.setLogicalLink(logicalSort)
+            val projected = if (sorted.output == s.output) {
+              sorted
+            } else {
+              val p = ProjectExec(s.output, sorted)
+              p.setLogicalLink(Project(s.output, logicalSort))
+              p
+            }
+            l.withNewChildren(Seq(projected))
+          case _ => l
+        }
     }
   }
 
   // Note: this is not implementing a generalized notion of "global order 
preservation", but just
-  // tackles the regular ORDER BY semantics with optional LIMIT (top-K).
-  object SparkPlanWithGlobalOrdering {
-    @tailrec
-    def unapply(plan: SparkPlan): Option[Seq[SortOrder]] = plan match {
-      case p: SortExec if p.global => Some(p.sortOrder)
-      case p: LocalLimitExec => unapply(p.child)
-      case p: WholeStageCodegenExec => unapply(p.child)
-      case _ => None
-    }
+  // a best effort to catch the common query patterns that the data ordering 
should be preserved.
+  private def extractOrderingAndPropagateOrderingColumns(
+      plan: SparkPlan): Option[(Seq[SortOrder], SparkPlan)] = plan match {
+    case p: SortExec if p.global => Some(p.sortOrder, p)
+    case p: UnaryExecNode if
+        p.isInstanceOf[LocalLimitExec] ||
+          p.isInstanceOf[WholeStageCodegenExec] ||
+          p.isInstanceOf[FilterExec] ||
+          p.isInstanceOf[EvalPythonExec] =>
+      extractOrderingAndPropagateOrderingColumns(p.child) match {
+        case Some((ordering, newChild)) => Some((ordering, 
p.withNewChildren(Seq(newChild))))
+        case _ => None
+      }
+    case p: ProjectExec =>
+      extractOrderingAndPropagateOrderingColumns(p.child) match {
+        case Some((ordering, newChild)) =>
+          val orderingCols = ordering.flatMap(_.references)
+          if (orderingCols.forall(p.outputSet.contains)) {
+            Some((ordering, p.withNewChildren(Seq(newChild))))
+          } else {
+            // In order to do the sort after shuffle, we must propagate the 
ordering columns in the
+            // pre-shuffle ProjectExec.
+            val missingCols = orderingCols.filterNot(p.outputSet.contains)
+            val newProj = p.copy(projectList = p.projectList ++ missingCols, 
child = newChild)
+            newProj.copyTagsFrom(p)
+            Some((ordering, newProj))
+          }
+        case _ => None
+      }
+    case _ => None
   }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/InsertSortForLimitAndOffsetSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/InsertSortForLimitAndOffsetSuite.scala
index 8d640a1840f4..d1b11a74cf35 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/InsertSortForLimitAndOffsetSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/InsertSortForLimitAndOffsetSuite.scala
@@ -17,10 +17,13 @@
 
 package org.apache.spark.sql.execution
 
-import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.{Dataset, QueryTest}
+import org.apache.spark.sql.IntegratedUDFTestUtils._
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+import org.apache.spark.sql.functions.rand
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.IntegerType
 
 class InsertSortForLimitAndOffsetSuite extends QueryTest
   with SharedSparkSession
@@ -51,6 +54,7 @@ class InsertSortForLimitAndOffsetSuite extends QueryTest
   private def hasLocalSort(plan: SparkPlan): Boolean = {
     find(plan) {
       case GlobalLimitExec(_, s: SortExec, _) => !s.global
+      case GlobalLimitExec(_, ProjectExec(_, s: SortExec), _) => !s.global
       case _ => false
     }.isDefined
   }
@@ -91,12 +95,16 @@ class InsertSortForLimitAndOffsetSuite extends QueryTest
       // one partition to read the range-partition shuffle and there is only 
one shuffle block for
       // the final single-partition shuffle, random fetch order is no longer 
an issue.
       SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "false") {
-      val df = spark.range(10).orderBy($"id" % 8).limit(2).distinct()
-      df.collect()
-      val physicalPlan = df.queryExecution.executedPlan
-      assertHasGlobalLimitExec(physicalPlan)
-      // Extra local sort is needed for middle LIMIT
-      assert(hasLocalSort(physicalPlan))
+      val df = 1.to(10).map(v => v -> v).toDF("c1", "c2").orderBy($"c1" % 8)
+      verifySortAdded(df.limit(2))
+      verifySortAdded(df.filter($"c2" > rand()).limit(2))
+      verifySortAdded(df.select($"c2").limit(2))
+      verifySortAdded(df.filter($"c2" > rand()).select($"c2").limit(2))
+
+      assume(shouldTestPythonUDFs)
+      val pythonTestUDF = TestPythonUDF(name = "pyUDF", Some(IntegerType))
+      verifySortAdded(df.filter(pythonTestUDF($"c2") > rand()).limit(2))
+      verifySortAdded(df.select(pythonTestUDF($"c2")).limit(2))
     }
   }
 
@@ -110,11 +118,28 @@ class InsertSortForLimitAndOffsetSuite extends QueryTest
   }
 
   test("middle OFFSET preserves data ordering with the extra sort") {
-    val df = spark.range(10).orderBy($"id" % 8).offset(2).distinct()
-    df.collect()
-    val physicalPlan = df.queryExecution.executedPlan
+    val df = 1.to(10).map(v => v -> v).toDF("c1", "c2").orderBy($"c1" % 8)
+    verifySortAdded(df.offset(2))
+    verifySortAdded(df.filter($"c2" > rand()).offset(2))
+    verifySortAdded(df.select($"c2").offset(2))
+    verifySortAdded(df.filter($"c2" > rand()).select($"c2").offset(2))
+
+    assume(shouldTestPythonUDFs)
+    val pythonTestUDF = TestPythonUDF(name = "pyUDF", Some(IntegerType))
+    verifySortAdded(df.filter(pythonTestUDF($"c2") > rand()).offset(2))
+    verifySortAdded(df.select(pythonTestUDF($"c2")).offset(2))
+  }
+
+  private def verifySortAdded(df: Dataset[_]): Unit = {
+    // Do distinct to trigger a shuffle, so that the LIMIT/OFFSET below won't 
be planned as
+    // `CollectLimitExec`
+    val shuffled = df.distinct()
+    shuffled.collect()
+    val physicalPlan = shuffled.queryExecution.executedPlan
     assertHasGlobalLimitExec(physicalPlan)
-    // Extra local sort is needed for middle OFFSET
+    // Extra local sort is needed for middle LIMIT/OFFSET
     assert(hasLocalSort(physicalPlan))
+    // Make sure the schema does not change.
+    assert(physicalPlan.schema == shuffled.schema)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to