This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 0e94f340a63a [SPARK-46378][SQL][FOLLOWUP] Do not rely on TreeNodeTag
in Project
0e94f340a63a is described below
commit 0e94f340a63af07f1b105c61e3f884993ee305e6
Author: Wenchen Fan <[email protected]>
AuthorDate: Thu Dec 21 16:23:25 2023 +0800
[SPARK-46378][SQL][FOLLOWUP] Do not rely on TreeNodeTag in Project
### What changes were proposed in this pull request?
This is a followup of https://github.com/apache/spark/pull/44310 . It turns
out that `TreeNodeTag` in `Project` is way too fragile. `Project` is a very
basic node and very easy to get removed/transformed during plan optimization.
This PR switches to a different approach: since we can't retain the
information (input data order doesn't matter) from `Aggregate`, let's leverage
this information immediately. We pull out the expensive part of
`EliminateSorts` to a new rule, so that we can safely call `EliminateSorts`
right before we turn `Aggregate` into `Project`.
### Why are the changes needed?
to make the optimizer more robust.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
existing tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #44429 from cloud-fan/sort.
Authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/catalyst/optimizer/Optimizer.scala | 35 ++++--------
.../catalyst/optimizer/RemoveRedundantSorts.scala | 62 ++++++++++++++++++++++
.../plans/logical/basicLogicalOperators.scala | 3 --
.../catalyst/optimizer/EliminateSortsSuite.scala | 3 +-
.../datasources/V1WriteCommandSuite.scala | 54 ++++++++++++-------
5 files changed, 111 insertions(+), 46 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 5a19c5e3c241..1a831b958ef2 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -211,7 +211,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
Batch("Join Reorder", FixedPoint(1),
CostBasedJoinReorder) :+
Batch("Eliminate Sorts", Once,
- EliminateSorts) :+
+ EliminateSorts,
+ RemoveRedundantSorts) :+
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates) :+
// This batch must run after "Decimal Optimizations", as that one may
change the
@@ -771,11 +772,11 @@ object LimitPushDown extends Rule[LogicalPlan] {
LocalLimit(exp, project.copy(child = pushLocalLimitThroughJoin(exp,
join)))
// Push down limit 1 through Aggregate and turn Aggregate into Project if
it is group only.
case Limit(le @ IntegerLiteral(1), a: Aggregate) if a.groupOnly =>
- val project = Project(a.aggregateExpressions, LocalLimit(le, a.child))
- project.setTagValue(Project.dataOrderIrrelevantTag, ())
- Limit(le, project)
+ val newAgg = EliminateSorts(a.copy(child = LocalLimit(le,
a.child))).asInstanceOf[Aggregate]
+ Limit(le, Project(newAgg.aggregateExpressions, newAgg.child))
case Limit(le @ IntegerLiteral(1), p @ Project(_, a: Aggregate)) if
a.groupOnly =>
- Limit(le, p.copy(child = Project(a.aggregateExpressions, LocalLimit(le,
a.child))))
+ val newAgg = EliminateSorts(a.copy(child = LocalLimit(le,
a.child))).asInstanceOf[Aggregate]
+ Limit(le, p.copy(child = Project(newAgg.aggregateExpressions,
newAgg.child)))
// Merge offset value and limit value into LocalLimit and pushes down
LocalLimit through Offset.
case LocalLimit(le, Offset(oe, grandChild)) =>
Offset(oe, LocalLimit(Add(le, oe), grandChild))
@@ -1557,38 +1558,30 @@ object CombineFilters extends Rule[LogicalPlan] with
PredicateHelper {
* Note that changes in the final output ordering may affect the file size
(SPARK-32318).
* This rule handles the following cases:
* 1) if the sort order is empty or the sort order does not have any reference
- * 2) if the Sort operator is a local sort and the child is already sorted
- * 3) if there is another Sort operator separated by 0...n Project, Filter,
Repartition or
+ * 2) if there is another Sort operator separated by 0...n Project, Filter,
Repartition or
* RepartitionByExpression, RebalancePartitions (with deterministic
expressions) operators
- * 4) if the Sort operator is within Join separated by 0...n Project, Filter,
Repartition or
+ * 3) if the Sort operator is within Join separated by 0...n Project, Filter,
Repartition or
* RepartitionByExpression, RebalancePartitions (with deterministic
expressions) operators only
* and the Join condition is deterministic
- * 5) if the Sort operator is within GroupBy separated by 0...n Project,
Filter, Repartition or
+ * 4) if the Sort operator is within GroupBy separated by 0...n Project,
Filter, Repartition or
* RepartitionByExpression, RebalancePartitions (with deterministic
expressions) operators only
* and the aggregate function is order irrelevant
*/
object EliminateSorts extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
- _.containsPattern(SORT))(applyLocally)
-
- private val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
+ def apply(plan: LogicalPlan): LogicalPlan =
plan.transformUpWithPruning(_.containsPattern(SORT)) {
case s @ Sort(orders, _, child) if orders.isEmpty ||
orders.exists(_.child.foldable) =>
val newOrders = orders.filterNot(_.child.foldable)
if (newOrders.isEmpty) {
- applyLocally.lift(child).getOrElse(child)
+ child
} else {
s.copy(order = newOrders)
}
- case Sort(orders, false, child) if
SortOrder.orderingSatisfies(child.outputOrdering, orders) =>
- applyLocally.lift(child).getOrElse(child)
case s @ Sort(_, global, child) => s.copy(child =
recursiveRemoveSort(child, global))
case j @ Join(originLeft, originRight, _, cond, _) if
cond.forall(_.deterministic) =>
j.copy(left = recursiveRemoveSort(originLeft, true),
right = recursiveRemoveSort(originRight, true))
case g @ Aggregate(_, aggs, originChild) if isOrderIrrelevantAggs(aggs) =>
g.copy(child = recursiveRemoveSort(originChild, true))
- case p: Project if p.getTagValue(Project.dataOrderIrrelevantTag).isDefined
=>
- p.copy(child = recursiveRemoveSort(p.child, true))
}
/**
@@ -1604,12 +1597,6 @@ object EliminateSorts extends Rule[LogicalPlan] {
plan match {
case Sort(_, global, child) if canRemoveGlobalSort || !global =>
recursiveRemoveSort(child, canRemoveGlobalSort)
- case Sort(sortOrder, true, child) =>
- // For this case, the upper sort is local so the ordering of present
sort is unnecessary,
- // so here we only preserve its output partitioning using
`RepartitionByExpression`.
- // We should use `None` as the optNumPartitions so AQE can coalesce
shuffle partitions.
- // This behavior is same with original global sort.
- RepartitionByExpression(sortOrder, recursiveRemoveSort(child, true),
None)
case other if canEliminateSort(other) =>
other.withNewChildren(other.children.map(c => recursiveRemoveSort(c,
canRemoveGlobalSort)))
case other if canEliminateGlobalSort(other) =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSorts.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSorts.scala
new file mode 100644
index 000000000000..204d2a34675b
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSorts.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions.SortOrder
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan,
RepartitionByExpression, Sort}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.SORT
+
+/**
+ * Remove redundant local [[Sort]] from the logical plan if its child is
already sorted, and also
+ * rewrite global [[Sort]] under local [[Sort]] into
[[RepartitionByExpression]].
+ */
+object RemoveRedundantSorts extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = {
+ recursiveRemoveSort(plan, optimizeGlobalSort = false)
+ }
+
+ private def recursiveRemoveSort(plan: LogicalPlan, optimizeGlobalSort:
Boolean): LogicalPlan = {
+ if (!plan.containsPattern(SORT)) {
+ return plan
+ }
+ plan match {
+ case s @ Sort(orders, false, child) =>
+ if (SortOrder.orderingSatisfies(child.outputOrdering, orders)) {
+ recursiveRemoveSort(child, optimizeGlobalSort = false)
+ } else {
+ s.withNewChildren(Seq(recursiveRemoveSort(child, optimizeGlobalSort
= true)))
+ }
+
+ case s @ Sort(orders, true, child) =>
+ val newChild = recursiveRemoveSort(child, optimizeGlobalSort = false)
+ if (optimizeGlobalSort) {
+ // For this case, the upper sort is local so the ordering of present
sort is unnecessary,
+ // so here we only preserve its output partitioning using
`RepartitionByExpression`.
+ // We should use `None` as the optNumPartitions so AQE can coalesce
shuffle partitions.
+ // This behavior is same with original global sort.
+ RepartitionByExpression(orders, newChild, None)
+ } else {
+ s.withNewChildren(Seq(newChild))
+ }
+
+ case _ =>
+ plan.withNewChildren(plan.children.map(recursiveRemoveSort(_,
optimizeGlobalSort = false)))
+ }
+ }
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 65f4151c0c96..497f485b67fe 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -101,9 +101,6 @@ case class Project(projectList: Seq[NamedExpression],
child: LogicalPlan)
object Project {
val hiddenOutputTag: TreeNodeTag[Seq[Attribute]] =
TreeNodeTag[Seq[Attribute]]("hidden_output")
- // Project with this tag means it doesn't care about the data order of its
input. We only set
- // this tag when the Project was converted from grouping-only Aggregate.
- val dataOrderIrrelevantTag: TreeNodeTag[Unit] =
TreeNodeTag[Unit]("data_order_irrelevant")
def matchSchema(plan: LogicalPlan, schema: StructType, conf: SQLConf):
Project = {
assert(plan.resolved)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala
index c6312fa1b1aa..5cfe4a7bf462 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala
@@ -39,7 +39,8 @@ class EliminateSortsSuite extends AnalysisTest {
FoldablePropagation,
LimitPushDown) ::
Batch("Eliminate Sorts", Once,
- EliminateSorts) ::
+ EliminateSorts,
+ RemoveRedundantSorts) ::
Batch("Collapse Project", Once,
CollapseProject) :: Nil
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
index ce43edb79c12..3ca516463d36 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
@@ -223,15 +223,24 @@ class V1WriteCommandSuite extends QueryTest with
SharedSparkSession with V1Write
}
// assert the outer most sort in the executed plan
- assert(plan.collectFirst {
- case s: SortExec => s
- }.exists {
- case SortExec(Seq(
- SortOrder(AttributeReference("key", IntegerType, _, _),
Ascending, NullsFirst, _),
- SortOrder(AttributeReference("value", StringType, _, _),
Ascending, NullsFirst, _)
- ), false, _, _) => true
- case _ => false
- }, plan)
+ val sort = plan.collectFirst { case s: SortExec => s }
+ if (enabled) {
+ // With planned write, optimizer is more efficient and can
eliminate the `SORT BY`.
+ assert(sort.exists {
+ case SortExec(Seq(
+ SortOrder(AttributeReference("key", IntegerType, _, _),
Ascending, NullsFirst, _)
+ ), false, _, _) => true
+ case _ => false
+ }, plan)
+ } else {
+ assert(sort.exists {
+ case SortExec(Seq(
+ SortOrder(AttributeReference("key", IntegerType, _, _),
Ascending, NullsFirst, _),
+ SortOrder(AttributeReference("value", StringType, _, _),
Ascending, NullsFirst, _)
+ ), false, _, _) => true
+ case _ => false
+ }, plan)
+ }
}
}
}
@@ -270,15 +279,24 @@ class V1WriteCommandSuite extends QueryTest with
SharedSparkSession with V1Write
}
// assert the outer most sort in the executed plan
- assert(plan.collectFirst {
- case s: SortExec => s
- }.exists {
- case SortExec(Seq(
- SortOrder(AttributeReference("value", StringType, _, _),
Ascending, NullsFirst, _),
- SortOrder(AttributeReference("key", IntegerType, _, _), Ascending,
NullsFirst, _)
- ), false, _, _) => true
- case _ => false
- }, plan)
+ val sort = plan.collectFirst { case s: SortExec => s }
+ if (enabled) {
+ // With planned write, optimizer is more efficient and can eliminate
the `SORT BY`.
+ assert(sort.exists {
+ case SortExec(Seq(
+ SortOrder(AttributeReference("value", StringType, _, _),
Ascending, NullsFirst, _)
+ ), false, _, _) => true
+ case _ => false
+ }, plan)
+ } else {
+ assert(sort.exists {
+ case SortExec(Seq(
+ SortOrder(AttributeReference("value", StringType, _, _),
Ascending, NullsFirst, _),
+ SortOrder(AttributeReference("key", IntegerType, _, _),
Ascending, NullsFirst, _)
+ ), false, _, _) => true
+ case _ => false
+ }, plan)
+ }
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]