Repository: spark
Updated Branches:
  refs/heads/branch-1.6 4df1dd403 -> 9177ea383


[SPARK-11749][STREAMING] Duplicate creating the RDD in file stream when 
recovering from checkpoint data

Add a transient flag `DStream.restoredFromCheckpointData` to control the 
restore processing in DStream to avoid duplicate works:  check this flag first 
in `DStream.restoreCheckpointData`, only when `false`, the restore process will 
be executed.

Author: jhu-chang <gt.hu.ch...@gmail.com>

Closes #9765 from jhu-chang/SPARK-11749.

(cherry picked from commit f4346f612b6798517153a786f9172cf41618d34d)
Signed-off-by: Shixiong Zhu <shixi...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9177ea38
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9177ea38
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9177ea38

Branch: refs/heads/branch-1.6
Commit: 9177ea383a29653f0591a59e1ee2dff6b87d5a1c
Parents: 4df1dd4
Author: jhu-chang <gt.hu.ch...@gmail.com>
Authored: Thu Dec 17 17:53:15 2015 -0800
Committer: Shixiong Zhu <shixi...@databricks.com>
Committed: Thu Dec 17 17:54:14 2015 -0800

----------------------------------------------------------------------
 .../spark/streaming/dstream/DStream.scala       | 15 ++++--
 .../spark/streaming/CheckpointSuite.scala       | 56 ++++++++++++++++++--
 2 files changed, 62 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9177ea38/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala 
b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
index 1a6edf9..91a43e1 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
@@ -97,6 +97,8 @@ abstract class DStream[T: ClassTag] (
   private[streaming] val mustCheckpoint = false
   private[streaming] var checkpointDuration: Duration = null
   private[streaming] val checkpointData = new DStreamCheckpointData(this)
+  @transient
+  private var restoredFromCheckpointData = false
 
   // Reference to whole DStream graph
   private[streaming] var graph: DStreamGraph = null
@@ -507,11 +509,14 @@ abstract class DStream[T: ClassTag] (
    * override the updateCheckpointData() method would also need to override 
this method.
    */
   private[streaming] def restoreCheckpointData() {
-    // Create RDDs from the checkpoint data
-    logInfo("Restoring checkpoint data")
-    checkpointData.restore()
-    dependencies.foreach(_.restoreCheckpointData())
-    logInfo("Restored checkpoint data")
+    if (!restoredFromCheckpointData) {
+      // Create RDDs from the checkpoint data
+      logInfo("Restoring checkpoint data")
+      checkpointData.restore()
+      dependencies.foreach(_.restoreCheckpointData())
+      restoredFromCheckpointData = true
+      logInfo("Restored checkpoint data")
+    }
   }
 
   @throws(classOf[IOException])

http://git-wip-us.apache.org/repos/asf/spark/blob/9177ea38/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala 
b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index cd28d3c..f5f446f 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.streaming
 
-import java.io.{ObjectOutputStream, ByteArrayOutputStream, 
ByteArrayInputStream, File}
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, 
ObjectOutputStream}
 
 import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
 import scala.reflect.ClassTag
@@ -34,9 +34,30 @@ import org.scalatest.concurrent.Eventually._
 import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite, TestUtils}
-import org.apache.spark.streaming.dstream.{DStream, FileInputDStream}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.dstream._
 import org.apache.spark.streaming.scheduler._
-import org.apache.spark.util.{MutableURLClassLoader, Clock, ManualClock, Utils}
+import org.apache.spark.util.{Clock, ManualClock, MutableURLClassLoader, Utils}
+
+/**
+ * A input stream that records the times of restore() invoked
+ */
+private[streaming]
+class CheckpointInputDStream(ssc_ : StreamingContext) extends 
InputDStream[Int](ssc_) {
+  protected[streaming] override val checkpointData = new 
FileInputDStreamCheckpointData
+  override def start(): Unit = { }
+  override def stop(): Unit = { }
+  override def compute(time: Time): Option[RDD[Int]] = 
Some(ssc.sc.makeRDD(Seq(1)))
+  private[streaming]
+  class FileInputDStreamCheckpointData extends DStreamCheckpointData(this) {
+    @transient
+    var restoredTimes = 0
+    override def restore() {
+      restoredTimes += 1
+      super.restore()
+    }
+  }
+}
 
 /**
  * A trait of that can be mixed in to get methods for testing DStream 
operations under
@@ -110,7 +131,7 @@ trait DStreamCheckpointTester { self: SparkFunSuite =>
     new StreamingContext(SparkContext.getOrCreate(conf), batchDuration)
   }
 
-  private def generateOutput[V: ClassTag](
+  protected def generateOutput[V: ClassTag](
       ssc: StreamingContext,
       targetBatchTime: Time,
       checkpointDir: String,
@@ -715,6 +736,33 @@ class CheckpointSuite extends TestSuiteBase with 
DStreamCheckpointTester {
     }
   }
 
+  test("DStreamCheckpointData.restore invoking times") {
+    withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
+      ssc.checkpoint(checkpointDir)
+      val inputDStream = new CheckpointInputDStream(ssc)
+      val checkpointData = inputDStream.checkpointData
+      val mappedDStream = inputDStream.map(_ + 100)
+      val outputStream = new TestOutputStreamWithPartitions(mappedDStream)
+      outputStream.register()
+      // do two more times output
+      mappedDStream.foreachRDD(rdd => rdd.count())
+      mappedDStream.foreachRDD(rdd => rdd.count())
+      assert(checkpointData.restoredTimes === 0)
+      val batchDurationMillis = ssc.progressListener.batchDuration
+      generateOutput(ssc, Time(batchDurationMillis * 3), checkpointDir, 
stopSparkContext = true)
+      assert(checkpointData.restoredTimes === 0)
+    }
+    logInfo("*********** RESTARTING ************")
+    withStreamingContext(new StreamingContext(checkpointDir)) { ssc =>
+      val checkpointData =
+        
ssc.graph.getInputStreams().head.asInstanceOf[CheckpointInputDStream].checkpointData
+      assert(checkpointData.restoredTimes === 1)
+      ssc.start()
+      ssc.stop()
+      assert(checkpointData.restoredTimes === 1)
+    }
+  }
+
   // This tests whether spark can deserialize array object
   // refer to SPARK-5569
   test("recovery from checkpoint contains array object") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to