Repository: spark
Updated Branches:
  refs/heads/master e3bf37fa3 -> c7ac027d5


[SPARK-17839][CORE] Use Nio's directbuffer instead of BufferedInputStream in 
order to avoid additional copy from os buffer cache to user buffer

## What changes were proposed in this pull request?

Currently we use BufferedInputStream to read the shuffle file which copies the 
file content from os buffer cache to the user buffer. This adds additional 
latency in reading the spill files. We made a change to use java nio's direct 
buffer to read the spill files and for certain pipelines spilling significant 
amount of data, we see up to 7% speedup for the entire pipeline.

## How was this patch tested?
Tested by running the job in the cluster and observed up to 7% speedup.

Author: Sital Kedia <ske...@fb.com>

Closes #15408 from sitalkedia/skedia/nio_spill_read.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c7ac027d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c7ac027d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c7ac027d

Branch: refs/heads/master
Commit: c7ac027d5fd7a80d3122a9269b2bb9c28c6a57db
Parents: e3bf37f
Author: Sital Kedia <ske...@fb.com>
Authored: Mon Oct 17 11:03:04 2016 -0700
Committer: Shixiong Zhu <shixi...@databricks.com>
Committed: Mon Oct 17 11:03:04 2016 -0700

----------------------------------------------------------------------
 .../spark/io/NioBufferedFileInputStream.java    | 137 +++++++++++++++++++
 .../unsafe/sort/UnsafeSorterSpillReader.java    |   5 +-
 .../shuffle/IndexShuffleBlockResolver.scala     |   3 +-
 .../io/NioBufferedFileInputStreamSuite.java     | 135 ++++++++++++++++++
 .../spark/sql/execution/python/RowQueue.scala   |   3 +-
 5 files changed, 279 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c7ac027d/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java 
b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java
new file mode 100644
index 0000000..f6d1288
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.io;
+
+import org.apache.spark.storage.StorageUtils;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import java.nio.channels.FileChannel;
+import java.nio.file.StandardOpenOption;
+
+/**
+ * {@link InputStream} implementation which uses direct buffer
+ * to read a file to avoid extra copy of data between Java and
+ * native memory which happens when using {@link java.io.BufferedInputStream}.
+ * Unfortunately, this is not something already available in JDK,
+ * {@link sun.nio.ch.ChannelInputStream} supports reading a file using nio,
+ * but does not support buffering.
+ */
+public final class NioBufferedFileInputStream extends InputStream {
+
+  private static final int DEFAULT_BUFFER_SIZE_BYTES = 8192;
+
+  private final ByteBuffer byteBuffer;
+
+  private final FileChannel fileChannel;
+
+  public NioBufferedFileInputStream(File file, int bufferSizeInBytes) throws 
IOException {
+    byteBuffer = ByteBuffer.allocateDirect(bufferSizeInBytes);
+    fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ);
+    byteBuffer.flip();
+  }
+
+  public NioBufferedFileInputStream(File file) throws IOException {
+    this(file, DEFAULT_BUFFER_SIZE_BYTES);
+  }
+
+  /**
+   * Checks weather data is left to be read from the input stream.
+   * @return true if data is left, false otherwise
+   * @throws IOException
+   */
+  private boolean refill() throws IOException {
+    if (!byteBuffer.hasRemaining()) {
+      byteBuffer.clear();
+      int nRead = 0;
+      while (nRead == 0) {
+        nRead = fileChannel.read(byteBuffer);
+      }
+      if (nRead < 0) {
+        return false;
+      }
+      byteBuffer.flip();
+    }
+    return true;
+  }
+
+  @Override
+  public synchronized int read() throws IOException {
+    if (!refill()) {
+      return -1;
+    }
+    return byteBuffer.get() & 0xFF;
+  }
+
+  @Override
+  public synchronized int read(byte[] b, int offset, int len) throws 
IOException {
+    if (offset < 0 || len < 0 || offset + len < 0 || offset + len > b.length) {
+      throw new IndexOutOfBoundsException();
+    }
+    if (!refill()) {
+      return -1;
+    }
+    len = Math.min(len, byteBuffer.remaining());
+    byteBuffer.get(b, offset, len);
+    return len;
+  }
+
+  @Override
+  public synchronized int available() throws IOException {
+    return byteBuffer.remaining();
+  }
+
+  @Override
+  public synchronized long skip(long n) throws IOException {
+    if (n <= 0L) {
+      return 0L;
+    }
+    if (byteBuffer.remaining() >= n) {
+      // The buffered content is enough to skip
+      byteBuffer.position(byteBuffer.position() + (int) n);
+      return n;
+    }
+    long skippedFromBuffer = byteBuffer.remaining();
+    long toSkipFromFileChannel = n - skippedFromBuffer;
+    // Discard everything we have read in the buffer.
+    byteBuffer.position(0);
+    byteBuffer.flip();
+    return skippedFromBuffer + skipFromFileChannel(toSkipFromFileChannel);
+  }
+
+  private long skipFromFileChannel(long n) throws IOException {
+    long currentFilePosition = fileChannel.position();
+    long size = fileChannel.size();
+    if (n > size - currentFilePosition) {
+      fileChannel.position(size);
+      return size - currentFilePosition;
+    } else {
+      fileChannel.position(currentFilePosition + n);
+      return n;
+    }
+  }
+
+  @Override
+  public synchronized void close() throws IOException {
+    fileChannel.close();
+    StorageUtils.dispose(byteBuffer);
+  }
+
+  @Override
+  protected void finalize() throws IOException {
+    close();
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c7ac027d/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
index e6d9766..a658e5e 100644
--- 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
+++ 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -23,6 +23,7 @@ import com.google.common.io.ByteStreams;
 import com.google.common.io.Closeables;
 
 import org.apache.spark.SparkEnv;
+import org.apache.spark.io.NioBufferedFileInputStream;
 import org.apache.spark.serializer.SerializerManager;
 import org.apache.spark.storage.BlockId;
 import org.apache.spark.unsafe.Platform;
@@ -69,8 +70,8 @@ public final class UnsafeSorterSpillReader extends 
UnsafeSorterIterator implemen
       bufferSizeBytes = DEFAULT_BUFFER_SIZE_BYTES;
     }
 
-    final BufferedInputStream bs =
-        new BufferedInputStream(new FileInputStream(file), (int) 
bufferSizeBytes);
+    final InputStream bs =
+        new NioBufferedFileInputStream(file, (int) bufferSizeBytes);
     try {
       this.in = serializerManager.wrapStream(blockId, bs);
       this.din = new DataInputStream(this.in);

http://git-wip-us.apache.org/repos/asf/spark/blob/c7ac027d/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala 
b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
index 8d6396b..91858f0 100644
--- 
a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
+++ 
b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
@@ -23,6 +23,7 @@ import com.google.common.io.ByteStreams
 
 import org.apache.spark.{SparkConf, SparkEnv}
 import org.apache.spark.internal.Logging
+import org.apache.spark.io.NioBufferedFileInputStream
 import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, 
ManagedBuffer}
 import org.apache.spark.network.netty.SparkTransportConf
 import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
@@ -89,7 +90,7 @@ private[spark] class IndexShuffleBlockResolver(
     val lengths = new Array[Long](blocks)
     // Read the lengths of blocks
     val in = try {
-      new DataInputStream(new BufferedInputStream(new FileInputStream(index)))
+      new DataInputStream(new NioBufferedFileInputStream(index))
     } catch {
       case e: IOException =>
         return null

http://git-wip-us.apache.org/repos/asf/spark/blob/c7ac027d/core/src/test/java/org/apache/spark/io/NioBufferedFileInputStreamSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/io/NioBufferedFileInputStreamSuite.java 
b/core/src/test/java/org/apache/spark/io/NioBufferedFileInputStreamSuite.java
new file mode 100644
index 0000000..2c1a34a
--- /dev/null
+++ 
b/core/src/test/java/org/apache/spark/io/NioBufferedFileInputStreamSuite.java
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.io;
+
+import org.apache.commons.io.FileUtils;
+import org.apache.commons.lang3.RandomUtils;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Tests functionality of {@link NioBufferedFileInputStream}
+ */
+public class NioBufferedFileInputStreamSuite {
+
+  private byte[] randomBytes;
+
+  private File inputFile;
+
+  @Before
+  public void setUp() throws IOException {
+    // Create a byte array of size 2 MB with random bytes
+    randomBytes =  RandomUtils.nextBytes(2 * 1024 * 1024);
+    inputFile = File.createTempFile("temp-file", ".tmp");
+    FileUtils.writeByteArrayToFile(inputFile, randomBytes);
+  }
+
+  @After
+  public void tearDown() {
+    inputFile.delete();
+  }
+
+  @Test
+  public void testReadOneByte() throws IOException {
+    InputStream inputStream = new NioBufferedFileInputStream(inputFile);
+    for (int i = 0; i < randomBytes.length; i++) {
+      assertEquals(randomBytes[i], (byte) inputStream.read());
+    }
+  }
+
+  @Test
+  public void testReadMultipleBytes() throws IOException {
+    InputStream inputStream = new NioBufferedFileInputStream(inputFile);
+    byte[] readBytes = new byte[8 * 1024];
+    int i = 0;
+    while (i < randomBytes.length) {
+      int read = inputStream.read(readBytes, 0, 8 * 1024);
+      for (int j = 0; j < read; j++) {
+        assertEquals(randomBytes[i], readBytes[j]);
+        i++;
+      }
+    }
+  }
+
+  @Test
+  public void testBytesSkipped() throws IOException {
+    InputStream inputStream = new NioBufferedFileInputStream(inputFile);
+    assertEquals(1024, inputStream.skip(1024));
+    for (int i = 1024; i < randomBytes.length; i++) {
+      assertEquals(randomBytes[i], (byte) inputStream.read());
+    }
+  }
+
+  @Test
+  public void testBytesSkippedAfterRead() throws IOException {
+    InputStream inputStream = new NioBufferedFileInputStream(inputFile);
+    for (int i = 0; i < 1024; i++) {
+      assertEquals(randomBytes[i], (byte) inputStream.read());
+    }
+    assertEquals(1024, inputStream.skip(1024));
+    for (int i = 2048; i < randomBytes.length; i++) {
+      assertEquals(randomBytes[i], (byte) inputStream.read());
+    }
+  }
+
+  @Test
+  public void testNegativeBytesSkippedAfterRead() throws IOException {
+    InputStream inputStream = new NioBufferedFileInputStream(inputFile);
+    for (int i = 0; i < 1024; i++) {
+      assertEquals(randomBytes[i], (byte) inputStream.read());
+    }
+    // Skipping negative bytes should essential be a no-op
+    assertEquals(0, inputStream.skip(-1));
+    assertEquals(0, inputStream.skip(-1024));
+    assertEquals(0, inputStream.skip(Long.MIN_VALUE));
+    assertEquals(1024, inputStream.skip(1024));
+    for (int i = 2048; i < randomBytes.length; i++) {
+      assertEquals(randomBytes[i], (byte) inputStream.read());
+    }
+  }
+
+  @Test
+  public void testSkipFromFileChannel() throws IOException {
+    InputStream inputStream = new NioBufferedFileInputStream(inputFile, 10);
+    // Since the buffer is smaller than the skipped bytes, this will guarantee
+    // we skip from underlying file channel.
+    assertEquals(1024, inputStream.skip(1024));
+    for (int i = 1024; i < 2048; i++) {
+      assertEquals(randomBytes[i], (byte) inputStream.read());
+    }
+    assertEquals(256, inputStream.skip(256));
+    assertEquals(256, inputStream.skip(256));
+    assertEquals(512, inputStream.skip(512));
+    for (int i = 3072; i < randomBytes.length; i++) {
+      assertEquals(randomBytes[i], (byte) inputStream.read());
+    }
+  }
+
+  @Test
+  public void testBytesSkippedAfterEOF() throws IOException {
+    InputStream inputStream = new NioBufferedFileInputStream(inputFile);
+    assertEquals(randomBytes.length, inputStream.skip(randomBytes.length + 1));
+    assertEquals(-1, inputStream.read());
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c7ac027d/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
index 422a3f8..cd1e77f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
@@ -22,6 +22,7 @@ import java.io._
 import com.google.common.io.Closeables
 
 import org.apache.spark.SparkException
+import org.apache.spark.io.NioBufferedFileInputStream
 import org.apache.spark.memory.{MemoryConsumer, TaskMemoryManager}
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.unsafe.Platform
@@ -130,7 +131,7 @@ private[python] case class DiskRowQueue(file: File, fields: 
Int) extends RowQueu
     if (out != null) {
       out.close()
       out = null
-      in = new DataInputStream(new BufferedInputStream(new 
FileInputStream(file.toString)))
+      in = new DataInputStream(new NioBufferedFileInputStream(file))
     }
 
     if (unreadBytes > 0) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to