This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 817767b52 [VL] RAS: Renew validator instance for each rule applier 
call (#6766)
817767b52 is described below

commit 817767b5292967b59294ea8f7e3e73f6b123589c
Author: Hongze Zhang <[email protected]>
AuthorDate: Fri Aug 9 14:30:03 2024 +0800

    [VL] RAS: Renew validator instance for each rule applier call (#6766)
---
 .../columnar/enumerated/EnumeratedTransform.scala  |  74 ++++++------
 .../extension/columnar/enumerated/RasOffload.scala | 124 ++++++++++-----------
 2 files changed, 99 insertions(+), 99 deletions(-)

diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
index 4b747eb70..007f18fca 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
@@ -18,6 +18,7 @@ package org.apache.gluten.extension.columnar.enumerated
 
 import org.apache.gluten.extension.columnar.{OffloadExchange, OffloadJoin, 
OffloadOthers}
 import org.apache.gluten.extension.columnar.transition.ConventionReq
+import org.apache.gluten.extension.columnar.validator.{Validator, Validators}
 import org.apache.gluten.planner.GlutenOptimization
 import org.apache.gluten.planner.cost.GlutenCostModel
 import org.apache.gluten.planner.property.Conv
@@ -41,45 +42,54 @@ case class EnumeratedTransform(session: SparkSession, 
outputsColumnar: Boolean)
   extends Rule[SparkPlan]
   with LogLevelUtil {
 
+  private val validator: Validator = Validators
+    .builder()
+    .fallbackByHint()
+    .fallbackIfScanOnly()
+    .fallbackComplexExpressions()
+    .fallbackByBackendSettings()
+    .fallbackByUserOptions()
+    .fallbackByTestInjects()
+    .build()
+
   private val rules = List(
-    new PushFilterToScan(RasOffload.validator),
+    new PushFilterToScan(validator),
     RemoveSort,
     RemoveFilter
   )
 
   // TODO: Should obey ReplaceSingleNode#applyScanNotTransformable to select
   //  (vanilla) scan with cheaper sub-query plan through cost model.
-  private val offloadRules = List(
-    RasOffload.from[Exchange](OffloadExchange()).toRule,
-    RasOffload.from[BaseJoinExec](OffloadJoin()).toRule,
-    RasOffloadHashAggregate.toRule,
-    RasOffloadFilter.toRule,
-    RasOffloadProject.toRule,
-    RasOffload.from[DataSourceV2ScanExecBase](OffloadOthers()).toRule,
-    RasOffload.from[DataSourceScanExec](OffloadOthers()).toRule,
-    RasOffload
-      .from(
-        (node: SparkPlan) => 
HiveTableScanExecTransformer.isHiveTableScan(node),
-        OffloadOthers())
-      .toRule,
-    RasOffload.from[CoalesceExec](OffloadOthers()).toRule,
-    RasOffload.from[SortAggregateExec](OffloadOthers()).toRule,
-    RasOffload.from[ObjectHashAggregateExec](OffloadOthers()).toRule,
-    RasOffload.from[UnionExec](OffloadOthers()).toRule,
-    RasOffload.from[ExpandExec](OffloadOthers()).toRule,
-    RasOffload.from[WriteFilesExec](OffloadOthers()).toRule,
-    RasOffload.from[SortExec](OffloadOthers()).toRule,
-    RasOffload.from[TakeOrderedAndProjectExec](OffloadOthers()).toRule,
-    RasOffload.from[WindowExec](OffloadOthers()).toRule,
-    RasOffload
-      .from(
-        (node: SparkPlan) => 
SparkShimLoader.getSparkShims.isWindowGroupLimitExec(node),
-        OffloadOthers())
-      .toRule,
-    RasOffload.from[LimitExec](OffloadOthers()).toRule,
-    RasOffload.from[GenerateExec](OffloadOthers()).toRule,
-    RasOffload.from[EvalPythonExec](OffloadOthers()).toRule
-  )
+  private val offloadRules =
+    Seq(
+      RasOffload.from[Exchange](OffloadExchange()),
+      RasOffload.from[BaseJoinExec](OffloadJoin()),
+      RasOffloadHashAggregate,
+      RasOffloadFilter,
+      RasOffloadProject,
+      RasOffload.from[DataSourceV2ScanExecBase](OffloadOthers()),
+      RasOffload.from[DataSourceScanExec](OffloadOthers()),
+      RasOffload
+        .from(
+          (node: SparkPlan) => 
HiveTableScanExecTransformer.isHiveTableScan(node),
+          OffloadOthers()),
+      RasOffload.from[CoalesceExec](OffloadOthers()),
+      RasOffload.from[SortAggregateExec](OffloadOthers()),
+      RasOffload.from[ObjectHashAggregateExec](OffloadOthers()),
+      RasOffload.from[UnionExec](OffloadOthers()),
+      RasOffload.from[ExpandExec](OffloadOthers()),
+      RasOffload.from[WriteFilesExec](OffloadOthers()),
+      RasOffload.from[SortExec](OffloadOthers()),
+      RasOffload.from[TakeOrderedAndProjectExec](OffloadOthers()),
+      RasOffload.from[WindowExec](OffloadOthers()),
+      RasOffload
+        .from(
+          (node: SparkPlan) => 
SparkShimLoader.getSparkShims.isWindowGroupLimitExec(node),
+          OffloadOthers()),
+      RasOffload.from[LimitExec](OffloadOthers()),
+      RasOffload.from[GenerateExec](OffloadOthers()),
+      RasOffload.from[EvalPythonExec](OffloadOthers())
+    ).map(RasOffload.Rule(_, validator))
 
   private val optimization = {
     GlutenOptimization
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala
index 43f52a9e4..fc29b36f0 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala
@@ -19,7 +19,7 @@ package org.apache.gluten.extension.columnar.enumerated
 import org.apache.gluten.extension.GlutenPlan
 import org.apache.gluten.extension.columnar.OffloadSingleNode
 import org.apache.gluten.extension.columnar.rewrite.RewriteSingleNode
-import org.apache.gluten.extension.columnar.validator.{Validator, Validators}
+import org.apache.gluten.extension.columnar.validator.Validator
 import org.apache.gluten.ras.path.Pattern
 import org.apache.gluten.ras.path.Pattern.node
 import org.apache.gluten.ras.rule.{RasRule, Shape}
@@ -49,16 +49,6 @@ object RasOffload {
     }
   }
 
-  val validator: Validator = Validators
-    .builder()
-    .fallbackByHint()
-    .fallbackIfScanOnly()
-    .fallbackComplexExpressions()
-    .fallbackByBackendSettings()
-    .fallbackByUserOptions()
-    .fallbackByTestInjects()
-    .build()
-
   private val rewrites = RewriteSingleNode.allRules()
 
   def from[T <: SparkPlan: ClassTag](base: OffloadSingleNode): RasOffload = {
@@ -75,70 +65,70 @@ object RasOffload {
     }
   }
 
-  implicit class RasOffloadOps(base: RasOffload) {
-    def toRule: RasRule[SparkPlan] = {
-      new RuleImpl(base)
+  object Rule {
+    def apply(base: RasOffload, validator: Validator): RasRule[SparkPlan] = {
+      new RuleImpl(base, validator)
     }
-  }
-
-  private class RuleImpl(base: RasOffload) extends RasRule[SparkPlan] {
-    private val typeIdentifier: TypeIdentifier = base.typeIdentifier()
 
-    final override def shift(node: SparkPlan): Iterable[SparkPlan] = {
-      // 0. If the node is already offloaded, fail fast.
-      assert(typeIdentifier.isInstance(node))
-
-      // 1. Rewrite the node to form that native library supports.
-      val rewritten = rewrites.foldLeft(node) {
-        case (node, rewrite) =>
-          node.transformUp {
-            case p =>
-              val out = rewrite.rewrite(p)
-              out
-          }
-      }
+    private class RuleImpl(base: RasOffload, validator: Validator) extends 
RasRule[SparkPlan] {
+      private val typeIdentifier: TypeIdentifier = base.typeIdentifier()
+
+      final override def shift(node: SparkPlan): Iterable[SparkPlan] = {
+        // 0. If the node is already offloaded, fail fast.
+        assert(typeIdentifier.isInstance(node))
+
+        // 1. Rewrite the node to form that native library supports.
+        val rewritten = rewrites.foldLeft(node) {
+          case (node, rewrite) =>
+            node.transformUp {
+              case p =>
+                val out = rewrite.rewrite(p)
+                out
+            }
+        }
 
-      // 2. Walk the rewritten tree.
-      val offloaded = rewritten.transformUp {
-        case from if typeIdentifier.isInstance(from) =>
-          // 3. Validate current node. If passed, offload it.
-          validator.validate(from) match {
-            case Validator.Passed =>
-              val offloaded = base.offload(from)
-              val offloadedNodes = offloaded.collect[GlutenPlan] { case t: 
GlutenPlan => t }
-              if (offloadedNodes.exists(!_.doValidate().ok())) {
-                // 4. If native validation fails on the offloaded node, return 
the
-                // original one.
+        // 2. Walk the rewritten tree.
+        val offloaded = rewritten.transformUp {
+          case from if typeIdentifier.isInstance(from) =>
+            // 3. Validate current node. If passed, offload it.
+            validator.validate(from) match {
+              case Validator.Passed =>
+                val offloaded = base.offload(from)
+                val offloadedNodes = offloaded.collect[GlutenPlan] { case t: 
GlutenPlan => t }
+                if (offloadedNodes.exists(!_.doValidate().ok())) {
+                  // 4. If native validation fails on the offloaded node, 
return the
+                  // original one.
+                  from
+                } else {
+                  offloaded
+                }
+              case Validator.Failed(reason) =>
                 from
-              } else {
-                offloaded
-              }
-            case Validator.Failed(reason) =>
-              from
-          }
-      }
+            }
+        }
 
-      // 5. If rewritten plan is not offload-able, discard it.
-      if (offloaded.fastEquals(rewritten)) {
-        return List.empty
-      }
+        // 5. If rewritten plan is not offload-able, discard it.
+        if (offloaded.fastEquals(rewritten)) {
+          return List.empty
+        }
 
-      // 6. Otherwise, return the final tree.
-      List(offloaded)
-    }
+        // 6. Otherwise, return the final tree.
+        List(offloaded)
+      }
 
-    override def shape(): Shape[SparkPlan] = {
-      pattern(node[SparkPlan](new Pattern.Matcher[SparkPlan] {
-        override def apply(plan: SparkPlan): Boolean = {
-          if (plan.isInstanceOf[GlutenPlan]) {
-            return false
-          }
-          if (typeIdentifier.isInstance(plan)) {
-            return true
+      override def shape(): Shape[SparkPlan] = {
+        pattern(node[SparkPlan](new Pattern.Matcher[SparkPlan] {
+          override def apply(plan: SparkPlan): Boolean = {
+            if (plan.isInstanceOf[GlutenPlan]) {
+              return false
+            }
+            if (typeIdentifier.isInstance(plan)) {
+              return true
+            }
+            false
           }
-          false
-        }
-      }).build())
+        }).build())
+      }
     }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to