Github user JoshRosen commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5868#discussion_r30286052
  
    --- Diff: 
core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java ---
    @@ -0,0 +1,438 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.shuffle.unsafe;
    +
    +import java.io.*;
    +import java.nio.channels.FileChannel;
    +import java.util.Iterator;
    +import javax.annotation.Nullable;
    +
    +import scala.Option;
    +import scala.Product2;
    +import scala.collection.JavaConversions;
    +import scala.reflect.ClassTag;
    +import scala.reflect.ClassTag$;
    +
    +import com.google.common.annotations.VisibleForTesting;
    +import com.google.common.io.ByteStreams;
    +import com.google.common.io.Closeables;
    +import com.google.common.io.Files;
    +import org.slf4j.Logger;
    +import org.slf4j.LoggerFactory;
    +
    +import org.apache.spark.*;
    +import org.apache.spark.annotation.Private;
    +import org.apache.spark.io.CompressionCodec;
    +import org.apache.spark.io.CompressionCodec$;
    +import org.apache.spark.io.LZFCompressionCodec;
    +import org.apache.spark.executor.ShuffleWriteMetrics;
    +import org.apache.spark.network.util.LimitedInputStream;
    +import org.apache.spark.scheduler.MapStatus;
    +import org.apache.spark.scheduler.MapStatus$;
    +import org.apache.spark.serializer.SerializationStream;
    +import org.apache.spark.serializer.Serializer;
    +import org.apache.spark.serializer.SerializerInstance;
    +import org.apache.spark.shuffle.IndexShuffleBlockResolver;
    +import org.apache.spark.shuffle.ShuffleMemoryManager;
    +import org.apache.spark.shuffle.ShuffleWriter;
    +import org.apache.spark.storage.BlockManager;
    +import org.apache.spark.storage.TimeTrackingOutputStream;
    +import org.apache.spark.unsafe.PlatformDependent;
    +import org.apache.spark.unsafe.memory.TaskMemoryManager;
    +
    +@Private
    +public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
    +
    +  private final Logger logger = 
LoggerFactory.getLogger(UnsafeShuffleWriter.class);
    +
    +  private static final ClassTag<Object> OBJECT_CLASS_TAG = 
ClassTag$.MODULE$.Object();
    +
    +  @VisibleForTesting
    +  static final int INITIAL_SORT_BUFFER_SIZE = 4096;
    +
    +  private final BlockManager blockManager;
    +  private final IndexShuffleBlockResolver shuffleBlockResolver;
    +  private final TaskMemoryManager memoryManager;
    +  private final ShuffleMemoryManager shuffleMemoryManager;
    +  private final SerializerInstance serializer;
    +  private final Partitioner partitioner;
    +  private final ShuffleWriteMetrics writeMetrics;
    +  private final int shuffleId;
    +  private final int mapId;
    +  private final TaskContext taskContext;
    +  private final SparkConf sparkConf;
    +  private final boolean transferToEnabled;
    +
    +  private MapStatus mapStatus = null;
    +  private UnsafeShuffleExternalSorter sorter = null;
    +
    +  /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
    +  private static final class MyByteArrayOutputStream extends 
ByteArrayOutputStream {
    +    public MyByteArrayOutputStream(int size) { super(size); }
    +    public byte[] getBuf() { return buf; }
    +  }
    +
    +  private MyByteArrayOutputStream serBuffer;
    +  private SerializationStream serOutputStream;
    +
    +  /**
    +   * Are we in the process of stopping? Because map tasks can call stop() 
with success = true
    +   * and then call stop() with success = false if they get an exception, 
we want to make sure
    +   * we don't try deleting files, etc twice.
    +   */
    +  private boolean stopping = false;
    +
    +  public UnsafeShuffleWriter(
    +      BlockManager blockManager,
    +      IndexShuffleBlockResolver shuffleBlockResolver,
    +      TaskMemoryManager memoryManager,
    +      ShuffleMemoryManager shuffleMemoryManager,
    +      UnsafeShuffleHandle<K, V> handle,
    +      int mapId,
    +      TaskContext taskContext,
    +      SparkConf sparkConf) throws IOException {
    +    final int numPartitions = 
handle.dependency().partitioner().numPartitions();
    +    if (numPartitions > 
UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS()) {
    +      throw new IllegalArgumentException(
    +        "UnsafeShuffleWriter can only be used for shuffles with at most " +
    +          UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS() + " reduce 
partitions");
    +    }
    +    this.blockManager = blockManager;
    +    this.shuffleBlockResolver = shuffleBlockResolver;
    +    this.memoryManager = memoryManager;
    +    this.shuffleMemoryManager = shuffleMemoryManager;
    +    this.mapId = mapId;
    +    final ShuffleDependency<K, V, V> dep = handle.dependency();
    +    this.shuffleId = dep.shuffleId();
    +    this.serializer = 
Serializer.getSerializer(dep.serializer()).newInstance();
    +    this.partitioner = dep.partitioner();
    +    this.writeMetrics = new ShuffleWriteMetrics();
    +    
taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics));
    +    this.taskContext = taskContext;
    +    this.sparkConf = sparkConf;
    +    this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", 
true);
    +    open();
    +  }
    +
    +  /**
    +   * This convenience method should only be called in test code.
    +   */
    +  @VisibleForTesting
    +  public void write(Iterator<Product2<K, V>> records) throws IOException {
    +    write(JavaConversions.asScalaIterator(records));
    +  }
    +
    +  @Override
    +  public void write(scala.collection.Iterator<Product2<K, V>> records) 
throws IOException {
    +    boolean success = false;
    +    try {
    +      while (records.hasNext()) {
    +        insertRecordIntoSorter(records.next());
    +      }
    +      closeAndWriteOutput();
    +      success = true;
    +    } finally {
    +      if (!success) {
    +        sorter.cleanupAfterError();
    +      }
    +    }
    +  }
    +
    +  private void open() throws IOException {
    +    assert (sorter == null);
    +    sorter = new UnsafeShuffleExternalSorter(
    +      memoryManager,
    +      shuffleMemoryManager,
    +      blockManager,
    +      taskContext,
    +      INITIAL_SORT_BUFFER_SIZE,
    +      partitioner.numPartitions(),
    +      sparkConf,
    +      writeMetrics);
    +    serBuffer = new MyByteArrayOutputStream(1024 * 1024);
    +    serOutputStream = serializer.serializeStream(serBuffer);
    +  }
    +
    +  @VisibleForTesting
    +  void closeAndWriteOutput() throws IOException {
    +    serBuffer = null;
    +    serOutputStream = null;
    +    final SpillInfo[] spills = sorter.closeAndGetSpills();
    +    sorter = null;
    +    final long[] partitionLengths;
    +    try {
    +      partitionLengths = mergeSpills(spills);
    +    } finally {
    +      for (SpillInfo spill : spills) {
    +        if (spill.file.exists() && ! spill.file.delete()) {
    +          logger.error("Error while deleting spill file {}", 
spill.file.getPath());
    +        }
    +      }
    +    }
    +    shuffleBlockResolver.writeIndexFile(shuffleId, mapId, 
partitionLengths);
    +    mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), 
partitionLengths);
    +  }
    +
    +  @VisibleForTesting
    +  void insertRecordIntoSorter(Product2<K, V> record) throws IOException {
    +    final K key = record._1();
    +    final int partitionId = partitioner.getPartition(key);
    +    serBuffer.reset();
    +    serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
    +    serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
    +    serOutputStream.flush();
    +
    +    final int serializedRecordSize = serBuffer.size();
    +    assert (serializedRecordSize > 0);
    +
    +    sorter.insertRecord(
    +      serBuffer.getBuf(), PlatformDependent.BYTE_ARRAY_OFFSET, 
serializedRecordSize, partitionId);
    +  }
    +
    +  @VisibleForTesting
    +  void forceSorterToSpill() throws IOException {
    +    assert (sorter != null);
    +    sorter.spill();
    +  }
    +
    +  /**
    +   * Merge zero or more spill files together, choosing the fastest merging 
strategy based on the
    +   * number of spills and the IO compression codec.
    +   *
    +   * @return the partition lengths in the merged file.
    +   */
    +  private long[] mergeSpills(SpillInfo[] spills) throws IOException {
    +    final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, 
mapId);
    +    final boolean compressionEnabled = 
sparkConf.getBoolean("spark.shuffle.compress", true);
    +    final CompressionCodec compressionCodec = 
CompressionCodec$.MODULE$.createCodec(sparkConf);
    +    final boolean fastMergeEnabled =
    +      sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true);
    +    final boolean fastMergeIsSupported =
    +      !compressionEnabled || compressionCodec instanceof 
LZFCompressionCodec;
    +    try {
    +      if (spills.length == 0) {
    +        new FileOutputStream(outputFile).close(); // Create an empty file
    +        return new long[partitioner.numPartitions()];
    +      } else if (spills.length == 1) {
    +        // Here, we don't need to perform any metrics updates because the 
bytes written to this
    +        // output file would have already been counted as shuffle bytes 
written.
    +        Files.move(spills[0].file, outputFile);
    +        return spills[0].partitionLengths;
    +      } else {
    +        final long[] partitionLengths;
    +        // There are multiple spills to merge, so none of these spill 
files' lengths were counted
    +        // towards our shuffle write count or shuffle write time. If we 
use the slow merge path,
    +        // then the final output file's size won't necessarily be equal to 
the sum of the spill
    +        // files' sizes. To guard against this case, we look at the output 
file's actual size when
    +        // computing shuffle bytes written.
    +        //
    +        // We allow the individual merge methods to report their own IO 
times since different merge
    +        // strategies use different IO techniques.  We count IO during 
merge towards the shuffle
    +        // shuffle write time, which appears to be consistent with the 
"not bypassing merge-sort"
    +        // branch in ExternalSorter.
    +        if (fastMergeEnabled && fastMergeIsSupported) {
    +          // Compression is disabled or we are using an IO compression 
codec that supports
    +          // decompression of concatenated compressed streams, so we can 
perform a fast spill merge
    +          // that doesn't need to interpret the spilled bytes.
    +          if (transferToEnabled) {
    +            logger.debug("Using transferTo-based fast merge");
    +            partitionLengths = mergeSpillsWithTransferTo(spills, 
outputFile);
    +          } else {
    +            logger.debug("Using fileStream-based fast merge");
    +            partitionLengths = mergeSpillsWithFileStream(spills, 
outputFile, null);
    +          }
    +        } else {
    +          logger.debug("Using slow merge");
    +          partitionLengths = mergeSpillsWithFileStream(spills, outputFile, 
compressionCodec);
    +        }
    +        // When closing an UnsafeShuffleExternalSorter that has already 
spilled once but also has
    +        // in-memory records, we write out the in-memory records to a file 
but do not count that
    +        // final write as bytes spilled (instead, it's accounted as 
shuffle write). The merge needs
    +        // to be counted as shuffle write, but this will lead to 
double-counting of the final
    +        // SpillInfo's bytes.
    +        writeMetrics.decShuffleBytesWritten(spills[spills.length - 
1].file.length());
    --- End diff --
    
    If control reaches this point, then we have written multiple files to disk 
in our external sorter.  The last file should not be counted as bytes spilled 
in order to try to remain consistent with other code.  In the existing 
sort-merge code, the final in-memory partition is merged against the on-disk 
partitions and thus does not get counted as extra disk write.  Here, we end up 
dumping the in-memory partition to disk, incurring a bit of extra write but 
massively simplifying the merge path code.  It's true that we're performing 
some disk writes that aren't accounted for anywhere in the final metrics, but 
I'm not sure how this should be reported / handled. I don't think that it's 
correct to treat these bytes as spilled bytes, but treating them as shuffle 
write is also confusing because then we'll end up reporting more shuffle bytes 
written than are transferred over the network.
    
    The reason that I decrement the count here is so that the final count of 
shuffle bytes written is equal to the output file's size.  Alternatively, I 
could have avoided incrementing shuffleBytesWritten when writing the last file, 
but that would cause complications when updating record counts.
    
    I'm happy to work on fixes for the metrics, but I'd rather defer that to a 
separate followup patch that does a broader refactoring of our metrics code.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to