This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 0e94f340a63a [SPARK-46378][SQL][FOLLOWUP] Do not rely on TreeNodeTag in Project 0e94f340a63a is described below commit 0e94f340a63af07f1b105c61e3f884993ee305e6 Author: Wenchen Fan <wenc...@databricks.com> AuthorDate: Thu Dec 21 16:23:25 2023 +0800 [SPARK-46378][SQL][FOLLOWUP] Do not rely on TreeNodeTag in Project ### What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/44310 . It turns out that `TreeNodeTag` in `Project` is way too fragile. `Project` is a very basic node and very easy to get removed/transformed during plan optimization. This PR switches to a different approach: since we can't retain the information (input data order doesn't matter) from `Aggregate`, let's leverage this information immediately. We pull out the expensive part of `EliminateSorts` to a new rule, so that we can safely call `EliminateSorts` right before we turn `Aggregate` into `Project`. ### Why are the changes needed? to make the optimizer more robust. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #44429 from cloud-fan/sort. Authored-by: Wenchen Fan <wenc...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 35 ++++-------- .../catalyst/optimizer/RemoveRedundantSorts.scala | 62 ++++++++++++++++++++++ .../plans/logical/basicLogicalOperators.scala | 3 -- .../catalyst/optimizer/EliminateSortsSuite.scala | 3 +- .../datasources/V1WriteCommandSuite.scala | 54 ++++++++++++------- 5 files changed, 111 insertions(+), 46 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5a19c5e3c241..1a831b958ef2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -211,7 +211,8 @@ abstract class Optimizer(catalogManager: CatalogManager) Batch("Join Reorder", FixedPoint(1), CostBasedJoinReorder) :+ Batch("Eliminate Sorts", Once, - EliminateSorts) :+ + EliminateSorts, + RemoveRedundantSorts) :+ Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :+ // This batch must run after "Decimal Optimizations", as that one may change the @@ -771,11 +772,11 @@ object LimitPushDown extends Rule[LogicalPlan] { LocalLimit(exp, project.copy(child = pushLocalLimitThroughJoin(exp, join))) // Push down limit 1 through Aggregate and turn Aggregate into Project if it is group only. case Limit(le @ IntegerLiteral(1), a: Aggregate) if a.groupOnly => - val project = Project(a.aggregateExpressions, LocalLimit(le, a.child)) - project.setTagValue(Project.dataOrderIrrelevantTag, ()) - Limit(le, project) + val newAgg = EliminateSorts(a.copy(child = LocalLimit(le, a.child))).asInstanceOf[Aggregate] + Limit(le, Project(newAgg.aggregateExpressions, newAgg.child)) case Limit(le @ IntegerLiteral(1), p @ Project(_, a: Aggregate)) if a.groupOnly => - Limit(le, p.copy(child = Project(a.aggregateExpressions, LocalLimit(le, a.child)))) + val newAgg = EliminateSorts(a.copy(child = LocalLimit(le, a.child))).asInstanceOf[Aggregate] + Limit(le, p.copy(child = Project(newAgg.aggregateExpressions, newAgg.child))) // Merge offset value and limit value into LocalLimit and pushes down LocalLimit through Offset. case LocalLimit(le, Offset(oe, grandChild)) => Offset(oe, LocalLimit(Add(le, oe), grandChild)) @@ -1557,38 +1558,30 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper { * Note that changes in the final output ordering may affect the file size (SPARK-32318). * This rule handles the following cases: * 1) if the sort order is empty or the sort order does not have any reference - * 2) if the Sort operator is a local sort and the child is already sorted - * 3) if there is another Sort operator separated by 0...n Project, Filter, Repartition or + * 2) if there is another Sort operator separated by 0...n Project, Filter, Repartition or * RepartitionByExpression, RebalancePartitions (with deterministic expressions) operators - * 4) if the Sort operator is within Join separated by 0...n Project, Filter, Repartition or + * 3) if the Sort operator is within Join separated by 0...n Project, Filter, Repartition or * RepartitionByExpression, RebalancePartitions (with deterministic expressions) operators only * and the Join condition is deterministic - * 5) if the Sort operator is within GroupBy separated by 0...n Project, Filter, Repartition or + * 4) if the Sort operator is within GroupBy separated by 0...n Project, Filter, Repartition or * RepartitionByExpression, RebalancePartitions (with deterministic expressions) operators only * and the aggregate function is order irrelevant */ object EliminateSorts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( - _.containsPattern(SORT))(applyLocally) - - private val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(_.containsPattern(SORT)) { case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) => val newOrders = orders.filterNot(_.child.foldable) if (newOrders.isEmpty) { - applyLocally.lift(child).getOrElse(child) + child } else { s.copy(order = newOrders) } - case Sort(orders, false, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) => - applyLocally.lift(child).getOrElse(child) case s @ Sort(_, global, child) => s.copy(child = recursiveRemoveSort(child, global)) case j @ Join(originLeft, originRight, _, cond, _) if cond.forall(_.deterministic) => j.copy(left = recursiveRemoveSort(originLeft, true), right = recursiveRemoveSort(originRight, true)) case g @ Aggregate(_, aggs, originChild) if isOrderIrrelevantAggs(aggs) => g.copy(child = recursiveRemoveSort(originChild, true)) - case p: Project if p.getTagValue(Project.dataOrderIrrelevantTag).isDefined => - p.copy(child = recursiveRemoveSort(p.child, true)) } /** @@ -1604,12 +1597,6 @@ object EliminateSorts extends Rule[LogicalPlan] { plan match { case Sort(_, global, child) if canRemoveGlobalSort || !global => recursiveRemoveSort(child, canRemoveGlobalSort) - case Sort(sortOrder, true, child) => - // For this case, the upper sort is local so the ordering of present sort is unnecessary, - // so here we only preserve its output partitioning using `RepartitionByExpression`. - // We should use `None` as the optNumPartitions so AQE can coalesce shuffle partitions. - // This behavior is same with original global sort. - RepartitionByExpression(sortOrder, recursiveRemoveSort(child, true), None) case other if canEliminateSort(other) => other.withNewChildren(other.children.map(c => recursiveRemoveSort(c, canRemoveGlobalSort))) case other if canEliminateGlobalSort(other) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSorts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSorts.scala new file mode 100644 index 000000000000..204d2a34675b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSorts.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, RepartitionByExpression, Sort} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.SORT + +/** + * Remove redundant local [[Sort]] from the logical plan if its child is already sorted, and also + * rewrite global [[Sort]] under local [[Sort]] into [[RepartitionByExpression]]. + */ +object RemoveRedundantSorts extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = { + recursiveRemoveSort(plan, optimizeGlobalSort = false) + } + + private def recursiveRemoveSort(plan: LogicalPlan, optimizeGlobalSort: Boolean): LogicalPlan = { + if (!plan.containsPattern(SORT)) { + return plan + } + plan match { + case s @ Sort(orders, false, child) => + if (SortOrder.orderingSatisfies(child.outputOrdering, orders)) { + recursiveRemoveSort(child, optimizeGlobalSort = false) + } else { + s.withNewChildren(Seq(recursiveRemoveSort(child, optimizeGlobalSort = true))) + } + + case s @ Sort(orders, true, child) => + val newChild = recursiveRemoveSort(child, optimizeGlobalSort = false) + if (optimizeGlobalSort) { + // For this case, the upper sort is local so the ordering of present sort is unnecessary, + // so here we only preserve its output partitioning using `RepartitionByExpression`. + // We should use `None` as the optNumPartitions so AQE can coalesce shuffle partitions. + // This behavior is same with original global sort. + RepartitionByExpression(orders, newChild, None) + } else { + s.withNewChildren(Seq(newChild)) + } + + case _ => + plan.withNewChildren(plan.children.map(recursiveRemoveSort(_, optimizeGlobalSort = false))) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 65f4151c0c96..497f485b67fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -101,9 +101,6 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) object Project { val hiddenOutputTag: TreeNodeTag[Seq[Attribute]] = TreeNodeTag[Seq[Attribute]]("hidden_output") - // Project with this tag means it doesn't care about the data order of its input. We only set - // this tag when the Project was converted from grouping-only Aggregate. - val dataOrderIrrelevantTag: TreeNodeTag[Unit] = TreeNodeTag[Unit]("data_order_irrelevant") def matchSchema(plan: LogicalPlan, schema: StructType, conf: SQLConf): Project = { assert(plan.resolved) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index c6312fa1b1aa..5cfe4a7bf462 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -39,7 +39,8 @@ class EliminateSortsSuite extends AnalysisTest { FoldablePropagation, LimitPushDown) :: Batch("Eliminate Sorts", Once, - EliminateSorts) :: + EliminateSorts, + RemoveRedundantSorts) :: Batch("Collapse Project", Once, CollapseProject) :: Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala index ce43edb79c12..3ca516463d36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala @@ -223,15 +223,24 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write } // assert the outer most sort in the executed plan - assert(plan.collectFirst { - case s: SortExec => s - }.exists { - case SortExec(Seq( - SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _), - SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _) - ), false, _, _) => true - case _ => false - }, plan) + val sort = plan.collectFirst { case s: SortExec => s } + if (enabled) { + // With planned write, optimizer is more efficient and can eliminate the `SORT BY`. + assert(sort.exists { + case SortExec(Seq( + SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _) + ), false, _, _) => true + case _ => false + }, plan) + } else { + assert(sort.exists { + case SortExec(Seq( + SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _), + SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _) + ), false, _, _) => true + case _ => false + }, plan) + } } } } @@ -270,15 +279,24 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write } // assert the outer most sort in the executed plan - assert(plan.collectFirst { - case s: SortExec => s - }.exists { - case SortExec(Seq( - SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _), - SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _) - ), false, _, _) => true - case _ => false - }, plan) + val sort = plan.collectFirst { case s: SortExec => s } + if (enabled) { + // With planned write, optimizer is more efficient and can eliminate the `SORT BY`. + assert(sort.exists { + case SortExec(Seq( + SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _) + ), false, _, _) => true + case _ => false + }, plan) + } else { + assert(sort.exists { + case SortExec(Seq( + SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _), + SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _) + ), false, _, _) => true + case _ => false + }, plan) + } } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org