Repository: spark
Updated Branches:
  refs/heads/branch-1.6 68bcb9b33 -> 7f030aa42


[SPARK-11979][STREAMING] Empty TrackStateRDD cannot be checkpointed and 
recovered from checkpoint file

This solves the following exception caused when empty state RDD is checkpointed 
and recovered. The root cause is that an empty OpenHashMapBasedStateMap cannot 
be deserialized as the initialCapacity is set to zero.
```
Job aborted due to stage failure: Task 0 in stage 6.0 failed 1 times, most 
recent failure: Lost task 0.0 in stage 6.0 (TID 20, localhost): 
java.lang.IllegalArgumentException: requirement failed: Invalid initial capacity
        at scala.Predef$.require(Predef.scala:233)
        at 
org.apache.spark.streaming.util.OpenHashMapBasedStateMap.<init>(StateMap.scala:96)
        at 
org.apache.spark.streaming.util.OpenHashMapBasedStateMap.<init>(StateMap.scala:86)
        at 
org.apache.spark.streaming.util.OpenHashMapBasedStateMap.readObject(StateMap.scala:291)
        at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
        at 
sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57)
        at 
sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
        at java.lang.reflect.Method.invoke(Method.java:606)
        at 
java.io.ObjectStreamClass.invokeReadObject(ObjectStreamClass.java:1017)
        at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1893)
        at 
java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1798)
        at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1350)
        at 
java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:1990)
        at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1915)
        at 
java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1798)
        at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1350)
        at java.io.ObjectInputStream.readObject(ObjectInputStream.java:370)
        at 
org.apache.spark.serializer.JavaDeserializationStream.readObject(JavaSerializer.scala:76)
        at 
org.apache.spark.serializer.DeserializationStream$$anon$1.getNext(Serializer.scala:181)
        at org.apache.spark.util.NextIterator.hasNext(NextIterator.scala:73)
        at scala.collection.Iterator$$anon$13.hasNext(Iterator.scala:371)
        at scala.collection.Iterator$class.foreach(Iterator.scala:727)
        at scala.collection.AbstractIterator.foreach(Iterator.scala:1157)
        at 
scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:48)
        at 
scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:103)
        at 
scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:47)
        at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:273)
        at scala.collection.AbstractIterator.to(Iterator.scala:1157)
        at 
scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:265)
        at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1157)
        at 
scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:252)
        at scala.collection.AbstractIterator.toArray(Iterator.scala:1157)
        at 
org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$12.apply(RDD.scala:921)
        at 
org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$12.apply(RDD.scala:921)
        at 
org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1858)
        at 
org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1858)
        at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:66)
        at org.apache.spark.scheduler.Task.run(Task.scala:88)
        at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:214)
        at 
java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145)
        at 
java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615)
        at java.lang.Thread.run(Thread.java:744)
```

Author: Tathagata Das <tathagata.das1...@gmail.com>

Closes #9958 from tdas/SPARK-11979.

(cherry picked from commit 2169886883d33b33acf378ac42a626576b342df1)
Signed-off-by: Shixiong Zhu <shixi...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7f030aa4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7f030aa4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7f030aa4

Branch: refs/heads/branch-1.6
Commit: 7f030aa422802a8e7077e1c74a59ab9a5fe54488
Parents: 68bcb9b
Author: Tathagata Das <tathagata.das1...@gmail.com>
Authored: Tue Nov 24 23:13:01 2015 -0800
Committer: Shixiong Zhu <shixi...@databricks.com>
Committed: Tue Nov 24 23:13:29 2015 -0800

----------------------------------------------------------------------
 .../apache/spark/streaming/util/StateMap.scala  | 19 ++++++++-----
 .../apache/spark/streaming/StateMapSuite.scala  | 30 +++++++++++++-------
 .../streaming/rdd/TrackStateRDDSuite.scala      | 10 +++++++
 3 files changed, 42 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7f030aa4/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala 
b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
index 34287c3..3f139ad 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
@@ -59,7 +59,7 @@ private[streaming] object StateMap {
   def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = {
     val deltaChainThreshold = 
conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold",
       DELTA_CHAIN_LENGTH_THRESHOLD)
-    new OpenHashMapBasedStateMap[K, S](64, deltaChainThreshold)
+    new OpenHashMapBasedStateMap[K, S](deltaChainThreshold)
   }
 }
 
@@ -79,7 +79,7 @@ private[streaming] class EmptyStateMap[K: ClassTag, S: 
ClassTag] extends StateMa
 /** Implementation of StateMap based on Spark's 
[[org.apache.spark.util.collection.OpenHashMap]] */
 private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
     @transient @volatile var parentStateMap: StateMap[K, S],
-    initialCapacity: Int = 64,
+    initialCapacity: Int = DEFAULT_INITIAL_CAPACITY,
     deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD
   ) extends StateMap[K, S] { self =>
 
@@ -89,12 +89,14 @@ private[streaming] class OpenHashMapBasedStateMap[K: 
ClassTag, S: ClassTag](
     deltaChainThreshold = deltaChainThreshold)
 
   def this(deltaChainThreshold: Int) = this(
-    initialCapacity = 64, deltaChainThreshold = deltaChainThreshold)
+    initialCapacity = DEFAULT_INITIAL_CAPACITY, deltaChainThreshold = 
deltaChainThreshold)
 
   def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD)
 
-  @transient @volatile private var deltaMap =
-    new OpenHashMap[K, StateInfo[S]](initialCapacity)
+  require(initialCapacity >= 1, "Invalid initial capacity")
+  require(deltaChainThreshold >= 1, "Invalid delta chain threshold")
+
+  @transient @volatile private var deltaMap = new OpenHashMap[K, 
StateInfo[S]](initialCapacity)
 
   /** Get the session data if it exists */
   override def get(key: K): Option[S] = {
@@ -284,9 +286,10 @@ private[streaming] class OpenHashMapBasedStateMap[K: 
ClassTag, S: ClassTag](
     // Read the data of the parent map. Keep reading records, until the 
limiter is reached
     // First read the approximate number of records to expect and allocate 
properly size
     // OpenHashMap
-    val parentSessionStoreSizeHint = inputStream.readInt()
+    val parentStateMapSizeHint = inputStream.readInt()
+    val newStateMapInitialCapacity = math.max(parentStateMapSizeHint, 
DEFAULT_INITIAL_CAPACITY)
     val newParentSessionStore = new OpenHashMapBasedStateMap[K, S](
-      initialCapacity = parentSessionStoreSizeHint, deltaChainThreshold)
+      initialCapacity = newStateMapInitialCapacity, deltaChainThreshold)
 
     // Read the records until the limit marking object has been reached
     var parentSessionLoopDone = false
@@ -338,4 +341,6 @@ private[streaming] object OpenHashMapBasedStateMap {
   class LimitMarker(val num: Int) extends Serializable
 
   val DELTA_CHAIN_LENGTH_THRESHOLD = 20
+
+  val DEFAULT_INITIAL_CAPACITY = 64
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7f030aa4/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala 
b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
index 48d3b41..c4a01ea 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
@@ -122,23 +122,27 @@ class StateMapSuite extends SparkFunSuite {
 
   test("OpenHashMapBasedStateMap - serializing and deserializing") {
     val map1 = new OpenHashMapBasedStateMap[Int, Int]()
+    testSerialization(map1, "error deserializing and serialized empty map")
+
     map1.put(1, 100, 1)
     map1.put(2, 200, 2)
+    testSerialization(map1, "error deserializing and serialized map with data 
+ no delta")
 
     val map2 = map1.copy()
+    // Do not test compaction
+    assert(map2.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === 
false)
+    testSerialization(map2, "error deserializing and serialized map with 1 
delta + no new data")
+
     map2.put(3, 300, 3)
     map2.put(4, 400, 4)
+    testSerialization(map2, "error deserializing and serialized map with 1 
delta + new data")
 
     val map3 = map2.copy()
+    assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === 
false)
+    testSerialization(map3, "error deserializing and serialized map with 2 
delta + no new data")
     map3.put(3, 600, 3)
     map3.remove(2)
-
-    // Do not test compaction
-    assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === 
false)
-
-    val deser_map3 = Utils.deserialize[StateMap[Int, Int]](
-      Utils.serialize(map3), Thread.currentThread().getContextClassLoader)
-    assertMap(deser_map3, map3, 1, "Deserialized map not same as original map")
+    testSerialization(map3, "error deserializing and serialized map with 2 
delta + new data")
   }
 
   test("OpenHashMapBasedStateMap - serializing and deserializing with 
compaction") {
@@ -156,11 +160,9 @@ class StateMapSuite extends SparkFunSuite {
     assert(map.deltaChainLength > deltaChainThreshold)
     assert(map.shouldCompact === true)
 
-    val deser_map = Utils.deserialize[OpenHashMapBasedStateMap[Int, Int]](
-      Utils.serialize(map), Thread.currentThread().getContextClassLoader)
+    val deser_map = testSerialization(map, "Deserialized + compacted map not 
same as original map")
     assert(deser_map.deltaChainLength < deltaChainThreshold)
     assert(deser_map.shouldCompact === false)
-    assertMap(deser_map, map, 1, "Deserialized + compacted map not same as 
original map")
   }
 
   test("OpenHashMapBasedStateMap - all possible sequences of operations with 
copies ") {
@@ -265,6 +267,14 @@ class StateMapSuite extends SparkFunSuite {
     assertMap(stateMap, refMap.toMap, time, "Final state map does not match 
reference map")
   }
 
+  private def testSerialization[MapType <: StateMap[Int, Int]](
+    map: MapType, msg: String): MapType = {
+    val deserMap = Utils.deserialize[MapType](
+      Utils.serialize(map), Thread.currentThread().getContextClassLoader)
+    assertMap(deserMap, map, 1, msg)
+    deserMap
+  }
+
   // Assert whether all the data and operations on a state map matches that of 
a reference state map
   private def assertMap(
       mapToTest: StateMap[Int, Int],

http://git-wip-us.apache.org/repos/asf/spark/blob/7f030aa4/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
 
b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
index 0feb3af..3b2d43f 100644
--- 
a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
+++ 
b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
@@ -332,6 +332,16 @@ class TrackStateRDDSuite extends SparkFunSuite with 
RDDCheckpointTester with Bef
       makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, 
rddCollectFunc _)
   }
 
+  test("checkpointing empty state RDD") {
+    val emptyStateRDD = TrackStateRDD.createFromPairRDD[Int, Int, Int, Int](
+      sc.emptyRDD[(Int, Int)], new HashPartitioner(10), Time(0))
+    emptyStateRDD.checkpoint()
+    assert(emptyStateRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty)
+    val cpRDD = sc.checkpointFile[TrackStateRDDRecord[Int, Int, Int]](
+      emptyStateRDD.getCheckpointFile.get)
+    assert(cpRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty)
+  }
+
   /** Assert whether the `trackStateByKey` operation generates expected 
results */
   private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: 
ClassTag](
       testStateRDD: TrackStateRDD[K, V, S, T],


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to