This is an automated email from the ASF dual-hosted git repository.
lixiao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 44d370d [SPARK-31475][SQL] Broadcast stage in AQE did not timeout
44d370d is described below
commit 44d370dd4501f0a4abb7194f7cff0d346aac0992
Author: Maryann Xue <[email protected]>
AuthorDate: Mon Apr 20 11:55:48 2020 -0700
[SPARK-31475][SQL] Broadcast stage in AQE did not timeout
### What changes were proposed in this pull request?
This PR adds a timeout for the Future of a BroadcastQueryStageExec to make
sure it can have the same timeout behavior as a non-AQE broadcast exchange.
### Why are the changes needed?
This is to make the broadcast timeout behavior in AQE consistent with that
in non-AQE.
### Does this PR introduce any user-facing change?
No.
### How was this patch tested?
Added UT.
Closes #28250 from maryannxue/aqe-broadcast-timeout.
Authored-by: Maryann Xue <[email protected]>
Signed-off-by: gatorsmile <[email protected]>
---
.../execution/adaptive/AdaptiveSparkPlanExec.scala | 2 +-
.../sql/execution/adaptive/QueryStageExec.scala | 35 ++++++++++++++++++----
.../execution/exchange/BroadcastExchangeExec.scala | 8 ++---
.../sql/execution/joins/BroadcastJoinSuite.scala | 23 ++++++++++++--
4 files changed, 56 insertions(+), 12 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index 3ac4ea5..f819937 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -547,7 +547,7 @@ case class AdaptiveSparkPlanExec(
}
object AdaptiveSparkPlanExec {
- private val executionContext = ExecutionContext.fromExecutorService(
+ private[adaptive] val executionContext =
ExecutionContext.fromExecutorService(
ThreadUtils.newDaemonCachedThreadPool("QueryStageCreator", 16))
/**
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
index beaa972..f414f85 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala
@@ -17,9 +17,11 @@
package org.apache.spark.sql.execution.adaptive
-import scala.concurrent.Future
+import java.util.concurrent.TimeUnit
-import org.apache.spark.{FutureAction, MapOutputStatistics}
+import scala.concurrent.{Future, Promise}
+
+import org.apache.spark.{FutureAction, MapOutputStatistics, SparkException}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -28,6 +30,8 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.util.ThreadUtils
/**
* A query stage is an independent subgraph of the query plan. Query stage
materializes its output
@@ -100,8 +104,8 @@ abstract class QueryStageExec extends LeafExecNode {
override def executeTail(n: Int): Array[InternalRow] = plan.executeTail(n)
override def executeToIterator(): Iterator[InternalRow] =
plan.executeToIterator()
- override def doPrepare(): Unit = plan.prepare()
- override def doExecute(): RDD[InternalRow] = plan.execute()
+ protected override def doPrepare(): Unit = plan.prepare()
+ protected override def doExecute(): RDD[InternalRow] = plan.execute()
override def doExecuteBroadcast[T](): Broadcast[T] = plan.executeBroadcast()
override def doCanonicalize(): SparkPlan = plan.canonicalized
@@ -187,8 +191,24 @@ case class BroadcastQueryStageExec(
throw new IllegalStateException("wrong plan for broadcast stage:\n " +
plan.treeString)
}
+ @transient private lazy val materializeWithTimeout = {
+ val broadcastFuture = broadcast.completionFuture
+ val timeout = SQLConf.get.broadcastTimeout
+ val promise = Promise[Any]()
+ val fail = BroadcastQueryStageExec.scheduledExecutor.schedule(new
Runnable() {
+ override def run(): Unit = {
+ promise.tryFailure(new SparkException(s"Could not execute broadcast in
$timeout secs. " +
+ s"You can increase the timeout for broadcasts via
${SQLConf.BROADCAST_TIMEOUT.key} or " +
+ s"disable broadcast join by setting
${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1"))
+ }
+ }, timeout, TimeUnit.SECONDS)
+ broadcastFuture.onComplete(_ =>
fail.cancel(false))(AdaptiveSparkPlanExec.executionContext)
+ Future.firstCompletedOf(
+ Seq(broadcastFuture,
promise.future))(AdaptiveSparkPlanExec.executionContext)
+ }
+
override def doMaterialize(): Future[Any] = {
- broadcast.completionFuture
+ materializeWithTimeout
}
override def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]):
QueryStageExec = {
@@ -204,3 +224,8 @@ case class BroadcastQueryStageExec(
}
}
}
+
+object BroadcastQueryStageExec {
+ private val scheduledExecutor =
+ ThreadUtils.newDaemonSingleThreadScheduledExecutor("BroadcastStageTimeout")
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
index f69da86..d35bbe9 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
@@ -120,7 +120,7 @@ case class BroadcastExchangeExec(
System.nanoTime() - beforeBroadcast)
val executionId =
sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId,
metrics.values.toSeq)
- promise.success(broadcasted)
+ promise.trySuccess(broadcasted)
broadcasted
} catch {
// SPARK-24294: To bypass scala bug:
https://github.com/scala/bug/issues/9554, we throw
@@ -133,14 +133,14 @@ case class BroadcastExchangeExec(
s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or
increase the spark " +
s"driver memory by setting ${SparkLauncher.DRIVER_MEMORY} to
a higher value.")
.initCause(oe.getCause))
- promise.failure(ex)
+ promise.tryFailure(ex)
throw ex
case e if !NonFatal(e) =>
val ex = new SparkFatalException(e)
- promise.failure(ex)
+ promise.tryFailure(ex)
throw ex
case e: Throwable =>
- promise.failure(e)
+ promise.tryFailure(e)
throw e
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 5ce758e..64ecf5d 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.{Dataset, QueryTest, Row,
SparkSession}
import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast,
Literal, ShiftLeft}
import org.apache.spark.sql.catalyst.plans.logical.BROADCAST
import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec}
-import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper,
AdaptiveTestUtils, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.functions._
@@ -39,7 +39,8 @@ import org.apache.spark.sql.types.{LongType, ShortType}
* unsafe map in [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]]
is not triggered
* without serializing the hashed relation, which does not happen in local
mode.
*/
-class BroadcastJoinSuite extends QueryTest with SQLTestUtils with
AdaptiveSparkPlanHelper {
+abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils
+ with AdaptiveSparkPlanHelper {
import testImplicits._
protected var spark: SparkSession = null
@@ -398,4 +399,22 @@ class BroadcastJoinSuite extends QueryTest with
SQLTestUtils with AdaptiveSparkP
}
}
}
+
+ test("Broadcast timeout") {
+ val timeout = 5
+ val slowUDF = udf({ x: Int => Thread.sleep(timeout * 10 * 1000); x })
+ val df1 = spark.range(10).select($"id" as 'a)
+ val df2 = spark.range(5).select(slowUDF($"id") as 'a)
+ val testDf = df1.join(broadcast(df2), "a")
+ withSQLConf(SQLConf.BROADCAST_TIMEOUT.key -> timeout.toString) {
+ val e = intercept[Exception] {
+ testDf.collect()
+ }
+ AdaptiveTestUtils.assertExceptionMessage(e, s"Could not execute
broadcast in $timeout secs.")
+ }
+ }
}
+
+class BroadcastJoinSuite extends BroadcastJoinSuiteBase with
DisableAdaptiveExecutionSuite
+
+class BroadcastJoinSuiteAE extends BroadcastJoinSuiteBase with
EnableAdaptiveExecutionSuite
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]