This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new f9779db1e [VL] RAS: Remove AddTransformHintRule route from
EnumeratedApplier (#5552)
f9779db1e is described below
commit f9779db1e97873b5befe6a42c6237aa0e352731f
Author: Hongze Zhang <[email protected]>
AuthorDate: Sun Apr 28 08:07:35 2024 +0800
[VL] RAS: Remove AddTransformHintRule route from EnumeratedApplier (#5552)
---
.../extension/columnar/TransformSingleNode.scala | 33 ++++++++---------
.../columnar/enumerated/ConditionedRule.scala | 23 ++----------
.../columnar/enumerated/EnumeratedApplier.scala | 3 +-
.../columnar/enumerated/EnumeratedTransform.scala | 42 ++++++++++++++++------
.../columnar/enumerated/ImplementAggregate.scala | 11 +++---
.../columnar/enumerated/ImplementFilter.scala | 12 ++++---
.../gluten/planner/cost/GlutenCostModel.scala | 29 +++++++++------
7 files changed, 83 insertions(+), 70 deletions(-)
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformSingleNode.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformSingleNode.scala
index b8f99330e..760929bbd 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformSingleNode.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformSingleNode.scala
@@ -132,22 +132,7 @@ case class TransformExchange() extends TransformSingleNode
with LogLevelUtil {
// Join transformation.
case class TransformJoin() extends TransformSingleNode with LogLevelUtil {
-
- /**
- * Get the build side supported by the execution of vanilla Spark.
- *
- * @param plan
- * : shuffled hash join plan
- * @return
- * the supported build side
- */
- private def getSparkSupportedBuildSide(plan: ShuffledHashJoinExec):
BuildSide = {
- plan.joinType match {
- case LeftOuter | LeftSemi => BuildRight
- case RightOuter => BuildLeft
- case _ => plan.buildSide
- }
- }
+ import TransformJoin._
override def impl(plan: SparkPlan): SparkPlan = {
if (TransformHints.isNotTransformable(plan)) {
@@ -155,6 +140,7 @@ case class TransformJoin() extends TransformSingleNode with
LogLevelUtil {
plan match {
case shj: ShuffledHashJoinExec =>
if (BackendsApiManager.getSettings.recreateJoinExecOnFallback()) {
+ // Since https://github.com/apache/incubator-gluten/pull/408
// Because we manually removed the build side limitation for
LeftOuter, LeftSemi and
// RightOuter, need to change the build side back if this join
fallback into vanilla
// Spark for execution.
@@ -237,6 +223,20 @@ case class TransformJoin() extends TransformSingleNode
with LogLevelUtil {
}
+object TransformJoin {
+ private def getSparkSupportedBuildSide(plan: ShuffledHashJoinExec):
BuildSide = {
+ plan.joinType match {
+ case LeftOuter | LeftSemi => BuildRight
+ case RightOuter => BuildLeft
+ case _ => plan.buildSide
+ }
+ }
+
+ def isLegal(plan: ShuffledHashJoinExec): Boolean = {
+ plan.buildSide == getSparkSupportedBuildSide(plan)
+ }
+}
+
// Filter transformation.
case class TransformFilter() extends TransformSingleNode with LogLevelUtil {
import TransformOthers._
@@ -465,6 +465,7 @@ object TransformOthers {
}
}
+ // Since https://github.com/apache/incubator-gluten/pull/2701
private def applyScanNotTransformable(plan: SparkPlan): SparkPlan = plan
match {
case plan: FileSourceScanExec =>
val newPartitionFilters =
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ConditionedRule.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ConditionedRule.scala
index 092d67efc..33d99f5f7 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ConditionedRule.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ConditionedRule.scala
@@ -37,31 +37,12 @@ object ConditionedRule {
}
}
- trait PostCondition {
- def apply(node: SparkPlan): Boolean
- }
-
- object PostCondition {
- implicit class FromValidator(validator: Validator) extends PostCondition {
- override def apply(node: SparkPlan): Boolean = {
- validator.validate(node) match {
- case Validator.Passed => true
- case Validator.Failed(reason) => false
- }
- }
- }
- }
-
- def wrap(
- rule: RasRule[SparkPlan],
- pre: ConditionedRule.PreCondition,
- post: ConditionedRule.PostCondition): RasRule[SparkPlan] = {
+ def wrap(rule: RasRule[SparkPlan], cond: ConditionedRule.PreCondition):
RasRule[SparkPlan] = {
new RasRule[SparkPlan] {
override def shift(node: SparkPlan): Iterable[SparkPlan] = {
val out = List(node)
- .filter(pre.apply)
+ .filter(cond.apply)
.flatMap(rule.shift)
- .filter(post.apply)
out
}
override def shape(): Shape[SparkPlan] = rule.shape()
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala
index 091761e6e..dfc2d474f 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala
@@ -120,8 +120,7 @@ class EnumeratedApplier(session: SparkSession)
BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarValidationRules()
:::
List(
(spark: SparkSession) => MergeTwoPhasesHashBaseAggregate(spark),
- (_: SparkSession) => RewriteSparkPlanRulesManager(),
- (_: SparkSession) => AddTransformHintRule()
+ (_: SparkSession) => RewriteSparkPlanRulesManager()
) :::
List(
(session: SparkSession) => EnumeratedTransform(session,
outputsColumnar),
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
index 27dc1be3d..973020438 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
@@ -16,8 +16,9 @@
*/
package org.apache.gluten.extension.columnar.enumerated
+import org.apache.gluten.extension.GlutenPlan
import org.apache.gluten.extension.columnar.{TransformExchange, TransformJoin,
TransformOthers, TransformSingleNode}
-import org.apache.gluten.extension.columnar.validator.Validator
+import org.apache.gluten.extension.columnar.validator.{Validator, Validators}
import org.apache.gluten.planner.GlutenOptimization
import org.apache.gluten.planner.property.Conventions
import org.apache.gluten.ras.property.PropertySet
@@ -33,17 +34,31 @@ case class EnumeratedTransform(session: SparkSession,
outputsColumnar: Boolean)
with LogLevelUtil {
import EnumeratedTransform._
- private val rasRules = List(
+ private val validator = Validators
+ .builder()
+ .fallbackByHint()
+ .fallbackIfScanOnly()
+ .fallbackComplexExpressions()
+ .fallbackByBackendSettings()
+ .fallbackByUserOptions()
+ .build()
+
+ private val rules = List(
+ PushFilterToScan,
+ FilterRemoveRule
+ )
+
+ // TODO: Should obey ReplaceSingleNode#applyScanNotTransformable to select
+ // (vanilla) scan with cheaper sub-query plan through cost model.
+ private val implRules = List(
AsRasImplement(TransformOthers()),
AsRasImplement(TransformExchange()),
AsRasImplement(TransformJoin()),
ImplementAggregate,
- ImplementFilter,
- PushFilterToScan,
- FilterRemoveRule
- )
+ ImplementFilter
+ ).map(_.withValidator(validator))
- private val optimization = GlutenOptimization(rasRules)
+ private val optimization = GlutenOptimization(rules ++ implRules)
private val reqConvention = Conventions.ANY
private val altConventions =
@@ -62,8 +77,13 @@ case class EnumeratedTransform(session: SparkSession,
outputsColumnar: Boolean)
object EnumeratedTransform {
private case class AsRasImplement(delegate: TransformSingleNode) extends
RasRule[SparkPlan] {
override def shift(node: SparkPlan): Iterable[SparkPlan] = {
- val out = List(delegate.impl(node))
- out
+ val out = delegate.impl(node)
+ out match {
+ case t: GlutenPlan if !t.doValidate().isValid =>
+ List.empty
+ case other =>
+ List(other)
+ }
}
override def shape(): Shape[SparkPlan] = Shapes.fixedHeight(1)
@@ -71,8 +91,8 @@ object EnumeratedTransform {
// TODO: Currently not in use. Prepared for future development.
implicit private class RasRuleImplicits(rasRule: RasRule[SparkPlan]) {
- def withValidator(pre: Validator, post: Validator): RasRule[SparkPlan] = {
- ConditionedRule.wrap(rasRule, pre, post)
+ def withValidator(v: Validator): RasRule[SparkPlan] = {
+ ConditionedRule.wrap(rasRule, v)
}
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementAggregate.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementAggregate.scala
index 818d22568..8c51ca4fd 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementAggregate.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementAggregate.scala
@@ -17,7 +17,7 @@
package org.apache.gluten.extension.columnar.enumerated
import org.apache.gluten.backendsapi.BackendsApiManager
-import org.apache.gluten.extension.columnar.TransformHints
+import org.apache.gluten.execution.HashAggregateExecBaseTransformer
import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
import org.apache.spark.sql.execution.SparkPlan
@@ -25,16 +25,19 @@ import
org.apache.spark.sql.execution.aggregate.HashAggregateExec
object ImplementAggregate extends RasRule[SparkPlan] {
override def shift(node: SparkPlan): Iterable[SparkPlan] = node match {
- case plan if TransformHints.isNotTransformable(plan) => List.empty
case agg: HashAggregateExec => shiftAgg(agg)
case _ => List.empty
}
private def shiftAgg(agg: HashAggregateExec): Iterable[SparkPlan] = {
- List(implement(agg))
+ val transformer = implement(agg)
+ if (!transformer.doValidate().isValid) {
+ return List.empty
+ }
+ List(transformer)
}
- private def implement(agg: HashAggregateExec): SparkPlan = {
+ private def implement(agg: HashAggregateExec):
HashAggregateExecBaseTransformer = {
BackendsApiManager.getSparkPlanExecApiInstance
.genHashAggregateExecTransformer(
agg.requiredChildDistributionExpressions,
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementFilter.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementFilter.scala
index 6ec384bd3..33121e7f1 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementFilter.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/ImplementFilter.scala
@@ -17,18 +17,20 @@
package org.apache.gluten.extension.columnar.enumerated
import org.apache.gluten.backendsapi.BackendsApiManager
-import org.apache.gluten.extension.columnar.TransformHints
import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
import org.apache.spark.sql.execution.{FilterExec, SparkPlan}
object ImplementFilter extends RasRule[SparkPlan] {
override def shift(node: SparkPlan): Iterable[SparkPlan] = node match {
- case plan if TransformHints.isNotTransformable(plan) => List.empty
case FilterExec(condition, child) =>
- List(
- BackendsApiManager.getSparkPlanExecApiInstance
- .genFilterExecTransformer(condition, child))
+ val out = BackendsApiManager.getSparkPlanExecApiInstance
+ .genFilterExecTransformer(condition, child)
+ if (!out.doValidate().isValid) {
+ List.empty
+ } else {
+ List(out)
+ }
case _ =>
List.empty
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/planner/cost/GlutenCostModel.scala
b/gluten-core/src/main/scala/org/apache/gluten/planner/cost/GlutenCostModel.scala
index e1295480c..a5b66df46 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/planner/cost/GlutenCostModel.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/planner/cost/GlutenCostModel.scala
@@ -16,12 +16,13 @@
*/
package org.apache.gluten.planner.cost
-import org.apache.gluten.extension.columnar.ColumnarTransitions
+import org.apache.gluten.extension.columnar.{ColumnarTransitions,
TransformJoin}
import org.apache.gluten.planner.plan.GlutenPlanModel.GroupLeafExec
import org.apache.gluten.ras.{Cost, CostModel}
import org.apache.gluten.utils.PlanUtil
import org.apache.spark.sql.execution.{ColumnarToRowExec, RowToColumnarExec,
SparkPlan}
+import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
class GlutenCostModel {}
@@ -31,6 +32,8 @@ object GlutenCostModel {
}
private object RoughCostModel extends CostModel[SparkPlan] {
+ private val infLongCost = Long.MaxValue
+
override def costOf(node: SparkPlan): GlutenCost = node match {
case _: GroupLeafExec => throw new IllegalStateException()
case _ => GlutenCost(longCostOf(node))
@@ -52,15 +55,19 @@ object GlutenCostModel {
}
// A very rough estimation as of now.
- private def selfLongCostOf(node: SparkPlan): Long = node match {
- case ColumnarToRowExec(child) => 3L
- case RowToColumnarExec(child) => 3L
- case ColumnarTransitions.ColumnarToRowLike(child) => 3L
- case ColumnarTransitions.RowToColumnarLike(child) => 3L
- case p if PlanUtil.isGlutenColumnarOp(p) => 2L
- case p if PlanUtil.isVanillaColumnarOp(p) => 3L
- // Other row ops. Usually a vanilla row op.
- case _ => 5L
+ private def selfLongCostOf(node: SparkPlan): Long = {
+ node match {
+ case p: ShuffledHashJoinExec if !TransformJoin.isLegal(p) =>
+ infLongCost
+ case ColumnarToRowExec(child) => 3L
+ case RowToColumnarExec(child) => 3L
+ case ColumnarTransitions.ColumnarToRowLike(child) => 3L
+ case ColumnarTransitions.RowToColumnarLike(child) => 3L
+ case p if PlanUtil.isGlutenColumnarOp(p) => 2L
+ case p if PlanUtil.isVanillaColumnarOp(p) => 3L
+ // Other row ops. Usually a vanilla row op.
+ case _ => 5L
+ }
}
override def costComparator(): Ordering[Cost] = Ordering.Long.on {
@@ -68,6 +75,6 @@ object GlutenCostModel {
case _ => throw new IllegalStateException("Unexpected cost type")
}
- override def makeInfCost(): Cost = GlutenCost(Long.MaxValue)
+ override def makeInfCost(): Cost = GlutenCost(infLongCost)
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]