Repository: spark
Updated Branches:
  refs/heads/branch-2.0 a3bba372a -> bec077069


[SPARK-17491] Close serialization stream to fix wrong answer bug in 
putIteratorAsBytes()

## What changes were proposed in this pull request?

`MemoryStore.putIteratorAsBytes()` may silently lose values when used with 
`KryoSerializer` because it does not properly close the serialization stream 
before attempting to deserialize the already-serialized values, which may cause 
values buffered in Kryo's internal buffers to not be read.

This is the root cause behind a user-reported "wrong answer" bug in PySpark 
caching reported by bennoleslie on the Spark user mailing list in a thread 
titled "pyspark persist MEMORY_ONLY vs MEMORY_AND_DISK". Due to Spark 2.0's 
automatic use of KryoSerializer for "safe" types (such as byte arrays, 
primitives, etc.) this misuse of serializers manifested itself as silent data 
corruption rather than a StreamCorrupted error (which you might get from 
JavaSerializer).

The minimal fix, implemented here, is to close the serialization stream before 
attempting to deserialize written values. In addition, this patch adds several 
additional assertions / precondition checks to prevent misuse of 
`PartiallySerializedBlock` and `ChunkedByteBufferOutputStream`.

## How was this patch tested?

The original bug was masked by an invalid assert in the memory store test 
cases: the old assert compared two results record-by-record with `zip` but 
didn't first check that the lengths of the two collections were equal, causing 
missing records to go unnoticed. The updated test case reproduced this bug.

In addition, I added a new `PartiallySerializedBlockSuite` to unit test that 
component.

Author: Josh Rosen <joshro...@databricks.com>

Closes #15043 from JoshRosen/partially-serialized-block-values-iterator-bugfix.

(cherry picked from commit 8faa5217b44e8d52eab7eb2d53d0652abaaf43cd)
Signed-off-by: Josh Rosen <joshro...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bec07706
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bec07706
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bec07706

Branch: refs/heads/branch-2.0
Commit: bec077069af0b3bc22092a0552baf855dfb344ad
Parents: a3bba37
Author: Josh Rosen <joshro...@databricks.com>
Authored: Sat Sep 17 11:46:15 2016 -0700
Committer: Josh Rosen <joshro...@databricks.com>
Committed: Sat Sep 17 11:46:39 2016 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/scheduler/Task.scala |   1 +
 .../spark/storage/memory/MemoryStore.scala      |  89 ++++++--
 .../spark/util/ByteBufferOutputStream.scala     |  27 ++-
 .../util/io/ChunkedByteBufferOutputStream.scala |  12 +-
 .../apache/spark/storage/MemoryStoreSuite.scala |  34 ++-
 .../storage/PartiallySerializedBlockSuite.scala | 215 +++++++++++++++++++
 .../PartiallyUnrolledIteratorSuite.scala        |   2 +-
 .../io/ChunkedByteBufferOutputStreamSuite.scala |   8 +
 8 files changed, 344 insertions(+), 44 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bec07706/core/src/main/scala/org/apache/spark/scheduler/Task.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala 
b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 35c4daf..1ed36bf 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -230,6 +230,7 @@ private[spark] object Task {
     dataOut.flush()
     val taskBytes = serializer.serialize(task)
     Utils.writeByteBuffer(taskBytes, out)
+    out.close()
     out.toByteBuffer
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/bec07706/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala 
b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
index 1230128..161434c 100644
--- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
@@ -33,7 +33,7 @@ import org.apache.spark.memory.{MemoryManager, MemoryMode}
 import org.apache.spark.serializer.{SerializationStream, SerializerManager}
 import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel}
 import org.apache.spark.unsafe.Platform
-import org.apache.spark.util.{CompletionIterator, SizeEstimator, Utils}
+import org.apache.spark.util.{SizeEstimator, Utils}
 import org.apache.spark.util.collection.SizeTrackingVector
 import org.apache.spark.util.io.{ChunkedByteBuffer, 
ChunkedByteBufferOutputStream}
 
@@ -275,6 +275,7 @@ private[spark] class MemoryStore(
           "released too much unroll memory")
         Left(new PartiallyUnrolledIterator(
           this,
+          MemoryMode.ON_HEAP,
           unrollMemoryUsedByThisBlock,
           unrolled = arrayValues.toIterator,
           rest = Iterator.empty))
@@ -283,7 +284,11 @@ private[spark] class MemoryStore(
       // We ran out of space while unrolling the values for this block
       logUnrollFailureMessage(blockId, vector.estimateSize())
       Left(new PartiallyUnrolledIterator(
-        this, unrollMemoryUsedByThisBlock, unrolled = vector.iterator, rest = 
values))
+        this,
+        MemoryMode.ON_HEAP,
+        unrollMemoryUsedByThisBlock,
+        unrolled = vector.iterator,
+        rest = values))
     }
   }
 
@@ -392,7 +397,7 @@ private[spark] class MemoryStore(
           redirectableStream,
           unrollMemoryUsedByThisBlock,
           memoryMode,
-          bbos.toChunkedByteBuffer,
+          bbos,
           values,
           classTag))
     }
@@ -653,6 +658,7 @@ private[spark] class MemoryStore(
  * The result of a failed [[MemoryStore.putIteratorAsValues()]] call.
  *
  * @param memoryStore  the memoryStore, used for freeing memory.
+ * @param memoryMode   the memory mode (on- or off-heap).
  * @param unrollMemory the amount of unroll memory used by the values in 
`unrolled`.
  * @param unrolled     an iterator for the partially-unrolled values.
  * @param rest         the rest of the original iterator passed to
@@ -660,13 +666,14 @@ private[spark] class MemoryStore(
  */
 private[storage] class PartiallyUnrolledIterator[T](
     memoryStore: MemoryStore,
+    memoryMode: MemoryMode,
     unrollMemory: Long,
     private[this] var unrolled: Iterator[T],
     rest: Iterator[T])
   extends Iterator[T] {
 
   private def releaseUnrollMemory(): Unit = {
-    memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, 
unrollMemory)
+    memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
     // SPARK-17503: Garbage collects the unrolling memory before the life end 
of
     // PartiallyUnrolledIterator.
     unrolled = null
@@ -704,7 +711,7 @@ private[storage] class PartiallyUnrolledIterator[T](
 /**
  * A wrapper which allows an open [[OutputStream]] to be redirected to a 
different sink.
  */
-private class RedirectableOutputStream extends OutputStream {
+private[storage] class RedirectableOutputStream extends OutputStream {
   private[this] var os: OutputStream = _
   def setOutputStream(s: OutputStream): Unit = { os = s }
   override def write(b: Int): Unit = os.write(b)
@@ -724,7 +731,8 @@ private class RedirectableOutputStream extends OutputStream 
{
  * @param redirectableOutputStream an OutputStream which can be redirected to 
a different sink.
  * @param unrollMemory the amount of unroll memory used by the values in 
`unrolled`.
  * @param memoryMode whether the unroll memory is on- or off-heap
- * @param unrolled a byte buffer containing the partially-serialized values.
+ * @param bbos byte buffer output stream containing the partially-serialized 
values.
+ *                     [[redirectableOutputStream]] initially points to this 
output stream.
  * @param rest         the rest of the original iterator passed to
  *                     [[MemoryStore.putIteratorAsValues()]].
  * @param classTag the [[ClassTag]] for the block.
@@ -733,14 +741,19 @@ private[storage] class PartiallySerializedBlock[T](
     memoryStore: MemoryStore,
     serializerManager: SerializerManager,
     blockId: BlockId,
-    serializationStream: SerializationStream,
-    redirectableOutputStream: RedirectableOutputStream,
-    unrollMemory: Long,
+    private val serializationStream: SerializationStream,
+    private val redirectableOutputStream: RedirectableOutputStream,
+    val unrollMemory: Long,
     memoryMode: MemoryMode,
-    unrolled: ChunkedByteBuffer,
+    bbos: ChunkedByteBufferOutputStream,
     rest: Iterator[T],
     classTag: ClassTag[T]) {
 
+  private lazy val unrolledBuffer: ChunkedByteBuffer = {
+    bbos.close()
+    bbos.toChunkedByteBuffer
+  }
+
   // If the task does not fully consume `valuesIterator` or otherwise fails to 
consume or dispose of
   // this PartiallySerializedBlock then we risk leaking of direct buffers, so 
we use a task
   // completion listener here in order to ensure that `unrolled.dispose()` is 
called at least once.
@@ -749,7 +762,23 @@ private[storage] class PartiallySerializedBlock[T](
     taskContext.addTaskCompletionListener { _ =>
       // When a task completes, its unroll memory will automatically be freed. 
Thus we do not call
       // releaseUnrollMemoryForThisTask() here because we want to avoid 
double-freeing.
-      unrolled.dispose()
+      unrolledBuffer.dispose()
+    }
+  }
+
+  // Exposed for testing
+  private[storage] def getUnrolledChunkedByteBuffer: ChunkedByteBuffer = 
unrolledBuffer
+
+  private[this] var discarded = false
+  private[this] var consumed = false
+
+  private def verifyNotConsumedAndNotDiscarded(): Unit = {
+    if (consumed) {
+      throw new IllegalStateException(
+        "Can only call one of finishWritingToStream() or valuesIterator() and 
can only call once.")
+    }
+    if (discarded) {
+      throw new IllegalStateException("Cannot call methods on a discarded 
PartiallySerializedBlock")
     }
   }
 
@@ -757,15 +786,18 @@ private[storage] class PartiallySerializedBlock[T](
    * Called to dispose of this block and free its memory.
    */
   def discard(): Unit = {
-    try {
-      // We want to close the output stream in order to free any resources 
associated with the
-      // serializer itself (such as Kryo's internal buffers). close() might 
cause data to be
-      // written, so redirect the output stream to discard that data.
-      redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream())
-      serializationStream.close()
-    } finally {
-      unrolled.dispose()
-      memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
+    if (!discarded) {
+      try {
+        // We want to close the output stream in order to free any resources 
associated with the
+        // serializer itself (such as Kryo's internal buffers). close() might 
cause data to be
+        // written, so redirect the output stream to discard that data.
+        
redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream())
+        serializationStream.close()
+      } finally {
+        discarded = true
+        unrolledBuffer.dispose()
+        memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
+      }
     }
   }
 
@@ -774,8 +806,10 @@ private[storage] class PartiallySerializedBlock[T](
    * and then serializing the values from the original input iterator.
    */
   def finishWritingToStream(os: OutputStream): Unit = {
+    verifyNotConsumedAndNotDiscarded()
+    consumed = true
     // `unrolled`'s underlying buffers will be freed once this input stream is 
fully read:
-    ByteStreams.copy(unrolled.toInputStream(dispose = true), os)
+    ByteStreams.copy(unrolledBuffer.toInputStream(dispose = true), os)
     memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
     redirectableOutputStream.setOutputStream(os)
     while (rest.hasNext) {
@@ -792,13 +826,22 @@ private[storage] class PartiallySerializedBlock[T](
    * `close()` on it to free its resources.
    */
   def valuesIterator: PartiallyUnrolledIterator[T] = {
+    verifyNotConsumedAndNotDiscarded()
+    consumed = true
+    // Close the serialization stream so that the serializer's internal 
buffers are freed and any
+    // "end-of-stream" markers can be written out so that `unrolled` is a 
valid serialized stream.
+    serializationStream.close()
     // `unrolled`'s underlying buffers will be freed once this input stream is 
fully read:
     val unrolledIter = serializerManager.dataDeserializeStream(
-      blockId, unrolled.toInputStream(dispose = true))(classTag)
+      blockId, unrolledBuffer.toInputStream(dispose = true))(classTag)
+    // The unroll memory will be freed once `unrolledIter` is fully consumed in
+    // PartiallyUnrolledIterator. If the iterator is not consumed by the end 
of the task then any
+    // extra unroll memory will automatically be freed by a `finally` block in 
`Task`.
     new PartiallyUnrolledIterator(
       memoryStore,
+      memoryMode,
       unrollMemory,
-      unrolled = CompletionIterator[T, Iterator[T]](unrolledIter, discard()),
+      unrolled = unrolledIter,
       rest = rest)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/bec07706/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala 
b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala
index 09e7579..9077b86 100644
--- a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala
+++ b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala
@@ -29,7 +29,32 @@ private[spark] class ByteBufferOutputStream(capacity: Int) 
extends ByteArrayOutp
 
   def getCount(): Int = count
 
+  private[this] var closed: Boolean = false
+
+  override def write(b: Int): Unit = {
+    require(!closed, "cannot write to a closed ByteBufferOutputStream")
+    super.write(b)
+  }
+
+  override def write(b: Array[Byte], off: Int, len: Int): Unit = {
+    require(!closed, "cannot write to a closed ByteBufferOutputStream")
+    super.write(b, off, len)
+  }
+
+  override def reset(): Unit = {
+    require(!closed, "cannot reset a closed ByteBufferOutputStream")
+    super.reset()
+  }
+
+  override def close(): Unit = {
+    if (!closed) {
+      super.close()
+      closed = true
+    }
+  }
+
   def toByteBuffer: ByteBuffer = {
-    return ByteBuffer.wrap(buf, 0, count)
+    require(closed, "can only call toByteBuffer() after ByteBufferOutputStream 
has been closed")
+    ByteBuffer.wrap(buf, 0, count)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/bec07706/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala
 
b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala
index 67b50d1..a625b32 100644
--- 
a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala
+++ 
b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala
@@ -49,10 +49,19 @@ private[spark] class ChunkedByteBufferOutputStream(
    */
   private[this] var position = chunkSize
   private[this] var _size = 0
+  private[this] var closed: Boolean = false
 
   def size: Long = _size
 
+  override def close(): Unit = {
+    if (!closed) {
+      super.close()
+      closed = true
+    }
+  }
+
   override def write(b: Int): Unit = {
+    require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream")
     allocateNewChunkIfNeeded()
     chunks(lastChunkIndex).put(b.toByte)
     position += 1
@@ -60,6 +69,7 @@ private[spark] class ChunkedByteBufferOutputStream(
   }
 
   override def write(bytes: Array[Byte], off: Int, len: Int): Unit = {
+    require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream")
     var written = 0
     while (written < len) {
       allocateNewChunkIfNeeded()
@@ -73,7 +83,6 @@ private[spark] class ChunkedByteBufferOutputStream(
 
   @inline
   private def allocateNewChunkIfNeeded(): Unit = {
-    require(!toChunkedByteBufferWasCalled, "cannot write after 
toChunkedByteBuffer() is called")
     if (position == chunkSize) {
       chunks += allocator(chunkSize)
       lastChunkIndex += 1
@@ -82,6 +91,7 @@ private[spark] class ChunkedByteBufferOutputStream(
   }
 
   def toChunkedByteBuffer: ChunkedByteBuffer = {
+    require(closed, "cannot call toChunkedByteBuffer() unless close() has been 
called")
     require(!toChunkedByteBufferWasCalled, "toChunkedByteBuffer() can only be 
called once")
     toChunkedByteBufferWasCalled = true
     if (lastChunkIndex == -1) {

http://git-wip-us.apache.org/repos/asf/spark/blob/bec07706/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala 
b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala
index 145d432..9e10ee5 100644
--- a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala
@@ -80,6 +80,13 @@ class MemoryStoreSuite
     (memoryStore, blockInfoManager)
   }
 
+  private def assertSameContents[T](expected: Seq[T], actual: Seq[T], hint: 
String): Unit = {
+    assert(actual.length === expected.length, s"wrong number of values 
returned in $hint")
+    expected.iterator.zip(actual.iterator).foreach { case (e, a) =>
+      assert(e === a, s"$hint did not return original values!")
+    }
+  }
+
   test("reserve/release unroll memory") {
     val (memoryStore, _) = makeMemoryStore(12000)
     assert(memoryStore.currentUnrollMemory === 0)
@@ -138,9 +145,7 @@ class MemoryStoreSuite
     var putResult = putIteratorAsValues("unroll", smallList.iterator, 
ClassTag.Any)
     assert(putResult.isRight)
     assert(memoryStore.currentUnrollMemoryForThisTask === 0)
-    smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case 
(e, a) =>
-      assert(e === a, "getValues() did not return original values!")
-    }
+    assertSameContents(smallList, memoryStore.getValues("unroll").get.toSeq, 
"getValues")
     blockInfoManager.lockForWriting("unroll")
     assert(memoryStore.remove("unroll"))
     blockInfoManager.removeBlock("unroll")
@@ -153,9 +158,7 @@ class MemoryStoreSuite
     assert(memoryStore.currentUnrollMemoryForThisTask === 0)
     assert(memoryStore.contains("someBlock2"))
     assert(!memoryStore.contains("someBlock1"))
-    smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case 
(e, a) =>
-      assert(e === a, "getValues() did not return original values!")
-    }
+    assertSameContents(smallList, memoryStore.getValues("unroll").get.toSeq, 
"getValues")
     blockInfoManager.lockForWriting("unroll")
     assert(memoryStore.remove("unroll"))
     blockInfoManager.removeBlock("unroll")
@@ -168,9 +171,7 @@ class MemoryStoreSuite
     assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an 
iterator
     assert(!memoryStore.contains("someBlock2"))
     assert(putResult.isLeft)
-    bigList.iterator.zip(putResult.left.get).foreach { case (e, a) =>
-      assert(e === a, "putIterator() did not return original values!")
-    }
+    assertSameContents(bigList, putResult.left.get.toSeq, "putIterator")
     // The unroll memory was freed once the iterator returned by putIterator() 
was fully traversed.
     assert(memoryStore.currentUnrollMemoryForThisTask === 0)
   }
@@ -317,9 +318,8 @@ class MemoryStoreSuite
     assert(res.isLeft)
     assert(memoryStore.currentUnrollMemoryForThisTask > 0)
     val valuesReturnedFromFailedPut = res.left.get.valuesIterator.toSeq // 
force materialization
-    valuesReturnedFromFailedPut.zip(bigList).foreach { case (e, a) =>
-      assert(e === a, "PartiallySerializedBlock.valuesIterator() did not 
return original values!")
-    }
+    assertSameContents(
+      bigList, valuesReturnedFromFailedPut, 
"PartiallySerializedBlock.valuesIterator()")
     // The unroll memory was freed once the iterator was fully traversed.
     assert(memoryStore.currentUnrollMemoryForThisTask === 0)
   }
@@ -341,12 +341,10 @@ class MemoryStoreSuite
     res.left.get.finishWritingToStream(bos)
     // The unroll memory was freed once the block was fully written.
     assert(memoryStore.currentUnrollMemoryForThisTask === 0)
-    val deserializationStream = serializerManager.dataDeserializeStream[Any](
-      "b1", new ByteBufferInputStream(bos.toByteBuffer))(ClassTag.Any)
-    deserializationStream.zip(bigList.iterator).foreach { case (e, a) =>
-      assert(e === a,
-        "PartiallySerializedBlock.finishWritingtoStream() did not write 
original values!")
-    }
+    val deserializedValues = serializerManager.dataDeserializeStream[Any](
+      "b1", new ByteBufferInputStream(bos.toByteBuffer))(ClassTag.Any).toSeq
+    assertSameContents(
+      bigList, deserializedValues, 
"PartiallySerializedBlock.finishWritingToStream()")
   }
 
   test("multiple unrolls by the same thread") {

http://git-wip-us.apache.org/repos/asf/spark/blob/bec07706/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala
 
b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala
new file mode 100644
index 0000000..ec4f263
--- /dev/null
+++ 
b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.nio.ByteBuffer
+
+import scala.reflect.ClassTag
+
+import org.mockito.Mockito
+import org.mockito.Mockito.atLeastOnce
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester}
+
+import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, 
TaskContextImpl}
+import org.apache.spark.memory.MemoryMode
+import org.apache.spark.serializer.{JavaSerializer, SerializationStream, 
SerializerManager}
+import org.apache.spark.storage.memory.{MemoryStore, PartiallySerializedBlock, 
RedirectableOutputStream}
+import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream}
+import org.apache.spark.util.io.{ChunkedByteBuffer, 
ChunkedByteBufferOutputStream}
+
+class PartiallySerializedBlockSuite
+    extends SparkFunSuite
+    with BeforeAndAfterEach
+    with PrivateMethodTester {
+
+  private val blockId = new TestBlockId("test")
+  private val conf = new SparkConf()
+  private val memoryStore = Mockito.mock(classOf[MemoryStore], 
Mockito.RETURNS_SMART_NULLS)
+  private val serializerManager = new SerializerManager(new 
JavaSerializer(conf), conf)
+
+  private val getSerializationStream = 
PrivateMethod[SerializationStream]('serializationStream)
+  private val getRedirectableOutputStream =
+    PrivateMethod[RedirectableOutputStream]('redirectableOutputStream)
+
+  override protected def beforeEach(): Unit = {
+    super.beforeEach()
+    Mockito.reset(memoryStore)
+  }
+
+  private def partiallyUnroll[T: ClassTag](
+      iter: Iterator[T],
+      numItemsToBuffer: Int): PartiallySerializedBlock[T] = {
+
+    val bbos: ChunkedByteBufferOutputStream = {
+      val spy = Mockito.spy(new ChunkedByteBufferOutputStream(128, 
ByteBuffer.allocate))
+      Mockito.doAnswer(new Answer[ChunkedByteBuffer] {
+        override def answer(invocationOnMock: InvocationOnMock): 
ChunkedByteBuffer = {
+          
Mockito.spy(invocationOnMock.callRealMethod().asInstanceOf[ChunkedByteBuffer])
+        }
+      }).when(spy).toChunkedByteBuffer
+      spy
+    }
+
+    val serializer = 
serializerManager.getSerializer(implicitly[ClassTag[T]]).newInstance()
+    val redirectableOutputStream = Mockito.spy(new RedirectableOutputStream)
+    redirectableOutputStream.setOutputStream(bbos)
+    val serializationStream = 
Mockito.spy(serializer.serializeStream(redirectableOutputStream))
+
+    (1 to numItemsToBuffer).foreach { _ =>
+      assert(iter.hasNext)
+      serializationStream.writeObject[T](iter.next())
+    }
+
+    val unrollMemory = bbos.size
+    new PartiallySerializedBlock[T](
+      memoryStore,
+      serializerManager,
+      blockId,
+      serializationStream = serializationStream,
+      redirectableOutputStream,
+      unrollMemory = unrollMemory,
+      memoryMode = MemoryMode.ON_HEAP,
+      bbos,
+      rest = iter,
+      classTag = implicitly[ClassTag[T]])
+  }
+
+  test("valuesIterator() and finishWritingToStream() cannot be called after 
discard() is called") {
+    val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+    partiallySerializedBlock.discard()
+    intercept[IllegalStateException] {
+      partiallySerializedBlock.finishWritingToStream(null)
+    }
+    intercept[IllegalStateException] {
+      partiallySerializedBlock.valuesIterator
+    }
+  }
+
+  test("discard() can be called more than once") {
+    val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+    partiallySerializedBlock.discard()
+    partiallySerializedBlock.discard()
+  }
+
+  test("cannot call valuesIterator() more than once") {
+    val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+    partiallySerializedBlock.valuesIterator
+    intercept[IllegalStateException] {
+      partiallySerializedBlock.valuesIterator
+    }
+  }
+
+  test("cannot call finishWritingToStream() more than once") {
+    val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+    partiallySerializedBlock.finishWritingToStream(new 
ByteBufferOutputStream())
+    intercept[IllegalStateException] {
+      partiallySerializedBlock.finishWritingToStream(new 
ByteBufferOutputStream())
+    }
+  }
+
+  test("cannot call finishWritingToStream() after valuesIterator()") {
+    val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+    partiallySerializedBlock.valuesIterator
+    intercept[IllegalStateException] {
+      partiallySerializedBlock.finishWritingToStream(new 
ByteBufferOutputStream())
+    }
+  }
+
+  test("cannot call valuesIterator() after finishWritingToStream()") {
+    val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+    partiallySerializedBlock.finishWritingToStream(new 
ByteBufferOutputStream())
+    intercept[IllegalStateException] {
+      partiallySerializedBlock.valuesIterator
+    }
+  }
+
+  test("buffers are deallocated in a TaskCompletionListener") {
+    try {
+      TaskContext.setTaskContext(TaskContext.empty())
+      val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
+      TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted()
+      
Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer).dispose()
+      Mockito.verifyNoMoreInteractions(memoryStore)
+    } finally {
+      TaskContext.unset()
+    }
+  }
+
+  private def testUnroll[T: ClassTag](
+      testCaseName: String,
+      items: Seq[T],
+      numItemsToBuffer: Int): Unit = {
+
+    test(s"$testCaseName with discard() and numBuffered = $numItemsToBuffer") {
+      val partiallySerializedBlock = partiallyUnroll(items.iterator, 
numItemsToBuffer)
+      partiallySerializedBlock.discard()
+
+      Mockito.verify(memoryStore).releaseUnrollMemoryForThisTask(
+        MemoryMode.ON_HEAP, partiallySerializedBlock.unrollMemory)
+      
Mockito.verify(partiallySerializedBlock.invokePrivate(getSerializationStream())).close()
+      
Mockito.verify(partiallySerializedBlock.invokePrivate(getRedirectableOutputStream())).close()
+      Mockito.verifyNoMoreInteractions(memoryStore)
+      Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer, 
atLeastOnce).dispose()
+    }
+
+    test(s"$testCaseName with finishWritingToStream() and numBuffered = 
$numItemsToBuffer") {
+      val partiallySerializedBlock = partiallyUnroll(items.iterator, 
numItemsToBuffer)
+      val bbos = Mockito.spy(new ByteBufferOutputStream())
+      partiallySerializedBlock.finishWritingToStream(bbos)
+
+      Mockito.verify(memoryStore).releaseUnrollMemoryForThisTask(
+        MemoryMode.ON_HEAP, partiallySerializedBlock.unrollMemory)
+      
Mockito.verify(partiallySerializedBlock.invokePrivate(getSerializationStream())).close()
+      
Mockito.verify(partiallySerializedBlock.invokePrivate(getRedirectableOutputStream())).close()
+      Mockito.verify(bbos).close()
+      Mockito.verifyNoMoreInteractions(memoryStore)
+      Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer, 
atLeastOnce).dispose()
+
+      val serializer = 
serializerManager.getSerializer(implicitly[ClassTag[T]]).newInstance()
+      val deserialized =
+        serializer.deserializeStream(new 
ByteBufferInputStream(bbos.toByteBuffer)).asIterator.toSeq
+      assert(deserialized === items)
+    }
+
+    test(s"$testCaseName with valuesIterator() and numBuffered = 
$numItemsToBuffer") {
+      val partiallySerializedBlock = partiallyUnroll(items.iterator, 
numItemsToBuffer)
+      val valuesIterator = partiallySerializedBlock.valuesIterator
+      
Mockito.verify(partiallySerializedBlock.invokePrivate(getSerializationStream())).close()
+      
Mockito.verify(partiallySerializedBlock.invokePrivate(getRedirectableOutputStream())).close()
+
+      val deserializedItems = valuesIterator.toArray.toSeq
+      Mockito.verify(memoryStore).releaseUnrollMemoryForThisTask(
+        MemoryMode.ON_HEAP, partiallySerializedBlock.unrollMemory)
+      Mockito.verifyNoMoreInteractions(memoryStore)
+      Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer, 
atLeastOnce).dispose()
+      assert(deserializedItems === items)
+    }
+  }
+
+  testUnroll("basic numbers", 1 to 1000, numItemsToBuffer = 50)
+  testUnroll("basic numbers", 1 to 1000, numItemsToBuffer = 0)
+  testUnroll("basic numbers", 1 to 1000, numItemsToBuffer = 1000)
+  testUnroll("case classes", (1 to 1000).map(x => MyCaseClass(x.toString)), 
numItemsToBuffer = 50)
+  testUnroll("case classes", (1 to 1000).map(x => MyCaseClass(x.toString)), 
numItemsToBuffer = 0)
+  testUnroll("case classes", (1 to 1000).map(x => MyCaseClass(x.toString)), 
numItemsToBuffer = 1000)
+  testUnroll("empty iterator", Seq.empty[String], numItemsToBuffer = 0)
+}
+
+private case class MyCaseClass(str: String)

http://git-wip-us.apache.org/repos/asf/spark/blob/bec07706/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
 
b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
index 02c2331..4253cc8 100644
--- 
a/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
@@ -33,7 +33,7 @@ class PartiallyUnrolledIteratorSuite extends SparkFunSuite 
with MockitoSugar {
     val rest = (unrollSize until restSize + unrollSize).iterator
 
     val memoryStore = mock[MemoryStore]
-    val joinIterator = new PartiallyUnrolledIterator(memoryStore, unrollSize, 
unroll, rest)
+    val joinIterator = new PartiallyUnrolledIterator(memoryStore, ON_HEAP, 
unrollSize, unroll, rest)
 
     // Firstly iterate over unrolling memory iterator
     (0 until unrollSize).foreach { value =>

http://git-wip-us.apache.org/repos/asf/spark/blob/bec07706/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala
 
b/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala
index 2266220..8696174 100644
--- 
a/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala
@@ -28,12 +28,14 @@ class ChunkedByteBufferOutputStreamSuite extends 
SparkFunSuite {
 
   test("empty output") {
     val o = new ChunkedByteBufferOutputStream(1024, ByteBuffer.allocate)
+    o.close()
     assert(o.toChunkedByteBuffer.size === 0)
   }
 
   test("write a single byte") {
     val o = new ChunkedByteBufferOutputStream(1024, ByteBuffer.allocate)
     o.write(10)
+    o.close()
     val chunkedByteBuffer = o.toChunkedByteBuffer
     assert(chunkedByteBuffer.getChunks().length === 1)
     assert(chunkedByteBuffer.getChunks().head.array().toSeq === Seq(10.toByte))
@@ -43,6 +45,7 @@ class ChunkedByteBufferOutputStreamSuite extends 
SparkFunSuite {
     val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
     o.write(new Array[Byte](9))
     o.write(99)
+    o.close()
     val chunkedByteBuffer = o.toChunkedByteBuffer
     assert(chunkedByteBuffer.getChunks().length === 1)
     assert(chunkedByteBuffer.getChunks().head.array()(9) === 99.toByte)
@@ -52,6 +55,7 @@ class ChunkedByteBufferOutputStreamSuite extends 
SparkFunSuite {
     val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
     o.write(new Array[Byte](10))
     o.write(99)
+    o.close()
     val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
     assert(arrays.length === 2)
     assert(arrays(1).length === 1)
@@ -63,6 +67,7 @@ class ChunkedByteBufferOutputStreamSuite extends 
SparkFunSuite {
     Random.nextBytes(ref)
     val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
     o.write(ref)
+    o.close()
     val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
     assert(arrays.length === 1)
     assert(arrays.head.length === ref.length)
@@ -74,6 +79,7 @@ class ChunkedByteBufferOutputStreamSuite extends 
SparkFunSuite {
     Random.nextBytes(ref)
     val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
     o.write(ref)
+    o.close()
     val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
     assert(arrays.length === 1)
     assert(arrays.head.length === ref.length)
@@ -85,6 +91,7 @@ class ChunkedByteBufferOutputStreamSuite extends 
SparkFunSuite {
     Random.nextBytes(ref)
     val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
     o.write(ref)
+    o.close()
     val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
     assert(arrays.length === 3)
     assert(arrays(0).length === 10)
@@ -101,6 +108,7 @@ class ChunkedByteBufferOutputStreamSuite extends 
SparkFunSuite {
     Random.nextBytes(ref)
     val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
     o.write(ref)
+    o.close()
     val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
     assert(arrays.length === 3)
     assert(arrays(0).length === 10)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to