This is an automated email from the ASF dual-hosted git repository. lixiao pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 44d370d [SPARK-31475][SQL] Broadcast stage in AQE did not timeout 44d370d is described below commit 44d370dd4501f0a4abb7194f7cff0d346aac0992 Author: Maryann Xue <maryann....@gmail.com> AuthorDate: Mon Apr 20 11:55:48 2020 -0700 [SPARK-31475][SQL] Broadcast stage in AQE did not timeout ### What changes were proposed in this pull request? This PR adds a timeout for the Future of a BroadcastQueryStageExec to make sure it can have the same timeout behavior as a non-AQE broadcast exchange. ### Why are the changes needed? This is to make the broadcast timeout behavior in AQE consistent with that in non-AQE. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Added UT. Closes #28250 from maryannxue/aqe-broadcast-timeout. Authored-by: Maryann Xue <maryann....@gmail.com> Signed-off-by: gatorsmile <gatorsm...@gmail.com> --- .../execution/adaptive/AdaptiveSparkPlanExec.scala | 2 +- .../sql/execution/adaptive/QueryStageExec.scala | 35 ++++++++++++++++++---- .../execution/exchange/BroadcastExchangeExec.scala | 8 ++--- .../sql/execution/joins/BroadcastJoinSuite.scala | 23 ++++++++++++-- 4 files changed, 56 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 3ac4ea5..f819937 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -547,7 +547,7 @@ case class AdaptiveSparkPlanExec( } object AdaptiveSparkPlanExec { - private val executionContext = ExecutionContext.fromExecutorService( + private[adaptive] val executionContext = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonCachedThreadPool("QueryStageCreator", 16)) /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala index beaa972..f414f85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.execution.adaptive -import scala.concurrent.Future +import java.util.concurrent.TimeUnit -import org.apache.spark.{FutureAction, MapOutputStatistics} +import scala.concurrent.{Future, Promise} + +import org.apache.spark.{FutureAction, MapOutputStatistics, SparkException} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -28,6 +30,8 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.ThreadUtils /** * A query stage is an independent subgraph of the query plan. Query stage materializes its output @@ -100,8 +104,8 @@ abstract class QueryStageExec extends LeafExecNode { override def executeTail(n: Int): Array[InternalRow] = plan.executeTail(n) override def executeToIterator(): Iterator[InternalRow] = plan.executeToIterator() - override def doPrepare(): Unit = plan.prepare() - override def doExecute(): RDD[InternalRow] = plan.execute() + protected override def doPrepare(): Unit = plan.prepare() + protected override def doExecute(): RDD[InternalRow] = plan.execute() override def doExecuteBroadcast[T](): Broadcast[T] = plan.executeBroadcast() override def doCanonicalize(): SparkPlan = plan.canonicalized @@ -187,8 +191,24 @@ case class BroadcastQueryStageExec( throw new IllegalStateException("wrong plan for broadcast stage:\n " + plan.treeString) } + @transient private lazy val materializeWithTimeout = { + val broadcastFuture = broadcast.completionFuture + val timeout = SQLConf.get.broadcastTimeout + val promise = Promise[Any]() + val fail = BroadcastQueryStageExec.scheduledExecutor.schedule(new Runnable() { + override def run(): Unit = { + promise.tryFailure(new SparkException(s"Could not execute broadcast in $timeout secs. " + + s"You can increase the timeout for broadcasts via ${SQLConf.BROADCAST_TIMEOUT.key} or " + + s"disable broadcast join by setting ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1")) + } + }, timeout, TimeUnit.SECONDS) + broadcastFuture.onComplete(_ => fail.cancel(false))(AdaptiveSparkPlanExec.executionContext) + Future.firstCompletedOf( + Seq(broadcastFuture, promise.future))(AdaptiveSparkPlanExec.executionContext) + } + override def doMaterialize(): Future[Any] = { - broadcast.completionFuture + materializeWithTimeout } override def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec = { @@ -204,3 +224,8 @@ case class BroadcastQueryStageExec( } } } + +object BroadcastQueryStageExec { + private val scheduledExecutor = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("BroadcastStageTimeout") +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index f69da86..d35bbe9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -120,7 +120,7 @@ case class BroadcastExchangeExec( System.nanoTime() - beforeBroadcast) val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) - promise.success(broadcasted) + promise.trySuccess(broadcasted) broadcasted } catch { // SPARK-24294: To bypass scala bug: https://github.com/scala/bug/issues/9554, we throw @@ -133,14 +133,14 @@ case class BroadcastExchangeExec( s"${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key} to -1 or increase the spark " + s"driver memory by setting ${SparkLauncher.DRIVER_MEMORY} to a higher value.") .initCause(oe.getCause)) - promise.failure(ex) + promise.tryFailure(ex) throw ex case e if !NonFatal(e) => val ex = new SparkFatalException(e) - promise.failure(ex) + promise.tryFailure(ex) throw ex case e: Throwable => - promise.failure(e) + promise.tryFailure(e) throw e } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 5ce758e..64ecf5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} import org.apache.spark.sql.catalyst.plans.logical.BROADCAST import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec} -import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AdaptiveTestUtils, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.functions._ @@ -39,7 +39,8 @@ import org.apache.spark.sql.types.{LongType, ShortType} * unsafe map in [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered * without serializing the hashed relation, which does not happen in local mode. */ -class BroadcastJoinSuite extends QueryTest with SQLTestUtils with AdaptiveSparkPlanHelper { +abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils + with AdaptiveSparkPlanHelper { import testImplicits._ protected var spark: SparkSession = null @@ -398,4 +399,22 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils with AdaptiveSparkP } } } + + test("Broadcast timeout") { + val timeout = 5 + val slowUDF = udf({ x: Int => Thread.sleep(timeout * 10 * 1000); x }) + val df1 = spark.range(10).select($"id" as 'a) + val df2 = spark.range(5).select(slowUDF($"id") as 'a) + val testDf = df1.join(broadcast(df2), "a") + withSQLConf(SQLConf.BROADCAST_TIMEOUT.key -> timeout.toString) { + val e = intercept[Exception] { + testDf.collect() + } + AdaptiveTestUtils.assertExceptionMessage(e, s"Could not execute broadcast in $timeout secs.") + } + } } + +class BroadcastJoinSuite extends BroadcastJoinSuiteBase with DisableAdaptiveExecutionSuite + +class BroadcastJoinSuiteAE extends BroadcastJoinSuiteBase with EnableAdaptiveExecutionSuite --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org