This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 883596a4bab [SPARK-38697][SQL] Extend SparkSessionExtensions to inject
rules into AQE Optimizer
883596a4bab is described below
commit 883596a4bab36ddf0e1a5af0ba98325ca8582550
Author: ulysses-you <[email protected]>
AuthorDate: Fri Apr 15 16:02:00 2022 +0800
[SPARK-38697][SQL] Extend SparkSessionExtensions to inject rules into AQE
Optimizer
### What changes were proposed in this pull request?
Add `injectRuntimeOptimizerRule` public method in `SparkSessionExtensions`
### Why are the changes needed?
Provide a entrance for user to play their logical plan with runtime
optimizer in adaptive query execution framework.
We should follow the existed Spark session extension to allow user inject
the rule.
So developers can improve the logical plan leverage accurate statistics
from shuffle.
### Does this PR introduce _any_ user-facing change?
yes, a new entrance for Spark session extension
### How was this patch tested?
Add test
Closes #36011 from ulysses-you/aqe-optimizer.
Authored-by: ulysses-you <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../apache/spark/sql/SparkSessionExtensions.scala | 22 ++++++++++++++
.../sql/execution/adaptive/AQEOptimizer.scala | 10 ++++---
.../execution/adaptive/AdaptiveRulesHolder.scala | 30 +++++++++++++++++++
.../execution/adaptive/AdaptiveSparkPlanExec.scala | 5 ++--
.../sql/internal/BaseSessionStateBuilder.scala | 11 ++++---
.../apache/spark/sql/internal/SessionState.scala | 4 +--
.../spark/sql/SparkSessionExtensionSuite.scala | 35 ++++++++++++++++++++--
7 files changed, 102 insertions(+), 15 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
index a4ec48142cf..a8ccc39ac47 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
@@ -47,6 +47,7 @@ import org.apache.spark.sql.execution.{ColumnarRule,
SparkPlan}
* <li>(External) Catalog listeners.</li>
* <li>Columnar Rules.</li>
* <li>Adaptive Query Stage Preparation Rules.</li>
+ * <li>Adaptive Query Execution Runtime Optimizer Rules.</li>
* </ul>
*
* The extensions can be used by calling `withExtensions` on the
[[SparkSession.Builder]], for
@@ -113,6 +114,7 @@ class SparkSessionExtensions {
private[this] val columnarRuleBuilders =
mutable.Buffer.empty[ColumnarRuleBuilder]
private[this] val queryStagePrepRuleBuilders =
mutable.Buffer.empty[QueryStagePrepRuleBuilder]
+ private[this] val runtimeOptimizerRules = mutable.Buffer.empty[RuleBuilder]
/**
* Build the override rules for columnar execution.
@@ -128,6 +130,13 @@ class SparkSessionExtensions {
queryStagePrepRuleBuilders.map(_.apply(session)).toSeq
}
+ /**
+ * Build the override rules for the optimizer of adaptive query execution.
+ */
+ private[sql] def buildRuntimeOptimizerRules(session: SparkSession):
Seq[Rule[LogicalPlan]] = {
+ runtimeOptimizerRules.map(_.apply(session)).toSeq
+ }
+
/**
* Inject a rule that can override the columnar execution of an executor.
*/
@@ -143,6 +152,19 @@ class SparkSessionExtensions {
queryStagePrepRuleBuilders += builder
}
+ /**
+ * Inject a runtime `Rule` builder into the [[SparkSession]].
+ * The injected rules will be executed after built-in
+ * [[org.apache.spark.sql.execution.adaptive.AQEOptimizer]] rules are
applied.
+ * A runtime optimizer rule is used to improve the quality of a logical plan
during execution
+ * which can leverage accurate statistics from shuffle.
+ *
+ * Note that, it does not work if adaptive query execution is disabled.
+ */
+ def injectRuntimeOptimizerRule(builder: RuleBuilder): Unit = {
+ runtimeOptimizerRules += builder
+ }
+
private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]
/**
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
index 5533bb1cd79..93fde72993e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.adaptive
import org.apache.spark.sql.catalyst.analysis.UpdateAttributeNullability
import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation,
EliminateLimits, OptimizeOneRowPlan}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan,
LogicalPlanIntegrity, PlanHelper}
-import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.Utils
@@ -28,7 +28,9 @@ import org.apache.spark.util.Utils
/**
* The optimizer for re-optimizing the logical plan used by
AdaptiveSparkPlanExec.
*/
-class AQEOptimizer(conf: SQLConf) extends RuleExecutor[LogicalPlan] {
+class AQEOptimizer(conf: SQLConf, extendedRuntimeOptimizerRules:
Seq[Rule[LogicalPlan]])
+ extends RuleExecutor[LogicalPlan] {
+
private def fixedPoint =
FixedPoint(
conf.optimizerMaxIterations,
@@ -41,8 +43,8 @@ class AQEOptimizer(conf: SQLConf) extends
RuleExecutor[LogicalPlan] {
UpdateAttributeNullability),
Batch("Dynamic Join Selection", Once, DynamicJoinSelection),
Batch("Eliminate Limits", fixedPoint, EliminateLimits),
- Batch("Optimize One Row Plan", fixedPoint, OptimizeOneRowPlan)
- )
+ Batch("Optimize One Row Plan", fixedPoint, OptimizeOneRowPlan)) :+
+ Batch("User Provided Runtime Optimizers", fixedPoint,
extendedRuntimeOptimizerRules: _*)
final override protected def batches: Seq[Batch] = {
val excludedRules = conf.getConf(SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala
new file mode 100644
index 00000000000..11cae824568
--- /dev/null
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.adaptive
+
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.SparkPlan
+
+/**
+ * A holder to warp the SQL extension rules of adaptive query execution
+ */
+class AdaptiveRulesHolder(
+ val queryStagePrepRules: Seq[Rule[SparkPlan]],
+ val runtimeOptimizerRules: Seq[Rule[LogicalPlan]]) {
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index 25380bc1d89..808959363ac 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -84,7 +84,8 @@ case class AdaptiveSparkPlanExec(
@transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]()
// The logical plan optimizer for re-optimizing the current logical plan.
- @transient private val optimizer = new AQEOptimizer(conf)
+ @transient private val optimizer = new AQEOptimizer(conf,
+ session.sessionState.adaptiveRulesHolder.runtimeOptimizerRules)
// `EnsureRequirements` may remove user-specified repartition and assume the
query plan won't
// change its output partitioning. This assumption is not true in AQE. Here
we check the
@@ -121,7 +122,7 @@ case class AdaptiveSparkPlanExec(
RemoveRedundantSorts,
DisableUnnecessaryBucketedScan,
OptimizeSkewedJoin(ensureRequirements)
- ) ++ context.session.sessionState.queryStagePrepRules
+ ) ++ context.session.sessionState.adaptiveRulesHolder.queryStagePrepRules
}
// A list of physical optimizer rules to be applied to a new stage before
its execution. These
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 0655b946cc8..f3cbb789a94 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -27,7 +27,8 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.execution.{ColumnarRule, CommandExecutionMode,
QueryExecution, SparkOptimizer, SparkPlan, SparkPlanner, SparkSqlParser}
+import org.apache.spark.sql.execution.{ColumnarRule, CommandExecutionMode,
QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser}
+import org.apache.spark.sql.execution.adaptive.AdaptiveRulesHolder
import org.apache.spark.sql.execution.aggregate.{ResolveEncodersInScalaAgg,
ScalaUDAF}
import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin
import org.apache.spark.sql.execution.command.CommandCheck
@@ -308,8 +309,10 @@ abstract class BaseSessionStateBuilder(
extensions.buildColumnarRules(session)
}
- protected def queryStagePrepRules: Seq[Rule[SparkPlan]] = {
- extensions.buildQueryStagePrepRules(session)
+ protected def adaptiveRulesHolder: AdaptiveRulesHolder = {
+ new AdaptiveRulesHolder(
+ extensions.buildQueryStagePrepRules(session),
+ extensions.buildRuntimeOptimizerRules(session))
}
/**
@@ -366,7 +369,7 @@ abstract class BaseSessionStateBuilder(
createQueryExecution,
createClone,
columnarRules,
- queryStagePrepRules)
+ adaptiveRulesHolder)
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index cdf764a7317..1d5e61aab26 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -30,9 +30,9 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.adaptive.AdaptiveRulesHolder
import org.apache.spark.sql.streaming.StreamingQueryManager
import org.apache.spark.sql.util.ExecutionListenerManager
import org.apache.spark.util.{DependencyUtils, Utils}
@@ -79,7 +79,7 @@ private[sql] class SessionState(
createQueryExecution: (LogicalPlan, CommandExecutionMode.Value) =>
QueryExecution,
createClone: (SparkSession, SessionState) => SessionState,
val columnarRules: Seq[ColumnarRule],
- val queryStagePrepRules: Seq[Rule[SparkPlan]]) {
+ val adaptiveRulesHolder: AdaptiveRulesHolder) {
// The following fields are lazy to avoid creating the Hive client when
creating SessionState.
lazy val catalog: SessionCatalog = catalogBuilder()
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 17124cc2e4c..1aef458a352 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -27,7 +27,8 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow,
TableIdentifier}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser,
ParserInterface}
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation,
LogicalPlan, Statistics, UnresolvedHint}
+import org.apache.spark.sql.catalyst.plans.SQLHelper
+import org.apache.spark.sql.catalyst.plans.logical.{Limit, LocalRelation,
LogicalPlan, Statistics, UnresolvedHint}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
@@ -45,7 +46,7 @@ import org.apache.spark.unsafe.types.UTF8String
/**
* Test cases for the [[SparkSessionExtensions]].
*/
-class SparkSessionExtensionSuite extends SparkFunSuite {
+class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper {
private def create(
builder: SparkSessionExtensionsProvider):
Seq[SparkSessionExtensionsProvider] = Seq(builder)
@@ -171,7 +172,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
}
withSession(extensions) { session =>
session.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED, true)
-
assert(session.sessionState.queryStagePrepRules.contains(MyQueryStagePrepRule()))
+ assert(session.sessionState.adaptiveRulesHolder.queryStagePrepRules
+ .contains(MyQueryStagePrepRule()))
assert(session.sessionState.columnarRules.contains(
MyColumnarRule(MyNewQueryStageRule(), MyNewQueryStageRule())))
import session.sqlContext.implicits._
@@ -406,6 +408,26 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
session.sql("SELECT * FROM v")
}
}
+
+ test("SPARK-38697: Extend SparkSessionExtensions to inject rules into AQE
Optimizer") {
+ def executedPlan(df: Dataset[java.lang.Long]): SparkPlan = {
+
assert(df.queryExecution.executedPlan.isInstanceOf[AdaptiveSparkPlanExec])
+
df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
+ }
+ val extensions = create { extensions =>
+ extensions.injectRuntimeOptimizerRule(_ => AddLimit)
+ }
+ withSession(extensions) { session =>
+
assert(session.sessionState.adaptiveRulesHolder.runtimeOptimizerRules.contains(AddLimit))
+
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+ val df = session.range(2).repartition()
+ assert(!executedPlan(df).isInstanceOf[CollectLimitExec])
+ df.collect()
+ assert(executedPlan(df).isInstanceOf[CollectLimitExec])
+ }
+ }
+ }
}
case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] {
@@ -1024,3 +1046,10 @@ class YourExtensions extends
SparkSessionExtensionsProvider {
v1.injectFunction(getAppName)
}
}
+
+object AddLimit extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan match {
+ case Limit(_, _) => plan
+ case _ => Limit(Literal(1), plan)
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]