Github user squito commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20244#discussion_r165764166
  
    --- Diff: 
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala ---
    @@ -2399,6 +2424,115 @@ class DAGSchedulerSuite extends SparkFunSuite with 
LocalSparkContext with TimeLi
         }
       }
     
    +  /**
    +   * In this test, we simply simulate the scene in concurrent jobs using 
the same
    +   * rdd which is marked to do checkpoint:
    +   * Job one has already finished the spark job, and start the process of 
doCheckpoint;
    +   * Job two is submitted, and submitMissingTasks is called.
    +   * In submitMissingTasks, if taskSerialization is called before 
doCheckpoint is done,
    +   * while part calculates from stage.rdd.partitions is called after 
doCheckpoint is done,
    +   * we may get a ClassCastException when execute the task because of some 
rdd will do
    +   * Partition cast.
    +   *
    +   * With this test case, just want to indicate that we should do 
taskSerialization and
    +   * part calculate in submitMissingTasks with the same rdd checkpoint 
status.
    +   */
    +  test("SPARK-23053: avoid ClassCastException in concurrent execution with 
checkpoint") {
    +    // set checkpointDir.
    +    val tempDir = Utils.createTempDir()
    +    val checkpointDir = File.createTempFile("temp", "", tempDir)
    +    checkpointDir.delete()
    +    sc.setCheckpointDir(checkpointDir.toString)
    +
    +    // Semaphores to control the process sequence for the two threads 
below.
    +    val semaphore1 = new Semaphore(0)
    +    val semaphore2 = new Semaphore(0)
    +
    +    val rdd = new WrappedRDD(sc.makeRDD(1 to 100, 4))
    +    rdd.checkpoint()
    +
    +    val checkpointRunnable = new Runnable {
    +      override def run() = {
    +        // Simply simulate what RDD.doCheckpoint() do here.
    +        rdd.doCheckpointCalled = true
    +        val checkpointData = rdd.checkpointData.get
    +        RDDCheckpointData.synchronized {
    +          if (checkpointData.cpState == CheckpointState.Initialized) {
    +            checkpointData.cpState = 
CheckpointState.CheckpointingInProgress
    +          }
    +        }
    +
    +        val newRDD = checkpointData.doCheckpoint()
    +
    +        // Release semaphore1 after job triggered in checkpoint finished, 
so that taskBinary
    +        // serialization can start.
    +        semaphore1.release()
    +        // Wait until taskBinary serialization finished in 
submitMissingTasksThread.
    +        semaphore2.acquire()
    +
    +        // Update our state and truncate the RDD lineage.
    +        RDDCheckpointData.synchronized {
    +          checkpointData.cpRDD = Some(newRDD)
    +          checkpointData.cpState = CheckpointState.Checkpointed
    +          rdd.markCheckpointed()
    +        }
    +        semaphore1.release()
    +      }
    +    }
    +
    +    val submitMissingTasksRunnable = new Runnable {
    +      override def run() = {
    +        // Simply simulate the process of submitMissingTasks.
    +        // Wait until doCheckpoint job running finished, but checkpoint 
status not changed.
    +        semaphore1.acquire()
    +
    +        val ser = SparkEnv.get.closureSerializer.newInstance()
    +
    +        // Simply simulate task serialization while submitMissingTasks.
    +        // Task serialized with rdd checkpoint not finished.
    +        val cleanedFunc = sc.clean(Utils.getIteratorSize _)
    +        val func = (ctx: TaskContext, it: Iterator[Int]) => cleanedFunc(it)
    +        val taskBinaryBytes = JavaUtils.bufferToArray(
    +          ser.serialize((rdd, func): AnyRef))
    +        // Because partition calculate is in a synchronized block, so in 
the fixed code
    +        // partition is calculated here.
    +        val correctPart = rdd.partitions(0)
    +
    +        // Release semaphore2 so changing checkpoint status to 
Checkpointed will be done in
    +        // checkpointThread.
    +        semaphore2.release()
    +        // Wait until checkpoint status changed to Checkpointed in 
checkpointThread.
    +        semaphore1.acquire()
    +
    +        // Part calculated with rdd checkpoint already finished.
    +        val errPart = rdd.partitions(0)
    +
    +        // TaskBinary will be deserialized when run task in executor.
    +        val (taskRdd, taskFunc) = ser.deserialize[(RDD[Int], (TaskContext, 
Iterator[Int]) => Unit)](
    +          ByteBuffer.wrap(taskBinaryBytes), 
Thread.currentThread.getContextClassLoader)
    +
    +        val taskContext = mock(classOf[TaskContext])
    +        doNothing().when(taskContext).killTaskIfInterrupted()
    +
    +        // ClassCastException is expected with errPart.
    --- End diff --
    
    I think this is a bit easier to follow if you say
    
    Make sure our test case is setup correctly -- we expect a 
ClassCastException here if we use the `rdd.partitions` *after* checkpointing 
was done, but our binary bytes is from before it finished.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to