This is an automated email from the ASF dual-hosted git repository.

vanzin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ea90ea6  [SPARK-28571][CORE][SHUFFLE] Use the shuffle writer plugin 
for the SortShuffleWriter
ea90ea6 is described below

commit ea90ea6ce7e1e07064279b2f78a0301fd2048f11
Author: mcheah <mch...@palantir.com>
AuthorDate: Fri Aug 30 09:43:07 2019 -0700

    [SPARK-28571][CORE][SHUFFLE] Use the shuffle writer plugin for the 
SortShuffleWriter
    
    ## What changes were proposed in this pull request?
    
    Use the shuffle writer APIs introduced in SPARK-28209 in the sort shuffle 
writer.
    
    ## How was this patch tested?
    
    Existing unit tests were changed to use the plugin instead, and they used 
the local disk version to ensure that there were no regressions.
    
    Closes #25342 from mccheah/shuffle-writer-refactor-sort-shuffle-writer.
    
    Lead-authored-by: mcheah <mch...@palantir.com>
    Co-authored-by: mccheah <mch...@palantir.com>
    Signed-off-by: Marcelo Vanzin <van...@cloudera.com>
---
 .../shuffle/ShufflePartitionPairsWriter.scala      | 126 +++++++++++++++++++++
 .../spark/shuffle/sort/SortShuffleManager.scala    |   3 +-
 .../spark/shuffle/sort/SortShuffleWriter.scala     |  23 ++--
 .../spark/storage/DiskBlockObjectWriter.scala      |   6 +-
 .../spark/util/collection/ExternalSorter.scala     |  88 ++++++++++++--
 .../apache/spark/util/collection/PairsWriter.scala |  28 +++++
 .../WritablePartitionedPairCollection.scala        |   4 +-
 .../shuffle/sort/SortShuffleWriterSuite.scala      |  18 ++-
 8 files changed, 265 insertions(+), 31 deletions(-)

diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala
 
b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala
new file mode 100644
index 0000000..a988c5e
--- /dev/null
+++ 
b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import java.io.{Closeable, IOException, OutputStream}
+
+import org.apache.spark.serializer.{SerializationStream, SerializerInstance, 
SerializerManager}
+import org.apache.spark.shuffle.api.ShufflePartitionWriter
+import org.apache.spark.storage.BlockId
+import org.apache.spark.util.Utils
+import org.apache.spark.util.collection.PairsWriter
+
+/**
+ * A key-value writer inspired by {@link DiskBlockObjectWriter} that pushes 
the bytes to an
+ * arbitrary partition writer instead of writing to local disk through the 
block manager.
+ */
+private[spark] class ShufflePartitionPairsWriter(
+    partitionWriter: ShufflePartitionWriter,
+    serializerManager: SerializerManager,
+    serializerInstance: SerializerInstance,
+    blockId: BlockId,
+    writeMetrics: ShuffleWriteMetricsReporter)
+  extends PairsWriter with Closeable {
+
+  private var isClosed = false
+  private var partitionStream: OutputStream = _
+  private var wrappedStream: OutputStream = _
+  private var objOut: SerializationStream = _
+  private var numRecordsWritten = 0
+  private var curNumBytesWritten = 0L
+
+  override def write(key: Any, value: Any): Unit = {
+    if (isClosed) {
+      throw new IOException("Partition pairs writer is already closed.")
+    }
+    if (objOut == null) {
+      open()
+    }
+    objOut.writeKey(key)
+    objOut.writeValue(value)
+    recordWritten()
+  }
+
+  private def open(): Unit = {
+    try {
+      partitionStream = partitionWriter.openStream
+      wrappedStream = serializerManager.wrapStream(blockId, partitionStream)
+      objOut = serializerInstance.serializeStream(wrappedStream)
+    } catch {
+      case e: Exception =>
+        Utils.tryLogNonFatalError {
+          close()
+        }
+        throw e
+    }
+  }
+
+  override def close(): Unit = {
+    if (!isClosed) {
+      Utils.tryWithSafeFinally {
+        Utils.tryWithSafeFinally {
+          objOut = closeIfNonNull(objOut)
+          // Setting these to null will prevent the underlying streams from 
being closed twice
+          // just in case any stream's close() implementation is not 
idempotent.
+          wrappedStream = null
+          partitionStream = null
+        } {
+          // Normally closing objOut would close the inner streams as well, 
but just in case there
+          // was an error in initialization etc. we make sure we clean the 
other streams up too.
+          Utils.tryWithSafeFinally {
+            wrappedStream = closeIfNonNull(wrappedStream)
+            // Same as above - if wrappedStream closes then assume it closes 
underlying
+            // partitionStream and don't close again in the finally
+            partitionStream = null
+          } {
+            partitionStream = closeIfNonNull(partitionStream)
+          }
+        }
+        updateBytesWritten()
+      } {
+        isClosed = true
+      }
+    }
+  }
+
+  private def closeIfNonNull[T <: Closeable](closeable: T): T = {
+    if (closeable != null) {
+      closeable.close()
+    }
+    null.asInstanceOf[T]
+  }
+
+  /**
+   * Notify the writer that a record worth of bytes has been written with 
OutputStream#write.
+   */
+  private def recordWritten(): Unit = {
+    numRecordsWritten += 1
+    writeMetrics.incRecordsWritten(1)
+
+    if (numRecordsWritten % 16384 == 0) {
+      updateBytesWritten()
+    }
+  }
+
+  private def updateBytesWritten(): Unit = {
+    val numBytesWritten = partitionWriter.getNumBytesWritten
+    val bytesWrittenDiff = numBytesWritten - curNumBytesWritten
+    writeMetrics.incBytesWritten(bytesWrittenDiff)
+    curNumBytesWritten = numBytesWritten
+  }
+}
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala 
b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index 17719f5..2a99c93 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -157,7 +157,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) 
extends ShuffleManager
           metrics,
           shuffleExecutorComponents)
       case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
-        new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
+        new SortShuffleWriter(
+          shuffleBlockResolver, other, mapId, context, 
shuffleExecutorComponents)
     }
   }
 
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala 
b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 16058de..a781b16 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -21,15 +21,15 @@ import org.apache.spark._
 import org.apache.spark.internal.{config, Logging}
 import org.apache.spark.scheduler.MapStatus
 import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, 
ShuffleWriter}
-import org.apache.spark.storage.ShuffleBlockId
-import org.apache.spark.util.Utils
+import org.apache.spark.shuffle.api.ShuffleExecutorComponents
 import org.apache.spark.util.collection.ExternalSorter
 
 private[spark] class SortShuffleWriter[K, V, C](
     shuffleBlockResolver: IndexShuffleBlockResolver,
     handle: BaseShuffleHandle[K, V, C],
     mapId: Int,
-    context: TaskContext)
+    context: TaskContext,
+    shuffleExecutorComponents: ShuffleExecutorComponents)
   extends ShuffleWriter[K, V] with Logging {
 
   private val dep = handle.dependency
@@ -64,18 +64,11 @@ private[spark] class SortShuffleWriter[K, V, C](
     // Don't bother including the time to open the merged output file in the 
shuffle write time,
     // because it just opens a single file, so is typically too fast to 
measure accurately
     // (see SPARK-3570).
-    val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
-    val tmp = Utils.tempFileWith(output)
-    try {
-      val blockId = ShuffleBlockId(dep.shuffleId, mapId, 
IndexShuffleBlockResolver.NOOP_REDUCE_ID)
-      val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
-      shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, 
partitionLengths, tmp)
-      mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
-    } finally {
-      if (tmp.exists() && !tmp.delete()) {
-        logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
-      }
-    }
+    val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(
+      dep.shuffleId, mapId, context.taskAttemptId(), 
dep.partitioner.numPartitions)
+    sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
+    val partitionLengths = mapOutputWriter.commitAllPartitions()
+    mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
   }
 
   /** Close this writer, passing along whether the map completed */
diff --git 
a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala 
b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index 17390f9..758621c 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -24,6 +24,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.serializer.{SerializationStream, SerializerInstance, 
SerializerManager}
 import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
 import org.apache.spark.util.Utils
+import org.apache.spark.util.collection.PairsWriter
 
 /**
  * A class for writing JVM objects directly to a file on disk. This class 
allows data to be appended
@@ -46,7 +47,8 @@ private[spark] class DiskBlockObjectWriter(
     writeMetrics: ShuffleWriteMetricsReporter,
     val blockId: BlockId = null)
   extends OutputStream
-  with Logging {
+  with Logging
+  with PairsWriter {
 
   /**
    * Guards against close calls, e.g. from a wrapping stream.
@@ -232,7 +234,7 @@ private[spark] class DiskBlockObjectWriter(
   /**
    * Writes a key-value pair.
    */
-  def write(key: Any, value: Any) {
+  override def write(key: Any, value: Any) {
     if (!streamOpen) {
       open()
     }
diff --git 
a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala 
b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 3f3b7d2..7a822e1 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -23,13 +23,16 @@ import java.util.Comparator
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
-import com.google.common.io.ByteStreams
+import com.google.common.io.{ByteStreams, Closeables}
 
 import org.apache.spark._
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.internal.{config, Logging}
 import org.apache.spark.serializer._
-import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
+import org.apache.spark.shuffle.ShufflePartitionPairsWriter
+import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, 
ShufflePartitionWriter}
+import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, 
ShuffleBlockId}
+import org.apache.spark.util.{Utils => TryUtils}
 
 /**
  * Sorts and potentially merges a number of key-value pairs of type (K, V) to 
produce key-combiner
@@ -670,11 +673,9 @@ private[spark] class ExternalSorter[K, V, C](
   }
 
   /**
-   * Write all the data added into this ExternalSorter into a file in the disk 
store. This is
-   * called by the SortShuffleWriter.
-   *
-   * @param blockId block ID to write to. The index file will be blockId.name 
+ ".index".
-   * @return array of lengths, in bytes, of each partition of the file (used 
by map output tracker)
+   * TODO(SPARK-28764): remove this, as this is only used by 
UnsafeRowSerializerSuite in the SQL
+   * project. We should figure out an alternative way to test that so that we 
can remove this
+   * otherwise unused code path.
    */
   def writePartitionedFile(
       blockId: BlockId,
@@ -718,6 +719,77 @@ private[spark] class ExternalSorter[K, V, C](
     lengths
   }
 
+  /**
+   * Write all the data added into this ExternalSorter into a map output 
writer that pushes bytes
+   * to some arbitrary backing store. This is called by the SortShuffleWriter.
+   *
+   * @return array of lengths, in bytes, of each partition of the file (used 
by map output tracker)
+   */
+  def writePartitionedMapOutput(
+      shuffleId: Int,
+      mapId: Int,
+      mapOutputWriter: ShuffleMapOutputWriter): Unit = {
+    var nextPartitionId = 0
+    if (spills.isEmpty) {
+      // Case where we only have in-memory data
+      val collection = if (aggregator.isDefined) map else buffer
+      val it = 
collection.destructiveSortedWritablePartitionedIterator(comparator)
+      while (it.hasNext()) {
+        val partitionId = it.nextPartition()
+        var partitionWriter: ShufflePartitionWriter = null
+        var partitionPairsWriter: ShufflePartitionPairsWriter = null
+        TryUtils.tryWithSafeFinally {
+          partitionWriter = mapOutputWriter.getPartitionWriter(partitionId)
+          val blockId = ShuffleBlockId(shuffleId, mapId, partitionId)
+          partitionPairsWriter = new ShufflePartitionPairsWriter(
+            partitionWriter,
+            serializerManager,
+            serInstance,
+            blockId,
+            context.taskMetrics().shuffleWriteMetrics)
+          while (it.hasNext && it.nextPartition() == partitionId) {
+            it.writeNext(partitionPairsWriter)
+          }
+        } {
+          if (partitionPairsWriter != null) {
+            partitionPairsWriter.close()
+          }
+        }
+        nextPartitionId = partitionId + 1
+      }
+    } else {
+      // We must perform merge-sort; get an iterator by partition and write 
everything directly.
+      for ((id, elements) <- this.partitionedIterator) {
+        val blockId = ShuffleBlockId(shuffleId, mapId, id)
+        var partitionWriter: ShufflePartitionWriter = null
+        var partitionPairsWriter: ShufflePartitionPairsWriter = null
+        TryUtils.tryWithSafeFinally {
+          partitionWriter = mapOutputWriter.getPartitionWriter(id)
+          partitionPairsWriter = new ShufflePartitionPairsWriter(
+            partitionWriter,
+            serializerManager,
+            serInstance,
+            blockId,
+            context.taskMetrics().shuffleWriteMetrics)
+          if (elements.hasNext) {
+            for (elem <- elements) {
+              partitionPairsWriter.write(elem._1, elem._2)
+            }
+          }
+        } {
+          if (partitionPairsWriter != null) {
+            partitionPairsWriter.close()
+          }
+        }
+        nextPartitionId = id + 1
+      }
+    }
+
+    context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
+    context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
+    context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)
+  }
+
   def stop(): Unit = {
     spills.foreach(s => s.file.delete())
     spills.clear()
@@ -781,7 +853,7 @@ private[spark] class ExternalSorter[K, V, C](
         val inMemoryIterator = new WritablePartitionedIterator {
           private[this] var cur = if (upstream.hasNext) upstream.next() else 
null
 
-          def writeNext(writer: DiskBlockObjectWriter): Unit = {
+          def writeNext(writer: PairsWriter): Unit = {
             writer.write(cur._1._2, cur._2)
             cur = if (upstream.hasNext) upstream.next() else null
           }
diff --git 
a/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala 
b/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala
new file mode 100644
index 0000000..05ed72c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+/**
+ * An abstraction of a consumer of key-value pairs, primarily used when
+ * persisting partitioned data, either through the shuffle writer plugins
+ * or via DiskBlockObjectWriter.
+ */
+private[spark] trait PairsWriter {
+
+  def write(key: Any, value: Any): Unit
+}
diff --git 
a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
 
b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
index dd7f68f..da8d58d 100644
--- 
a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
+++ 
b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
@@ -52,7 +52,7 @@ private[spark] trait WritablePartitionedPairCollection[K, V] {
     new WritablePartitionedIterator {
       private[this] var cur = if (it.hasNext) it.next() else null
 
-      def writeNext(writer: DiskBlockObjectWriter): Unit = {
+      def writeNext(writer: PairsWriter): Unit = {
         writer.write(cur._1._2, cur._2)
         cur = if (it.hasNext) it.next() else null
       }
@@ -89,7 +89,7 @@ private[spark] object WritablePartitionedPairCollection {
  * has an associated partition.
  */
 private[spark] trait WritablePartitionedIterator {
-  def writeNext(writer: DiskBlockObjectWriter): Unit
+  def writeNext(writer: PairsWriter): Unit
 
   def hasNext(): Boolean
 
diff --git 
a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
 
b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
index 690bcd9..0dd6040 100644
--- 
a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
@@ -17,24 +17,32 @@
 
 package org.apache.spark.shuffle.sort
 
+import org.mockito.{Mock, MockitoAnnotations}
+import org.mockito.Answers.RETURNS_SMART_NULLS
 import org.mockito.Mockito._
-import org.mockito.MockitoAnnotations
 import org.scalatest.Matchers
 
 import org.apache.spark.{Partitioner, SharedSparkContext, ShuffleDependency, 
SparkFunSuite}
 import org.apache.spark.memory.MemoryTestingUtils
 import org.apache.spark.serializer.JavaSerializer
 import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver}
+import org.apache.spark.shuffle.api.ShuffleExecutorComponents
+import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents
+import org.apache.spark.storage.BlockManager
 import org.apache.spark.util.Utils
 
 
 class SortShuffleWriterSuite extends SparkFunSuite with SharedSparkContext 
with Matchers {
 
+  @Mock(answer = RETURNS_SMART_NULLS)
+  private var blockManager: BlockManager = _
+
   private val shuffleId = 0
   private val numMaps = 5
   private var shuffleHandle: BaseShuffleHandle[Int, Int, Int] = _
   private val shuffleBlockResolver = new IndexShuffleBlockResolver(conf)
   private val serializer = new JavaSerializer(conf)
+  private var shuffleExecutorComponents: ShuffleExecutorComponents = _
 
   override def beforeEach(): Unit = {
     super.beforeEach()
@@ -51,6 +59,8 @@ class SortShuffleWriterSuite extends SparkFunSuite with 
SharedSparkContext with
       when(dependency.keyOrdering).thenReturn(None)
       new BaseShuffleHandle(shuffleId, numMaps = numMaps, dependency)
     }
+    shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents(
+      conf, blockManager, shuffleBlockResolver)
   }
 
   override def afterAll(): Unit = {
@@ -67,7 +77,8 @@ class SortShuffleWriterSuite extends SparkFunSuite with 
SharedSparkContext with
       shuffleBlockResolver,
       shuffleHandle,
       mapId = 1,
-      context)
+      context,
+      shuffleExecutorComponents)
     writer.write(Iterator.empty)
     writer.stop(success = true)
     val dataFile = shuffleBlockResolver.getDataFile(shuffleId, 1)
@@ -84,7 +95,8 @@ class SortShuffleWriterSuite extends SparkFunSuite with 
SharedSparkContext with
       shuffleBlockResolver,
       shuffleHandle,
       mapId = 2,
-      context)
+      context,
+      shuffleExecutorComponents)
     writer.write(records.toIterator)
     writer.stop(success = true)
     val dataFile = shuffleBlockResolver.getDataFile(shuffleId, 2)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to