beliefer commented on a change in pull request #29747:
URL: https://github.com/apache/spark/pull/29747#discussion_r488407757



##########
File path: 
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerTestHelper.scala
##########
@@ -0,0 +1,614 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.Properties
+import java.util.concurrent.atomic.AtomicBoolean
+
+import scala.annotation.meta.param
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
+import scala.util.control.NonFatal
+
+import org.mockito.Mockito.spy
+import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits}
+
+import org.apache.spark._
+import org.apache.spark.broadcast.BroadcastManager
+import org.apache.spark.executor.ExecutorMetrics
+import org.apache.spark.rdd.{DeterministicLevel, RDD}
+import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
+import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
+import org.apache.spark.util.{AccumulatorV2, CallSite}
+
+class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler)
+  extends DAGSchedulerEventProcessLoop(dagScheduler) {
+
+  override def post(event: DAGSchedulerEvent): Unit = {
+    try {
+      // Forward event to `onReceive` directly to avoid processing event 
asynchronously.
+      onReceive(event)
+    } catch {
+      case NonFatal(e) => onError(e)
+    }
+  }
+
+  override def onError(e: Throwable): Unit = {
+    logError("Error in DAGSchedulerEventLoop: ", e)
+    dagScheduler.stop()
+    throw e
+  }
+
+}
+
+class MyCheckpointRDD(
+    sc: SparkContext,
+    numPartitions: Int,
+    dependencies: List[Dependency[_]],
+    locations: Seq[Seq[String]] = Nil,
+    @(transient @param) tracker: MapOutputTrackerMaster = null,
+    indeterminate: Boolean = false)
+  extends MyRDD(sc, numPartitions, dependencies, locations, tracker, 
indeterminate) {
+
+  // Allow doCheckpoint() on this RDD.
+  override def compute(split: Partition, context: TaskContext): Iterator[(Int, 
Int)] =
+    Iterator.empty
+}
+
+/**
+ * An RDD for passing to DAGScheduler. These RDDs will use the dependencies and
+ * preferredLocations (if any) that are passed to them. They are deliberately 
not executable
+ * so we can test that DAGScheduler does not try to execute RDDs locally.
+ *
+ * Optionally, one can pass in a list of locations to use as preferred 
locations for each task,
+ * and a MapOutputTrackerMaster to enable reduce task locality. We pass the 
tracker separately
+ * because, in this test suite, it won't be the same as 
sc.env.mapOutputTracker.
+ */
+class MyRDD(
+    sc: SparkContext,
+    numPartitions: Int,
+    dependencies: List[Dependency[_]],
+    locations: Seq[Seq[String]] = Nil,
+    @(transient @param) tracker: MapOutputTrackerMaster = null,
+    indeterminate: Boolean = false)
+  extends RDD[(Int, Int)](sc, dependencies) with Serializable {
+
+  override def compute(split: Partition, context: TaskContext): Iterator[(Int, 
Int)] =
+    throw new RuntimeException("should not be reached")
+
+  override def getPartitions: Array[Partition] = (0 until numPartitions).map(i 
=> new Partition {
+    override def index: Int = i
+  }).toArray
+
+  override protected def getOutputDeterministicLevel = {
+    if (indeterminate) DeterministicLevel.INDETERMINATE else 
super.getOutputDeterministicLevel
+  }
+
+  override def getPreferredLocations(partition: Partition): Seq[String] = {
+    if (locations.isDefinedAt(partition.index)) {
+      locations(partition.index)
+    } else if (tracker != null && dependencies.size == 1 &&
+        dependencies(0).isInstanceOf[ShuffleDependency[_, _, _]]) {
+      // If we have only one shuffle dependency, use the same code path as 
ShuffledRDD for locality
+      val dep = dependencies(0).asInstanceOf[ShuffleDependency[_, _, _]]
+      tracker.getPreferredLocationsForShuffle(dep, partition.index)
+    } else {
+      Nil
+    }
+  }
+
+  override def toString: String = "DAGSchedulerSuiteRDD " + id
+}
+
+class DAGSchedulerSuiteDummyException extends Exception
+
+class DAGSchedulerTestHelper extends SparkFunSuite with TempLocalSparkContext 
with TimeLimits {
+
+  import DAGSchedulerTestHelper._
+
+  // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like 
ScalaTest 2.2.x
+  implicit val defaultSignaler: Signaler = ThreadSignaler
+
+  private var firstInit: Boolean = _
+  /** Set of TaskSets the DAGScheduler has requested executed. */
+  val taskSets = scala.collection.mutable.Buffer[TaskSet]()
+
+  /** Stages for which the DAGScheduler has called 
TaskScheduler.cancelTasks(). */
+  val cancelledStages = new HashSet[Int]()
+
+  val tasksMarkedAsCompleted = new ArrayBuffer[Task[_]]()
+
+  val taskScheduler = new TaskScheduler() {
+    override def schedulingMode: SchedulingMode = SchedulingMode.FIFO
+    override def rootPool: Pool = new Pool("", schedulingMode, 0, 0)
+    override def start() = {}
+    override def stop() = {}
+    override def executorHeartbeatReceived(
+        execId: String,
+        accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])],
+        blockManagerId: BlockManagerId,
+        executorUpdates: Map[(Int, Int), ExecutorMetrics]): Boolean = true
+    override def submitTasks(taskSet: TaskSet) = {
+      // normally done by TaskSetManager
+      taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch)
+      taskSets += taskSet
+    }
+    override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = {
+      cancelledStages += stageId
+    }
+    override def killTaskAttempt(
+      taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
+    override def killAllTaskAttempts(
+      stageId: Int, interruptThread: Boolean, reason: String): Unit = {}
+    override def notifyPartitionCompletion(stageId: Int, partitionId: Int): 
Unit = {
+      taskSets.filter(_.stageId == stageId).lastOption.foreach { ts =>
+        val tasks = ts.tasks.filter(_.partitionId == partitionId)
+        assert(tasks.length == 1)
+        tasksMarkedAsCompleted += tasks.head
+      }
+    }
+    override def setDAGScheduler(dagScheduler: DAGScheduler) = {}
+    override def defaultParallelism() = 2
+    override def executorLost(executorId: String, reason: ExecutorLossReason): 
Unit = {}
+    override def workerRemoved(workerId: String, host: String, message: 
String): Unit = {}
+    override def applicationAttemptId(): Option[String] = None
+    override def executorDecommission(
+      executorId: String,
+      decommissionInfo: ExecutorDecommissionInfo): Unit = {}
+    override def getExecutorDecommissionState(
+      executorId: String): Option[ExecutorDecommissionState] = None
+  }
+
+  /**
+   * Listeners which records some information to verify in UTs. Getter-kind 
methods in this class
+   * ensures the value is returned after ensuring there's no event to process, 
as well as the
+   * value is immutable: prevent showing odd result by race condition.
+   */
+  class EventInfoRecordingListener extends SparkListener {
+    private val _submittedStageInfos = new HashSet[StageInfo]
+    private val _successfulStages = new HashSet[Int]
+    private val _failedStages = new ArrayBuffer[Int]
+    private val _stageByOrderOfExecution = new ArrayBuffer[Int]
+    private val _endedTasks = new HashSet[Long]
+
+    override def onStageSubmitted(stageSubmitted: 
SparkListenerStageSubmitted): Unit = {
+      _submittedStageInfos += stageSubmitted.stageInfo
+    }
+
+    override def onStageCompleted(stageCompleted: 
SparkListenerStageCompleted): Unit = {
+      val stageInfo = stageCompleted.stageInfo
+      _stageByOrderOfExecution += stageInfo.stageId
+      if (stageInfo.failureReason.isEmpty) {
+        _successfulStages += stageInfo.stageId
+      } else {
+        _failedStages += stageInfo.stageId
+      }
+    }
+
+    override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+      _endedTasks += taskEnd.taskInfo.taskId
+    }
+
+    def submittedStageInfos: Set[StageInfo] = {
+      waitForListeners()
+      _submittedStageInfos.toSet
+    }
+
+    def successfulStages: Set[Int] = {
+      waitForListeners()
+      _successfulStages.toSet
+    }
+
+    def failedStages: List[Int] = {
+      waitForListeners()
+      _failedStages.toList
+    }
+
+    def stageByOrderOfExecution: List[Int] = {
+      waitForListeners()
+      _stageByOrderOfExecution.toList
+    }
+
+    def endedTasks: Set[Long] = {
+      waitForListeners()
+      _endedTasks.toSet
+    }
+
+    private def waitForListeners(): Unit = sc.listenerBus.waitUntilEmpty()
+  }
+
+  var sparkListener: EventInfoRecordingListener = null
+
+  var blockManagerMaster: BlockManagerMaster = null
+  var mapOutputTracker: MapOutputTrackerMaster = null
+  var broadcastManager: BroadcastManager = null
+  var securityMgr: SecurityManager = null
+  var scheduler: DAGScheduler = null
+  var dagEventProcessLoopTester: DAGSchedulerEventProcessLoop = null
+
+  /**
+   * Set of cache locations to return from our mock BlockManagerMaster.
+   * Keys are (rdd ID, partition ID). Anything not present will return an empty
+   * list of cache locations silently.
+   */
+  val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]]
+  // stub out BlockManagerMaster.getLocations to use our cacheLocations
+  class MyBlockManagerMaster(conf: SparkConf) extends BlockManagerMaster(null, 
null, conf, true) {
+    override def getLocations(blockIds: Array[BlockId]): 
IndexedSeq[Seq[BlockManagerId]] = {
+      blockIds.map {
+        _.asRDDId.map { id => (id.rddId -> id.splitIndex)
+        }.flatMap { key => cacheLocations.get(key)
+        }.getOrElse(Seq())
+      }.toIndexedSeq
+    }
+    override def removeExecutor(execId: String): Unit = {
+      // don't need to propagate to the driver, which we don't have
+    }
+  }
+
+  /** The list of results that DAGScheduler has collected. */
+  val results = new HashMap[Int, Any]()
+  var failure: Exception = _
+  val jobListener = new JobListener() {
+    override def taskSucceeded(index: Int, result: Any) = results.put(index, 
result)
+    override def jobFailed(exception: Exception) = { failure = exception }
+  }
+
+  /** A simple helper class for creating custom JobListeners */
+  class SimpleListener extends JobListener {
+    val results = new HashMap[Int, Any]
+    var failure: Exception = null
+    override def taskSucceeded(index: Int, result: Any): Unit = 
results.put(index, result)
+    override def jobFailed(exception: Exception): Unit = { failure = exception 
}
+  }
+
+  class MyMapOutputTrackerMaster(
+      conf: SparkConf,
+      broadcastManager: BroadcastManager)
+    extends MapOutputTrackerMaster(conf, broadcastManager, true) {
+
+    override def sendTracker(message: Any): Unit = {
+      // no-op, just so we can stop this to avoid leaking threads
+    }
+  }
+
+  override def beforeEach(): Unit = {
+    super.beforeEach()
+    firstInit = true
+  }
+
+  override def sc: SparkContext = {
+    val sc = super.sc
+    if (firstInit) {
+      init(sc)
+      firstInit = false
+    }
+    sc
+  }
+
+  private def init(sc: SparkContext): Unit = {
+    sparkListener = new EventInfoRecordingListener
+    failure = null
+    sc.addSparkListener(sparkListener)
+    taskSets.clear()
+    tasksMarkedAsCompleted.clear()
+    cancelledStages.clear()
+    cacheLocations.clear()
+    results.clear()
+    securityMgr = new SecurityManager(sc.getConf)
+    broadcastManager = new BroadcastManager(true, sc.getConf, securityMgr)
+    mapOutputTracker = spy(new MyMapOutputTrackerMaster(sc.getConf, 
broadcastManager))
+    blockManagerMaster = spy(new MyBlockManagerMaster(sc.getConf))
+    scheduler = new DAGScheduler(
+      sc,
+      taskScheduler,
+      sc.listenerBus,
+      mapOutputTracker,
+      blockManagerMaster,
+      sc.env)
+    dagEventProcessLoopTester = new 
DAGSchedulerEventProcessLoopTester(scheduler)
+  }
+
+  override def afterEach(): Unit = {
+    try {
+      scheduler.stop()
+      dagEventProcessLoopTester.stop()
+      mapOutputTracker.stop()
+      broadcastManager.stop()
+    } finally {
+      super.afterEach()
+    }
+  }
+
+  override def afterAll(): Unit = {
+    super.afterAll()
+  }
+
+  /**
+   * Type of RDD we use for testing. Note that we should never call the real 
RDD compute methods.
+   * This is a pair RDD type so it can always be used in ShuffleDependencies.
+   */
+  type PairOfIntsRDD = RDD[(Int, Int)]
+
+  /**
+   * Process the supplied event as if it were the top of the DAGScheduler 
event queue, expecting
+   * the scheduler not to exit.
+   *
+   * After processing the event, submit waiting stages as is done on most 
iterations of the
+   * DAGScheduler event loop.
+   */
+  protected def runEvent(event: DAGSchedulerEvent): Unit = {
+    // Ensure the initialization of various components
+    sc
+    dagEventProcessLoopTester.post(event)
+  }
+
+  /**
+   * When we submit dummy Jobs, this is the compute function we supply. Except 
in a local test
+   * below, we do not expect this function to ever be executed; instead, we 
will return results
+   * directly through CompletionEvents.
+   */
+  private val jobComputeFunc = (context: TaskContext, it: Iterator[(_)]) =>
+    it.next.asInstanceOf[Tuple2[_, _]]._1
+
+  /** Send the given CompletionEvent messages for the tasks in the TaskSet. */
+  protected def complete(taskSet: TaskSet, taskEndInfos: Seq[(TaskEndReason, 
Any)]): Unit = {
+    assert(taskSet.tasks.size >= taskEndInfos.size)
+    for ((result, i) <- taskEndInfos.zipWithIndex) {
+      if (i < taskSet.tasks.size) {
+        runEvent(makeCompletionEvent(taskSet.tasks(i), result._1, result._2))
+      }
+    }
+  }
+
+  protected def completeWithAccumulator(
+      accumId: Long,
+      taskSet: TaskSet,
+      results: Seq[(TaskEndReason, Any)]): Unit = {
+    assert(taskSet.tasks.size >= results.size)
+    for ((result, i) <- results.zipWithIndex) {
+      if (i < taskSet.tasks.size) {
+        runEvent(makeCompletionEvent(
+          taskSet.tasks(i),
+          result._1,
+          result._2,
+          Seq(AccumulatorSuite.createLongAccum("", initValue = 1, id = 
accumId))))
+      }
+    }
+  }
+
+  /** Submits a job to the scheduler and returns the job id. */
+  protected def submit(
+      rdd: RDD[_],
+      partitions: Array[Int],
+      func: (TaskContext, Iterator[_]) => _ = jobComputeFunc,
+      listener: JobListener = jobListener,
+      properties: Properties = null): Int = {
+    val jobId = scheduler.nextJobId.getAndIncrement()
+    runEvent(JobSubmitted(jobId, rdd, func, partitions, CallSite("", ""), 
listener, properties))
+    jobId
+  }
+
+  /** Submits a map stage to the scheduler and returns the job id. */
+  protected def submitMapStage(
+      shuffleDep: ShuffleDependency[_, _, _],
+      listener: JobListener = jobListener): Int = {
+    val jobId = scheduler.nextJobId.getAndIncrement()
+    runEvent(MapStageSubmitted(jobId, shuffleDep, CallSite("", ""), listener))
+    jobId
+  }
+
+  /** Sends TaskSetFailed to the scheduler. */
+  protected def failed(taskSet: TaskSet, message: String): Unit = {
+    runEvent(TaskSetFailed(taskSet, message, None))
+  }
+
+  /** Sends JobCancelled to the DAG scheduler. */
+  protected def cancel(jobId: Int): Unit = {
+    runEvent(JobCancelled(jobId, None))
+  }
+
+  /** Make some tasks in task set success and check results. */
+  protected def completeAndCheckAnswer(
+      taskSet: TaskSet,
+      taskEndInfos: Seq[(TaskEndReason, Any)],
+      expected: Map[Int, Any]): Unit = {
+    complete(taskSet, taskEndInfos)
+    assert(this.results === expected)
+  }
+
+  // Helper function to validate state when creating tests for task failures
+  private def checkStageId(stageId: Int, attempt: Int, stageAttempt: TaskSet): 
Unit = {
+    assert(stageAttempt.stageId === stageId)
+    assert(stageAttempt.stageAttemptId == attempt)
+  }
+
+  // Helper functions to extract commonly used code in Fetch Failure test cases
+  protected def setupStageAbortTest(sc: SparkContext): Unit = {

Review comment:
       OK




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to