Repository: spark
Updated Branches:
  refs/heads/master 32096c2ae -> 6906b69cf


SPARK-2787: Make sort-based shuffle write files directly when there's no 
sorting/aggregation and # partitions is small

As described in https://issues.apache.org/jira/browse/SPARK-2787, right now 
sort-based shuffle is more expensive than hash-based for map operations that do 
no partial aggregation or sorting, such as groupByKey. This is because it has 
to serialize each data item twice (once when spilling to intermediate files, 
and then again when merging these files object-by-object). This patch adds a 
code path to just write separate files directly if the # of output partitions 
is small, and concatenate them at the end to produce a sorted file.

On the unit test side, I added some tests that force or don't force this bypass 
path to be used, and checked that our tests for other features (e.g. all the 
operations) cover both cases.

Author: Matei Zaharia <[email protected]>

Closes #1799 from mateiz/SPARK-2787 and squashes the following commits:

88cf26a [Matei Zaharia] Fix rebase
10233af [Matei Zaharia] Review comments
398cb95 [Matei Zaharia] Fix looking up shuffle manager in conf
ca3efd9 [Matei Zaharia] Add docs for shuffle manager properties, and allow 
short names for them
d0ae3c5 [Matei Zaharia] Fix some comments
90d084f [Matei Zaharia] Add code path to bypass merge-sort in ExternalSorter, 
and tests
31e5d7c [Matei Zaharia] Move existing logic for writing partitioned files into 
ExternalSorter


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6906b69c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6906b69c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6906b69c

Branch: refs/heads/master
Commit: 6906b69cf568015f20c7d7c77cbcba650e5431a9
Parents: 32096c2
Author: Matei Zaharia <[email protected]>
Authored: Thu Aug 7 18:04:49 2014 -0700
Committer: Reynold Xin <[email protected]>
Committed: Thu Aug 7 18:04:49 2014 -0700

----------------------------------------------------------------------
 .../main/scala/org/apache/spark/SparkEnv.scala  |  27 ++-
 .../spark/shuffle/hash/HashShuffleReader.scala  |   2 +-
 .../spark/shuffle/sort/SortShuffleWriter.scala  |  80 ++-----
 .../spark/util/collection/ExternalSorter.scala  | 233 ++++++++++++++++---
 .../util/collection/ExternalSorterSuite.scala   | 165 +++++++++++--
 docs/configuration.md                           |  18 ++
 6 files changed, 407 insertions(+), 118 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6906b69c/core/src/main/scala/org/apache/spark/SparkEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala 
b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 9d4edeb..22d8d1c 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -156,11 +156,9 @@ object SparkEnv extends Logging {
       conf.set("spark.driver.port", boundPort.toString)
     }
 
-    // Create an instance of the class named by the given Java system 
property, or by
-    // defaultClassName if the property is not set, and return it as a T
-    def instantiateClass[T](propertyName: String, defaultClassName: String): T 
= {
-      val name = conf.get(propertyName,  defaultClassName)
-      val cls = Class.forName(name, true, Utils.getContextOrSparkClassLoader)
+    // Create an instance of the class with the given name, possibly 
initializing it with our conf
+    def instantiateClass[T](className: String): T = {
+      val cls = Class.forName(className, true, 
Utils.getContextOrSparkClassLoader)
       // Look for a constructor taking a SparkConf and a boolean isDriver, 
then one taking just
       // SparkConf, then one taking no arguments
       try {
@@ -178,11 +176,17 @@ object SparkEnv extends Logging {
       }
     }
 
-    val serializer = instantiateClass[Serializer](
+    // Create an instance of the class named by the given SparkConf property, 
or defaultClassName
+    // if the property is not set, possibly initializing it with our conf
+    def instantiateClassFromConf[T](propertyName: String, defaultClassName: 
String): T = {
+      instantiateClass[T](conf.get(propertyName, defaultClassName))
+    }
+
+    val serializer = instantiateClassFromConf[Serializer](
       "spark.serializer", "org.apache.spark.serializer.JavaSerializer")
     logDebug(s"Using serializer: ${serializer.getClass}")
 
-    val closureSerializer = instantiateClass[Serializer](
+    val closureSerializer = instantiateClassFromConf[Serializer](
       "spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer")
 
     def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
@@ -246,8 +250,13 @@ object SparkEnv extends Logging {
       "."
     }
 
-    val shuffleManager = instantiateClass[ShuffleManager](
-      "spark.shuffle.manager", 
"org.apache.spark.shuffle.hash.HashShuffleManager")
+    // Let the user specify short names for shuffle managers
+    val shortShuffleMgrNames = Map(
+      "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager",
+      "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager")
+    val shuffleMgrName = conf.get("spark.shuffle.manager", "hash")
+    val shuffleMgrClass = 
shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName)
+    val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass)
 
     val shuffleMemoryManager = new ShuffleMemoryManager(conf)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6906b69c/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala 
b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index 7c9dc8e..88a5f1e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -58,7 +58,7 @@ private[spark] class HashShuffleReader[K, C](
         // Create an ExternalSorter to sort the data. Note that if 
spark.shuffle.spill is disabled,
         // the ExternalSorter won't spill to disk.
         val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), 
serializer = Some(ser))
-        sorter.write(aggregatedIter)
+        sorter.insertAll(aggregatedIter)
         context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled
         context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled
         sorter.iterator

http://git-wip-us.apache.org/repos/asf/spark/blob/6906b69c/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala 
b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index e54e638..22f656f 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -44,6 +44,7 @@ private[spark] class SortShuffleWriter[K, V, C](
 
   private var sorter: ExternalSorter[K, V, _] = null
   private var outputFile: File = null
+  private var indexFile: File = null
 
   // Are we in the process of stopping? Because map tasks can call stop() with 
success = true
   // and then call stop() with success = false if they get an exception, we 
want to make sure
@@ -57,78 +58,36 @@ private[spark] class SortShuffleWriter[K, V, C](
 
   /** Write a bunch of records to this task's output */
   override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
-    // Get an iterator with the elements for each partition ID
-    val partitions: Iterator[(Int, Iterator[Product2[K, _]])] = {
-      if (dep.mapSideCombine) {
-        if (!dep.aggregator.isDefined) {
-          throw new IllegalStateException("Aggregator is empty for map-side 
combine")
-        }
-        sorter = new ExternalSorter[K, V, C](
-          dep.aggregator, Some(dep.partitioner), dep.keyOrdering, 
dep.serializer)
-        sorter.write(records)
-        sorter.partitionedIterator
-      } else {
-        // In this case we pass neither an aggregator nor an ordering to the 
sorter, because we
-        // don't care whether the keys get sorted in each partition; that will 
be done on the
-        // reduce side if the operation being run is sortByKey.
-        sorter = new ExternalSorter[K, V, V](
-          None, Some(dep.partitioner), None, dep.serializer)
-        sorter.write(records)
-        sorter.partitionedIterator
+    if (dep.mapSideCombine) {
+      if (!dep.aggregator.isDefined) {
+        throw new IllegalStateException("Aggregator is empty for map-side 
combine")
       }
+      sorter = new ExternalSorter[K, V, C](
+        dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
+      sorter.insertAll(records)
+    } else {
+      // In this case we pass neither an aggregator nor an ordering to the 
sorter, because we don't
+      // care whether the keys get sorted in each partition; that will be done 
on the reduce side
+      // if the operation being run is sortByKey.
+      sorter = new ExternalSorter[K, V, V](
+        None, Some(dep.partitioner), None, dep.serializer)
+      sorter.insertAll(records)
     }
 
     // Create a single shuffle file with reduce ID 0 that we'll write all 
results to. We'll later
     // serve different ranges of this file using an index file that we create 
at the end.
     val blockId = ShuffleBlockId(dep.shuffleId, mapId, 0)
-    outputFile = blockManager.diskBlockManager.getFile(blockId)
-
-    // Track location of each range in the output file
-    val offsets = new Array[Long](numPartitions + 1)
-    val lengths = new Array[Long](numPartitions)
-
-    for ((id, elements) <- partitions) {
-      if (elements.hasNext) {
-        val writer = blockManager.getDiskWriter(blockId, outputFile, ser, 
fileBufferSize,
-          writeMetrics)
-        for (elem <- elements) {
-          writer.write(elem)
-        }
-        writer.commitAndClose()
-        val segment = writer.fileSegment()
-        offsets(id + 1) = segment.offset + segment.length
-        lengths(id) = segment.length
-      } else {
-        // The partition is empty; don't create a new writer to avoid writing 
headers, etc
-        offsets(id + 1) = offsets(id)
-      }
-    }
-
-    context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled
-    context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled
 
-    // Write an index file with the offsets of each block, plus a final offset 
at the end for the
-    // end of the output file. This will be used by 
SortShuffleManager.getBlockLocation to figure
-    // out where each block begins and ends.
+    outputFile = blockManager.diskBlockManager.getFile(blockId)
+    indexFile = blockManager.diskBlockManager.getFile(blockId.name + ".index")
 
-    val diskBlockManager = blockManager.diskBlockManager
-    val indexFile = diskBlockManager.getFile(blockId.name + ".index")
-    val out = new DataOutputStream(new BufferedOutputStream(new 
FileOutputStream(indexFile)))
-    try {
-      var i = 0
-      while (i < numPartitions + 1) {
-        out.writeLong(offsets(i))
-        i += 1
-      }
-    } finally {
-      out.close()
-    }
+    val partitionLengths = sorter.writePartitionedFile(blockId, context)
 
     // Register our map output with the ShuffleBlockManager, which handles 
cleaning it over time
     blockManager.shuffleBlockManager.addCompletedMap(dep.shuffleId, mapId, 
numPartitions)
 
     mapStatus = new MapStatus(blockManager.blockManagerId,
-      lengths.map(MapOutputTracker.compressSize))
+      partitionLengths.map(MapOutputTracker.compressSize))
   }
 
   /** Close this writer, passing along whether the map completed */
@@ -145,6 +104,9 @@ private[spark] class SortShuffleWriter[K, V, C](
         if (outputFile != null) {
           outputFile.delete()
         }
+        if (indexFile != null) {
+          indexFile.delete()
+        }
         return None
       }
     } finally {

http://git-wip-us.apache.org/repos/asf/spark/blob/6906b69c/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala 
b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index eb4849e..b73d5e0 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -25,10 +25,10 @@ import scala.collection.mutable
 
 import com.google.common.io.ByteStreams
 
-import org.apache.spark.{Aggregator, SparkEnv, Logging, Partitioner}
+import org.apache.spark._
 import org.apache.spark.serializer.{DeserializationStream, Serializer}
-import org.apache.spark.storage.BlockId
 import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.storage.{BlockObjectWriter, BlockId}
 
 /**
  * Sorts and potentially merges a number of key-value pairs of type (K, V) to 
produce key-combiner
@@ -67,6 +67,13 @@ import org.apache.spark.executor.ShuffleWriteMetrics
  *   for equality to merge values.
  *
  * - Users are expected to call stop() at the end to delete all the 
intermediate files.
+ *
+ * As a special case, if no Ordering and no Aggregator is given, and the 
number of partitions is
+ * less than spark.shuffle.sort.bypassMergeThreshold, we bypass the merge-sort 
and just write to
+ * separate files for each partition each time we spill, similar to the 
HashShuffleWriter. We can
+ * then concatenate these files to produce a single sorted file, without 
having to serialize and
+ * de-serialize each item twice (as is needed during the merge). This speeds 
up the map side of
+ * groupBy, sort, etc operations since they do no partial aggregation.
  */
 private[spark] class ExternalSorter[K, V, C](
     aggregator: Option[Aggregator[K, V, C]] = None,
@@ -124,6 +131,18 @@ private[spark] class ExternalSorter[K, V, C](
   // How much of the shared memory pool this collection has claimed
   private var myMemoryThreshold = 0L
 
+  // If there are fewer than spark.shuffle.sort.bypassMergeThreshold 
partitions and we don't need
+  // local aggregation and sorting, write numPartitions files directly and 
just concatenate them
+  // at the end. This avoids doing serialization and deserialization twice to 
merge together the
+  // spilled files, which would happen with the normal code path. The downside 
is having multiple
+  // files open at a time and thus more memory allocated to buffers.
+  private val bypassMergeThreshold = 
conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
+  private val bypassMergeSort =
+    (numPartitions <= bypassMergeThreshold && aggregator.isEmpty && 
ordering.isEmpty)
+
+  // Array of file writers for each partition, used if bypassMergeSort is true 
and we've spilled
+  private var partitionWriters: Array[BlockObjectWriter] = null
+
   // A comparator for keys K that orders them within a partition to allow 
aggregation or sorting.
   // Can be a partial ordering by hash code if a total ordering is not 
provided through by the
   // user. (A partial ordering means that equal keys have 
comparator.compare(k, k) = 0, but some
@@ -137,7 +156,14 @@ private[spark] class ExternalSorter[K, V, C](
     }
   })
 
-  // A comparator for (Int, K) elements that orders them by partition and then 
possibly by key
+  // A comparator for (Int, K) pairs that orders them by only their partition 
ID
+  private val partitionComparator: Comparator[(Int, K)] = new Comparator[(Int, 
K)] {
+    override def compare(a: (Int, K), b: (Int, K)): Int = {
+      a._1 - b._1
+    }
+  }
+
+  // A comparator that orders (Int, K) pairs by partition ID and then possibly 
by key
   private val partitionKeyComparator: Comparator[(Int, K)] = {
     if (ordering.isDefined || aggregator.isDefined) {
       // Sort by partition ID then key comparator
@@ -153,11 +179,7 @@ private[spark] class ExternalSorter[K, V, C](
       }
     } else {
       // Just sort it by partition ID
-      new Comparator[(Int, K)] {
-        override def compare(a: (Int, K), b: (Int, K)): Int = {
-          a._1 - b._1
-        }
-      }
+      partitionComparator
     }
   }
 
@@ -171,7 +193,7 @@ private[spark] class ExternalSorter[K, V, C](
     elementsPerPartition: Array[Long])
   private val spills = new ArrayBuffer[SpilledFile]
 
-  def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
+  def insertAll(records: Iterator[_ <: Product2[K, V]]): Unit = {
     // TODO: stop combining if we find that the reduction factor isn't high
     val shouldCombine = aggregator.isDefined
 
@@ -242,6 +264,38 @@ private[spark] class ExternalSorter[K, V, C](
     val threadId = Thread.currentThread().getId
     logInfo("Thread %d spilling in-memory batch of %d MB to disk (%d spill%s 
so far)"
       .format(threadId, memorySize / (1024 * 1024), spillCount, if (spillCount 
> 1) "s" else ""))
+
+    if (bypassMergeSort) {
+      spillToPartitionFiles(collection)
+    } else {
+      spillToMergeableFile(collection)
+    }
+
+    if (usingMap) {
+      map = new SizeTrackingAppendOnlyMap[(Int, K), C]
+    } else {
+      buffer = new SizeTrackingPairBuffer[(Int, K), C]
+    }
+
+    // Release our memory back to the shuffle pool so that other threads can 
grab it
+    shuffleMemoryManager.release(myMemoryThreshold)
+    myMemoryThreshold = 0
+
+    _memoryBytesSpilled += memorySize
+  }
+
+  /**
+   * Spill our in-memory collection to a sorted file that we can merge later 
(normal code path).
+   * We add this file into spilledFiles to find it later.
+   *
+   * Alternatively, if bypassMergeSort is true, we spill to separate files for 
each partition.
+   * See spillToPartitionedFiles() for that code path.
+   *
+   * @param collection whichever collection we're using (map or buffer)
+   */
+  private def spillToMergeableFile(collection: 
SizeTrackingPairCollection[(Int, K), C]): Unit = {
+    assert(!bypassMergeSort)
+
     val (blockId, file) = diskBlockManager.createTempBlock()
     curWriteMetrics = new ShuffleWriteMetrics()
     var writer = blockManager.getDiskWriter(blockId, file, ser, 
fileBufferSize, curWriteMetrics)
@@ -304,18 +358,36 @@ private[spark] class ExternalSorter[K, V, C](
       }
     }
 
-    if (usingMap) {
-      map = new SizeTrackingAppendOnlyMap[(Int, K), C]
-    } else {
-      buffer = new SizeTrackingPairBuffer[(Int, K), C]
-    }
+    spills.append(SpilledFile(file, blockId, batchSizes.toArray, 
elementsPerPartition))
+  }
 
-    // Release our memory back to the shuffle pool so that other threads can 
grab it
-    shuffleMemoryManager.release(myMemoryThreshold)
-    myMemoryThreshold = 0
+  /**
+   * Spill our in-memory collection to separate files, one for each partition. 
This is used when
+   * there's no aggregator and ordering and the number of partitions is small, 
because it allows
+   * writePartitionedFile to just concatenate files without deserializing data.
+   *
+   * @param collection whichever collection we're using (map or buffer)
+   */
+  private def spillToPartitionFiles(collection: 
SizeTrackingPairCollection[(Int, K), C]): Unit = {
+    assert(bypassMergeSort)
+
+    // Create our file writers if we haven't done so yet
+    if (partitionWriters == null) {
+      curWriteMetrics = new ShuffleWriteMetrics()
+      partitionWriters = Array.fill(numPartitions) {
+        val (blockId, file) = diskBlockManager.createTempBlock()
+        blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, 
curWriteMetrics).open()
+      }
+    }
 
-    spills.append(SpilledFile(file, blockId, batchSizes.toArray, 
elementsPerPartition))
-    _memoryBytesSpilled += memorySize
+    val it = collection.iterator  // No need to sort stuff, just write each 
element out
+    while (it.hasNext) {
+      val elem = it.next()
+      val partitionId = elem._1._1
+      val key = elem._1._2
+      val value = elem._2
+      partitionWriters(partitionId).write((key, value))
+    }
   }
 
   /**
@@ -479,7 +551,6 @@ private[spark] class ExternalSorter[K, V, C](
 
     skipToNextPartition()
 
-
     // Intermediate file and deserializer streams that read from exactly one 
batch
     // This guards against pre-fetching and other arbitrary behavior of higher 
level streams
     var fileStream: FileInputStream = null
@@ -619,23 +690,25 @@ private[spark] class ExternalSorter[K, V, C](
   def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
     val usingMap = aggregator.isDefined
     val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) 
map else buffer
-    if (spills.isEmpty) {
+    if (spills.isEmpty && partitionWriters == null) {
       // Special case: if we have only in-memory data, we don't need to merge 
streams, and perhaps
       // we don't even need to sort by anything other than partition ID
       if (!ordering.isDefined) {
-        // The user isn't requested sorted keys, so only sort by partition ID, 
not key
-        val partitionComparator = new Comparator[(Int, K)] {
-          override def compare(a: (Int, K), b: (Int, K)): Int = {
-            a._1 - b._1
-          }
-        }
+        // The user hasn't requested sorted keys, so only sort by partition 
ID, not key
         
groupByPartition(collection.destructiveSortedIterator(partitionComparator))
       } else {
         // We do need to sort by both partition ID and key
         
groupByPartition(collection.destructiveSortedIterator(partitionKeyComparator))
       }
+    } else if (bypassMergeSort) {
+      // Read data from each partition file and merge it together with the 
data in memory;
+      // note that there's no ordering or aggregator in this case -- we just 
partition objects
+      val collIter = 
groupByPartition(collection.destructiveSortedIterator(partitionComparator))
+      collIter.map { case (partitionId, values) =>
+        (partitionId, values ++ 
readPartitionFile(partitionWriters(partitionId)))
+      }
     } else {
-      // General case: merge spilled and in-memory data
+      // Merge spilled and in-memory data
       merge(spills, 
collection.destructiveSortedIterator(partitionKeyComparator))
     }
   }
@@ -645,9 +718,113 @@ private[spark] class ExternalSorter[K, V, C](
    */
   def iterator: Iterator[Product2[K, C]] = partitionedIterator.flatMap(pair => 
pair._2)
 
+  /**
+   * Write all the data added into this ExternalSorter into a file in the disk 
store, creating
+   * an .index file for it as well with the offsets of each partition. This is 
called by the
+   * SortShuffleWriter and can go through an efficient path of just 
concatenating binary files
+   * if we decided to avoid merge-sorting.
+   *
+   * @param blockId block ID to write to. The index file will be blockId.name 
+ ".index".
+   * @param context a TaskContext for a running Spark task, for us to update 
shuffle metrics.
+   * @return array of lengths, in bytes, of each partition of the file (used 
by map output tracker)
+   */
+  def writePartitionedFile(blockId: BlockId, context: TaskContext): 
Array[Long] = {
+    val outputFile = blockManager.diskBlockManager.getFile(blockId)
+
+    // Track location of each range in the output file
+    val offsets = new Array[Long](numPartitions + 1)
+    val lengths = new Array[Long](numPartitions)
+
+    if (bypassMergeSort && partitionWriters != null) {
+      // We decided to write separate files for each partition, so just 
concatenate them. To keep
+      // this simple we spill out the current in-memory collection so that 
everything is in files.
+      spillToPartitionFiles(if (aggregator.isDefined) map else buffer)
+      partitionWriters.foreach(_.commitAndClose())
+      var out: FileOutputStream = null
+      var in: FileInputStream = null
+      try {
+        out = new FileOutputStream(outputFile)
+        for (i <- 0 until numPartitions) {
+          val file = partitionWriters(i).fileSegment().file
+          in = new FileInputStream(file)
+          org.apache.spark.util.Utils.copyStream(in, out)
+          in.close()
+          in = null
+          lengths(i) = file.length()
+          offsets(i + 1) = offsets(i) + lengths(i)
+        }
+      } finally {
+        if (out != null) {
+          out.close()
+        }
+        if (in != null) {
+          in.close()
+        }
+      }
+    } else {
+      // Either we're not bypassing merge-sort or we have only in-memory data; 
get an iterator by
+      // partition and just write everything directly.
+      for ((id, elements) <- this.partitionedIterator) {
+        if (elements.hasNext) {
+          val writer = blockManager.getDiskWriter(
+            blockId, outputFile, ser, fileBufferSize, 
context.taskMetrics.shuffleWriteMetrics.get)
+          for (elem <- elements) {
+            writer.write(elem)
+          }
+          writer.commitAndClose()
+          val segment = writer.fileSegment()
+          offsets(id + 1) = segment.offset + segment.length
+          lengths(id) = segment.length
+        } else {
+          // The partition is empty; don't create a new writer to avoid 
writing headers, etc
+          offsets(id + 1) = offsets(id)
+        }
+      }
+    }
+
+    context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled
+    context.taskMetrics.diskBytesSpilled += diskBytesSpilled
+
+    // Write an index file with the offsets of each block, plus a final offset 
at the end for the
+    // end of the output file. This will be used by 
SortShuffleManager.getBlockLocation to figure
+    // out where each block begins and ends.
+
+    val diskBlockManager = blockManager.diskBlockManager
+    val indexFile = diskBlockManager.getFile(blockId.name + ".index")
+    val out = new DataOutputStream(new BufferedOutputStream(new 
FileOutputStream(indexFile)))
+    try {
+      var i = 0
+      while (i < numPartitions + 1) {
+        out.writeLong(offsets(i))
+        i += 1
+      }
+    } finally {
+      out.close()
+    }
+
+    lengths
+  }
+
+  /**
+   * Read a partition file back as an iterator (used in our iterator method)
+   */
+  def readPartitionFile(writer: BlockObjectWriter): Iterator[Product2[K, C]] = 
{
+    if (writer.isOpen) {
+      writer.commitAndClose()
+    }
+    blockManager.getLocalFromDisk(writer.blockId, 
ser).get.asInstanceOf[Iterator[Product2[K, C]]]
+  }
+
   def stop(): Unit = {
     spills.foreach(s => s.file.delete())
     spills.clear()
+    if (partitionWriters != null) {
+      partitionWriters.foreach { w =>
+        w.revertPartialWritesAndClose()
+        diskBlockManager.getFile(w.blockId).delete()
+      }
+      partitionWriters = null
+    }
   }
 
   def memoryBytesSpilled: Long = _memoryBytesSpilled

http://git-wip-us.apache.org/repos/asf/spark/blob/6906b69c/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
 
b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
index 57dcb4f..706faed 100644
--- 
a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -19,12 +19,12 @@ package org.apache.spark.util.collection
 
 import scala.collection.mutable.ArrayBuffer
 
-import org.scalatest.FunSuite
+import org.scalatest.{PrivateMethodTester, FunSuite}
 
 import org.apache.spark._
 import org.apache.spark.SparkContext._
 
-class ExternalSorterSuite extends FunSuite with LocalSparkContext {
+class ExternalSorterSuite extends FunSuite with LocalSparkContext with 
PrivateMethodTester {
   private def createSparkConf(loadDefaults: Boolean): SparkConf = {
     val conf = new SparkConf(loadDefaults)
     // Make the Java serializer write a reset instruction (TC_RESET) after 
each object to test
@@ -36,6 +36,16 @@ class ExternalSorterSuite extends FunSuite with 
LocalSparkContext {
     conf
   }
 
+  private def assertBypassedMergeSort(sorter: ExternalSorter[_, _, _]): Unit = 
{
+    val bypassMergeSort = PrivateMethod[Boolean]('bypassMergeSort)
+    assert(sorter.invokePrivate(bypassMergeSort()), "sorter did not bypass 
merge-sort")
+  }
+
+  private def assertDidNotBypassMergeSort(sorter: ExternalSorter[_, _, _]): 
Unit = {
+    val bypassMergeSort = PrivateMethod[Boolean]('bypassMergeSort)
+    assert(!sorter.invokePrivate(bypassMergeSort()), "sorter bypassed 
merge-sort")
+  }
+
   test("empty data stream") {
     val conf = new SparkConf(false)
     conf.set("spark.shuffle.memoryFraction", "0.001")
@@ -86,28 +96,28 @@ class ExternalSorterSuite extends FunSuite with 
LocalSparkContext {
     // Both aggregator and ordering
     val sorter = new ExternalSorter[Int, Int, Int](
       Some(agg), Some(new HashPartitioner(7)), Some(ord), None)
-    sorter.write(elements.iterator)
+    sorter.insertAll(elements.iterator)
     assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === 
expected)
     sorter.stop()
 
     // Only aggregator
     val sorter2 = new ExternalSorter[Int, Int, Int](
       Some(agg), Some(new HashPartitioner(7)), None, None)
-    sorter2.write(elements.iterator)
+    sorter2.insertAll(elements.iterator)
     assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === 
expected)
     sorter2.stop()
 
     // Only ordering
     val sorter3 = new ExternalSorter[Int, Int, Int](
       None, Some(new HashPartitioner(7)), Some(ord), None)
-    sorter3.write(elements.iterator)
+    sorter3.insertAll(elements.iterator)
     assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === 
expected)
     sorter3.stop()
 
     // Neither aggregator nor ordering
     val sorter4 = new ExternalSorter[Int, Int, Int](
       None, Some(new HashPartitioner(7)), None, None)
-    sorter4.write(elements.iterator)
+    sorter4.insertAll(elements.iterator)
     assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === 
expected)
     sorter4.stop()
   }
@@ -118,13 +128,37 @@ class ExternalSorterSuite extends FunSuite with 
LocalSparkContext {
     conf.set("spark.shuffle.manager", 
"org.apache.spark.shuffle.sort.SortShuffleManager")
     sc = new SparkContext("local", "test", conf)
 
-    val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => 
i + j)
     val ord = implicitly[Ordering[Int]]
     val elements = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x 
=> (2, 2))
 
     val sorter = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(7)), Some(ord), None)
+    assertDidNotBypassMergeSort(sorter)
+    sorter.insertAll(elements)
+    assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // 
Make sure it spilled
+    val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList))
+    assert(iter.next() === (0, Nil))
+    assert(iter.next() === (1, List((1, 1))))
+    assert(iter.next() === (2, (0 until 100000).map(x => (2, 2)).toList))
+    assert(iter.next() === (3, Nil))
+    assert(iter.next() === (4, Nil))
+    assert(iter.next() === (5, List((5, 5))))
+    assert(iter.next() === (6, Nil))
+    sorter.stop()
+  }
+
+  test("empty partitions with spilling, bypass merge-sort") {
+    val conf = createSparkConf(false)
+    conf.set("spark.shuffle.memoryFraction", "0.001")
+    conf.set("spark.shuffle.manager", 
"org.apache.spark.shuffle.sort.SortShuffleManager")
+    sc = new SparkContext("local", "test", conf)
+
+    val elements = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x 
=> (2, 2))
+
+    val sorter = new ExternalSorter[Int, Int, Int](
       None, Some(new HashPartitioner(7)), None, None)
-    sorter.write(elements)
+    assertBypassedMergeSort(sorter)
+    sorter.insertAll(elements)
     assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // 
Make sure it spilled
     val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList))
     assert(iter.next() === (0, Nil))
@@ -286,14 +320,43 @@ class ExternalSorterSuite extends FunSuite with 
LocalSparkContext {
     sc = new SparkContext("local", "test", conf)
     val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
 
+    val ord = implicitly[Ordering[Int]]
+
+    val sorter = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(3)), Some(ord), None)
+    assertDidNotBypassMergeSort(sorter)
+    sorter.insertAll((0 until 100000).iterator.map(i => (i, i)))
+    assert(diskBlockManager.getAllFiles().length > 0)
+    sorter.stop()
+    assert(diskBlockManager.getAllBlocks().length === 0)
+
+    val sorter2 = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(3)), Some(ord), None)
+    assertDidNotBypassMergeSort(sorter2)
+    sorter2.insertAll((0 until 100000).iterator.map(i => (i, i)))
+    assert(diskBlockManager.getAllFiles().length > 0)
+    assert(sorter2.iterator.toSet === (0 until 100000).map(i => (i, i)).toSet)
+    sorter2.stop()
+    assert(diskBlockManager.getAllBlocks().length === 0)
+  }
+
+  test("cleanup of intermediate files in sorter, bypass merge-sort") {
+    val conf = createSparkConf(true)  // Load defaults, otherwise SPARK_HOME 
is not found
+    conf.set("spark.shuffle.memoryFraction", "0.001")
+    conf.set("spark.shuffle.manager", 
"org.apache.spark.shuffle.sort.SortShuffleManager")
+    sc = new SparkContext("local", "test", conf)
+    val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
+
     val sorter = new ExternalSorter[Int, Int, Int](None, Some(new 
HashPartitioner(3)), None, None)
-    sorter.write((0 until 100000).iterator.map(i => (i, i)))
+    assertBypassedMergeSort(sorter)
+    sorter.insertAll((0 until 100000).iterator.map(i => (i, i)))
     assert(diskBlockManager.getAllFiles().length > 0)
     sorter.stop()
     assert(diskBlockManager.getAllBlocks().length === 0)
 
     val sorter2 = new ExternalSorter[Int, Int, Int](None, Some(new 
HashPartitioner(3)), None, None)
-    sorter2.write((0 until 100000).iterator.map(i => (i, i)))
+    assertBypassedMergeSort(sorter2)
+    sorter2.insertAll((0 until 100000).iterator.map(i => (i, i)))
     assert(diskBlockManager.getAllFiles().length > 0)
     assert(sorter2.iterator.toSet === (0 until 100000).map(i => (i, i)).toSet)
     sorter2.stop()
@@ -307,9 +370,35 @@ class ExternalSorterSuite extends FunSuite with 
LocalSparkContext {
     sc = new SparkContext("local", "test", conf)
     val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
 
+    val ord = implicitly[Ordering[Int]]
+
+    val sorter = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(3)), Some(ord), None)
+    assertDidNotBypassMergeSort(sorter)
+    intercept[SparkException] {
+      sorter.insertAll((0 until 100000).iterator.map(i => {
+        if (i == 99990) {
+          throw new SparkException("Intentional failure")
+        }
+        (i, i)
+      }))
+    }
+    assert(diskBlockManager.getAllFiles().length > 0)
+    sorter.stop()
+    assert(diskBlockManager.getAllBlocks().length === 0)
+  }
+
+  test("cleanup of intermediate files in sorter if there are errors, bypass 
merge-sort") {
+    val conf = createSparkConf(true)  // Load defaults, otherwise SPARK_HOME 
is not found
+    conf.set("spark.shuffle.memoryFraction", "0.001")
+    conf.set("spark.shuffle.manager", 
"org.apache.spark.shuffle.sort.SortShuffleManager")
+    sc = new SparkContext("local", "test", conf)
+    val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
+
     val sorter = new ExternalSorter[Int, Int, Int](None, Some(new 
HashPartitioner(3)), None, None)
+    assertBypassedMergeSort(sorter)
     intercept[SparkException] {
-      sorter.write((0 until 100000).iterator.map(i => {
+      sorter.insertAll((0 until 100000).iterator.map(i => {
         if (i == 99990) {
           throw new SparkException("Intentional failure")
         }
@@ -365,7 +454,7 @@ class ExternalSorterSuite extends FunSuite with 
LocalSparkContext {
     sc = new SparkContext("local", "test", conf)
 
     val sorter = new ExternalSorter[Int, Int, Int](None, Some(new 
HashPartitioner(3)), None, None)
-    sorter.write((0 until 100000).iterator.map(i => (i / 4, i)))
+    sorter.insertAll((0 until 100000).iterator.map(i => (i / 4, i)))
     val results = sorter.partitionedIterator.map{case (p, vs) => (p, 
vs.toSet)}.toSet
     val expected = (0 until 3).map(p => {
       (p, (0 until 100000).map(i => (i / 4, i)).filter(_._1 % 3 == p).toSet)
@@ -381,7 +470,7 @@ class ExternalSorterSuite extends FunSuite with 
LocalSparkContext {
 
     val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => 
i + j)
     val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), 
None, None)
-    sorter.write((0 until 100).iterator.map(i => (i / 2, i)))
+    sorter.insertAll((0 until 100).iterator.map(i => (i / 2, i)))
     val results = sorter.partitionedIterator.map{case (p, vs) => (p, 
vs.toSet)}.toSet
     val expected = (0 until 3).map(p => {
       (p, (0 until 50).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet)
@@ -397,7 +486,7 @@ class ExternalSorterSuite extends FunSuite with 
LocalSparkContext {
 
     val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => 
i + j)
     val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), 
None, None)
-    sorter.write((0 until 100000).iterator.map(i => (i / 2, i)))
+    sorter.insertAll((0 until 100000).iterator.map(i => (i / 2, i)))
     val results = sorter.partitionedIterator.map{case (p, vs) => (p, 
vs.toSet)}.toSet
     val expected = (0 until 3).map(p => {
       (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet)
@@ -414,7 +503,7 @@ class ExternalSorterSuite extends FunSuite with 
LocalSparkContext {
     val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => 
i + j)
     val ord = implicitly[Ordering[Int]]
     val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), 
Some(ord), None)
-    sorter.write((0 until 100000).iterator.map(i => (i / 2, i)))
+    sorter.insertAll((0 until 100000).iterator.map(i => (i / 2, i)))
     val results = sorter.partitionedIterator.map{case (p, vs) => (p, 
vs.toSet)}.toSet
     val expected = (0 until 3).map(p => {
       (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet)
@@ -431,7 +520,7 @@ class ExternalSorterSuite extends FunSuite with 
LocalSparkContext {
     val ord = implicitly[Ordering[Int]]
     val sorter = new ExternalSorter[Int, Int, Int](
       None, Some(new HashPartitioner(3)), Some(ord), None)
-    sorter.write((0 until 100).iterator.map(i => (i, i)))
+    sorter.insertAll((0 until 100).iterator.map(i => (i, i)))
     val results = sorter.partitionedIterator.map{case (p, vs) => (p, 
vs.toSeq)}.toSeq
     val expected = (0 until 3).map(p => {
       (p, (0 until 100).map(i => (i, i)).filter(_._1 % 3 == p).toSeq)
@@ -448,7 +537,7 @@ class ExternalSorterSuite extends FunSuite with 
LocalSparkContext {
     val ord = implicitly[Ordering[Int]]
     val sorter = new ExternalSorter[Int, Int, Int](
       None, Some(new HashPartitioner(3)), Some(ord), None)
-    sorter.write((0 until 100000).iterator.map(i => (i, i)))
+    sorter.insertAll((0 until 100000).iterator.map(i => (i, i)))
     val results = sorter.partitionedIterator.map{case (p, vs) => (p, 
vs.toSeq)}.toSeq
     val expected = (0 until 3).map(p => {
       (p, (0 until 100000).map(i => (i, i)).filter(_._1 % 3 == p).toSeq)
@@ -495,7 +584,7 @@ class ExternalSorterSuite extends FunSuite with 
LocalSparkContext {
     val toInsert = (1 to 100000).iterator.map(_.toString).map(s => (s, s)) ++
       collisionPairs.iterator ++ collisionPairs.iterator.map(_.swap)
 
-    sorter.write(toInsert)
+    sorter.insertAll(toInsert)
 
     // A map of collision pairs in both directions
     val collisionPairsMap = (collisionPairs ++ 
collisionPairs.map(_.swap)).toMap
@@ -524,7 +613,7 @@ class ExternalSorterSuite extends FunSuite with 
LocalSparkContext {
     // Insert 10 copies each of lots of objects whose hash codes are either 0 
or 1. This causes
     // problems if the map fails to group together the objects with the same 
code (SPARK-2043).
     val toInsert = for (i <- 1 to 10; j <- 1 to 10000) yield 
(FixedHashObject(j, j % 2), 1)
-    sorter.write(toInsert.iterator)
+    sorter.insertAll(toInsert.iterator)
 
     val it = sorter.iterator
     var count = 0
@@ -548,7 +637,7 @@ class ExternalSorterSuite extends FunSuite with 
LocalSparkContext {
     val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, 
mergeValue, mergeCombiners)
     val sorter = new ExternalSorter[Int, Int, ArrayBuffer[Int]](Some(agg), 
None, None, None)
 
-    sorter.write((1 to 100000).iterator.map(i => (i, i)) ++ 
Iterator((Int.MaxValue, Int.MaxValue)))
+    sorter.insertAll((1 to 100000).iterator.map(i => (i, i)) ++ 
Iterator((Int.MaxValue, Int.MaxValue)))
 
     val it = sorter.iterator
     while (it.hasNext) {
@@ -572,7 +661,7 @@ class ExternalSorterSuite extends FunSuite with 
LocalSparkContext {
     val sorter = new ExternalSorter[String, String, ArrayBuffer[String]](
       Some(agg), None, None, None)
 
-    sorter.write((1 to 100000).iterator.map(i => (i.toString, i.toString)) ++ 
Iterator(
+    sorter.insertAll((1 to 100000).iterator.map(i => (i.toString, i.toString)) 
++ Iterator(
       (null.asInstanceOf[String], "1"),
       ("1", null.asInstanceOf[String]),
       (null.asInstanceOf[String], null.asInstanceOf[String])
@@ -584,4 +673,38 @@ class ExternalSorterSuite extends FunSuite with 
LocalSparkContext {
       it.next()
     }
   }
+
+  test("conditions for bypassing merge-sort") {
+    val conf = createSparkConf(false)
+    conf.set("spark.shuffle.memoryFraction", "0.001")
+    conf.set("spark.shuffle.manager", 
"org.apache.spark.shuffle.sort.SortShuffleManager")
+    sc = new SparkContext("local", "test", conf)
+
+    val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => 
i + j)
+    val ord = implicitly[Ordering[Int]]
+
+    // Numbers of partitions that are above and below the default 
bypassMergeThreshold
+    val FEW_PARTITIONS = 50
+    val MANY_PARTITIONS = 10000
+
+    // Sorters with no ordering or aggregator: should bypass unless # of 
partitions is high
+
+    val sorter1 = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(FEW_PARTITIONS)), None, None)
+    assertBypassedMergeSort(sorter1)
+
+    val sorter2 = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(MANY_PARTITIONS)), None, None)
+    assertDidNotBypassMergeSort(sorter2)
+
+    // Sorters with an ordering or aggregator: should not bypass even if they 
have few partitions
+
+    val sorter3 = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(FEW_PARTITIONS)), Some(ord), None)
+    assertDidNotBypassMergeSort(sorter3)
+
+    val sorter4 = new ExternalSorter[Int, Int, Int](
+      Some(agg), Some(new HashPartitioner(FEW_PARTITIONS)), None, None)
+    assertDidNotBypassMergeSort(sorter4)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/6906b69c/docs/configuration.md
----------------------------------------------------------------------
diff --git a/docs/configuration.md b/docs/configuration.md
index 5e3eb0f..4d27c5a 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -281,6 +281,24 @@ Apart from these, the following properties are also 
available, and may be useful
     overhead per reduce task, so keep it small unless you have a large amount 
of memory.
   </td>
 </tr>
+<tr>
+  <td><code>spark.shuffle.manager</code></td>
+  <td>HASH</td>
+  <td>
+    Implementation to use for shuffling data. A hash-based shuffle manager is 
the default, but
+    starting in Spark 1.1 there is an experimental sort-based shuffle manager 
that is more 
+    memory-efficient in environments with small executors, such as YARN. To 
use that, change
+    this value to <code>SORT</code>.
+  </td>
+</tr>
+<tr>
+  <td><code>spark.shuffle.sort.bypassMergeThreshold</code></td>
+  <td>200</td>
+  <td>
+    (Advanced) In the sort-based shuffle manager, avoid merge-sorting data if 
there is no
+    map-side aggregation and there are at most this many reduce partitions.
+  </td>
+</tr>
 </table>
 
 #### Spark UI


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to