Github user ivoson commented on a diff in the pull request: https://github.com/apache/spark/pull/20244#discussion_r166387716 --- Diff: core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala --- @@ -2399,6 +2424,115 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } } + /** + * In this test, we simply simulate the scene in concurrent jobs using the same + * rdd which is marked to do checkpoint: + * Job one has already finished the spark job, and start the process of doCheckpoint; + * Job two is submitted, and submitMissingTasks is called. + * In submitMissingTasks, if taskSerialization is called before doCheckpoint is done, + * while part calculates from stage.rdd.partitions is called after doCheckpoint is done, + * we may get a ClassCastException when execute the task because of some rdd will do + * Partition cast. + * + * With this test case, just want to indicate that we should do taskSerialization and + * part calculate in submitMissingTasks with the same rdd checkpoint status. + */ + test("SPARK-23053: avoid ClassCastException in concurrent execution with checkpoint") { + // set checkpointDir. + val tempDir = Utils.createTempDir() + val checkpointDir = File.createTempFile("temp", "", tempDir) + checkpointDir.delete() + sc.setCheckpointDir(checkpointDir.toString) + + // Semaphores to control the process sequence for the two threads below. + val semaphore1 = new Semaphore(0) + val semaphore2 = new Semaphore(0) + + val rdd = new WrappedRDD(sc.makeRDD(1 to 100, 4)) + rdd.checkpoint() + + val checkpointRunnable = new Runnable { + override def run() = { + // Simply simulate what RDD.doCheckpoint() do here. + rdd.doCheckpointCalled = true + val checkpointData = rdd.checkpointData.get + RDDCheckpointData.synchronized { + if (checkpointData.cpState == CheckpointState.Initialized) { + checkpointData.cpState = CheckpointState.CheckpointingInProgress + } + } + + val newRDD = checkpointData.doCheckpoint() + + // Release semaphore1 after job triggered in checkpoint finished, so that taskBinary + // serialization can start. + semaphore1.release() + // Wait until taskBinary serialization finished in submitMissingTasksThread. + semaphore2.acquire() + + // Update our state and truncate the RDD lineage. + RDDCheckpointData.synchronized { + checkpointData.cpRDD = Some(newRDD) + checkpointData.cpState = CheckpointState.Checkpointed + rdd.markCheckpointed() + } + semaphore1.release() + } + } + + val submitMissingTasksRunnable = new Runnable { + override def run() = { + // Simply simulate the process of submitMissingTasks. + // Wait until doCheckpoint job running finished, but checkpoint status not changed. + semaphore1.acquire() + + val ser = SparkEnv.get.closureSerializer.newInstance() + + // Simply simulate task serialization while submitMissingTasks. + // Task serialized with rdd checkpoint not finished. + val cleanedFunc = sc.clean(Utils.getIteratorSize _) + val func = (ctx: TaskContext, it: Iterator[Int]) => cleanedFunc(it) + val taskBinaryBytes = JavaUtils.bufferToArray( + ser.serialize((rdd, func): AnyRef)) + // Because partition calculate is in a synchronized block, so in the fixed code + // partition is calculated here. + val correctPart = rdd.partitions(0) + + // Release semaphore2 so changing checkpoint status to Checkpointed will be done in + // checkpointThread. + semaphore2.release() + // Wait until checkpoint status changed to Checkpointed in checkpointThread. + semaphore1.acquire() + + // Part calculated with rdd checkpoint already finished. --- End diff -- thanks for the advise, it is really helpful for understanding, will update this.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org