This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 7ad2086fd [GLUTEN-6882][CORE] Move Spark / columnar rule list to
backend code (#6931)
7ad2086fd is described below
commit 7ad2086fdf90d354edc9d7008fd7ea3bc78b0e06
Author: Hongze Zhang <[email protected]>
AuthorDate: Tue Aug 20 17:03:15 2024 +0800
[GLUTEN-6882][CORE] Move Spark / columnar rule list to backend code (#6931)
Closes #6882
---
.../gluten/backendsapi/clickhouse/CHBackend.scala | 1 +
.../gluten/backendsapi/clickhouse/CHRuleApi.scala | 111 +++++++++++++
.../clickhouse/CHSparkPlanExecApi.scala | 83 ----------
.../gluten/backendsapi/velox/VeloxBackend.scala | 1 +
.../gluten/backendsapi/velox/VeloxRuleApi.scala | 133 ++++++++++++++++
.../backendsapi/velox/VeloxSparkPlanExecApi.scala | 84 +---------
.../extension/FlushableHashAggregateRule.scala | 46 +++---
.../scala/org/apache/gluten/GlutenPlugin.scala | 31 +---
.../org/apache/gluten/backendsapi/Backend.scala | 2 +
.../gluten/backendsapi/BackendsApiManager.scala | 4 +
.../backendsapi/{Backend.scala => RuleApi.scala} | 27 +---
.../gluten/backendsapi/SparkPlanExecApi.scala | 73 +--------
.../gluten/extension/ColumnarOverrides.scala | 21 +--
.../gluten/extension/GlutenSessionExtensions.scala | 39 +++++
.../extension/OthersExtensionOverrides.scala | 48 ------
.../gluten/extension/QueryStagePrepOverrides.scala | 50 ------
.../extension/columnar/ColumnarRuleApplier.scala | 13 ++
.../columnar/enumerated/EnumeratedApplier.scala | 85 ++--------
.../columnar/heuristic/HeuristicApplier.scala | 87 ++++-------
.../extension/columnar/util/AdaptiveContext.scala | 1 +
.../gluten/extension/injector/GlutenInjector.scala | 94 +++++++++++
.../injector/RuleInjector.scala} | 33 ++--
.../gluten/extension/injector/SparkInjector.scala | 83 ++++++++++
.../{SparkRuleUtil.scala => SparkPlanRules.scala} | 52 ++++---
.../sql/execution/FallbackStrategiesSuite.scala | 167 +++++++++++---------
.../extension/GlutenSessionExtensionSuite.scala | 3 +-
.../sql/execution/FallbackStrategiesSuite.scala | 171 +++++++++++---------
.../extension/GlutenSessionExtensionSuite.scala | 3 +-
.../sql/execution/FallbackStrategiesSuite.scala | 173 ++++++++++++---------
.../extension/GlutenSessionExtensionSuite.scala | 3 +-
.../sql/execution/FallbackStrategiesSuite.scala | 168 +++++++++++---------
.../extension/GlutenSessionExtensionSuite.scala | 3 +-
32 files changed, 994 insertions(+), 899 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
index 9884a0c6e..41ffbdb58 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
@@ -53,6 +53,7 @@ class CHBackend extends Backend {
override def validatorApi(): ValidatorApi = new CHValidatorApi
override def metricsApi(): MetricsApi = new CHMetricsApi
override def listenerApi(): ListenerApi = new CHListenerApi
+ override def ruleApi(): RuleApi = new CHRuleApi
override def settings(): BackendSettingsApi = CHBackendSettings
}
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
new file mode 100644
index 000000000..177d6a6f0
--- /dev/null
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.backendsapi.clickhouse
+
+import org.apache.gluten.backendsapi.RuleApi
+import org.apache.gluten.extension._
+import org.apache.gluten.extension.columnar._
+import
org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow,
RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast, TransformPreOverrides}
+import
org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager
+import org.apache.gluten.extension.columnar.transition.{InsertTransitions,
RemoveTransitions}
+import org.apache.gluten.extension.injector.{RuleInjector, SparkInjector}
+import org.apache.gluten.extension.injector.GlutenInjector.{LegacyInjector,
RasInjector}
+import org.apache.gluten.parser.GlutenClickhouseSqlParser
+import org.apache.gluten.sql.shims.SparkShimLoader
+
+import org.apache.spark.sql.catalyst.{CHAggregateFunctionRewriteRule,
EqualToRewrite}
+import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages,
GlutenFallbackReporter}
+import org.apache.spark.util.SparkPlanRules
+
+class CHRuleApi extends RuleApi {
+ import CHRuleApi._
+ override def injectRules(injector: RuleInjector): Unit = {
+ injectSpark(injector.spark)
+ injectLegacy(injector.gluten.legacy)
+ injectRas(injector.gluten.ras)
+ }
+}
+
+private object CHRuleApi {
+ def injectSpark(injector: SparkInjector): Unit = {
+ // Regular Spark rules.
+
injector.injectQueryStagePrepRule(FallbackBroadcastHashJoinPrepQueryStage.apply)
+ injector.injectParser(
+ (spark, parserInterface) => new GlutenClickhouseSqlParser(spark,
parserInterface))
+ injector.injectResolutionRule(
+ spark => new RewriteToDateExpresstionRule(spark,
spark.sessionState.conf))
+ injector.injectResolutionRule(
+ spark => new RewriteDateTimestampComparisonRule(spark,
spark.sessionState.conf))
+ injector.injectOptimizerRule(
+ spark => new CommonSubexpressionEliminateRule(spark,
spark.sessionState.conf))
+ injector.injectOptimizerRule(spark =>
CHAggregateFunctionRewriteRule(spark))
+ injector.injectOptimizerRule(_ => CountDistinctWithoutExpand)
+ injector.injectOptimizerRule(_ => EqualToRewrite)
+ }
+
+ def injectLegacy(injector: LegacyInjector): Unit = {
+ // Gluten columnar: Transform rules.
+ injector.injectTransform(_ => RemoveTransitions)
+ injector.injectTransform(c => FallbackOnANSIMode.apply(c.session))
+ injector.injectTransform(c => FallbackMultiCodegens.apply(c.session))
+ injector.injectTransform(c => PlanOneRowRelation.apply(c.session))
+ injector.injectTransform(_ => RewriteSubqueryBroadcast())
+ injector.injectTransform(c => FallbackBroadcastHashJoin.apply(c.session))
+ injector.injectTransform(_ => FallbackEmptySchemaRelation())
+ injector.injectTransform(c =>
MergeTwoPhasesHashBaseAggregate.apply(c.session))
+ injector.injectTransform(_ => RewriteSparkPlanRulesManager())
+ injector.injectTransform(_ => AddFallbackTagRule())
+ injector.injectTransform(_ => TransformPreOverrides())
+ injector.injectTransform(_ => RemoveNativeWriteFilesSortAndProject())
+ injector.injectTransform(c => RewriteTransformer.apply(c.session))
+ injector.injectTransform(_ => EnsureLocalSortRequirements)
+ injector.injectTransform(_ => EliminateLocalSort)
+ injector.injectTransform(_ => CollapseProjectExecTransformer)
+ injector.injectTransform(c =>
RewriteSortMergeJoinToHashJoinRule.apply(c.session))
+ injector.injectTransform(
+ c =>
SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarTransformRules)(c.session))
+ injector.injectTransform(c => InsertTransitions(c.outputsColumnar))
+
+ // Gluten columnar: Fallback policies.
+ injector.injectFallbackPolicy(
+ c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan()))
+
+ // Gluten columnar: Post rules.
+ injector.injectPost(c => RemoveTopmostColumnarToRow(c.session,
c.ac.isAdaptiveContext()))
+ SparkShimLoader.getSparkShims
+ .getExtendedColumnarPostRules()
+ .foreach(each => injector.injectPost(c => each(c.session)))
+ injector.injectPost(c => ColumnarCollapseTransformStages(c.conf))
+ injector.injectTransform(
+ c =>
SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarPostRules)(c.session))
+
+ // Gluten columnar: Final rules.
+ injector.injectFinal(c => RemoveGlutenTableCacheColumnarToRow(c.session))
+ injector.injectFinal(c => GlutenFallbackReporter(c.conf, c.session))
+ injector.injectFinal(_ => RemoveFallbackTagRule())
+ }
+
+ def injectRas(injector: RasInjector): Unit = {
+ // CH backend doesn't work with RAS at the moment. Inject a rule that
aborts any
+ // execution calls.
+ injector.inject(
+ _ =>
+ new SparkPlanRules.AbortRule(
+ "Clickhouse backend doesn't yet have RAS support, please try
disabling RAS and" +
+ " rerunning the application"))
+ }
+}
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index 8fdc2645a..02b4777e7 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -21,11 +21,9 @@ import org.apache.gluten.backendsapi.{BackendsApiManager,
SparkPlanExecApi}
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution._
import org.apache.gluten.expression._
-import org.apache.gluten.extension.{CommonSubexpressionEliminateRule,
CountDistinctWithoutExpand, FallbackBroadcastHashJoin,
FallbackBroadcastHashJoinPrepQueryStage, RewriteDateTimestampComparisonRule,
RewriteSortMergeJoinToHashJoinRule, RewriteToDateExpresstionRule}
import org.apache.gluten.extension.columnar.AddFallbackTagRule
import
org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverrides
import org.apache.gluten.extension.columnar.transition.Convention
-import org.apache.gluten.parser.GlutenClickhouseSqlParser
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode, WindowFunctionNode}
import org.apache.gluten.utils.{CHJoinValidateUtil, UnknownJoinStrategy}
@@ -36,18 +34,13 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{GenShuffleWriterParameters,
GlutenShuffleWriterWrapper, HashPartitioningWrapper}
import org.apache.spark.shuffle.utils.CHShuffleUtil
-import org.apache.spark.sql.{SparkSession, Strategy}
-import org.apache.spark.sql.catalyst.{CHAggregateFunctionRewriteRule,
EqualToRewrite}
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
CollectList, CollectSet}
import org.apache.spark.sql.catalyst.optimizer.BuildSide
-import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.JoinType
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode,
HashPartitioning, Partitioning, RangePartitioning}
-import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.delta.files.TahoeFileIndex
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec
@@ -549,82 +542,6 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
ClickHouseBuildSideRelation(mode, newOutput, batches.flatten, rowCount,
newBuildKeys)
}
- /**
- * Generate extended DataSourceV2 Strategies. Currently only for ClickHouse
backend.
- *
- * @return
- */
- override def genExtendedDataSourceV2Strategies(): List[SparkSession =>
Strategy] = {
- List.empty
- }
-
- /**
- * Generate extended query stage preparation rules.
- *
- * @return
- */
- override def genExtendedQueryStagePrepRules(): List[SparkSession =>
Rule[SparkPlan]] = {
- List(spark => FallbackBroadcastHashJoinPrepQueryStage(spark))
- }
-
- /**
- * Generate extended Analyzers. Currently only for ClickHouse backend.
- *
- * @return
- */
- override def genExtendedAnalyzers(): List[SparkSession => Rule[LogicalPlan]]
= {
- List(
- spark => new RewriteToDateExpresstionRule(spark,
spark.sessionState.conf),
- spark => new RewriteDateTimestampComparisonRule(spark,
spark.sessionState.conf))
- }
-
- /**
- * Generate extended Optimizers.
- *
- * @return
- */
- override def genExtendedOptimizers(): List[SparkSession =>
Rule[LogicalPlan]] = {
- List(
- spark => new CommonSubexpressionEliminateRule(spark,
spark.sessionState.conf),
- spark => CHAggregateFunctionRewriteRule(spark),
- _ => CountDistinctWithoutExpand,
- _ => EqualToRewrite
- )
- }
-
- /**
- * Generate extended columnar pre-rules, in the validation phase.
- *
- * @return
- */
- override def genExtendedColumnarValidationRules(): List[SparkSession =>
Rule[SparkPlan]] =
- List(spark => FallbackBroadcastHashJoin(spark))
-
- /**
- * Generate extended columnar pre-rules.
- *
- * @return
- */
- override def genExtendedColumnarTransformRules(): List[SparkSession =>
Rule[SparkPlan]] =
- List(spark => RewriteSortMergeJoinToHashJoinRule(spark))
-
- override def genInjectPostHocResolutionRules(): List[SparkSession =>
Rule[LogicalPlan]] = {
- List()
- }
-
- /**
- * Generate extended Strategies.
- *
- * @return
- */
- override def genExtendedStrategies(): List[SparkSession => Strategy] =
- List()
-
- override def genInjectExtendedParser()
- : List[(SparkSession, ParserInterface) => ParserInterface] = {
- List((spark, parserInterface) => new GlutenClickhouseSqlParser(spark,
parserInterface))
- }
-
/** Define backend specfic expression mappings. */
override def extraExpressionMappings: Seq[Sig] = {
List(
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
index d32911f4a..21175f20e 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
@@ -55,6 +55,7 @@ class VeloxBackend extends Backend {
override def validatorApi(): ValidatorApi = new VeloxValidatorApi
override def metricsApi(): MetricsApi = new VeloxMetricsApi
override def listenerApi(): ListenerApi = new VeloxListenerApi
+ override def ruleApi(): RuleApi = new VeloxRuleApi
override def settings(): BackendSettingsApi = VeloxBackendSettings
}
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
new file mode 100644
index 000000000..645407be8
--- /dev/null
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.backendsapi.velox
+
+import org.apache.gluten.backendsapi.RuleApi
+import org.apache.gluten.datasource.ArrowConvertorRule
+import org.apache.gluten.extension._
+import org.apache.gluten.extension.columnar._
+import
org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow,
RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast, TransformPreOverrides}
+import org.apache.gluten.extension.columnar.enumerated.EnumeratedTransform
+import
org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager
+import org.apache.gluten.extension.columnar.transition.{InsertTransitions,
RemoveTransitions}
+import org.apache.gluten.extension.injector.{RuleInjector, SparkInjector}
+import org.apache.gluten.extension.injector.GlutenInjector.{LegacyInjector,
RasInjector}
+import org.apache.gluten.sql.shims.SparkShimLoader
+
+import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages,
GlutenFallbackReporter}
+import org.apache.spark.sql.expression.UDFResolver
+import org.apache.spark.util.SparkPlanRules
+
+class VeloxRuleApi extends RuleApi {
+ import VeloxRuleApi._
+
+ override def injectRules(injector: RuleInjector): Unit = {
+ injectSpark(injector.spark)
+ injectLegacy(injector.gluten.legacy)
+ injectRas(injector.gluten.ras)
+ }
+}
+
+private object VeloxRuleApi {
+ def injectSpark(injector: SparkInjector): Unit = {
+ // Regular Spark rules.
+ injector.injectOptimizerRule(CollectRewriteRule.apply)
+ injector.injectOptimizerRule(HLLRewriteRule.apply)
+ UDFResolver.getFunctionSignatures.foreach(injector.injectFunction)
+ injector.injectPostHocResolutionRule(ArrowConvertorRule.apply)
+ }
+
+ def injectLegacy(injector: LegacyInjector): Unit = {
+ // Gluten columnar: Transform rules.
+ injector.injectTransform(_ => RemoveTransitions)
+ injector.injectTransform(c => FallbackOnANSIMode.apply(c.session))
+ injector.injectTransform(c => FallbackMultiCodegens.apply(c.session))
+ injector.injectTransform(c => PlanOneRowRelation.apply(c.session))
+ injector.injectTransform(_ => RewriteSubqueryBroadcast())
+ injector.injectTransform(c =>
BloomFilterMightContainJointRewriteRule.apply(c.session))
+ injector.injectTransform(c => ArrowScanReplaceRule.apply(c.session))
+ injector.injectTransform(_ => FallbackEmptySchemaRelation())
+ injector.injectTransform(c =>
MergeTwoPhasesHashBaseAggregate.apply(c.session))
+ injector.injectTransform(_ => RewriteSparkPlanRulesManager())
+ injector.injectTransform(_ => AddFallbackTagRule())
+ injector.injectTransform(_ => TransformPreOverrides())
+ injector.injectTransform(_ => RemoveNativeWriteFilesSortAndProject())
+ injector.injectTransform(c => RewriteTransformer.apply(c.session))
+ injector.injectTransform(_ => EnsureLocalSortRequirements)
+ injector.injectTransform(_ => EliminateLocalSort)
+ injector.injectTransform(_ => CollapseProjectExecTransformer)
+ injector.injectTransform(c => FlushableHashAggregateRule.apply(c.session))
+ injector.injectTransform(
+ c =>
SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarTransformRules)(c.session))
+ injector.injectTransform(c => InsertTransitions(c.outputsColumnar))
+
+ // Gluten columnar: Fallback policies.
+ injector.injectFallbackPolicy(
+ c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan()))
+
+ // Gluten columnar: Post rules.
+ injector.injectPost(c => RemoveTopmostColumnarToRow(c.session,
c.ac.isAdaptiveContext()))
+ SparkShimLoader.getSparkShims
+ .getExtendedColumnarPostRules()
+ .foreach(each => injector.injectPost(c => each(c.session)))
+ injector.injectPost(c => ColumnarCollapseTransformStages(c.conf))
+ injector.injectTransform(
+ c =>
SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarPostRules)(c.session))
+
+ // Gluten columnar: Final rules.
+ injector.injectFinal(c => RemoveGlutenTableCacheColumnarToRow(c.session))
+ injector.injectFinal(c => GlutenFallbackReporter(c.conf, c.session))
+ injector.injectFinal(_ => RemoveFallbackTagRule())
+ }
+
+ def injectRas(injector: RasInjector): Unit = {
+ // Gluten RAS: Pre rules.
+ injector.inject(_ => RemoveTransitions)
+ injector.inject(c => FallbackOnANSIMode.apply(c.session))
+ injector.inject(c => PlanOneRowRelation.apply(c.session))
+ injector.inject(_ => FallbackEmptySchemaRelation())
+ injector.inject(_ => RewriteSubqueryBroadcast())
+ injector.inject(c =>
BloomFilterMightContainJointRewriteRule.apply(c.session))
+ injector.inject(c => ArrowScanReplaceRule.apply(c.session))
+ injector.inject(c => MergeTwoPhasesHashBaseAggregate.apply(c.session))
+
+ // Gluten RAS: The RAS rule.
+ injector.inject(c => EnumeratedTransform(c.session, c.outputsColumnar))
+
+ // Gluten RAS: Post rules.
+ injector.inject(_ => RemoveTransitions)
+ injector.inject(_ => RemoveNativeWriteFilesSortAndProject())
+ injector.inject(c => RewriteTransformer.apply(c.session))
+ injector.inject(_ => EnsureLocalSortRequirements)
+ injector.inject(_ => EliminateLocalSort)
+ injector.inject(_ => CollapseProjectExecTransformer)
+ injector.inject(c => FlushableHashAggregateRule.apply(c.session))
+ injector.inject(
+ c =>
SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarTransformRules)(c.session))
+ injector.inject(c => InsertTransitions(c.outputsColumnar))
+ injector.inject(c => RemoveTopmostColumnarToRow(c.session,
c.ac.isAdaptiveContext()))
+ SparkShimLoader.getSparkShims
+ .getExtendedColumnarPostRules()
+ .foreach(each => injector.inject(c => each(c.session)))
+ injector.inject(c => ColumnarCollapseTransformStages(c.conf))
+ injector.inject(
+ c =>
SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarPostRules)(c.session))
+ injector.inject(c => RemoveGlutenTableCacheColumnarToRow(c.session))
+ injector.inject(c => GlutenFallbackReporter(c.conf, c.session))
+ injector.inject(_ => RemoveFallbackTagRule())
+ }
+}
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index fd0fc62dc..bd390004f 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -18,12 +18,10 @@ package org.apache.gluten.backendsapi.velox
import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.SparkPlanExecApi
-import org.apache.gluten.datasource.ArrowConvertorRule
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution._
import org.apache.gluten.expression._
import org.apache.gluten.expression.aggregate.{HLLAdapter,
VeloxBloomFilterAggregate, VeloxCollectList, VeloxCollectSet}
-import org.apache.gluten.extension._
import org.apache.gluten.extension.columnar.FallbackTags
import org.apache.gluten.extension.columnar.transition.Convention
import
org.apache.gluten.extension.columnar.transition.ConventionFunc.BatchOverride
@@ -36,18 +34,13 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{GenShuffleWriterParameters,
GlutenShuffleWriterWrapper}
import org.apache.spark.shuffle.utils.ShuffleUtil
-import org.apache.spark.sql.{SparkSession, Strategy}
-import org.apache.spark.sql.catalyst.FunctionIdentifier
-import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.JoinType
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.datasources.FileFormat
@@ -56,7 +49,7 @@ import
org.apache.spark.sql.execution.joins.{BuildSideRelation, HashedRelationBr
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
import org.apache.spark.sql.execution.utils.ExecUtil
-import org.apache.spark.sql.expression.{UDFExpression, UDFResolver,
UserDefinedAggregateFunction}
+import org.apache.spark.sql.expression.{UDFExpression,
UserDefinedAggregateFunction}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -65,8 +58,6 @@ import org.apache.commons.lang3.ClassUtils
import javax.ws.rs.core.UriBuilder
-import scala.collection.mutable.ListBuffer
-
class VeloxSparkPlanExecApi extends SparkPlanExecApi {
/** The columnar-batch type this backend is using. */
@@ -760,74 +751,6 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
}
}
- /**
- * * Rules and strategies.
- */
-
- /**
- * Generate extended DataSourceV2 Strategy.
- *
- * @return
- */
- override def genExtendedDataSourceV2Strategies(): List[SparkSession =>
Strategy] = List()
-
- /**
- * Generate extended query stage preparation rules.
- *
- * @return
- */
- override def genExtendedQueryStagePrepRules(): List[SparkSession =>
Rule[SparkPlan]] = List()
-
- /**
- * Generate extended Analyzer.
- *
- * @return
- */
- override def genExtendedAnalyzers(): List[SparkSession => Rule[LogicalPlan]]
= List()
-
- /**
- * Generate extended Optimizer. Currently only for Velox backend.
- *
- * @return
- */
- override def genExtendedOptimizers(): List[SparkSession =>
Rule[LogicalPlan]] =
- List(CollectRewriteRule.apply, HLLRewriteRule.apply)
-
- /**
- * Generate extended columnar pre-rules, in the validation phase.
- *
- * @return
- */
- override def genExtendedColumnarValidationRules(): List[SparkSession =>
Rule[SparkPlan]] = {
- List(BloomFilterMightContainJointRewriteRule.apply,
ArrowScanReplaceRule.apply)
- }
-
- /**
- * Generate extended columnar pre-rules.
- *
- * @return
- */
- override def genExtendedColumnarTransformRules(): List[SparkSession =>
Rule[SparkPlan]] = {
- val buf: ListBuffer[SparkSession => Rule[SparkPlan]] = ListBuffer()
- if (GlutenConfig.getConf.enableVeloxFlushablePartialAggregation) {
- buf += FlushableHashAggregateRule.apply
- }
- buf.result
- }
-
- override def genInjectPostHocResolutionRules(): List[SparkSession =>
Rule[LogicalPlan]] = {
- List(ArrowConvertorRule)
- }
-
- /**
- * Generate extended Strategy.
- *
- * @return
- */
- override def genExtendedStrategies(): List[SparkSession => Strategy] = {
- List()
- }
-
/** Define backend specfic expression mappings. */
override def extraExpressionMappings: Seq[Sig] = {
Seq(
@@ -844,11 +767,6 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
)
}
- override def genInjectedFunctions()
- : Seq[(FunctionIdentifier, ExpressionInfo, FunctionBuilder)] = {
- UDFResolver.getFunctionSignatures
- }
-
override def rewriteSpillPath(path: String): String = {
val fs = GlutenConfig.getConf.veloxSpillFileSystem
fs match {
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala
b/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala
index 3137d6e6a..04bdbe1ef 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala
@@ -16,6 +16,7 @@
*/
package org.apache.gluten.extension
+import org.apache.gluten.GlutenConfig
import org.apache.gluten.execution._
import org.apache.spark.sql.SparkSession
@@ -31,27 +32,32 @@ import
org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
*/
case class FlushableHashAggregateRule(session: SparkSession) extends
Rule[SparkPlan] {
import FlushableHashAggregateRule._
- override def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
- case s: ShuffleExchangeLike =>
- // If an exchange follows a hash aggregate in which all functions are in
partial mode,
- // then it's safe to convert the hash aggregate to flushable hash
aggregate.
- val out = s.withNewChildren(
- List(
- replaceEligibleAggregates(s.child) {
- agg =>
- FlushableHashAggregateExecTransformer(
- agg.requiredChildDistributionExpressions,
- agg.groupingExpressions,
- agg.aggregateExpressions,
- agg.aggregateAttributes,
- agg.initialInputBufferOffset,
- agg.resultExpressions,
- agg.child
- )
- }
+ override def apply(plan: SparkPlan): SparkPlan = {
+ if (!GlutenConfig.getConf.enableVeloxFlushablePartialAggregation) {
+ return plan
+ }
+ plan.transformUp {
+ case s: ShuffleExchangeLike =>
+ // If an exchange follows a hash aggregate in which all functions are
in partial mode,
+ // then it's safe to convert the hash aggregate to flushable hash
aggregate.
+ val out = s.withNewChildren(
+ List(
+ replaceEligibleAggregates(s.child) {
+ agg =>
+ FlushableHashAggregateExecTransformer(
+ agg.requiredChildDistributionExpressions,
+ agg.groupingExpressions,
+ agg.aggregateExpressions,
+ agg.aggregateAttributes,
+ agg.initialInputBufferOffset,
+ agg.resultExpressions,
+ agg.child
+ )
+ }
+ )
)
- )
- out
+ out
+ }
}
private def replaceEligibleAggregates(plan: SparkPlan)(
diff --git a/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala
b/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala
index dbf927909..6e3484dfa 100644
--- a/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala
+++ b/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala
@@ -17,12 +17,11 @@
package org.apache.gluten
import org.apache.gluten.GlutenConfig.GLUTEN_DEFAULT_SESSION_TIMEZONE_KEY
-import org.apache.gluten.GlutenPlugin.{GLUTEN_SESSION_EXTENSION_NAME,
SPARK_SESSION_EXTS_KEY}
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.events.GlutenBuildInfoEvent
import org.apache.gluten.exception.GlutenException
import org.apache.gluten.expression.ExpressionMappings
-import org.apache.gluten.extension.{ColumnarOverrides,
OthersExtensionOverrides, QueryStagePrepOverrides}
+import
org.apache.gluten.extension.GlutenSessionExtensions.{GLUTEN_SESSION_EXTENSION_NAME,
SPARK_SESSION_EXTS_KEY}
import org.apache.gluten.test.TestStats
import org.apache.gluten.utils.TaskListener
@@ -31,14 +30,13 @@ import org.apache.spark.api.plugin.{DriverPlugin,
ExecutorPlugin, PluginContext,
import org.apache.spark.internal.Logging
import org.apache.spark.listener.GlutenListenerFactory
import org.apache.spark.network.util.JavaUtils
-import org.apache.spark.sql.SparkSessionExtensions
import org.apache.spark.sql.execution.ui.GlutenEventUtils
-import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.utils.ExpressionUtil
import org.apache.spark.util.{SparkResourceUtil, TaskResources}
import java.util
-import java.util.{Collections, Objects}
+import java.util.Collections
import scala.collection.mutable
@@ -298,25 +296,4 @@ private[gluten] class GlutenExecutorPlugin extends
ExecutorPlugin {
}
}
-private[gluten] class GlutenSessionExtensions extends (SparkSessionExtensions
=> Unit) {
- override def apply(exts: SparkSessionExtensions): Unit = {
- GlutenPlugin.DEFAULT_INJECTORS.foreach(injector => injector.inject(exts))
- }
-}
-
-private[gluten] trait GlutenSparkExtensionsInjector {
- def inject(extensions: SparkSessionExtensions): Unit
-}
-
-private[gluten] object GlutenPlugin {
- val SPARK_SESSION_EXTS_KEY: String =
StaticSQLConf.SPARK_SESSION_EXTENSIONS.key
- val GLUTEN_SESSION_EXTENSION_NAME: String =
- Objects.requireNonNull(classOf[GlutenSessionExtensions].getCanonicalName)
-
- /** Specify all injectors that Gluten is using in following list. */
- val DEFAULT_INJECTORS: List[GlutenSparkExtensionsInjector] = List(
- QueryStagePrepOverrides,
- ColumnarOverrides,
- OthersExtensionOverrides
- )
-}
+private object GlutenPlugin {}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/Backend.scala
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/Backend.scala
index 2c465ac61..3a5975522 100644
--- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/Backend.scala
+++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/Backend.scala
@@ -33,6 +33,8 @@ trait Backend {
def listenerApi(): ListenerApi
+ def ruleApi(): RuleApi
+
def settings(): BackendSettingsApi
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendsApiManager.scala
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendsApiManager.scala
index f2c93d8c7..16aa9161e 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendsApiManager.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendsApiManager.scala
@@ -83,6 +83,10 @@ object BackendsApiManager {
backend.metricsApi()
}
+ def getRuleApiInstance: RuleApi = {
+ backend.ruleApi()
+ }
+
def getSettings: BackendSettingsApi = {
backend.settings
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/Backend.scala
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala
similarity index 64%
copy from gluten-core/src/main/scala/org/apache/gluten/backendsapi/Backend.scala
copy to gluten-core/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala
index 2c465ac61..f8669a6fe 100644
--- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/Backend.scala
+++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/RuleApi.scala
@@ -16,28 +16,9 @@
*/
package org.apache.gluten.backendsapi
-trait Backend {
- def name(): String
+import org.apache.gluten.extension.injector.RuleInjector
- def buildInfo(): BackendBuildInfo
-
- def iteratorApi(): IteratorApi
-
- def sparkPlanExecApi(): SparkPlanExecApi
-
- def transformerApi(): TransformerApi
-
- def validatorApi(): ValidatorApi
-
- def metricsApi(): MetricsApi
-
- def listenerApi(): ListenerApi
-
- def settings(): BackendSettingsApi
+trait RuleApi {
+ // Injects all Gluten / Spark query planner rules used by the backend.
+ def injectRules(injector: RuleInjector): Unit
}
-
-case class BackendBuildInfo(
- backend: String,
- backendBranch: String,
- backendRevision: String,
- backendRevisionTime: String)
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index 3b9e87a20..0227ed5da 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -27,20 +27,14 @@ import org.apache.spark.ShuffleDependency
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{GenShuffleWriterParameters,
GlutenShuffleWriterWrapper}
-import org.apache.spark.sql.{SparkSession, Strategy}
-import org.apache.spark.sql.catalyst.FunctionIdentifier
-import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.BuildSide
-import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.JoinType
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode,
Partitioning}
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.{ColumnarWriteFilesExec,
FileSourceScanExec, GenerateExec, LeafExecNode, SparkPlan}
+import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
@@ -392,69 +386,6 @@ trait SparkPlanExecApi {
child: SparkPlan,
evalType: Int): SparkPlan
- /**
- * Generate extended DataSourceV2 Strategies. Currently only for ClickHouse
backend.
- *
- * @return
- */
- def genExtendedDataSourceV2Strategies(): List[SparkSession => Strategy]
-
- /**
- * Generate extended query stage preparation rules.
- *
- * @return
- */
- def genExtendedQueryStagePrepRules(): List[SparkSession => Rule[SparkPlan]]
-
- /**
- * Generate extended Analyzers. Currently only for ClickHouse backend.
- *
- * @return
- */
- def genExtendedAnalyzers(): List[SparkSession => Rule[LogicalPlan]]
-
- /**
- * Generate extended Optimizers. Currently only for Velox backend.
- *
- * @return
- */
- def genExtendedOptimizers(): List[SparkSession => Rule[LogicalPlan]]
-
- /**
- * Generate extended Strategies
- *
- * @return
- */
- def genExtendedStrategies(): List[SparkSession => Strategy]
-
- /**
- * Generate extended columnar pre-rules, in the validation phase.
- *
- * @return
- */
- def genExtendedColumnarValidationRules(): List[SparkSession =>
Rule[SparkPlan]]
-
- /**
- * Generate extended columnar transform-rules.
- *
- * @return
- */
- def genExtendedColumnarTransformRules(): List[SparkSession =>
Rule[SparkPlan]]
-
- /**
- * Generate extended columnar post-rules.
- *
- * @return
- */
- def genExtendedColumnarPostRules(): List[SparkSession => Rule[SparkPlan]] = {
- SparkShimLoader.getSparkShims.getExtendedColumnarPostRules() ::: List()
- }
-
- def genInjectPostHocResolutionRules(): List[SparkSession =>
Rule[LogicalPlan]]
-
- def genInjectExtendedParser(): List[(SparkSession, ParserInterface) =>
ParserInterface] =
- List.empty
-
def genGetStructFieldTransformer(
substraitExprName: String,
childTransformer: ExpressionTransformer,
@@ -665,8 +596,6 @@ trait SparkPlanExecApi {
}
}
- def genInjectedFunctions(): Seq[(FunctionIdentifier, ExpressionInfo,
FunctionBuilder)] = Seq.empty
-
def rewriteSpillPath(path: String): String = path
/**
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/ColumnarOverrides.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/ColumnarOverrides.scala
index 067976b63..c5a9afec3 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/ColumnarOverrides.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/ColumnarOverrides.scala
@@ -16,17 +16,14 @@
*/
package org.apache.gluten.extension
-import org.apache.gluten.{GlutenConfig, GlutenSparkExtensionsInjector}
import org.apache.gluten.extension.columnar._
-import org.apache.gluten.extension.columnar.enumerated.EnumeratedApplier
-import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier
import org.apache.gluten.extension.columnar.transition.Transitions
import org.apache.gluten.utils.LogLevelUtil
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.rules.Rule
@@ -95,7 +92,9 @@ object ColumnarOverrideRules {
}
}
-case class ColumnarOverrideRules(session: SparkSession)
+case class ColumnarOverrideRules(
+ session: SparkSession,
+ applierBuilder: SparkSession => ColumnarRuleApplier)
extends ColumnarRule
with Logging
with LogLevelUtil {
@@ -117,19 +116,11 @@ case class ColumnarOverrideRules(session: SparkSession)
val outputsColumnar = OutputsColumnarTester.inferOutputsColumnar(plan)
val unwrapped = OutputsColumnarTester.unwrap(plan)
val vanillaPlan = Transitions.insertTransitions(unwrapped, outputsColumnar)
- val applier: ColumnarRuleApplier = if (GlutenConfig.getConf.enableRas) {
- new EnumeratedApplier(session)
- } else {
- new HeuristicApplier(session)
- }
+ val applier = applierBuilder.apply(session)
val out = applier.apply(vanillaPlan, outputsColumnar)
out
}
}
-object ColumnarOverrides extends GlutenSparkExtensionsInjector {
- override def inject(extensions: SparkSessionExtensions): Unit = {
- extensions.injectColumnar(spark => ColumnarOverrideRules(spark))
- }
-}
+object ColumnarOverrides {}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala
new file mode 100644
index 000000000..4456dda61
--- /dev/null
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenSessionExtensions.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension
+
+import org.apache.gluten.backendsapi.BackendsApiManager
+import org.apache.gluten.extension.injector.RuleInjector
+
+import org.apache.spark.sql.SparkSessionExtensions
+import org.apache.spark.sql.internal.StaticSQLConf
+
+import java.util.Objects
+
+private[gluten] class GlutenSessionExtensions extends (SparkSessionExtensions
=> Unit) {
+ override def apply(exts: SparkSessionExtensions): Unit = {
+ val injector = new RuleInjector()
+ BackendsApiManager.getRuleApiInstance.injectRules(injector)
+ injector.inject(exts)
+ }
+}
+
+private[gluten] object GlutenSessionExtensions {
+ val SPARK_SESSION_EXTS_KEY: String =
StaticSQLConf.SPARK_SESSION_EXTENSIONS.key
+ val GLUTEN_SESSION_EXTENSION_NAME: String =
+ Objects.requireNonNull(classOf[GlutenSessionExtensions].getCanonicalName)
+}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/OthersExtensionOverrides.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/OthersExtensionOverrides.scala
deleted file mode 100644
index f2ccf6e81..000000000
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/OthersExtensionOverrides.scala
+++ /dev/null
@@ -1,48 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.extension
-
-import org.apache.gluten.GlutenSparkExtensionsInjector
-import org.apache.gluten.backendsapi.BackendsApiManager
-
-import org.apache.spark.sql.SparkSessionExtensions
-
-object OthersExtensionOverrides extends GlutenSparkExtensionsInjector {
- override def inject(extensions: SparkSessionExtensions): Unit = {
- BackendsApiManager.getSparkPlanExecApiInstance
- .genInjectExtendedParser()
- .foreach(extensions.injectParser)
- BackendsApiManager.getSparkPlanExecApiInstance
- .genExtendedAnalyzers()
- .foreach(extensions.injectResolutionRule)
- BackendsApiManager.getSparkPlanExecApiInstance
- .genExtendedOptimizers()
- .foreach(extensions.injectOptimizerRule)
- BackendsApiManager.getSparkPlanExecApiInstance
- .genExtendedDataSourceV2Strategies()
- .foreach(extensions.injectPlannerStrategy)
- BackendsApiManager.getSparkPlanExecApiInstance
- .genExtendedStrategies()
- .foreach(extensions.injectPlannerStrategy)
- BackendsApiManager.getSparkPlanExecApiInstance
- .genInjectedFunctions()
- .foreach(extensions.injectFunction)
- BackendsApiManager.getSparkPlanExecApiInstance
- .genInjectPostHocResolutionRules()
- .foreach(extensions.injectPostHocResolutionRule)
- }
-}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/QueryStagePrepOverrides.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/QueryStagePrepOverrides.scala
deleted file mode 100644
index 8f9e2326c..000000000
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/QueryStagePrepOverrides.scala
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.extension
-
-import org.apache.gluten.GlutenSparkExtensionsInjector
-import org.apache.gluten.backendsapi.BackendsApiManager
-
-import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.SparkPlan
-
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-object QueryStagePrepOverrides extends GlutenSparkExtensionsInjector {
- private val RULES: Seq[SparkSession => Rule[SparkPlan]] =
-
BackendsApiManager.getSparkPlanExecApiInstance.genExtendedQueryStagePrepRules()
-
- override def inject(extensions: SparkSessionExtensions): Unit = {
- RULES.foreach(extensions.injectQueryStagePrepRule)
- }
-}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala
index 27213698b..9b78ccd11 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala
@@ -17,10 +17,12 @@
package org.apache.gluten.extension.columnar
import org.apache.gluten.GlutenConfig
+import org.apache.gluten.extension.columnar.util.AdaptiveContext
import org.apache.gluten.metrics.GlutenTimeMetric
import org.apache.gluten.utils.LogLevelUtil
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.execution.SparkPlan
@@ -30,6 +32,17 @@ trait ColumnarRuleApplier {
}
object ColumnarRuleApplier {
+ type ColumnarRuleBuilder = ColumnarRuleCall => Rule[SparkPlan]
+
+ class ColumnarRuleCall(
+ val session: SparkSession,
+ val ac: AdaptiveContext,
+ val outputsColumnar: Boolean) {
+ val conf: GlutenConfig = {
+ new GlutenConfig(session.sessionState.conf)
+ }
+ }
+
class Executor(phase: String, rules: Seq[Rule[SparkPlan]]) extends
RuleExecutor[SparkPlan] {
private val batch: Batch =
Batch(s"Columnar (Phase [$phase])", Once, rules.map(r => new
LoggedRule(r)): _*)
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala
index 5cf3961c5..bebce3a61 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedApplier.scala
@@ -16,11 +16,8 @@
*/
package org.apache.gluten.extension.columnar.enumerated
-import org.apache.gluten.GlutenConfig
-import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.extension.columnar._
-import
org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow,
RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast}
-import org.apache.gluten.extension.columnar.transition.{InsertTransitions,
RemoveTransitions}
+import
org.apache.gluten.extension.columnar.ColumnarRuleApplier.{ColumnarRuleBuilder,
ColumnarRuleCall}
import org.apache.gluten.extension.columnar.util.AdaptiveContext
import org.apache.gluten.utils.{LogLevelUtil, PhysicalPlanSelector}
@@ -28,8 +25,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages,
GlutenFallbackReporter, SparkPlan}
-import org.apache.spark.util.SparkRuleUtil
+import org.apache.spark.sql.execution.SparkPlan
/**
* Columnar rule applier that optimizes, implements Spark plan into Gluten
plan by enumerating on
@@ -40,7 +36,7 @@ import org.apache.spark.util.SparkRuleUtil
* implementing them in EnumeratedTransform.
*/
@Experimental
-class EnumeratedApplier(session: SparkSession)
+class EnumeratedApplier(session: SparkSession, ruleBuilders:
Seq[ColumnarRuleBuilder])
extends ColumnarRuleApplier
with Logging
with LogLevelUtil {
@@ -53,22 +49,18 @@ class EnumeratedApplier(session: SparkSession)
}
private val adaptiveContext = AdaptiveContext(session, aqeStackTraceIndex)
- override def apply(plan: SparkPlan, outputsColumnar: Boolean): SparkPlan =
+ override def apply(plan: SparkPlan, outputsColumnar: Boolean): SparkPlan = {
+ val call = new ColumnarRuleCall(session, adaptiveContext, outputsColumnar)
PhysicalPlanSelector.maybe(session, plan) {
- val transformed =
- transformPlan("transform",
transformRules(outputsColumnar).map(_(session)), plan)
- val postPlan = maybeAqe {
- transformPlan("post", postRules().map(_(session)), transformed)
+ val finalPlan = maybeAqe {
+ apply0(ruleBuilders.map(b => b(call)), plan)
}
- val finalPlan = transformPlan("final", finalRules().map(_(session)),
postPlan)
finalPlan
}
+ }
- private def transformPlan(
- phase: String,
- rules: Seq[Rule[SparkPlan]],
- plan: SparkPlan): SparkPlan = {
- val executor = new ColumnarRuleApplier.Executor(phase, rules)
+ private def apply0(rules: Seq[Rule[SparkPlan]], plan: SparkPlan): SparkPlan
= {
+ val executor = new ColumnarRuleApplier.Executor("ras", rules)
executor.execute(plan)
}
@@ -80,61 +72,4 @@ class EnumeratedApplier(session: SparkSession)
adaptiveContext.resetAdaptiveContext()
}
}
-
- /**
- * Rules to let planner create a suggested Gluten plan being sent to
`fallbackPolicies` in which
- * the plan will be breakdown and decided to be fallen back or not.
- */
- private def transformRules(outputsColumnar: Boolean): Seq[SparkSession =>
Rule[SparkPlan]] = {
- List(
- (_: SparkSession) => RemoveTransitions,
- (spark: SparkSession) => FallbackOnANSIMode(spark),
- (spark: SparkSession) => PlanOneRowRelation(spark),
- (_: SparkSession) => FallbackEmptySchemaRelation(),
- (_: SparkSession) => RewriteSubqueryBroadcast()
- ) :::
-
BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarValidationRules()
:::
- List((spark: SparkSession) => MergeTwoPhasesHashBaseAggregate(spark)) :::
- List(
- (session: SparkSession) => EnumeratedTransform(session,
outputsColumnar),
- (_: SparkSession) => RemoveTransitions
- ) :::
- List(
- (_: SparkSession) => RemoveNativeWriteFilesSortAndProject(),
- (spark: SparkSession) => RewriteTransformer(spark),
- (_: SparkSession) => EnsureLocalSortRequirements,
- (_: SparkSession) => EliminateLocalSort,
- (_: SparkSession) => CollapseProjectExecTransformer
- ) :::
-
BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarTransformRules()
:::
- SparkRuleUtil
- .extendedColumnarRules(session,
GlutenConfig.getConf.extendedColumnarTransformRules) :::
- List((_: SparkSession) => InsertTransitions(outputsColumnar))
- }
-
- /**
- * Rules applying to non-fallen-back Gluten plans. To do some post cleanup
works on the plan to
- * make sure it be able to run and be compatible with Spark's execution
engine.
- */
- private def postRules(): Seq[SparkSession => Rule[SparkPlan]] =
- List(
- (s: SparkSession) => RemoveTopmostColumnarToRow(s,
adaptiveContext.isAdaptiveContext())) :::
-
BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarPostRules()
:::
- List((_: SparkSession) =>
ColumnarCollapseTransformStages(GlutenConfig.getConf)) :::
- SparkRuleUtil.extendedColumnarRules(session,
GlutenConfig.getConf.extendedColumnarPostRules)
-
- /*
- * Rules consistently applying to all input plans after all other rules have
been applied, despite
- * whether the input plan is fallen back or not.
- */
- private def finalRules(): Seq[SparkSession => Rule[SparkPlan]] = {
- List(
- // The rule is required despite whether the stage is fallen back or not.
Since
- // ColumnarCachedBatchSerializer is statically registered to Spark
without a columnar rule
- // when columnar table cache is enabled.
- (s: SparkSession) => RemoveGlutenTableCacheColumnarToRow(s),
- (s: SparkSession) => GlutenFallbackReporter(GlutenConfig.getConf, s),
- (_: SparkSession) => RemoveFallbackTagRule()
- )
- }
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala
index f776a1dcc..dea9f01df 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala
@@ -16,26 +16,26 @@
*/
package org.apache.gluten.extension.columnar.heuristic
-import org.apache.gluten.GlutenConfig
-import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.extension.columnar._
-import
org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow,
RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast, TransformPreOverrides}
-import
org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager
-import org.apache.gluten.extension.columnar.transition.{InsertTransitions,
RemoveTransitions}
+import
org.apache.gluten.extension.columnar.ColumnarRuleApplier.{ColumnarRuleBuilder,
ColumnarRuleCall}
import org.apache.gluten.extension.columnar.util.AdaptiveContext
import org.apache.gluten.utils.{LogLevelUtil, PhysicalPlanSelector}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages,
GlutenFallbackReporter, SparkPlan}
-import org.apache.spark.util.SparkRuleUtil
+import org.apache.spark.sql.execution.SparkPlan
/**
* Columnar rule applier that optimizes, implements Spark plan into Gluten
plan by heuristically
* applying columnar rules in fixed order.
*/
-class HeuristicApplier(session: SparkSession)
+class HeuristicApplier(
+ session: SparkSession,
+ transformBuilders: Seq[ColumnarRuleBuilder],
+ fallbackPolicyBuilders: Seq[ColumnarRuleBuilder],
+ postBuilders: Seq[ColumnarRuleBuilder],
+ finalBuilders: Seq[ColumnarRuleBuilder])
extends ColumnarRuleApplier
with Logging
with LogLevelUtil {
@@ -49,27 +49,27 @@ class HeuristicApplier(session: SparkSession)
private val adaptiveContext = AdaptiveContext(session, aqeStackTraceIndex)
override def apply(plan: SparkPlan, outputsColumnar: Boolean): SparkPlan = {
- withTransformRules(transformRules(outputsColumnar)).apply(plan)
+ val call = new ColumnarRuleCall(session, adaptiveContext, outputsColumnar)
+ makeRule(call).apply(plan)
}
- // Visible for testing.
- def withTransformRules(transformRules: Seq[SparkSession =>
Rule[SparkPlan]]): Rule[SparkPlan] =
+ private def makeRule(call: ColumnarRuleCall): Rule[SparkPlan] =
plan =>
PhysicalPlanSelector.maybe(session, plan) {
val finalPlan = prepareFallback(plan) {
p =>
- val suggestedPlan = transformPlan("transform",
transformRules.map(_(session)), p)
- transformPlan("fallback", fallbackPolicies().map(_(session)),
suggestedPlan) match {
+ val suggestedPlan = transformPlan("transform",
transformRules(call), p)
+ transformPlan("fallback", fallbackPolicies(call), suggestedPlan)
match {
case FallbackNode(fallbackPlan) =>
// we should use vanilla c2r rather than native c2r,
// and there should be no `GlutenPlan` any more,
// so skip the `postRules()`.
fallbackPlan
case plan =>
- transformPlan("post", postRules().map(_(session)), plan)
+ transformPlan("post", postRules(call), plan)
}
}
- transformPlan("final", finalRules().map(_(session)), finalPlan)
+ transformPlan("final", finalRules(call), finalPlan)
}
private def transformPlan(
@@ -95,69 +95,32 @@ class HeuristicApplier(session: SparkSession)
* Rules to let planner create a suggested Gluten plan being sent to
`fallbackPolicies` in which
* the plan will be breakdown and decided to be fallen back or not.
*/
- private def transformRules(outputsColumnar: Boolean): Seq[SparkSession =>
Rule[SparkPlan]] = {
- List(
- (_: SparkSession) => RemoveTransitions,
- (spark: SparkSession) => FallbackOnANSIMode(spark),
- (spark: SparkSession) => FallbackMultiCodegens(spark),
- (spark: SparkSession) => PlanOneRowRelation(spark),
- (_: SparkSession) => RewriteSubqueryBroadcast()
- ) :::
-
BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarValidationRules()
:::
- List(
- (_: SparkSession) => FallbackEmptySchemaRelation(),
- (spark: SparkSession) => MergeTwoPhasesHashBaseAggregate(spark),
- (_: SparkSession) => RewriteSparkPlanRulesManager(),
- (_: SparkSession) => AddFallbackTagRule()
- ) :::
- List((_: SparkSession) => TransformPreOverrides()) :::
- List(
- (_: SparkSession) => RemoveNativeWriteFilesSortAndProject(),
- (spark: SparkSession) => RewriteTransformer(spark),
- (_: SparkSession) => EnsureLocalSortRequirements,
- (_: SparkSession) => EliminateLocalSort,
- (_: SparkSession) => CollapseProjectExecTransformer
- ) :::
-
BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarTransformRules()
:::
- SparkRuleUtil
- .extendedColumnarRules(session,
GlutenConfig.getConf.extendedColumnarTransformRules) :::
- List((_: SparkSession) => InsertTransitions(outputsColumnar))
+ private def transformRules(call: ColumnarRuleCall): Seq[Rule[SparkPlan]] = {
+ transformBuilders.map(b => b.apply(call))
}
/**
* Rules to add wrapper `FallbackNode`s on top of the input plan, as hints
to make planner fall
* back the whole input plan to the original vanilla Spark plan.
*/
- private def fallbackPolicies(): Seq[SparkSession => Rule[SparkPlan]] = {
- List(
- (_: SparkSession) =>
- ExpandFallbackPolicy(adaptiveContext.isAdaptiveContext(),
adaptiveContext.originalPlan()))
+ private def fallbackPolicies(call: ColumnarRuleCall): Seq[Rule[SparkPlan]] =
{
+ fallbackPolicyBuilders.map(b => b.apply(call))
}
/**
* Rules applying to non-fallen-back Gluten plans. To do some post cleanup
works on the plan to
* make sure it be able to run and be compatible with Spark's execution
engine.
*/
- private def postRules(): Seq[SparkSession => Rule[SparkPlan]] =
- List(
- (s: SparkSession) => RemoveTopmostColumnarToRow(s,
adaptiveContext.isAdaptiveContext())) :::
-
BackendsApiManager.getSparkPlanExecApiInstance.genExtendedColumnarPostRules()
:::
- List((_: SparkSession) =>
ColumnarCollapseTransformStages(GlutenConfig.getConf)) :::
- SparkRuleUtil.extendedColumnarRules(session,
GlutenConfig.getConf.extendedColumnarPostRules)
+ private def postRules(call: ColumnarRuleCall): Seq[Rule[SparkPlan]] = {
+ postBuilders.map(b => b.apply(call))
+ }
/*
* Rules consistently applying to all input plans after all other rules have
been applied, despite
* whether the input plan is fallen back or not.
*/
- private def finalRules(): Seq[SparkSession => Rule[SparkPlan]] = {
- List(
- // The rule is required despite whether the stage is fallen back or not.
Since
- // ColumnarCachedBatchSerializer is statically registered to Spark
without a columnar rule
- // when columnar table cache is enabled.
- (s: SparkSession) => RemoveGlutenTableCacheColumnarToRow(s),
- (s: SparkSession) => GlutenFallbackReporter(GlutenConfig.getConf, s),
- (_: SparkSession) => RemoveFallbackTagRule()
- )
+ private def finalRules(call: ColumnarRuleCall): Seq[Rule[SparkPlan]] = {
+ finalBuilders.map(b => b.apply(call))
}
// Just for test use.
@@ -166,3 +129,5 @@ class HeuristicApplier(session: SparkSession)
this
}
}
+
+object HeuristicApplier {}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/util/AdaptiveContext.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/util/AdaptiveContext.scala
index 4a9d69f8f..e1f594fd3 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/util/AdaptiveContext.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/util/AdaptiveContext.scala
@@ -22,6 +22,7 @@ import
org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import scala.collection.mutable.ListBuffer
+// Since: https://github.com/apache/incubator-gluten/pull/3294.
sealed trait AdaptiveContext {
def enableAdaptiveContext(): Unit
def isAdaptiveContext(): Boolean
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala
new file mode 100644
index 000000000..728e569cc
--- /dev/null
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension.injector
+
+import org.apache.gluten.GlutenConfig
+import org.apache.gluten.extension.ColumnarOverrideRules
+import org.apache.gluten.extension.columnar.ColumnarRuleApplier
+import
org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleBuilder
+import org.apache.gluten.extension.columnar.enumerated.EnumeratedApplier
+import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier
+
+import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
+
+import scala.collection.mutable
+
+/** Injector used to inject query planner rules into Gluten. */
+class GlutenInjector private[injector] {
+ import GlutenInjector._
+ val legacy: LegacyInjector = new LegacyInjector()
+ val ras: RasInjector = new RasInjector()
+
+ private[injector] def inject(extensions: SparkSessionExtensions): Unit = {
+ val ruleBuilder = (session: SparkSession) => new
ColumnarOverrideRules(session, applier)
+ extensions.injectColumnar(session => ruleBuilder(session))
+ }
+
+ private def applier(session: SparkSession): ColumnarRuleApplier = {
+ val conf = new GlutenConfig(session.sessionState.conf)
+ if (conf.enableRas) {
+ return ras.createApplier(session)
+ }
+ legacy.createApplier(session)
+ }
+}
+
+object GlutenInjector {
+ class LegacyInjector {
+ private val transformBuilders = mutable.Buffer.empty[ColumnarRuleBuilder]
+ private val fallbackPolicyBuilders =
mutable.Buffer.empty[ColumnarRuleBuilder]
+ private val postBuilders = mutable.Buffer.empty[ColumnarRuleBuilder]
+ private val finalBuilders = mutable.Buffer.empty[ColumnarRuleBuilder]
+
+ def injectTransform(builder: ColumnarRuleBuilder): Unit = {
+ transformBuilders += builder
+ }
+
+ def injectFallbackPolicy(builder: ColumnarRuleBuilder): Unit = {
+ fallbackPolicyBuilders += builder
+ }
+
+ def injectPost(builder: ColumnarRuleBuilder): Unit = {
+ postBuilders += builder
+ }
+
+ def injectFinal(builder: ColumnarRuleBuilder): Unit = {
+ finalBuilders += builder
+ }
+
+ private[injector] def createApplier(session: SparkSession):
ColumnarRuleApplier = {
+ new HeuristicApplier(
+ session,
+ transformBuilders.toSeq,
+ fallbackPolicyBuilders.toSeq,
+ postBuilders.toSeq,
+ finalBuilders.toSeq)
+ }
+ }
+
+ class RasInjector {
+ private val ruleBuilders = mutable.Buffer.empty[ColumnarRuleBuilder]
+
+ def inject(builder: ColumnarRuleBuilder): Unit = {
+ ruleBuilders += builder
+ }
+
+ private[injector] def createApplier(session: SparkSession):
ColumnarRuleApplier = {
+ new EnumeratedApplier(session, ruleBuilders.toSeq)
+ }
+ }
+}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/Backend.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala
similarity index 61%
copy from gluten-core/src/main/scala/org/apache/gluten/backendsapi/Backend.scala
copy to
gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala
index 2c465ac61..bccbd38b2 100644
--- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/Backend.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/RuleInjector.scala
@@ -14,30 +14,19 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.gluten.backendsapi
+package org.apache.gluten.extension.injector
-trait Backend {
- def name(): String
+import org.apache.spark.sql.SparkSessionExtensions
- def buildInfo(): BackendBuildInfo
+/** Injector used to inject query planner rules into Spark and Gluten. */
+class RuleInjector {
+ val spark: SparkInjector = new SparkInjector()
+ val gluten: GlutenInjector = new GlutenInjector()
- def iteratorApi(): IteratorApi
-
- def sparkPlanExecApi(): SparkPlanExecApi
-
- def transformerApi(): TransformerApi
-
- def validatorApi(): ValidatorApi
-
- def metricsApi(): MetricsApi
-
- def listenerApi(): ListenerApi
-
- def settings(): BackendSettingsApi
+ private[extension] def inject(extensions: SparkSessionExtensions): Unit = {
+ spark.inject(extensions)
+ gluten.inject(extensions)
+ }
}
-case class BackendBuildInfo(
- backend: String,
- backendBranch: String,
- backendRevision: String,
- backendRevisionTime: String)
+object RuleInjector {}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala
new file mode 100644
index 000000000..6935e61bd
--- /dev/null
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/SparkInjector.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension.injector
+
+import org.apache.spark.sql.{SparkSession, SparkSessionExtensions, Strategy}
+import org.apache.spark.sql.catalyst.FunctionIdentifier
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
+import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
+import org.apache.spark.sql.catalyst.parser.ParserInterface
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.SparkPlan
+
+import scala.collection.mutable
+
+/** Injector used to inject query planner rules into Spark. */
+class SparkInjector private[injector] {
+ private type RuleBuilder = SparkSession => Rule[LogicalPlan]
+ private type StrategyBuilder = SparkSession => Strategy
+ private type ParserBuilder = (SparkSession, ParserInterface) =>
ParserInterface
+ private type FunctionDescription = (FunctionIdentifier, ExpressionInfo,
FunctionBuilder)
+ private type QueryStagePrepRuleBuilder = SparkSession => Rule[SparkPlan]
+
+ private val queryStagePrepRuleBuilders =
mutable.Buffer.empty[QueryStagePrepRuleBuilder]
+ private val parserBuilders = mutable.Buffer.empty[ParserBuilder]
+ private val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]
+ private val optimizerRules = mutable.Buffer.empty[RuleBuilder]
+ private val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder]
+ private val injectedFunctions = mutable.Buffer.empty[FunctionDescription]
+ private val postHocResolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]
+
+ def injectQueryStagePrepRule(builder: QueryStagePrepRuleBuilder): Unit = {
+ queryStagePrepRuleBuilders += builder
+ }
+
+ def injectParser(builder: ParserBuilder): Unit = {
+ parserBuilders += builder
+ }
+
+ def injectResolutionRule(builder: RuleBuilder): Unit = {
+ resolutionRuleBuilders += builder
+ }
+
+ def injectOptimizerRule(builder: RuleBuilder): Unit = {
+ optimizerRules += builder
+ }
+
+ def injectPlannerStrategy(builder: StrategyBuilder): Unit = {
+ plannerStrategyBuilders += builder
+ }
+
+ def injectFunction(functionDescription: FunctionDescription): Unit = {
+ injectedFunctions += functionDescription
+ }
+
+ def injectPostHocResolutionRule(builder: RuleBuilder): Unit = {
+ postHocResolutionRuleBuilders += builder
+ }
+
+ private[injector] def inject(extensions: SparkSessionExtensions): Unit = {
+ queryStagePrepRuleBuilders.foreach(extensions.injectQueryStagePrepRule)
+ parserBuilders.foreach(extensions.injectParser)
+ resolutionRuleBuilders.foreach(extensions.injectResolutionRule)
+ optimizerRules.foreach(extensions.injectOptimizerRule)
+ plannerStrategyBuilders.foreach(extensions.injectPlannerStrategy)
+ injectedFunctions.foreach(extensions.injectFunction)
+
postHocResolutionRuleBuilders.foreach(extensions.injectPostHocResolutionRule)
+ }
+}
diff --git
a/gluten-core/src/main/scala/org/apache/spark/util/SparkRuleUtil.scala
b/gluten-core/src/main/scala/org/apache/spark/util/SparkPlanRules.scala
similarity index 55%
rename from gluten-core/src/main/scala/org/apache/spark/util/SparkRuleUtil.scala
rename to gluten-core/src/main/scala/org/apache/spark/util/SparkPlanRules.scala
index 100ec36d2..bbaee81a5 100644
--- a/gluten-core/src/main/scala/org/apache/spark/util/SparkRuleUtil.scala
+++ b/gluten-core/src/main/scala/org/apache/spark/util/SparkPlanRules.scala
@@ -21,36 +21,48 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
-object SparkRuleUtil extends Logging {
-
- /** Add the extended pre/post column rules */
- def extendedColumnarRules(
- session: SparkSession,
- conf: String
- ): List[SparkSession => Rule[SparkPlan]] = {
- val extendedRules = conf.split(",").filter(_.nonEmpty)
- extendedRules
- .map {
- ruleStr =>
+object SparkPlanRules extends Logging {
+ // Since https://github.com/apache/incubator-gluten/pull/1523
+ def extendedColumnarRule(ruleNamesStr: String): SparkSession =>
Rule[SparkPlan] =
+ (session: SparkSession) => {
+ val ruleNames = ruleNamesStr.split(",").filter(_.nonEmpty)
+ val rules = ruleNames.flatMap {
+ ruleName =>
try {
- val extensionConfClass = Utils.classForName(ruleStr)
- val extensionConf =
- extensionConfClass
+ val ruleClass = Utils.classForName(ruleName)
+ val rule =
+ ruleClass
.getConstructor(classOf[SparkSession])
.newInstance(session)
.asInstanceOf[Rule[SparkPlan]]
-
- Some((sparkSession: SparkSession) => extensionConf)
+ Some(rule)
} catch {
// Ignore the error if we cannot find the class or when the class
has the wrong type.
case e @ (_: ClassCastException | _: ClassNotFoundException |
_: NoClassDefFoundError) =>
- logWarning(s"Cannot create extended rule $ruleStr", e)
+ logWarning(s"Cannot create extended rule $ruleName", e)
None
}
}
- .filter(_.isDefined)
- .map(_.get)
- .toList
+ new OrderedRules(rules)
+ }
+
+ object EmptyRule extends Rule[SparkPlan] {
+ override def apply(plan: SparkPlan): SparkPlan = plan
+ }
+
+ class AbortRule(message: String) extends Rule[SparkPlan] {
+ override def apply(plan: SparkPlan): SparkPlan =
+ throw new IllegalStateException(
+ "AbortRule is being executed, this should not happen. Reason: " +
message)
+ }
+
+ class OrderedRules(rules: Seq[Rule[SparkPlan]]) extends Rule[SparkPlan] {
+ override def apply(plan: SparkPlan): SparkPlan = {
+ rules.foldLeft(plan) {
+ case (plan, rule) =>
+ rule.apply(plan)
+ }
+ }
}
}
diff --git
a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala
b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala
index 7c7aa0879..5d171a36b 100644
---
a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala
+++
b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala
@@ -16,8 +16,12 @@
*/
package org.apache.spark.sql.execution
+import org.apache.gluten.GlutenConfig
import org.apache.gluten.execution.BasicScanExecTransformer
import org.apache.gluten.extension.GlutenPlan
+import org.apache.gluten.extension.columnar.{ExpandFallbackPolicy,
RemoveFallbackTagRule}
+import
org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleBuilder
+import
org.apache.gluten.extension.columnar.MiscColumnarRules.RemoveTopmostColumnarToRow
import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier
import org.apache.gluten.extension.columnar.transition.InsertTransitions
import org.apache.gluten.utils.QueryPlanSelector
@@ -28,18 +32,20 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
class FallbackStrategiesSuite extends GlutenSQLTestsTrait {
+ import FallbackStrategiesSuite._
testGluten("Fall back the whole query if one unsupported") {
withSQLConf(("spark.gluten.sql.columnar.query.fallback.threshold", "1")) {
val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp()))))
- val rule = new HeuristicApplier(spark).withTransformRules(
+ val rule = newRuleApplier(
+ spark,
List(
_ =>
_ => {
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp()))))
},
- (_: SparkSession) => InsertTransitions(outputsColumnar = false)))
- val outputPlan = rule.apply(originalPlan)
+ c => InsertTransitions(c.outputsColumnar)))
+ val outputPlan = rule.apply(originalPlan, false)
// Expect to fall back the entire plan.
assert(outputPlan == originalPlan)
}
@@ -48,16 +54,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait {
testGluten("Fall back the whole plan if meeting the configured threshold") {
withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold",
"1")) {
val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp()))))
- val rule = new HeuristicApplier(spark)
+ val rule = newRuleApplier(
+ spark,
+ List(
+ _ =>
+ _ => {
+
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp()))))
+ },
+ c => InsertTransitions(c.outputsColumnar)))
.enableAdaptiveContext()
- .withTransformRules(
- List(
- _ =>
- _ => {
-
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp()))))
- },
- (_: SparkSession) => InsertTransitions(outputsColumnar = false)))
- val outputPlan = rule.apply(originalPlan)
+ val outputPlan = rule.apply(originalPlan, false)
// Expect to fall back the entire plan.
assert(outputPlan == originalPlan)
}
@@ -66,16 +72,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait {
testGluten("Don't fall back the whole plan if NOT meeting the configured
threshold") {
withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold",
"4")) {
val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp()))))
- val rule = new HeuristicApplier(spark)
+ val rule = newRuleApplier(
+ spark,
+ List(
+ _ =>
+ _ => {
+
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp()))))
+ },
+ c => InsertTransitions(c.outputsColumnar)))
.enableAdaptiveContext()
- .withTransformRules(
- List(
- _ =>
- _ => {
-
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp()))))
- },
- (_: SparkSession) => InsertTransitions(outputsColumnar = false)))
- val outputPlan = rule.apply(originalPlan)
+ val outputPlan = rule.apply(originalPlan, false)
// Expect to get the plan with columnar rule applied.
assert(outputPlan != originalPlan)
}
@@ -86,16 +92,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait {
" transformable)") {
withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold",
"2")) {
val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp()))))
- val rule = new HeuristicApplier(spark)
+ val rule = newRuleApplier(
+ spark,
+ List(
+ _ =>
+ _ => {
+
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer()))))
+ },
+ c => InsertTransitions(c.outputsColumnar)))
.enableAdaptiveContext()
- .withTransformRules(
- List(
- _ =>
- _ => {
-
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer()))))
- },
- (_: SparkSession) => InsertTransitions(outputsColumnar = false)))
- val outputPlan = rule.apply(originalPlan)
+ val outputPlan = rule.apply(originalPlan, false)
// Expect to fall back the entire plan.
assert(outputPlan == originalPlan)
}
@@ -106,16 +112,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait
{
"leaf node is transformable)") {
withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold",
"3")) {
val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp()))))
- val rule = new HeuristicApplier(spark)
+ val rule = newRuleApplier(
+ spark,
+ List(
+ _ =>
+ _ => {
+
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer()))))
+ },
+ c => InsertTransitions(c.outputsColumnar)))
.enableAdaptiveContext()
- .withTransformRules(
- List(
- _ =>
- _ => {
-
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer()))))
- },
- (_: SparkSession) => InsertTransitions(outputsColumnar = false)))
- val outputPlan = rule.apply(originalPlan)
+ val outputPlan = rule.apply(originalPlan, false)
// Expect to get the plan with columnar rule applied.
assert(outputPlan != originalPlan)
}
@@ -153,43 +159,60 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait
{
}
}
-case class LeafOp(override val supportsColumnar: Boolean = false) extends
LeafExecNode {
- override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
- override def output: Seq[Attribute] = Seq.empty
-}
+private object FallbackStrategiesSuite {
+ def newRuleApplier(
+ spark: SparkSession,
+ transformBuilders: Seq[ColumnarRuleBuilder]): HeuristicApplier = {
+ new HeuristicApplier(
+ spark,
+ transformBuilders,
+ List(c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(),
c.ac.originalPlan())),
+ List(
+ c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()),
+ _ => ColumnarCollapseTransformStages(GlutenConfig.getConf)
+ ),
+ List(_ => RemoveFallbackTagRule())
+ )
+ }
-case class UnaryOp1(child: SparkPlan, override val supportsColumnar: Boolean =
false)
- extends UnaryExecNode {
- override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
- override def output: Seq[Attribute] = child.output
- override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1 =
- copy(child = newChild)
-}
+ case class LeafOp(override val supportsColumnar: Boolean = false) extends
LeafExecNode {
+ override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
+ override def output: Seq[Attribute] = Seq.empty
+ }
-case class UnaryOp2(child: SparkPlan, override val supportsColumnar: Boolean =
false)
- extends UnaryExecNode {
- override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
- override def output: Seq[Attribute] = child.output
- override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp2 =
- copy(child = newChild)
-}
+ case class UnaryOp1(child: SparkPlan, override val supportsColumnar: Boolean
= false)
+ extends UnaryExecNode {
+ override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
+ override def output: Seq[Attribute] = child.output
+ override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1
=
+ copy(child = newChild)
+ }
+
+ case class UnaryOp2(child: SparkPlan, override val supportsColumnar: Boolean
= false)
+ extends UnaryExecNode {
+ override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
+ override def output: Seq[Attribute] = child.output
+ override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp2
=
+ copy(child = newChild)
+ }
// For replacing LeafOp.
-case class LeafOpTransformer(override val supportsColumnar: Boolean = true)
- extends LeafExecNode
- with GlutenPlan {
- override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
- override def output: Seq[Attribute] = Seq.empty
-}
+ case class LeafOpTransformer(override val supportsColumnar: Boolean = true)
+ extends LeafExecNode
+ with GlutenPlan {
+ override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
+ override def output: Seq[Attribute] = Seq.empty
+ }
// For replacing UnaryOp1.
-case class UnaryOp1Transformer(
- override val child: SparkPlan,
- override val supportsColumnar: Boolean = true)
- extends UnaryExecNode
- with GlutenPlan {
- override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
- override def output: Seq[Attribute] = child.output
- override protected def withNewChildInternal(newChild: SparkPlan):
UnaryOp1Transformer =
- copy(child = newChild)
+ case class UnaryOp1Transformer(
+ override val child: SparkPlan,
+ override val supportsColumnar: Boolean = true)
+ extends UnaryExecNode
+ with GlutenPlan {
+ override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
+ override def output: Seq[Attribute] = child.output
+ override protected def withNewChildInternal(newChild: SparkPlan):
UnaryOp1Transformer =
+ copy(child = newChild)
+ }
}
diff --git
a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala
b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala
index 681653409..2ca7429f1 100644
---
a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala
+++
b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala
@@ -31,7 +31,8 @@ class GlutenSessionExtensionSuite extends GlutenSQLTestsTrait
{
}
testGluten("test gluten extensions") {
-
assert(spark.sessionState.columnarRules.contains(ColumnarOverrideRules(spark)))
+ assert(
+
spark.sessionState.columnarRules.map(_.getClass).contains(classOf[ColumnarOverrideRules]))
assert(spark.sessionState.planner.strategies.contains(MySparkStrategy(spark)))
assert(spark.sessionState.analyzer.extendedResolutionRules.contains(MyRule(spark)))
diff --git
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala
index 54d7596b6..1ce0025f2 100644
---
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala
+++
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala
@@ -16,10 +16,13 @@
*/
package org.apache.spark.sql.execution
+import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.execution.BasicScanExecTransformer
import org.apache.gluten.extension.GlutenPlan
-import org.apache.gluten.extension.columnar.{FallbackEmptySchemaRelation,
FallbackTags}
+import org.apache.gluten.extension.columnar.{ExpandFallbackPolicy,
FallbackEmptySchemaRelation, FallbackTags, RemoveFallbackTagRule}
+import
org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleBuilder
+import
org.apache.gluten.extension.columnar.MiscColumnarRules.RemoveTopmostColumnarToRow
import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier
import org.apache.gluten.extension.columnar.transition.InsertTransitions
import org.apache.gluten.utils.QueryPlanSelector
@@ -30,17 +33,19 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
class FallbackStrategiesSuite extends GlutenSQLTestsTrait {
+ import FallbackStrategiesSuite._
testGluten("Fall back the whole query if one unsupported") {
withSQLConf(("spark.gluten.sql.columnar.query.fallback.threshold", "1")) {
val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp()))))
- val rule = new HeuristicApplier(spark).withTransformRules(
+ val rule = newRuleApplier(
+ spark,
List(
_ =>
_ => {
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp()))))
},
- (_: SparkSession) => InsertTransitions(outputsColumnar = false)))
- val outputPlan = rule.apply(originalPlan)
+ c => InsertTransitions(c.outputsColumnar)))
+ val outputPlan = rule.apply(originalPlan, false)
// Expect to fall back the entire plan.
assert(outputPlan == originalPlan)
}
@@ -49,16 +54,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait {
testGluten("Fall back the whole plan if meeting the configured threshold") {
withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold",
"1")) {
val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp()))))
- val rule = new HeuristicApplier(spark)
+ val rule = newRuleApplier(
+ spark,
+ List(
+ _ =>
+ _ => {
+
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp()))))
+ },
+ c => InsertTransitions(c.outputsColumnar)))
.enableAdaptiveContext()
- .withTransformRules(
- List(
- _ =>
- _ => {
-
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp()))))
- },
- (_: SparkSession) => InsertTransitions(outputsColumnar = false)))
- val outputPlan = rule.apply(originalPlan)
+ val outputPlan = rule.apply(originalPlan, false)
// Expect to fall back the entire plan.
assert(outputPlan == originalPlan)
}
@@ -67,16 +72,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait {
testGluten("Don't fall back the whole plan if NOT meeting the configured
threshold") {
withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold",
"4")) {
val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp()))))
- val rule = new HeuristicApplier(spark)
+ val rule = newRuleApplier(
+ spark,
+ List(
+ _ =>
+ _ => {
+
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp()))))
+ },
+ c => InsertTransitions(c.outputsColumnar)))
.enableAdaptiveContext()
- .withTransformRules(
- List(
- _ =>
- _ => {
-
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp()))))
- },
- (_: SparkSession) => InsertTransitions(outputsColumnar = false)))
- val outputPlan = rule.apply(originalPlan)
+ val outputPlan = rule.apply(originalPlan, false)
// Expect to get the plan with columnar rule applied.
assert(outputPlan != originalPlan)
}
@@ -87,16 +92,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait {
" transformable)") {
withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold",
"2")) {
val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp()))))
- val rule = new HeuristicApplier(spark)
+ val rule = newRuleApplier(
+ spark,
+ List(
+ _ =>
+ _ => {
+
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer()))))
+ },
+ c => InsertTransitions(c.outputsColumnar)))
.enableAdaptiveContext()
- .withTransformRules(
- List(
- _ =>
- _ => {
-
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer()))))
- },
- (_: SparkSession) => InsertTransitions(outputsColumnar = false)))
- val outputPlan = rule.apply(originalPlan)
+ val outputPlan = rule.apply(originalPlan, false)
// Expect to fall back the entire plan.
assert(outputPlan == originalPlan)
}
@@ -107,16 +112,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait
{
"leaf node is transformable)") {
withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold",
"3")) {
val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp()))))
- val rule = new HeuristicApplier(spark)
+ val rule = newRuleApplier(
+ spark,
+ List(
+ _ =>
+ _ => {
+
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer()))))
+ },
+ c => InsertTransitions(c.outputsColumnar)))
.enableAdaptiveContext()
- .withTransformRules(
- List(
- _ =>
- _ => {
-
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer()))))
- },
- (_: SparkSession) => InsertTransitions(outputsColumnar = false)))
- val outputPlan = rule.apply(originalPlan)
+ val outputPlan = rule.apply(originalPlan, false)
// Expect to get the plan with columnar rule applied.
assert(outputPlan != originalPlan)
}
@@ -168,44 +173,60 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait
{
thread.join(10000)
}
}
+private object FallbackStrategiesSuite {
+ def newRuleApplier(
+ spark: SparkSession,
+ transformBuilders: Seq[ColumnarRuleBuilder]): HeuristicApplier = {
+ new HeuristicApplier(
+ spark,
+ transformBuilders,
+ List(c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(),
c.ac.originalPlan())),
+ List(
+ c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()),
+ _ => ColumnarCollapseTransformStages(GlutenConfig.getConf)
+ ),
+ List(_ => RemoveFallbackTagRule())
+ )
+ }
-case class LeafOp(override val supportsColumnar: Boolean = false) extends
LeafExecNode {
- override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
- override def output: Seq[Attribute] = Seq.empty
-}
+ case class LeafOp(override val supportsColumnar: Boolean = false) extends
LeafExecNode {
+ override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
+ override def output: Seq[Attribute] = Seq.empty
+ }
-case class UnaryOp1(child: SparkPlan, override val supportsColumnar: Boolean =
false)
- extends UnaryExecNode {
- override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
- override def output: Seq[Attribute] = child.output
- override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1 =
- copy(child = newChild)
-}
+ case class UnaryOp1(child: SparkPlan, override val supportsColumnar: Boolean
= false)
+ extends UnaryExecNode {
+ override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
+ override def output: Seq[Attribute] = child.output
+ override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1
=
+ copy(child = newChild)
+ }
-case class UnaryOp2(child: SparkPlan, override val supportsColumnar: Boolean =
false)
- extends UnaryExecNode {
- override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
- override def output: Seq[Attribute] = child.output
- override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp2 =
- copy(child = newChild)
-}
+ case class UnaryOp2(child: SparkPlan, override val supportsColumnar: Boolean
= false)
+ extends UnaryExecNode {
+ override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
+ override def output: Seq[Attribute] = child.output
+ override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp2
=
+ copy(child = newChild)
+ }
-// For replacing LeafOp.
-case class LeafOpTransformer(override val supportsColumnar: Boolean = true)
- extends LeafExecNode
- with GlutenPlan {
- override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
- override def output: Seq[Attribute] = Seq.empty
-}
+ // For replacing LeafOp.
+ case class LeafOpTransformer(override val supportsColumnar: Boolean = true)
+ extends LeafExecNode
+ with GlutenPlan {
+ override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
+ override def output: Seq[Attribute] = Seq.empty
+ }
-// For replacing UnaryOp1.
-case class UnaryOp1Transformer(
- override val child: SparkPlan,
- override val supportsColumnar: Boolean = true)
- extends UnaryExecNode
- with GlutenPlan {
- override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
- override def output: Seq[Attribute] = child.output
- override protected def withNewChildInternal(newChild: SparkPlan):
UnaryOp1Transformer =
- copy(child = newChild)
+ // For replacing UnaryOp1.
+ case class UnaryOp1Transformer(
+ override val child: SparkPlan,
+ override val supportsColumnar: Boolean = true)
+ extends UnaryExecNode
+ with GlutenPlan {
+ override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
+ override def output: Seq[Attribute] = child.output
+ override protected def withNewChildInternal(newChild: SparkPlan):
UnaryOp1Transformer =
+ copy(child = newChild)
+ }
}
diff --git
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala
index 681653409..2ca7429f1 100644
---
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala
+++
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala
@@ -31,7 +31,8 @@ class GlutenSessionExtensionSuite extends GlutenSQLTestsTrait
{
}
testGluten("test gluten extensions") {
-
assert(spark.sessionState.columnarRules.contains(ColumnarOverrideRules(spark)))
+ assert(
+
spark.sessionState.columnarRules.map(_.getClass).contains(classOf[ColumnarOverrideRules]))
assert(spark.sessionState.planner.strategies.contains(MySparkStrategy(spark)))
assert(spark.sessionState.analyzer.extendedResolutionRules.contains(MyRule(spark)))
diff --git
a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala
b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala
index 5150a4768..3acc9c4b3 100644
---
a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala
+++
b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala
@@ -16,10 +16,13 @@
*/
package org.apache.spark.sql.execution
+import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.execution.BasicScanExecTransformer
import org.apache.gluten.extension.GlutenPlan
-import org.apache.gluten.extension.columnar.{FallbackEmptySchemaRelation,
FallbackTags}
+import org.apache.gluten.extension.columnar.{ExpandFallbackPolicy,
FallbackEmptySchemaRelation, FallbackTags, RemoveFallbackTagRule}
+import
org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleBuilder
+import
org.apache.gluten.extension.columnar.MiscColumnarRules.RemoveTopmostColumnarToRow
import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier
import org.apache.gluten.extension.columnar.transition.InsertTransitions
import org.apache.gluten.utils.QueryPlanSelector
@@ -30,18 +33,19 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
class FallbackStrategiesSuite extends GlutenSQLTestsTrait {
-
+ import FallbackStrategiesSuite._
testGluten("Fall back the whole query if one unsupported") {
withSQLConf(("spark.gluten.sql.columnar.query.fallback.threshold", "1")) {
val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp()))))
- val rule = new HeuristicApplier(spark).withTransformRules(
+ val rule = newRuleApplier(
+ spark,
List(
_ =>
_ => {
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp()))))
},
- (_: SparkSession) => InsertTransitions(outputsColumnar = false)))
- val outputPlan = rule.apply(originalPlan)
+ c => InsertTransitions(c.outputsColumnar)))
+ val outputPlan = rule.apply(originalPlan, false)
// Expect to fall back the entire plan.
assert(outputPlan == originalPlan)
}
@@ -50,16 +54,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait {
testGluten("Fall back the whole plan if meeting the configured threshold") {
withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold",
"1")) {
val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp()))))
- val rule = new HeuristicApplier(spark)
+ val rule = newRuleApplier(
+ spark,
+ List(
+ _ =>
+ _ => {
+
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp()))))
+ },
+ c => InsertTransitions(c.outputsColumnar)))
.enableAdaptiveContext()
- .withTransformRules(
- List(
- _ =>
- _ => {
-
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp()))))
- },
- (_: SparkSession) => InsertTransitions(outputsColumnar = false)))
- val outputPlan = rule.apply(originalPlan)
+ val outputPlan = rule.apply(originalPlan, false)
// Expect to fall back the entire plan.
assert(outputPlan == originalPlan)
}
@@ -68,16 +72,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait {
testGluten("Don't fall back the whole plan if NOT meeting the configured
threshold") {
withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold",
"4")) {
val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp()))))
- val rule = new HeuristicApplier(spark)
+ val rule = newRuleApplier(
+ spark,
+ List(
+ _ =>
+ _ => {
+
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp()))))
+ },
+ c => InsertTransitions(c.outputsColumnar)))
.enableAdaptiveContext()
- .withTransformRules(
- List(
- _ =>
- _ => {
-
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp()))))
- },
- (_: SparkSession) => InsertTransitions(outputsColumnar = false)))
- val outputPlan = rule.apply(originalPlan)
+ val outputPlan = rule.apply(originalPlan, false)
// Expect to get the plan with columnar rule applied.
assert(outputPlan != originalPlan)
}
@@ -88,16 +92,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait {
" transformable)") {
withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold",
"2")) {
val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp()))))
- val rule = new HeuristicApplier(spark)
+ val rule = newRuleApplier(
+ spark,
+ List(
+ _ =>
+ _ => {
+
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer()))))
+ },
+ c => InsertTransitions(c.outputsColumnar)))
.enableAdaptiveContext()
- .withTransformRules(
- List(
- _ =>
- _ => {
-
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer()))))
- },
- (_: SparkSession) => InsertTransitions(outputsColumnar = false)))
- val outputPlan = rule.apply(originalPlan)
+ val outputPlan = rule.apply(originalPlan, false)
// Expect to fall back the entire plan.
assert(outputPlan == originalPlan)
}
@@ -108,16 +112,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait
{
"leaf node is transformable)") {
withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold",
"3")) {
val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp()))))
- val rule = new HeuristicApplier(spark)
+ val rule = newRuleApplier(
+ spark,
+ List(
+ _ =>
+ _ => {
+
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer()))))
+ },
+ c => InsertTransitions(c.outputsColumnar)))
.enableAdaptiveContext()
- .withTransformRules(
- List(
- _ =>
- _ => {
-
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer()))))
- },
- (_: SparkSession) => InsertTransitions(outputsColumnar = false)))
- val outputPlan = rule.apply(originalPlan)
+ val outputPlan = rule.apply(originalPlan, false)
// Expect to get the plan with columnar rule applied.
assert(outputPlan != originalPlan)
}
@@ -170,43 +174,60 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait
{
}
}
-case class LeafOp(override val supportsColumnar: Boolean = false) extends
LeafExecNode {
- override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
- override def output: Seq[Attribute] = Seq.empty
-}
+private object FallbackStrategiesSuite {
+ def newRuleApplier(
+ spark: SparkSession,
+ transformBuilders: Seq[ColumnarRuleBuilder]): HeuristicApplier = {
+ new HeuristicApplier(
+ spark,
+ transformBuilders,
+ List(c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(),
c.ac.originalPlan())),
+ List(
+ c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()),
+ _ => ColumnarCollapseTransformStages(GlutenConfig.getConf)
+ ),
+ List(_ => RemoveFallbackTagRule())
+ )
+ }
-case class UnaryOp1(child: SparkPlan, override val supportsColumnar: Boolean =
false)
- extends UnaryExecNode {
- override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
- override def output: Seq[Attribute] = child.output
- override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1 =
- copy(child = newChild)
-}
+ case class LeafOp(override val supportsColumnar: Boolean = false) extends
LeafExecNode {
+ override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
+ override def output: Seq[Attribute] = Seq.empty
+ }
-case class UnaryOp2(child: SparkPlan, override val supportsColumnar: Boolean =
false)
- extends UnaryExecNode {
- override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
- override def output: Seq[Attribute] = child.output
- override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp2 =
- copy(child = newChild)
-}
+ case class UnaryOp1(child: SparkPlan, override val supportsColumnar: Boolean
= false)
+ extends UnaryExecNode {
+ override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
+ override def output: Seq[Attribute] = child.output
+ override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1
=
+ copy(child = newChild)
+ }
-// For replacing LeafOp.
-case class LeafOpTransformer(override val supportsColumnar: Boolean = true)
- extends LeafExecNode
- with GlutenPlan {
- override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
- override def output: Seq[Attribute] = Seq.empty
-}
+ case class UnaryOp2(child: SparkPlan, override val supportsColumnar: Boolean
= false)
+ extends UnaryExecNode {
+ override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
+ override def output: Seq[Attribute] = child.output
+ override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp2
=
+ copy(child = newChild)
+ }
-// For replacing UnaryOp1.
-case class UnaryOp1Transformer(
- override val child: SparkPlan,
- override val supportsColumnar: Boolean = true)
- extends UnaryExecNode
- with GlutenPlan {
- override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
- override def output: Seq[Attribute] = child.output
- override protected def withNewChildInternal(newChild: SparkPlan):
UnaryOp1Transformer =
- copy(child = newChild)
+ // For replacing LeafOp.
+ case class LeafOpTransformer(override val supportsColumnar: Boolean = true)
+ extends LeafExecNode
+ with GlutenPlan {
+ override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
+ override def output: Seq[Attribute] = Seq.empty
+ }
+
+ // For replacing UnaryOp1.
+ case class UnaryOp1Transformer(
+ override val child: SparkPlan,
+ override val supportsColumnar: Boolean = true)
+ extends UnaryExecNode
+ with GlutenPlan {
+ override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
+ override def output: Seq[Attribute] = child.output
+ override protected def withNewChildInternal(newChild: SparkPlan):
UnaryOp1Transformer =
+ copy(child = newChild)
+ }
}
diff --git
a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala
b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala
index 681653409..2ca7429f1 100644
---
a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala
+++
b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala
@@ -31,7 +31,8 @@ class GlutenSessionExtensionSuite extends GlutenSQLTestsTrait
{
}
testGluten("test gluten extensions") {
-
assert(spark.sessionState.columnarRules.contains(ColumnarOverrideRules(spark)))
+ assert(
+
spark.sessionState.columnarRules.map(_.getClass).contains(classOf[ColumnarOverrideRules]))
assert(spark.sessionState.planner.strategies.contains(MySparkStrategy(spark)))
assert(spark.sessionState.analyzer.extendedResolutionRules.contains(MyRule(spark)))
diff --git
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala
index 5150a4768..bcc4e829b 100644
---
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala
+++
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/execution/FallbackStrategiesSuite.scala
@@ -16,10 +16,13 @@
*/
package org.apache.spark.sql.execution
+import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.execution.BasicScanExecTransformer
import org.apache.gluten.extension.GlutenPlan
-import org.apache.gluten.extension.columnar.{FallbackEmptySchemaRelation,
FallbackTags}
+import org.apache.gluten.extension.columnar.{ExpandFallbackPolicy,
FallbackEmptySchemaRelation, FallbackTags, RemoveFallbackTagRule}
+import
org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleBuilder
+import
org.apache.gluten.extension.columnar.MiscColumnarRules.RemoveTopmostColumnarToRow
import org.apache.gluten.extension.columnar.heuristic.HeuristicApplier
import org.apache.gluten.extension.columnar.transition.InsertTransitions
import org.apache.gluten.utils.QueryPlanSelector
@@ -30,18 +33,20 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
class FallbackStrategiesSuite extends GlutenSQLTestsTrait {
+ import FallbackStrategiesSuite._
testGluten("Fall back the whole query if one unsupported") {
withSQLConf(("spark.gluten.sql.columnar.query.fallback.threshold", "1")) {
val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp()))))
- val rule = new HeuristicApplier(spark).withTransformRules(
+ val rule = newRuleApplier(
+ spark,
List(
_ =>
_ => {
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp()))))
},
- (_: SparkSession) => InsertTransitions(outputsColumnar = false)))
- val outputPlan = rule.apply(originalPlan)
+ c => InsertTransitions(c.outputsColumnar)))
+ val outputPlan = rule.apply(originalPlan, false)
// Expect to fall back the entire plan.
assert(outputPlan == originalPlan)
}
@@ -50,16 +55,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait {
testGluten("Fall back the whole plan if meeting the configured threshold") {
withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold",
"1")) {
val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp()))))
- val rule = new HeuristicApplier(spark)
+ val rule = newRuleApplier(
+ spark,
+ List(
+ _ =>
+ _ => {
+
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp()))))
+ },
+ c => InsertTransitions(c.outputsColumnar)))
.enableAdaptiveContext()
- .withTransformRules(
- List(
- _ =>
- _ => {
-
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp()))))
- },
- (_: SparkSession) => InsertTransitions(outputsColumnar = false)))
- val outputPlan = rule.apply(originalPlan)
+ val outputPlan = rule.apply(originalPlan, false)
// Expect to fall back the entire plan.
assert(outputPlan == originalPlan)
}
@@ -68,16 +73,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait {
testGluten("Don't fall back the whole plan if NOT meeting the configured
threshold") {
withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold",
"4")) {
val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp()))))
- val rule = new HeuristicApplier(spark)
+ val rule = newRuleApplier(
+ spark,
+ List(
+ _ =>
+ _ => {
+
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp()))))
+ },
+ c => InsertTransitions(c.outputsColumnar)))
.enableAdaptiveContext()
- .withTransformRules(
- List(
- _ =>
- _ => {
-
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOp()))))
- },
- (_: SparkSession) => InsertTransitions(outputsColumnar = false)))
- val outputPlan = rule.apply(originalPlan)
+ val outputPlan = rule.apply(originalPlan, false)
// Expect to get the plan with columnar rule applied.
assert(outputPlan != originalPlan)
}
@@ -88,16 +93,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait {
" transformable)") {
withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold",
"2")) {
val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp()))))
- val rule = new HeuristicApplier(spark)
+ val rule = newRuleApplier(
+ spark,
+ List(
+ _ =>
+ _ => {
+
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer()))))
+ },
+ c => InsertTransitions(c.outputsColumnar)))
.enableAdaptiveContext()
- .withTransformRules(
- List(
- _ =>
- _ => {
-
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer()))))
- },
- (_: SparkSession) => InsertTransitions(outputsColumnar = false)))
- val outputPlan = rule.apply(originalPlan)
+ val outputPlan = rule.apply(originalPlan, false)
// Expect to fall back the entire plan.
assert(outputPlan == originalPlan)
}
@@ -108,16 +113,16 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait
{
"leaf node is transformable)") {
withSQLConf(("spark.gluten.sql.columnar.wholeStage.fallback.threshold",
"3")) {
val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp()))))
- val rule = new HeuristicApplier(spark)
+ val rule = newRuleApplier(
+ spark,
+ List(
+ _ =>
+ _ => {
+
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer()))))
+ },
+ c => InsertTransitions(c.outputsColumnar)))
.enableAdaptiveContext()
- .withTransformRules(
- List(
- _ =>
- _ => {
-
UnaryOp2(UnaryOp1Transformer(UnaryOp2(UnaryOp1Transformer(LeafOpTransformer()))))
- },
- (_: SparkSession) => InsertTransitions(outputsColumnar = false)))
- val outputPlan = rule.apply(originalPlan)
+ val outputPlan = rule.apply(originalPlan, false)
// Expect to get the plan with columnar rule applied.
assert(outputPlan != originalPlan)
}
@@ -170,43 +175,60 @@ class FallbackStrategiesSuite extends GlutenSQLTestsTrait
{
}
}
-case class LeafOp(override val supportsColumnar: Boolean = false) extends
LeafExecNode {
- override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
- override def output: Seq[Attribute] = Seq.empty
-}
+private object FallbackStrategiesSuite {
+ def newRuleApplier(
+ spark: SparkSession,
+ transformBuilders: Seq[ColumnarRuleBuilder]): HeuristicApplier = {
+ new HeuristicApplier(
+ spark,
+ transformBuilders,
+ List(c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(),
c.ac.originalPlan())),
+ List(
+ c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()),
+ _ => ColumnarCollapseTransformStages(GlutenConfig.getConf)
+ ),
+ List(_ => RemoveFallbackTagRule())
+ )
+ }
-case class UnaryOp1(child: SparkPlan, override val supportsColumnar: Boolean =
false)
- extends UnaryExecNode {
- override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
- override def output: Seq[Attribute] = child.output
- override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1 =
- copy(child = newChild)
-}
+ case class LeafOp(override val supportsColumnar: Boolean = false) extends
LeafExecNode {
+ override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
+ override def output: Seq[Attribute] = Seq.empty
+ }
-case class UnaryOp2(child: SparkPlan, override val supportsColumnar: Boolean =
false)
- extends UnaryExecNode {
- override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
- override def output: Seq[Attribute] = child.output
- override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp2 =
- copy(child = newChild)
-}
+ case class UnaryOp1(child: SparkPlan, override val supportsColumnar: Boolean
= false)
+ extends UnaryExecNode {
+ override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
+ override def output: Seq[Attribute] = child.output
+ override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1
=
+ copy(child = newChild)
+ }
+
+ case class UnaryOp2(child: SparkPlan, override val supportsColumnar: Boolean
= false)
+ extends UnaryExecNode {
+ override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
+ override def output: Seq[Attribute] = child.output
+ override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp2
=
+ copy(child = newChild)
+ }
// For replacing LeafOp.
-case class LeafOpTransformer(override val supportsColumnar: Boolean = true)
- extends LeafExecNode
- with GlutenPlan {
- override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
- override def output: Seq[Attribute] = Seq.empty
-}
+ case class LeafOpTransformer(override val supportsColumnar: Boolean = true)
+ extends LeafExecNode
+ with GlutenPlan {
+ override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
+ override def output: Seq[Attribute] = Seq.empty
+ }
// For replacing UnaryOp1.
-case class UnaryOp1Transformer(
- override val child: SparkPlan,
- override val supportsColumnar: Boolean = true)
- extends UnaryExecNode
- with GlutenPlan {
- override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
- override def output: Seq[Attribute] = child.output
- override protected def withNewChildInternal(newChild: SparkPlan):
UnaryOp1Transformer =
- copy(child = newChild)
+ case class UnaryOp1Transformer(
+ override val child: SparkPlan,
+ override val supportsColumnar: Boolean = true)
+ extends UnaryExecNode
+ with GlutenPlan {
+ override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
+ override def output: Seq[Attribute] = child.output
+ override protected def withNewChildInternal(newChild: SparkPlan):
UnaryOp1Transformer =
+ copy(child = newChild)
+ }
}
diff --git
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala
index 681653409..2ca7429f1 100644
---
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala
+++
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/extension/GlutenSessionExtensionSuite.scala
@@ -31,7 +31,8 @@ class GlutenSessionExtensionSuite extends GlutenSQLTestsTrait
{
}
testGluten("test gluten extensions") {
-
assert(spark.sessionState.columnarRules.contains(ColumnarOverrideRules(spark)))
+ assert(
+
spark.sessionState.columnarRules.map(_.getClass).contains(classOf[ColumnarOverrideRules]))
assert(spark.sessionState.planner.strategies.contains(MySparkStrategy(spark)))
assert(spark.sessionState.analyzer.extendedResolutionRules.contains(MyRule(spark)))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]