beliefer commented on a change in pull request #29747: URL: https://github.com/apache/spark/pull/29747#discussion_r488451628
########## File path: core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerTestHelper.scala ########## @@ -0,0 +1,614 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.util.Properties +import java.util.concurrent.atomic.AtomicBoolean + +import scala.annotation.meta.param +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} +import scala.util.control.NonFatal + +import org.mockito.Mockito.spy +import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} + +import org.apache.spark._ +import org.apache.spark.broadcast.BroadcastManager +import org.apache.spark.executor.ExecutorMetrics +import org.apache.spark.rdd.{DeterministicLevel, RDD} +import org.apache.spark.scheduler.SchedulingMode.SchedulingMode +import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} +import org.apache.spark.util.{AccumulatorV2, CallSite} + +class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler) + extends DAGSchedulerEventProcessLoop(dagScheduler) { + + override def post(event: DAGSchedulerEvent): Unit = { + try { + // Forward event to `onReceive` directly to avoid processing event asynchronously. + onReceive(event) + } catch { + case NonFatal(e) => onError(e) + } + } + + override def onError(e: Throwable): Unit = { + logError("Error in DAGSchedulerEventLoop: ", e) + dagScheduler.stop() + throw e + } + +} + +class MyCheckpointRDD( + sc: SparkContext, + numPartitions: Int, + dependencies: List[Dependency[_]], + locations: Seq[Seq[String]] = Nil, + @(transient @param) tracker: MapOutputTrackerMaster = null, + indeterminate: Boolean = false) + extends MyRDD(sc, numPartitions, dependencies, locations, tracker, indeterminate) { + + // Allow doCheckpoint() on this RDD. + override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = + Iterator.empty +} + +/** + * An RDD for passing to DAGScheduler. These RDDs will use the dependencies and + * preferredLocations (if any) that are passed to them. They are deliberately not executable + * so we can test that DAGScheduler does not try to execute RDDs locally. + * + * Optionally, one can pass in a list of locations to use as preferred locations for each task, + * and a MapOutputTrackerMaster to enable reduce task locality. We pass the tracker separately + * because, in this test suite, it won't be the same as sc.env.mapOutputTracker. + */ +class MyRDD( + sc: SparkContext, + numPartitions: Int, + dependencies: List[Dependency[_]], + locations: Seq[Seq[String]] = Nil, + @(transient @param) tracker: MapOutputTrackerMaster = null, + indeterminate: Boolean = false) + extends RDD[(Int, Int)](sc, dependencies) with Serializable { + + override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = + throw new RuntimeException("should not be reached") + + override def getPartitions: Array[Partition] = (0 until numPartitions).map(i => new Partition { + override def index: Int = i + }).toArray + + override protected def getOutputDeterministicLevel = { + if (indeterminate) DeterministicLevel.INDETERMINATE else super.getOutputDeterministicLevel + } + + override def getPreferredLocations(partition: Partition): Seq[String] = { + if (locations.isDefinedAt(partition.index)) { + locations(partition.index) + } else if (tracker != null && dependencies.size == 1 && + dependencies(0).isInstanceOf[ShuffleDependency[_, _, _]]) { + // If we have only one shuffle dependency, use the same code path as ShuffledRDD for locality + val dep = dependencies(0).asInstanceOf[ShuffleDependency[_, _, _]] + tracker.getPreferredLocationsForShuffle(dep, partition.index) + } else { + Nil + } + } + + override def toString: String = "DAGSchedulerSuiteRDD " + id +} + +class DAGSchedulerSuiteDummyException extends Exception + +class DAGSchedulerTestHelper extends SparkFunSuite with TempLocalSparkContext with TimeLimits { Review comment: Offline discussion between @Ngone51 and me, decided to use `DAGSchedulerTestBase` ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
