This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 63c7ca4df297 [SPARK-50597][SQL] Refactor batch construction in 
Optimizer.scala and SparkOptimizer.scala
63c7ca4df297 is described below

commit 63c7ca4df2970d12574ad3b542ec17eb5276ef86
Author: Anton Lykov <[email protected]>
AuthorDate: Tue Dec 17 22:39:11 2024 +0800

    [SPARK-50597][SQL] Refactor batch construction in Optimizer.scala and 
SparkOptimizer.scala
    
    ### What changes were proposed in this pull request?
    
    See description. Previously, it was a pain to reorder batches and guard 
behavior of certain batches / sequences of batches by a flag. This was 
primarily due to ample usage of `::`, `:::`, and `:+` to juggle rules and 
batches around which imposed syntactic limitations.
    
    After this change, we keep a single sequence `allBatches`, that can contain 
either `Batch` or `Seq[Batch]` elements to allow further groupings, which is 
later flattened into a single `Seq[Batch]`.
    
    We avoid any usage of `::`, `:::`, and `:+`.
    
    To add/replace a flag-guarded batch of sequence of batches, write a 
function that returns either `Batch` of `Seq[Batch]` with desired behavior, and 
add/replace in the relevant place in the `allBatches` list.
    
    ### Why are the changes needed?
    
    This simplifies further restructuring and reordering of batches.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    No tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Closes #49208 from anton5798/batch-refactor.
    
    Authored-by: Anton Lykov <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/catalyst/optimizer/Optimizer.scala   | 113 ++++++++++++---------
 .../spark/sql/execution/SparkOptimizer.scala       |  63 ++++++------
 2 files changed, 98 insertions(+), 78 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 31c1f8917763..b141d2be04c3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -73,6 +73,21 @@ abstract class Optimizer(catalogManager: CatalogManager)
       conf.optimizerMaxIterations,
       maxIterationsSetting = SQLConf.OPTIMIZER_MAX_ITERATIONS.key)
 
+  /**
+   * A helper method that takes as input a Seq of Batch or Seq[Batch], and 
flattens it out.
+   */
+  def flattenBatches(nestedBatchSequence: Seq[Any]): Seq[Batch] = {
+    assert(nestedBatchSequence.forall {
+      case _: Batch => true
+      case s: Seq[_] => s.forall(_.isInstanceOf[Batch])
+      case _ => false
+    })
+    nestedBatchSequence.flatMap {
+      case batches: Seq[Batch @unchecked] => batches
+      case batch: Batch => Seq(batch)
+    }
+  }
+
   /**
    * Defines the default rule batches in the Optimizer.
    *
@@ -143,39 +158,38 @@ abstract class Optimizer(catalogManager: CatalogManager)
         PushdownPredicatesAndPruneColumnsForCTEDef) ++
         extendedOperatorOptimizationRules
 
-    val operatorOptimizationBatch: Seq[Batch] = {
+    val operatorOptimizationBatch: Seq[Batch] = Seq(
       Batch("Operator Optimization before Inferring Filters", fixedPoint,
-        operatorOptimizationRuleSet: _*) ::
+        operatorOptimizationRuleSet: _*),
       Batch("Infer Filters", Once,
         InferFiltersFromGenerate,
-        InferFiltersFromConstraints) ::
+        InferFiltersFromConstraints),
       Batch("Operator Optimization after Inferring Filters", fixedPoint,
-        operatorOptimizationRuleSet: _*) ::
+        operatorOptimizationRuleSet: _*),
       Batch("Push extra predicate through join", fixedPoint,
         PushExtraPredicateThroughJoin,
-        PushDownPredicates) :: Nil
-    }
+        PushDownPredicates))
 
-    val batches = (
-    Batch("Finish Analysis", FixedPoint(1), FinishAnalysis) ::
+    val batches: Seq[Batch] = flattenBatches(Seq(
+    Batch("Finish Analysis", FixedPoint(1), FinishAnalysis),
     // We must run this batch after `ReplaceExpressions`, as 
`RuntimeReplaceable` expression
     // may produce `With` expressions that need to be rewritten.
-    Batch("Rewrite With expression", fixedPoint, RewriteWithExpression) ::
+    Batch("Rewrite With expression", fixedPoint, RewriteWithExpression),
     
//////////////////////////////////////////////////////////////////////////////////////////
     // Optimizer rules start here
     
//////////////////////////////////////////////////////////////////////////////////////////
-    Batch("Eliminate Distinct", Once, EliminateDistinct) ::
+    Batch("Eliminate Distinct", Once, EliminateDistinct),
     // - Do the first call of CombineUnions before starting the major 
Optimizer rules,
     //   since it can reduce the number of iteration and the other rules could 
add/move
     //   extra operators between two adjacent Union operators.
     // - Call CombineUnions again in Batch("Operator Optimizations"),
     //   since the other rules might make two separate Unions operators 
adjacent.
     Batch("Inline CTE", Once,
-      InlineCTE()) ::
+      InlineCTE()),
     Batch("Union", fixedPoint,
       RemoveNoopOperators,
       CombineUnions,
-      RemoveNoopUnion) ::
+      RemoveNoopUnion),
     // Run this once earlier. This might simplify the plan and reduce cost of 
optimizer.
     // For example, a query such as Filter(LocalRelation) would go through all 
the heavy
     // optimizer rules that are triggered when there is a filter
@@ -186,16 +200,16 @@ abstract class Optimizer(catalogManager: CatalogManager)
       PropagateEmptyRelation,
       // PropagateEmptyRelation can change the nullability of an attribute 
from nullable to
       // non-nullable when an empty relation child of a Union is removed
-      UpdateAttributeNullability) ::
+      UpdateAttributeNullability),
     Batch("Pullup Correlated Expressions", Once,
       OptimizeOneRowRelationSubquery,
       PullOutNestedDataOuterRefExpressions,
-      PullupCorrelatedPredicates) ::
+      PullupCorrelatedPredicates),
     // Subquery batch applies the optimizer rules recursively. Therefore, it 
makes no sense
     // to enforce idempotence on it and we change this batch from Once to 
FixedPoint(1).
     Batch("Subquery", FixedPoint(1),
       OptimizeSubqueries,
-      OptimizeOneRowRelationSubquery) ::
+      OptimizeOneRowRelationSubquery),
     Batch("Replace Operators", fixedPoint,
       RewriteExceptAll,
       RewriteIntersectAll,
@@ -203,48 +217,48 @@ abstract class Optimizer(catalogManager: CatalogManager)
       ReplaceExceptWithFilter,
       ReplaceExceptWithAntiJoin,
       ReplaceDistinctWithAggregate,
-      ReplaceDeduplicateWithAggregate) ::
+      ReplaceDeduplicateWithAggregate),
     Batch("Aggregate", fixedPoint,
       RemoveLiteralFromGroupExpressions,
-      RemoveRepetitionFromGroupExpressions) :: Nil ++
-    operatorOptimizationBatch) :+
-    Batch("Clean Up Temporary CTE Info", Once, CleanUpTempCTEInfo) :+
+      RemoveRepetitionFromGroupExpressions),
+    operatorOptimizationBatch,
+    Batch("Clean Up Temporary CTE Info", Once, CleanUpTempCTEInfo),
     // This batch rewrites plans after the operator optimization and
     // before any batches that depend on stats.
-    Batch("Pre CBO Rules", Once, preCBORules: _*) :+
+    Batch("Pre CBO Rules", Once, preCBORules: _*),
     // This batch pushes filters and projections into scan nodes. Before this 
batch, the logical
     // plan may contain nodes that do not report stats. Anything that uses 
stats must run after
     // this batch.
-    Batch("Early Filter and Projection Push-Down", Once, 
earlyScanPushDownRules: _*) :+
-    Batch("Update CTE Relation Stats", Once, UpdateCTERelationStats) :+
+    Batch("Early Filter and Projection Push-Down", Once, 
earlyScanPushDownRules: _*),
+    Batch("Update CTE Relation Stats", Once, UpdateCTERelationStats),
     // Since join costs in AQP can change between multiple runs, there is no 
reason that we have an
     // idempotence enforcement on this batch. We thus make it FixedPoint(1) 
instead of Once.
     Batch("Join Reorder", FixedPoint(1),
-      CostBasedJoinReorder) :+
+      CostBasedJoinReorder),
     Batch("Eliminate Sorts", Once,
       EliminateSorts,
-      RemoveRedundantSorts) :+
+      RemoveRedundantSorts),
     Batch("Decimal Optimizations", fixedPoint,
-      DecimalAggregates) :+
+      DecimalAggregates),
     // This batch must run after "Decimal Optimizations", as that one may 
change the
     // aggregate distinct column
     Batch("Distinct Aggregate Rewrite", Once,
-      RewriteDistinctAggregates) :+
+      RewriteDistinctAggregates),
     Batch("Object Expressions Optimization", fixedPoint,
       EliminateMapObjects,
       CombineTypedFilters,
       ObjectSerializerPruning,
-      ReassignLambdaVariableID) :+
+      ReassignLambdaVariableID),
     Batch("LocalRelation", fixedPoint,
       ConvertToLocalRelation,
       PropagateEmptyRelation,
       // PropagateEmptyRelation can change the nullability of an attribute 
from nullable to
       // non-nullable when an empty relation child of a Union is removed
-      UpdateAttributeNullability) :+
-    Batch("Optimize One Row Plan", fixedPoint, OptimizeOneRowPlan) :+
+      UpdateAttributeNullability),
+    Batch("Optimize One Row Plan", fixedPoint, OptimizeOneRowPlan),
     // The following batch should be executed after batch "Join Reorder" and 
"LocalRelation".
     Batch("Check Cartesian Products", Once,
-      CheckCartesianProducts) :+
+      CheckCartesianProducts),
     Batch("RewriteSubquery", Once,
       RewritePredicateSubquery,
       PushPredicateThroughJoin,
@@ -252,10 +266,10 @@ abstract class Optimizer(catalogManager: CatalogManager)
       ColumnPruning,
       CollapseProject,
       RemoveRedundantAliases,
-      RemoveNoopOperators) :+
+      RemoveNoopOperators),
     // This batch must be executed after the `RewriteSubquery` batch, which 
creates joins.
-    Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+
-    Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression)
+    Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers),
+    Batch("ReplaceUpdateFieldsExpression", Once, 
ReplaceUpdateFieldsExpression)))
 
     // remove any batches with no rules. this may happen when subclasses do 
not add optional rules.
     batches.filter(_.rules.nonEmpty)
@@ -270,22 +284,23 @@ abstract class Optimizer(catalogManager: CatalogManager)
    * (defaultBatches - (excludedRules - nonExcludableRules)).
    */
   def nonExcludableRules: Seq[String] =
-    FinishAnalysis.ruleName ::
-      RewriteDistinctAggregates.ruleName ::
-      ReplaceDeduplicateWithAggregate.ruleName ::
-      ReplaceIntersectWithSemiJoin.ruleName ::
-      ReplaceExceptWithFilter.ruleName ::
-      ReplaceExceptWithAntiJoin.ruleName ::
-      RewriteExceptAll.ruleName ::
-      RewriteIntersectAll.ruleName ::
-      ReplaceDistinctWithAggregate.ruleName ::
-      PullupCorrelatedPredicates.ruleName ::
-      RewriteCorrelatedScalarSubquery.ruleName ::
-      RewritePredicateSubquery.ruleName ::
-      NormalizeFloatingNumbers.ruleName ::
-      ReplaceUpdateFieldsExpression.ruleName ::
-      RewriteLateralSubquery.ruleName ::
-      OptimizeSubqueries.ruleName :: Nil
+    Seq(
+      FinishAnalysis.ruleName,
+      RewriteDistinctAggregates.ruleName,
+      ReplaceDeduplicateWithAggregate.ruleName,
+      ReplaceIntersectWithSemiJoin.ruleName,
+      ReplaceExceptWithFilter.ruleName,
+      ReplaceExceptWithAntiJoin.ruleName,
+      RewriteExceptAll.ruleName,
+      RewriteIntersectAll.ruleName,
+      ReplaceDistinctWithAggregate.ruleName,
+      PullupCorrelatedPredicates.ruleName,
+      RewriteCorrelatedScalarSubquery.ruleName,
+      RewritePredicateSubquery.ruleName,
+      NormalizeFloatingNumbers.ruleName,
+      ReplaceUpdateFieldsExpression.ruleName,
+      RewriteLateralSubquery.ruleName,
+      OptimizeSubqueries.ruleName)
 
   /**
    * Apply finish-analysis rules for the entire plan including all subqueries.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index 6173703ef3cd..6ceb363b41ae 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -36,38 +36,41 @@ class SparkOptimizer(
 
   override def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] =
     // TODO: move SchemaPruning into catalyst
-    Seq(SchemaPruning) :+
-      GroupBasedRowLevelOperationScanPlanning :+
-      V1Writes :+
-      V2ScanRelationPushDown :+
-      V2ScanPartitioningAndOrdering :+
-      V2Writes :+
-      PruneFileSourcePartitions
+    Seq(
+      SchemaPruning,
+      GroupBasedRowLevelOperationScanPlanning,
+      V1Writes,
+      V2ScanRelationPushDown,
+      V2ScanPartitioningAndOrdering,
+      V2Writes,
+      PruneFileSourcePartitions)
 
   override def preCBORules: Seq[Rule[LogicalPlan]] =
-    OptimizeMetadataOnlyDeleteFromTable :: Nil
+    Seq(OptimizeMetadataOnlyDeleteFromTable)
 
-  override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ 
super.defaultBatches :+
-    Batch("Optimize Metadata Only Query", Once, 
OptimizeMetadataOnlyQuery(catalog)) :+
+  override def defaultBatches: Seq[Batch] = flattenBatches(Seq(
+    preOptimizationBatches,
+    super.defaultBatches,
+    Batch("Optimize Metadata Only Query", Once, 
OptimizeMetadataOnlyQuery(catalog)),
     Batch("PartitionPruning", Once,
       PartitionPruning,
       // We can't run `OptimizeSubqueries` in this batch, as it will optimize 
the subqueries
       // twice which may break some optimizer rules that can only be applied 
once. The rule below
       // only invokes `OptimizeSubqueries` to optimize newly added subqueries.
-      new RowLevelOperationRuntimeGroupFiltering(OptimizeSubqueries)) :+
+      new RowLevelOperationRuntimeGroupFiltering(OptimizeSubqueries)),
     Batch("InjectRuntimeFilter", FixedPoint(1),
-      InjectRuntimeFilter) :+
+      InjectRuntimeFilter),
     Batch("MergeScalarSubqueries", Once,
       MergeScalarSubqueries,
-      RewriteDistinctAggregates) :+
+      RewriteDistinctAggregates),
     Batch("Pushdown Filters from PartitionPruning", fixedPoint,
-      PushDownPredicates) :+
+      PushDownPredicates),
     Batch("Cleanup filters that cannot be pushed down", Once,
       CleanupDynamicPruningFilters,
       // cleanup the unnecessary TrueLiteral predicates
       BooleanSimplification,
-      PruneFilters)) ++
-    postHocOptimizationBatches :+
+      PruneFilters),
+    postHocOptimizationBatches,
     Batch("Extract Python UDFs", Once,
       ExtractPythonUDFFromJoinCondition,
       // `ExtractPythonUDFFromJoinCondition` can convert a join to a cartesian 
product.
@@ -84,25 +87,27 @@ class SparkOptimizer(
       LimitPushDown,
       PushPredicateThroughNonJoin,
       PushProjectionThroughLimit,
-      RemoveNoopOperators) :+
+      RemoveNoopOperators),
     Batch("Infer window group limit", Once,
       InferWindowGroupLimit,
       LimitPushDown,
       LimitPushDownThroughWindow,
       EliminateLimits,
-      ConstantFolding) :+
-    Batch("User Provided Optimizers", fixedPoint, 
experimentalMethods.extraOptimizations: _*) :+
-    Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition)
+      ConstantFolding),
+    Batch("User Provided Optimizers", fixedPoint, 
experimentalMethods.extraOptimizations: _*),
+    Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition)))
 
-  override def nonExcludableRules: Seq[String] = super.nonExcludableRules :+
-    ExtractPythonUDFFromJoinCondition.ruleName :+
-    ExtractPythonUDFFromAggregate.ruleName :+ 
ExtractGroupingPythonUDFFromAggregate.ruleName :+
-    ExtractPythonUDFs.ruleName :+
-    GroupBasedRowLevelOperationScanPlanning.ruleName :+
-    V2ScanRelationPushDown.ruleName :+
-    V2ScanPartitioningAndOrdering.ruleName :+
-    V2Writes.ruleName :+
-    ReplaceCTERefWithRepartition.ruleName
+  override def nonExcludableRules: Seq[String] = super.nonExcludableRules ++
+    Seq(
+      ExtractPythonUDFFromJoinCondition.ruleName,
+      ExtractPythonUDFFromAggregate.ruleName,
+      ExtractGroupingPythonUDFFromAggregate.ruleName,
+      ExtractPythonUDFs.ruleName,
+      GroupBasedRowLevelOperationScanPlanning.ruleName,
+      V2ScanRelationPushDown.ruleName,
+      V2ScanPartitioningAndOrdering.ruleName,
+      V2Writes.ruleName,
+      ReplaceCTERefWithRepartition.ruleName)
 
   /**
    * Optimization batches that are executed before the regular optimization 
batches (also before


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to