Github user ivoson commented on a diff in the pull request: https://github.com/apache/spark/pull/20244#discussion_r161145542 --- Diff: core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala --- @@ -2399,6 +2417,93 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } } + /** + * In this test, we simply simulate the scene in concurrent jobs using the same + * rdd which is marked to do checkpoint: + * Job one has already finished the spark job, and start the process of doCheckpoint; + * Job two is submitted, and submitMissingTasks is called. + * In submitMissingTasks, if taskSerialization is called before doCheckpoint is done, + * while part calculates from stage.rdd.partitions is called after doCheckpoint is done, + * we may get a ClassCastException when execute the task because of some rdd will do + * Partition cast. + * + * With this test case, just want to indicate that we should do taskSerialization and + * part calculate in submitMissingTasks with the same rdd checkpoint status. + */ + test("task part misType with checkpoint rdd in concurrent execution scenes") { + // set checkpointDir. + val tempDir = Utils.createTempDir() + val checkpointDir = File.createTempFile("temp", "", tempDir) + checkpointDir.delete() + sc.setCheckpointDir(checkpointDir.toString) + + val latch = new CountDownLatch(2) + val semaphore1 = new Semaphore(0) + val semaphore2 = new Semaphore(0) + + val rdd = new WrappedRDD(sc.makeRDD(1 to 100, 4)) + rdd.checkpoint() + + val checkpointRunnable = new Runnable { + override def run() = { + // Simply simulate what RDD.doCheckpoint() do here. + rdd.doCheckpointCalled = true + val checkpointData = rdd.checkpointData.get + RDDCheckpointData.synchronized { + if (checkpointData.cpState == CheckpointState.Initialized) { + checkpointData.cpState = CheckpointState.CheckpointingInProgress + } + } + + val newRDD = checkpointData.doCheckpoint() + + // Release semaphore1 after job triggered in checkpoint finished. + semaphore1.release() + semaphore2.acquire() + // Update our state and truncate the RDD lineage. + RDDCheckpointData.synchronized { + checkpointData.cpRDD = Some(newRDD) + checkpointData.cpState = CheckpointState.Checkpointed + rdd.markCheckpointed() + } + semaphore1.release() + + latch.countDown() + } + } + + val submitMissingTasksRunnable = new Runnable { + override def run() = { + // Simply simulate the process of submitMissingTasks. + val ser = SparkEnv.get.closureSerializer.newInstance() + semaphore1.acquire() + // Simulate task serialization while submitMissingTasks. + // Task serialized with rdd checkpoint not finished. + val cleanedFunc = sc.clean(Utils.getIteratorSize _) + val func = (ctx: TaskContext, it: Iterator[Int]) => cleanedFunc(it) + val taskBinaryBytes = JavaUtils.bufferToArray( + ser.serialize((rdd, func): AnyRef)) + semaphore2.release() + semaphore1.acquire() + // Part calculated with rdd checkpoint already finished. + val (taskRdd, taskFunc) = ser.deserialize[(RDD[Int], (TaskContext, Iterator[Int]) => Unit)]( + ByteBuffer.wrap(taskBinaryBytes), Thread.currentThread.getContextClassLoader) + val part = rdd.partitions(0) + intercept[ClassCastException] { --- End diff -- it is a reproduce case, i will fix this.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org