beliefer commented on a change in pull request #29747: URL: https://github.com/apache/spark/pull/29747#discussion_r488407225
########## File path: core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerTestHelper.scala ########## @@ -0,0 +1,614 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.util.Properties +import java.util.concurrent.atomic.AtomicBoolean + +import scala.annotation.meta.param +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} +import scala.util.control.NonFatal + +import org.mockito.Mockito.spy +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} + +import org.apache.spark._ +import org.apache.spark.broadcast.BroadcastManager +import org.apache.spark.executor.ExecutorMetrics +import org.apache.spark.rdd.{DeterministicLevel, RDD} +import org.apache.spark.scheduler.SchedulingMode.SchedulingMode +import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} +import org.apache.spark.util.{AccumulatorV2, CallSite} + +class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler) + extends DAGSchedulerEventProcessLoop(dagScheduler) { + + override def post(event: DAGSchedulerEvent): Unit = { + try { + // Forward event to `onReceive` directly to avoid processing event asynchronously. + onReceive(event) + } catch { + case NonFatal(e) => onError(e) + } + } + + override def onError(e: Throwable): Unit = { + logError("Error in DAGSchedulerEventLoop: ", e) + dagScheduler.stop() + throw e + } + +} + +class MyCheckpointRDD( + sc: SparkContext, + numPartitions: Int, + dependencies: List[Dependency[_]], + locations: Seq[Seq[String]] = Nil, + @(transient @param) tracker: MapOutputTrackerMaster = null, + indeterminate: Boolean = false) + extends MyRDD(sc, numPartitions, dependencies, locations, tracker, indeterminate) { + + // Allow doCheckpoint() on this RDD. + override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = + Iterator.empty +} + +/** + * An RDD for passing to DAGScheduler. These RDDs will use the dependencies and + * preferredLocations (if any) that are passed to them. They are deliberately not executable + * so we can test that DAGScheduler does not try to execute RDDs locally. + * + * Optionally, one can pass in a list of locations to use as preferred locations for each task, + * and a MapOutputTrackerMaster to enable reduce task locality. We pass the tracker separately + * because, in this test suite, it won't be the same as sc.env.mapOutputTracker. + */ +class MyRDD( + sc: SparkContext, + numPartitions: Int, + dependencies: List[Dependency[_]], + locations: Seq[Seq[String]] = Nil, + @(transient @param) tracker: MapOutputTrackerMaster = null, + indeterminate: Boolean = false) + extends RDD[(Int, Int)](sc, dependencies) with Serializable { + + override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = + throw new RuntimeException("should not be reached") + + override def getPartitions: Array[Partition] = (0 until numPartitions).map(i => new Partition { + override def index: Int = i + }).toArray + + override protected def getOutputDeterministicLevel = { + if (indeterminate) DeterministicLevel.INDETERMINATE else super.getOutputDeterministicLevel + } + + override def getPreferredLocations(partition: Partition): Seq[String] = { + if (locations.isDefinedAt(partition.index)) { + locations(partition.index) + } else if (tracker != null && dependencies.size == 1 && + dependencies(0).isInstanceOf[ShuffleDependency[_, _, _]]) { + // If we have only one shuffle dependency, use the same code path as ShuffledRDD for locality + val dep = dependencies(0).asInstanceOf[ShuffleDependency[_, _, _]] + tracker.getPreferredLocationsForShuffle(dep, partition.index) + } else { + Nil + } + } + + override def toString: String = "DAGSchedulerSuiteRDD " + id +} + +class DAGSchedulerSuiteDummyException extends Exception + +class DAGSchedulerTestHelper extends SparkFunSuite with TempLocalSparkContext with TimeLimits { + + import DAGSchedulerTestHelper._ + + // Necessary to make ScalaTest 3.x interrupt a thread on the JVM like ScalaTest 2.2.x + implicit val defaultSignaler: Signaler = ThreadSignaler + + private var firstInit: Boolean = _ + /** Set of TaskSets the DAGScheduler has requested executed. */ + val taskSets = scala.collection.mutable.Buffer[TaskSet]() + + /** Stages for which the DAGScheduler has called TaskScheduler.cancelTasks(). */ + val cancelledStages = new HashSet[Int]() + + val tasksMarkedAsCompleted = new ArrayBuffer[Task[_]]() + + val taskScheduler = new TaskScheduler() { + override def schedulingMode: SchedulingMode = SchedulingMode.FIFO + override def rootPool: Pool = new Pool("", schedulingMode, 0, 0) + override def start() = {} + override def stop() = {} + override def executorHeartbeatReceived( + execId: String, + accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], + blockManagerId: BlockManagerId, + executorUpdates: Map[(Int, Int), ExecutorMetrics]): Boolean = true + override def submitTasks(taskSet: TaskSet) = { + // normally done by TaskSetManager + taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) + taskSets += taskSet + } + override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = { + cancelledStages += stageId + } + override def killTaskAttempt( + taskId: Long, interruptThread: Boolean, reason: String): Boolean = false + override def killAllTaskAttempts( + stageId: Int, interruptThread: Boolean, reason: String): Unit = {} + override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = { + taskSets.filter(_.stageId == stageId).lastOption.foreach { ts => + val tasks = ts.tasks.filter(_.partitionId == partitionId) + assert(tasks.length == 1) + tasksMarkedAsCompleted += tasks.head + } + } + override def setDAGScheduler(dagScheduler: DAGScheduler) = {} + override def defaultParallelism() = 2 + override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} + override def workerRemoved(workerId: String, host: String, message: String): Unit = {} + override def applicationAttemptId(): Option[String] = None + override def executorDecommission( + executorId: String, + decommissionInfo: ExecutorDecommissionInfo): Unit = {} + override def getExecutorDecommissionState( + executorId: String): Option[ExecutorDecommissionState] = None + } + + /** + * Listeners which records some information to verify in UTs. Getter-kind methods in this class + * ensures the value is returned after ensuring there's no event to process, as well as the + * value is immutable: prevent showing odd result by race condition. + */ + class EventInfoRecordingListener extends SparkListener { + private val _submittedStageInfos = new HashSet[StageInfo] + private val _successfulStages = new HashSet[Int] + private val _failedStages = new ArrayBuffer[Int] + private val _stageByOrderOfExecution = new ArrayBuffer[Int] + private val _endedTasks = new HashSet[Long] + + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { + _submittedStageInfos += stageSubmitted.stageInfo + } + + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + val stageInfo = stageCompleted.stageInfo + _stageByOrderOfExecution += stageInfo.stageId + if (stageInfo.failureReason.isEmpty) { + _successfulStages += stageInfo.stageId + } else { + _failedStages += stageInfo.stageId + } + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + _endedTasks += taskEnd.taskInfo.taskId + } + + def submittedStageInfos: Set[StageInfo] = { + waitForListeners() + _submittedStageInfos.toSet + } + + def successfulStages: Set[Int] = { + waitForListeners() + _successfulStages.toSet + } + + def failedStages: List[Int] = { + waitForListeners() + _failedStages.toList + } + + def stageByOrderOfExecution: List[Int] = { + waitForListeners() + _stageByOrderOfExecution.toList + } + + def endedTasks: Set[Long] = { + waitForListeners() + _endedTasks.toSet + } + + private def waitForListeners(): Unit = sc.listenerBus.waitUntilEmpty() + } + + var sparkListener: EventInfoRecordingListener = null + + var blockManagerMaster: BlockManagerMaster = null + var mapOutputTracker: MapOutputTrackerMaster = null + var broadcastManager: BroadcastManager = null + var securityMgr: SecurityManager = null + var scheduler: DAGScheduler = null + var dagEventProcessLoopTester: DAGSchedulerEventProcessLoop = null + + /** + * Set of cache locations to return from our mock BlockManagerMaster. + * Keys are (rdd ID, partition ID). Anything not present will return an empty + * list of cache locations silently. + */ + val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]] + // stub out BlockManagerMaster.getLocations to use our cacheLocations + class MyBlockManagerMaster(conf: SparkConf) extends BlockManagerMaster(null, null, conf, true) { + override def getLocations(blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { + blockIds.map { + _.asRDDId.map { id => (id.rddId -> id.splitIndex) + }.flatMap { key => cacheLocations.get(key) + }.getOrElse(Seq()) + }.toIndexedSeq + } + override def removeExecutor(execId: String): Unit = { + // don't need to propagate to the driver, which we don't have + } + } + + /** The list of results that DAGScheduler has collected. */ + val results = new HashMap[Int, Any]() + var failure: Exception = _ + val jobListener = new JobListener() { + override def taskSucceeded(index: Int, result: Any) = results.put(index, result) + override def jobFailed(exception: Exception) = { failure = exception } + } + + /** A simple helper class for creating custom JobListeners */ + class SimpleListener extends JobListener { Review comment: OK ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
