This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new f9779db1e [VL] RAS: Remove AddTransformHintRule route from 
EnumeratedApplier (#5552)
f9779db1e is described below

commit f9779db1e97873b5befe6a42c6237aa0e352731f
Author: Hongze Zhang <[email protected]>
AuthorDate: Sun Apr 28 08:07:35 2024 +0800

    [VL] RAS: Remove AddTransformHintRule route from EnumeratedApplier (#5552)
---
 .../extension/columnar/TransformSingleNode.scala   | 33 ++++++++---------
 .../columnar/enumerated/ConditionedRule.scala      | 23 ++----------
 .../columnar/enumerated/EnumeratedApplier.scala    |  3 +-
 .../columnar/enumerated/EnumeratedTransform.scala  | 42 ++++++++++++++++------
 .../columnar/enumerated/ImplementAggregate.scala   | 11 +++---
 .../columnar/enumerated/ImplementFilter.scala      | 12 ++++---
 .../gluten/planner/cost/GlutenCostModel.scala      | 29 +++++++++------
 7 files changed, 83 insertions(+), 70 deletions(-)

diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformSingleNode.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformSingleNode.scala
index b8f99330e..760929bbd 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformSingleNode.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformSingleNode.scala
@@ -132,22 +132,7 @@ case class TransformExchange() extends TransformSingleNode 
with LogLevelUtil {
 
 // Join transformation.
 case class TransformJoin() extends TransformSingleNode with LogLevelUtil {
-
-  /**
-   * Get the build side supported by the execution of vanilla Spark.
-   *
-   * @param plan
-   *   : shuffled hash join plan
-   * @return
-   *   the supported build side
-   */
-  private def getSparkSupportedBuildSide(plan: ShuffledHashJoinExec): 
BuildSide = {
-    plan.joinType match {
-      case LeftOuter | LeftSemi => BuildRight
-      case RightOuter => BuildLeft
-      case _ => plan.buildSide
-    }
-  }
+  import TransformJoin._
 
   override def impl(plan: SparkPlan): SparkPlan = {
     if (TransformHints.isNotTransformable(plan)) {
@@ -155,6 +140,7 @@ case class TransformJoin() extends TransformSingleNode with 
LogLevelUtil {
       plan match {
         case shj: ShuffledHashJoinExec =>
           if (BackendsApiManager.getSettings.recreateJoinExecOnFallback()) {
+            // Since https://github.com/apache/incubator-gluten/pull/408
             // Because we manually removed the build side limitation for 
LeftOuter, LeftSemi and
             // RightOuter, need to change the build side back if this join 
fallback into vanilla
             // Spark for execution.
@@ -237,6 +223,20 @@ case class TransformJoin() extends TransformSingleNode 
with LogLevelUtil {
 
 }
 
+object TransformJoin {
+  private def getSparkSupportedBuildSide(plan: ShuffledHashJoinExec): 
BuildSide = {
+    plan.joinType match {
+      case LeftOuter | LeftSemi => BuildRight
+      case RightOuter => BuildLeft
+      case _ => plan.buildSide
+    }
+  }
+
+  def isLegal(plan: ShuffledHashJoinExec): Boolean = {
+    plan.buildSide == getSparkSupportedBuildSide(plan)
+  }
+}
+
 // Filter transformation.
 case class TransformFilter() extends TransformSingleNode with LogLevelUtil {
   import TransformOthers._
@@ -465,6 +465,7 @@ object TransformOthers {
       }
     }
 
+    // Since https://github.com/apache/incubator-gluten/pull/2701
     private def applyScanNotTransformable(plan: SparkPlan): SparkPlan = plan 
match {
       case plan: FileSourceScanExec =>
         val newPartitionFilters =
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ConditionedRule.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ConditionedRule.scala
index 092d67efc..33d99f5f7 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ConditionedRule.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ConditionedRule.scala
@@ -37,31 +37,12 @@ object ConditionedRule {
     }
   }
 
-  trait PostCondition {
-    def apply(node: SparkPlan): Boolean
-  }
-
-  object PostCondition {
-    implicit class FromValidator(validator: Validator) extends PostCondition {
-      override def apply(node: SparkPlan): Boolean = {
-        validator.validate(node) match {
-          case Validator.Passed => true
-          case Validator.Failed(reason) => false
-        }
-      }
-    }
-  }
-
-  def wrap(
-      rule: RasRule[SparkPlan],
-      pre: ConditionedRule.PreCondition,
-      post: ConditionedRule.PostCondition): RasRule[SparkPlan] = {
+  def wrap(rule: RasRule[SparkPlan], cond: ConditionedRule.PreCondition): 
RasRule[SparkPlan] = {
     new RasRule[SparkPlan] {
       override def shift(node: SparkPlan): Iterable[SparkPlan] = {
         val out = List(node)
-          .filter(pre.apply)
+          .filter(cond.apply)
           .flatMap(rule.shift)
-          .filter(post.apply)
         out
       }
       override def shape(): Shape[SparkPlan] = rule.shape()
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala
index 091761e6e..dfc2d474f 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala
@@ -120,8 +120,7 @@ class EnumeratedApplier(session: SparkSession)
       
BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarValidationRules()
 :::
       List(
         (spark: SparkSession) => MergeTwoPhasesHashBaseAggregate(spark),
-        (_: SparkSession) => RewriteSparkPlanRulesManager(),
-        (_: SparkSession) => AddTransformHintRule()
+        (_: SparkSession) => RewriteSparkPlanRulesManager()
       ) :::
       List(
         (session: SparkSession) => EnumeratedTransform(session, 
outputsColumnar),
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
index 27dc1be3d..973020438 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
@@ -16,8 +16,9 @@
  */
 package org.apache.gluten.extension.columnar.enumerated
 
+import org.apache.gluten.extension.GlutenPlan
 import org.apache.gluten.extension.columnar.{TransformExchange, TransformJoin, 
TransformOthers, TransformSingleNode}
-import org.apache.gluten.extension.columnar.validator.Validator
+import org.apache.gluten.extension.columnar.validator.{Validator, Validators}
 import org.apache.gluten.planner.GlutenOptimization
 import org.apache.gluten.planner.property.Conventions
 import org.apache.gluten.ras.property.PropertySet
@@ -33,17 +34,31 @@ case class EnumeratedTransform(session: SparkSession, 
outputsColumnar: Boolean)
   with LogLevelUtil {
   import EnumeratedTransform._
 
-  private val rasRules = List(
+  private val validator = Validators
+    .builder()
+    .fallbackByHint()
+    .fallbackIfScanOnly()
+    .fallbackComplexExpressions()
+    .fallbackByBackendSettings()
+    .fallbackByUserOptions()
+    .build()
+
+  private val rules = List(
+    PushFilterToScan,
+    FilterRemoveRule
+  )
+
+  // TODO: Should obey ReplaceSingleNode#applyScanNotTransformable to select
+  //  (vanilla) scan with cheaper sub-query plan through cost model.
+  private val implRules = List(
     AsRasImplement(TransformOthers()),
     AsRasImplement(TransformExchange()),
     AsRasImplement(TransformJoin()),
     ImplementAggregate,
-    ImplementFilter,
-    PushFilterToScan,
-    FilterRemoveRule
-  )
+    ImplementFilter
+  ).map(_.withValidator(validator))
 
-  private val optimization = GlutenOptimization(rasRules)
+  private val optimization = GlutenOptimization(rules ++ implRules)
 
   private val reqConvention = Conventions.ANY
   private val altConventions =
@@ -62,8 +77,13 @@ case class EnumeratedTransform(session: SparkSession, 
outputsColumnar: Boolean)
 object EnumeratedTransform {
   private case class AsRasImplement(delegate: TransformSingleNode) extends 
RasRule[SparkPlan] {
     override def shift(node: SparkPlan): Iterable[SparkPlan] = {
-      val out = List(delegate.impl(node))
-      out
+      val out = delegate.impl(node)
+      out match {
+        case t: GlutenPlan if !t.doValidate().isValid =>
+          List.empty
+        case other =>
+          List(other)
+      }
     }
 
     override def shape(): Shape[SparkPlan] = Shapes.fixedHeight(1)
@@ -71,8 +91,8 @@ object EnumeratedTransform {
 
   // TODO: Currently not in use. Prepared for future development.
   implicit private class RasRuleImplicits(rasRule: RasRule[SparkPlan]) {
-    def withValidator(pre: Validator, post: Validator): RasRule[SparkPlan] = {
-      ConditionedRule.wrap(rasRule, pre, post)
+    def withValidator(v: Validator): RasRule[SparkPlan] = {
+      ConditionedRule.wrap(rasRule, v)
     }
   }
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementAggregate.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementAggregate.scala
index 818d22568..8c51ca4fd 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementAggregate.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementAggregate.scala
@@ -17,7 +17,7 @@
 package org.apache.gluten.extension.columnar.enumerated
 
 import org.apache.gluten.backendsapi.BackendsApiManager
-import org.apache.gluten.extension.columnar.TransformHints
+import org.apache.gluten.execution.HashAggregateExecBaseTransformer
 import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
 
 import org.apache.spark.sql.execution.SparkPlan
@@ -25,16 +25,19 @@ import 
org.apache.spark.sql.execution.aggregate.HashAggregateExec
 
 object ImplementAggregate extends RasRule[SparkPlan] {
   override def shift(node: SparkPlan): Iterable[SparkPlan] = node match {
-    case plan if TransformHints.isNotTransformable(plan) => List.empty
     case agg: HashAggregateExec => shiftAgg(agg)
     case _ => List.empty
   }
 
   private def shiftAgg(agg: HashAggregateExec): Iterable[SparkPlan] = {
-    List(implement(agg))
+    val transformer = implement(agg)
+    if (!transformer.doValidate().isValid) {
+      return List.empty
+    }
+    List(transformer)
   }
 
-  private def implement(agg: HashAggregateExec): SparkPlan = {
+  private def implement(agg: HashAggregateExec): 
HashAggregateExecBaseTransformer = {
     BackendsApiManager.getSparkPlanExecApiInstance
       .genHashAggregateExecTransformer(
         agg.requiredChildDistributionExpressions,
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementFilter.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementFilter.scala
index 6ec384bd3..33121e7f1 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementFilter.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementFilter.scala
@@ -17,18 +17,20 @@
 package org.apache.gluten.extension.columnar.enumerated
 
 import org.apache.gluten.backendsapi.BackendsApiManager
-import org.apache.gluten.extension.columnar.TransformHints
 import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
 
 import org.apache.spark.sql.execution.{FilterExec, SparkPlan}
 
 object ImplementFilter extends RasRule[SparkPlan] {
   override def shift(node: SparkPlan): Iterable[SparkPlan] = node match {
-    case plan if TransformHints.isNotTransformable(plan) => List.empty
     case FilterExec(condition, child) =>
-      List(
-        BackendsApiManager.getSparkPlanExecApiInstance
-          .genFilterExecTransformer(condition, child))
+      val out = BackendsApiManager.getSparkPlanExecApiInstance
+        .genFilterExecTransformer(condition, child)
+      if (!out.doValidate().isValid) {
+        List.empty
+      } else {
+        List(out)
+      }
     case _ =>
       List.empty
   }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/planner/cost/GlutenCostModel.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/planner/cost/GlutenCostModel.scala
index e1295480c..a5b66df46 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/planner/cost/GlutenCostModel.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/planner/cost/GlutenCostModel.scala
@@ -16,12 +16,13 @@
  */
 package org.apache.gluten.planner.cost
 
-import org.apache.gluten.extension.columnar.ColumnarTransitions
+import org.apache.gluten.extension.columnar.{ColumnarTransitions, 
TransformJoin}
 import org.apache.gluten.planner.plan.GlutenPlanModel.GroupLeafExec
 import org.apache.gluten.ras.{Cost, CostModel}
 import org.apache.gluten.utils.PlanUtil
 
 import org.apache.spark.sql.execution.{ColumnarToRowExec, RowToColumnarExec, 
SparkPlan}
+import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
 
 class GlutenCostModel {}
 
@@ -31,6 +32,8 @@ object GlutenCostModel {
   }
 
   private object RoughCostModel extends CostModel[SparkPlan] {
+    private val infLongCost = Long.MaxValue
+
     override def costOf(node: SparkPlan): GlutenCost = node match {
       case _: GroupLeafExec => throw new IllegalStateException()
       case _ => GlutenCost(longCostOf(node))
@@ -52,15 +55,19 @@ object GlutenCostModel {
     }
 
     // A very rough estimation as of now.
-    private def selfLongCostOf(node: SparkPlan): Long = node match {
-      case ColumnarToRowExec(child) => 3L
-      case RowToColumnarExec(child) => 3L
-      case ColumnarTransitions.ColumnarToRowLike(child) => 3L
-      case ColumnarTransitions.RowToColumnarLike(child) => 3L
-      case p if PlanUtil.isGlutenColumnarOp(p) => 2L
-      case p if PlanUtil.isVanillaColumnarOp(p) => 3L
-      // Other row ops. Usually a vanilla row op.
-      case _ => 5L
+    private def selfLongCostOf(node: SparkPlan): Long = {
+      node match {
+        case p: ShuffledHashJoinExec if !TransformJoin.isLegal(p) =>
+          infLongCost
+        case ColumnarToRowExec(child) => 3L
+        case RowToColumnarExec(child) => 3L
+        case ColumnarTransitions.ColumnarToRowLike(child) => 3L
+        case ColumnarTransitions.RowToColumnarLike(child) => 3L
+        case p if PlanUtil.isGlutenColumnarOp(p) => 2L
+        case p if PlanUtil.isVanillaColumnarOp(p) => 3L
+        // Other row ops. Usually a vanilla row op.
+        case _ => 5L
+      }
     }
 
     override def costComparator(): Ordering[Cost] = Ordering.Long.on {
@@ -68,6 +75,6 @@ object GlutenCostModel {
       case _ => throw new IllegalStateException("Unexpected cost type")
     }
 
-    override def makeInfCost(): Cost = GlutenCost(Long.MaxValue)
+    override def makeInfCost(): Cost = GlutenCost(infLongCost)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to