This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 8b209d188f6c [SPARK-50124][SQL][FOLLOWUP] InsertSortForLimitAndOffset
should propagate missing ordering columns
8b209d188f6c is described below
commit 8b209d188f6c3cb8ca59d41625c49f9a10bb16f2
Author: Wenchen Fan <[email protected]>
AuthorDate: Fri Jan 10 16:23:01 2025 +0800
[SPARK-50124][SQL][FOLLOWUP] InsertSortForLimitAndOffset should propagate
missing ordering columns
### What changes were proposed in this pull request?
This is a followup of https://github.com/apache/spark/pull/48661 to extend
the fix to handle a common query pattern: doing project/filter after the sort.
The implementation is a bit complicated as now we need to propagate the
ordering columns pre-shuffle, so that we can perform the local sort
post-shuffle.
### Why are the changes needed?
extend the bug fix to cover more query patterns.
### Does this PR introduce _any_ user-facing change?
No, the fix is not released
### How was this patch tested?
new tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #49416 from cloud-fan/fix.
Authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../execution/InsertSortForLimitAndOffset.scala | 81 ++++++++++++++--------
.../InsertSortForLimitAndOffsetSuite.scala | 47 ++++++++++---
2 files changed, 90 insertions(+), 38 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/InsertSortForLimitAndOffset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/InsertSortForLimitAndOffset.scala
index fa63e04d91b0..aa29128cda7e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/InsertSortForLimitAndOffset.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/InsertSortForLimitAndOffset.scala
@@ -17,13 +17,12 @@
package org.apache.spark.sql.execution
-import scala.annotation.tailrec
-
import org.apache.spark.sql.catalyst.expressions.SortOrder
+import org.apache.spark.sql.catalyst.plans.logical.{Project, Sort}
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec,
ShuffleQueryStageExec}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.python.EvalPythonExec
import org.apache.spark.sql.internal.SQLConf
/**
@@ -43,33 +42,61 @@ object InsertSortForLimitAndOffset extends Rule[SparkPlan] {
plan transform {
case l @ GlobalLimitExec(
_,
- SinglePartitionShuffleWithGlobalOrdering(ordering),
- _) =>
- val newChild = SortExec(ordering, global = false, child = l.child)
- l.withNewChildren(Seq(newChild))
- }
- }
-
- object SinglePartitionShuffleWithGlobalOrdering {
- @tailrec
- def unapply(plan: SparkPlan): Option[Seq[SortOrder]] = plan match {
- case ShuffleExchangeExec(SinglePartition,
SparkPlanWithGlobalOrdering(ordering), _, _) =>
- Some(ordering)
- case p: AQEShuffleReadExec => unapply(p.child)
- case p: ShuffleQueryStageExec => unapply(p.plan)
- case _ => None
+ // Should not match AQE shuffle stage because we only target
un-submitted stages which
+ // we can still rewrite the query plan.
+ s @ ShuffleExchangeExec(SinglePartition, child, _, _),
+ _) if child.logicalLink.isDefined =>
+ extractOrderingAndPropagateOrderingColumns(child) match {
+ case Some((ordering, newChild)) =>
+ val newShuffle = s.withNewChildren(Seq(newChild))
+ val sorted = SortExec(ordering, global = false, child = newShuffle)
+ // We must set the logical plan link to avoid losing the added
SortExec and ProjectExec
+ // during AQE re-optimization, where we turn physical plan back to
logical plan.
+ val logicalSort = Sort(ordering, global = false, child =
s.child.logicalLink.get)
+ sorted.setLogicalLink(logicalSort)
+ val projected = if (sorted.output == s.output) {
+ sorted
+ } else {
+ val p = ProjectExec(s.output, sorted)
+ p.setLogicalLink(Project(s.output, logicalSort))
+ p
+ }
+ l.withNewChildren(Seq(projected))
+ case _ => l
+ }
}
}
// Note: this is not implementing a generalized notion of "global order
preservation", but just
- // tackles the regular ORDER BY semantics with optional LIMIT (top-K).
- object SparkPlanWithGlobalOrdering {
- @tailrec
- def unapply(plan: SparkPlan): Option[Seq[SortOrder]] = plan match {
- case p: SortExec if p.global => Some(p.sortOrder)
- case p: LocalLimitExec => unapply(p.child)
- case p: WholeStageCodegenExec => unapply(p.child)
- case _ => None
- }
+ // a best effort to catch the common query patterns that the data ordering
should be preserved.
+ private def extractOrderingAndPropagateOrderingColumns(
+ plan: SparkPlan): Option[(Seq[SortOrder], SparkPlan)] = plan match {
+ case p: SortExec if p.global => Some(p.sortOrder, p)
+ case p: UnaryExecNode if
+ p.isInstanceOf[LocalLimitExec] ||
+ p.isInstanceOf[WholeStageCodegenExec] ||
+ p.isInstanceOf[FilterExec] ||
+ p.isInstanceOf[EvalPythonExec] =>
+ extractOrderingAndPropagateOrderingColumns(p.child) match {
+ case Some((ordering, newChild)) => Some((ordering,
p.withNewChildren(Seq(newChild))))
+ case _ => None
+ }
+ case p: ProjectExec =>
+ extractOrderingAndPropagateOrderingColumns(p.child) match {
+ case Some((ordering, newChild)) =>
+ val orderingCols = ordering.flatMap(_.references)
+ if (orderingCols.forall(p.outputSet.contains)) {
+ Some((ordering, p.withNewChildren(Seq(newChild))))
+ } else {
+ // In order to do the sort after shuffle, we must propagate the
ordering columns in the
+ // pre-shuffle ProjectExec.
+ val missingCols = orderingCols.filterNot(p.outputSet.contains)
+ val newProj = p.copy(projectList = p.projectList ++ missingCols,
child = newChild)
+ newProj.copyTagsFrom(p)
+ Some((ordering, newProj))
+ }
+ case _ => None
+ }
+ case _ => None
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/InsertSortForLimitAndOffsetSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/InsertSortForLimitAndOffsetSuite.scala
index 8d640a1840f4..d1b11a74cf35 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/InsertSortForLimitAndOffsetSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/InsertSortForLimitAndOffsetSuite.scala
@@ -17,10 +17,13 @@
package org.apache.spark.sql.execution
-import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.{Dataset, QueryTest}
+import org.apache.spark.sql.IntegratedUDFTestUtils._
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+import org.apache.spark.sql.functions.rand
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.IntegerType
class InsertSortForLimitAndOffsetSuite extends QueryTest
with SharedSparkSession
@@ -51,6 +54,7 @@ class InsertSortForLimitAndOffsetSuite extends QueryTest
private def hasLocalSort(plan: SparkPlan): Boolean = {
find(plan) {
case GlobalLimitExec(_, s: SortExec, _) => !s.global
+ case GlobalLimitExec(_, ProjectExec(_, s: SortExec), _) => !s.global
case _ => false
}.isDefined
}
@@ -91,12 +95,16 @@ class InsertSortForLimitAndOffsetSuite extends QueryTest
// one partition to read the range-partition shuffle and there is only
one shuffle block for
// the final single-partition shuffle, random fetch order is no longer
an issue.
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "false") {
- val df = spark.range(10).orderBy($"id" % 8).limit(2).distinct()
- df.collect()
- val physicalPlan = df.queryExecution.executedPlan
- assertHasGlobalLimitExec(physicalPlan)
- // Extra local sort is needed for middle LIMIT
- assert(hasLocalSort(physicalPlan))
+ val df = 1.to(10).map(v => v -> v).toDF("c1", "c2").orderBy($"c1" % 8)
+ verifySortAdded(df.limit(2))
+ verifySortAdded(df.filter($"c2" > rand()).limit(2))
+ verifySortAdded(df.select($"c2").limit(2))
+ verifySortAdded(df.filter($"c2" > rand()).select($"c2").limit(2))
+
+ assume(shouldTestPythonUDFs)
+ val pythonTestUDF = TestPythonUDF(name = "pyUDF", Some(IntegerType))
+ verifySortAdded(df.filter(pythonTestUDF($"c2") > rand()).limit(2))
+ verifySortAdded(df.select(pythonTestUDF($"c2")).limit(2))
}
}
@@ -110,11 +118,28 @@ class InsertSortForLimitAndOffsetSuite extends QueryTest
}
test("middle OFFSET preserves data ordering with the extra sort") {
- val df = spark.range(10).orderBy($"id" % 8).offset(2).distinct()
- df.collect()
- val physicalPlan = df.queryExecution.executedPlan
+ val df = 1.to(10).map(v => v -> v).toDF("c1", "c2").orderBy($"c1" % 8)
+ verifySortAdded(df.offset(2))
+ verifySortAdded(df.filter($"c2" > rand()).offset(2))
+ verifySortAdded(df.select($"c2").offset(2))
+ verifySortAdded(df.filter($"c2" > rand()).select($"c2").offset(2))
+
+ assume(shouldTestPythonUDFs)
+ val pythonTestUDF = TestPythonUDF(name = "pyUDF", Some(IntegerType))
+ verifySortAdded(df.filter(pythonTestUDF($"c2") > rand()).offset(2))
+ verifySortAdded(df.select(pythonTestUDF($"c2")).offset(2))
+ }
+
+ private def verifySortAdded(df: Dataset[_]): Unit = {
+ // Do distinct to trigger a shuffle, so that the LIMIT/OFFSET below won't
be planned as
+ // `CollectLimitExec`
+ val shuffled = df.distinct()
+ shuffled.collect()
+ val physicalPlan = shuffled.queryExecution.executedPlan
assertHasGlobalLimitExec(physicalPlan)
- // Extra local sort is needed for middle OFFSET
+ // Extra local sort is needed for middle LIMIT/OFFSET
assert(hasLocalSort(physicalPlan))
+ // Make sure the schema does not change.
+ assert(physicalPlan.schema == shuffled.schema)
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]