Repository: spark
Updated Branches:
  refs/heads/master f38c76063 -> 7539ae59d


[SPARK-23366] Improve hot reading path in ReadAheadInputStream

## What changes were proposed in this pull request?

`ReadAheadInputStream` was introduced in 
https://github.com/apache/spark/pull/18317/ to optimize reading spill files 
from disk.
However, from the profiles it seems that the hot path of reading small amounts 
of data (like readInt) is inefficient - it involves taking locks, and multiple 
checks.

Optimize locking: Lock is not needed when simply accessing the active buffer. 
Only lock when needing to swap buffers or trigger async reading, or get 
information about the async state.

Optimize short-path single byte reads, that are used e.g. by Java library 
DataInputStream.readInt.

The asyncReader used to call "read" only once on the underlying stream, that 
never filled the underlying buffer when it was wrapping an LZ4BlockInputStream. 
If the buffer was returned unfilled, that would trigger the async reader to be 
triggered to fill the read ahead buffer on each call, because the reader would 
see that the active buffer is below the refill threshold all the time.

However, filling the full buffer all the time could introduce increased 
latency, so also add an `AtomicBoolean` flag for the async reader to return 
earlier if there is a reader waiting for data.

Remove `readAheadThresholdInBytes` and instead immediately trigger async read 
when switching the buffers. It allows to simplify code paths, especially the 
hot one that then only has to check if there is available data in the active 
buffer, without worrying if it needs to retrigger async read. It seems to have 
positive effect on perf.

## How was this patch tested?

It was noticed as a regression in some workloads after upgrading to Spark 2.3. 

It was particularly visible on TPCDS Q95 running on instances with fast disk 
(i3 AWS instances).
Running with profiling:
* Spark 2.2 - 5.2-5.3 minutes 9.5% in LZ4BlockInputStream.read
* Spark 2.3 - 6.4-6.6 minutes 31.1% in ReadAheadInputStream.read
* Spark 2.3 + fix - 5.3-5.4 minutes 13.3% in ReadAheadInputStream.read - very 
slightly slower, practically within noise.

We didn't see other regressions, and many workloads in general seem to be 
faster with Spark 2.3 (not investigated if thanks to async readed, or 
unrelated).

Author: Juliusz Sompolski <[email protected]>

Closes #20555 from juliuszsompolski/SPARK-23366.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7539ae59
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7539ae59
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7539ae59

Branch: refs/heads/master
Commit: 7539ae59d6c354c95c50528abe9ddff6972e960f
Parents: f38c760
Author: Juliusz Sompolski <[email protected]>
Authored: Thu Feb 15 17:09:06 2018 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Thu Feb 15 17:09:06 2018 +0800

----------------------------------------------------------------------
 .../apache/spark/io/ReadAheadInputStream.java   | 119 +++++++++----------
 .../unsafe/sort/UnsafeSorterSpillReader.java    |  10 +-
 .../spark/io/GenericFileInputStreamSuite.java   |  98 ++++++++-------
 .../spark/io/NioBufferedInputStreamSuite.java   |   6 +-
 .../spark/io/ReadAheadInputStreamSuite.java     |  17 ++-
 5 files changed, 133 insertions(+), 117 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7539ae59/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java 
b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java
index 5b45d26..0cced9e 100644
--- a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java
+++ b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java
@@ -27,6 +27,7 @@ import java.io.InterruptedIOException;
 import java.nio.ByteBuffer;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.locks.Condition;
 import java.util.concurrent.locks.ReentrantLock;
 
@@ -78,9 +79,8 @@ public class ReadAheadInputStream extends InputStream {
   // whether there is a read ahead task running,
   private boolean isReading;
 
-  // If the remaining data size in the current buffer is below this threshold,
-  // we issue an async read from the underlying input stream.
-  private final int readAheadThresholdInBytes;
+  // whether there is a reader waiting for data.
+  private AtomicBoolean isWaiting = new AtomicBoolean(false);
 
   private final InputStream underlyingInputStream;
 
@@ -97,20 +97,13 @@ public class ReadAheadInputStream extends InputStream {
    *
    * @param inputStream The underlying input stream.
    * @param bufferSizeInBytes The buffer size.
-   * @param readAheadThresholdInBytes If the active buffer has less data than 
the read-ahead
-   *                                  threshold, an async read is triggered.
    */
   public ReadAheadInputStream(
-      InputStream inputStream, int bufferSizeInBytes, int 
readAheadThresholdInBytes) {
+      InputStream inputStream, int bufferSizeInBytes) {
     Preconditions.checkArgument(bufferSizeInBytes > 0,
         "bufferSizeInBytes should be greater than 0, but the value is " + 
bufferSizeInBytes);
-    Preconditions.checkArgument(readAheadThresholdInBytes > 0 &&
-            readAheadThresholdInBytes < bufferSizeInBytes,
-        "readAheadThresholdInBytes should be greater than 0 and less than 
bufferSizeInBytes, " +
-            "but the value is " + readAheadThresholdInBytes);
     activeBuffer = ByteBuffer.allocate(bufferSizeInBytes);
     readAheadBuffer = ByteBuffer.allocate(bufferSizeInBytes);
-    this.readAheadThresholdInBytes = readAheadThresholdInBytes;
     this.underlyingInputStream = inputStream;
     activeBuffer.flip();
     readAheadBuffer.flip();
@@ -166,12 +159,17 @@ public class ReadAheadInputStream extends InputStream {
         // in that case the reader waits for this async read to complete.
         // So there is no race condition in both the situations.
         int read = 0;
+        int off = 0, len = arr.length;
         Throwable exception = null;
         try {
-          while (true) {
-            read = underlyingInputStream.read(arr);
-            if (0 != read) break;
-          }
+          // try to fill the read ahead buffer.
+          // if a reader is waiting, possibly return early.
+          do {
+            read = underlyingInputStream.read(arr, off, len);
+            if (read <= 0) break;
+            off += read;
+            len -= read;
+          } while (len > 0 && !isWaiting.get());
         } catch (Throwable ex) {
           exception = ex;
           if (ex instanceof Error) {
@@ -181,13 +179,12 @@ public class ReadAheadInputStream extends InputStream {
           }
         } finally {
           stateChangeLock.lock();
+          readAheadBuffer.limit(off);
           if (read < 0 || (exception instanceof EOFException)) {
             endOfStream = true;
           } else if (exception != null) {
             readAborted = true;
             readException = exception;
-          } else {
-            readAheadBuffer.limit(read);
           }
           readInProgress = false;
           signalAsyncReadComplete();
@@ -230,7 +227,10 @@ public class ReadAheadInputStream extends InputStream {
 
   private void waitForAsyncReadComplete() throws IOException {
     stateChangeLock.lock();
+    isWaiting.set(true);
     try {
+      // There is only one reader, and one writer, so the writer should signal 
only once,
+      // but a while loop checking the wake up condition is still needed to 
avoid spurious wakeups.
       while (readInProgress) {
         asyncReadComplete.await();
       }
@@ -239,6 +239,7 @@ public class ReadAheadInputStream extends InputStream {
       iio.initCause(e);
       throw iio;
     } finally {
+      isWaiting.set(false);
       stateChangeLock.unlock();
     }
     checkReadException();
@@ -246,8 +247,13 @@ public class ReadAheadInputStream extends InputStream {
 
   @Override
   public int read() throws IOException {
-    byte[] oneByteArray = oneByte.get();
-    return read(oneByteArray, 0, 1) == -1 ? -1 : oneByteArray[0] & 0xFF;
+    if (activeBuffer.hasRemaining()) {
+      // short path - just get one byte.
+      return activeBuffer.get() & 0xFF;
+    } else {
+      byte[] oneByteArray = oneByte.get();
+      return read(oneByteArray, 0, 1) == -1 ? -1 : oneByteArray[0] & 0xFF;
+    }
   }
 
   @Override
@@ -258,54 +264,43 @@ public class ReadAheadInputStream extends InputStream {
     if (len == 0) {
       return 0;
     }
-    stateChangeLock.lock();
-    try {
-      return readInternal(b, offset, len);
-    } finally {
-      stateChangeLock.unlock();
-    }
-  }
 
-  /**
-   * flip the active and read ahead buffer
-   */
-  private void swapBuffers() {
-    ByteBuffer temp = activeBuffer;
-    activeBuffer = readAheadBuffer;
-    readAheadBuffer = temp;
-  }
-
-  /**
-   * Internal read function which should be called only from read() api. The 
assumption is that
-   * the stateChangeLock is already acquired in the caller before calling this 
function.
-   */
-  private int readInternal(byte[] b, int offset, int len) throws IOException {
-    assert (stateChangeLock.isLocked());
     if (!activeBuffer.hasRemaining()) {
-      waitForAsyncReadComplete();
-      if (readAheadBuffer.hasRemaining()) {
-        swapBuffers();
-      } else {
-        // The first read or activeBuffer is skipped.
-        readAsync();
+      // No remaining in active buffer - lock and switch to write ahead buffer.
+      stateChangeLock.lock();
+      try {
         waitForAsyncReadComplete();
-        if (isEndOfStream()) {
-          return -1;
+        if (!readAheadBuffer.hasRemaining()) {
+          // The first read.
+          readAsync();
+          waitForAsyncReadComplete();
+          if (isEndOfStream()) {
+            return -1;
+          }
         }
+        // Swap the newly read read ahead buffer in place of empty active 
buffer.
         swapBuffers();
+        // After swapping buffers, trigger another async read for read ahead 
buffer.
+        readAsync();
+      } finally {
+        stateChangeLock.unlock();
       }
-    } else {
-      checkReadException();
     }
     len = Math.min(len, activeBuffer.remaining());
     activeBuffer.get(b, offset, len);
 
-    if (activeBuffer.remaining() <= readAheadThresholdInBytes && 
!readAheadBuffer.hasRemaining()) {
-      readAsync();
-    }
     return len;
   }
 
+  /**
+   * flip the active and read ahead buffer
+   */
+  private void swapBuffers() {
+    ByteBuffer temp = activeBuffer;
+    activeBuffer = readAheadBuffer;
+    readAheadBuffer = temp;
+  }
+
   @Override
   public int available() throws IOException {
     stateChangeLock.lock();
@@ -323,6 +318,11 @@ public class ReadAheadInputStream extends InputStream {
     if (n <= 0L) {
       return 0L;
     }
+    if (n <= activeBuffer.remaining()) {
+      // Only skipping from active buffer is sufficient
+      activeBuffer.position((int) n + activeBuffer.position());
+      return n;
+    }
     stateChangeLock.lock();
     long skipped;
     try {
@@ -346,21 +346,14 @@ public class ReadAheadInputStream extends InputStream {
     if (available() >= n) {
       // we can skip from the internal buffers
       int toSkip = (int) n;
-      if (toSkip <= activeBuffer.remaining()) {
-        // Only skipping from active buffer is sufficient
-        activeBuffer.position(toSkip + activeBuffer.position());
-        if (activeBuffer.remaining() <= readAheadThresholdInBytes
-            && !readAheadBuffer.hasRemaining()) {
-          readAsync();
-        }
-        return n;
-      }
       // We need to skip from both active buffer and read ahead buffer
       toSkip -= activeBuffer.remaining();
+      assert(toSkip > 0); // skipping from activeBuffer already handled.
       activeBuffer.position(0);
       activeBuffer.flip();
       readAheadBuffer.position(toSkip + readAheadBuffer.position());
       swapBuffers();
+      // Trigger async read to emptied read ahead buffer.
       readAsync();
       return n;
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/7539ae59/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
index 2c53c8d..fb179d0 100644
--- 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
+++ 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -72,21 +72,15 @@ public final class UnsafeSorterSpillReader extends 
UnsafeSorterIterator implemen
       bufferSizeBytes = DEFAULT_BUFFER_SIZE_BYTES;
     }
 
-    final double readAheadFraction =
-        SparkEnv.get() == null ? 0.5 :
-             
SparkEnv.get().conf().getDouble("spark.unsafe.sorter.spill.read.ahead.fraction",
 0.5);
-
-    // SPARK-23310: Disable read-ahead input stream, because it is causing 
lock contention and perf
-    // regression for TPC-DS queries.
     final boolean readAheadEnabled = SparkEnv.get() != null &&
-        
SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled",
 false);
+        
SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled",
 true);
 
     final InputStream bs =
         new NioBufferedFileInputStream(file, (int) bufferSizeBytes);
     try {
       if (readAheadEnabled) {
         this.in = new 
ReadAheadInputStream(serializerManager.wrapStream(blockId, bs),
-                (int) bufferSizeBytes, (int) (bufferSizeBytes * 
readAheadFraction));
+                (int) bufferSizeBytes);
       } else {
         this.in = serializerManager.wrapStream(blockId, bs);
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/7539ae59/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java 
b/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java
index 3440e1a..22db359 100644
--- a/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java
+++ b/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java
@@ -37,7 +37,7 @@ public abstract class GenericFileInputStreamSuite {
 
   protected File inputFile;
 
-  protected InputStream inputStream;
+  protected InputStream[] inputStreams;
 
   @Before
   public void setUp() throws IOException {
@@ -54,77 +54,91 @@ public abstract class GenericFileInputStreamSuite {
 
   @Test
   public void testReadOneByte() throws IOException {
-    for (int i = 0; i < randomBytes.length; i++) {
-      assertEquals(randomBytes[i], (byte) inputStream.read());
+    for (InputStream inputStream: inputStreams) {
+      for (int i = 0; i < randomBytes.length; i++) {
+        assertEquals(randomBytes[i], (byte) inputStream.read());
+      }
     }
   }
 
   @Test
   public void testReadMultipleBytes() throws IOException {
-    byte[] readBytes = new byte[8 * 1024];
-    int i = 0;
-    while (i < randomBytes.length) {
-      int read = inputStream.read(readBytes, 0, 8 * 1024);
-      for (int j = 0; j < read; j++) {
-        assertEquals(randomBytes[i], readBytes[j]);
-        i++;
+    for (InputStream inputStream: inputStreams) {
+      byte[] readBytes = new byte[8 * 1024];
+      int i = 0;
+      while (i < randomBytes.length) {
+        int read = inputStream.read(readBytes, 0, 8 * 1024);
+        for (int j = 0; j < read; j++) {
+          assertEquals(randomBytes[i], readBytes[j]);
+          i++;
+        }
       }
     }
   }
 
   @Test
   public void testBytesSkipped() throws IOException {
-    assertEquals(1024, inputStream.skip(1024));
-    for (int i = 1024; i < randomBytes.length; i++) {
-      assertEquals(randomBytes[i], (byte) inputStream.read());
+    for (InputStream inputStream: inputStreams) {
+      assertEquals(1024, inputStream.skip(1024));
+      for (int i = 1024; i < randomBytes.length; i++) {
+        assertEquals(randomBytes[i], (byte) inputStream.read());
+      }
     }
   }
 
   @Test
   public void testBytesSkippedAfterRead() throws IOException {
-    for (int i = 0; i < 1024; i++) {
-      assertEquals(randomBytes[i], (byte) inputStream.read());
-    }
-    assertEquals(1024, inputStream.skip(1024));
-    for (int i = 2048; i < randomBytes.length; i++) {
-      assertEquals(randomBytes[i], (byte) inputStream.read());
+    for (InputStream inputStream: inputStreams) {
+      for (int i = 0; i < 1024; i++) {
+        assertEquals(randomBytes[i], (byte) inputStream.read());
+      }
+      assertEquals(1024, inputStream.skip(1024));
+      for (int i = 2048; i < randomBytes.length; i++) {
+        assertEquals(randomBytes[i], (byte) inputStream.read());
+      }
     }
   }
 
   @Test
   public void testNegativeBytesSkippedAfterRead() throws IOException {
-    for (int i = 0; i < 1024; i++) {
-      assertEquals(randomBytes[i], (byte) inputStream.read());
-    }
-    // Skipping negative bytes should essential be a no-op
-    assertEquals(0, inputStream.skip(-1));
-    assertEquals(0, inputStream.skip(-1024));
-    assertEquals(0, inputStream.skip(Long.MIN_VALUE));
-    assertEquals(1024, inputStream.skip(1024));
-    for (int i = 2048; i < randomBytes.length; i++) {
-      assertEquals(randomBytes[i], (byte) inputStream.read());
+    for (InputStream inputStream: inputStreams) {
+      for (int i = 0; i < 1024; i++) {
+        assertEquals(randomBytes[i], (byte) inputStream.read());
+      }
+      // Skipping negative bytes should essential be a no-op
+      assertEquals(0, inputStream.skip(-1));
+      assertEquals(0, inputStream.skip(-1024));
+      assertEquals(0, inputStream.skip(Long.MIN_VALUE));
+      assertEquals(1024, inputStream.skip(1024));
+      for (int i = 2048; i < randomBytes.length; i++) {
+        assertEquals(randomBytes[i], (byte) inputStream.read());
+      }
     }
   }
 
   @Test
   public void testSkipFromFileChannel() throws IOException {
-    // Since the buffer is smaller than the skipped bytes, this will guarantee
-    // we skip from underlying file channel.
-    assertEquals(1024, inputStream.skip(1024));
-    for (int i = 1024; i < 2048; i++) {
-      assertEquals(randomBytes[i], (byte) inputStream.read());
-    }
-    assertEquals(256, inputStream.skip(256));
-    assertEquals(256, inputStream.skip(256));
-    assertEquals(512, inputStream.skip(512));
-    for (int i = 3072; i < randomBytes.length; i++) {
-      assertEquals(randomBytes[i], (byte) inputStream.read());
+    for (InputStream inputStream: inputStreams) {
+      // Since the buffer is smaller than the skipped bytes, this will 
guarantee
+      // we skip from underlying file channel.
+      assertEquals(1024, inputStream.skip(1024));
+      for (int i = 1024; i < 2048; i++) {
+        assertEquals(randomBytes[i], (byte) inputStream.read());
+      }
+      assertEquals(256, inputStream.skip(256));
+      assertEquals(256, inputStream.skip(256));
+      assertEquals(512, inputStream.skip(512));
+      for (int i = 3072; i < randomBytes.length; i++) {
+        assertEquals(randomBytes[i], (byte) inputStream.read());
+      }
     }
   }
 
   @Test
   public void testBytesSkippedAfterEOF() throws IOException {
-    assertEquals(randomBytes.length, inputStream.skip(randomBytes.length + 1));
-    assertEquals(-1, inputStream.read());
+    for (InputStream inputStream: inputStreams) {
+      assertEquals(randomBytes.length, inputStream.skip(randomBytes.length + 
1));
+      assertEquals(-1, inputStream.read());
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7539ae59/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java 
b/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java
index 211b33a..a320f86 100644
--- a/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java
+++ b/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java
@@ -18,6 +18,7 @@ package org.apache.spark.io;
 
 import org.junit.Before;
 
+import java.io.InputStream;
 import java.io.IOException;
 
 /**
@@ -28,6 +29,9 @@ public class NioBufferedInputStreamSuite extends 
GenericFileInputStreamSuite {
   @Before
   public void setUp() throws IOException {
     super.setUp();
-    inputStream = new NioBufferedFileInputStream(inputFile);
+    inputStreams = new InputStream[] {
+      new NioBufferedFileInputStream(inputFile), // default
+      new NioBufferedFileInputStream(inputFile, 123) // small, unaligned buffer
+    };
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7539ae59/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java 
b/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java
index 918ddc4..bfa1e0b 100644
--- a/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java
+++ b/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java
@@ -19,16 +19,27 @@ package org.apache.spark.io;
 import org.junit.Before;
 
 import java.io.IOException;
+import java.io.InputStream;
 
 /**
- * Tests functionality of {@link NioBufferedFileInputStream}
+ * Tests functionality of {@link ReadAheadInputStreamSuite}
  */
 public class ReadAheadInputStreamSuite extends GenericFileInputStreamSuite {
 
   @Before
   public void setUp() throws IOException {
     super.setUp();
-    inputStream = new ReadAheadInputStream(
-        new NioBufferedFileInputStream(inputFile), 8 * 1024, 4 * 1024);
+    inputStreams = new InputStream[] {
+      // Tests equal and aligned buffers of wrapped an outer stream.
+      new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 8 * 
1024), 8 * 1024),
+      // Tests aligned buffers, wrapped bigger than outer.
+      new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 3 * 
1024), 2 * 1024),
+      // Tests aligned buffers, wrapped smaller than outer.
+      new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 2 * 
1024), 3 * 1024),
+      // Tests unaligned buffers, wrapped bigger than outer.
+      new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 321), 
123),
+      // Tests unaligned buffers, wrapped smaller than outer.
+      new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile, 123), 
321)
+    };
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to