This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 18af4bc3c [VL] RAS: Include rewrite rules used by 
RewriteSparkPlanRulesManager in EnumeratedTransform (#5575)
18af4bc3c is described below

commit 18af4bc3ce4c3e685ad63c869880f8b63d48dc1c
Author: Hongze Zhang <[email protected]>
AuthorDate: Tue May 7 14:49:03 2024 +0800

    [VL] RAS: Include rewrite rules used by RewriteSparkPlanRulesManager in 
EnumeratedTransform (#5575)
---
 .../clickhouse/CHSparkPlanExecApi.scala            |  3 +-
 .../execution/CHHashAggregateExecTransformer.scala |  6 +-
 .../backendsapi/velox/VeloxSparkPlanExecApi.scala  |  3 +-
 .../execution/HashAggregateExecTransformer.scala   | 14 ++--
 .../gluten/backendsapi/SparkPlanExecApi.scala      |  1 -
 .../HashAggregateExecBaseTransformer.scala         | 30 +++++--
 .../org/apache/gluten/extension/RewriteIn.scala    |  7 +-
 .../extension/columnar/MiscColumnarRules.scala     | 18 ++--
 ...ormSingleNode.scala => OffloadSingleNode.scala} | 95 ++++++++--------------
 .../extension/columnar/TransformHintRule.scala     | 33 +-------
 .../columnar/enumerated/ConditionedRule.scala      | 51 ------------
 .../columnar/enumerated/EnumeratedApplier.scala    |  5 +-
 .../columnar/enumerated/EnumeratedTransform.scala  | 56 ++++---------
 .../columnar/enumerated/PushFilterToScan.scala     | 27 +++---
 .../extension/columnar/enumerated/RasOffload.scala | 84 +++++++++++++++++++
 ...ntAggregate.scala => RasOffloadAggregate.scala} | 35 ++------
 ...mplementFilter.scala => RasOffloadFilter.scala} | 16 ++--
 .../{FilterRemoveRule.scala => RemoveFilter.scala} |  2 +-
 .../columnar/heuristic/HeuristicApplier.scala      |  1 +
 .../{ => rewrite}/PullOutPostProject.scala         | 14 ++--
 .../columnar/{ => rewrite}/PullOutPreProject.scala |  9 +-
 .../{ => columnar/rewrite}/RewriteCollect.scala    |  9 +-
 .../{ => rewrite}/RewriteMultiChildrenCount.scala  |  7 +-
 .../columnar/rewrite/RewriteSingleNode.scala       | 48 +++++++++++
 .../RewriteSparkPlanRulesManager.scala             | 17 ++--
 .../RewriteTypedImperativeAggregate.scala          |  7 +-
 .../gluten/planner/cost/GlutenCostModel.scala      |  6 +-
 .../GlutenFormatWriterInjectsBase.scala            |  3 +-
 28 files changed, 295 insertions(+), 312 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index 64090af28..a9a12a3ea 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -204,10 +204,9 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
 
   /** Generate HashAggregateExecPullOutHelper */
   override def genHashAggregateExecPullOutHelper(
-      groupingExpressions: Seq[NamedExpression],
       aggregateExpressions: Seq[AggregateExpression],
       aggregateAttributes: Seq[Attribute]): HashAggregateExecPullOutBaseHelper 
=
-    CHHashAggregateExecPullOutHelper(groupingExpressions, 
aggregateExpressions, aggregateAttributes)
+    CHHashAggregateExecPullOutHelper(aggregateExpressions, aggregateAttributes)
 
   /**
    * If there are expressions (not field reference) in the partitioning's 
children, add a projection
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
index 82c492f4c..d4f2f9eb3 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashAggregateExecTransformer.scala
@@ -411,13 +411,9 @@ case class CHHashAggregateExecTransformer(
 }
 
 case class CHHashAggregateExecPullOutHelper(
-    groupingExpressions: Seq[NamedExpression],
     aggregateExpressions: Seq[AggregateExpression],
     aggregateAttributes: Seq[Attribute])
-  extends HashAggregateExecPullOutBaseHelper(
-    groupingExpressions,
-    aggregateExpressions,
-    aggregateAttributes) {
+  extends HashAggregateExecPullOutBaseHelper {
 
   /** This method calculates the output attributes of Aggregation. */
   override protected def getAttrForAggregateExprs: List[Attribute] = {
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index a55aa1817..0a9f3ef65 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -330,10 +330,9 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
 
   /** Generate HashAggregateExecPullOutHelper */
   override def genHashAggregateExecPullOutHelper(
-      groupingExpressions: Seq[NamedExpression],
       aggregateExpressions: Seq[AggregateExpression],
       aggregateAttributes: Seq[Attribute]): HashAggregateExecPullOutBaseHelper 
=
-    HashAggregateExecPullOutHelper(groupingExpressions, aggregateExpressions, 
aggregateAttributes)
+    HashAggregateExecPullOutHelper(aggregateExpressions, aggregateAttributes)
 
   override def genColumnarShuffleExchange(
       shuffle: ShuffleExchangeExec,
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
index 0a9904206..f0a7ea180 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
@@ -20,7 +20,7 @@ import org.apache.gluten.backendsapi.BackendsApiManager
 import org.apache.gluten.exception.GlutenNotSupportException
 import org.apache.gluten.expression._
 import org.apache.gluten.expression.ConverterUtils.FunctionConfig
-import org.apache.gluten.extension.columnar.RewriteTypedImperativeAggregate
+import 
org.apache.gluten.extension.columnar.rewrite.RewriteTypedImperativeAggregate
 import org.apache.gluten.substrait.`type`.{TypeBuilder, TypeNode}
 import org.apache.gluten.substrait.{AggregationParams, SubstraitContext}
 import org.apache.gluten.substrait.expression.{AggregateFunctionNode, 
ExpressionBuilder, ExpressionNode, ScalarFunctionNode}
@@ -60,6 +60,12 @@ abstract class HashAggregateExecTransformer(
     resultExpressions,
     child) {
 
+  override def output: Seq[Attribute] = {
+    // TODO: We should have a check to make sure the returned schema actually 
matches the output
+    //  data. Since "resultExpressions" is not actually in used by Velox.
+    super.output
+  }
+
   override def doTransform(context: SubstraitContext): TransformContext = {
     val childCtx = child.asInstanceOf[TransformSupport].doTransform(context)
 
@@ -793,13 +799,9 @@ case class FlushableHashAggregateExecTransformer(
 }
 
 case class HashAggregateExecPullOutHelper(
-    groupingExpressions: Seq[NamedExpression],
     aggregateExpressions: Seq[AggregateExpression],
     aggregateAttributes: Seq[Attribute])
-  extends HashAggregateExecPullOutBaseHelper(
-    groupingExpressions,
-    aggregateExpressions,
-    aggregateAttributes) {
+  extends HashAggregateExecPullOutBaseHelper {
 
   /** This method calculates the output attributes of Aggregation. */
   override protected def getAttrForAggregateExprs: List[Attribute] = {
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index cfa1a4e53..f5e08a05d 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -102,7 +102,6 @@ trait SparkPlanExecApi {
 
   /** Generate HashAggregateExecPullOutHelper */
   def genHashAggregateExecPullOutHelper(
-      groupingExpressions: Seq[NamedExpression],
       aggregateExpressions: Seq[AggregateExpression],
       aggregateAttributes: Seq[Attribute]): HashAggregateExecPullOutBaseHelper
 
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala
index baf88c727..49a9ee1e8 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala
@@ -172,12 +172,32 @@ abstract class HashAggregateExecBaseTransformer(
       validation: Boolean = false): RelNode
 }
 
-abstract class HashAggregateExecPullOutBaseHelper(
-    groupingExpressions: Seq[NamedExpression],
-    aggregateExpressions: Seq[AggregateExpression],
-    aggregateAttributes: Seq[Attribute]) {
+object HashAggregateExecBaseTransformer {
+
+  private def getInitialInputBufferOffset(agg: BaseAggregateExec): Int = agg 
match {
+    case a: HashAggregateExec => a.initialInputBufferOffset
+    case a: ObjectHashAggregateExec => a.initialInputBufferOffset
+    case a: SortAggregateExec => a.initialInputBufferOffset
+  }
+
+  def from(agg: BaseAggregateExec)(
+      childConverter: SparkPlan => SparkPlan = p => p): 
HashAggregateExecBaseTransformer = {
+    BackendsApiManager.getSparkPlanExecApiInstance
+      .genHashAggregateExecTransformer(
+        agg.requiredChildDistributionExpressions,
+        agg.groupingExpressions,
+        agg.aggregateExpressions,
+        agg.aggregateAttributes,
+        getInitialInputBufferOffset(agg),
+        agg.resultExpressions,
+        childConverter(agg.child)
+      )
+  }
+}
+
+trait HashAggregateExecPullOutBaseHelper {
   // The direct outputs of Aggregation.
-  lazy val allAggregateResultAttributes: List[Attribute] =
+  def allAggregateResultAttributes(groupingExpressions: Seq[NamedExpression]): 
List[Attribute] =
     groupingExpressions.map(ConverterUtils.getAttrFromExpr(_)).toList :::
       getAttrForAggregateExprs
 
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/RewriteIn.scala 
b/gluten-core/src/main/scala/org/apache/gluten/extension/RewriteIn.scala
index b508f3eff..565b9bb19 100644
--- a/gluten-core/src/main/scala/org/apache/gluten/extension/RewriteIn.scala
+++ b/gluten-core/src/main/scala/org/apache/gluten/extension/RewriteIn.scala
@@ -16,8 +16,9 @@
  */
 package org.apache.gluten.extension
 
+import org.apache.gluten.extension.columnar.rewrite.RewriteSingleNode
+
 import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, In, Or}
-import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.{FileSourceScanExec, FilterExec, 
SparkPlan}
 import org.apache.spark.sql.types.StructType
 
@@ -32,7 +33,7 @@ import org.apache.spark.sql.types.StructType
  *
  * TODO: Remove this rule once Velox support the list option in `In` is not 
literal.
  */
-object RewriteIn extends Rule[SparkPlan] {
+object RewriteIn extends RewriteSingleNode {
 
   private def shouldRewrite(e: Expression): Boolean = {
     e match {
@@ -58,7 +59,7 @@ object RewriteIn extends Rule[SparkPlan] {
     }
   }
 
-  override def apply(plan: SparkPlan): SparkPlan = {
+  override def rewrite(plan: SparkPlan): SparkPlan = {
     plan match {
       // TODO: Support datasource v2
       case scan: FileSourceScanExec if 
scan.dataFilters.exists(_.find(shouldRewrite).isDefined) =>
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala
index 02a466b6a..068f62e49 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala
@@ -30,12 +30,12 @@ object MiscColumnarRules {
   object TransformPreOverrides {
     def apply(): TransformPreOverrides = {
       TransformPreOverrides(
-        List(TransformFilter()),
+        List(OffloadFilter()),
         List(
-          TransformOthers(),
-          TransformAggregate(),
-          TransformExchange(),
-          TransformJoin()
+          OffloadOthers(),
+          OffloadAggregate(),
+          OffloadExchange(),
+          OffloadJoin()
         )
       )
     }
@@ -43,17 +43,17 @@ object MiscColumnarRules {
 
   // This rule will conduct the conversion from Spark plan to the plan 
transformer.
   case class TransformPreOverrides(
-      topDownRules: Seq[TransformSingleNode],
-      bottomUpRules: Seq[TransformSingleNode])
+      topDownRules: Seq[OffloadSingleNode],
+      bottomUpRules: Seq[OffloadSingleNode])
     extends Rule[SparkPlan]
     with LogLevelUtil {
     @transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]()
 
     def apply(plan: SparkPlan): SparkPlan = {
       val plan0 =
-        topDownRules.foldLeft(plan)((p, rule) => p.transformDown { case p => 
rule.impl(p) })
+        topDownRules.foldLeft(plan)((p, rule) => p.transformDown { case p => 
rule.offload(p) })
       val plan1 =
-        bottomUpRules.foldLeft(plan0)((p, rule) => p.transformUp { case p => 
rule.impl(p) })
+        bottomUpRules.foldLeft(plan0)((p, rule) => p.transformUp { case p => 
rule.offload(p) })
       planChangeLogger.logRule(ruleName, plan, plan1)
       plan1
     }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformSingleNode.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala
similarity index 88%
rename from 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformSingleNode.scala
rename to 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala
index 760929bbd..84a2ec5c6 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformSingleNode.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala
@@ -40,13 +40,20 @@ import 
org.apache.spark.sql.execution.python.{ArrowEvalPythonExec, BatchEvalPyth
 import org.apache.spark.sql.execution.window.{WindowExec, 
WindowGroupLimitExecShim}
 import org.apache.spark.sql.hive.HiveTableScanExecTransformer
 
-sealed trait TransformSingleNode extends Logging {
-  def impl(plan: SparkPlan): SparkPlan
+/**
+ * Converts a vanilla Spark plan node into Gluten plan node. Gluten plan is 
supposed to be executed
+ * in native, and the internals of execution is subject by backend's 
implementation.
+ *
+ * Note: Only the current plan node is supposed to be open to modification. Do 
not access or modify
+ * the children node. Tree-walking is done by caller of this trait.
+ */
+sealed trait OffloadSingleNode extends Logging {
+  def offload(plan: SparkPlan): SparkPlan
 }
 
 // Aggregation transformation.
-case class TransformAggregate() extends TransformSingleNode with LogLevelUtil {
-  override def impl(plan: SparkPlan): SparkPlan = plan match {
+case class OffloadAggregate() extends OffloadSingleNode with LogLevelUtil {
+  override def offload(plan: SparkPlan): SparkPlan = plan match {
     case plan if TransformHints.isNotTransformable(plan) =>
       plan
     case agg: HashAggregateExec =>
@@ -69,19 +76,6 @@ case class TransformAggregate() extends TransformSingleNode 
with LogLevelUtil {
 
     val aggChild = plan.child
 
-    def transformHashAggregate(): GlutenPlan = {
-      BackendsApiManager.getSparkPlanExecApiInstance
-        .genHashAggregateExecTransformer(
-          plan.requiredChildDistributionExpressions,
-          plan.groupingExpressions,
-          plan.aggregateExpressions,
-          plan.aggregateAttributes,
-          plan.initialInputBufferOffset,
-          plan.resultExpressions,
-          aggChild
-        )
-    }
-
     // If child's output is empty, fallback or offload both the child and 
aggregation.
     if (
       aggChild.output.isEmpty && BackendsApiManager.getSettings
@@ -91,9 +85,9 @@ case class TransformAggregate() extends TransformSingleNode 
with LogLevelUtil {
         case _: TransformSupport =>
           // If the child is transformable, transform aggregation as well.
           logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-          transformHashAggregate()
+          HashAggregateExecBaseTransformer.from(plan)()
         case p: SparkPlan if PlanUtil.isGlutenTableCache(p) =>
-          transformHashAggregate()
+          HashAggregateExecBaseTransformer.from(plan)()
         case _ =>
           // If the child is not transformable, do not transform the agg.
           TransformHints.tagNotTransformable(plan, "child output schema is 
empty")
@@ -101,14 +95,14 @@ case class TransformAggregate() extends 
TransformSingleNode with LogLevelUtil {
       }
     } else {
       logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-      transformHashAggregate()
+      HashAggregateExecBaseTransformer.from(plan)()
     }
   }
 }
 
 // Exchange transformation.
-case class TransformExchange() extends TransformSingleNode with LogLevelUtil {
-  override def impl(plan: SparkPlan): SparkPlan = plan match {
+case class OffloadExchange() extends OffloadSingleNode with LogLevelUtil {
+  override def offload(plan: SparkPlan): SparkPlan = plan match {
     case plan if TransformHints.isNotTransformable(plan) =>
       plan
     case plan: ShuffleExchangeExec =>
@@ -131,10 +125,10 @@ case class TransformExchange() extends 
TransformSingleNode with LogLevelUtil {
 }
 
 // Join transformation.
-case class TransformJoin() extends TransformSingleNode with LogLevelUtil {
-  import TransformJoin._
+case class OffloadJoin() extends OffloadSingleNode with LogLevelUtil {
+  import OffloadJoin._
 
-  override def impl(plan: SparkPlan): SparkPlan = {
+  override def offload(plan: SparkPlan): SparkPlan = {
     if (TransformHints.isNotTransformable(plan)) {
       logDebug(s"Columnar Processing for ${plan.getClass} is under row guard.")
       plan match {
@@ -223,7 +217,7 @@ case class TransformJoin() extends TransformSingleNode with 
LogLevelUtil {
 
 }
 
-object TransformJoin {
+object OffloadJoin {
   private def getSparkSupportedBuildSide(plan: ShuffledHashJoinExec): 
BuildSide = {
     plan.joinType match {
       case LeftOuter | LeftSemi => BuildRight
@@ -238,11 +232,11 @@ object TransformJoin {
 }
 
 // Filter transformation.
-case class TransformFilter() extends TransformSingleNode with LogLevelUtil {
-  import TransformOthers._
+case class OffloadFilter() extends OffloadSingleNode with LogLevelUtil {
+  import OffloadOthers._
   private val replace = new ReplaceSingleNode()
 
-  override def impl(plan: SparkPlan): SparkPlan = plan match {
+  override def offload(plan: SparkPlan): SparkPlan = plan match {
     case filter: FilterExec =>
       genFilterExec(filter)
     case other => other
@@ -286,14 +280,14 @@ case class TransformFilter() extends TransformSingleNode 
with LogLevelUtil {
 }
 
 // Other transformations.
-case class TransformOthers() extends TransformSingleNode with LogLevelUtil {
-  import TransformOthers._
+case class OffloadOthers() extends OffloadSingleNode with LogLevelUtil {
+  import OffloadOthers._
   private val replace = new ReplaceSingleNode()
 
-  override def impl(plan: SparkPlan): SparkPlan = replace.doReplace(plan)
+  override def offload(plan: SparkPlan): SparkPlan = replace.doReplace(plan)
 }
 
-object TransformOthers {
+object OffloadOthers {
   // Utility to replace single node within transformed Gluten node.
   // Children will be preserved as they are as children of the output node.
   //
@@ -333,35 +327,16 @@ object TransformOthers {
           ProjectExecTransformer(plan.projectList, columnarChild)
         case plan: SortAggregateExec =>
           logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-          BackendsApiManager.getSparkPlanExecApiInstance
-            .genHashAggregateExecTransformer(
-              plan.requiredChildDistributionExpressions,
-              plan.groupingExpressions,
-              plan.aggregateExpressions,
-              plan.aggregateAttributes,
-              plan.initialInputBufferOffset,
-              plan.resultExpressions,
-              plan.child match {
-                case sort: SortExecTransformer if !sort.global =>
-                  sort.child
-                case sort: SortExec if !sort.global =>
-                  sort.child
-                case _ => plan.child
-              }
-            )
+          HashAggregateExecBaseTransformer.from(plan) {
+            case sort: SortExecTransformer if !sort.global =>
+              sort.child
+            case sort: SortExec if !sort.global =>
+              sort.child
+            case other => other
+          }
         case plan: ObjectHashAggregateExec =>
-          val child = plan.child
           logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
-          BackendsApiManager.getSparkPlanExecApiInstance
-            .genHashAggregateExecTransformer(
-              plan.requiredChildDistributionExpressions,
-              plan.groupingExpressions,
-              plan.aggregateExpressions,
-              plan.aggregateAttributes,
-              plan.initialInputBufferOffset,
-              plan.resultExpressions,
-              child
-            )
+          HashAggregateExecBaseTransformer.from(plan)()
         case plan: UnionExec =>
           val children = plan.children
           logDebug(s"Columnar Processing for ${plan.getClass} is currently 
supported.")
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala
index ea934425f..3c3d23ccc 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala
@@ -355,40 +355,13 @@ case class AddTransformHintRule() extends Rule[SparkPlan] 
{
             .genFilterExecTransformer(plan.condition, plan.child)
           transformer.doValidate().tagOnFallback(plan)
         case plan: HashAggregateExec =>
-          val transformer = BackendsApiManager.getSparkPlanExecApiInstance
-            .genHashAggregateExecTransformer(
-              plan.requiredChildDistributionExpressions,
-              plan.groupingExpressions,
-              plan.aggregateExpressions,
-              plan.aggregateAttributes,
-              plan.initialInputBufferOffset,
-              plan.resultExpressions,
-              plan.child
-            )
+          val transformer = HashAggregateExecBaseTransformer.from(plan)()
           transformer.doValidate().tagOnFallback(plan)
         case plan: SortAggregateExec =>
-          val transformer = BackendsApiManager.getSparkPlanExecApiInstance
-            .genHashAggregateExecTransformer(
-              plan.requiredChildDistributionExpressions,
-              plan.groupingExpressions,
-              plan.aggregateExpressions,
-              plan.aggregateAttributes,
-              plan.initialInputBufferOffset,
-              plan.resultExpressions,
-              plan.child
-            )
+          val transformer = HashAggregateExecBaseTransformer.from(plan)()
           transformer.doValidate().tagOnFallback(plan)
         case plan: ObjectHashAggregateExec =>
-          val transformer = BackendsApiManager.getSparkPlanExecApiInstance
-            .genHashAggregateExecTransformer(
-              plan.requiredChildDistributionExpressions,
-              plan.groupingExpressions,
-              plan.aggregateExpressions,
-              plan.aggregateAttributes,
-              plan.initialInputBufferOffset,
-              plan.resultExpressions,
-              plan.child
-            )
+          val transformer = HashAggregateExecBaseTransformer.from(plan)()
           transformer.doValidate().tagOnFallback(plan)
         case plan: UnionExec =>
           val transformer = ColumnarUnionExec(plan.children)
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ConditionedRule.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ConditionedRule.scala
deleted file mode 100644
index 33d99f5f7..000000000
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ConditionedRule.scala
+++ /dev/null
@@ -1,51 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.extension.columnar.enumerated
-
-import org.apache.gluten.extension.columnar.validator.Validator
-import org.apache.gluten.ras.rule.{RasRule, Shape}
-
-import org.apache.spark.sql.execution.SparkPlan
-
-object ConditionedRule {
-  trait PreCondition {
-    def apply(node: SparkPlan): Boolean
-  }
-
-  object PreCondition {
-    implicit class FromValidator(validator: Validator) extends PreCondition {
-      override def apply(node: SparkPlan): Boolean = {
-        validator.validate(node) match {
-          case Validator.Passed => true
-          case Validator.Failed(reason) => false
-        }
-      }
-    }
-  }
-
-  def wrap(rule: RasRule[SparkPlan], cond: ConditionedRule.PreCondition): 
RasRule[SparkPlan] = {
-    new RasRule[SparkPlan] {
-      override def shift(node: SparkPlan): Iterable[SparkPlan] = {
-        val out = List(node)
-          .filter(cond.apply)
-          .flatMap(rule.shift)
-        out
-      }
-      override def shape(): Shape[SparkPlan] = rule.shape()
-    }
-  }
-}
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala
index dfc2d474f..92d64abf3 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala
@@ -118,10 +118,7 @@ class EnumeratedApplier(session: SparkSession)
       (_: SparkSession) => FallbackEmptySchemaRelation()
     ) :::
       
BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarValidationRules()
 :::
-      List(
-        (spark: SparkSession) => MergeTwoPhasesHashBaseAggregate(spark),
-        (_: SparkSession) => RewriteSparkPlanRulesManager()
-      ) :::
+      List((spark: SparkSession) => MergeTwoPhasesHashBaseAggregate(spark)) :::
       List(
         (session: SparkSession) => EnumeratedTransform(session, 
outputsColumnar),
         (_: SparkSession) => RemoveTransitions
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
index 973020438..dc34bc1af 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
@@ -16,13 +16,10 @@
  */
 package org.apache.gluten.extension.columnar.enumerated
 
-import org.apache.gluten.extension.GlutenPlan
-import org.apache.gluten.extension.columnar.{TransformExchange, TransformJoin, 
TransformOthers, TransformSingleNode}
-import org.apache.gluten.extension.columnar.validator.{Validator, Validators}
+import org.apache.gluten.extension.columnar.{OffloadExchange, OffloadJoin, 
OffloadOthers, OffloadSingleNode}
 import org.apache.gluten.planner.GlutenOptimization
 import org.apache.gluten.planner.property.Conventions
 import org.apache.gluten.ras.property.PropertySet
-import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
 import org.apache.gluten.utils.LogLevelUtil
 
 import org.apache.spark.sql.SparkSession
@@ -34,31 +31,22 @@ case class EnumeratedTransform(session: SparkSession, 
outputsColumnar: Boolean)
   with LogLevelUtil {
   import EnumeratedTransform._
 
-  private val validator = Validators
-    .builder()
-    .fallbackByHint()
-    .fallbackIfScanOnly()
-    .fallbackComplexExpressions()
-    .fallbackByBackendSettings()
-    .fallbackByUserOptions()
-    .build()
-
   private val rules = List(
-    PushFilterToScan,
-    FilterRemoveRule
+    new PushFilterToScan(RasOffload.validator),
+    RemoveFilter
   )
 
   // TODO: Should obey ReplaceSingleNode#applyScanNotTransformable to select
   //  (vanilla) scan with cheaper sub-query plan through cost model.
-  private val implRules = List(
-    AsRasImplement(TransformOthers()),
-    AsRasImplement(TransformExchange()),
-    AsRasImplement(TransformJoin()),
-    ImplementAggregate,
-    ImplementFilter
-  ).map(_.withValidator(validator))
+  private val offloadRules = List(
+    new AsRasOffload(OffloadOthers()),
+    new AsRasOffload(OffloadExchange()),
+    new AsRasOffload(OffloadJoin()),
+    RasOffloadAggregate,
+    RasOffloadFilter
+  )
 
-  private val optimization = GlutenOptimization(rules ++ implRules)
+  private val optimization = GlutenOptimization(rules ++ offloadRules)
 
   private val reqConvention = Conventions.ANY
   private val altConventions =
@@ -75,24 +63,12 @@ case class EnumeratedTransform(session: SparkSession, 
outputsColumnar: Boolean)
 }
 
 object EnumeratedTransform {
-  private case class AsRasImplement(delegate: TransformSingleNode) extends 
RasRule[SparkPlan] {
-    override def shift(node: SparkPlan): Iterable[SparkPlan] = {
-      val out = delegate.impl(node)
-      out match {
-        case t: GlutenPlan if !t.doValidate().isValid =>
-          List.empty
-        case other =>
-          List(other)
-      }
-    }
-
-    override def shape(): Shape[SparkPlan] = Shapes.fixedHeight(1)
-  }
 
-  // TODO: Currently not in use. Prepared for future development.
-  implicit private class RasRuleImplicits(rasRule: RasRule[SparkPlan]) {
-    def withValidator(v: Validator): RasRule[SparkPlan] = {
-      ConditionedRule.wrap(rasRule, v)
+  /** Accepts a [[OffloadSingleNode]] rule to convert it into a RAS offload 
rule. */
+  private class AsRasOffload(delegate: OffloadSingleNode) extends RasOffload {
+    override protected def offload(node: SparkPlan): SparkPlan = {
+      val out = delegate.offload(node)
+      out
     }
   }
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala
index f04f572c1..7306b734a 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala
@@ -17,28 +17,31 @@
 package org.apache.gluten.extension.columnar.enumerated
 
 import org.apache.gluten.execution.{FilterHandler, TransformSupport}
-import org.apache.gluten.extension.columnar.TransformHints
+import org.apache.gluten.extension.columnar.validator.Validator
 import org.apache.gluten.ras.path.Pattern._
 import org.apache.gluten.ras.path.Pattern.Matchers._
 import org.apache.gluten.ras.rule.{RasRule, Shape}
 import org.apache.gluten.ras.rule.Shapes._
 
-import org.apache.spark.sql.execution.{ColumnarToRowExec, 
ColumnarToRowTransition, FileSourceScanExec, FilterExec, SparkPlan}
+import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
 
-object PushFilterToScan extends RasRule[SparkPlan] {
+// TODO: Match on Vanilla filter + Gluten scan.
+class PushFilterToScan(validator: Validator) extends RasRule[SparkPlan] {
   override def shift(node: SparkPlan): Iterable[SparkPlan] = node match {
     case FilterAndScan(filter, scan) =>
-      if (!TransformHints.isTransformable(scan)) {
-        return List.empty
-      }
-      val newScan =
-        FilterHandler.pushFilterToScan(filter.condition, scan)
-      newScan match {
-        case ts: TransformSupport if ts.doValidate().isValid =>
-          List(filter.withNewChildren(List(ts)))
-        case _ =>
+      validator.validate(scan) match {
+        case Validator.Failed(reason) =>
           List.empty
+        case Validator.Passed =>
+          val newScan =
+            FilterHandler.pushFilterToScan(filter.condition, scan)
+          newScan match {
+            case ts: TransformSupport if ts.doValidate().isValid =>
+              List(filter.withNewChildren(List(ts)))
+            case _ =>
+              List.empty
+          }
       }
     case _ =>
       List.empty
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala
new file mode 100644
index 000000000..57e093bde
--- /dev/null
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension.columnar.enumerated
+
+import org.apache.gluten.extension.GlutenPlan
+import org.apache.gluten.extension.columnar.rewrite.RewriteSingleNode
+import org.apache.gluten.extension.columnar.validator.{Validator, Validators}
+import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
+
+import org.apache.spark.sql.execution.SparkPlan
+
+trait RasOffload extends RasRule[SparkPlan] {
+  import RasOffload._
+
+  final override def shift(node: SparkPlan): Iterable[SparkPlan] = {
+    // 0. If the node is already offloaded, return fast.
+    if (node.isInstanceOf[GlutenPlan]) {
+      return List.empty
+    }
+
+    // 1. Rewrite the node to form that native library supports.
+    val rewritten = rewrites.foldLeft(node) {
+      case (node, rewrite) =>
+        node.transformUp {
+          case p =>
+            val out = rewrite.rewrite(p)
+            out
+        }
+    }
+
+    // 2. Walk the rewritten tree.
+    val offloaded = rewritten.transformUp {
+      case from =>
+        // 3. Validate current node. If passed, offload it.
+        validator.validate(from) match {
+          case Validator.Passed =>
+            offload(from) match {
+              case t: GlutenPlan if !t.doValidate().isValid =>
+                // 4. If native validation fails on the offloaded node, return 
the
+                // original one.
+                from
+              case other =>
+                other
+            }
+          case Validator.Failed(reason) =>
+            from
+        }
+    }
+
+    // 5. Return the final tree.
+    List(offloaded)
+  }
+
+  protected def offload(node: SparkPlan): SparkPlan
+
+  final override def shape(): Shape[SparkPlan] = Shapes.fixedHeight(1)
+}
+
+object RasOffload {
+  val validator = Validators
+    .builder()
+    .fallbackByHint()
+    .fallbackIfScanOnly()
+    .fallbackComplexExpressions()
+    .fallbackByBackendSettings()
+    .fallbackByUserOptions()
+    .build()
+
+  private val rewrites = RewriteSingleNode.allRules()
+}
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementAggregate.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadAggregate.scala
similarity index 50%
rename from 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementAggregate.scala
rename to 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadAggregate.scala
index 8c51ca4fd..e48545ae9 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementAggregate.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadAggregate.scala
@@ -16,39 +16,16 @@
  */
 package org.apache.gluten.extension.columnar.enumerated
 
-import org.apache.gluten.backendsapi.BackendsApiManager
 import org.apache.gluten.execution.HashAggregateExecBaseTransformer
-import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
 
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 
-object ImplementAggregate extends RasRule[SparkPlan] {
-  override def shift(node: SparkPlan): Iterable[SparkPlan] = node match {
-    case agg: HashAggregateExec => shiftAgg(agg)
-    case _ => List.empty
+object RasOffloadAggregate extends RasOffload {
+  override protected def offload(node: SparkPlan): SparkPlan = node match {
+    case agg: HashAggregateExec =>
+      val out = HashAggregateExecBaseTransformer.from(agg)()
+      out
+    case other => other
   }
-
-  private def shiftAgg(agg: HashAggregateExec): Iterable[SparkPlan] = {
-    val transformer = implement(agg)
-    if (!transformer.doValidate().isValid) {
-      return List.empty
-    }
-    List(transformer)
-  }
-
-  private def implement(agg: HashAggregateExec): 
HashAggregateExecBaseTransformer = {
-    BackendsApiManager.getSparkPlanExecApiInstance
-      .genHashAggregateExecTransformer(
-        agg.requiredChildDistributionExpressions,
-        agg.groupingExpressions,
-        agg.aggregateExpressions,
-        agg.aggregateAttributes,
-        agg.initialInputBufferOffset,
-        agg.resultExpressions,
-        agg.child
-      )
-  }
-
-  override def shape(): Shape[SparkPlan] = Shapes.fixedHeight(1)
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementFilter.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadFilter.scala
similarity index 75%
rename from 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementFilter.scala
rename to 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadFilter.scala
index 33121e7f1..030d05d47 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementFilter.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadFilter.scala
@@ -17,22 +17,16 @@
 package org.apache.gluten.extension.columnar.enumerated
 
 import org.apache.gluten.backendsapi.BackendsApiManager
-import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
 
 import org.apache.spark.sql.execution.{FilterExec, SparkPlan}
 
-object ImplementFilter extends RasRule[SparkPlan] {
-  override def shift(node: SparkPlan): Iterable[SparkPlan] = node match {
+object RasOffloadFilter extends RasOffload {
+  override protected def offload(node: SparkPlan): SparkPlan = node match {
     case FilterExec(condition, child) =>
       val out = BackendsApiManager.getSparkPlanExecApiInstance
         .genFilterExecTransformer(condition, child)
-      if (!out.doValidate().isValid) {
-        List.empty
-      } else {
-        List(out)
-      }
-    case _ =>
-      List.empty
+      out
+    case other =>
+      other
   }
-  override def shape(): Shape[SparkPlan] = Shapes.fixedHeight(1)
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/FilterRemoveRule.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveFilter.scala
similarity index 97%
rename from 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/FilterRemoveRule.scala
rename to 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveFilter.scala
index 52b5be981..c9f4b27bf 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/FilterRemoveRule.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveFilter.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.SparkPlan
 // because the pushed filter is not considered in the model. Removing the 
filter will make
 // optimizer choose a single scan as the winner sub-plan since a single scan's 
cost is lower than
 // filter + scan.
-object FilterRemoveRule extends RasRule[SparkPlan] {
+object RemoveFilter extends RasRule[SparkPlan] {
   override def shift(node: SparkPlan): Iterable[SparkPlan] = {
     val filter = node.asInstanceOf[FilterExecTransformerBase]
     if (filter.isNoop()) {
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala
index d33cda2e6..0e905ced1 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala
@@ -20,6 +20,7 @@ import org.apache.gluten.GlutenConfig
 import org.apache.gluten.backendsapi.BackendsApiManager
 import org.apache.gluten.extension.columnar._
 import 
org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow,
 RemoveTopmostColumnarToRow, TransformPostOverrides, TransformPreOverrides}
+import 
org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager
 import org.apache.gluten.extension.columnar.util.AdaptiveContext
 import org.apache.gluten.metrics.GlutenTimeMetric
 import org.apache.gluten.utils.{LogLevelUtil, PhysicalPlanSelector}
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/PullOutPostProject.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPostProject.scala
similarity index 92%
rename from 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/PullOutPostProject.scala
rename to 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPostProject.scala
index dc2e6423c..1b5467144 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/PullOutPostProject.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPostProject.scala
@@ -14,13 +14,12 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.gluten.extension.columnar
+package org.apache.gluten.extension.columnar.rewrite
 
 import org.apache.gluten.backendsapi.BackendsApiManager
 import org.apache.gluten.utils.PullOutProjectHelper
 
 import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
NamedExpression, WindowExpression}
-import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.{GenerateExec, ProjectExec, SparkPlan}
 import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
 import org.apache.spark.sql.execution.window.WindowExec
@@ -33,17 +32,17 @@ import scala.collection.mutable.ArrayBuffer
  * the output of Spark, ensuring that the output data of the native plan can 
match the Spark plan
  * when a fallback occurs.
  */
-object PullOutPostProject extends Rule[SparkPlan] with PullOutProjectHelper {
+object PullOutPostProject extends RewriteSingleNode with PullOutProjectHelper {
 
   private def needsPostProjection(plan: SparkPlan): Boolean = {
     plan match {
       case agg: BaseAggregateExec =>
         val pullOutHelper =
           
BackendsApiManager.getSparkPlanExecApiInstance.genHashAggregateExecPullOutHelper(
-            agg.groupingExpressions,
             agg.aggregateExpressions,
             agg.aggregateAttributes)
-        val allAggregateResultAttributes = 
pullOutHelper.allAggregateResultAttributes
+        val allAggregateResultAttributes =
+          pullOutHelper.allAggregateResultAttributes(agg.groupingExpressions)
         // If the result expressions has different size with output attribute,
         // post-projection is needed.
         agg.resultExpressions.size != allAggregateResultAttributes.size ||
@@ -72,14 +71,13 @@ object PullOutPostProject extends Rule[SparkPlan] with 
PullOutProjectHelper {
     }
   }
 
-  override def apply(plan: SparkPlan): SparkPlan = plan match {
+  override def rewrite(plan: SparkPlan): SparkPlan = plan match {
     case agg: BaseAggregateExec if supportedAggregate(agg) && 
needsPostProjection(agg) =>
       val pullOutHelper =
         
BackendsApiManager.getSparkPlanExecApiInstance.genHashAggregateExecPullOutHelper(
-          agg.groupingExpressions,
           agg.aggregateExpressions,
           agg.aggregateAttributes)
-      val newResultExpressions = pullOutHelper.allAggregateResultAttributes
+      val newResultExpressions = 
pullOutHelper.allAggregateResultAttributes(agg.groupingExpressions)
       val newAgg = copyBaseAggregateExec(agg)(newResultExpressions = 
newResultExpressions)
       ProjectExec(agg.resultExpressions, newAgg)
 
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/PullOutPreProject.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala
similarity index 96%
rename from 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/PullOutPreProject.scala
rename to 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala
index 48a9a7687..64d4f2736 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/PullOutPreProject.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/PullOutPreProject.scala
@@ -14,7 +14,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.gluten.extension.columnar
+package org.apache.gluten.extension.columnar.rewrite
 
 import org.apache.gluten.backendsapi.BackendsApiManager
 import org.apache.gluten.sql.shims.SparkShimLoader
@@ -22,8 +22,7 @@ import org.apache.gluten.utils.PullOutProjectHelper
 
 import org.apache.spark.sql.catalyst.expressions._
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Complete, Partial}
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.{ExpandExec, GenerateExec, ProjectExec, 
SortExec, SparkPlan, TakeOrderedAndProjectExec}
+import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, 
TypedAggregateExpression}
 import org.apache.spark.sql.execution.window.{WindowExec, 
WindowGroupLimitExecShim}
 
@@ -36,7 +35,7 @@ import scala.collection.mutable
  * to transform the SparkPlan at the physical plan level, constructing a 
SparkPlan that supports
  * execution by the native engine.
  */
-object PullOutPreProject extends Rule[SparkPlan] with PullOutProjectHelper {
+object PullOutPreProject extends RewriteSingleNode with PullOutProjectHelper {
 
   private def needsPreProject(plan: SparkPlan): Boolean = {
     plan match {
@@ -118,7 +117,7 @@ object PullOutPreProject extends Rule[SparkPlan] with 
PullOutProjectHelper {
     }
   }
 
-  override def apply(plan: SparkPlan): SparkPlan = plan match {
+  override def rewrite(plan: SparkPlan): SparkPlan = plan match {
     case sort: SortExec if needsPreProject(sort) =>
       val expressionMap = new mutable.HashMap[Expression, NamedExpression]()
       val newSortOrder = getNewSortOrder(sort.sortOrder, expressionMap)
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/RewriteCollect.scala 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteCollect.scala
similarity index 93%
rename from 
gluten-core/src/main/scala/org/apache/gluten/extension/RewriteCollect.scala
rename to 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteCollect.scala
index 3b6710857..74d493de5 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/RewriteCollect.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteCollect.scala
@@ -14,14 +14,13 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.gluten.extension
+package org.apache.gluten.extension.columnar.rewrite
 
 import org.apache.gluten.backendsapi.BackendsApiManager
 import org.apache.gluten.utils.PullOutProjectHelper
 
 import org.apache.spark.sql.catalyst.expressions.{And, Attribute, 
AttributeSet, If, IsNotNull, IsNull, Literal, NamedExpression}
-import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
CollectSet, Complete, Final, Partial}
-import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
 import org.apache.spark.sql.types.ArrayType
@@ -36,7 +35,7 @@ import scala.collection.mutable.ArrayBuffer
  *
  * TODO: remove this rule once Velox compatible with vanilla Spark.
  */
-object RewriteCollect extends Rule[SparkPlan] with PullOutProjectHelper {
+object RewriteCollect extends RewriteSingleNode with PullOutProjectHelper {
   private lazy val shouldRewriteCollect =
     BackendsApiManager.getSettings.shouldRewriteCollect()
 
@@ -121,7 +120,7 @@ object RewriteCollect extends Rule[SparkPlan] with 
PullOutProjectHelper {
     (newAggregateAttributes, newResultExpressions)
   }
 
-  override def apply(plan: SparkPlan): SparkPlan = {
+  override def rewrite(plan: SparkPlan): SparkPlan = {
     if (!shouldRewriteCollect) {
       return plan
     }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RewriteMultiChildrenCount.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteMultiChildrenCount.scala
similarity index 93%
rename from 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RewriteMultiChildrenCount.scala
rename to 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteMultiChildrenCount.scala
index 9657c127d..b395d961a 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RewriteMultiChildrenCount.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteMultiChildrenCount.scala
@@ -14,14 +14,13 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.gluten.extension.columnar
+package org.apache.gluten.extension.columnar.rewrite
 
 import org.apache.gluten.backendsapi.BackendsApiManager
 import org.apache.gluten.utils.PullOutProjectHelper
 
 import org.apache.spark.sql.catalyst.expressions.{If, IsNull, Literal, Or}
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Count, Partial}
-import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
 import org.apache.spark.sql.types.IntegerType
@@ -46,7 +45,7 @@ import org.apache.spark.sql.types.IntegerType
  *
  * TODO: Remove this rule when Velox support multi-children Count
  */
-object RewriteMultiChildrenCount extends Rule[SparkPlan] with 
PullOutProjectHelper {
+object RewriteMultiChildrenCount extends RewriteSingleNode with 
PullOutProjectHelper {
   private lazy val shouldRewriteCount = 
BackendsApiManager.getSettings.shouldRewriteCount()
 
   private def extractCountForRewrite(aggExpr: AggregateExpression): 
Option[Count] = {
@@ -92,7 +91,7 @@ object RewriteMultiChildrenCount extends Rule[SparkPlan] with 
PullOutProjectHelp
     }
   }
 
-  override def apply(plan: SparkPlan): SparkPlan = {
+  override def rewrite(plan: SparkPlan): SparkPlan = {
     if (!shouldRewriteCount) {
       return plan
     }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSingleNode.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSingleNode.scala
new file mode 100644
index 000000000..73bc8b967
--- /dev/null
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSingleNode.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension.columnar.rewrite
+
+import org.apache.gluten.extension.RewriteIn
+
+import org.apache.spark.sql.execution.SparkPlan
+
+/**
+ * Rewrites a plan node from vanilla Spark into its alternative representation.
+ *
+ * Gluten's planner will pick one that is considered the best executable plan 
between input plan and
+ * the output plan.
+ *
+ * Note: Only the current plan node is supposed to be open to modification. Do 
not access or modify
+ * the children node. Tree-walking is done by caller of this trait.
+ *
+ * TODO: Ideally for such API we'd better to allow multiple alternative 
outputs.
+ */
+trait RewriteSingleNode {
+  def rewrite(plan: SparkPlan): SparkPlan
+}
+
+object RewriteSingleNode {
+  def allRules(): Seq[RewriteSingleNode] = {
+    Seq(
+      RewriteIn,
+      RewriteMultiChildrenCount,
+      RewriteCollect,
+      RewriteTypedImperativeAggregate,
+      PullOutPreProject,
+      PullOutPostProject)
+  }
+}
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RewriteSparkPlanRulesManager.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala
similarity index 91%
rename from 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RewriteSparkPlanRulesManager.scala
rename to 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala
index 6070613c1..5fd728eca 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RewriteSparkPlanRulesManager.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSparkPlanRulesManager.scala
@@ -14,9 +14,9 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.gluten.extension.columnar
+package org.apache.gluten.extension.columnar.rewrite
 
-import org.apache.gluten.extension.{RewriteCollect, RewriteIn}
+import org.apache.gluten.extension.columnar.{AddTransformHintRule, 
TransformHint, TransformHints}
 import org.apache.gluten.sql.shims.SparkShimLoader
 
 import org.apache.spark.rdd.RDD
@@ -44,7 +44,7 @@ case class RewrittenNodeWall(originalChild: SparkPlan) 
extends LeafExecNode {
  *
  * Note that, this rule does not touch and tag these operators who does not 
need to rewrite.
  */
-class RewriteSparkPlanRulesManager private (rewriteRules: Seq[Rule[SparkPlan]])
+class RewriteSparkPlanRulesManager private (rewriteRules: 
Seq[RewriteSingleNode])
   extends Rule[SparkPlan] {
 
   private def mayNeedRewrite(plan: SparkPlan): Boolean = {
@@ -83,7 +83,7 @@ class RewriteSparkPlanRulesManager private (rewriteRules: 
Seq[Rule[SparkPlan]])
           // Some rewrite rules may generate new parent plan node, we should 
use transform to
           // rewrite the original plan. For example, PullOutPreProject and 
PullOutPostProject
           // will generate post-project plan node.
-          plan.transformUp { case p => rule.apply(p) }
+          plan.transformUp { case p => rule.rewrite(p) }
       }
       (rewrittenPlan, None)
     } catch {
@@ -133,13 +133,6 @@ class RewriteSparkPlanRulesManager private (rewriteRules: 
Seq[Rule[SparkPlan]])
 
 object RewriteSparkPlanRulesManager {
   def apply(): Rule[SparkPlan] = {
-    val rewriteRules = Seq(
-      RewriteIn,
-      RewriteMultiChildrenCount,
-      RewriteCollect,
-      RewriteTypedImperativeAggregate,
-      PullOutPreProject,
-      PullOutPostProject)
-    new RewriteSparkPlanRulesManager(rewriteRules)
+    new RewriteSparkPlanRulesManager(RewriteSingleNode.allRules())
   }
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RewriteTypedImperativeAggregate.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteTypedImperativeAggregate.scala
similarity index 91%
rename from 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RewriteTypedImperativeAggregate.scala
rename to 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteTypedImperativeAggregate.scala
index df5341373..971a87923 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/RewriteTypedImperativeAggregate.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteTypedImperativeAggregate.scala
@@ -14,18 +14,17 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.gluten.extension.columnar
+package org.apache.gluten.extension.columnar.rewrite
 
 import org.apache.gluten.backendsapi.BackendsApiManager
 import org.apache.gluten.utils.PullOutProjectHelper
 
 import org.apache.spark.sql.catalyst.expressions.AttributeReference
 import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
 
-object RewriteTypedImperativeAggregate extends Rule[SparkPlan] with 
PullOutProjectHelper {
+object RewriteTypedImperativeAggregate extends RewriteSingleNode with 
PullOutProjectHelper {
   private lazy val shouldRewriteTypedImperativeAggregate =
     BackendsApiManager.getSettings.shouldRewriteTypedImperativeAggregate()
 
@@ -40,7 +39,7 @@ object RewriteTypedImperativeAggregate extends 
Rule[SparkPlan] with PullOutProje
     }
   }
 
-  override def apply(plan: SparkPlan): SparkPlan = {
+  override def rewrite(plan: SparkPlan): SparkPlan = {
     if (!shouldRewriteTypedImperativeAggregate) {
       return plan
     }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/planner/cost/GlutenCostModel.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/planner/cost/GlutenCostModel.scala
index a5b66df46..2920c0a39 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/planner/cost/GlutenCostModel.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/planner/cost/GlutenCostModel.scala
@@ -16,7 +16,7 @@
  */
 package org.apache.gluten.planner.cost
 
-import org.apache.gluten.extension.columnar.{ColumnarTransitions, 
TransformJoin}
+import org.apache.gluten.extension.columnar.{ColumnarTransitions, OffloadJoin}
 import org.apache.gluten.planner.plan.GlutenPlanModel.GroupLeafExec
 import org.apache.gluten.ras.{Cost, CostModel}
 import org.apache.gluten.utils.PlanUtil
@@ -57,7 +57,9 @@ object GlutenCostModel {
     // A very rough estimation as of now.
     private def selfLongCostOf(node: SparkPlan): Long = {
       node match {
-        case p: ShuffledHashJoinExec if !TransformJoin.isLegal(p) =>
+        case p: ShuffledHashJoinExec if !OffloadJoin.isLegal(p) =>
+          // To exclude the rewritten intermediate plan that is not executable
+          // by vanilla Spark and was generated by strategy 
"JoinSelectionOverrides"
           infLongCost
         case ColumnarToRowExec(child) => 3L
         case RowToColumnarExec(child) => 3L
diff --git 
a/gluten-core/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenFormatWriterInjectsBase.scala
 
b/gluten-core/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenFormatWriterInjectsBase.scala
index 7308703e7..fbdbeadba 100644
--- 
a/gluten-core/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenFormatWriterInjectsBase.scala
+++ 
b/gluten-core/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenFormatWriterInjectsBase.scala
@@ -18,8 +18,9 @@ package org.apache.spark.sql.execution.datasources
 
 import org.apache.gluten.execution.{ProjectExecTransformer, 
SortExecTransformer, TransformSupport, WholeStageTransformer}
 import org.apache.gluten.execution.datasource.GlutenFormatWriterInjects
-import org.apache.gluten.extension.columnar.{AddTransformHintRule, 
RewriteSparkPlanRulesManager}
+import org.apache.gluten.extension.columnar.AddTransformHintRule
 import 
org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverrides
+import 
org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.SparkSession


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to