Repository: spark
Updated Branches:
  refs/heads/branch-2.1 9e96ac5a9 -> c2c2fdcb7


[SPARK-18546][CORE] Fix merging shuffle spills when using encryption.

The problem exists because it's not possible to just concatenate encrypted
partition data from different spill files; currently each partition would
have its own initial vector to set up encryption, and the final merged file
should contain a single initial vector for each merged partiton, otherwise
iterating over each record becomes really hard.

To fix that, UnsafeShuffleWriter now decrypts the partitions when merging,
so that the merged file contains a single initial vector at the start of
the partition data.

Because it's not possible to do that using the fast transferTo path, when
encryption is enabled UnsafeShuffleWriter will revert back to using file
streams when merging. It may be possible to use a hybrid approach when
using encryption, using an intermediate direct buffer when reading from
files and encrypting the data, but that's better left for a separate patch.

As part of the change I made DiskBlockObjectWriter take a SerializerManager
instead of a "wrap stream" closure, since that makes it easier to test the
code without having to mock SerializerManager functionality.

Tested with newly added unit tests (UnsafeShuffleWriterSuite for the write
side and ExternalAppendOnlyMapSuite for integration), and by running some
apps that failed without the fix.

Author: Marcelo Vanzin <van...@cloudera.com>

Closes #15982 from vanzin/SPARK-18546.

(cherry picked from commit 93e9d880bf8a144112d74a6897af4e36fcfa5807)
Signed-off-by: Marcelo Vanzin <van...@cloudera.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c2c2fdcb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c2c2fdcb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c2c2fdcb

Branch: refs/heads/branch-2.1
Commit: c2c2fdcb71e9bc82f0e88567148d1bae283f256a
Parents: 9e96ac5
Author: Marcelo Vanzin <van...@cloudera.com>
Authored: Wed Nov 30 14:10:32 2016 -0800
Committer: Marcelo Vanzin <van...@cloudera.com>
Committed: Wed Nov 30 14:10:44 2016 -0800

----------------------------------------------------------------------
 .../spark/shuffle/sort/UnsafeShuffleWriter.java |  48 +++++----
 .../spark/serializer/SerializerManager.scala    |   6 +-
 .../org/apache/spark/storage/BlockManager.scala |   5 +-
 .../spark/storage/DiskBlockObjectWriter.scala   |   6 +-
 .../shuffle/sort/UnsafeShuffleWriterSuite.java  | 100 +++++++++++++------
 .../map/AbstractBytesToBytesMapSuite.java       |  11 +-
 .../unsafe/sort/UnsafeExternalSorterSuite.java  |  21 ++--
 .../BypassMergeSortShuffleWriterSuite.scala     |   5 +-
 .../storage/DiskBlockObjectWriterSuite.scala    |  54 ++++------
 .../collection/ExternalAppendOnlyMapSuite.scala |   8 +-
 10 files changed, 145 insertions(+), 119 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c2c2fdcb/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java 
b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index f235c43..8a17718 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -40,6 +40,8 @@ import org.apache.spark.annotation.Private;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.io.CompressionCodec;
 import org.apache.spark.io.CompressionCodec$;
+import org.apache.commons.io.output.CloseShieldOutputStream;
+import org.apache.commons.io.output.CountingOutputStream;
 import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.network.util.LimitedInputStream;
 import org.apache.spark.scheduler.MapStatus;
@@ -264,6 +266,7 @@ public class UnsafeShuffleWriter<K, V> extends 
ShuffleWriter<K, V> {
       sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true);
     final boolean fastMergeIsSupported = !compressionEnabled ||
       
CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec);
+    final boolean encryptionEnabled = 
blockManager.serializerManager().encryptionEnabled();
     try {
       if (spills.length == 0) {
         new FileOutputStream(outputFile).close(); // Create an empty file
@@ -289,7 +292,7 @@ public class UnsafeShuffleWriter<K, V> extends 
ShuffleWriter<K, V> {
           // Compression is disabled or we are using an IO compression codec 
that supports
           // decompression of concatenated compressed streams, so we can 
perform a fast spill merge
           // that doesn't need to interpret the spilled bytes.
-          if (transferToEnabled) {
+          if (transferToEnabled && !encryptionEnabled) {
             logger.debug("Using transferTo-based fast merge");
             partitionLengths = mergeSpillsWithTransferTo(spills, outputFile);
           } else {
@@ -320,9 +323,9 @@ public class UnsafeShuffleWriter<K, V> extends 
ShuffleWriter<K, V> {
   /**
    * Merges spill files using Java FileStreams. This code path is slower than 
the NIO-based merge,
    * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, 
so it's only used in
-   * cases where the IO compression codec does not support concatenation of 
compressed data, or in
-   * cases where users have explicitly disabled use of {@code transferTo} in 
order to work around
-   * kernel bugs.
+   * cases where the IO compression codec does not support concatenation of 
compressed data, when
+   * encryption is enabled, or when users have explicitly disabled use of 
{@code transferTo} in
+   * order to work around kernel bugs.
    *
    * @param spills the spills to merge.
    * @param outputFile the file to write the merged data to.
@@ -337,7 +340,11 @@ public class UnsafeShuffleWriter<K, V> extends 
ShuffleWriter<K, V> {
     final int numPartitions = partitioner.numPartitions();
     final long[] partitionLengths = new long[numPartitions];
     final InputStream[] spillInputStreams = new FileInputStream[spills.length];
-    OutputStream mergedFileOutputStream = null;
+
+    // Use a counting output stream to avoid having to close the underlying 
file and ask
+    // the file system for its size after each partition is written.
+    final CountingOutputStream mergedFileOutputStream = new 
CountingOutputStream(
+      new FileOutputStream(outputFile));
 
     boolean threwException = true;
     try {
@@ -345,34 +352,35 @@ public class UnsafeShuffleWriter<K, V> extends 
ShuffleWriter<K, V> {
         spillInputStreams[i] = new FileInputStream(spills[i].file);
       }
       for (int partition = 0; partition < numPartitions; partition++) {
-        final long initialFileLength = outputFile.length();
-        mergedFileOutputStream =
-          new TimeTrackingOutputStream(writeMetrics, new 
FileOutputStream(outputFile, true));
+        final long initialFileLength = mergedFileOutputStream.getByteCount();
+        // Shield the underlying output stream from close() calls, so that we 
can close the higher
+        // level streams to make sure all data is really flushed and internal 
state is cleaned.
+        OutputStream partitionOutput = new CloseShieldOutputStream(
+          new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream));
+        partitionOutput = 
blockManager.serializerManager().wrapForEncryption(partitionOutput);
         if (compressionCodec != null) {
-          mergedFileOutputStream = 
compressionCodec.compressedOutputStream(mergedFileOutputStream);
+          partitionOutput = 
compressionCodec.compressedOutputStream(partitionOutput);
         }
-
         for (int i = 0; i < spills.length; i++) {
           final long partitionLengthInSpill = 
spills[i].partitionLengths[partition];
           if (partitionLengthInSpill > 0) {
-            InputStream partitionInputStream = null;
-            boolean innerThrewException = true;
+            InputStream partitionInputStream = new 
LimitedInputStream(spillInputStreams[i],
+              partitionLengthInSpill, false);
             try {
-              partitionInputStream =
-                  new LimitedInputStream(spillInputStreams[i], 
partitionLengthInSpill, false);
+              partitionInputStream = 
blockManager.serializerManager().wrapForEncryption(
+                partitionInputStream);
               if (compressionCodec != null) {
                 partitionInputStream = 
compressionCodec.compressedInputStream(partitionInputStream);
               }
-              ByteStreams.copy(partitionInputStream, mergedFileOutputStream);
-              innerThrewException = false;
+              ByteStreams.copy(partitionInputStream, partitionOutput);
             } finally {
-              Closeables.close(partitionInputStream, innerThrewException);
+              partitionInputStream.close();
             }
           }
         }
-        mergedFileOutputStream.flush();
-        mergedFileOutputStream.close();
-        partitionLengths[partition] = (outputFile.length() - 
initialFileLength);
+        partitionOutput.flush();
+        partitionOutput.close();
+        partitionLengths[partition] = (mergedFileOutputStream.getByteCount() - 
initialFileLength);
       }
       threwException = false;
     } finally {

http://git-wip-us.apache.org/repos/asf/spark/blob/c2c2fdcb/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala 
b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
index 7371f88..686305e 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
@@ -75,6 +75,8 @@ private[spark] class SerializerManager(
    * loaded yet. */
   private lazy val compressionCodec: CompressionCodec = 
CompressionCodec.createCodec(conf)
 
+  def encryptionEnabled: Boolean = encryptionKey.isDefined
+
   def canUseKryo(ct: ClassTag[_]): Boolean = {
     primitiveAndPrimitiveArrayClassTags.contains(ct) || ct == stringClassTag
   }
@@ -129,7 +131,7 @@ private[spark] class SerializerManager(
   /**
    * Wrap an input stream for encryption if shuffle encryption is enabled
    */
-  private[this] def wrapForEncryption(s: InputStream): InputStream = {
+  def wrapForEncryption(s: InputStream): InputStream = {
     encryptionKey
       .map { key => CryptoStreamUtils.createCryptoInputStream(s, conf, key) }
       .getOrElse(s)
@@ -138,7 +140,7 @@ private[spark] class SerializerManager(
   /**
    * Wrap an output stream for encryption if shuffle encryption is enabled
    */
-  private[this] def wrapForEncryption(s: OutputStream): OutputStream = {
+  def wrapForEncryption(s: OutputStream): OutputStream = {
     encryptionKey
       .map { key => CryptoStreamUtils.createCryptoOutputStream(s, conf, key) }
       .getOrElse(s)

http://git-wip-us.apache.org/repos/asf/spark/blob/c2c2fdcb/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala 
b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 982b833..04521c9 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -62,7 +62,7 @@ private[spark] class BlockManager(
     executorId: String,
     rpcEnv: RpcEnv,
     val master: BlockManagerMaster,
-    serializerManager: SerializerManager,
+    val serializerManager: SerializerManager,
     val conf: SparkConf,
     memoryManager: MemoryManager,
     mapOutputTracker: MapOutputTracker,
@@ -745,9 +745,8 @@ private[spark] class BlockManager(
       serializerInstance: SerializerInstance,
       bufferSize: Int,
       writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = {
-    val wrapStream: OutputStream => OutputStream = 
serializerManager.wrapStream(blockId, _)
     val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
-    new DiskBlockObjectWriter(file, serializerInstance, bufferSize, wrapStream,
+    new DiskBlockObjectWriter(file, serializerManager, serializerInstance, 
bufferSize,
       syncWrites, writeMetrics, blockId)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c2c2fdcb/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala 
b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index a499827..3cb12fc 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -22,7 +22,7 @@ import java.nio.channels.FileChannel
 
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.internal.Logging
-import org.apache.spark.serializer.{SerializationStream, SerializerInstance}
+import org.apache.spark.serializer.{SerializationStream, SerializerInstance, 
SerializerManager}
 import org.apache.spark.util.Utils
 
 /**
@@ -37,9 +37,9 @@ import org.apache.spark.util.Utils
  */
 private[spark] class DiskBlockObjectWriter(
     val file: File,
+    serializerManager: SerializerManager,
     serializerInstance: SerializerInstance,
     bufferSize: Int,
-    wrapStream: OutputStream => OutputStream,
     syncWrites: Boolean,
     // These write metrics concurrently shared with other active 
DiskBlockObjectWriters who
     // are themselves performing writes. All updates must be relative.
@@ -116,7 +116,7 @@ private[spark] class DiskBlockObjectWriter(
       initialized = true
     }
 
-    bs = wrapStream(mcs)
+    bs = serializerManager.wrapStream(blockId, mcs)
     objOut = serializerInstance.serializeStream(bs)
     streamOpen = true
     this

http://git-wip-us.apache.org/repos/asf/spark/blob/c2c2fdcb/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
 
b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index a96cd82..088b681 100644
--- 
a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ 
b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -26,11 +26,9 @@ import scala.Product2;
 import scala.Tuple2;
 import scala.Tuple2$;
 import scala.collection.Iterator;
-import scala.runtime.AbstractFunction1;
 
 import com.google.common.collect.HashMultiset;
 import com.google.common.collect.Iterators;
-import com.google.common.io.ByteStreams;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
@@ -53,6 +51,7 @@ import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.memory.TestMemoryManager;
 import org.apache.spark.network.util.LimitedInputStream;
 import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.security.CryptoStreamUtils;
 import org.apache.spark.serializer.*;
 import org.apache.spark.shuffle.IndexShuffleBlockResolver;
 import org.apache.spark.storage.*;
@@ -77,7 +76,6 @@ public class UnsafeShuffleWriterSuite {
   final LinkedList<File> spillFilesCreated = new LinkedList<>();
   SparkConf conf;
   final Serializer serializer = new KryoSerializer(new SparkConf());
-  final SerializerManager serializerManager = new 
SerializerManager(serializer, new SparkConf());
   TaskMetrics taskMetrics;
 
   @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
@@ -86,17 +84,6 @@ public class UnsafeShuffleWriterSuite {
   @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
   @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency<Object, Object, 
Object> shuffleDep;
 
-  private final class WrapStream extends AbstractFunction1<OutputStream, 
OutputStream> {
-    @Override
-    public OutputStream apply(OutputStream stream) {
-      if (conf.getBoolean("spark.shuffle.compress", true)) {
-        return 
CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream);
-      } else {
-        return stream;
-      }
-    }
-  }
-
   @After
   public void tearDown() {
     Utils.deleteRecursively(tempDir);
@@ -121,6 +108,11 @@ public class UnsafeShuffleWriterSuite {
     memoryManager = new TestMemoryManager(conf);
     taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
 
+    // Some tests will override this manager because they change the 
configuration. This is a
+    // default for tests that don't need a specific one.
+    SerializerManager manager = new SerializerManager(serializer, conf);
+    when(blockManager.serializerManager()).thenReturn(manager);
+
     when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
     when(blockManager.getDiskWriter(
       any(BlockId.class),
@@ -131,12 +123,11 @@ public class UnsafeShuffleWriterSuite {
       @Override
       public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) 
throws Throwable {
         Object[] args = invocationOnMock.getArguments();
-
         return new DiskBlockObjectWriter(
           (File) args[1],
+          blockManager.serializerManager(),
           (SerializerInstance) args[2],
           (Integer) args[3],
-          new WrapStream(),
           false,
           (ShuffleWriteMetrics) args[4],
           (BlockId) args[0]
@@ -201,9 +192,10 @@ public class UnsafeShuffleWriterSuite {
     for (int i = 0; i < NUM_PARTITITONS; i++) {
       final long partitionSize = partitionSizesInMergedFile[i];
       if (partitionSize > 0) {
-        InputStream in = new FileInputStream(mergedOutputFile);
-        ByteStreams.skipFully(in, startOffset);
-        in = new LimitedInputStream(in, partitionSize);
+        FileInputStream fin = new FileInputStream(mergedOutputFile);
+        fin.getChannel().position(startOffset);
+        InputStream in = new LimitedInputStream(fin, partitionSize);
+        in = blockManager.serializerManager().wrapForEncryption(in);
         if (conf.getBoolean("spark.shuffle.compress", true)) {
           in = 
CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in);
         }
@@ -294,14 +286,32 @@ public class UnsafeShuffleWriterSuite {
   }
 
   private void testMergingSpills(
-      boolean transferToEnabled,
-      String compressionCodecName) throws IOException {
+      final boolean transferToEnabled,
+      String compressionCodecName,
+      boolean encrypt) throws Exception {
     if (compressionCodecName != null) {
       conf.set("spark.shuffle.compress", "true");
       conf.set("spark.io.compression.codec", compressionCodecName);
     } else {
       conf.set("spark.shuffle.compress", "false");
     }
+    
conf.set(org.apache.spark.internal.config.package$.MODULE$.IO_ENCRYPTION_ENABLED(),
 encrypt);
+
+    SerializerManager manager;
+    if (encrypt) {
+      manager = new SerializerManager(serializer, conf,
+        Option.apply(CryptoStreamUtils.createKey(conf)));
+    } else {
+      manager = new SerializerManager(serializer, conf);
+    }
+
+    when(blockManager.serializerManager()).thenReturn(manager);
+    testMergingSpills(transferToEnabled, encrypt);
+  }
+
+  private void testMergingSpills(
+      boolean transferToEnabled,
+      boolean encrypted) throws IOException {
     final UnsafeShuffleWriter<Object, Object> writer = 
createWriter(transferToEnabled);
     final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
     for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) {
@@ -324,6 +334,7 @@ public class UnsafeShuffleWriterSuite {
     for (long size: partitionSizesInMergedFile) {
       sumOfPartitionSizes += size;
     }
+
     assertEquals(sumOfPartitionSizes, mergedOutputFile.length());
 
     assertEquals(HashMultiset.create(dataToWrite), 
HashMultiset.create(readRecordsFromFile()));
@@ -338,42 +349,72 @@ public class UnsafeShuffleWriterSuite {
 
   @Test
   public void mergeSpillsWithTransferToAndLZF() throws Exception {
-    testMergingSpills(true, LZFCompressionCodec.class.getName());
+    testMergingSpills(true, LZFCompressionCodec.class.getName(), false);
   }
 
   @Test
   public void mergeSpillsWithFileStreamAndLZF() throws Exception {
-    testMergingSpills(false, LZFCompressionCodec.class.getName());
+    testMergingSpills(false, LZFCompressionCodec.class.getName(), false);
   }
 
   @Test
   public void mergeSpillsWithTransferToAndLZ4() throws Exception {
-    testMergingSpills(true, LZ4CompressionCodec.class.getName());
+    testMergingSpills(true, LZ4CompressionCodec.class.getName(), false);
   }
 
   @Test
   public void mergeSpillsWithFileStreamAndLZ4() throws Exception {
-    testMergingSpills(false, LZ4CompressionCodec.class.getName());
+    testMergingSpills(false, LZ4CompressionCodec.class.getName(), false);
   }
 
   @Test
   public void mergeSpillsWithTransferToAndSnappy() throws Exception {
-    testMergingSpills(true, SnappyCompressionCodec.class.getName());
+    testMergingSpills(true, SnappyCompressionCodec.class.getName(), false);
   }
 
   @Test
   public void mergeSpillsWithFileStreamAndSnappy() throws Exception {
-    testMergingSpills(false, SnappyCompressionCodec.class.getName());
+    testMergingSpills(false, SnappyCompressionCodec.class.getName(), false);
   }
 
   @Test
   public void mergeSpillsWithTransferToAndNoCompression() throws Exception {
-    testMergingSpills(true, null);
+    testMergingSpills(true, null, false);
   }
 
   @Test
   public void mergeSpillsWithFileStreamAndNoCompression() throws Exception {
-    testMergingSpills(false, null);
+    testMergingSpills(false, null, false);
+  }
+
+  @Test
+  public void mergeSpillsWithCompressionAndEncryption() throws Exception {
+    // This should actually be translated to a "file stream merge" internally, 
just have the
+    // test to make sure that it's the case.
+    testMergingSpills(true, LZ4CompressionCodec.class.getName(), true);
+  }
+
+  @Test
+  public void mergeSpillsWithFileStreamAndCompressionAndEncryption() throws 
Exception {
+    testMergingSpills(false, LZ4CompressionCodec.class.getName(), true);
+  }
+
+  @Test
+  public void mergeSpillsWithCompressionAndEncryptionSlowPath() throws 
Exception {
+    conf.set("spark.shuffle.unsafe.fastMergeEnabled", "false");
+    testMergingSpills(false, LZ4CompressionCodec.class.getName(), true);
+  }
+
+  @Test
+  public void mergeSpillsWithEncryptionAndNoCompression() throws Exception {
+    // This should actually be translated to a "file stream merge" internally, 
just have the
+    // test to make sure that it's the case.
+    testMergingSpills(true, null, true);
+  }
+
+  @Test
+  public void mergeSpillsWithFileStreamAndEncryptionAndNoCompression() throws 
Exception {
+    testMergingSpills(false, null, true);
   }
 
   @Test
@@ -531,4 +572,5 @@ public class UnsafeShuffleWriterSuite {
       writer.stop(false);
     }
   }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c2c2fdcb/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
 
b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 33709b4..2656814 100644
--- 
a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ 
b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -19,13 +19,11 @@ package org.apache.spark.unsafe.map;
 
 import java.io.File;
 import java.io.IOException;
-import java.io.OutputStream;
 import java.nio.ByteBuffer;
 import java.util.*;
 
 import scala.Tuple2;
 import scala.Tuple2$;
-import scala.runtime.AbstractFunction1;
 
 import org.junit.After;
 import org.junit.Assert;
@@ -75,13 +73,6 @@ public abstract class AbstractBytesToBytesMapSuite {
   @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
   @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
 
-  private static final class WrapStream extends 
AbstractFunction1<OutputStream, OutputStream> {
-    @Override
-    public OutputStream apply(OutputStream stream) {
-      return stream;
-    }
-  }
-
   @Before
   public void setup() {
     memoryManager =
@@ -120,9 +111,9 @@ public abstract class AbstractBytesToBytesMapSuite {
 
         return new DiskBlockObjectWriter(
           (File) args[1],
+          serializerManager,
           (SerializerInstance) args[2],
           (Integer) args[3],
-          new WrapStream(),
           false,
           (ShuffleWriteMetrics) args[4],
           (BlockId) args[0]

http://git-wip-us.apache.org/repos/asf/spark/blob/c2c2fdcb/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
 
b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index a9cf8ff..fbbe530 100644
--- 
a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ 
b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -19,14 +19,12 @@ package org.apache.spark.util.collection.unsafe.sort;
 
 import java.io.File;
 import java.io.IOException;
-import java.io.OutputStream;
 import java.util.Arrays;
 import java.util.LinkedList;
 import java.util.UUID;
 
 import scala.Tuple2;
 import scala.Tuple2$;
-import scala.runtime.AbstractFunction1;
 
 import org.junit.After;
 import org.junit.Before;
@@ -57,13 +55,15 @@ import static org.mockito.Mockito.*;
 
 public class UnsafeExternalSorterSuite {
 
+  private final SparkConf conf = new SparkConf();
+
   final LinkedList<File> spillFilesCreated = new LinkedList<>();
   final TestMemoryManager memoryManager =
-    new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", 
"false"));
+    new TestMemoryManager(conf.clone().set("spark.memory.offHeap.enabled", 
"false"));
   final TaskMemoryManager taskMemoryManager = new 
TaskMemoryManager(memoryManager, 0);
   final SerializerManager serializerManager = new SerializerManager(
-    new JavaSerializer(new SparkConf()),
-    new SparkConf().set("spark.shuffle.spill.compress", "false"));
+    new JavaSerializer(conf),
+    conf.clone().set("spark.shuffle.spill.compress", "false"));
   // Use integer comparison for comparing prefixes (which are partition ids, 
in this case)
   final PrefixComparator prefixComparator = PrefixComparators.LONG;
   // Since the key fits within the 8-byte prefix, we don't need to do any 
record comparison, so
@@ -86,14 +86,7 @@ public class UnsafeExternalSorterSuite {
 
   protected boolean shouldUseRadixSort() { return false; }
 
-  private final long pageSizeBytes = new 
SparkConf().getSizeAsBytes("spark.buffer.pageSize", "4m");
-
-  private static final class WrapStream extends 
AbstractFunction1<OutputStream, OutputStream> {
-    @Override
-    public OutputStream apply(OutputStream stream) {
-      return stream;
-    }
-  }
+  private final long pageSizeBytes = 
conf.getSizeAsBytes("spark.buffer.pageSize", "4m");
 
   @Before
   public void setUp() {
@@ -126,9 +119,9 @@ public class UnsafeExternalSorterSuite {
 
         return new DiskBlockObjectWriter(
           (File) args[1],
+          serializerManager,
           (SerializerInstance) args[2],
           (Integer) args[3],
-          new WrapStream(),
           false,
           (ShuffleWriteMetrics) args[4],
           (BlockId) args[0]

http://git-wip-us.apache.org/repos/asf/spark/blob/c2c2fdcb/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
 
b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index 4429416..85ccb33 100644
--- 
a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -33,7 +33,7 @@ import org.scalatest.BeforeAndAfterEach
 
 import org.apache.spark._
 import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics}
-import org.apache.spark.serializer.{JavaSerializer, SerializerInstance}
+import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, 
SerializerManager}
 import org.apache.spark.shuffle.IndexShuffleBlockResolver
 import org.apache.spark.storage._
 import org.apache.spark.util.Utils
@@ -90,11 +90,12 @@ class BypassMergeSortShuffleWriterSuite extends 
SparkFunSuite with BeforeAndAfte
     )).thenAnswer(new Answer[DiskBlockObjectWriter] {
       override def answer(invocation: InvocationOnMock): DiskBlockObjectWriter 
= {
         val args = invocation.getArguments
+        val manager = new SerializerManager(new JavaSerializer(conf), conf)
         new DiskBlockObjectWriter(
           args(1).asInstanceOf[File],
+          manager,
           args(2).asInstanceOf[SerializerInstance],
           args(3).asInstanceOf[Int],
-          wrapStream = identity,
           syncWrites = false,
           args(4).asInstanceOf[ShuffleWriteMetrics],
           blockId = args(0).asInstanceOf[BlockId]

http://git-wip-us.apache.org/repos/asf/spark/blob/c2c2fdcb/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala 
b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
index 684e978..bfb3ac4 100644
--- 
a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
@@ -22,7 +22,7 @@ import org.scalatest.BeforeAndAfterEach
 
 import org.apache.spark.{SparkConf, SparkFunSuite}
 import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.serializer.{JavaSerializer, SerializerManager}
 import org.apache.spark.util.Utils
 
 class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach 
{
@@ -42,11 +42,19 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with 
BeforeAndAfterEach {
     }
   }
 
-  test("verify write metrics") {
+  private def createWriter(): (DiskBlockObjectWriter, File, 
ShuffleWriteMetrics) = {
     val file = new File(tempDir, "somefile")
+    val conf = new SparkConf()
+    val serializerManager = new SerializerManager(new JavaSerializer(conf), 
conf)
     val writeMetrics = new ShuffleWriteMetrics()
     val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, 
true, writeMetrics)
+      file, serializerManager, new JavaSerializer(new 
SparkConf()).newInstance(), 1024, true,
+      writeMetrics)
+    (writer, file, writeMetrics)
+  }
+
+  test("verify write metrics") {
+    val (writer, file, writeMetrics) = createWriter()
 
     writer.write(Long.box(20), Long.box(30))
     // Record metrics update on every write
@@ -66,10 +74,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with 
BeforeAndAfterEach {
   }
 
   test("verify write metrics on revert") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, 
true, writeMetrics)
+    val (writer, _, writeMetrics) = createWriter()
 
     writer.write(Long.box(20), Long.box(30))
     // Record metrics update on every write
@@ -89,10 +94,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with 
BeforeAndAfterEach {
   }
 
   test("Reopening a closed block writer") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, 
true, writeMetrics)
+    val (writer, _, _) = createWriter()
 
     writer.open()
     writer.close()
@@ -102,10 +104,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite 
with BeforeAndAfterEach {
   }
 
   test("calling revertPartialWritesAndClose() on a partial write should 
truncate up to commit") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, 
true, writeMetrics)
+    val (writer, file, writeMetrics) = createWriter()
 
     writer.write(Long.box(20), Long.box(30))
     val firstSegment = writer.commitAndGet()
@@ -120,10 +119,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite 
with BeforeAndAfterEach {
   }
 
   test("calling revertPartialWritesAndClose() after commit() should have no 
effect") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, 
true, writeMetrics)
+    val (writer, file, writeMetrics) = createWriter()
 
     writer.write(Long.box(20), Long.box(30))
     val firstSegment = writer.commitAndGet()
@@ -136,10 +132,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite 
with BeforeAndAfterEach {
   }
 
   test("calling revertPartialWritesAndClose() on a closed block writer should 
have no effect") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, 
true, writeMetrics)
+    val (writer, file, writeMetrics) = createWriter()
     for (i <- 1 to 1000) {
       writer.write(i, i)
     }
@@ -153,10 +146,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite 
with BeforeAndAfterEach {
   }
 
   test("commit() and close() should be idempotent") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, 
true, writeMetrics)
+    val (writer, file, writeMetrics) = createWriter()
     for (i <- 1 to 1000) {
       writer.write(i, i)
     }
@@ -173,10 +163,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite 
with BeforeAndAfterEach {
   }
 
   test("revertPartialWritesAndClose() should be idempotent") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, 
true, writeMetrics)
+    val (writer, file, writeMetrics) = createWriter()
     for (i <- 1 to 1000) {
       writer.write(i, i)
     }
@@ -191,10 +178,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite 
with BeforeAndAfterEach {
   }
 
   test("commit() and close() without ever opening or writing") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, 
true, writeMetrics)
+    val (writer, _, _) = createWriter()
     val segment = writer.commitAndGet()
     writer.close()
     assert(segment.length === 0)

http://git-wip-us.apache.org/repos/asf/spark/blob/c2c2fdcb/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
 
b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
index 5141e36..7f08382 100644
--- 
a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.util.collection
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark._
+import org.apache.spark.internal.config._
 import org.apache.spark.io.CompressionCodec
 import org.apache.spark.memory.MemoryTestingUtils
 
@@ -230,14 +231,19 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite 
with LocalSparkContext {
     }
   }
 
+  test("spilling with compression and encryption") {
+    testSimpleSpilling(Some(CompressionCodec.DEFAULT_COMPRESSION_CODEC), 
encrypt = true)
+  }
+
   /**
    * Test spilling through simple aggregations and cogroups.
    * If a compression codec is provided, use it. Otherwise, do not compress 
spills.
    */
-  private def testSimpleSpilling(codec: Option[String] = None): Unit = {
+  private def testSimpleSpilling(codec: Option[String] = None, encrypt: 
Boolean = false): Unit = {
     val size = 1000
     val conf = createSparkConf(loadDefaults = true, codec)  // Load defaults 
for Spark home
     conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 
4).toString)
+    conf.set(IO_ENCRYPTION_ENABLED, encrypt)
     sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
 
     assertSpilled(sc, "reduceByKey") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to