This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 52c0a2978 [VL] Untangle code of TransformPreOverrides (#4888)
52c0a2978 is described below

commit 52c0a29785582a446b71dbf20f5271a865fec2cf
Author: Hongze Zhang <[email protected]>
AuthorDate: Thu Mar 14 16:43:00 2024 +0800

    [VL] Untangle code of TransformPreOverrides (#4888)
---
 .../clickhouse/CHSparkPlanExecApi.scala            |   8 +-
 .../extension/columnar/MiscColumnarRules.scala     | 789 +++++++++++----------
 2 files changed, 428 insertions(+), 369 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
 
b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index 17a9e8d67..aca06cf88 100644
--- 
a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++ 
b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -197,7 +197,9 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
         if (projectExpressions.isEmpty) {
           return (0, plan.outputPartitioning, plan.child)
         }
-        val project = TransformPreOverrides().replaceWithTransformerPlan(
+        // FIXME: The operation happens inside ReplaceSingleNode().
+        //  Caller may not know it adds project on top of the shuffle.
+        val project = TransformPreOverrides().apply(
           AddTransformHintRule().apply(
             ProjectExec(plan.child.output ++ projectExpressions, plan.child)))
         var newExprs = Seq[Expression]()
@@ -220,7 +222,9 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
         if (projectExpressions.isEmpty) {
           return (0, plan.outputPartitioning, plan.child)
         }
-        val project = TransformPreOverrides().replaceWithTransformerPlan(
+        // FIXME: The operation happens inside ReplaceSingleNode().
+        //  Caller may not know it adds project on top of the shuffle.
+        val project = TransformPreOverrides().apply(
           AddTransformHintRule().apply(
             ProjectExec(plan.child.output ++ projectExpressions, plan.child)))
         var newOrderings = Seq[SortOrder]()
diff --git 
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala
 
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala
index 5d740153a..79fd37f3f 100644
--- 
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala
+++ 
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala
@@ -25,6 +25,7 @@ import io.glutenproject.sql.shims.SparkShimLoader
 import io.glutenproject.utils.{LogLevelUtil, PlanUtil}
 
 import org.apache.spark.api.python.EvalPythonExecTransformer
+import org.apache.spark.internal.Logging
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, 
BuildSide}
@@ -42,406 +43,461 @@ import org.apache.spark.sql.execution.window.WindowExec
 import org.apache.spark.sql.hive.HiveTableScanExecTransformer
 
 object MiscColumnarRules {
-  // This rule will conduct the conversion from Spark plan to the plan 
transformer.
-  case class TransformPreOverrides() extends Rule[SparkPlan] with LogLevelUtil 
{
-    val columnarConf: GlutenConfig = GlutenConfig.getConf
-    @transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]()
+  object TransformPreOverrides {
+    // Sub-rules of TransformPreOverrides.
 
-    /**
-     * Generate a plan for hash aggregation.
-     * @param plan:
-     *   the original Spark plan.
-     * @return
-     *   the actually used plan for execution.
-     */
-    private def genHashAggregateExec(plan: HashAggregateExec): SparkPlan = {
-      val newChild = replaceWithTransformerPlan(plan.child)
-      def transformHashAggregate(): GlutenPlan = {
-        BackendsApiManager.getSparkPlanExecApiInstance
-          .genHashAggregateExecTransformer(
-            plan.requiredChildDistributionExpressions,
-            plan.groupingExpressions,
-            plan.aggregateExpressions,
-            plan.aggregateAttributes,
-            plan.initialInputBufferOffset,
-            plan.resultExpressions,
-            newChild
-          )
+    // Aggregation transformation.
+    private case class AggregationTransformRule() extends Rule[SparkPlan] with 
LogLevelUtil {
+      override def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
+        case agg: HashAggregateExec =>
+          genHashAggregateExec(agg)
       }
 
-      // If child's output is empty, fallback or offload both the child and 
aggregation.
-      if (
-        plan.child.output.isEmpty && 
BackendsApiManager.getSettings.fallbackAggregateWithChild()
-      ) {
-        newChild match {
-          case _: TransformSupport =>
-            // If the child is transformable, transform aggregation as well.
-            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-            transformHashAggregate()
-          case p: SparkPlan if PlanUtil.isGlutenTableCache(p) =>
-            transformHashAggregate()
-          case _ =>
-            // If the child is not transformable, transform the grandchildren 
only.
-            TransformHints.tagNotTransformable(plan, "child output schema is 
empty")
-            val grandChildren = plan.child.children.map(child => 
replaceWithTransformerPlan(child))
-            
plan.withNewChildren(Seq(plan.child.withNewChildren(grandChildren)))
+      /**
+       * Generate a plan for hash aggregation.
+       *
+       * @param plan
+       *   : the original Spark plan.
+       * @return
+       *   the actually used plan for execution.
+       */
+      private def genHashAggregateExec(plan: HashAggregateExec): SparkPlan = {
+        if (TransformHints.isNotTransformable(plan)) {
+          return plan
         }
-      } else {
-        logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-        transformHashAggregate()
-      }
-    }
 
-    /**
-     * Generate a plan for filter.
-     * @param plan:
-     *   the original Spark plan.
-     * @return
-     *   the actually used plan for execution.
-     */
-    private def genFilterExec(plan: FilterExec): SparkPlan = {
-      // FIXME: Filter push-down should be better done by Vanilla Spark's 
planner or by
-      //  a individual rule.
-      val scan = plan.child
-      // Push down the left conditions in Filter into FileSourceScan.
-      val newChild: SparkPlan = scan match {
-        case _: FileSourceScanExec | _: BatchScanExec =>
-          if (TransformHints.isTransformable(scan)) {
-            val newScan = FilterHandler.applyFilterPushdownToScan(plan)
-            newScan match {
-              case ts: TransformSupport if ts.doValidate().isValid => ts
-              case _ => replaceWithTransformerPlan(scan)
-            }
-          } else {
-            replaceWithTransformerPlan(scan)
+        val aggChild = plan.child
+
+        def transformHashAggregate(): GlutenPlan = {
+          BackendsApiManager.getSparkPlanExecApiInstance
+            .genHashAggregateExecTransformer(
+              plan.requiredChildDistributionExpressions,
+              plan.groupingExpressions,
+              plan.aggregateExpressions,
+              plan.aggregateAttributes,
+              plan.initialInputBufferOffset,
+              plan.resultExpressions,
+              aggChild
+            )
+        }
+
+        // If child's output is empty, fallback or offload both the child and 
aggregation.
+        if (
+          plan.child.output.isEmpty && 
BackendsApiManager.getSettings.fallbackAggregateWithChild()
+        ) {
+          aggChild match {
+            case _: TransformSupport =>
+              // If the child is transformable, transform aggregation as well.
+              logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
+              transformHashAggregate()
+            case p: SparkPlan if PlanUtil.isGlutenTableCache(p) =>
+              transformHashAggregate()
+            case _ =>
+              // If the child is not transformable, do not transform the agg.
+              TransformHints.tagNotTransformable(plan, "child output schema is 
empty")
+              plan
           }
-        case _ => replaceWithTransformerPlan(plan.child)
+        } else {
+          logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
+          transformHashAggregate()
+        }
       }
-      logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-      BackendsApiManager.getSparkPlanExecApiInstance
-        .genFilterExecTransformer(plan.condition, newChild)
     }
 
-    def applyScanNotTransformable(plan: SparkPlan): SparkPlan = plan match {
-      case plan: FileSourceScanExec =>
-        val newPartitionFilters =
-          
ExpressionConverter.transformDynamicPruningExpr(plan.partitionFilters)
-        val newSource = plan.copy(partitionFilters = newPartitionFilters)
-        if (plan.logicalLink.nonEmpty) {
-          newSource.setLogicalLink(plan.logicalLink.get)
-        }
-        TransformHints.tag(newSource, TransformHints.getHint(plan))
-        newSource
-      case plan: BatchScanExec =>
-        val newPartitionFilters: Seq[Expression] = plan.scan match {
-          case scan: FileScan =>
-            
ExpressionConverter.transformDynamicPruningExpr(scan.partitionFilters)
-          case _ =>
-            
ExpressionConverter.transformDynamicPruningExpr(plan.runtimeFilters)
-        }
-        val newSource = plan.copy(runtimeFilters = newPartitionFilters)
-        if (plan.logicalLink.nonEmpty) {
-          newSource.setLogicalLink(plan.logicalLink.get)
+    // Filter transformation.
+    private case class FilterTransformRule() extends Rule[SparkPlan] with 
LogLevelUtil {
+      private val replace = new ReplaceSingleNode()
+
+      override def apply(plan: SparkPlan): SparkPlan = plan.transformDown {
+        case filter: FilterExec =>
+          genFilterExec(filter)
+      }
+
+      /**
+       * Generate a plan for filter.
+       *
+       * @param plan
+       *   : the original Spark plan.
+       * @return
+       *   the actually used plan for execution.
+       */
+      private def genFilterExec(plan: FilterExec): SparkPlan = {
+        if (TransformHints.isNotTransformable(plan)) {
+          return plan
         }
-        TransformHints.tag(newSource, TransformHints.getHint(plan))
-        newSource
-      case plan if HiveTableScanExecTransformer.isHiveTableScan(plan) =>
-        val newPartitionFilters: Seq[Expression] = 
ExpressionConverter.transformDynamicPruningExpr(
-          HiveTableScanExecTransformer.getPartitionFilters(plan))
-        val newSource = HiveTableScanExecTransformer.copyWith(plan, 
newPartitionFilters)
-        if (plan.logicalLink.nonEmpty) {
-          newSource.setLogicalLink(plan.logicalLink.get)
+
+        // FIXME: Filter push-down should be better done by Vanilla Spark's 
planner or by
+        //  a individual rule.
+        val scan = plan.child
+        // Push down the left conditions in Filter into FileSourceScan.
+        val newChild: SparkPlan = scan match {
+          case _: FileSourceScanExec | _: BatchScanExec =>
+            if (TransformHints.isTransformable(scan)) {
+              val newScan = FilterHandler.applyFilterPushdownToScan(plan)
+              newScan match {
+                case ts: TransformSupport if ts.doValidate().isValid => ts
+                // TODO remove the call
+                case _ => replace.replaceWithTransformerPlan(scan)
+              }
+            } else {
+              replace.replaceWithTransformerPlan(scan)
+            }
+          case _ => replace.replaceWithTransformerPlan(plan.child)
         }
-        TransformHints.tag(newSource, TransformHints.getHint(plan))
-        newSource
-      case other =>
-        throw new UnsupportedOperationException(s"${other.getClass.toString} 
is not supported.")
+        logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
+        BackendsApiManager.getSparkPlanExecApiInstance
+          .genFilterExecTransformer(plan.condition, newChild)
+      }
     }
 
-    def replaceWithTransformerPlan(plan: SparkPlan): SparkPlan = {
-      if (TransformHints.isNotTransformable(plan)) {
-        logDebug(s"Columnar Processing for ${plan.getClass} is under row 
guard.")
+    // Other transformations.
+    private case class RegularTransformRule() extends Rule[SparkPlan] with 
LogLevelUtil {
+      private val replace = new ReplaceSingleNode()
+
+      override def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
+        case plan => replace.replaceWithTransformerPlan(plan)
+      }
+    }
+
+    // Utility to replace single node within transformed Gluten node.
+    // Children will be preserved as they are as children of the output node.
+    class ReplaceSingleNode() extends LogLevelUtil with Logging {
+      private val columnarConf: GlutenConfig = GlutenConfig.getConf
+
+      def replaceWithTransformerPlan(p: SparkPlan): SparkPlan = {
+        val plan = p
+        if (TransformHints.isNotTransformable(plan)) {
+          logDebug(s"Columnar Processing for ${plan.getClass} is under row 
guard.")
+          plan match {
+            case shj: ShuffledHashJoinExec =>
+              if (BackendsApiManager.getSettings.recreateJoinExecOnFallback()) 
{
+                // Because we manually removed the build side limitation for 
LeftOuter, LeftSemi and
+                // RightOuter, need to change the build side back if this join 
fallback into vanilla
+                // Spark for execution.
+                return ShuffledHashJoinExec(
+                  shj.leftKeys,
+                  shj.rightKeys,
+                  shj.joinType,
+                  getSparkSupportedBuildSide(shj),
+                  shj.condition,
+                  shj.left,
+                  shj.right,
+                  shj.isSkewJoin
+                )
+              } else {
+                return shj
+              }
+            case plan: BatchScanExec =>
+              return applyScanNotTransformable(plan)
+            case plan: FileSourceScanExec =>
+              return applyScanNotTransformable(plan)
+            case plan if HiveTableScanExecTransformer.isHiveTableScan(plan) =>
+              return applyScanNotTransformable(plan)
+            case p =>
+              return p
+          }
+        }
         plan match {
-          case shj: ShuffledHashJoinExec =>
-            if (BackendsApiManager.getSettings.recreateJoinExecOnFallback()) {
-              // Because we manually removed the build side limitation for 
LeftOuter, LeftSemi and
-              // RightOuter, need to change the build side back if this join 
fallback into vanilla
-              // Spark for execution.
-              return ShuffledHashJoinExec(
-                shj.leftKeys,
-                shj.rightKeys,
-                shj.joinType,
-                getSparkSupportedBuildSide(shj),
-                shj.condition,
-                replaceWithTransformerPlan(shj.left),
-                replaceWithTransformerPlan(shj.right),
-                shj.isSkewJoin
-              )
-            } else {
-              return 
shj.withNewChildren(shj.children.map(replaceWithTransformerPlan))
-            }
           case plan: BatchScanExec =>
-            return applyScanNotTransformable(plan)
+            applyScanTransformer(plan)
           case plan: FileSourceScanExec =>
-            return applyScanNotTransformable(plan)
+            applyScanTransformer(plan)
           case plan if HiveTableScanExecTransformer.isHiveTableScan(plan) =>
-            return applyScanNotTransformable(plan)
-          case p =>
-            return 
p.withNewChildren(p.children.map(replaceWithTransformerPlan))
-        }
-      }
-      plan match {
-        case plan: BatchScanExec =>
-          applyScanTransformer(plan)
-        case plan: FileSourceScanExec =>
-          applyScanTransformer(plan)
-        case plan if HiveTableScanExecTransformer.isHiveTableScan(plan) =>
-          applyScanTransformer(plan)
-        case plan: CoalesceExec =>
-          logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-          CoalesceExecTransformer(plan.numPartitions, 
replaceWithTransformerPlan(plan.child))
-        case plan: ProjectExec =>
-          val columnarChild = replaceWithTransformerPlan(plan.child)
-          logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-          ProjectExecTransformer(plan.projectList, columnarChild)
-        case plan: FilterExec =>
-          genFilterExec(plan)
-        case plan: HashAggregateExec =>
-          genHashAggregateExec(plan)
-        case plan: SortAggregateExec =>
-          logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-          BackendsApiManager.getSparkPlanExecApiInstance
-            .genHashAggregateExecTransformer(
-              plan.requiredChildDistributionExpressions,
-              plan.groupingExpressions,
-              plan.aggregateExpressions,
-              plan.aggregateAttributes,
-              plan.initialInputBufferOffset,
-              plan.resultExpressions,
-              plan.child match {
-                case sort: SortExecTransformer if !sort.global =>
-                  replaceWithTransformerPlan(sort.child)
-                case sort: SortExec if !sort.global =>
-                  replaceWithTransformerPlan(sort.child)
-                case _ => replaceWithTransformerPlan(plan.child)
-              }
-            )
-        case plan: ObjectHashAggregateExec =>
-          val child = replaceWithTransformerPlan(plan.child)
-          logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-          BackendsApiManager.getSparkPlanExecApiInstance
-            .genHashAggregateExecTransformer(
-              plan.requiredChildDistributionExpressions,
-              plan.groupingExpressions,
-              plan.aggregateExpressions,
-              plan.aggregateAttributes,
-              plan.initialInputBufferOffset,
-              plan.resultExpressions,
-              child
+            applyScanTransformer(plan)
+          case plan: CoalesceExec =>
+            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
+            CoalesceExecTransformer(plan.numPartitions, plan.child)
+          case plan: ProjectExec =>
+            val columnarChild = plan.child
+            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
+            ProjectExecTransformer(plan.projectList, columnarChild)
+          case plan: SortAggregateExec =>
+            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
+            BackendsApiManager.getSparkPlanExecApiInstance
+              .genHashAggregateExecTransformer(
+                plan.requiredChildDistributionExpressions,
+                plan.groupingExpressions,
+                plan.aggregateExpressions,
+                plan.aggregateAttributes,
+                plan.initialInputBufferOffset,
+                plan.resultExpressions,
+                plan.child match {
+                  case sort: SortExecTransformer if !sort.global =>
+                    sort.child
+                  case sort: SortExec if !sort.global =>
+                    sort.child
+                  case _ => plan.child
+                }
+              )
+          case plan: ObjectHashAggregateExec =>
+            val child = plan.child
+            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
+            BackendsApiManager.getSparkPlanExecApiInstance
+              .genHashAggregateExecTransformer(
+                plan.requiredChildDistributionExpressions,
+                plan.groupingExpressions,
+                plan.aggregateExpressions,
+                plan.aggregateAttributes,
+                plan.initialInputBufferOffset,
+                plan.resultExpressions,
+                child
+              )
+          case plan: UnionExec =>
+            val children = plan.children
+            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
+            ColumnarUnionExec(children)
+          case plan: ExpandExec =>
+            val child = plan.child
+            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
+            ExpandExecTransformer(plan.projections, plan.output, child)
+          case plan: WriteFilesExec =>
+            val child = plan.child
+            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
+            val writeTransformer = WriteFilesExecTransformer(
+              child,
+              plan.fileFormat,
+              plan.partitionColumns,
+              plan.bucketSpec,
+              plan.options,
+              plan.staticPartitions)
+            
BackendsApiManager.getSparkPlanExecApiInstance.createColumnarWriteFilesExec(
+              writeTransformer,
+              plan.fileFormat,
+              plan.partitionColumns,
+              plan.bucketSpec,
+              plan.options,
+              plan.staticPartitions
             )
-        case plan: UnionExec =>
-          val children = plan.children.map(replaceWithTransformerPlan)
-          logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-          ColumnarUnionExec(children)
-        case plan: ExpandExec =>
-          val child = replaceWithTransformerPlan(plan.child)
-          logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-          ExpandExecTransformer(plan.projections, plan.output, child)
-        case plan: WriteFilesExec =>
-          val child = replaceWithTransformerPlan(plan.child)
-          logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-          val writeTransformer = WriteFilesExecTransformer(
-            child,
-            plan.fileFormat,
-            plan.partitionColumns,
-            plan.bucketSpec,
-            plan.options,
-            plan.staticPartitions)
-          
BackendsApiManager.getSparkPlanExecApiInstance.createColumnarWriteFilesExec(
-            writeTransformer,
-            plan.fileFormat,
-            plan.partitionColumns,
-            plan.bucketSpec,
-            plan.options,
-            plan.staticPartitions
-          )
-        case plan: SortExec =>
-          val child = replaceWithTransformerPlan(plan.child)
-          logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-          SortExecTransformer(plan.sortOrder, plan.global, child, 
plan.testSpillFrequency)
-        case plan: TakeOrderedAndProjectExec =>
-          logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-          val child = replaceWithTransformerPlan(plan.child)
-          val (limit, offset) = 
SparkShimLoader.getSparkShims.getLimitAndOffsetFromTopK(plan)
-          TakeOrderedAndProjectExecTransformer(
-            limit,
-            plan.sortOrder,
-            plan.projectList,
-            child,
-            offset)
-        case plan: ShuffleExchangeExec =>
-          logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-          val child = replaceWithTransformerPlan(plan.child)
-          if (
-            (child.supportsColumnar || columnarConf.enablePreferColumnar) &&
-            BackendsApiManager.getSettings.supportColumnarShuffleExec()
-          ) {
-            
BackendsApiManager.getSparkPlanExecApiInstance.genColumnarShuffleExchange(plan, 
child)
-          } else {
-            plan.withNewChildren(Seq(child))
-          }
-        case plan: ShuffledHashJoinExec =>
-          val left = replaceWithTransformerPlan(plan.left)
-          val right = replaceWithTransformerPlan(plan.right)
-          logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-          BackendsApiManager.getSparkPlanExecApiInstance
-            .genShuffledHashJoinExecTransformer(
+          case plan: SortExec =>
+            val child = plan.child
+            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
+            SortExecTransformer(plan.sortOrder, plan.global, child, 
plan.testSpillFrequency)
+          case plan: TakeOrderedAndProjectExec =>
+            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
+            val child = plan.child
+            val (limit, offset) = 
SparkShimLoader.getSparkShims.getLimitAndOffsetFromTopK(plan)
+            TakeOrderedAndProjectExecTransformer(
+              limit,
+              plan.sortOrder,
+              plan.projectList,
+              child,
+              offset)
+          case plan: ShuffleExchangeExec =>
+            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
+            val child = plan.child
+            if (
+              (child.supportsColumnar || columnarConf.enablePreferColumnar) &&
+              BackendsApiManager.getSettings.supportColumnarShuffleExec()
+            ) {
+              
BackendsApiManager.getSparkPlanExecApiInstance.genColumnarShuffleExchange(plan, 
child)
+            } else {
+              plan.withNewChildren(Seq(child))
+            }
+          case plan: ShuffledHashJoinExec =>
+            val left = plan.left
+            val right = plan.right
+            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
+            BackendsApiManager.getSparkPlanExecApiInstance
+              .genShuffledHashJoinExecTransformer(
+                plan.leftKeys,
+                plan.rightKeys,
+                plan.joinType,
+                plan.buildSide,
+                plan.condition,
+                left,
+                right,
+                plan.isSkewJoin)
+          case plan: SortMergeJoinExec =>
+            val left = plan.left
+            val right = plan.right
+            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
+            SortMergeJoinExecTransformer(
               plan.leftKeys,
               plan.rightKeys,
               plan.joinType,
-              plan.buildSide,
               plan.condition,
               left,
               right,
               plan.isSkewJoin)
-        case plan: SortMergeJoinExec =>
-          val left = replaceWithTransformerPlan(plan.left)
-          val right = replaceWithTransformerPlan(plan.right)
-          logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-          SortMergeJoinExecTransformer(
-            plan.leftKeys,
-            plan.rightKeys,
-            plan.joinType,
-            plan.condition,
-            left,
-            right,
-            plan.isSkewJoin)
-        case plan: BroadcastExchangeExec =>
-          val child = replaceWithTransformerPlan(plan.child)
-          logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-          ColumnarBroadcastExchangeExec(plan.mode, child)
-        case plan: BroadcastHashJoinExec =>
-          val left = replaceWithTransformerPlan(plan.left)
-          val right = replaceWithTransformerPlan(plan.right)
-          BackendsApiManager.getSparkPlanExecApiInstance
-            .genBroadcastHashJoinExecTransformer(
-              plan.leftKeys,
-              plan.rightKeys,
-              plan.joinType,
-              plan.buildSide,
-              plan.condition,
-              left,
-              right,
-              isNullAwareAntiJoin = plan.isNullAwareAntiJoin)
-        case plan: CartesianProductExec =>
-          val left = replaceWithTransformerPlan(plan.left)
-          val right = replaceWithTransformerPlan(plan.right)
-          BackendsApiManager.getSparkPlanExecApiInstance
-            .genCartesianProductExecTransformer(left, right, plan.condition)
-        case plan: BroadcastNestedLoopJoinExec =>
-          val left = replaceWithTransformerPlan(plan.left)
-          val right = replaceWithTransformerPlan(plan.right)
-          BackendsApiManager.getSparkPlanExecApiInstance
-            .genBroadcastNestedLoopJoinExecTransformer(
-              left,
-              right,
-              plan.buildSide,
-              plan.joinType,
-              plan.condition)
-        case plan: WindowExec =>
-          WindowExecTransformer(
-            plan.windowExpression,
-            plan.partitionSpec,
-            plan.orderSpec,
-            replaceWithTransformerPlan(plan.child))
-        case plan: GlobalLimitExec =>
-          logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-          val child = replaceWithTransformerPlan(plan.child)
-          val (limit, offset) = 
SparkShimLoader.getSparkShims.getLimitAndOffsetFromGlobalLimit(plan)
-          LimitTransformer(child, offset, limit)
-        case plan: LocalLimitExec =>
-          logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-          val child = replaceWithTransformerPlan(plan.child)
-          LimitTransformer(child, 0L, plan.limit)
-        case plan: GenerateExec =>
-          logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-          val child = replaceWithTransformerPlan(plan.child)
-          
BackendsApiManager.getSparkPlanExecApiInstance.genGenerateTransformer(
-            plan.generator,
-            plan.requiredChildOutput,
-            plan.outer,
-            plan.generatorOutput,
-            child)
-        case plan: EvalPythonExec =>
-          logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-          val child = replaceWithTransformerPlan(plan.child)
-          EvalPythonExecTransformer(plan.udfs, plan.resultAttrs, child)
-        case p =>
-          logDebug(s"Transformation for ${p.getClass} is currently not 
supported.")
-          val children = plan.children.map(replaceWithTransformerPlan)
-          p.withNewChildren(children)
+          case plan: BroadcastExchangeExec =>
+            val child = plan.child
+            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
+            ColumnarBroadcastExchangeExec(plan.mode, child)
+          case plan: BroadcastHashJoinExec =>
+            val left = plan.left
+            val right = plan.right
+            BackendsApiManager.getSparkPlanExecApiInstance
+              .genBroadcastHashJoinExecTransformer(
+                plan.leftKeys,
+                plan.rightKeys,
+                plan.joinType,
+                plan.buildSide,
+                plan.condition,
+                left,
+                right,
+                isNullAwareAntiJoin = plan.isNullAwareAntiJoin)
+          case plan: CartesianProductExec =>
+            val left = plan.left
+            val right = plan.right
+            BackendsApiManager.getSparkPlanExecApiInstance
+              .genCartesianProductExecTransformer(left, right, plan.condition)
+          case plan: BroadcastNestedLoopJoinExec =>
+            val left = plan.left
+            val right = plan.right
+            BackendsApiManager.getSparkPlanExecApiInstance
+              .genBroadcastNestedLoopJoinExecTransformer(
+                left,
+                right,
+                plan.buildSide,
+                plan.joinType,
+                plan.condition)
+          case plan: WindowExec =>
+            WindowExecTransformer(
+              plan.windowExpression,
+              plan.partitionSpec,
+              plan.orderSpec,
+              plan.child)
+          case plan: GlobalLimitExec =>
+            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
+            val child = plan.child
+            val (limit, offset) =
+              
SparkShimLoader.getSparkShims.getLimitAndOffsetFromGlobalLimit(plan)
+            LimitTransformer(child, offset, limit)
+          case plan: LocalLimitExec =>
+            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
+            val child = plan.child
+            LimitTransformer(child, 0L, plan.limit)
+          case plan: GenerateExec =>
+            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
+            val child = plan.child
+            
BackendsApiManager.getSparkPlanExecApiInstance.genGenerateTransformer(
+              plan.generator,
+              plan.requiredChildOutput,
+              plan.outer,
+              plan.generatorOutput,
+              child)
+          case plan: EvalPythonExec =>
+            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
+            val child = plan.child
+            EvalPythonExecTransformer(plan.udfs, plan.resultAttrs, child)
+          case p if !p.isInstanceOf[GlutenPlan] =>
+            logDebug(s"Transformation for ${p.getClass} is currently not 
supported.")
+            val children = plan.children
+            p.withNewChildren(children)
+          case other => other
+        }
       }
-    }
 
-    /**
-     * Get the build side supported by the execution of vanilla Spark.
-     *
-     * @param plan:
-     *   shuffled hash join plan
-     * @return
-     *   the supported build side
-     */
-    private def getSparkSupportedBuildSide(plan: ShuffledHashJoinExec): 
BuildSide = {
-      plan.joinType match {
-        case LeftOuter | LeftSemi => BuildRight
-        case RightOuter => BuildLeft
-        case _ => plan.buildSide
+      /**
+       * Get the build side supported by the execution of vanilla Spark.
+       *
+       * @param plan
+       *   : shuffled hash join plan
+       * @return
+       *   the supported build side
+       */
+      private def getSparkSupportedBuildSide(plan: ShuffledHashJoinExec): 
BuildSide = {
+        plan.joinType match {
+          case LeftOuter | LeftSemi => BuildRight
+          case RightOuter => BuildLeft
+          case _ => plan.buildSide
+        }
       }
-    }
 
-    /**
-     * Apply scan transformer for file source and batch source,
-     *   1. create new filter and scan transformer, 2. validate, tag new scan 
as unsupported if
-     *      failed, 3. return new source.
-     */
-    def applyScanTransformer(plan: SparkPlan): SparkPlan = plan match {
-      case plan: FileSourceScanExec =>
-        val transformer = 
ScanTransformerFactory.createFileSourceScanTransformer(plan)
-        val validationResult = transformer.doValidate()
-        if (validationResult.isValid) {
-          logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-          transformer
-        } else {
-          logDebug(s"Columnar Processing for ${plan.getClass} is currently 
unsupported.")
-          val newSource = plan.copy(partitionFilters = 
transformer.getPartitionFilters())
-          TransformHints.tagNotTransformable(newSource, 
validationResult.reason.get)
+      private def applyScanNotTransformable(plan: SparkPlan): SparkPlan = plan 
match {
+        case plan: FileSourceScanExec =>
+          val newPartitionFilters =
+            
ExpressionConverter.transformDynamicPruningExpr(plan.partitionFilters)
+          val newSource = plan.copy(partitionFilters = newPartitionFilters)
+          if (plan.logicalLink.nonEmpty) {
+            newSource.setLogicalLink(plan.logicalLink.get)
+          }
+          TransformHints.tag(newSource, TransformHints.getHint(plan))
           newSource
-        }
-      case plan: BatchScanExec =>
-        ScanTransformerFactory.createBatchScanTransformer(plan)
+        case plan: BatchScanExec =>
+          val newPartitionFilters: Seq[Expression] = plan.scan match {
+            case scan: FileScan =>
+              
ExpressionConverter.transformDynamicPruningExpr(scan.partitionFilters)
+            case _ =>
+              
ExpressionConverter.transformDynamicPruningExpr(plan.runtimeFilters)
+          }
+          val newSource = plan.copy(runtimeFilters = newPartitionFilters)
+          if (plan.logicalLink.nonEmpty) {
+            newSource.setLogicalLink(plan.logicalLink.get)
+          }
+          TransformHints.tag(newSource, TransformHints.getHint(plan))
+          newSource
+        case plan if HiveTableScanExecTransformer.isHiveTableScan(plan) =>
+          val newPartitionFilters: Seq[Expression] =
+            ExpressionConverter.transformDynamicPruningExpr(
+              HiveTableScanExecTransformer.getPartitionFilters(plan))
+          val newSource = HiveTableScanExecTransformer.copyWith(plan, 
newPartitionFilters)
+          if (plan.logicalLink.nonEmpty) {
+            newSource.setLogicalLink(plan.logicalLink.get)
+          }
+          TransformHints.tag(newSource, TransformHints.getHint(plan))
+          newSource
+        case other =>
+          throw new UnsupportedOperationException(s"${other.getClass.toString} 
is not supported.")
+      }
 
-      case plan if HiveTableScanExecTransformer.isHiveTableScan(plan) =>
-        // TODO: Add DynamicPartitionPruningHiveScanSuite.scala
-        val newPartitionFilters: Seq[Expression] = 
ExpressionConverter.transformDynamicPruningExpr(
-          HiveTableScanExecTransformer.getPartitionFilters(plan))
-        val hiveTableScanExecTransformer =
-          
BackendsApiManager.getSparkPlanExecApiInstance.genHiveTableScanExecTransformer(plan)
-        val validateResult = hiveTableScanExecTransformer.doValidate()
-        if (validateResult.isValid) {
-          logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-          return hiveTableScanExecTransformer
-        }
-        logDebug(s"Columnar Processing for ${plan.getClass} is currently 
unsupported.")
-        val newSource = HiveTableScanExecTransformer.copyWith(plan, 
newPartitionFilters)
-        TransformHints.tagNotTransformable(newSource, 
validateResult.reason.get)
-        newSource
-      case other =>
-        throw new UnsupportedOperationException(s"${other.getClass.toString} 
is not supported.")
+      /**
+       * Apply scan transformer for file source and batch source,
+       *   1. create new filter and scan transformer, 2. validate, tag new 
scan as unsupported if
+       *      failed, 3. return new source.
+       */
+      private def applyScanTransformer(plan: SparkPlan): SparkPlan = plan 
match {
+        case plan: FileSourceScanExec =>
+          val transformer = 
ScanTransformerFactory.createFileSourceScanTransformer(plan)
+          val validationResult = transformer.doValidate()
+          if (validationResult.isValid) {
+            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
+            transformer
+          } else {
+            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
unsupported.")
+            val newSource = plan.copy(partitionFilters = 
transformer.getPartitionFilters())
+            TransformHints.tagNotTransformable(newSource, 
validationResult.reason.get)
+            newSource
+          }
+        case plan: BatchScanExec =>
+          ScanTransformerFactory.createBatchScanTransformer(plan)
+
+        case plan if HiveTableScanExecTransformer.isHiveTableScan(plan) =>
+          // TODO: Add DynamicPartitionPruningHiveScanSuite.scala
+          val newPartitionFilters: Seq[Expression] =
+            ExpressionConverter.transformDynamicPruningExpr(
+              HiveTableScanExecTransformer.getPartitionFilters(plan))
+          val hiveTableScanExecTransformer =
+            
BackendsApiManager.getSparkPlanExecApiInstance.genHiveTableScanExecTransformer(plan)
+          val validateResult = hiveTableScanExecTransformer.doValidate()
+          if (validateResult.isValid) {
+            logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
+            return hiveTableScanExecTransformer
+          }
+          logDebug(s"Columnar Processing for ${plan.getClass} is currently 
unsupported.")
+          val newSource = HiveTableScanExecTransformer.copyWith(plan, 
newPartitionFilters)
+          TransformHints.tagNotTransformable(newSource, 
validateResult.reason.get)
+          newSource
+        case other =>
+          throw new UnsupportedOperationException(s"${other.getClass.toString} 
is not supported.")
+      }
     }
+  }
 
-    def apply(plan: SparkPlan): SparkPlan = {
-      val newPlan = replaceWithTransformerPlan(plan)
+  // This rule will conduct the conversion from Spark plan to the plan 
transformer.
+  case class TransformPreOverrides() extends Rule[SparkPlan] with LogLevelUtil 
{
+    import TransformPreOverrides._
+
+    private val subRules = List(
+      FilterTransformRule(),
+      RegularTransformRule(),
+      AggregationTransformRule()
+    )
 
+    @transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]()
+
+    def apply(plan: SparkPlan): SparkPlan = {
+      val newPlan = subRules.foldLeft(plan)((p, rule) => rule.apply(p))
       planChangeLogger.logRule(ruleName, plan, newPlan)
       newPlan
     }
@@ -450,7 +506,6 @@ object MiscColumnarRules {
   // This rule will try to convert the row-to-columnar and columnar-to-row
   // into native implementations.
   case class TransformPostOverrides() extends Rule[SparkPlan] {
-    val columnarConf: GlutenConfig = GlutenConfig.getConf
     @transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]()
 
     def replaceWithTransformerPlan(plan: SparkPlan): SparkPlan = 
plan.transformDown {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to