Github user JoshRosen commented on a diff in the pull request:
https://github.com/apache/spark/pull/5868#discussion_r30180807
--- Diff:
core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java ---
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.io.*;
+import java.nio.channels.FileChannel;
+import java.util.Iterator;
+import javax.annotation.Nullable;
+
+import scala.Option;
+import scala.Product2;
+import scala.collection.JavaConversions;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.io.ByteStreams;
+import com.google.common.io.Closeables;
+import com.google.common.io.Files;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.*;
+import org.apache.spark.io.CompressionCodec;
+import org.apache.spark.io.CompressionCodec$;
+import org.apache.spark.io.LZFCompressionCodec;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.network.util.LimitedInputStream;
+import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.scheduler.MapStatus$;
+import org.apache.spark.serializer.SerializationStream;
+import org.apache.spark.serializer.Serializer;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.shuffle.ShuffleWriter;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
+
+ private final Logger logger =
LoggerFactory.getLogger(UnsafeShuffleWriter.class);
+
+ private static final ClassTag<Object> OBJECT_CLASS_TAG =
ClassTag$.MODULE$.Object();
+
+ @VisibleForTesting
+ static final int INITIAL_SORT_BUFFER_SIZE = 4096;
+
+ private final BlockManager blockManager;
+ private final IndexShuffleBlockResolver shuffleBlockResolver;
+ private final TaskMemoryManager memoryManager;
+ private final ShuffleMemoryManager shuffleMemoryManager;
+ private final SerializerInstance serializer;
+ private final Partitioner partitioner;
+ private final ShuffleWriteMetrics writeMetrics;
+ private final int shuffleId;
+ private final int mapId;
+ private final TaskContext taskContext;
+ private final SparkConf sparkConf;
+ private final boolean transferToEnabled;
+
+ private MapStatus mapStatus = null;
+ private UnsafeShuffleExternalSorter sorter = null;
+
+ /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
+ private static final class MyByteArrayOutputStream extends
ByteArrayOutputStream {
+ public MyByteArrayOutputStream(int size) { super(size); }
+ public byte[] getBuf() { return buf; }
+ }
+
+ private MyByteArrayOutputStream serBuffer;
+ private SerializationStream serOutputStream;
+
+ /**
+ * Are we in the process of stopping? Because map tasks can call stop()
with success = true
+ * and then call stop() with success = false if they get an exception,
we want to make sure
+ * we don't try deleting files, etc twice.
+ */
+ private boolean stopping = false;
+
+ public UnsafeShuffleWriter(
+ BlockManager blockManager,
+ IndexShuffleBlockResolver shuffleBlockResolver,
+ TaskMemoryManager memoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
+ UnsafeShuffleHandle<K, V> handle,
+ int mapId,
+ TaskContext taskContext,
+ SparkConf sparkConf) {
+ final int numPartitions =
handle.dependency().partitioner().numPartitions();
+ if (numPartitions > PackedRecordPointer.MAXIMUM_PARTITION_ID) {
+ throw new IllegalArgumentException(
+ "UnsafeShuffleWriter can only be used for shuffles with at most " +
+ PackedRecordPointer.MAXIMUM_PARTITION_ID + " reduce partitions");
+ }
+ this.blockManager = blockManager;
+ this.shuffleBlockResolver = shuffleBlockResolver;
+ this.memoryManager = memoryManager;
+ this.shuffleMemoryManager = shuffleMemoryManager;
+ this.mapId = mapId;
+ final ShuffleDependency<K, V, V> dep = handle.dependency();
+ this.shuffleId = dep.shuffleId();
+ this.serializer =
Serializer.getSerializer(dep.serializer()).newInstance();
+ this.partitioner = dep.partitioner();
+ this.writeMetrics = new ShuffleWriteMetrics();
+
taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics));
+ this.taskContext = taskContext;
+ this.sparkConf = sparkConf;
+ this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo",
true);
+ }
+
+ public void write(Iterator<Product2<K, V>> records) throws IOException {
+ write(JavaConversions.asScalaIterator(records));
+ }
+
+ @Override
+ public void write(scala.collection.Iterator<Product2<K, V>> records)
throws IOException {
+ try {
+ while (records.hasNext()) {
+ insertRecordIntoSorter(records.next());
+ }
+ closeAndWriteOutput();
+ } catch (Exception e) {
+ // Unfortunately, we have to catch Exception here in order to ensure
proper cleanup after
+ // errors because Spark's Scala code, or users' custom Serializers,
might throw arbitrary
+ // unchecked exceptions.
+ try {
+ sorter.cleanupAfterError();
+ } finally {
+ throw new IOException("Error during shuffle write", e);
+ }
+ }
+ }
+
+ private void open() throws IOException {
+ assert (sorter == null);
+ sorter = new UnsafeShuffleExternalSorter(
+ memoryManager,
+ shuffleMemoryManager,
+ blockManager,
+ taskContext,
+ INITIAL_SORT_BUFFER_SIZE,
+ partitioner.numPartitions(),
+ sparkConf,
+ writeMetrics);
+ serBuffer = new MyByteArrayOutputStream(1024 * 1024);
+ serOutputStream = serializer.serializeStream(serBuffer);
+ }
+
+ @VisibleForTesting
+ void closeAndWriteOutput() throws IOException {
+ if (sorter == null) {
+ open();
+ }
+ serBuffer = null;
+ serOutputStream = null;
+ final SpillInfo[] spills = sorter.closeAndGetSpills();
+ sorter = null;
+ final long[] partitionLengths;
+ try {
+ partitionLengths = mergeSpills(spills);
+ } finally {
+ for (SpillInfo spill : spills) {
+ if (spill.file.exists() && ! spill.file.delete()) {
+ logger.error("Error while deleting spill file {}",
spill.file.getPath());
+ }
+ }
+ }
+ shuffleBlockResolver.writeIndexFile(shuffleId, mapId,
partitionLengths);
+ mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(),
partitionLengths);
+ }
+
+ @VisibleForTesting
+ void insertRecordIntoSorter(Product2<K, V> record) throws IOException{
+ if (sorter == null) {
+ open();
+ }
+ final K key = record._1();
+ final int partitionId = partitioner.getPartition(key);
+ serBuffer.reset();
+ serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
+ serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
+ serOutputStream.flush();
+
+ final int serializedRecordSize = serBuffer.size();
+ assert (serializedRecordSize > 0);
+
+ sorter.insertRecord(
+ serBuffer.getBuf(), PlatformDependent.BYTE_ARRAY_OFFSET,
serializedRecordSize, partitionId);
+ }
+
+ @VisibleForTesting
+ void forceSorterToSpill() throws IOException {
+ assert (sorter != null);
+ sorter.spill();
+ }
+
+ /**
+ * Merge zero or more spill files together, choosing the fastest merging
strategy based on the
+ * number of spills and the IO compression codec.
+ *
+ * @return the partition lengths in the merged file.
+ */
+ private long[] mergeSpills(SpillInfo[] spills) throws IOException {
+ final File outputFile = shuffleBlockResolver.getDataFile(shuffleId,
mapId);
+ final boolean compressionEnabled =
sparkConf.getBoolean("spark.shuffle.compress", true);
+ final CompressionCodec compressionCodec =
CompressionCodec$.MODULE$.createCodec(sparkConf);
+ final boolean fastMergeEnabled =
+ sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true);
+ final boolean fastMergeIsSupported =
+ !compressionEnabled || compressionCodec instanceof
LZFCompressionCodec;
+ try {
+ if (spills.length == 0) {
+ new FileOutputStream(outputFile).close(); // Create an empty file
+ return new long[partitioner.numPartitions()];
+ } else if (spills.length == 1) {
+ // Here, we don't need to perform any metrics updates because the
bytes written to this
+ // output file would have already been counted as shuffle bytes
written.
+ Files.move(spills[0].file, outputFile);
+ return spills[0].partitionLengths;
+ } else {
+ final long[] partitionLengths;
+ // There are multiple spills to merge, so none of these spill
files' lengths were counted
+ // towards our shuffle write count or shuffle write time. If we
use the slow merge path,
+ // then the final output file's size won't necessarily be equal to
the sum of the spill
+ // files' sizes. To guard against this case, we look at the output
file's actual size when
+ // computing shuffle bytes written.
+ //
+ // We allow the individual merge methods to report their own IO
times since different merge
+ // strategies use different IO techniques. We count IO during
merge towards the shuffle
+ // shuffle write time, which appears to be consistent with the
"not bypassing merge-sort"
+ // branch in ExternalSorter.
+ if (fastMergeEnabled && fastMergeIsSupported) {
+ // Compression is disabled or we are using an IO compression
codec that supports
+ // decompression of concatenated compressed streams, so we can
perform a fast spill merge
+ // that doesn't need to interpret the spilled bytes.
+ if (transferToEnabled) {
+ logger.debug("Using transferTo-based fast merge");
+ partitionLengths = mergeSpillsWithTransferTo(spills,
outputFile);
+ } else {
+ logger.debug("Using fileStream-based fast merge");
+ partitionLengths = mergeSpillsWithFileStream(spills,
outputFile, null);
+ }
+ } else {
+ logger.debug("Using slow merge");
+ partitionLengths = mergeSpillsWithFileStream(spills, outputFile,
compressionCodec);
+ }
+ // The final shuffle spill's write would have directly updated
shuffleBytesWritten, so
+ // we need to decrement to avoid double-counting this write.
+ writeMetrics.decShuffleBytesWritten(spills[spills.length -
1].file.length());
+ writeMetrics.incShuffleBytesWritten(outputFile.length());
+ return partitionLengths;
+ }
+ } catch (IOException e) {
+ if (outputFile.exists() && !outputFile.delete()) {
+ logger.error("Unable to delete output file {}",
outputFile.getPath());
+ }
+ throw e;
+ }
+ }
+
+ private long[] mergeSpillsWithFileStream(
+ SpillInfo[] spills,
+ File outputFile,
+ @Nullable CompressionCodec compressionCodec) throws IOException {
+ final int numPartitions = partitioner.numPartitions();
+ final long[] partitionLengths = new long[numPartitions];
+ final InputStream[] spillInputStreams = new
FileInputStream[spills.length];
+ OutputStream mergedFileOutputStream = null;
+
+ try {
+ for (int i = 0; i < spills.length; i++) {
+ spillInputStreams[i] = new FileInputStream(spills[i].file);
+ }
+ for (int partition = 0; partition < numPartitions; partition++) {
+ final long initialFileLength = outputFile.length();
+ mergedFileOutputStream =
+ new TimeTrackingFileOutputStream(writeMetrics, new
FileOutputStream(outputFile, true));
+ if (compressionCodec != null) {
+ mergedFileOutputStream =
compressionCodec.compressedOutputStream(mergedFileOutputStream);
+ }
+
+ for (int i = 0; i < spills.length; i++) {
+ final long partitionLengthInSpill =
spills[i].partitionLengths[partition];
+ if (partitionLengthInSpill > 0) {
+ InputStream partitionInputStream =
+ new LimitedInputStream(spillInputStreams[i],
partitionLengthInSpill);
+ if (compressionCodec != null) {
+ partitionInputStream =
compressionCodec.compressedInputStream(partitionInputStream);
+ }
+ ByteStreams.copy(partitionInputStream, mergedFileOutputStream);
+ }
+ }
+ mergedFileOutputStream.flush();
+ mergedFileOutputStream.close();
+ partitionLengths[partition] = (outputFile.length() -
initialFileLength);
+ }
+ } finally {
+ for (InputStream stream : spillInputStreams) {
+ Closeables.close(stream, false);
+ }
+ Closeables.close(mergedFileOutputStream, false);
+ }
+ return partitionLengths;
+ }
+
+ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File
outputFile) throws IOException {
+ final int numPartitions = partitioner.numPartitions();
+ final long[] partitionLengths = new long[numPartitions];
+ final FileChannel[] spillInputChannels = new
FileChannel[spills.length];
+ final long[] spillInputChannelPositions = new long[spills.length];
+ FileChannel mergedFileOutputChannel = null;
+
+ try {
+ for (int i = 0; i < spills.length; i++) {
+ spillInputChannels[i] = new
FileInputStream(spills[i].file).getChannel();
+ }
+ // This file needs to opened in append mode in order to work around
a Linux kernel bug that
+ // affects transferTo; see SPARK-3948 for more details.
+ mergedFileOutputChannel = new FileOutputStream(outputFile,
true).getChannel();
+
+ long bytesWrittenToMergedFile = 0;
+ for (int partition = 0; partition < numPartitions; partition++) {
+ for (int i = 0; i < spills.length; i++) {
+ final long partitionLengthInSpill =
spills[i].partitionLengths[partition];
+ long bytesToTransfer = partitionLengthInSpill;
+ final FileChannel spillInputChannel = spillInputChannels[i];
+ final long writeStartTime = System.nanoTime();
+ while (bytesToTransfer > 0) {
+ final long actualBytesTransferred =
spillInputChannel.transferTo(
+ spillInputChannelPositions[i],
+ bytesToTransfer,
+ mergedFileOutputChannel);
+ spillInputChannelPositions[i] += actualBytesTransferred;
+ bytesToTransfer -= actualBytesTransferred;
+ }
+ writeMetrics.incShuffleWriteTime(System.nanoTime() -
writeStartTime);
+ bytesWrittenToMergedFile += partitionLengthInSpill;
+ partitionLengths[partition] += partitionLengthInSpill;
+ }
+ }
+ // Check the position after transferTo loop to see if it is in the
right position and raise an
+ // exception if it is incorrect. The position will not be increased
to the expected length
+ // after calling transferTo in kernel version 2.6.32. This issue is
described at
+ // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948.
+ if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) {
+ throw new IOException(
+ "Current position " + mergedFileOutputChannel.position() + "
does not equal expected " +
+ "position " + bytesWrittenToMergedFile + " after transferTo.
Please check your kernel" +
+ " version to see if it is 2.6.32, as there is a kernel bug
which will lead to " +
+ "unexpected behavior when using transferTo. You can set
spark.file.transferTo=false " +
+ "to disable this NIO feature."
+ );
+ }
+ } finally {
+ for (int i = 0; i < spills.length; i++) {
+ assert(spillInputChannelPositions[i] == spills[i].file.length());
+ Closeables.close(spillInputChannels[i], false);
--- End diff --
I think that my concern was that throwing exceptions from the finally block
would mask other exceptions, but it looks like there's a nice idiom for
handling this that's shown in the Closeables.close docs:
```
<p>Example: <pre> {@code
*
* public void useStreamNicely() throws IOException {
* SomeStream stream = new SomeStream("foo");
* boolean threw = true;
* try {
* // ... code which does something with the stream ...
* threw = false;
* } finally {
* // If an exception occurs, rethrow it only if threw==false:
* Closeables.close(stream, threw);
* }
* }}</pre>
```
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]