Repository: spark
Updated Branches:
  refs/heads/branch-1.2 f225b3cc1 -> 46654b066


[SPARK-4029][Streaming] Update streaming driver to reliably save and recover 
received block metadata on driver failures

As part of the initiative of preventing data loss on driver failure, this JIRA 
tracks the sub task of modifying the streaming driver to reliably save received 
block metadata, and recover them on driver restart.

This was solved by introducing a `ReceivedBlockTracker` that takes all the 
responsibility of managing the metadata of received blocks (i.e. 
`ReceivedBlockInfo`, and any actions on them (e.g, allocating blocks to 
batches, etc.). All actions to block info get written out to a write ahead log 
(using `WriteAheadLogManager`). On recovery, all the actions are replaying to 
recreate the pre-failure state of the `ReceivedBlockTracker`, which include the 
batch-to-block allocations and the unallocated blocks.

Furthermore, the `ReceiverInputDStream` was modified to create 
`WriteAheadLogBackedBlockRDD`s when file segment info is present in the 
`ReceivedBlockInfo`. After recovery of all the block info (through recovery 
`ReceivedBlockTracker`), the `WriteAheadLogBackedBlockRDD`s gets recreated with 
the recovered info, and jobs submitted. The data of the blocks gets pulled from 
the write ahead logs, thanks to the segment info present in the 
`ReceivedBlockInfo`.

This is still a WIP. Things that are missing here are.

- *End-to-end integration tests:* Unit tests that tests the driver recovery, by 
killing and restarting the streaming context, and verifying all the input data 
gets processed. This has been implemented but not included in this PR yet. A 
sneak peek of that DriverFailureSuite can be found in this PR (on my personal 
repo): https://github.com/tdas/spark/pull/25 I can either include it in this 
PR, or submit that as a separate PR after this gets in.

- *WAL cleanup:* Cleaning up the received data write ahead log, by calling 
`ReceivedBlockHandler.cleanupOldBlocks`. This is being worked on.

Author: Tathagata Das <[email protected]>

Closes #3026 from tdas/driver-ha-rbt and squashes the following commits:

a8009ed [Tathagata Das] Added comment
1d704bb [Tathagata Das] Enabled storing recovered WAL-backed blocks to BM
2ee2484 [Tathagata Das] More minor changes based on PR
47fc1e3 [Tathagata Das] Addressed PR comments.
9a7e3e4 [Tathagata Das] Refactored ReceivedBlockTracker API a bit to make 
things a little cleaner for users of the tracker.
af63655 [Tathagata Das] Minor changes.
fce2b21 [Tathagata Das] Removed commented lines
59496d3 [Tathagata Das] Changed class names, made allocation more explicit and 
added cleanup
19aec7d [Tathagata Das] Fixed casting bug.
f66d277 [Tathagata Das] Fix line lengths.
cda62ee [Tathagata Das] Added license
25611d6 [Tathagata Das] Minor changes before submitting PR
7ae0a7fb [Tathagata Das] Transferred changes from driver-ha-working branch

(cherry picked from commit 5f13759d3642ea5b58c12a756e7125ac19aff10e)
Signed-off-by: Tathagata Das <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/46654b06
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/46654b06
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/46654b06

Branch: refs/heads/branch-1.2
Commit: 46654b0661257f432932c6efc09c4c0983521834
Parents: f225b3c
Author: Tathagata Das <[email protected]>
Authored: Wed Nov 5 01:21:53 2014 -0800
Committer: Tathagata Das <[email protected]>
Committed: Wed Nov 5 01:22:16 2014 -0800

----------------------------------------------------------------------
 .../dstream/ReceiverInputDStream.scala          |  69 +++---
 .../rdd/WriteAheadLogBackedBlockRDD.scala       |   3 +-
 .../streaming/scheduler/JobGenerator.scala      |  21 +-
 .../scheduler/ReceivedBlockTracker.scala        | 230 ++++++++++++++++++
 .../streaming/scheduler/ReceiverTracker.scala   |  98 +++++---
 .../spark/streaming/BasicOperationsSuite.scala  |  19 +-
 .../streaming/ReceivedBlockTrackerSuite.scala   | 242 +++++++++++++++++++
 .../rdd/WriteAheadLogBackedBlockRDDSuite.scala  |   4 +-
 8 files changed, 597 insertions(+), 89 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/46654b06/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
index bb47d37..3e67161 100644
--- 
a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
@@ -17,15 +17,14 @@
 
 package org.apache.spark.streaming.dstream
 
-import scala.collection.mutable.HashMap
 import scala.reflect.ClassTag
 
 import org.apache.spark.rdd.{BlockRDD, RDD}
-import org.apache.spark.storage.BlockId
+import org.apache.spark.storage.{BlockId, StorageLevel}
 import org.apache.spark.streaming._
-import org.apache.spark.streaming.receiver.{WriteAheadLogBasedStoreResult, 
BlockManagerBasedStoreResult, Receiver}
+import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD
+import org.apache.spark.streaming.receiver.{Receiver, 
WriteAheadLogBasedStoreResult}
 import org.apache.spark.streaming.scheduler.ReceivedBlockInfo
-import org.apache.spark.SparkException
 
 /**
  * Abstract class for defining any 
[[org.apache.spark.streaming.dstream.InputDStream]]
@@ -40,9 +39,6 @@ import org.apache.spark.SparkException
 abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : 
StreamingContext)
   extends InputDStream[T](ssc_) {
 
-  /** Keeps all received blocks information */
-  private lazy val receivedBlockInfo = new HashMap[Time, 
Array[ReceivedBlockInfo]]
-
   /** This is an unique identifier for the network input stream. */
   val id = ssc.getNewReceiverStreamId()
 
@@ -58,24 +54,45 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient 
ssc_ : StreamingCont
 
   def stop() {}
 
-  /** Ask ReceiverInputTracker for received data blocks and generates RDDs 
with them. */
+  /**
+   * Generates RDDs with blocks received by the receiver of this stream. */
   override def compute(validTime: Time): Option[RDD[T]] = {
-    // If this is called for any time before the start time of the context,
-    // then this returns an empty RDD. This may happen when recovering from a
-    // master failure
-    if (validTime >= graph.startTime) {
-      val blockInfo = ssc.scheduler.receiverTracker.getReceivedBlockInfo(id)
-      receivedBlockInfo(validTime) = blockInfo
-      val blockIds = blockInfo.map { 
_.blockStoreResult.blockId.asInstanceOf[BlockId] }
-      Some(new BlockRDD[T](ssc.sc, blockIds))
-    } else {
-      Some(new BlockRDD[T](ssc.sc, Array.empty))
-    }
-  }
+    val blockRDD = {
 
-  /** Get information on received blocks. */
-  private[streaming] def getReceivedBlockInfo(time: Time) = {
-    receivedBlockInfo.get(time).getOrElse(Array.empty[ReceivedBlockInfo])
+      if (validTime < graph.startTime) {
+        // If this is called for any time before the start time of the context,
+        // then this returns an empty RDD. This may happen when recovering 
from a
+        // driver failure without any write ahead log to recover pre-failure 
data.
+        new BlockRDD[T](ssc.sc, Array.empty)
+      } else {
+        // Otherwise, ask the tracker for all the blocks that have been 
allocated to this stream
+        // for this batch
+        val blockInfos =
+          
ssc.scheduler.receiverTracker.getBlocksOfBatch(validTime).get(id).getOrElse(Seq.empty)
+        val blockStoreResults = blockInfos.map { _.blockStoreResult }
+        val blockIds = blockStoreResults.map { _.blockId.asInstanceOf[BlockId] 
}.toArray
+
+        // Check whether all the results are of the same type
+        val resultTypes = blockStoreResults.map { _.getClass }.distinct
+        if (resultTypes.size > 1) {
+          logWarning("Multiple result types in block information, WAL 
information will be ignored.")
+        }
+
+        // If all the results are of type WriteAheadLogBasedStoreResult, then 
create
+        // WriteAheadLogBackedBlockRDD else create simple BlockRDD.
+        if (resultTypes.size == 1 && resultTypes.head == 
classOf[WriteAheadLogBasedStoreResult]) {
+          val logSegments = blockStoreResults.map {
+            _.asInstanceOf[WriteAheadLogBasedStoreResult].segment
+          }.toArray
+          // Since storeInBlockManager = false, the storage level does not 
matter.
+          new WriteAheadLogBackedBlockRDD[T](ssc.sparkContext,
+            blockIds, logSegments, storeInBlockManager = true, 
StorageLevel.MEMORY_ONLY_SER)
+        } else {
+          new BlockRDD[T](ssc.sc, blockIds)
+        }
+      }
+    }
+    Some(blockRDD)
   }
 
   /**
@@ -86,10 +103,6 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient 
ssc_ : StreamingCont
    */
   private[streaming] override def clearMetadata(time: Time) {
     super.clearMetadata(time)
-    val oldReceivedBlocks = receivedBlockInfo.filter(_._1 <= (time - 
rememberDuration))
-    receivedBlockInfo --= oldReceivedBlocks.keys
-    logDebug("Cleared " + oldReceivedBlocks.size + " RDDs that were older than 
" +
-      (time - rememberDuration) + ": " + oldReceivedBlocks.keys.mkString(", "))
+    ssc.scheduler.receiverTracker.cleanupOldMetadata(time - rememberDuration)
   }
 }
-

http://git-wip-us.apache.org/repos/asf/spark/blob/46654b06/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
index 23295bf..dd1e963 100644
--- 
a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
@@ -48,7 +48,6 @@ class WriteAheadLogBackedBlockRDDPartition(
  * If it does not find them, it looks up the corresponding file segment.
  *
  * @param sc SparkContext
- * @param hadoopConfig Hadoop configuration
  * @param blockIds Ids of the blocks that contains this RDD's data
  * @param segments Segments in write ahead logs that contain this RDD's data
  * @param storeInBlockManager Whether to store in the block manager after 
reading from the segment
@@ -58,7 +57,6 @@ class WriteAheadLogBackedBlockRDDPartition(
 private[streaming]
 class WriteAheadLogBackedBlockRDD[T: ClassTag](
     @transient sc: SparkContext,
-    @transient hadoopConfig: Configuration,
     @transient blockIds: Array[BlockId],
     @transient segments: Array[WriteAheadLogFileSegment],
     storeInBlockManager: Boolean,
@@ -71,6 +69,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag](
       s"the same as number of segments (${segments.length}})!")
 
   // Hadoop configuration is not serializable, so broadcast it as a 
serializable.
+  @transient private val hadoopConfig = sc.hadoopConfiguration
   private val broadcastedHadoopConf = new SerializableWritable(hadoopConfig)
 
   override def getPartitions: Array[Partition] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/46654b06/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index 7d73ada..39b66e1 100644
--- 
a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -112,7 +112,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends 
Logging {
       // Wait until all the received blocks in the network input tracker has
       // been consumed by network input DStreams, and jobs have been generated 
with them
       logInfo("Waiting for all received blocks to be consumed for job 
generation")
-      while(!hasTimedOut && 
jobScheduler.receiverTracker.hasMoreReceivedBlockIds) {
+      while(!hasTimedOut && jobScheduler.receiverTracker.hasUnallocatedBlocks) 
{
         Thread.sleep(pollTime)
       }
       logInfo("Waited for all received blocks to be consumed for job 
generation")
@@ -217,14 +217,18 @@ class JobGenerator(jobScheduler: JobScheduler) extends 
Logging {
 
   /** Generate jobs and perform checkpoint for the given `time`.  */
   private def generateJobs(time: Time) {
-    Try(graph.generateJobs(time)) match {
+    // Set the SparkEnv in this thread, so that job generation code can access 
the environment
+    // Example: BlockRDDs are created in this thread, and it needs to access 
BlockManager
+    // Update: This is probably redundant after threadlocal stuff in SparkEnv 
has been removed.
+    SparkEnv.set(ssc.env)
+    Try {
+      jobScheduler.receiverTracker.allocateBlocksToBatch(time) // allocate 
received blocks to batch
+      graph.generateJobs(time) // generate jobs using allocated block
+    } match {
       case Success(jobs) =>
-        val receivedBlockInfo = graph.getReceiverInputStreams.map { stream =>
-          val streamId = stream.id
-          val receivedBlockInfo = stream.getReceivedBlockInfo(time)
-          (streamId, receivedBlockInfo)
-        }.toMap
-        jobScheduler.submitJobSet(JobSet(time, jobs, receivedBlockInfo))
+        val receivedBlockInfos =
+          jobScheduler.receiverTracker.getBlocksOfBatch(time).mapValues { 
_.toArray }
+        jobScheduler.submitJobSet(JobSet(time, jobs, receivedBlockInfos))
       case Failure(e) =>
         jobScheduler.reportError("Error generating jobs for time " + time, e)
     }
@@ -234,6 +238,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends 
Logging {
   /** Clear DStream metadata for the given `time`. */
   private def clearMetadata(time: Time) {
     ssc.graph.clearMetadata(time)
+    jobScheduler.receiverTracker.cleanupOldMetadata(time - graph.batchDuration)
 
     // If checkpointing is enabled, then checkpoint,
     // else mark batch to be fully processed

http://git-wip-us.apache.org/repos/asf/spark/blob/46654b06/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
new file mode 100644
index 0000000..5f5e190
--- /dev/null
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable
+import scala.language.implicitConversions
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.{SparkException, Logging, SparkConf}
+import org.apache.spark.streaming.Time
+import org.apache.spark.streaming.util.{Clock, WriteAheadLogManager}
+import org.apache.spark.util.Utils
+
+/** Trait representing any event in the ReceivedBlockTracker that updates its 
state. */
+private[streaming] sealed trait ReceivedBlockTrackerLogEvent
+
+private[streaming] case class BlockAdditionEvent(receivedBlockInfo: 
ReceivedBlockInfo)
+  extends ReceivedBlockTrackerLogEvent
+private[streaming] case class BatchAllocationEvent(time: Time, 
allocatedBlocks: AllocatedBlocks)
+  extends ReceivedBlockTrackerLogEvent
+private[streaming] case class BatchCleanupEvent(times: Seq[Time])
+  extends ReceivedBlockTrackerLogEvent
+
+
+/** Class representing the blocks of all the streams allocated to a batch */
+private[streaming]
+case class AllocatedBlocks(streamIdToAllocatedBlocks: Map[Int, 
Seq[ReceivedBlockInfo]]) {
+  def getBlocksOfStream(streamId: Int): Seq[ReceivedBlockInfo] = {
+    streamIdToAllocatedBlocks.get(streamId).getOrElse(Seq.empty)
+  }
+}
+
+/**
+ * Class that keep track of all the received blocks, and allocate them to 
batches
+ * when required. All actions taken by this class can be saved to a write 
ahead log
+ * (if a checkpoint directory has been provided), so that the state of the 
tracker
+ * (received blocks and block-to-batch allocations) can be recovered after 
driver failure.
+ *
+ * Note that when any instance of this class is created with a checkpoint 
directory,
+ * it will try reading events from logs in the directory.
+ */
+private[streaming] class ReceivedBlockTracker(
+    conf: SparkConf,
+    hadoopConf: Configuration,
+    streamIds: Seq[Int],
+    clock: Clock,
+    checkpointDirOption: Option[String])
+  extends Logging {
+
+  private type ReceivedBlockQueue = mutable.Queue[ReceivedBlockInfo]
+  
+  private val streamIdToUnallocatedBlockQueues = new mutable.HashMap[Int, 
ReceivedBlockQueue]
+  private val timeToAllocatedBlocks = new mutable.HashMap[Time, 
AllocatedBlocks]
+
+  private val logManagerRollingIntervalSecs = conf.getInt(
+    "spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", 
60)
+  private val logManagerOption = checkpointDirOption.map { checkpointDir =>
+    new WriteAheadLogManager(
+      ReceivedBlockTracker.checkpointDirToLogDir(checkpointDir),
+      hadoopConf,
+      rollingIntervalSecs = logManagerRollingIntervalSecs,
+      callerName = "ReceivedBlockHandlerMaster",
+      clock = clock
+    )
+  }
+
+  private var lastAllocatedBatchTime: Time = null
+
+  // Recover block information from write ahead logs
+  recoverFromWriteAheadLogs()
+
+  /** Add received block. This event will get written to the write ahead log 
(if enabled). */
+  def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = synchronized {
+    try {
+      writeToLog(BlockAdditionEvent(receivedBlockInfo))
+      getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo
+      logDebug(s"Stream ${receivedBlockInfo.streamId} received " +
+        s"block ${receivedBlockInfo.blockStoreResult.blockId}")
+      true
+    } catch {
+      case e: Exception =>
+        logError(s"Error adding block $receivedBlockInfo", e)
+        false
+    }
+  }
+
+  /**
+   * Allocate all unallocated blocks to the given batch.
+   * This event will get written to the write ahead log (if enabled).
+   */
+  def allocateBlocksToBatch(batchTime: Time): Unit = synchronized {
+    if (lastAllocatedBatchTime == null || batchTime > lastAllocatedBatchTime) {
+      val streamIdToBlocks = streamIds.map { streamId =>
+          (streamId, getReceivedBlockQueue(streamId).dequeueAll(x => true))
+      }.toMap
+      val allocatedBlocks = AllocatedBlocks(streamIdToBlocks)
+      writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks))
+      timeToAllocatedBlocks(batchTime) = allocatedBlocks
+      lastAllocatedBatchTime = batchTime
+      allocatedBlocks
+    } else {
+      throw new SparkException(s"Unexpected allocation of blocks, " +
+        s"last batch = $lastAllocatedBatchTime, batch time to allocate = 
$batchTime  ")
+    }
+  }
+
+  /** Get the blocks allocated to the given batch. */
+  def getBlocksOfBatch(batchTime: Time): Map[Int, Seq[ReceivedBlockInfo]] = 
synchronized {
+    timeToAllocatedBlocks.get(batchTime).map { _.streamIdToAllocatedBlocks 
}.getOrElse(Map.empty)
+  }
+
+  /** Get the blocks allocated to the given batch and stream. */
+  def getBlocksOfBatchAndStream(batchTime: Time, streamId: Int): 
Seq[ReceivedBlockInfo] = {
+    synchronized {
+      timeToAllocatedBlocks.get(batchTime).map {
+        _.getBlocksOfStream(streamId)
+      }.getOrElse(Seq.empty)
+    }
+  }
+
+  /** Check if any blocks are left to be allocated to batches. */
+  def hasUnallocatedReceivedBlocks: Boolean = synchronized {
+    !streamIdToUnallocatedBlockQueues.values.forall(_.isEmpty)
+  }
+
+  /**
+   * Get blocks that have been added but not yet allocated to any batch. This 
method
+   * is primarily used for testing.
+   */
+  def getUnallocatedBlocks(streamId: Int): Seq[ReceivedBlockInfo] = 
synchronized {
+    getReceivedBlockQueue(streamId).toSeq
+  }
+
+  /** Clean up block information of old batches. */
+  def cleanupOldBatches(cleanupThreshTime: Time): Unit = synchronized {
+    assert(cleanupThreshTime.milliseconds < clock.currentTime())
+    val timesToCleanup = timeToAllocatedBlocks.keys.filter { _ < 
cleanupThreshTime }.toSeq
+    logInfo("Deleting batches " + timesToCleanup)
+    writeToLog(BatchCleanupEvent(timesToCleanup))
+    timeToAllocatedBlocks --= timesToCleanup
+    logManagerOption.foreach(_.cleanupOldLogs(cleanupThreshTime.milliseconds))
+    log
+  }
+
+  /** Stop the block tracker. */
+  def stop() {
+    logManagerOption.foreach { _.stop() }
+  }
+
+  /**
+   * Recover all the tracker actions from the write ahead logs to recover the 
state (unallocated
+   * and allocated block info) prior to failure.
+   */
+  private def recoverFromWriteAheadLogs(): Unit = synchronized {
+    // Insert the recovered block information
+    def insertAddedBlock(receivedBlockInfo: ReceivedBlockInfo) {
+      logTrace(s"Recovery: Inserting added block $receivedBlockInfo")
+      getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo
+    }
+
+    // Insert the recovered block-to-batch allocations and clear the queue of 
received blocks
+    // (when the blocks were originally allocated to the batch, the queue must 
have been cleared).
+    def insertAllocatedBatch(batchTime: Time, allocatedBlocks: 
AllocatedBlocks) {
+      logTrace(s"Recovery: Inserting allocated batch for time $batchTime to " +
+        s"${allocatedBlocks.streamIdToAllocatedBlocks}")
+      streamIdToUnallocatedBlockQueues.values.foreach { _.clear() }
+      lastAllocatedBatchTime = batchTime
+      timeToAllocatedBlocks.put(batchTime, allocatedBlocks)
+    }
+
+    // Cleanup the batch allocations
+    def cleanupBatches(batchTimes: Seq[Time]) {
+      logTrace(s"Recovery: Cleaning up batches $batchTimes")
+      timeToAllocatedBlocks --= batchTimes
+    }
+
+    logManagerOption.foreach { logManager =>
+      logInfo(s"Recovering from write ahead logs in 
${checkpointDirOption.get}")
+      logManager.readFromLog().foreach { byteBuffer =>
+        logTrace("Recovering record " + byteBuffer)
+        Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) 
match {
+          case BlockAdditionEvent(receivedBlockInfo) =>
+            insertAddedBlock(receivedBlockInfo)
+          case BatchAllocationEvent(time, allocatedBlocks) =>
+            insertAllocatedBatch(time, allocatedBlocks)
+          case BatchCleanupEvent(batchTimes) =>
+            cleanupBatches(batchTimes)
+        }
+      }
+    }
+  }
+
+  /** Write an update to the tracker to the write ahead log */
+  private def writeToLog(record: ReceivedBlockTrackerLogEvent) {
+    logDebug(s"Writing to log $record")
+    logManagerOption.foreach { logManager =>
+        logManager.writeToLog(ByteBuffer.wrap(Utils.serialize(record)))
+    }
+  }
+
+  /** Get the queue of received blocks belonging to a particular stream */
+  private def getReceivedBlockQueue(streamId: Int): ReceivedBlockQueue = {
+    streamIdToUnallocatedBlockQueues.getOrElseUpdate(streamId, new 
ReceivedBlockQueue)
+  }
+}
+
+private[streaming] object ReceivedBlockTracker {
+  def checkpointDirToLogDir(checkpointDir: String): String = {
+    new Path(checkpointDir, "receivedBlockMetadata").toString
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/46654b06/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
index d696563..1c3984d 100644
--- 
a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
@@ -17,15 +17,16 @@
 
 package org.apache.spark.streaming.scheduler
 
-import scala.collection.mutable.{HashMap, SynchronizedMap, SynchronizedQueue}
+
+import scala.collection.mutable.{HashMap, SynchronizedMap}
 import scala.language.existentials
 
 import akka.actor._
-import org.apache.spark.{SerializableWritable, Logging, SparkEnv, 
SparkException}
+
+import org.apache.spark.{Logging, SerializableWritable, SparkEnv, 
SparkException}
 import org.apache.spark.SparkContext._
 import org.apache.spark.streaming.{StreamingContext, Time}
 import org.apache.spark.streaming.receiver.{Receiver, ReceiverSupervisorImpl, 
StopReceiver}
-import org.apache.spark.util.AkkaUtils
 
 /**
  * Messages used by the NetworkReceiver and the ReceiverTracker to communicate
@@ -48,23 +49,28 @@ private[streaming] case class DeregisterReceiver(streamId: 
Int, msg: String, err
  * This class manages the execution of the receivers of NetworkInputDStreams. 
Instance of
  * this class must be created after all input streams have been added and 
StreamingContext.start()
  * has been called because it needs the final set of input streams at the time 
of instantiation.
+ *
+ * @param skipReceiverLaunch Do not launch the receiver. This is useful for 
testing.
  */
 private[streaming]
-class ReceiverTracker(ssc: StreamingContext) extends Logging {
+class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = 
false) extends Logging {
 
-  val receiverInputStreams = ssc.graph.getReceiverInputStreams()
-  val receiverInputStreamMap = Map(receiverInputStreams.map(x => (x.id, x)): 
_*)
-  val receiverExecutor = new ReceiverLauncher()
-  val receiverInfo = new HashMap[Int, ReceiverInfo] with SynchronizedMap[Int, 
ReceiverInfo]
-  val receivedBlockInfo = new HashMap[Int, 
SynchronizedQueue[ReceivedBlockInfo]]
-    with SynchronizedMap[Int, SynchronizedQueue[ReceivedBlockInfo]]
-  val timeout = AkkaUtils.askTimeout(ssc.conf)
-  val listenerBus = ssc.scheduler.listenerBus
+  private val receiverInputStreams = ssc.graph.getReceiverInputStreams()
+  private val receiverInputStreamIds = receiverInputStreams.map { _.id }
+  private val receiverExecutor = new ReceiverLauncher()
+  private val receiverInfo = new HashMap[Int, ReceiverInfo] with 
SynchronizedMap[Int, ReceiverInfo]
+  private val receivedBlockTracker = new ReceivedBlockTracker(
+    ssc.sparkContext.conf,
+    ssc.sparkContext.hadoopConfiguration,
+    receiverInputStreamIds,
+    ssc.scheduler.clock,
+    Option(ssc.checkpointDir)
+  )
+  private val listenerBus = ssc.scheduler.listenerBus
 
   // actor is created when generator starts.
   // This not being null means the tracker has been started and not stopped
-  var actor: ActorRef = null
-  var currentTime: Time = null
+  private var actor: ActorRef = null
 
   /** Start the actor and receiver execution thread. */
   def start() = synchronized {
@@ -75,7 +81,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging {
     if (!receiverInputStreams.isEmpty) {
       actor = ssc.env.actorSystem.actorOf(Props(new ReceiverTrackerActor),
         "ReceiverTracker")
-      receiverExecutor.start()
+      if (!skipReceiverLaunch) receiverExecutor.start()
       logInfo("ReceiverTracker started")
     }
   }
@@ -84,45 +90,59 @@ class ReceiverTracker(ssc: StreamingContext) extends 
Logging {
   def stop() = synchronized {
     if (!receiverInputStreams.isEmpty && actor != null) {
       // First, stop the receivers
-      receiverExecutor.stop()
+      if (!skipReceiverLaunch) receiverExecutor.stop()
 
       // Finally, stop the actor
       ssc.env.actorSystem.stop(actor)
       actor = null
+      receivedBlockTracker.stop()
       logInfo("ReceiverTracker stopped")
     }
   }
 
-  /** Return all the blocks received from a receiver. */
-  def getReceivedBlockInfo(streamId: Int): Array[ReceivedBlockInfo] = {
-    val receivedBlockInfo = getReceivedBlockInfoQueue(streamId).dequeueAll(x 
=> true)
-    logInfo("Stream " + streamId + " received " + receivedBlockInfo.size + " 
blocks")
-    receivedBlockInfo.toArray
+  /** Allocate all unallocated blocks to the given batch. */
+  def allocateBlocksToBatch(batchTime: Time): Unit = {
+    if (receiverInputStreams.nonEmpty) {
+      receivedBlockTracker.allocateBlocksToBatch(batchTime)
+    }
+  }
+
+  /** Get the blocks for the given batch and all input streams. */
+  def getBlocksOfBatch(batchTime: Time): Map[Int, Seq[ReceivedBlockInfo]] = {
+    receivedBlockTracker.getBlocksOfBatch(batchTime)
   }
 
-  private def getReceivedBlockInfoQueue(streamId: Int) = {
-    receivedBlockInfo.getOrElseUpdate(streamId, new 
SynchronizedQueue[ReceivedBlockInfo])
+  /** Get the blocks allocated to the given batch and stream. */
+  def getBlocksOfBatchAndStream(batchTime: Time, streamId: Int): 
Seq[ReceivedBlockInfo] = {
+    synchronized {
+      receivedBlockTracker.getBlocksOfBatchAndStream(batchTime, streamId)
+    }
+  }
+
+    /** Clean up metadata older than the given threshold time */
+  def cleanupOldMetadata(cleanupThreshTime: Time) {
+    receivedBlockTracker.cleanupOldBatches(cleanupThreshTime)
   }
 
   /** Register a receiver */
-  def registerReceiver(
+  private def registerReceiver(
       streamId: Int,
       typ: String,
       host: String,
       receiverActor: ActorRef,
       sender: ActorRef
     ) {
-    if (!receiverInputStreamMap.contains(streamId)) {
-      throw new Exception("Register received for unexpected id " + streamId)
+    if (!receiverInputStreamIds.contains(streamId)) {
+      throw new SparkException("Register received for unexpected id " + 
streamId)
     }
     receiverInfo(streamId) = ReceiverInfo(
       streamId, s"${typ}-${streamId}", receiverActor, true, host)
-    
ssc.scheduler.listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId)))
+    listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId)))
     logInfo("Registered receiver for stream " + streamId + " from " + 
sender.path.address)
   }
 
   /** Deregister a receiver */
-  def deregisterReceiver(streamId: Int, message: String, error: String) {
+  private def deregisterReceiver(streamId: Int, message: String, error: 
String) {
     val newReceiverInfo = receiverInfo.get(streamId) match {
       case Some(oldInfo) =>
         oldInfo.copy(actor = null, active = false, lastErrorMessage = message, 
lastError = error)
@@ -131,7 +151,7 @@ class ReceiverTracker(ssc: StreamingContext) extends 
Logging {
         ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = 
message, lastError = error)
     }
     receiverInfo(streamId) = newReceiverInfo
-    
ssc.scheduler.listenerBus.post(StreamingListenerReceiverStopped(receiverInfo(streamId)))
+    listenerBus.post(StreamingListenerReceiverStopped(receiverInfo(streamId)))
     val messageWithError = if (error != null && !error.isEmpty) {
       s"$message - $error"
     } else {
@@ -141,14 +161,12 @@ class ReceiverTracker(ssc: StreamingContext) extends 
Logging {
   }
 
   /** Add new blocks for the given stream */
-  def addBlocks(receivedBlockInfo: ReceivedBlockInfo) {
-    getReceivedBlockInfoQueue(receivedBlockInfo.streamId) += receivedBlockInfo
-    logDebug("Stream " + receivedBlockInfo.streamId + " received new blocks: " 
+
-      receivedBlockInfo.blockStoreResult.blockId)
+  private def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = {
+    receivedBlockTracker.addBlock(receivedBlockInfo)
   }
 
   /** Report error sent by a receiver */
-  def reportError(streamId: Int, message: String, error: String) {
+  private def reportError(streamId: Int, message: String, error: String) {
     val newReceiverInfo = receiverInfo.get(streamId) match {
       case Some(oldInfo) =>
         oldInfo.copy(lastErrorMessage = message, lastError = error)
@@ -157,7 +175,7 @@ class ReceiverTracker(ssc: StreamingContext) extends 
Logging {
         ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = 
message, lastError = error)
     }
     receiverInfo(streamId) = newReceiverInfo
-    
ssc.scheduler.listenerBus.post(StreamingListenerReceiverError(receiverInfo(streamId)))
+    listenerBus.post(StreamingListenerReceiverError(receiverInfo(streamId)))
     val messageWithError = if (error != null && !error.isEmpty) {
       s"$message - $error"
     } else {
@@ -167,8 +185,8 @@ class ReceiverTracker(ssc: StreamingContext) extends 
Logging {
   }
 
   /** Check if any blocks are left to be processed */
-  def hasMoreReceivedBlockIds: Boolean = {
-    !receivedBlockInfo.values.forall(_.isEmpty)
+  def hasUnallocatedBlocks: Boolean = {
+    receivedBlockTracker.hasUnallocatedReceivedBlocks
   }
 
   /** Actor to receive messages from the receivers. */
@@ -178,8 +196,7 @@ class ReceiverTracker(ssc: StreamingContext) extends 
Logging {
         registerReceiver(streamId, typ, host, receiverActor, sender)
         sender ! true
       case AddBlock(receivedBlockInfo) =>
-        addBlocks(receivedBlockInfo)
-        sender ! true
+        sender ! addBlock(receivedBlockInfo)
       case ReportError(streamId, message, error) =>
         reportError(streamId, message, error)
       case DeregisterReceiver(streamId, message, error) =>
@@ -194,6 +211,7 @@ class ReceiverTracker(ssc: StreamingContext) extends 
Logging {
     @transient val thread  = new Thread() {
       override def run() {
         try {
+          SparkEnv.set(env)
           startReceivers()
         } catch {
           case ie: InterruptedException => logInfo("ReceiverLauncher 
interrupted")
@@ -267,7 +285,7 @@ class ReceiverTracker(ssc: StreamingContext) extends 
Logging {
 
       // Distribute the receivers and start them
       logInfo("Starting " + receivers.length + " receivers")
-      ssc.sparkContext.runJob(tempRDD, startReceiver)
+      ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver))
       logInfo("All of the receivers have been terminated")
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/46654b06/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
 
b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index 6c8bb50..dbab685 100644
--- 
a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ 
b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -17,18 +17,19 @@
 
 package org.apache.spark.streaming
 
-import org.apache.spark.streaming.StreamingContext._
-
-import org.apache.spark.rdd.{BlockRDD, RDD}
-import org.apache.spark.SparkContext._
+import scala.collection.mutable
+import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
+import scala.language.existentials
+import scala.reflect.ClassTag
 
 import util.ManualClock
-import org.apache.spark.{SparkException, SparkConf}
-import org.apache.spark.streaming.dstream.{WindowedDStream, DStream}
-import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer}
-import scala.reflect.ClassTag
+
+import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.SparkContext._
+import org.apache.spark.rdd.{BlockRDD, RDD}
 import org.apache.spark.storage.StorageLevel
-import scala.collection.mutable
+import org.apache.spark.streaming.StreamingContext._
+import org.apache.spark.streaming.dstream.{DStream, WindowedDStream}
 
 class BasicOperationsSuite extends TestSuiteBase {
   test("map") {

http://git-wip-us.apache.org/repos/asf/spark/blob/46654b06/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
 
b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
new file mode 100644
index 0000000..fd9c97f
--- /dev/null
+++ 
b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
@@ -0,0 +1,242 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+import java.io.File
+
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.duration._
+import scala.language.{implicitConversions, postfixOps}
+import scala.util.Random
+
+import com.google.common.io.Files
+import org.apache.commons.io.FileUtils
+import org.apache.hadoop.conf.Configuration
+import org.scalatest.{BeforeAndAfter, FunSuite, Matchers}
+import org.scalatest.concurrent.Eventually._
+
+import org.apache.spark.{Logging, SparkConf, SparkException}
+import org.apache.spark.storage.StreamBlockId
+import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult
+import org.apache.spark.streaming.scheduler._
+import org.apache.spark.streaming.util.{Clock, ManualClock, SystemClock, 
WriteAheadLogReader}
+import org.apache.spark.streaming.util.WriteAheadLogSuite._
+import org.apache.spark.util.Utils
+
+class ReceivedBlockTrackerSuite
+  extends FunSuite with BeforeAndAfter with Matchers with Logging {
+
+  val conf = new 
SparkConf().setMaster("local[2]").setAppName("ReceivedBlockTrackerSuite")
+  
conf.set("spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs",
 "1")
+
+  val hadoopConf = new Configuration()
+  val akkaTimeout = 10 seconds
+  val streamId = 1
+
+  var allReceivedBlockTrackers = new ArrayBuffer[ReceivedBlockTracker]()
+  var checkpointDirectory: File = null
+
+  before {
+    checkpointDirectory = Files.createTempDir()
+  }
+
+  after {
+    allReceivedBlockTrackers.foreach { _.stop() }
+    if (checkpointDirectory != null && checkpointDirectory.exists()) {
+      FileUtils.deleteDirectory(checkpointDirectory)
+      checkpointDirectory = null
+    }
+  }
+
+  test("block addition, and block to batch allocation") {
+    val receivedBlockTracker = createTracker(enableCheckpoint = false)
+    receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual Seq.empty
+
+    val blockInfos = generateBlockInfos()
+    blockInfos.map(receivedBlockTracker.addBlock)
+
+    // Verify added blocks are unallocated blocks
+    receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos
+
+    // Allocate the blocks to a batch and verify that all of them have been 
allocated
+    receivedBlockTracker.allocateBlocksToBatch(1)
+    receivedBlockTracker.getBlocksOfBatchAndStream(1, streamId) shouldEqual 
blockInfos
+    receivedBlockTracker.getUnallocatedBlocks(streamId) shouldBe empty
+
+    // Allocate no blocks to another batch
+    receivedBlockTracker.allocateBlocksToBatch(2)
+    receivedBlockTracker.getBlocksOfBatchAndStream(2, streamId) shouldBe empty
+
+    // Verify that batch 2 cannot be allocated again
+    intercept[SparkException] {
+      receivedBlockTracker.allocateBlocksToBatch(2)
+    }
+
+    // Verify that older batches cannot be allocated again
+    intercept[SparkException] {
+      receivedBlockTracker.allocateBlocksToBatch(1)
+    }
+  }
+
+  test("block addition, block to batch allocation and cleanup with write ahead 
log") {
+    val manualClock = new ManualClock
+    conf.getInt(
+      
"spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", -1) 
should be (1)
+
+    // Set the time increment level to twice the rotation interval so that 
every increment creates
+    // a new log file
+    val timeIncrementMillis = 2000L
+    def incrementTime() {
+      manualClock.addToTime(timeIncrementMillis)
+    }
+
+    // Generate and add blocks to the given tracker
+    def addBlockInfos(tracker: ReceivedBlockTracker): Seq[ReceivedBlockInfo] = 
{
+      val blockInfos = generateBlockInfos()
+      blockInfos.map(tracker.addBlock)
+      blockInfos
+    }
+
+    // Print the data present in the log ahead files in the log directory
+    def printLogFiles(message: String) {
+      val fileContents = getWriteAheadLogFiles().map { file =>
+        (s"\n>>>>> $file: <<<<<\n${getWrittenLogData(file).mkString("\n")}")
+      }.mkString("\n")
+      
logInfo(s"\n\n=====================\n$message\n$fileContents\n=====================\n")
+    }
+
+    // Start tracker and add blocks
+    val tracker1 = createTracker(enableCheckpoint = true, clock = manualClock)
+    val blockInfos1 = addBlockInfos(tracker1)
+    tracker1.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1
+
+    // Verify whether write ahead log has correct contents
+    val expectedWrittenData1 = blockInfos1.map(BlockAdditionEvent)
+    getWrittenLogData() shouldEqual expectedWrittenData1
+    getWriteAheadLogFiles() should have size 1
+
+    // Restart tracker and verify recovered list of unallocated blocks
+    incrementTime()
+    val tracker2 = createTracker(enableCheckpoint = true, clock = manualClock)
+    tracker2.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1
+
+    // Allocate blocks to batch and verify whether the unallocated blocks got 
allocated
+    val batchTime1 = manualClock.currentTime
+    tracker2.allocateBlocksToBatch(batchTime1)
+    tracker2.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual 
blockInfos1
+
+    // Add more blocks and allocate to another batch
+    incrementTime()
+    val batchTime2 = manualClock.currentTime
+    val blockInfos2 = addBlockInfos(tracker2)
+    tracker2.allocateBlocksToBatch(batchTime2)
+    tracker2.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual 
blockInfos2
+
+    // Verify whether log has correct contents
+    val expectedWrittenData2 = expectedWrittenData1 ++
+      Seq(createBatchAllocation(batchTime1, blockInfos1)) ++
+      blockInfos2.map(BlockAdditionEvent) ++
+      Seq(createBatchAllocation(batchTime2, blockInfos2))
+    getWrittenLogData() shouldEqual expectedWrittenData2
+
+    // Restart tracker and verify recovered state
+    incrementTime()
+    val tracker3 = createTracker(enableCheckpoint = true, clock = manualClock)
+    tracker3.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual 
blockInfos1
+    tracker3.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual 
blockInfos2
+    tracker3.getUnallocatedBlocks(streamId) shouldBe empty
+
+    // Cleanup first batch but not second batch
+    val oldestLogFile = getWriteAheadLogFiles().head
+    incrementTime()
+    tracker3.cleanupOldBatches(batchTime2)
+
+    // Verify that the batch allocations have been cleaned, and the act has 
been written to log
+    tracker3.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual 
Seq.empty
+    getWrittenLogData(getWriteAheadLogFiles().last) should 
contain(createBatchCleanup(batchTime1))
+
+    // Verify that at least one log file gets deleted
+    eventually(timeout(10 seconds), interval(10 millisecond)) {
+      getWriteAheadLogFiles() should not contain oldestLogFile
+    }
+    printLogFiles("After cleanup")
+
+    // Restart tracker and verify recovered state, specifically whether info 
about the first
+    // batch has been removed, but not the second batch
+    incrementTime()
+    val tracker4 = createTracker(enableCheckpoint = true, clock = manualClock)
+    tracker4.getUnallocatedBlocks(streamId) shouldBe empty
+    tracker4.getBlocksOfBatchAndStream(batchTime1, streamId) shouldBe empty  
// should be cleaned
+    tracker4.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual 
blockInfos2
+  }
+
+  /**
+   * Create tracker object with the optional provided clock. Use fake clock if 
you
+   * want to control time by manually incrementing it to test log cleanup.
+   */
+  def createTracker(enableCheckpoint: Boolean, clock: Clock = new 
SystemClock): ReceivedBlockTracker = {
+    val cpDirOption = if (enableCheckpoint) Some(checkpointDirectory.toString) 
else None
+    val tracker = new ReceivedBlockTracker(conf, hadoopConf, Seq(streamId), 
clock, cpDirOption)
+    allReceivedBlockTrackers += tracker
+    tracker
+  }
+
+  /** Generate blocks infos using random ids */
+  def generateBlockInfos(): Seq[ReceivedBlockInfo] = {
+    List.fill(5)(ReceivedBlockInfo(streamId, 0,
+      BlockManagerBasedStoreResult(StreamBlockId(streamId, 
math.abs(Random.nextInt)))))
+  }
+
+  /** Get all the data written in the given write ahead log file. */
+  def getWrittenLogData(logFile: String): Seq[ReceivedBlockTrackerLogEvent] = {
+    getWrittenLogData(Seq(logFile))
+  }
+
+  /**
+   * Get all the data written in the given write ahead log files. By default, 
it will read all
+   * files in the test log directory.
+   */
+  def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles): 
Seq[ReceivedBlockTrackerLogEvent] = {
+    logFiles.flatMap {
+      file => new WriteAheadLogReader(file, hadoopConf).toSeq
+    }.map { byteBuffer =>
+      Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array)
+    }.toList
+  }
+
+  /** Get all the write ahead log files in the test directory */
+  def getWriteAheadLogFiles(): Seq[String] = {
+    import ReceivedBlockTracker._
+    val logDir = checkpointDirToLogDir(checkpointDirectory.toString)
+    getLogFilesInDirectory(logDir).map { _.toString }
+  }
+
+  /** Create batch allocation object from the given info */
+  def createBatchAllocation(time: Long, blockInfos: Seq[ReceivedBlockInfo]): 
BatchAllocationEvent = {
+    BatchAllocationEvent(time, AllocatedBlocks(Map((streamId -> blockInfos))))
+  }
+
+  /** Create batch cleanup object from the given info */
+  def createBatchCleanup(time: Long, moreTimes: Long*): BatchCleanupEvent = {
+    BatchCleanupEvent((Seq(time) ++ moreTimes).map(Time.apply))
+  }
+
+  implicit def millisToTime(milliseconds: Long): Time = Time(milliseconds)
+
+  implicit def timeToMillis(time: Time): Long = time.milliseconds
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/46654b06/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala
 
b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala
index 1016024..d2b983c 100644
--- 
a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala
+++ 
b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala
@@ -117,12 +117,12 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite 
with BeforeAndAfterAll {
     )
 
     // Create the RDD and verify whether the returned data is correct
-    val rdd = new WriteAheadLogBackedBlockRDD[String](sparkContext, 
hadoopConf, blockIds.toArray,
+    val rdd = new WriteAheadLogBackedBlockRDD[String](sparkContext, 
blockIds.toArray,
       segments.toArray, storeInBlockManager = false, StorageLevel.MEMORY_ONLY)
     assert(rdd.collect() === data.flatten)
 
     if (testStoreInBM) {
-      val rdd2 = new WriteAheadLogBackedBlockRDD[String](sparkContext, 
hadoopConf, blockIds.toArray,
+      val rdd2 = new WriteAheadLogBackedBlockRDD[String](sparkContext, 
blockIds.toArray,
         segments.toArray, storeInBlockManager = true, StorageLevel.MEMORY_ONLY)
       assert(rdd2.collect() === data.flatten)
       assert(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to