jiangxb1987 commented on code in PR #56055:
URL: https://github.com/apache/spark/pull/56055#discussion_r3308494687
##########
core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala:
##########
@@ -1569,6 +1581,27 @@ private[spark] class DAGScheduler(
}
}
+ /**
+ * An experimental API to submit child stages even while the parents are
running. This is only
+ * used in [[ConcurrentStageDAGScheduler]]. It defined here since it depends
two private APIs in
+ * this class (namely submitMissingTasks() and activeJobForStage()).
Review Comment:
Would it be cleaner to relax submitMissingTasks and activeJobForStage to
protected and move both submitConcurrentStage and postSchedulerEvent into the
subclass? Right now the base class is carrying two helpers whose only purpose
is to back a subclass.
##########
core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala:
##########
@@ -1569,6 +1581,27 @@ private[spark] class DAGScheduler(
}
}
+ /**
+ * An experimental API to submit child stages even while the parents are
running. This is only
+ * used in [[ConcurrentStageDAGScheduler]]. It defined here since it depends
two private APIs in
+ * this class (namely submitMissingTasks() and activeJobForStage()).
+ */
+ protected def submitConcurrentStage(stage: Stage): Unit = {
+ assert(waitingStages.contains(stage))
+ activeJobForStage(stage) match {
+ case Some(job) =>
+ waitingStages -= stage
+ submitMissingTasks(stage, job)
+ case None => // Not expected.
+ new IllegalStateException(s"No active job for stage $stage")
Review Comment:
should be "throw new IllegalStateException" ?
##########
core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.Properties
+
+import scala.collection.mutable
+
+import org.apache.spark.{MapOutputTrackerMaster, SparkContext, SparkEnv,
SparkException, SparkRuntimeException, Success}
+import org.apache.spark.internal.LogKeys
+import org.apache.spark.internal.config.{SPECULATION_ENABLED,
STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED}
+import org.apache.spark.resource.ResourceProfile
+import org.apache.spark.storage.BlockManagerMaster
+import org.apache.spark.util.Clock
+import org.apache.spark.util.SystemClock
+
+/**
+ * A [[DAGScheduler]] that runs all the stages in a job without waiting for
its parents
+ * complete. This combined with streaming shuffle between the stages, allows
for low latency
+ * execution of streaming queries in real-time mode.
+ */
+class ConcurrentStageDAGScheduler(
+ sc: SparkContext,
+ taskScheduler: TaskScheduler,
+ listenerBus: LiveListenerBus,
+ mapOutputTracker: MapOutputTrackerMaster,
+ blockManagerMaster: BlockManagerMaster,
+ env: SparkEnv,
+ clock: Clock = new SystemClock())
+ extends DAGScheduler(
+ sc, taskScheduler, listenerBus, mapOutputTracker, blockManagerMaster, env,
clock) {
+
+ import ConcurrentStageDAGScheduler._
+
+ def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
+ this(
+ sc,
+ taskScheduler,
+ sc.listenerBus,
+ sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
+ sc.env.blockManager.master,
+ sc.env
+ )
+ }
+
+ def this(sc: SparkContext) = this(sc, sc.taskScheduler)
+
+ // This contains all the concurrent states that are yet to be scheduled
across all the jobs.
+ private[spark] val concurrentStages = new mutable.HashSet[Stage]
+
+ private[scheduler] case class DependentStageInfo(
+ parents: mutable.HashSet[Stage] = mutable.HashSet.empty,
+ delayedTaskCompletionEvents: mutable.ListBuffer[CompletionEvent] =
mutable.ListBuffer.empty)
+
+ // This map holds parents of concurrently scheduled stages. When tasks for
such a stage complete,
+ // and if any of the parents are still running, we delay processing of such
events until parent
+ // stages are complete. We save these events in this map until then.
+ private[spark] val dependentStageMap = new mutable.HashMap[Stage,
DependentStageInfo]
+
+ private def totalNumCoreForStage(stage: Stage): Int = {
+ val numTask = stage match {
+ case r: ResultStage => r.partitions.length
+ case m: ShuffleMapStage => m.numPartitions
+ }
+ val resourceProfile =
sc.resourceProfileManager.resourceProfileFromId(stage.resourceProfileId)
+ val taskCpus =
ResourceProfile.getTaskCpusOrDefaultForProfile(resourceProfile, sc.conf)
+ taskCpus * numTask
+ }
+
+ /**
+ * Hook invoked after the final stage is created. Registers stages reachable
from
+ * the final stage as concurrent so they can be submitted in parallel.
+ */
+ override def onFinalStageCreated(finalStage: Stage, properties: Properties):
Unit = {
+
+ val queryBatchId = getStreamingBatchIdFromProperties(properties)
+
+ if (queryBatchId.nonEmpty && isConcurrentStagesEnabled(properties)) {
+ if (properties.getProperty(SPECULATION_ENABLED.key) == "true") {
Review Comment:
This check only fires when speculation is set as a per-job local property.
Every other consumer in core (TaskSchedulerImpl/TaskSetManager/PairRDDFunctions
etc.) reads it via `conf.get(SPECULATION_ENABLED)`, which is the documented way
to enable speculation. Users with cluster-wide spark.speculation=true (the
common case) will silently bypass this guard.
Suggest also checking sc.conf.get(SPECULATION_ENABLED):
```
if (properties.getProperty(SPECULATION_ENABLED.key) == "true" ||
sc.conf.get(SPECULATION_ENABLED)) {
throw new SparkException(...)
}
```
##########
core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.Properties
+
+import scala.collection.mutable
+
+import org.apache.spark.{MapOutputTrackerMaster, SparkContext, SparkEnv,
SparkException, SparkRuntimeException, Success}
+import org.apache.spark.internal.LogKeys
+import org.apache.spark.internal.config.{SPECULATION_ENABLED,
STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED}
+import org.apache.spark.resource.ResourceProfile
+import org.apache.spark.storage.BlockManagerMaster
+import org.apache.spark.util.Clock
+import org.apache.spark.util.SystemClock
+
+/**
+ * A [[DAGScheduler]] that runs all the stages in a job without waiting for
its parents
+ * complete. This combined with streaming shuffle between the stages, allows
for low latency
+ * execution of streaming queries in real-time mode.
+ */
+class ConcurrentStageDAGScheduler(
+ sc: SparkContext,
+ taskScheduler: TaskScheduler,
+ listenerBus: LiveListenerBus,
+ mapOutputTracker: MapOutputTrackerMaster,
+ blockManagerMaster: BlockManagerMaster,
+ env: SparkEnv,
+ clock: Clock = new SystemClock())
+ extends DAGScheduler(
+ sc, taskScheduler, listenerBus, mapOutputTracker, blockManagerMaster, env,
clock) {
+
+ import ConcurrentStageDAGScheduler._
+
+ def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
+ this(
+ sc,
+ taskScheduler,
+ sc.listenerBus,
+ sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
+ sc.env.blockManager.master,
+ sc.env
+ )
+ }
+
+ def this(sc: SparkContext) = this(sc, sc.taskScheduler)
+
+ // This contains all the concurrent states that are yet to be scheduled
across all the jobs.
+ private[spark] val concurrentStages = new mutable.HashSet[Stage]
+
+ private[scheduler] case class DependentStageInfo(
+ parents: mutable.HashSet[Stage] = mutable.HashSet.empty,
+ delayedTaskCompletionEvents: mutable.ListBuffer[CompletionEvent] =
mutable.ListBuffer.empty)
+
+ // This map holds parents of concurrently scheduled stages. When tasks for
such a stage complete,
+ // and if any of the parents are still running, we delay processing of such
events until parent
+ // stages are complete. We save these events in this map until then.
+ private[spark] val dependentStageMap = new mutable.HashMap[Stage,
DependentStageInfo]
+
+ private def totalNumCoreForStage(stage: Stage): Int = {
+ val numTask = stage match {
+ case r: ResultStage => r.partitions.length
+ case m: ShuffleMapStage => m.numPartitions
+ }
+ val resourceProfile =
sc.resourceProfileManager.resourceProfileFromId(stage.resourceProfileId)
+ val taskCpus =
ResourceProfile.getTaskCpusOrDefaultForProfile(resourceProfile, sc.conf)
+ taskCpus * numTask
+ }
+
+ /**
+ * Hook invoked after the final stage is created. Registers stages reachable
from
+ * the final stage as concurrent so they can be submitted in parallel.
+ */
+ override def onFinalStageCreated(finalStage: Stage, properties: Properties):
Unit = {
+
+ val queryBatchId = getStreamingBatchIdFromProperties(properties)
+
+ if (queryBatchId.nonEmpty && isConcurrentStagesEnabled(properties)) {
+ if (properties.getProperty(SPECULATION_ENABLED.key) == "true") {
+ // Speculation is not supported with concurrent stages.
+ throw new SparkException(
+ "Speculative execution is not supported with concurrent stages " +
+ s"(streaming query: $queryBatchId). Please disable
${SPECULATION_ENABLED.key} config."
+ )
+ }
+
+ logInfo(log"Concurrent stages is enabled for [query
${MDC(LogKeys.STREAMING_QUERY_ID,
+ queryBatchId.get.queryId)} batch ${MDC(LogKeys.BATCH_ID,
queryBatchId.get.batchId)}]")
+
+ // Mark current stage and all its ancestors as concurrent
+ var totalCoresNeeded = 0
+ def visit(stage: Stage): Unit = {
+ if (!concurrentStages.contains(stage)) {
+ logInfo(log"Marking stage '${MDC(LogKeys.STAGE, stage)}' concurrent
for [query ${MDC(
+ LogKeys.STREAMING_QUERY_ID, queryBatchId.get.queryId)} batch ${MDC(
+ LogKeys.BATCH_ID, queryBatchId.get.batchId)}]")
+ concurrentStages += stage
+ totalCoresNeeded += totalNumCoreForStage(stage)
+ stage.parents.foreach(visit)
+ }
+ }
+ visit(finalStage)
+
+ if (!sc.conf.get(STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED)) {
+ try {
+ val totalSlots = sc.schedulerBackend.defaultParallelism()
+ val coresInUse =
runningStages.toArray.map(totalNumCoreForStage(_)).sum
+ if (totalSlots - coresInUse < totalCoresNeeded) {
+ throw new SparkRuntimeException(
Review Comment:
When this throws, the stages added to concurrentStages above are leaked —
handleJobSubmitted catches the exception and fails the job, but nothing ever
clears those entries. A subsequent job whose stages share IDs (e.g. retries
from the same RDDChain) would inherit them. Either clear concurrentStages of
the stages just visited before throwing, or capture them in a local set and
only commit to concurrentStages once the slot check passes.
##########
core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala:
##########
@@ -1069,7 +1076,12 @@ private[spark] class TaskSetManager(
emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid, tasks(index), reason,
null,
accumUpdates, metricPeaks)
- if (!isZombie && reason.countTowardsTaskFailures) {
+ val countTowardsTaskFailures = reason.countTowardsTaskFailures ||
+ // if the query is running in real time mode, any failures should
contribute the task failures
Review Comment:
"contribute the task failures" -> "count toward the task failures"
##########
core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.Properties
+
+import scala.collection.mutable
+
+import org.apache.spark.{MapOutputTrackerMaster, SparkContext, SparkEnv,
SparkException, SparkRuntimeException, Success}
+import org.apache.spark.internal.LogKeys
+import org.apache.spark.internal.config.{SPECULATION_ENABLED,
STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED}
+import org.apache.spark.resource.ResourceProfile
+import org.apache.spark.storage.BlockManagerMaster
+import org.apache.spark.util.Clock
+import org.apache.spark.util.SystemClock
+
+/**
+ * A [[DAGScheduler]] that runs all the stages in a job without waiting for
its parents
+ * complete. This combined with streaming shuffle between the stages, allows
for low latency
+ * execution of streaming queries in real-time mode.
+ */
+class ConcurrentStageDAGScheduler(
+ sc: SparkContext,
+ taskScheduler: TaskScheduler,
+ listenerBus: LiveListenerBus,
+ mapOutputTracker: MapOutputTrackerMaster,
+ blockManagerMaster: BlockManagerMaster,
+ env: SparkEnv,
+ clock: Clock = new SystemClock())
+ extends DAGScheduler(
+ sc, taskScheduler, listenerBus, mapOutputTracker, blockManagerMaster, env,
clock) {
+
+ import ConcurrentStageDAGScheduler._
+
+ def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
+ this(
+ sc,
+ taskScheduler,
+ sc.listenerBus,
+ sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
+ sc.env.blockManager.master,
+ sc.env
+ )
+ }
+
+ def this(sc: SparkContext) = this(sc, sc.taskScheduler)
+
+ // This contains all the concurrent states that are yet to be scheduled
across all the jobs.
+ private[spark] val concurrentStages = new mutable.HashSet[Stage]
+
+ private[scheduler] case class DependentStageInfo(
+ parents: mutable.HashSet[Stage] = mutable.HashSet.empty,
+ delayedTaskCompletionEvents: mutable.ListBuffer[CompletionEvent] =
mutable.ListBuffer.empty)
+
+ // This map holds parents of concurrently scheduled stages. When tasks for
such a stage complete,
+ // and if any of the parents are still running, we delay processing of such
events until parent
+ // stages are complete. We save these events in this map until then.
+ private[spark] val dependentStageMap = new mutable.HashMap[Stage,
DependentStageInfo]
+
+ private def totalNumCoreForStage(stage: Stage): Int = {
+ val numTask = stage match {
+ case r: ResultStage => r.partitions.length
+ case m: ShuffleMapStage => m.numPartitions
+ }
+ val resourceProfile =
sc.resourceProfileManager.resourceProfileFromId(stage.resourceProfileId)
+ val taskCpus =
ResourceProfile.getTaskCpusOrDefaultForProfile(resourceProfile, sc.conf)
+ taskCpus * numTask
+ }
+
+ /**
+ * Hook invoked after the final stage is created. Registers stages reachable
from
+ * the final stage as concurrent so they can be submitted in parallel.
+ */
+ override def onFinalStageCreated(finalStage: Stage, properties: Properties):
Unit = {
+
+ val queryBatchId = getStreamingBatchIdFromProperties(properties)
+
+ if (queryBatchId.nonEmpty && isConcurrentStagesEnabled(properties)) {
+ if (properties.getProperty(SPECULATION_ENABLED.key) == "true") {
+ // Speculation is not supported with concurrent stages.
+ throw new SparkException(
+ "Speculative execution is not supported with concurrent stages " +
+ s"(streaming query: $queryBatchId). Please disable
${SPECULATION_ENABLED.key} config."
+ )
+ }
+
+ logInfo(log"Concurrent stages is enabled for [query
${MDC(LogKeys.STREAMING_QUERY_ID,
+ queryBatchId.get.queryId)} batch ${MDC(LogKeys.BATCH_ID,
queryBatchId.get.batchId)}]")
+
+ // Mark current stage and all its ancestors as concurrent
+ var totalCoresNeeded = 0
+ def visit(stage: Stage): Unit = {
+ if (!concurrentStages.contains(stage)) {
+ logInfo(log"Marking stage '${MDC(LogKeys.STAGE, stage)}' concurrent
for [query ${MDC(
+ LogKeys.STREAMING_QUERY_ID, queryBatchId.get.queryId)} batch ${MDC(
+ LogKeys.BATCH_ID, queryBatchId.get.batchId)}]")
+ concurrentStages += stage
+ totalCoresNeeded += totalNumCoreForStage(stage)
+ stage.parents.foreach(visit)
+ }
+ }
+ visit(finalStage)
+
+ if (!sc.conf.get(STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED)) {
+ try {
+ val totalSlots = sc.schedulerBackend.defaultParallelism()
+ val coresInUse =
runningStages.toArray.map(totalNumCoreForStage(_)).sum
+ if (totalSlots - coresInUse < totalCoresNeeded) {
+ throw new SparkRuntimeException(
+ errorClass = "CONCURRENT_SCHEDULER_INSUFFICIENT_SLOT",
+ messageParameters = Map(
+ "numSlots" -> (totalSlots - coresInUse).toString,
+ "numTasks" -> totalCoresNeeded.toString))
+ }
+ } catch {
+ case e: UnsupportedOperationException =>
+ logWarning(log"${MDC(LogKeys.ERROR, e)}. Skipping slot check for
RTM.")
+ }
+ }
+ } else {
+ super.onFinalStageCreated(finalStage, properties)
+ }
+ }
+
+ override def submitStage(stage: Stage): Unit = {
+ super.submitStage(stage)
+
+ if (!waitingStages.contains(stage) && concurrentStages.contains(stage)) {
+ // The current stage is not registered in waitingStages, which means it
has
+ // no parents. This case we should remove it from concurrentStages since
it is already
+ // running.
+ assert(runningStages.contains(stage), "stage should be running if not in
waitingStages")
+ logInfo(log"Removing stage ${MDC(LogKeys.STAGE, stage)} from
concurrentStages")
+ concurrentStages -= stage
+ }
+
+ // Find the stages that should be submitted concurrently with this stage.
+ waitingStages.intersect(concurrentStages).foreach { stage =>
+ logInfo(log"Submitting stage concurrently: ${MDC(LogKeys.STAGE, stage)}")
+ concurrentStages -= stage // Don't submit this stage concurrently for
subsequent attempts.
+ stage.parents.foreach { parent =>
+ if (isRunningStage(parent)) {
+ logInfo(log"Updating dependent map for stage ${MDC(LogKeys.STAGE,
stage)} with parent ${
+ MDC(LogKeys.PARENT_STAGE, parent)}")
+ dependentStageMap.getOrElseUpdate(stage,
DependentStageInfo()).parents += parent
+ }
+ }
+ // Remove stage and its parents from concurrentStages
+ def removeFromConcurrentStages(stage: Stage): Unit = {
+ if (concurrentStages.contains(stage)) {
+ logInfo(log"Removing stage ${MDC(LogKeys.STAGE, stage)} from
concurrentStages")
+ concurrentStages -= stage
+ }
+ stage.parents.foreach { parent =>
+ assert(!waitingStages.contains(parent), "Parent stage should not
still be waiting")
+ removeFromConcurrentStages(parent)
+ }
+ }
+ removeFromConcurrentStages(stage)
+ submitConcurrentStage(stage)
+ }
+ }
+
+ // This is overridden to check if the task completion event should be
delayed a parent stage
+ // till has running tasks. See comment for `dependentStageMap` for more
details.
+ override private[scheduler] def handleTaskCompletion(event:
CompletionEvent): Unit = {
+ val stageId = event.task.stageId
+ val taskId = event.taskInfo.taskId
+
+ getStage(stageId) match {
+ case Some(stage) if event.reason == Success &&
dependentStageMap.contains(stage) =>
+ val dependentStageInfo = dependentStageMap(stage)
+ logInfo(log"Delaying completion event for task ${MDC(LogKeys.TASK_ID,
taskId)} in stage ${
+ MDC(LogKeys.STAGE, stage)}. Active parent(s):
${MDC(LogKeys.PARENT_STAGES,
+ dependentStageInfo.parents.mkString(", "))}")
+ dependentStageInfo.delayedTaskCompletionEvents += event
+
+ case _ => // Otherwise handle the event as usual.
+ super.handleTaskCompletion(event)
+ }
+ }
+
+ // This is overridden to handle any delayed task completion events for
dependent stages.
+ override def markStageAsFinished(
Review Comment:
The dependentStageMap cleanup path only fires when a stage in the map is
named as a parent via markStageAsFinished(parent). If a dependent stage itself
aborts mid-job (e.g. its single allowed failure under maxTaskFailures=1), its
own entry — including any buffered delayedTaskCompletionEvents — is never
removed from dependentStageMap. With concurrent jobs sharing a long-lived
scheduler instance, that's a slow leak across queries. Consider clearing the
entry for stage itself inside markStageAsFinished (especially when
errorMessage.isDefined).
##########
common/utils/src/main/resources/error/error-conditions.json:
##########
@@ -908,6 +890,12 @@
],
"sqlState" : "0A000"
},
+ "CONCURRENT_SCHEDULER_INSUFFICIENT_SLOT" : {
+ "message" : [
+ "The minimum number of free slots required in the cluster is <numTasks>,
however, the cluster has only has <numSlots> slots free. Query will stall or
fail. Increase cluster size to proceed."
Review Comment:
nit: "has only has" -> "has only"
##########
core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.Properties
+
+import scala.collection.mutable
+
+import org.apache.spark.{MapOutputTrackerMaster, SparkContext, SparkEnv,
SparkException, SparkRuntimeException, Success}
+import org.apache.spark.internal.LogKeys
+import org.apache.spark.internal.config.{SPECULATION_ENABLED,
STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED}
+import org.apache.spark.resource.ResourceProfile
+import org.apache.spark.storage.BlockManagerMaster
+import org.apache.spark.util.Clock
+import org.apache.spark.util.SystemClock
+
+/**
+ * A [[DAGScheduler]] that runs all the stages in a job without waiting for
its parents
+ * complete. This combined with streaming shuffle between the stages, allows
for low latency
+ * execution of streaming queries in real-time mode.
+ */
+class ConcurrentStageDAGScheduler(
+ sc: SparkContext,
+ taskScheduler: TaskScheduler,
+ listenerBus: LiveListenerBus,
+ mapOutputTracker: MapOutputTrackerMaster,
+ blockManagerMaster: BlockManagerMaster,
+ env: SparkEnv,
+ clock: Clock = new SystemClock())
+ extends DAGScheduler(
+ sc, taskScheduler, listenerBus, mapOutputTracker, blockManagerMaster, env,
clock) {
+
+ import ConcurrentStageDAGScheduler._
+
+ def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
+ this(
+ sc,
+ taskScheduler,
+ sc.listenerBus,
+ sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
+ sc.env.blockManager.master,
+ sc.env
+ )
+ }
+
+ def this(sc: SparkContext) = this(sc, sc.taskScheduler)
+
+ // This contains all the concurrent states that are yet to be scheduled
across all the jobs.
Review Comment:
Typo: "states" → "stages".
##########
common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java:
##########
@@ -792,6 +793,7 @@ public enum LogKeys implements LogKey {
STREAMING_DATA_SOURCE_NAME,
STREAMING_OFFSETS_END,
STREAMING_OFFSETS_START,
+ STREAMING_QUERY_ID,
Review Comment:
QUERY_ID already exists and is what
StructuredStreamingIdAwareSchedulerLogging uses to log streaming query IDs.
Adding STREAMING_QUERY_ID creates a parallel key for the same concept. Suggest
dropping this addition and using LogKeys.QUERY_ID at all the callsites, or
update the callsites in StructuredStreamingIdAwareSchedulerLogging.
##########
core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.Properties
+
+import scala.collection.mutable
+
+import org.apache.spark.{MapOutputTrackerMaster, SparkContext, SparkEnv,
SparkException, SparkRuntimeException, Success}
+import org.apache.spark.internal.LogKeys
+import org.apache.spark.internal.config.{SPECULATION_ENABLED,
STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED}
+import org.apache.spark.resource.ResourceProfile
+import org.apache.spark.storage.BlockManagerMaster
+import org.apache.spark.util.Clock
+import org.apache.spark.util.SystemClock
+
+/**
+ * A [[DAGScheduler]] that runs all the stages in a job without waiting for
its parents
+ * complete. This combined with streaming shuffle between the stages, allows
for low latency
+ * execution of streaming queries in real-time mode.
+ */
+class ConcurrentStageDAGScheduler(
+ sc: SparkContext,
+ taskScheduler: TaskScheduler,
+ listenerBus: LiveListenerBus,
+ mapOutputTracker: MapOutputTrackerMaster,
+ blockManagerMaster: BlockManagerMaster,
+ env: SparkEnv,
+ clock: Clock = new SystemClock())
+ extends DAGScheduler(
+ sc, taskScheduler, listenerBus, mapOutputTracker, blockManagerMaster, env,
clock) {
+
+ import ConcurrentStageDAGScheduler._
+
+ def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
+ this(
+ sc,
+ taskScheduler,
+ sc.listenerBus,
+ sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
+ sc.env.blockManager.master,
+ sc.env
+ )
+ }
+
+ def this(sc: SparkContext) = this(sc, sc.taskScheduler)
+
+ // This contains all the concurrent states that are yet to be scheduled
across all the jobs.
+ private[spark] val concurrentStages = new mutable.HashSet[Stage]
+
+ private[scheduler] case class DependentStageInfo(
+ parents: mutable.HashSet[Stage] = mutable.HashSet.empty,
+ delayedTaskCompletionEvents: mutable.ListBuffer[CompletionEvent] =
mutable.ListBuffer.empty)
+
+ // This map holds parents of concurrently scheduled stages. When tasks for
such a stage complete,
+ // and if any of the parents are still running, we delay processing of such
events until parent
+ // stages are complete. We save these events in this map until then.
+ private[spark] val dependentStageMap = new mutable.HashMap[Stage,
DependentStageInfo]
+
+ private def totalNumCoreForStage(stage: Stage): Int = {
+ val numTask = stage match {
+ case r: ResultStage => r.partitions.length
+ case m: ShuffleMapStage => m.numPartitions
+ }
+ val resourceProfile =
sc.resourceProfileManager.resourceProfileFromId(stage.resourceProfileId)
+ val taskCpus =
ResourceProfile.getTaskCpusOrDefaultForProfile(resourceProfile, sc.conf)
+ taskCpus * numTask
+ }
+
+ /**
+ * Hook invoked after the final stage is created. Registers stages reachable
from
+ * the final stage as concurrent so they can be submitted in parallel.
+ */
+ override def onFinalStageCreated(finalStage: Stage, properties: Properties):
Unit = {
+
+ val queryBatchId = getStreamingBatchIdFromProperties(properties)
+
+ if (queryBatchId.nonEmpty && isConcurrentStagesEnabled(properties)) {
+ if (properties.getProperty(SPECULATION_ENABLED.key) == "true") {
+ // Speculation is not supported with concurrent stages.
+ throw new SparkException(
+ "Speculative execution is not supported with concurrent stages " +
+ s"(streaming query: $queryBatchId). Please disable
${SPECULATION_ENABLED.key} config."
+ )
+ }
+
+ logInfo(log"Concurrent stages is enabled for [query
${MDC(LogKeys.STREAMING_QUERY_ID,
+ queryBatchId.get.queryId)} batch ${MDC(LogKeys.BATCH_ID,
queryBatchId.get.batchId)}]")
+
+ // Mark current stage and all its ancestors as concurrent
+ var totalCoresNeeded = 0
+ def visit(stage: Stage): Unit = {
+ if (!concurrentStages.contains(stage)) {
+ logInfo(log"Marking stage '${MDC(LogKeys.STAGE, stage)}' concurrent
for [query ${MDC(
+ LogKeys.STREAMING_QUERY_ID, queryBatchId.get.queryId)} batch ${MDC(
+ LogKeys.BATCH_ID, queryBatchId.get.batchId)}]")
+ concurrentStages += stage
+ totalCoresNeeded += totalNumCoreForStage(stage)
+ stage.parents.foreach(visit)
+ }
+ }
+ visit(finalStage)
+
+ if (!sc.conf.get(STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED)) {
+ try {
+ val totalSlots = sc.schedulerBackend.defaultParallelism()
+ val coresInUse =
runningStages.toArray.map(totalNumCoreForStage(_)).sum
+ if (totalSlots - coresInUse < totalCoresNeeded) {
+ throw new SparkRuntimeException(
+ errorClass = "CONCURRENT_SCHEDULER_INSUFFICIENT_SLOT",
+ messageParameters = Map(
+ "numSlots" -> (totalSlots - coresInUse).toString,
+ "numTasks" -> totalCoresNeeded.toString))
+ }
+ } catch {
+ case e: UnsupportedOperationException =>
+ logWarning(log"${MDC(LogKeys.ERROR, e)}. Skipping slot check for
RTM.")
+ }
+ }
+ } else {
+ super.onFinalStageCreated(finalStage, properties)
+ }
+ }
+
+ override def submitStage(stage: Stage): Unit = {
+ super.submitStage(stage)
+
+ if (!waitingStages.contains(stage) && concurrentStages.contains(stage)) {
+ // The current stage is not registered in waitingStages, which means it
has
+ // no parents. This case we should remove it from concurrentStages since
it is already
+ // running.
+ assert(runningStages.contains(stage), "stage should be running if not in
waitingStages")
+ logInfo(log"Removing stage ${MDC(LogKeys.STAGE, stage)} from
concurrentStages")
+ concurrentStages -= stage
+ }
+
+ // Find the stages that should be submitted concurrently with this stage.
+ waitingStages.intersect(concurrentStages).foreach { stage =>
+ logInfo(log"Submitting stage concurrently: ${MDC(LogKeys.STAGE, stage)}")
+ concurrentStages -= stage // Don't submit this stage concurrently for
subsequent attempts.
+ stage.parents.foreach { parent =>
+ if (isRunningStage(parent)) {
+ logInfo(log"Updating dependent map for stage ${MDC(LogKeys.STAGE,
stage)} with parent ${
+ MDC(LogKeys.PARENT_STAGE, parent)}")
+ dependentStageMap.getOrElseUpdate(stage,
DependentStageInfo()).parents += parent
+ }
+ }
+ // Remove stage and its parents from concurrentStages
+ def removeFromConcurrentStages(stage: Stage): Unit = {
+ if (concurrentStages.contains(stage)) {
+ logInfo(log"Removing stage ${MDC(LogKeys.STAGE, stage)} from
concurrentStages")
+ concurrentStages -= stage
+ }
+ stage.parents.foreach { parent =>
+ assert(!waitingStages.contains(parent), "Parent stage should not
still be waiting")
+ removeFromConcurrentStages(parent)
+ }
+ }
+ removeFromConcurrentStages(stage)
+ submitConcurrentStage(stage)
+ }
+ }
+
+ // This is overridden to check if the task completion event should be
delayed a parent stage
+ // till has running tasks. See comment for `dependentStageMap` for more
details.
+ override private[scheduler] def handleTaskCompletion(event:
CompletionEvent): Unit = {
+ val stageId = event.task.stageId
+ val taskId = event.taskInfo.taskId
+
+ getStage(stageId) match {
+ case Some(stage) if event.reason == Success &&
dependentStageMap.contains(stage) =>
+ val dependentStageInfo = dependentStageMap(stage)
+ logInfo(log"Delaying completion event for task ${MDC(LogKeys.TASK_ID,
taskId)} in stage ${
+ MDC(LogKeys.STAGE, stage)}. Active parent(s):
${MDC(LogKeys.PARENT_STAGES,
+ dependentStageInfo.parents.mkString(", "))}")
+ dependentStageInfo.delayedTaskCompletionEvents += event
+
+ case _ => // Otherwise handle the event as usual.
+ super.handleTaskCompletion(event)
+ }
+ }
+
+ // This is overridden to handle any delayed task completion events for
dependent stages.
+ override def markStageAsFinished(
+ stage: Stage,
+ errorMessage: Option[String] = None,
+ willRetry: Boolean = false): Unit = {
+
+ super.markStageAsFinished(stage, errorMessage, willRetry)
+
+ // If this is a parent of a stage in dependentStageMap, remove it from
parents.
+ val dependentStages = dependentStageMap
+ .filter(_._2.parents.contains(stage))
+ .keys
+
+ dependentStages.foreach { dependent =>
+ if (errorMessage.isEmpty) {
+ assert(
+ isRunningStage(dependent),
+ s"Parent stages $stage's dependent stage $dependent should be
running")
+ }
+ logInfo(log"Removing parent stage ${MDC(LogKeys.PARENT_STAGE, stage)}
from dependent map " +
+ log"for stage ${MDC(LogKeys.STAGE, dependent)}")
+ dependentStageMap(dependent).parents -= stage
+ checkDependentStageTasks(dependent)
+ }
+ }
+
+ // Checks if the dependent stage's parents are all done. If all the parents
are done,
+ // enqueues any saved task completion event (if any).
+ private def checkDependentStageTasks(stage: Stage): Unit = {
+ val dependentStageInfo = dependentStageMap.getOrElse(
+ stage, throw new RuntimeException(s"Stage $stage is not in
dependentStageMap")
+ )
+
+ if (dependentStageInfo.parents.isEmpty) {
+ val delayedEvents = dependentStageInfo.delayedTaskCompletionEvents
+ logInfo(log"All the parents are done for ${MDC(LogKeys.STAGE, stage)}.
Removing it from " +
+ log"the map. It has ${MDC(LogKeys.NUM_EVENTS,
delayedEvents.size.toLong)} " +
+ log"task completion events")
+ dependentStageMap -= stage
+ delayedEvents.foreach { event =>
+ logInfo(log"Posting delayed task ${MDC(LogKeys.TASK_ID,
event.taskInfo.taskId)} " +
+ log"completion event for stage ${MDC(LogKeys.STAGE, stage)}")
+ postSchedulerEvent(event)
+ }
+ }
+ }
+}
+
+object ConcurrentStageDAGScheduler {
+
+ val CONCURRENT_STAGES_ENABLED_PROPERTY: String =
"streaming.concurrent.stages.enabled"
+
+ def isConcurrentStagesEnabled(properties: Properties): Boolean = {
+ properties != null &&
+ properties.getProperty(CONCURRENT_STAGES_ENABLED_PROPERTY) == "true"
+ }
+
+ /**
+ * Extracts the [[StreamingBatchId]] from the given properties if all three
of the streaming
+ * query id, run id and batch id are present.
+ */
+ def getStreamingBatchIdFromProperties(properties: Properties):
Option[StreamingBatchId] = {
+ if (properties == null) {
+ return None
+ }
+
+ val queryId = Option(properties.getProperty("sql.streaming.queryId"))
Review Comment:
These property keys are already defined as constants in the same package:
StructuredStreamingIdAwareSchedulerLogging.QUERY_ID_KEY and BATCH_ID_KEY.
Suggest reusing them (and adding a RUN_ID_KEY there) so all consumers of
streaming-job properties go through one source.
Also: StreamingBatchId.runId is set here but never read anywhere — if it's
not intended to be consumed, drop the field from the case class; if it is, the
consuming code is missing.
##########
core/src/main/scala/org/apache/spark/internal/config/package.scala:
##########
@@ -2396,6 +2396,26 @@ package object config {
.booleanConf
.createWithDefault(true)
+ private[spark] val STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED =
Review Comment:
Normally we would use _ENABLED instead of _DISABLED, to avoid
double-negative.
##########
core/src/test/scala/org/apache/spark/scheduler/ConcurrentStageDAGSchedulerSuite.scala:
##########
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.Properties
+
+import org.apache.spark.HashPartitioner
+import org.apache.spark.ShuffleDependency
+import org.apache.spark.SparkConf
+import org.apache.spark.SparkContext
+import org.apache.spark.internal.config.SPECULATION_ENABLED
+import
org.apache.spark.internal.config.STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED
+
+class ConcurrentStageDAGSchedulerSuite extends DAGSchedulerSuiteBase {
+
+ // The unit-test SparkContext runs in local[2] mode, but the concurrent
pipelines exercised
+ // here often need more slots than that. Disable the slot check so the tests
aren't gated by
+ // executor capacity.
+ override def conf: SparkConf =
+ super.conf.set(STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED, true)
+
+ class TestConcurrentStageDAGScheduler(sc: SparkContext)
+ extends ConcurrentStageDAGScheduler(
+ sc,
+ taskScheduler,
+ sc.listenerBus,
+ mapOutputTracker,
+ blockManagerMaster,
+ sc.env)
+ with TestDAGScheduler
+
+ override def createInitialScheduler(sc: SparkContext): DAGScheduler = {
+ new TestConcurrentStageDAGScheduler(sc)
+ }
+
+ // Catch the job failure exception with a listener.
+ private class TestJobListener extends JobListener {
+ private var failureException: Option[Exception] = None
+
+ override def jobFailed(exception: Exception): Unit = {
+ failureException = Some(exception)
+ }
+
+ override def taskSucceeded(index: Int, result: Any): Unit = { }
+
+ def expectFailure(): Exception = {
+ assert(failureException.nonEmpty, "Job was expected to fail with an
exception, but didn't")
+ failureException.get
+ }
+ }
+
+
+ /** Default job properties with query settings and concurrent stages
enabled. */
+ private val testProperties: Properties = {
+ val properties = new Properties()
+ properties.setProperty("sql.streaming.queryId", "test_query_id")
+ properties.setProperty("sql.streaming.runId", "test_run_id")
+ properties.setProperty("streaming.sql.batchId", "5")
+
properties.setProperty(ConcurrentStageDAGScheduler.CONCURRENT_STAGES_ENABLED_PROPERTY,
"true")
+ new Properties(properties) {
+ // Make it read-only.
+ override def setProperty(key: String, value: String): AnyRef = {
+ throw new UnsupportedOperationException("Default properties are
read-only.")
+ }
+ }
+ }
+
+ test("Simple job with two concurrent stages") {
+ // Run a simple job with two stages. Both stages should be running
concurrently.
+
+ val mapStage = new MyRDD(sc, 1, Nil) // stage_0
+ val shuffleDep = new ShuffleDependency(mapStage, new HashPartitioner(1))
+ val resultStage = new MyRDD(sc, 3, List(shuffleDep)) // stage_1
+
+ // Shape: [stage_0, map stage, parent] <--- [stage_1, result stage]
+
+ submit(resultStage, Array(0), properties = testProperties)
+
+ assert(scheduler.waitingStages.isEmpty) // Both are submitted.
+ assert(scheduler.runningStages.map(_.id) === Set(0, 1)) // Both stages are
running.
+
+ // Verify concurrent scheduler specific state.
+ val concurrentScheduler =
scheduler.asInstanceOf[TestConcurrentStageDAGScheduler]
+
+ assert(concurrentScheduler.concurrentStages.isEmpty) // All are already
scheduled
+
+ val depStageMap = concurrentScheduler.dependentStageMap
+ assert(depStageMap.keys.map(_.id) == Set(1)) // Result stage is the key.
+ assert(depStageMap.values.flatMap(_.parents.map(_.id)) == Seq(0)) // Map
stage is the parent.
+ assert(depStageMap.values.flatMap(_.delayedTaskCompletionEvents).isEmpty)
// No completed tasks.
+
+ // First complete the result stage. Its tasks will complete, but the
actual stage would still
+ // be running since its parent (map stage) hasn't completed yet.
+
+ completeNextResultStageWithSuccess(1, 0)
+ assert(scheduler.runningStages.map(_.id) === Set(0, 1)) // Both stages are
still running.
+ // dependentStageMap should have the completed task from result stage
enqueued.
+ assert(depStageMap.values.flatMap(_.delayedTaskCompletionEvents).size ===
1)
+
+ // Now complete the map stage. This should complete the result stage as
well.
+ completeShuffleMapStageSuccessfully(0, 0, 1)
+
+ assert(scheduler.runningStages.map(_.id) === Set()) // Both stages are
complete.
+ assert(depStageMap.isEmpty) // No more dependent stages.
+
+ assertDataStructuresEmpty()
+ }
+
+ test("Default scheduler using a simple job with concurrent stages disabled")
{
+ // This is opposite of the previous test. Concurrent stages are disabled,
so the stages should
+ // be submitted one after the other.
+
+ val mapStage = new MyRDD(sc, 1, Nil) // stage_0
+ val shuffleDep = new ShuffleDependency(mapStage, new HashPartitioner(1))
+ val resultStage = new MyRDD(sc, 3, List(shuffleDep)) // stage_1
+
+ // Shape: [stage_0, map stage, parent] <--- [stage_1, result stage]
+
+ submit(resultStage, Array(0), properties = new Properties())
+
+ assert(scheduler.runningStages.map(_.id) == Set(0)) // Only the map stage
is running.
+ assert(scheduler.waitingStages.map(_.id) == Set(1)) // Result stage is
waiting.
+
+ val concurrentScheduler =
scheduler.asInstanceOf[TestConcurrentStageDAGScheduler]
+ assert(concurrentScheduler.concurrentStages.isEmpty) // No concurrent
stages.
+ assert(concurrentScheduler.dependentStageMap.isEmpty) // No dependent
stages.
+
+ // Complete the map stage. This should submit the result stage.
+ completeShuffleMapStageSuccessfully(0, 0, 1)
+
+ assert(scheduler.runningStages.map(_.id) == Set(1)) // Only the result
stage is running.
+ assert(scheduler.waitingStages.map(_.id) == Set()) // No waiting stages
+
+ completeNextResultStageWithSuccess(1, 0)
+ assertDataStructuresEmpty()
+ }
+
+ test("Complex pipeline with many stages") {
+ // Run a complex pipeline with multiple stages with multiple branches.
Such a pipeline not
+ // common, but useful to ensure scheduler works as expected.
+
+ // Shape:
+ // /<-------------------- stage_D
+ // stage_A <--- stage_B <--- stage_C <---\ ^
+ // \ \<---------/ \ |
+ // \ <----------------/ \ |
+ // stage_E <--------------------------- stage_F
+
+ // All of these should be running concurrently.
+
+ val rddA = new MyRDD(sc, 2, Nil).setName("rddA")
+ val shuffleDepA = new ShuffleDependency(rddA, new HashPartitioner(1))
+
+ val rddB = new MyRDD(sc, 1, List(shuffleDepA)).setName("rddB")
+ val shuffleDepB = new ShuffleDependency(rddB, new HashPartitioner(1))
+
+ val rddC = new MyRDD(sc, 1, List(shuffleDepA, shuffleDepB)).setName("rddC")
+ val shuffleDepC = new ShuffleDependency(rddC, new HashPartitioner(5))
+
+ val rddD = new MyRDD(sc, 4, List(shuffleDepB)).setName("rddD")
+ val shuffleDepD = new ShuffleDependency(rddD, new HashPartitioner(5))
+
+ val rddE = new MyRDD(sc, 3, Nil).setName("rddE")
+ val shuffleDepE = new ShuffleDependency(rddE, new HashPartitioner(5))
+
+ val rddF = new MyRDD(sc, 5, List(shuffleDepC, shuffleDepD,
shuffleDepE)).setName("rddF")
+
+ submit(rddF, Array(0, 1, 2, 3, 4), properties = testProperties)
+
+ assert(scheduler.waitingStages.isEmpty) // All the stages are submitted.
+ assert(scheduler.runningStages.map(_.id) == Set(0, 1, 2, 3, 4, 5)) // All
6 are running.
+
+ // Assign stage ids corresponding to the RDDs A, B, etc
+ def stageFor(rddName: String): Stage =
scheduler.runningStages.find(_.rdd.name == rddName).get
+
+ val sA = stageFor("rddA")
+ val sB = stageFor("rddB")
+ val sC = stageFor("rddC")
+ val sD = stageFor("rddD")
+ val sE = stageFor("rddE")
+ val sF = stageFor("rddF")
+
+ // log stage id mapping for debugging:
+ for (name <- List("A", "B", "C", "D", "E", "F")) {
+ logInfo(s"Stage id for stage $name is ${stageFor("rdd" + name).id}")
+ }
+
+ // Verify concurrent scheduler specific state.
+ val concurrentScheduler =
scheduler.asInstanceOf[TestConcurrentStageDAGScheduler]
+
+ assert(concurrentScheduler.concurrentStages.isEmpty) // All are already
scheduled
+
+ val depStageMap = concurrentScheduler.dependentStageMap
+ assert(depStageMap.keys === Set(sB, sC, sD, sF)) // All non-root stages
are keys.
+ assert(depStageMap.values.flatMap(_.parents).toSet === Set(
+ sA, sB, sC, sD, sE)) // All except the results stage.
+ assert(depStageMap.values.flatMap(_.delayedTaskCompletionEvents).isEmpty)
// No completed tasks.
+
+ // Complete stages in order-of-order and verify the state.
+
+ // First complete C. Entry for C would be updated with the completed task.
+ assert(depStageMap(sC).delayedTaskCompletionEvents.size === 0)
+ completeShuffleMapStageSuccessfully(sC.id, 0, 5) // Complete stage C.
+ assert(depStageMap(sC).delayedTaskCompletionEvents.size === 1)
+
+ // All the 6 stages are still 'running' since C's completion events are
delayed.
+ assert(scheduler.runningStages.map(_.id) == Set(0, 1, 2, 3, 4, 5))
+
+ // Now complete stage D. This is similar to completing C. The tasks are
enqueued.
+ assert(depStageMap(sD).delayedTaskCompletionEvents.size === 0)
+ completeShuffleMapStageSuccessfully(sD.id, 0, 5) // Complete stage D
+ assert(depStageMap(sD).delayedTaskCompletionEvents.size === 4) // 4 tasks
in stage C.
Review Comment:
"stage C" -> "stage D"
##########
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala:
##########
@@ -6107,7 +6123,7 @@ class DAGSchedulerSuite extends SparkFunSuite with
TempLocalSparkContext with Ti
}
}
- private def assertDataStructuresEmpty(): Unit = {
+ protected def assertDataStructuresEmpty(): Unit = {
Review Comment:
The new suite inherits all the tests against ConcurrentStageDAGScheduler,
but this helper doesn't validate the new state. A leak in concurrentStages or
dependentStageMap (see the slot-check and aborted-stage concerns above)
wouldn't be caught by any inherited test. Consider overriding in
ConcurrentStageDAGSchedulerSuite(or adding a protected def extraEmptyChecks:
Unit = () hook here) so subclasses can extend the empty-state assertions.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]