This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0e94f340a63a [SPARK-46378][SQL][FOLLOWUP] Do not rely on TreeNodeTag 
in Project
0e94f340a63a is described below

commit 0e94f340a63af07f1b105c61e3f884993ee305e6
Author: Wenchen Fan <wenc...@databricks.com>
AuthorDate: Thu Dec 21 16:23:25 2023 +0800

    [SPARK-46378][SQL][FOLLOWUP] Do not rely on TreeNodeTag in Project
    
    ### What changes were proposed in this pull request?
    
    This is a followup of https://github.com/apache/spark/pull/44310 . It turns 
out that `TreeNodeTag` in `Project` is way too fragile. `Project` is a very 
basic node and very easy to get removed/transformed during plan optimization.
    
    This PR switches to a different approach: since we can't retain the 
information (input data order doesn't matter) from `Aggregate`, let's leverage 
this information immediately. We pull out the expensive part of 
`EliminateSorts` to a new rule, so that we can safely call `EliminateSorts` 
right before we turn `Aggregate` into `Project`.
    
    ### Why are the changes needed?
    
    to make the optimizer more robust.
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    existing tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #44429 from cloud-fan/sort.
    
    Authored-by: Wenchen Fan <wenc...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/optimizer/Optimizer.scala   | 35 ++++--------
 .../catalyst/optimizer/RemoveRedundantSorts.scala  | 62 ++++++++++++++++++++++
 .../plans/logical/basicLogicalOperators.scala      |  3 --
 .../catalyst/optimizer/EliminateSortsSuite.scala   |  3 +-
 .../datasources/V1WriteCommandSuite.scala          | 54 ++++++++++++-------
 5 files changed, 111 insertions(+), 46 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 5a19c5e3c241..1a831b958ef2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -211,7 +211,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
     Batch("Join Reorder", FixedPoint(1),
       CostBasedJoinReorder) :+
     Batch("Eliminate Sorts", Once,
-      EliminateSorts) :+
+      EliminateSorts,
+      RemoveRedundantSorts) :+
     Batch("Decimal Optimizations", fixedPoint,
       DecimalAggregates) :+
     // This batch must run after "Decimal Optimizations", as that one may 
change the
@@ -771,11 +772,11 @@ object LimitPushDown extends Rule[LogicalPlan] {
       LocalLimit(exp, project.copy(child = pushLocalLimitThroughJoin(exp, 
join)))
     // Push down limit 1 through Aggregate and turn Aggregate into Project if 
it is group only.
     case Limit(le @ IntegerLiteral(1), a: Aggregate) if a.groupOnly =>
-      val project = Project(a.aggregateExpressions, LocalLimit(le, a.child))
-      project.setTagValue(Project.dataOrderIrrelevantTag, ())
-      Limit(le, project)
+      val newAgg = EliminateSorts(a.copy(child = LocalLimit(le, 
a.child))).asInstanceOf[Aggregate]
+      Limit(le, Project(newAgg.aggregateExpressions, newAgg.child))
     case Limit(le @ IntegerLiteral(1), p @ Project(_, a: Aggregate)) if 
a.groupOnly =>
-      Limit(le, p.copy(child = Project(a.aggregateExpressions, LocalLimit(le, 
a.child))))
+      val newAgg = EliminateSorts(a.copy(child = LocalLimit(le, 
a.child))).asInstanceOf[Aggregate]
+      Limit(le, p.copy(child = Project(newAgg.aggregateExpressions, 
newAgg.child)))
     // Merge offset value and limit value into LocalLimit and pushes down 
LocalLimit through Offset.
     case LocalLimit(le, Offset(oe, grandChild)) =>
       Offset(oe, LocalLimit(Add(le, oe), grandChild))
@@ -1557,38 +1558,30 @@ object CombineFilters extends Rule[LogicalPlan] with 
PredicateHelper {
  * Note that changes in the final output ordering may affect the file size 
(SPARK-32318).
  * This rule handles the following cases:
  * 1) if the sort order is empty or the sort order does not have any reference
- * 2) if the Sort operator is a local sort and the child is already sorted
- * 3) if there is another Sort operator separated by 0...n Project, Filter, 
Repartition or
+ * 2) if there is another Sort operator separated by 0...n Project, Filter, 
Repartition or
  *    RepartitionByExpression, RebalancePartitions (with deterministic 
expressions) operators
- * 4) if the Sort operator is within Join separated by 0...n Project, Filter, 
Repartition or
+ * 3) if the Sort operator is within Join separated by 0...n Project, Filter, 
Repartition or
  *    RepartitionByExpression, RebalancePartitions (with deterministic 
expressions) operators only
  *    and the Join condition is deterministic
- * 5) if the Sort operator is within GroupBy separated by 0...n Project, 
Filter, Repartition or
+ * 4) if the Sort operator is within GroupBy separated by 0...n Project, 
Filter, Repartition or
  *    RepartitionByExpression, RebalancePartitions (with deterministic 
expressions) operators only
  *    and the aggregate function is order irrelevant
  */
 object EliminateSorts extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
-    _.containsPattern(SORT))(applyLocally)
-
-  private val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
+  def apply(plan: LogicalPlan): LogicalPlan = 
plan.transformUpWithPruning(_.containsPattern(SORT)) {
     case s @ Sort(orders, _, child) if orders.isEmpty || 
orders.exists(_.child.foldable) =>
       val newOrders = orders.filterNot(_.child.foldable)
       if (newOrders.isEmpty) {
-        applyLocally.lift(child).getOrElse(child)
+        child
       } else {
         s.copy(order = newOrders)
       }
-    case Sort(orders, false, child) if 
SortOrder.orderingSatisfies(child.outputOrdering, orders) =>
-      applyLocally.lift(child).getOrElse(child)
     case s @ Sort(_, global, child) => s.copy(child = 
recursiveRemoveSort(child, global))
     case j @ Join(originLeft, originRight, _, cond, _) if 
cond.forall(_.deterministic) =>
       j.copy(left = recursiveRemoveSort(originLeft, true),
         right = recursiveRemoveSort(originRight, true))
     case g @ Aggregate(_, aggs, originChild) if isOrderIrrelevantAggs(aggs) =>
       g.copy(child = recursiveRemoveSort(originChild, true))
-    case p: Project if p.getTagValue(Project.dataOrderIrrelevantTag).isDefined 
=>
-      p.copy(child = recursiveRemoveSort(p.child, true))
   }
 
   /**
@@ -1604,12 +1597,6 @@ object EliminateSorts extends Rule[LogicalPlan] {
     plan match {
       case Sort(_, global, child) if canRemoveGlobalSort || !global =>
         recursiveRemoveSort(child, canRemoveGlobalSort)
-      case Sort(sortOrder, true, child) =>
-        // For this case, the upper sort is local so the ordering of present 
sort is unnecessary,
-        // so here we only preserve its output partitioning using 
`RepartitionByExpression`.
-        // We should use `None` as the optNumPartitions so AQE can coalesce 
shuffle partitions.
-        // This behavior is same with original global sort.
-        RepartitionByExpression(sortOrder, recursiveRemoveSort(child, true), 
None)
       case other if canEliminateSort(other) =>
         other.withNewChildren(other.children.map(c => recursiveRemoveSort(c, 
canRemoveGlobalSort)))
       case other if canEliminateGlobalSort(other) =>
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSorts.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSorts.scala
new file mode 100644
index 000000000000..204d2a34675b
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSorts.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions.SortOrder
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, 
RepartitionByExpression, Sort}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.SORT
+
+/**
+ * Remove redundant local [[Sort]] from the logical plan if its child is 
already sorted, and also
+ * rewrite global [[Sort]] under local [[Sort]] into 
[[RepartitionByExpression]].
+ */
+object RemoveRedundantSorts extends Rule[LogicalPlan] {
+  override def apply(plan: LogicalPlan): LogicalPlan = {
+    recursiveRemoveSort(plan, optimizeGlobalSort = false)
+  }
+
+  private def recursiveRemoveSort(plan: LogicalPlan, optimizeGlobalSort: 
Boolean): LogicalPlan = {
+    if (!plan.containsPattern(SORT)) {
+      return plan
+    }
+    plan match {
+      case s @ Sort(orders, false, child) =>
+        if (SortOrder.orderingSatisfies(child.outputOrdering, orders)) {
+          recursiveRemoveSort(child, optimizeGlobalSort = false)
+        } else {
+          s.withNewChildren(Seq(recursiveRemoveSort(child, optimizeGlobalSort 
= true)))
+        }
+
+      case s @ Sort(orders, true, child) =>
+        val newChild = recursiveRemoveSort(child, optimizeGlobalSort = false)
+        if (optimizeGlobalSort) {
+          // For this case, the upper sort is local so the ordering of present 
sort is unnecessary,
+          // so here we only preserve its output partitioning using 
`RepartitionByExpression`.
+          // We should use `None` as the optNumPartitions so AQE can coalesce 
shuffle partitions.
+          // This behavior is same with original global sort.
+          RepartitionByExpression(orders, newChild, None)
+        } else {
+          s.withNewChildren(Seq(newChild))
+        }
+
+      case _ =>
+        plan.withNewChildren(plan.children.map(recursiveRemoveSort(_, 
optimizeGlobalSort = false)))
+    }
+  }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 65f4151c0c96..497f485b67fe 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -101,9 +101,6 @@ case class Project(projectList: Seq[NamedExpression], 
child: LogicalPlan)
 
 object Project {
   val hiddenOutputTag: TreeNodeTag[Seq[Attribute]] = 
TreeNodeTag[Seq[Attribute]]("hidden_output")
-  // Project with this tag means it doesn't care about the data order of its 
input. We only set
-  // this tag when the Project was converted from grouping-only Aggregate.
-  val dataOrderIrrelevantTag: TreeNodeTag[Unit] = 
TreeNodeTag[Unit]("data_order_irrelevant")
 
   def matchSchema(plan: LogicalPlan, schema: StructType, conf: SQLConf): 
Project = {
     assert(plan.resolved)
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala
index c6312fa1b1aa..5cfe4a7bf462 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala
@@ -39,7 +39,8 @@ class EliminateSortsSuite extends AnalysisTest {
         FoldablePropagation,
         LimitPushDown) ::
       Batch("Eliminate Sorts", Once,
-        EliminateSorts) ::
+        EliminateSorts,
+        RemoveRedundantSorts) ::
       Batch("Collapse Project", Once,
         CollapseProject) :: Nil
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
index ce43edb79c12..3ca516463d36 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
@@ -223,15 +223,24 @@ class V1WriteCommandSuite extends QueryTest with 
SharedSparkSession with V1Write
           }
 
           // assert the outer most sort in the executed plan
-          assert(plan.collectFirst {
-            case s: SortExec => s
-          }.exists {
-            case SortExec(Seq(
-              SortOrder(AttributeReference("key", IntegerType, _, _), 
Ascending, NullsFirst, _),
-              SortOrder(AttributeReference("value", StringType, _, _), 
Ascending, NullsFirst, _)
-            ), false, _, _) => true
-            case _ => false
-          }, plan)
+          val sort = plan.collectFirst { case s: SortExec => s }
+          if (enabled) {
+            // With planned write, optimizer is more efficient and can 
eliminate the `SORT BY`.
+            assert(sort.exists {
+              case SortExec(Seq(
+                SortOrder(AttributeReference("key", IntegerType, _, _), 
Ascending, NullsFirst, _)
+              ), false, _, _) => true
+              case _ => false
+            }, plan)
+          } else {
+            assert(sort.exists {
+              case SortExec(Seq(
+                SortOrder(AttributeReference("key", IntegerType, _, _), 
Ascending, NullsFirst, _),
+                SortOrder(AttributeReference("value", StringType, _, _), 
Ascending, NullsFirst, _)
+              ), false, _, _) => true
+              case _ => false
+            }, plan)
+          }
         }
       }
     }
@@ -270,15 +279,24 @@ class V1WriteCommandSuite extends QueryTest with 
SharedSparkSession with V1Write
         }
 
         // assert the outer most sort in the executed plan
-        assert(plan.collectFirst {
-          case s: SortExec => s
-        }.exists {
-          case SortExec(Seq(
-            SortOrder(AttributeReference("value", StringType, _, _), 
Ascending, NullsFirst, _),
-            SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, 
NullsFirst, _)
-          ), false, _, _) => true
-          case _ => false
-        }, plan)
+        val sort = plan.collectFirst { case s: SortExec => s }
+        if (enabled) {
+          // With planned write, optimizer is more efficient and can eliminate 
the `SORT BY`.
+          assert(sort.exists {
+            case SortExec(Seq(
+              SortOrder(AttributeReference("value", StringType, _, _), 
Ascending, NullsFirst, _)
+            ), false, _, _) => true
+            case _ => false
+          }, plan)
+        } else {
+          assert(sort.exists {
+            case SortExec(Seq(
+              SortOrder(AttributeReference("value", StringType, _, _), 
Ascending, NullsFirst, _),
+              SortOrder(AttributeReference("key", IntegerType, _, _), 
Ascending, NullsFirst, _)
+            ), false, _, _) => true
+            case _ => false
+          }, plan)
+        }
       }
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to