Github user cloud-fan commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20449#discussion_r170845905
  
    --- Diff: core/src/test/scala/org/apache/spark/JobCancellationSuite.scala 
---
    @@ -320,6 +321,58 @@ class JobCancellationSuite extends SparkFunSuite with 
Matchers with BeforeAndAft
         f2.get()
       }
     
    +  test("Interruptible iterator of shuffle reader") {
    +    // In this test case, we create a Spark job of two stages. The second 
stage is cancelled during
    +    // execution and a counter is used to make sure that the corresponding 
tasks are indeed
    +    // cancelled.
    +    import JobCancellationSuite._
    +    val numSlice = 2
    +    sc = new SparkContext(s"local[$numSlice]", "test")
    +
    +    val f = sc.parallelize(1 to 1000, numSlice).map { i => (i, i) }
    +      .repartitionAndSortWithinPartitions(new HashPartitioner(2))
    +      .mapPartitions { iter =>
    +        taskStartedSemaphore.release()
    +        iter
    +      }.foreachAsync { x =>
    +        if (x._1 >= 10) {
    +          // This block of code is partially executed. It will be blocked 
when x._1 >= 10 and the
    +          // next iteration will be cancelled if the source iterator is 
interruptible. Then in this
    +          // case, the maximum num of increment would be 11(|1...10| + 
|N|) where N is the first
    +          // element in another partition(assuming no ordering guarantee).
    +          taskCancelledSemaphore.acquire()
    +        }
    +        executionOfInterruptibleCounter.getAndIncrement()
    +    }
    +
    +    val taskCompletedSem = new Semaphore(0)
    +
    +    sc.addSparkListener(new SparkListener {
    +      override def onStageCompleted(stageCompleted: 
SparkListenerStageCompleted): Unit = {
    +        // release taskCancelledSemaphore when cancelTasks event has been 
posted
    +        if (stageCompleted.stageInfo.stageId == 1) {
    +          taskCancelledSemaphore.release(1000)
    +        }
    +      }
    +
    +      override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
    +        if (taskEnd.stageId == 1) { // make sure tasks are completed
    +          taskCompletedSem.release()
    +        }
    +      }
    +    })
    +
    +    taskStartedSemaphore.acquire()
    +    f.cancel()
    +
    +    val e = intercept[SparkException] { f.get() }.getCause
    +    assert(e.getMessage.contains("cancelled") || 
e.getMessage.contains("killed"))
    +
    +    // Make sure tasks are indeed completed.
    +    taskCompletedSem.acquire(numSlice)
    +    assert(executionOfInterruptibleCounter.get() <= 11)
    --- End diff --
    
    For simplicity, can we just test 1 partition/task?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to