jiangxb1987 commented on code in PR #56055:
URL: https://github.com/apache/spark/pull/56055#discussion_r3308494687


##########
core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala:
##########
@@ -1569,6 +1581,27 @@ private[spark] class DAGScheduler(
     }
   }
 
+  /**
+   * An experimental API to submit child stages even while the parents are 
running. This is only
+   * used in [[ConcurrentStageDAGScheduler]]. It defined here since it depends 
two private APIs in
+   * this class (namely submitMissingTasks() and activeJobForStage()).

Review Comment:
   Would it be cleaner to relax submitMissingTasks and activeJobForStage to 
protected and move both submitConcurrentStage and postSchedulerEvent into the 
subclass? Right now the base class is carrying two helpers whose only purpose 
is to back a subclass.



##########
core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala:
##########
@@ -1569,6 +1581,27 @@ private[spark] class DAGScheduler(
     }
   }
 
+  /**
+   * An experimental API to submit child stages even while the parents are 
running. This is only
+   * used in [[ConcurrentStageDAGScheduler]]. It defined here since it depends 
two private APIs in
+   * this class (namely submitMissingTasks() and activeJobForStage()).
+   */
+  protected def submitConcurrentStage(stage: Stage): Unit = {
+    assert(waitingStages.contains(stage))
+    activeJobForStage(stage) match {
+      case Some(job) =>
+        waitingStages -= stage
+        submitMissingTasks(stage, job)
+      case None => // Not expected.
+        new IllegalStateException(s"No active job for stage $stage")

Review Comment:
   should be "throw new IllegalStateException" ?



##########
core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.Properties
+
+import scala.collection.mutable
+
+import org.apache.spark.{MapOutputTrackerMaster, SparkContext, SparkEnv, 
SparkException, SparkRuntimeException, Success}
+import org.apache.spark.internal.LogKeys
+import org.apache.spark.internal.config.{SPECULATION_ENABLED, 
STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED}
+import org.apache.spark.resource.ResourceProfile
+import org.apache.spark.storage.BlockManagerMaster
+import org.apache.spark.util.Clock
+import org.apache.spark.util.SystemClock
+
+/**
+ *  A [[DAGScheduler]] that runs all the stages in a job without waiting for 
its parents
+ *  complete. This combined with streaming shuffle between the stages, allows 
for low latency
+ *  execution of streaming queries in real-time mode.
+ */
+class ConcurrentStageDAGScheduler(
+    sc: SparkContext,
+    taskScheduler: TaskScheduler,
+    listenerBus: LiveListenerBus,
+    mapOutputTracker: MapOutputTrackerMaster,
+    blockManagerMaster: BlockManagerMaster,
+    env: SparkEnv,
+    clock: Clock = new SystemClock())
+  extends DAGScheduler(
+    sc, taskScheduler, listenerBus, mapOutputTracker, blockManagerMaster, env, 
clock) {
+
+  import ConcurrentStageDAGScheduler._
+
+  def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
+    this(
+      sc,
+      taskScheduler,
+      sc.listenerBus,
+      sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
+      sc.env.blockManager.master,
+      sc.env
+    )
+  }
+
+  def this(sc: SparkContext) = this(sc, sc.taskScheduler)
+
+  // This contains all the concurrent states that are yet to be scheduled 
across all the jobs.
+  private[spark] val concurrentStages = new mutable.HashSet[Stage]
+
+  private[scheduler] case class DependentStageInfo(
+    parents: mutable.HashSet[Stage] = mutable.HashSet.empty,
+    delayedTaskCompletionEvents: mutable.ListBuffer[CompletionEvent] = 
mutable.ListBuffer.empty)
+
+  // This map holds parents of concurrently scheduled stages. When tasks for 
such a stage complete,
+  // and if any of the parents are still running, we delay processing of such 
events until parent
+  // stages are complete. We save these events in this map until then.
+  private[spark] val dependentStageMap = new mutable.HashMap[Stage, 
DependentStageInfo]
+
+  private def totalNumCoreForStage(stage: Stage): Int = {
+    val numTask = stage match {
+      case r: ResultStage => r.partitions.length
+      case m: ShuffleMapStage => m.numPartitions
+    }
+    val resourceProfile = 
sc.resourceProfileManager.resourceProfileFromId(stage.resourceProfileId)
+    val taskCpus = 
ResourceProfile.getTaskCpusOrDefaultForProfile(resourceProfile, sc.conf)
+    taskCpus * numTask
+  }
+
+  /**
+   * Hook invoked after the final stage is created. Registers stages reachable 
from
+   * the final stage as concurrent so they can be submitted in parallel.
+   */
+  override def onFinalStageCreated(finalStage: Stage, properties: Properties): 
Unit = {
+
+    val queryBatchId = getStreamingBatchIdFromProperties(properties)
+
+    if (queryBatchId.nonEmpty && isConcurrentStagesEnabled(properties)) {
+      if (properties.getProperty(SPECULATION_ENABLED.key) == "true") {

Review Comment:
   This check only fires when speculation is set as a per-job local property. 
Every other consumer in core (TaskSchedulerImpl/TaskSetManager/PairRDDFunctions 
etc.) reads it via `conf.get(SPECULATION_ENABLED)`, which is the documented way 
to enable speculation. Users with cluster-wide spark.speculation=true (the 
common case) will silently bypass this guard.
   
   Suggest also checking sc.conf.get(SPECULATION_ENABLED):
   ```
      if (properties.getProperty(SPECULATION_ENABLED.key) == "true" ||
          sc.conf.get(SPECULATION_ENABLED)) {
        throw new SparkException(...)
      }
   ```



##########
core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.Properties
+
+import scala.collection.mutable
+
+import org.apache.spark.{MapOutputTrackerMaster, SparkContext, SparkEnv, 
SparkException, SparkRuntimeException, Success}
+import org.apache.spark.internal.LogKeys
+import org.apache.spark.internal.config.{SPECULATION_ENABLED, 
STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED}
+import org.apache.spark.resource.ResourceProfile
+import org.apache.spark.storage.BlockManagerMaster
+import org.apache.spark.util.Clock
+import org.apache.spark.util.SystemClock
+
+/**
+ *  A [[DAGScheduler]] that runs all the stages in a job without waiting for 
its parents
+ *  complete. This combined with streaming shuffle between the stages, allows 
for low latency
+ *  execution of streaming queries in real-time mode.
+ */
+class ConcurrentStageDAGScheduler(
+    sc: SparkContext,
+    taskScheduler: TaskScheduler,
+    listenerBus: LiveListenerBus,
+    mapOutputTracker: MapOutputTrackerMaster,
+    blockManagerMaster: BlockManagerMaster,
+    env: SparkEnv,
+    clock: Clock = new SystemClock())
+  extends DAGScheduler(
+    sc, taskScheduler, listenerBus, mapOutputTracker, blockManagerMaster, env, 
clock) {
+
+  import ConcurrentStageDAGScheduler._
+
+  def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
+    this(
+      sc,
+      taskScheduler,
+      sc.listenerBus,
+      sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
+      sc.env.blockManager.master,
+      sc.env
+    )
+  }
+
+  def this(sc: SparkContext) = this(sc, sc.taskScheduler)
+
+  // This contains all the concurrent states that are yet to be scheduled 
across all the jobs.
+  private[spark] val concurrentStages = new mutable.HashSet[Stage]
+
+  private[scheduler] case class DependentStageInfo(
+    parents: mutable.HashSet[Stage] = mutable.HashSet.empty,
+    delayedTaskCompletionEvents: mutable.ListBuffer[CompletionEvent] = 
mutable.ListBuffer.empty)
+
+  // This map holds parents of concurrently scheduled stages. When tasks for 
such a stage complete,
+  // and if any of the parents are still running, we delay processing of such 
events until parent
+  // stages are complete. We save these events in this map until then.
+  private[spark] val dependentStageMap = new mutable.HashMap[Stage, 
DependentStageInfo]
+
+  private def totalNumCoreForStage(stage: Stage): Int = {
+    val numTask = stage match {
+      case r: ResultStage => r.partitions.length
+      case m: ShuffleMapStage => m.numPartitions
+    }
+    val resourceProfile = 
sc.resourceProfileManager.resourceProfileFromId(stage.resourceProfileId)
+    val taskCpus = 
ResourceProfile.getTaskCpusOrDefaultForProfile(resourceProfile, sc.conf)
+    taskCpus * numTask
+  }
+
+  /**
+   * Hook invoked after the final stage is created. Registers stages reachable 
from
+   * the final stage as concurrent so they can be submitted in parallel.
+   */
+  override def onFinalStageCreated(finalStage: Stage, properties: Properties): 
Unit = {
+
+    val queryBatchId = getStreamingBatchIdFromProperties(properties)
+
+    if (queryBatchId.nonEmpty && isConcurrentStagesEnabled(properties)) {
+      if (properties.getProperty(SPECULATION_ENABLED.key) == "true") {
+        // Speculation is not supported with concurrent stages.
+        throw new SparkException(
+          "Speculative execution is not supported with concurrent stages " +
+          s"(streaming query: $queryBatchId). Please disable 
${SPECULATION_ENABLED.key} config."
+        )
+      }
+
+      logInfo(log"Concurrent stages is enabled for [query 
${MDC(LogKeys.STREAMING_QUERY_ID,
+        queryBatchId.get.queryId)} batch ${MDC(LogKeys.BATCH_ID, 
queryBatchId.get.batchId)}]")
+
+      // Mark current stage and all its ancestors as concurrent
+      var totalCoresNeeded = 0
+      def visit(stage: Stage): Unit = {
+        if (!concurrentStages.contains(stage)) {
+          logInfo(log"Marking stage '${MDC(LogKeys.STAGE, stage)}' concurrent 
for [query ${MDC(
+            LogKeys.STREAMING_QUERY_ID, queryBatchId.get.queryId)} batch ${MDC(
+            LogKeys.BATCH_ID, queryBatchId.get.batchId)}]")
+          concurrentStages += stage
+          totalCoresNeeded += totalNumCoreForStage(stage)
+          stage.parents.foreach(visit)
+        }
+      }
+      visit(finalStage)
+
+      if (!sc.conf.get(STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED)) {
+        try {
+          val totalSlots = sc.schedulerBackend.defaultParallelism()
+          val coresInUse = 
runningStages.toArray.map(totalNumCoreForStage(_)).sum
+          if (totalSlots - coresInUse < totalCoresNeeded) {
+            throw new SparkRuntimeException(

Review Comment:
   When this throws, the stages added to concurrentStages above are leaked — 
handleJobSubmitted catches the exception and fails the job, but nothing ever 
clears those entries. A subsequent job whose stages share IDs (e.g. retries 
from the same RDDChain) would inherit them. Either clear concurrentStages of 
the stages just visited before throwing, or capture them in a local set and 
only commit to concurrentStages once the slot check passes.



##########
core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala:
##########
@@ -1069,7 +1076,12 @@ private[spark] class TaskSetManager(
     emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid, tasks(index), reason, 
null,
       accumUpdates, metricPeaks)
 
-    if (!isZombie && reason.countTowardsTaskFailures) {
+    val countTowardsTaskFailures = reason.countTowardsTaskFailures ||
+      // if the query is running in real time mode, any failures should 
contribute the task failures

Review Comment:
   "contribute the task failures" -> "count toward the task failures"



##########
core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.Properties
+
+import scala.collection.mutable
+
+import org.apache.spark.{MapOutputTrackerMaster, SparkContext, SparkEnv, 
SparkException, SparkRuntimeException, Success}
+import org.apache.spark.internal.LogKeys
+import org.apache.spark.internal.config.{SPECULATION_ENABLED, 
STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED}
+import org.apache.spark.resource.ResourceProfile
+import org.apache.spark.storage.BlockManagerMaster
+import org.apache.spark.util.Clock
+import org.apache.spark.util.SystemClock
+
+/**
+ *  A [[DAGScheduler]] that runs all the stages in a job without waiting for 
its parents
+ *  complete. This combined with streaming shuffle between the stages, allows 
for low latency
+ *  execution of streaming queries in real-time mode.
+ */
+class ConcurrentStageDAGScheduler(
+    sc: SparkContext,
+    taskScheduler: TaskScheduler,
+    listenerBus: LiveListenerBus,
+    mapOutputTracker: MapOutputTrackerMaster,
+    blockManagerMaster: BlockManagerMaster,
+    env: SparkEnv,
+    clock: Clock = new SystemClock())
+  extends DAGScheduler(
+    sc, taskScheduler, listenerBus, mapOutputTracker, blockManagerMaster, env, 
clock) {
+
+  import ConcurrentStageDAGScheduler._
+
+  def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
+    this(
+      sc,
+      taskScheduler,
+      sc.listenerBus,
+      sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
+      sc.env.blockManager.master,
+      sc.env
+    )
+  }
+
+  def this(sc: SparkContext) = this(sc, sc.taskScheduler)
+
+  // This contains all the concurrent states that are yet to be scheduled 
across all the jobs.
+  private[spark] val concurrentStages = new mutable.HashSet[Stage]
+
+  private[scheduler] case class DependentStageInfo(
+    parents: mutable.HashSet[Stage] = mutable.HashSet.empty,
+    delayedTaskCompletionEvents: mutable.ListBuffer[CompletionEvent] = 
mutable.ListBuffer.empty)
+
+  // This map holds parents of concurrently scheduled stages. When tasks for 
such a stage complete,
+  // and if any of the parents are still running, we delay processing of such 
events until parent
+  // stages are complete. We save these events in this map until then.
+  private[spark] val dependentStageMap = new mutable.HashMap[Stage, 
DependentStageInfo]
+
+  private def totalNumCoreForStage(stage: Stage): Int = {
+    val numTask = stage match {
+      case r: ResultStage => r.partitions.length
+      case m: ShuffleMapStage => m.numPartitions
+    }
+    val resourceProfile = 
sc.resourceProfileManager.resourceProfileFromId(stage.resourceProfileId)
+    val taskCpus = 
ResourceProfile.getTaskCpusOrDefaultForProfile(resourceProfile, sc.conf)
+    taskCpus * numTask
+  }
+
+  /**
+   * Hook invoked after the final stage is created. Registers stages reachable 
from
+   * the final stage as concurrent so they can be submitted in parallel.
+   */
+  override def onFinalStageCreated(finalStage: Stage, properties: Properties): 
Unit = {
+
+    val queryBatchId = getStreamingBatchIdFromProperties(properties)
+
+    if (queryBatchId.nonEmpty && isConcurrentStagesEnabled(properties)) {
+      if (properties.getProperty(SPECULATION_ENABLED.key) == "true") {
+        // Speculation is not supported with concurrent stages.
+        throw new SparkException(
+          "Speculative execution is not supported with concurrent stages " +
+          s"(streaming query: $queryBatchId). Please disable 
${SPECULATION_ENABLED.key} config."
+        )
+      }
+
+      logInfo(log"Concurrent stages is enabled for [query 
${MDC(LogKeys.STREAMING_QUERY_ID,
+        queryBatchId.get.queryId)} batch ${MDC(LogKeys.BATCH_ID, 
queryBatchId.get.batchId)}]")
+
+      // Mark current stage and all its ancestors as concurrent
+      var totalCoresNeeded = 0
+      def visit(stage: Stage): Unit = {
+        if (!concurrentStages.contains(stage)) {
+          logInfo(log"Marking stage '${MDC(LogKeys.STAGE, stage)}' concurrent 
for [query ${MDC(
+            LogKeys.STREAMING_QUERY_ID, queryBatchId.get.queryId)} batch ${MDC(
+            LogKeys.BATCH_ID, queryBatchId.get.batchId)}]")
+          concurrentStages += stage
+          totalCoresNeeded += totalNumCoreForStage(stage)
+          stage.parents.foreach(visit)
+        }
+      }
+      visit(finalStage)
+
+      if (!sc.conf.get(STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED)) {
+        try {
+          val totalSlots = sc.schedulerBackend.defaultParallelism()
+          val coresInUse = 
runningStages.toArray.map(totalNumCoreForStage(_)).sum
+          if (totalSlots - coresInUse < totalCoresNeeded) {
+            throw new SparkRuntimeException(
+              errorClass = "CONCURRENT_SCHEDULER_INSUFFICIENT_SLOT",
+              messageParameters = Map(
+                "numSlots" -> (totalSlots - coresInUse).toString,
+                "numTasks" -> totalCoresNeeded.toString))
+          }
+        } catch {
+          case e: UnsupportedOperationException =>
+            logWarning(log"${MDC(LogKeys.ERROR, e)}. Skipping slot check for 
RTM.")
+        }
+      }
+    } else {
+      super.onFinalStageCreated(finalStage, properties)
+    }
+  }
+
+  override def submitStage(stage: Stage): Unit = {
+    super.submitStage(stage)
+
+    if (!waitingStages.contains(stage) && concurrentStages.contains(stage)) {
+      // The current stage is not registered in waitingStages, which means it 
has
+      // no parents. This case we should remove it from concurrentStages since 
it is already
+      // running.
+      assert(runningStages.contains(stage), "stage should be running if not in 
waitingStages")
+      logInfo(log"Removing stage ${MDC(LogKeys.STAGE, stage)} from 
concurrentStages")
+      concurrentStages -= stage
+    }
+
+    // Find the stages that should be submitted concurrently with this stage.
+    waitingStages.intersect(concurrentStages).foreach { stage =>
+      logInfo(log"Submitting stage concurrently: ${MDC(LogKeys.STAGE, stage)}")
+      concurrentStages -= stage // Don't submit this stage concurrently for 
subsequent attempts.
+      stage.parents.foreach { parent =>
+        if (isRunningStage(parent)) {
+          logInfo(log"Updating dependent map for stage ${MDC(LogKeys.STAGE, 
stage)} with parent ${
+            MDC(LogKeys.PARENT_STAGE, parent)}")
+          dependentStageMap.getOrElseUpdate(stage, 
DependentStageInfo()).parents += parent
+        }
+      }
+      // Remove stage and its parents from concurrentStages
+      def removeFromConcurrentStages(stage: Stage): Unit = {
+        if (concurrentStages.contains(stage)) {
+          logInfo(log"Removing stage ${MDC(LogKeys.STAGE, stage)} from 
concurrentStages")
+          concurrentStages -= stage
+        }
+        stage.parents.foreach { parent =>
+          assert(!waitingStages.contains(parent), "Parent stage should not 
still be waiting")
+          removeFromConcurrentStages(parent)
+        }
+      }
+      removeFromConcurrentStages(stage)
+      submitConcurrentStage(stage)
+    }
+  }
+
+  // This is overridden to check if the task completion event should be 
delayed a parent stage
+  // till has running tasks. See comment for `dependentStageMap` for more 
details.
+  override private[scheduler] def handleTaskCompletion(event: 
CompletionEvent): Unit = {
+    val stageId = event.task.stageId
+    val taskId = event.taskInfo.taskId
+
+    getStage(stageId) match {
+      case Some(stage) if event.reason == Success && 
dependentStageMap.contains(stage) =>
+        val dependentStageInfo = dependentStageMap(stage)
+        logInfo(log"Delaying completion event for task ${MDC(LogKeys.TASK_ID, 
taskId)} in stage ${
+          MDC(LogKeys.STAGE, stage)}. Active parent(s): 
${MDC(LogKeys.PARENT_STAGES,
+          dependentStageInfo.parents.mkString(", "))}")
+        dependentStageInfo.delayedTaskCompletionEvents += event
+
+      case _ =>  // Otherwise handle the event as usual.
+        super.handleTaskCompletion(event)
+    }
+  }
+
+  // This is overridden to handle any delayed task completion events for 
dependent stages.
+  override def markStageAsFinished(

Review Comment:
   The dependentStageMap cleanup path only fires when a stage in the map is 
named as a parent via markStageAsFinished(parent). If a dependent stage itself 
aborts mid-job (e.g. its single allowed failure under maxTaskFailures=1), its 
own entry — including any buffered delayedTaskCompletionEvents — is never 
removed from dependentStageMap. With concurrent jobs sharing a long-lived 
scheduler instance, that's a slow leak across queries. Consider clearing the 
entry for stage itself inside markStageAsFinished (especially when 
errorMessage.isDefined).



##########
common/utils/src/main/resources/error/error-conditions.json:
##########
@@ -908,6 +890,12 @@
     ],
     "sqlState" : "0A000"
   },
+  "CONCURRENT_SCHEDULER_INSUFFICIENT_SLOT" : {
+    "message" : [
+      "The minimum number of free slots required in the cluster is <numTasks>, 
however, the cluster has only has <numSlots> slots free. Query will stall or 
fail. Increase cluster size to proceed."

Review Comment:
   nit: "has only has" -> "has only"



##########
core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.Properties
+
+import scala.collection.mutable
+
+import org.apache.spark.{MapOutputTrackerMaster, SparkContext, SparkEnv, 
SparkException, SparkRuntimeException, Success}
+import org.apache.spark.internal.LogKeys
+import org.apache.spark.internal.config.{SPECULATION_ENABLED, 
STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED}
+import org.apache.spark.resource.ResourceProfile
+import org.apache.spark.storage.BlockManagerMaster
+import org.apache.spark.util.Clock
+import org.apache.spark.util.SystemClock
+
+/**
+ *  A [[DAGScheduler]] that runs all the stages in a job without waiting for 
its parents
+ *  complete. This combined with streaming shuffle between the stages, allows 
for low latency
+ *  execution of streaming queries in real-time mode.
+ */
+class ConcurrentStageDAGScheduler(
+    sc: SparkContext,
+    taskScheduler: TaskScheduler,
+    listenerBus: LiveListenerBus,
+    mapOutputTracker: MapOutputTrackerMaster,
+    blockManagerMaster: BlockManagerMaster,
+    env: SparkEnv,
+    clock: Clock = new SystemClock())
+  extends DAGScheduler(
+    sc, taskScheduler, listenerBus, mapOutputTracker, blockManagerMaster, env, 
clock) {
+
+  import ConcurrentStageDAGScheduler._
+
+  def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
+    this(
+      sc,
+      taskScheduler,
+      sc.listenerBus,
+      sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
+      sc.env.blockManager.master,
+      sc.env
+    )
+  }
+
+  def this(sc: SparkContext) = this(sc, sc.taskScheduler)
+
+  // This contains all the concurrent states that are yet to be scheduled 
across all the jobs.

Review Comment:
   Typo: "states" → "stages".



##########
common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java:
##########
@@ -792,6 +793,7 @@ public enum LogKeys implements LogKey {
   STREAMING_DATA_SOURCE_NAME,
   STREAMING_OFFSETS_END,
   STREAMING_OFFSETS_START,
+  STREAMING_QUERY_ID,

Review Comment:
   QUERY_ID already exists and is what 
StructuredStreamingIdAwareSchedulerLogging uses to log streaming query IDs. 
Adding STREAMING_QUERY_ID creates a parallel key for the same concept. Suggest 
dropping this addition and using LogKeys.QUERY_ID at all the callsites, or 
update the callsites in StructuredStreamingIdAwareSchedulerLogging.



##########
core/src/main/scala/org/apache/spark/scheduler/ConcurrentStageDAGScheduler.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.Properties
+
+import scala.collection.mutable
+
+import org.apache.spark.{MapOutputTrackerMaster, SparkContext, SparkEnv, 
SparkException, SparkRuntimeException, Success}
+import org.apache.spark.internal.LogKeys
+import org.apache.spark.internal.config.{SPECULATION_ENABLED, 
STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED}
+import org.apache.spark.resource.ResourceProfile
+import org.apache.spark.storage.BlockManagerMaster
+import org.apache.spark.util.Clock
+import org.apache.spark.util.SystemClock
+
+/**
+ *  A [[DAGScheduler]] that runs all the stages in a job without waiting for 
its parents
+ *  complete. This combined with streaming shuffle between the stages, allows 
for low latency
+ *  execution of streaming queries in real-time mode.
+ */
+class ConcurrentStageDAGScheduler(
+    sc: SparkContext,
+    taskScheduler: TaskScheduler,
+    listenerBus: LiveListenerBus,
+    mapOutputTracker: MapOutputTrackerMaster,
+    blockManagerMaster: BlockManagerMaster,
+    env: SparkEnv,
+    clock: Clock = new SystemClock())
+  extends DAGScheduler(
+    sc, taskScheduler, listenerBus, mapOutputTracker, blockManagerMaster, env, 
clock) {
+
+  import ConcurrentStageDAGScheduler._
+
+  def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
+    this(
+      sc,
+      taskScheduler,
+      sc.listenerBus,
+      sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
+      sc.env.blockManager.master,
+      sc.env
+    )
+  }
+
+  def this(sc: SparkContext) = this(sc, sc.taskScheduler)
+
+  // This contains all the concurrent states that are yet to be scheduled 
across all the jobs.
+  private[spark] val concurrentStages = new mutable.HashSet[Stage]
+
+  private[scheduler] case class DependentStageInfo(
+    parents: mutable.HashSet[Stage] = mutable.HashSet.empty,
+    delayedTaskCompletionEvents: mutable.ListBuffer[CompletionEvent] = 
mutable.ListBuffer.empty)
+
+  // This map holds parents of concurrently scheduled stages. When tasks for 
such a stage complete,
+  // and if any of the parents are still running, we delay processing of such 
events until parent
+  // stages are complete. We save these events in this map until then.
+  private[spark] val dependentStageMap = new mutable.HashMap[Stage, 
DependentStageInfo]
+
+  private def totalNumCoreForStage(stage: Stage): Int = {
+    val numTask = stage match {
+      case r: ResultStage => r.partitions.length
+      case m: ShuffleMapStage => m.numPartitions
+    }
+    val resourceProfile = 
sc.resourceProfileManager.resourceProfileFromId(stage.resourceProfileId)
+    val taskCpus = 
ResourceProfile.getTaskCpusOrDefaultForProfile(resourceProfile, sc.conf)
+    taskCpus * numTask
+  }
+
+  /**
+   * Hook invoked after the final stage is created. Registers stages reachable 
from
+   * the final stage as concurrent so they can be submitted in parallel.
+   */
+  override def onFinalStageCreated(finalStage: Stage, properties: Properties): 
Unit = {
+
+    val queryBatchId = getStreamingBatchIdFromProperties(properties)
+
+    if (queryBatchId.nonEmpty && isConcurrentStagesEnabled(properties)) {
+      if (properties.getProperty(SPECULATION_ENABLED.key) == "true") {
+        // Speculation is not supported with concurrent stages.
+        throw new SparkException(
+          "Speculative execution is not supported with concurrent stages " +
+          s"(streaming query: $queryBatchId). Please disable 
${SPECULATION_ENABLED.key} config."
+        )
+      }
+
+      logInfo(log"Concurrent stages is enabled for [query 
${MDC(LogKeys.STREAMING_QUERY_ID,
+        queryBatchId.get.queryId)} batch ${MDC(LogKeys.BATCH_ID, 
queryBatchId.get.batchId)}]")
+
+      // Mark current stage and all its ancestors as concurrent
+      var totalCoresNeeded = 0
+      def visit(stage: Stage): Unit = {
+        if (!concurrentStages.contains(stage)) {
+          logInfo(log"Marking stage '${MDC(LogKeys.STAGE, stage)}' concurrent 
for [query ${MDC(
+            LogKeys.STREAMING_QUERY_ID, queryBatchId.get.queryId)} batch ${MDC(
+            LogKeys.BATCH_ID, queryBatchId.get.batchId)}]")
+          concurrentStages += stage
+          totalCoresNeeded += totalNumCoreForStage(stage)
+          stage.parents.foreach(visit)
+        }
+      }
+      visit(finalStage)
+
+      if (!sc.conf.get(STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED)) {
+        try {
+          val totalSlots = sc.schedulerBackend.defaultParallelism()
+          val coresInUse = 
runningStages.toArray.map(totalNumCoreForStage(_)).sum
+          if (totalSlots - coresInUse < totalCoresNeeded) {
+            throw new SparkRuntimeException(
+              errorClass = "CONCURRENT_SCHEDULER_INSUFFICIENT_SLOT",
+              messageParameters = Map(
+                "numSlots" -> (totalSlots - coresInUse).toString,
+                "numTasks" -> totalCoresNeeded.toString))
+          }
+        } catch {
+          case e: UnsupportedOperationException =>
+            logWarning(log"${MDC(LogKeys.ERROR, e)}. Skipping slot check for 
RTM.")
+        }
+      }
+    } else {
+      super.onFinalStageCreated(finalStage, properties)
+    }
+  }
+
+  override def submitStage(stage: Stage): Unit = {
+    super.submitStage(stage)
+
+    if (!waitingStages.contains(stage) && concurrentStages.contains(stage)) {
+      // The current stage is not registered in waitingStages, which means it 
has
+      // no parents. This case we should remove it from concurrentStages since 
it is already
+      // running.
+      assert(runningStages.contains(stage), "stage should be running if not in 
waitingStages")
+      logInfo(log"Removing stage ${MDC(LogKeys.STAGE, stage)} from 
concurrentStages")
+      concurrentStages -= stage
+    }
+
+    // Find the stages that should be submitted concurrently with this stage.
+    waitingStages.intersect(concurrentStages).foreach { stage =>
+      logInfo(log"Submitting stage concurrently: ${MDC(LogKeys.STAGE, stage)}")
+      concurrentStages -= stage // Don't submit this stage concurrently for 
subsequent attempts.
+      stage.parents.foreach { parent =>
+        if (isRunningStage(parent)) {
+          logInfo(log"Updating dependent map for stage ${MDC(LogKeys.STAGE, 
stage)} with parent ${
+            MDC(LogKeys.PARENT_STAGE, parent)}")
+          dependentStageMap.getOrElseUpdate(stage, 
DependentStageInfo()).parents += parent
+        }
+      }
+      // Remove stage and its parents from concurrentStages
+      def removeFromConcurrentStages(stage: Stage): Unit = {
+        if (concurrentStages.contains(stage)) {
+          logInfo(log"Removing stage ${MDC(LogKeys.STAGE, stage)} from 
concurrentStages")
+          concurrentStages -= stage
+        }
+        stage.parents.foreach { parent =>
+          assert(!waitingStages.contains(parent), "Parent stage should not 
still be waiting")
+          removeFromConcurrentStages(parent)
+        }
+      }
+      removeFromConcurrentStages(stage)
+      submitConcurrentStage(stage)
+    }
+  }
+
+  // This is overridden to check if the task completion event should be 
delayed a parent stage
+  // till has running tasks. See comment for `dependentStageMap` for more 
details.
+  override private[scheduler] def handleTaskCompletion(event: 
CompletionEvent): Unit = {
+    val stageId = event.task.stageId
+    val taskId = event.taskInfo.taskId
+
+    getStage(stageId) match {
+      case Some(stage) if event.reason == Success && 
dependentStageMap.contains(stage) =>
+        val dependentStageInfo = dependentStageMap(stage)
+        logInfo(log"Delaying completion event for task ${MDC(LogKeys.TASK_ID, 
taskId)} in stage ${
+          MDC(LogKeys.STAGE, stage)}. Active parent(s): 
${MDC(LogKeys.PARENT_STAGES,
+          dependentStageInfo.parents.mkString(", "))}")
+        dependentStageInfo.delayedTaskCompletionEvents += event
+
+      case _ =>  // Otherwise handle the event as usual.
+        super.handleTaskCompletion(event)
+    }
+  }
+
+  // This is overridden to handle any delayed task completion events for 
dependent stages.
+  override def markStageAsFinished(
+    stage: Stage,
+    errorMessage: Option[String] = None,
+    willRetry: Boolean = false): Unit = {
+
+    super.markStageAsFinished(stage, errorMessage, willRetry)
+
+    // If this is a parent of a stage in dependentStageMap, remove it from 
parents.
+    val dependentStages = dependentStageMap
+      .filter(_._2.parents.contains(stage))
+      .keys
+
+    dependentStages.foreach { dependent =>
+      if (errorMessage.isEmpty) {
+        assert(
+          isRunningStage(dependent),
+          s"Parent stages $stage's dependent stage $dependent should be 
running")
+      }
+      logInfo(log"Removing parent stage ${MDC(LogKeys.PARENT_STAGE, stage)} 
from dependent map " +
+        log"for stage ${MDC(LogKeys.STAGE, dependent)}")
+      dependentStageMap(dependent).parents -= stage
+      checkDependentStageTasks(dependent)
+    }
+  }
+
+  // Checks if the dependent stage's parents are all done. If all the parents 
are done,
+  // enqueues any saved task completion event (if any).
+  private def checkDependentStageTasks(stage: Stage): Unit = {
+    val dependentStageInfo = dependentStageMap.getOrElse(
+      stage, throw new RuntimeException(s"Stage $stage is not in 
dependentStageMap")
+    )
+
+    if (dependentStageInfo.parents.isEmpty) {
+      val delayedEvents = dependentStageInfo.delayedTaskCompletionEvents
+      logInfo(log"All the parents are done for ${MDC(LogKeys.STAGE, stage)}. 
Removing it from " +
+        log"the map. It has ${MDC(LogKeys.NUM_EVENTS, 
delayedEvents.size.toLong)} " +
+        log"task completion events")
+      dependentStageMap -= stage
+      delayedEvents.foreach { event =>
+        logInfo(log"Posting delayed task ${MDC(LogKeys.TASK_ID, 
event.taskInfo.taskId)} " +
+          log"completion event for stage ${MDC(LogKeys.STAGE, stage)}")
+        postSchedulerEvent(event)
+      }
+    }
+  }
+}
+
+object ConcurrentStageDAGScheduler {
+
+  val CONCURRENT_STAGES_ENABLED_PROPERTY: String = 
"streaming.concurrent.stages.enabled"
+
+  def isConcurrentStagesEnabled(properties: Properties): Boolean = {
+    properties != null &&
+      properties.getProperty(CONCURRENT_STAGES_ENABLED_PROPERTY) == "true"
+  }
+
+  /**
+   * Extracts the [[StreamingBatchId]] from the given properties if all three 
of the streaming
+   * query id, run id and batch id are present.
+   */
+  def getStreamingBatchIdFromProperties(properties: Properties): 
Option[StreamingBatchId] = {
+    if (properties == null) {
+      return None
+    }
+
+    val queryId = Option(properties.getProperty("sql.streaming.queryId"))

Review Comment:
   These property keys are already defined as constants in the same package: 
StructuredStreamingIdAwareSchedulerLogging.QUERY_ID_KEY and  BATCH_ID_KEY. 
Suggest reusing them (and adding a RUN_ID_KEY there) so  all consumers of 
streaming-job properties go through one source.
   
   Also: StreamingBatchId.runId is set here but never read anywhere — if it's 
not intended to be consumed, drop the field from the case class; if it is, the 
consuming code is missing.



##########
core/src/main/scala/org/apache/spark/internal/config/package.scala:
##########
@@ -2396,6 +2396,26 @@ package object config {
       .booleanConf
       .createWithDefault(true)
 
+  private[spark] val STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED =

Review Comment:
   Normally we would use _ENABLED instead of _DISABLED, to avoid 
double-negative.



##########
core/src/test/scala/org/apache/spark/scheduler/ConcurrentStageDAGSchedulerSuite.scala:
##########
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.Properties
+
+import org.apache.spark.HashPartitioner
+import org.apache.spark.ShuffleDependency
+import org.apache.spark.SparkConf
+import org.apache.spark.SparkContext
+import org.apache.spark.internal.config.SPECULATION_ENABLED
+import 
org.apache.spark.internal.config.STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED
+
+class ConcurrentStageDAGSchedulerSuite extends DAGSchedulerSuiteBase {
+
+  // The unit-test SparkContext runs in local[2] mode, but the concurrent 
pipelines exercised
+  // here often need more slots than that. Disable the slot check so the tests 
aren't gated by
+  // executor capacity.
+  override def conf: SparkConf =
+    super.conf.set(STREAMING_REALTIME_MODE_SLOTS_CHECK_DISABLED, true)
+
+  class TestConcurrentStageDAGScheduler(sc: SparkContext)
+    extends ConcurrentStageDAGScheduler(
+      sc,
+      taskScheduler,
+      sc.listenerBus,
+      mapOutputTracker,
+      blockManagerMaster,
+      sc.env)
+    with TestDAGScheduler
+
+  override def createInitialScheduler(sc: SparkContext): DAGScheduler = {
+    new TestConcurrentStageDAGScheduler(sc)
+  }
+
+  // Catch the job failure exception with a listener.
+  private class TestJobListener extends JobListener {
+    private var failureException: Option[Exception] = None
+
+    override def jobFailed(exception: Exception): Unit = {
+      failureException = Some(exception)
+    }
+
+    override def taskSucceeded(index: Int, result: Any): Unit = { }
+
+    def expectFailure(): Exception = {
+      assert(failureException.nonEmpty, "Job was expected to fail with an 
exception, but didn't")
+      failureException.get
+    }
+  }
+
+
+  /** Default job properties with query settings and concurrent stages 
enabled. */
+  private val testProperties: Properties = {
+    val properties = new Properties()
+    properties.setProperty("sql.streaming.queryId", "test_query_id")
+    properties.setProperty("sql.streaming.runId", "test_run_id")
+    properties.setProperty("streaming.sql.batchId", "5")
+    
properties.setProperty(ConcurrentStageDAGScheduler.CONCURRENT_STAGES_ENABLED_PROPERTY,
 "true")
+    new Properties(properties) {
+      // Make it read-only.
+      override def setProperty(key: String, value: String): AnyRef = {
+        throw new UnsupportedOperationException("Default properties are 
read-only.")
+      }
+    }
+  }
+
+  test("Simple job with two concurrent stages") {
+    // Run a simple job with two stages. Both stages should be running 
concurrently.
+
+    val mapStage = new MyRDD(sc, 1, Nil) // stage_0
+    val shuffleDep = new ShuffleDependency(mapStage, new HashPartitioner(1))
+    val resultStage = new MyRDD(sc, 3, List(shuffleDep)) // stage_1
+
+    // Shape: [stage_0, map stage, parent] <--- [stage_1, result stage]
+
+    submit(resultStage, Array(0), properties = testProperties)
+
+    assert(scheduler.waitingStages.isEmpty) // Both are submitted.
+    assert(scheduler.runningStages.map(_.id) === Set(0, 1)) // Both stages are 
running.
+
+    // Verify concurrent scheduler specific state.
+    val concurrentScheduler = 
scheduler.asInstanceOf[TestConcurrentStageDAGScheduler]
+
+    assert(concurrentScheduler.concurrentStages.isEmpty) // All are already 
scheduled
+
+    val depStageMap = concurrentScheduler.dependentStageMap
+    assert(depStageMap.keys.map(_.id) == Set(1)) // Result stage is the key.
+    assert(depStageMap.values.flatMap(_.parents.map(_.id)) == Seq(0)) // Map 
stage is the parent.
+    assert(depStageMap.values.flatMap(_.delayedTaskCompletionEvents).isEmpty) 
// No completed tasks.
+
+    // First complete the result stage. Its tasks will complete, but the 
actual stage would still
+    // be running since its parent (map stage) hasn't completed yet.
+
+    completeNextResultStageWithSuccess(1, 0)
+    assert(scheduler.runningStages.map(_.id) === Set(0, 1)) // Both stages are 
still running.
+    // dependentStageMap should have the completed task from result stage 
enqueued.
+    assert(depStageMap.values.flatMap(_.delayedTaskCompletionEvents).size === 
1)
+
+    // Now complete the map stage. This should complete the result stage as 
well.
+    completeShuffleMapStageSuccessfully(0, 0, 1)
+
+    assert(scheduler.runningStages.map(_.id) === Set()) // Both stages are 
complete.
+    assert(depStageMap.isEmpty) // No more dependent stages.
+
+    assertDataStructuresEmpty()
+  }
+
+  test("Default scheduler using a simple job with concurrent stages disabled") 
{
+    // This is opposite of the previous test. Concurrent stages are disabled, 
so the stages should
+    // be submitted one after the other.
+
+    val mapStage = new MyRDD(sc, 1, Nil) // stage_0
+    val shuffleDep = new ShuffleDependency(mapStage, new HashPartitioner(1))
+    val resultStage = new MyRDD(sc, 3, List(shuffleDep)) // stage_1
+
+    // Shape: [stage_0, map stage, parent] <--- [stage_1, result stage]
+
+    submit(resultStage, Array(0), properties = new Properties())
+
+    assert(scheduler.runningStages.map(_.id) == Set(0)) // Only the map stage 
is running.
+    assert(scheduler.waitingStages.map(_.id) == Set(1)) // Result stage is 
waiting.
+
+    val concurrentScheduler = 
scheduler.asInstanceOf[TestConcurrentStageDAGScheduler]
+    assert(concurrentScheduler.concurrentStages.isEmpty) // No concurrent 
stages.
+    assert(concurrentScheduler.dependentStageMap.isEmpty) // No dependent 
stages.
+
+    // Complete the map stage. This should submit the result stage.
+    completeShuffleMapStageSuccessfully(0, 0, 1)
+
+    assert(scheduler.runningStages.map(_.id) == Set(1)) // Only the result 
stage is running.
+    assert(scheduler.waitingStages.map(_.id) == Set()) // No waiting stages
+
+    completeNextResultStageWithSuccess(1, 0)
+    assertDataStructuresEmpty()
+  }
+
+  test("Complex pipeline with many stages") {
+    // Run a complex pipeline with multiple stages with multiple branches. 
Such a pipeline not
+    // common, but useful to ensure scheduler works as expected.
+
+    // Shape:
+    //                  /<-------------------- stage_D
+    // stage_A <--- stage_B <--- stage_C <---\   ^
+    //        \         \<---------/          \  |
+    //         \ <----------------/            \ |
+    // stage_E <---------------------------  stage_F
+
+    // All of these should be running concurrently.
+
+    val rddA = new MyRDD(sc, 2, Nil).setName("rddA")
+    val shuffleDepA = new ShuffleDependency(rddA, new HashPartitioner(1))
+
+    val rddB = new MyRDD(sc, 1, List(shuffleDepA)).setName("rddB")
+    val shuffleDepB = new ShuffleDependency(rddB, new HashPartitioner(1))
+
+    val rddC = new MyRDD(sc, 1, List(shuffleDepA, shuffleDepB)).setName("rddC")
+    val shuffleDepC = new ShuffleDependency(rddC, new HashPartitioner(5))
+
+    val rddD = new MyRDD(sc, 4, List(shuffleDepB)).setName("rddD")
+    val shuffleDepD = new ShuffleDependency(rddD, new HashPartitioner(5))
+
+    val rddE = new MyRDD(sc, 3, Nil).setName("rddE")
+    val shuffleDepE = new ShuffleDependency(rddE, new HashPartitioner(5))
+
+    val rddF = new MyRDD(sc, 5, List(shuffleDepC, shuffleDepD, 
shuffleDepE)).setName("rddF")
+
+    submit(rddF, Array(0, 1, 2, 3, 4), properties = testProperties)
+
+    assert(scheduler.waitingStages.isEmpty) // All the stages are submitted.
+    assert(scheduler.runningStages.map(_.id) == Set(0, 1, 2, 3, 4, 5)) // All 
6 are running.
+
+    // Assign stage ids corresponding to the RDDs A, B, etc
+    def stageFor(rddName: String): Stage = 
scheduler.runningStages.find(_.rdd.name == rddName).get
+
+    val sA = stageFor("rddA")
+    val sB = stageFor("rddB")
+    val sC = stageFor("rddC")
+    val sD = stageFor("rddD")
+    val sE = stageFor("rddE")
+    val sF = stageFor("rddF")
+
+    // log stage id mapping for debugging:
+    for (name <- List("A", "B", "C", "D", "E", "F")) {
+      logInfo(s"Stage id for stage $name is ${stageFor("rdd" + name).id}")
+    }
+
+    // Verify concurrent scheduler specific state.
+    val concurrentScheduler = 
scheduler.asInstanceOf[TestConcurrentStageDAGScheduler]
+
+    assert(concurrentScheduler.concurrentStages.isEmpty) // All are already 
scheduled
+
+    val depStageMap = concurrentScheduler.dependentStageMap
+    assert(depStageMap.keys === Set(sB, sC, sD, sF)) // All non-root stages 
are keys.
+    assert(depStageMap.values.flatMap(_.parents).toSet === Set(
+      sA, sB, sC, sD, sE)) // All except the results stage.
+    assert(depStageMap.values.flatMap(_.delayedTaskCompletionEvents).isEmpty) 
// No completed tasks.
+
+    // Complete stages in order-of-order and verify the state.
+
+    // First complete C. Entry for C would be updated with the completed task.
+    assert(depStageMap(sC).delayedTaskCompletionEvents.size === 0)
+    completeShuffleMapStageSuccessfully(sC.id, 0, 5) // Complete stage C.
+    assert(depStageMap(sC).delayedTaskCompletionEvents.size === 1)
+
+    // All the 6 stages are still 'running' since C's completion events are 
delayed.
+    assert(scheduler.runningStages.map(_.id) == Set(0, 1, 2, 3, 4, 5))
+
+    // Now complete stage D. This is similar to completing C. The tasks are 
enqueued.
+    assert(depStageMap(sD).delayedTaskCompletionEvents.size === 0)
+    completeShuffleMapStageSuccessfully(sD.id, 0, 5) // Complete stage D
+    assert(depStageMap(sD).delayedTaskCompletionEvents.size === 4) // 4 tasks 
in stage C.

Review Comment:
   "stage C" -> "stage D"



##########
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala:
##########
@@ -6107,7 +6123,7 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     }
   }
 
-  private def assertDataStructuresEmpty(): Unit = {
+  protected def assertDataStructuresEmpty(): Unit = {

Review Comment:
   The new suite inherits all the tests against ConcurrentStageDAGScheduler, 
but this helper doesn't validate the new state. A leak in concurrentStages or 
dependentStageMap (see the slot-check and aborted-stage concerns above) 
wouldn't be caught by any inherited test. Consider overriding in 
ConcurrentStageDAGSchedulerSuite(or adding a protected def extraEmptyChecks: 
Unit = () hook here) so subclasses can extend the empty-state assertions.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to