Github user ivoson commented on a diff in the pull request:
https://github.com/apache/spark/pull/20244#discussion_r161145542
--- Diff:
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala ---
@@ -2399,6 +2417,93 @@ class DAGSchedulerSuite extends SparkFunSuite with
LocalSparkContext with TimeLi
}
}
+ /**
+ * In this test, we simply simulate the scene in concurrent jobs using
the same
+ * rdd which is marked to do checkpoint:
+ * Job one has already finished the spark job, and start the process of
doCheckpoint;
+ * Job two is submitted, and submitMissingTasks is called.
+ * In submitMissingTasks, if taskSerialization is called before
doCheckpoint is done,
+ * while part calculates from stage.rdd.partitions is called after
doCheckpoint is done,
+ * we may get a ClassCastException when execute the task because of some
rdd will do
+ * Partition cast.
+ *
+ * With this test case, just want to indicate that we should do
taskSerialization and
+ * part calculate in submitMissingTasks with the same rdd checkpoint
status.
+ */
+ test("task part misType with checkpoint rdd in concurrent execution
scenes") {
+ // set checkpointDir.
+ val tempDir = Utils.createTempDir()
+ val checkpointDir = File.createTempFile("temp", "", tempDir)
+ checkpointDir.delete()
+ sc.setCheckpointDir(checkpointDir.toString)
+
+ val latch = new CountDownLatch(2)
+ val semaphore1 = new Semaphore(0)
+ val semaphore2 = new Semaphore(0)
+
+ val rdd = new WrappedRDD(sc.makeRDD(1 to 100, 4))
+ rdd.checkpoint()
+
+ val checkpointRunnable = new Runnable {
+ override def run() = {
+ // Simply simulate what RDD.doCheckpoint() do here.
+ rdd.doCheckpointCalled = true
+ val checkpointData = rdd.checkpointData.get
+ RDDCheckpointData.synchronized {
+ if (checkpointData.cpState == CheckpointState.Initialized) {
+ checkpointData.cpState =
CheckpointState.CheckpointingInProgress
+ }
+ }
+
+ val newRDD = checkpointData.doCheckpoint()
+
+ // Release semaphore1 after job triggered in checkpoint finished.
+ semaphore1.release()
+ semaphore2.acquire()
+ // Update our state and truncate the RDD lineage.
+ RDDCheckpointData.synchronized {
+ checkpointData.cpRDD = Some(newRDD)
+ checkpointData.cpState = CheckpointState.Checkpointed
+ rdd.markCheckpointed()
+ }
+ semaphore1.release()
+
+ latch.countDown()
+ }
+ }
+
+ val submitMissingTasksRunnable = new Runnable {
+ override def run() = {
+ // Simply simulate the process of submitMissingTasks.
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+ semaphore1.acquire()
+ // Simulate task serialization while submitMissingTasks.
+ // Task serialized with rdd checkpoint not finished.
+ val cleanedFunc = sc.clean(Utils.getIteratorSize _)
+ val func = (ctx: TaskContext, it: Iterator[Int]) => cleanedFunc(it)
+ val taskBinaryBytes = JavaUtils.bufferToArray(
+ ser.serialize((rdd, func): AnyRef))
+ semaphore2.release()
+ semaphore1.acquire()
+ // Part calculated with rdd checkpoint already finished.
+ val (taskRdd, taskFunc) = ser.deserialize[(RDD[Int], (TaskContext,
Iterator[Int]) => Unit)](
+ ByteBuffer.wrap(taskBinaryBytes),
Thread.currentThread.getContextClassLoader)
+ val part = rdd.partitions(0)
+ intercept[ClassCastException] {
--- End diff --
it is a reproduce case, i will fix this.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]