vanzin commented on a change in pull request #25706: [SPARK-26989][CORE][TEST]
DAGSchedulerSuite: ensure listeners are fully processed before checking
recorded values
URL: https://github.com/apache/spark/pull/25706#discussion_r321923031
##########
File path:
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
##########
@@ -174,31 +174,72 @@ class DAGSchedulerSuite extends SparkFunSuite with
LocalSparkContext with TimeLi
/** Length of time to wait while draining listener events. */
val WAIT_TIMEOUT_MILLIS = 10000
- val submittedStageInfos = new HashSet[StageInfo]
- val successfulStages = new HashSet[Int]
- val failedStages = new ArrayBuffer[Int]
- val stageByOrderOfExecution = new ArrayBuffer[Int]
- val endedTasks = new HashSet[Long]
- val sparkListener = new SparkListener() {
+ /**
+ * Listeners which records some information to verify in UTs. Getter-kind
methods in this class
+ * ensures the value is returned after ensuring there's no event to process,
as well as the
+ * value is immutable: prevent showing odd result by race condition.
+ */
+ class EventInfoRecordingListener extends SparkListener {
+ private val _submittedStageInfos = new HashSet[StageInfo]
+ private val _successfulStages = new HashSet[Int]
+ private val _failedStages = new ArrayBuffer[Int]
+ private val _stageByOrderOfExecution = new ArrayBuffer[Int]
+ private val _endedTasks = new HashSet[Long]
+
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted)
{
- submittedStageInfos += stageSubmitted.stageInfo
+ _submittedStageInfos += stageSubmitted.stageInfo
}
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted)
{
val stageInfo = stageCompleted.stageInfo
- stageByOrderOfExecution += stageInfo.stageId
+ _stageByOrderOfExecution += stageInfo.stageId
if (stageInfo.failureReason.isEmpty) {
- successfulStages += stageInfo.stageId
+ _successfulStages += stageInfo.stageId
} else {
- failedStages += stageInfo.stageId
+ _failedStages += stageInfo.stageId
}
}
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
- endedTasks += taskEnd.taskInfo.taskId
+ _endedTasks += taskEnd.taskInfo.taskId
+ }
+
+ def submittedStageInfos: Set[StageInfo] = withWaitingListenerUntilEmpty {
+ _submittedStageInfos.toSet
+ }
+
+ def successfulStages: Set[Int] = withWaitingListenerUntilEmpty {
+ _successfulStages.toSet
+ }
+
+ def failedStages: List[Int] = withWaitingListenerUntilEmpty {
+ _failedStages.toList
+ }
+
+ def stageByOrderOfExecution: List[Int] = withWaitingListenerUntilEmpty {
+ _stageByOrderOfExecution.toList
+ }
+
+ def endedTask: Set[Long] = withWaitingListenerUntilEmpty {
+ _endedTasks.toSet
+ }
+
+ def clear(): Unit = {
+ _submittedStageInfos.clear()
+ _successfulStages.clear()
+ _failedStages.clear()
+ _stageByOrderOfExecution.clear()
+ _endedTasks.clear()
+ }
+
+ private def withWaitingListenerUntilEmpty[T](fn: => T): T = {
Review comment:
This is kind of a weird method because `fn` doesn't need to run in the
context of this method, just after all the events have been processes. So in
the getters you could just:
```
def endedTask: Set[Long] = {
waitForListeners()
_endedTasks.toSet
}
def waitForListenerts(): Unit =
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
```
That makes each getter a little longer but looks less weird.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]