vanzin commented on a change in pull request #25706: [SPARK-26989][CORE][TEST] 
DAGSchedulerSuite: ensure listeners are fully processed before checking 
recorded values
URL: https://github.com/apache/spark/pull/25706#discussion_r321923311
 
 

 ##########
 File path: 
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
 ##########
 @@ -174,31 +174,72 @@ class DAGSchedulerSuite extends SparkFunSuite with 
LocalSparkContext with TimeLi
   /** Length of time to wait while draining listener events. */
   val WAIT_TIMEOUT_MILLIS = 10000
 
-  val submittedStageInfos = new HashSet[StageInfo]
-  val successfulStages = new HashSet[Int]
-  val failedStages = new ArrayBuffer[Int]
-  val stageByOrderOfExecution = new ArrayBuffer[Int]
-  val endedTasks = new HashSet[Long]
-  val sparkListener = new SparkListener() {
+  /**
+   * Listeners which records some information to verify in UTs. Getter-kind 
methods in this class
+   * ensures the value is returned after ensuring there's no event to process, 
as well as the
+   * value is immutable: prevent showing odd result by race condition.
+   */
+  class EventInfoRecordingListener extends SparkListener {
+    private val _submittedStageInfos = new HashSet[StageInfo]
+    private val _successfulStages = new HashSet[Int]
+    private val _failedStages = new ArrayBuffer[Int]
+    private val _stageByOrderOfExecution = new ArrayBuffer[Int]
+    private val _endedTasks = new HashSet[Long]
+
     override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) 
{
-      submittedStageInfos += stageSubmitted.stageInfo
+      _submittedStageInfos += stageSubmitted.stageInfo
     }
 
     override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) 
{
       val stageInfo = stageCompleted.stageInfo
-      stageByOrderOfExecution += stageInfo.stageId
+      _stageByOrderOfExecution += stageInfo.stageId
       if (stageInfo.failureReason.isEmpty) {
-        successfulStages += stageInfo.stageId
+        _successfulStages += stageInfo.stageId
       } else {
-        failedStages += stageInfo.stageId
+        _failedStages += stageInfo.stageId
       }
     }
 
     override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
-      endedTasks += taskEnd.taskInfo.taskId
+      _endedTasks += taskEnd.taskInfo.taskId
+    }
+
+    def submittedStageInfos: Set[StageInfo] = withWaitingListenerUntilEmpty {
+      _submittedStageInfos.toSet
+    }
+
+    def successfulStages: Set[Int] = withWaitingListenerUntilEmpty {
+      _successfulStages.toSet
+    }
+
+    def failedStages: List[Int] = withWaitingListenerUntilEmpty {
+      _failedStages.toList
+    }
+
+    def stageByOrderOfExecution: List[Int] = withWaitingListenerUntilEmpty {
+      _stageByOrderOfExecution.toList
+    }
+
+    def endedTask: Set[Long] = withWaitingListenerUntilEmpty {
+      _endedTasks.toSet
+    }
+
+    def clear(): Unit = {
+      _submittedStageInfos.clear()
+      _successfulStages.clear()
+      _failedStages.clear()
+      _stageByOrderOfExecution.clear()
+      _endedTasks.clear()
+    }
+
+    private def withWaitingListenerUntilEmpty[T](fn: => T): T = {
+      sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
 
 Review comment:
   On a side note, the timeout for this method is hardcoded to a bunch of 
different arbitrary values in so many different places, that it may be good at 
some point to just have a default value in `LiveListenerBus`. I doubt any test 
code actually depends on a specific timeout here.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to