Repository: spark
Updated Branches:
  refs/heads/master deb9588b2 -> 3eee9e024


[SPARK-25535][CORE] Work around bad error handling in commons-crypto.

The commons-crypto library does some questionable error handling internally,
which can lead to JVM crashes if some call into native code fails and cleans
up state it should not.

While the library is not fixed, this change adds some workarounds in Spark code
so that when an error is detected in the commons-crypto side, Spark avoids
calling into the library further.

Tested with existing and added unit tests.

Closes #22557 from vanzin/SPARK-25535.

Authored-by: Marcelo Vanzin <van...@cloudera.com>
Signed-off-by: Imran Rashid <iras...@cloudera.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3eee9e02
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3eee9e02
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3eee9e02

Branch: refs/heads/master
Commit: 3eee9e02463e10570a29fad00823c953debd945e
Parents: deb9588
Author: Marcelo Vanzin <van...@cloudera.com>
Authored: Tue Oct 9 09:27:08 2018 -0500
Committer: Imran Rashid <iras...@cloudera.com>
Committed: Tue Oct 9 09:27:08 2018 -0500

----------------------------------------------------------------------
 .../apache/spark/network/crypto/AuthEngine.java |  95 ++++++++-----
 .../spark/network/crypto/TransportCipher.java   |  60 ++++++--
 .../spark/network/crypto/AuthEngineSuite.java   |  17 +++
 .../spark/security/CryptoStreamUtils.scala      | 137 +++++++++++++++++--
 .../spark/security/CryptoStreamUtilsSuite.scala |  37 ++++-
 5 files changed, 295 insertions(+), 51 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3eee9e02/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java
----------------------------------------------------------------------
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java
 
b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java
index 056505e..64fdb32 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java
@@ -159,15 +159,21 @@ class AuthEngine implements Closeable {
     // accurately report the errors when they happen.
     RuntimeException error = null;
     byte[] dummy = new byte[8];
-    try {
-      doCipherOp(encryptor, dummy, true);
-    } catch (Exception e) {
-      error = new RuntimeException(e);
+    if (encryptor != null) {
+      try {
+        doCipherOp(Cipher.ENCRYPT_MODE, dummy, true);
+      } catch (Exception e) {
+        error = new RuntimeException(e);
+      }
+      encryptor = null;
     }
-    try {
-      doCipherOp(decryptor, dummy, true);
-    } catch (Exception e) {
-      error = new RuntimeException(e);
+    if (decryptor != null) {
+      try {
+        doCipherOp(Cipher.DECRYPT_MODE, dummy, true);
+      } catch (Exception e) {
+        error = new RuntimeException(e);
+      }
+      decryptor = null;
     }
     random.close();
 
@@ -189,11 +195,11 @@ class AuthEngine implements Closeable {
   }
 
   private byte[] decrypt(byte[] in) throws GeneralSecurityException {
-    return doCipherOp(decryptor, in, false);
+    return doCipherOp(Cipher.DECRYPT_MODE, in, false);
   }
 
   private byte[] encrypt(byte[] in) throws GeneralSecurityException {
-    return doCipherOp(encryptor, in, false);
+    return doCipherOp(Cipher.ENCRYPT_MODE, in, false);
   }
 
   private void initializeForAuth(String cipher, byte[] nonce, SecretKeySpec 
key)
@@ -205,11 +211,13 @@ class AuthEngine implements Closeable {
     byte[] iv = new byte[conf.ivLength()];
     System.arraycopy(nonce, 0, iv, 0, Math.min(nonce.length, iv.length));
 
-    encryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf);
-    encryptor.init(Cipher.ENCRYPT_MODE, key, new IvParameterSpec(iv));
+    CryptoCipher _encryptor = CryptoCipherFactory.getCryptoCipher(cipher, 
cryptoConf);
+    _encryptor.init(Cipher.ENCRYPT_MODE, key, new IvParameterSpec(iv));
+    this.encryptor = _encryptor;
 
-    decryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf);
-    decryptor.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(iv));
+    CryptoCipher _decryptor = CryptoCipherFactory.getCryptoCipher(cipher, 
cryptoConf);
+    _decryptor.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(iv));
+    this.decryptor = _decryptor;
   }
 
   /**
@@ -241,29 +249,52 @@ class AuthEngine implements Closeable {
     return new SecretKeySpec(key.getEncoded(), conf.keyAlgorithm());
   }
 
-  private byte[] doCipherOp(CryptoCipher cipher, byte[] in, boolean isFinal)
+  private byte[] doCipherOp(int mode, byte[] in, boolean isFinal)
     throws GeneralSecurityException {
 
-    Preconditions.checkState(cipher != null);
+    CryptoCipher cipher;
+    switch (mode) {
+      case Cipher.ENCRYPT_MODE:
+        cipher = encryptor;
+        break;
+      case Cipher.DECRYPT_MODE:
+        cipher = decryptor;
+        break;
+      default:
+        throw new IllegalArgumentException(String.valueOf(mode));
+    }
 
-    int scale = 1;
-    while (true) {
-      int size = in.length * scale;
-      byte[] buffer = new byte[size];
-      try {
-        int outSize = isFinal ? cipher.doFinal(in, 0, in.length, buffer, 0)
-          : cipher.update(in, 0, in.length, buffer, 0);
-        if (outSize != buffer.length) {
-          byte[] output = new byte[outSize];
-          System.arraycopy(buffer, 0, output, 0, output.length);
-          return output;
-        } else {
-          return buffer;
+    Preconditions.checkState(cipher != null, "Cipher is invalid because of 
previous error.");
+
+    try {
+      int scale = 1;
+      while (true) {
+        int size = in.length * scale;
+        byte[] buffer = new byte[size];
+        try {
+          int outSize = isFinal ? cipher.doFinal(in, 0, in.length, buffer, 0)
+            : cipher.update(in, 0, in.length, buffer, 0);
+          if (outSize != buffer.length) {
+            byte[] output = new byte[outSize];
+            System.arraycopy(buffer, 0, output, 0, output.length);
+            return output;
+          } else {
+            return buffer;
+          }
+        } catch (ShortBufferException e) {
+          // Try again with a bigger buffer.
+          scale *= 2;
         }
-      } catch (ShortBufferException e) {
-        // Try again with a bigger buffer.
-        scale *= 2;
       }
+    } catch (InternalError ie) {
+      // SPARK-25535. The commons-cryto library will throw InternalError if 
something goes wrong,
+      // and leave bad state behind in the Java wrappers, so it's not safe to 
use them afterwards.
+      if (mode == Cipher.ENCRYPT_MODE) {
+        this.encryptor = null;
+      } else {
+        this.decryptor = null;
+      }
+      throw ie;
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/3eee9e02/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
----------------------------------------------------------------------
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
 
b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
index b64e4b7..2745052 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
@@ -107,45 +107,72 @@ public class TransportCipher {
   private static class EncryptionHandler extends ChannelOutboundHandlerAdapter 
{
     private final ByteArrayWritableChannel byteChannel;
     private final CryptoOutputStream cos;
+    private boolean isCipherValid;
 
     EncryptionHandler(TransportCipher cipher) throws IOException {
       byteChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE);
       cos = cipher.createOutputStream(byteChannel);
+      isCipherValid = true;
     }
 
     @Override
     public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise 
promise)
       throws Exception {
-      ctx.write(new EncryptedMessage(cos, msg, byteChannel), promise);
+      ctx.write(new EncryptedMessage(this, cos, msg, byteChannel), promise);
     }
 
     @Override
     public void close(ChannelHandlerContext ctx, ChannelPromise promise) 
throws Exception {
       try {
-        cos.close();
+        if (isCipherValid) {
+          cos.close();
+        }
       } finally {
         super.close(ctx, promise);
       }
     }
+
+    /**
+     * SPARK-25535. Workaround for CRYPTO-141. Avoid further interaction with 
the underlying cipher
+     * after an error occurs.
+     */
+    void reportError() {
+      this.isCipherValid = false;
+    }
+
+    boolean isCipherValid() {
+      return isCipherValid;
+    }
   }
 
   private static class DecryptionHandler extends ChannelInboundHandlerAdapter {
     private final CryptoInputStream cis;
     private final ByteArrayReadableChannel byteChannel;
+    private boolean isCipherValid;
 
     DecryptionHandler(TransportCipher cipher) throws IOException {
       byteChannel = new ByteArrayReadableChannel();
       cis = cipher.createInputStream(byteChannel);
+      isCipherValid = true;
     }
 
     @Override
     public void channelRead(ChannelHandlerContext ctx, Object data) throws 
Exception {
+      if (!isCipherValid) {
+        throw new IOException("Cipher is in invalid state.");
+      }
       byteChannel.feedData((ByteBuf) data);
 
       byte[] decryptedData = new byte[byteChannel.readableBytes()];
       int offset = 0;
       while (offset < decryptedData.length) {
-        offset += cis.read(decryptedData, offset, decryptedData.length - 
offset);
+        // SPARK-25535: workaround for CRYPTO-141.
+        try {
+          offset += cis.read(decryptedData, offset, decryptedData.length - 
offset);
+        } catch (InternalError ie) {
+          isCipherValid = false;
+          throw ie;
+        }
       }
 
       ctx.fireChannelRead(Unpooled.wrappedBuffer(decryptedData, 0, 
decryptedData.length));
@@ -154,7 +181,9 @@ public class TransportCipher {
     @Override
     public void channelInactive(ChannelHandlerContext ctx) throws Exception {
       try {
-        cis.close();
+        if (isCipherValid) {
+          cis.close();
+        }
       } finally {
         super.channelInactive(ctx);
       }
@@ -165,8 +194,9 @@ public class TransportCipher {
     private final boolean isByteBuf;
     private final ByteBuf buf;
     private final FileRegion region;
+    private final CryptoOutputStream cos;
+    private final EncryptionHandler handler;
     private long transferred;
-    private CryptoOutputStream cos;
 
     // Due to streaming issue CRYPTO-125: 
https://issues.apache.org/jira/browse/CRYPTO-125, it has
     // to utilize two helper ByteArrayWritableChannel for streaming. One is 
used to receive raw data
@@ -176,9 +206,14 @@ public class TransportCipher {
 
     private ByteBuffer currentEncrypted;
 
-    EncryptedMessage(CryptoOutputStream cos, Object msg, 
ByteArrayWritableChannel ch) {
+    EncryptedMessage(
+        EncryptionHandler handler,
+        CryptoOutputStream cos,
+        Object msg,
+        ByteArrayWritableChannel ch) {
       Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof 
FileRegion,
         "Unrecognized message type: %s", msg.getClass().getName());
+      this.handler = handler;
       this.isByteBuf = msg instanceof ByteBuf;
       this.buf = isByteBuf ? (ByteBuf) msg : null;
       this.region = isByteBuf ? null : (FileRegion) msg;
@@ -261,6 +296,9 @@ public class TransportCipher {
     }
 
     private void encryptMore() throws IOException {
+      if (!handler.isCipherValid()) {
+        throw new IOException("Cipher is in invalid state.");
+      }
       byteRawChannel.reset();
 
       if (isByteBuf) {
@@ -269,8 +307,14 @@ public class TransportCipher {
       } else {
         region.transferTo(byteRawChannel, region.transferred());
       }
-      cos.write(byteRawChannel.getData(), 0, byteRawChannel.length());
-      cos.flush();
+
+      try {
+        cos.write(byteRawChannel.getData(), 0, byteRawChannel.length());
+        cos.flush();
+      } catch (InternalError ie) {
+        handler.reportError();
+        throw ie;
+      }
 
       currentEncrypted = ByteBuffer.wrap(byteEncChannel.getData(),
         0, byteEncChannel.length());

http://git-wip-us.apache.org/repos/asf/spark/blob/3eee9e02/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java
----------------------------------------------------------------------
diff --git 
a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java
 
b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java
index a3519fe..c0aa298 100644
--- 
a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java
+++ 
b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java
@@ -18,8 +18,11 @@
 package org.apache.spark.network.crypto;
 
 import java.util.Arrays;
+import java.util.Map;
+import java.security.InvalidKeyException;
 import static java.nio.charset.StandardCharsets.UTF_8;
 
+import com.google.common.collect.ImmutableMap;
 import org.junit.BeforeClass;
 import org.junit.Test;
 import static org.junit.Assert.*;
@@ -104,4 +107,18 @@ public class AuthEngineSuite {
       challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge));
   }
 
+  @Test(expected = InvalidKeyException.class)
+  public void testBadKeySize() throws Exception {
+    Map<String, String> mconf = 
ImmutableMap.of("spark.network.crypto.keyLength", "42");
+    TransportConf conf = new TransportConf("rpc", new 
MapConfigProvider(mconf));
+
+    try (AuthEngine engine = new AuthEngine("appId", "secret", conf)) {
+      engine.challenge();
+      fail("Should have failed to create challenge message.");
+
+      // Call close explicitly to make sure it's idempotent.
+      engine.close();
+    }
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3eee9e02/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala 
b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
index 0062197..18b735b 100644
--- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
+++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
@@ -16,7 +16,7 @@
  */
 package org.apache.spark.security
 
-import java.io.{InputStream, OutputStream}
+import java.io.{Closeable, InputStream, IOException, OutputStream}
 import java.nio.ByteBuffer
 import java.nio.channels.{ReadableByteChannel, WritableByteChannel}
 import java.util.Properties
@@ -54,8 +54,10 @@ private[spark] object CryptoStreamUtils extends Logging {
     val params = new CryptoParams(key, sparkConf)
     val iv = createInitializationVector(params.conf)
     os.write(iv)
-    new CryptoOutputStream(params.transformation, params.conf, os, 
params.keySpec,
-      new IvParameterSpec(iv))
+    new ErrorHandlingOutputStream(
+      new CryptoOutputStream(params.transformation, params.conf, os, 
params.keySpec,
+        new IvParameterSpec(iv)),
+      os)
   }
 
   /**
@@ -70,8 +72,10 @@ private[spark] object CryptoStreamUtils extends Logging {
     val helper = new CryptoHelperChannel(channel)
 
     helper.write(ByteBuffer.wrap(iv))
-    new CryptoOutputStream(params.transformation, params.conf, helper, 
params.keySpec,
-      new IvParameterSpec(iv))
+    new ErrorHandlingWritableChannel(
+      new CryptoOutputStream(params.transformation, params.conf, helper, 
params.keySpec,
+        new IvParameterSpec(iv)),
+      helper)
   }
 
   /**
@@ -84,8 +88,10 @@ private[spark] object CryptoStreamUtils extends Logging {
     val iv = new Array[Byte](IV_LENGTH_IN_BYTES)
     ByteStreams.readFully(is, iv)
     val params = new CryptoParams(key, sparkConf)
-    new CryptoInputStream(params.transformation, params.conf, is, 
params.keySpec,
-      new IvParameterSpec(iv))
+    new ErrorHandlingInputStream(
+      new CryptoInputStream(params.transformation, params.conf, is, 
params.keySpec,
+        new IvParameterSpec(iv)),
+      is)
   }
 
   /**
@@ -100,8 +106,10 @@ private[spark] object CryptoStreamUtils extends Logging {
     JavaUtils.readFully(channel, buf)
 
     val params = new CryptoParams(key, sparkConf)
-    new CryptoInputStream(params.transformation, params.conf, channel, 
params.keySpec,
-      new IvParameterSpec(iv))
+    new ErrorHandlingReadableChannel(
+      new CryptoInputStream(params.transformation, params.conf, channel, 
params.keySpec,
+        new IvParameterSpec(iv)),
+      channel)
   }
 
   def toCryptoConf(conf: SparkConf): Properties = {
@@ -157,6 +165,117 @@ private[spark] object CryptoStreamUtils extends Logging {
 
   }
 
+  /**
+   * SPARK-25535. The commons-cryto library will throw InternalError if 
something goes
+   * wrong, and leave bad state behind in the Java wrappers, so it's not safe 
to use them
+   * afterwards. This wrapper detects that situation and avoids further calls 
into the
+   * commons-crypto code, while still allowing the underlying streams to be 
closed.
+   *
+   * This should be removed once CRYPTO-141 is fixed (and Spark upgrades its 
commons-crypto
+   * dependency).
+   */
+  trait BaseErrorHandler extends Closeable {
+
+    private var closed = false
+
+    /** The encrypted stream that may get into an unhealthy state. */
+    protected def cipherStream: Closeable
+
+    /**
+     * The underlying stream that is being wrapped by the encrypted stream, so 
that it can be
+     * closed even if there's an error in the crypto layer.
+     */
+    protected def original: Closeable
+
+    protected def safeCall[T](fn: => T): T = {
+      if (closed) {
+        throw new IOException("Cipher stream is closed.")
+      }
+      try {
+        fn
+      } catch {
+        case ie: InternalError =>
+          closed = true
+          original.close()
+          throw ie
+      }
+    }
+
+    override def close(): Unit = {
+      if (!closed) {
+        cipherStream.close()
+      }
+    }
+
+  }
+
+  // Visible for testing.
+  class ErrorHandlingReadableChannel(
+      protected val cipherStream: ReadableByteChannel,
+      protected val original: ReadableByteChannel)
+    extends ReadableByteChannel with BaseErrorHandler {
+
+    override def read(src: ByteBuffer): Int = safeCall {
+      cipherStream.read(src)
+    }
+
+    override def isOpen(): Boolean = cipherStream.isOpen()
+
+  }
+
+  private class ErrorHandlingInputStream(
+      protected val cipherStream: InputStream,
+      protected val original: InputStream)
+    extends InputStream with BaseErrorHandler {
+
+    override def read(b: Array[Byte]): Int = safeCall {
+      cipherStream.read(b)
+    }
+
+    override def read(b: Array[Byte], off: Int, len: Int): Int = safeCall {
+      cipherStream.read(b, off, len)
+    }
+
+    override def read(): Int = safeCall {
+      cipherStream.read()
+    }
+  }
+
+  private class ErrorHandlingWritableChannel(
+      protected val cipherStream: WritableByteChannel,
+      protected val original: WritableByteChannel)
+    extends WritableByteChannel with BaseErrorHandler {
+
+    override def write(src: ByteBuffer): Int = safeCall {
+      cipherStream.write(src)
+    }
+
+    override def isOpen(): Boolean = cipherStream.isOpen()
+
+  }
+
+  private class ErrorHandlingOutputStream(
+      protected val cipherStream: OutputStream,
+      protected val original: OutputStream)
+    extends OutputStream with BaseErrorHandler {
+
+    override def flush(): Unit = safeCall {
+      cipherStream.flush()
+    }
+
+    override def write(b: Array[Byte]): Unit = safeCall {
+      cipherStream.write(b)
+    }
+
+    override def write(b: Array[Byte], off: Int, len: Int): Unit = safeCall {
+      cipherStream.write(b, off, len)
+    }
+
+    override def write(b: Int): Unit = safeCall {
+      cipherStream.write(b)
+    }
+  }
+
   private class CryptoParams(key: Array[Byte], sparkConf: SparkConf) {
 
     val keySpec = new SecretKeySpec(key, "AES")

http://git-wip-us.apache.org/repos/asf/spark/blob/3eee9e02/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala 
b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
index 78f618f..0d3611c 100644
--- a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
@@ -16,13 +16,16 @@
  */
 package org.apache.spark.security
 
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, 
FileOutputStream}
-import java.nio.channels.Channels
+import java.io._
+import java.nio.ByteBuffer
+import java.nio.channels.{Channels, ReadableByteChannel}
 import java.nio.charset.StandardCharsets.UTF_8
 import java.nio.file.Files
 import java.util.{Arrays, Random, UUID}
 
 import com.google.common.io.ByteStreams
+import org.mockito.Matchers.any
+import org.mockito.Mockito._
 
 import org.apache.spark._
 import org.apache.spark.internal.config._
@@ -164,6 +167,36 @@ class CryptoStreamUtilsSuite extends SparkFunSuite {
     }
   }
 
+  test("error handling wrapper") {
+    val wrapped = mock(classOf[ReadableByteChannel])
+    val decrypted = mock(classOf[ReadableByteChannel])
+    val errorHandler = new 
CryptoStreamUtils.ErrorHandlingReadableChannel(decrypted, wrapped)
+
+    when(decrypted.read(any(classOf[ByteBuffer])))
+      .thenThrow(new IOException())
+      .thenThrow(new InternalError())
+      .thenReturn(1)
+
+    val out = ByteBuffer.allocate(1)
+    intercept[IOException] {
+      errorHandler.read(out)
+    }
+    intercept[InternalError] {
+      errorHandler.read(out)
+    }
+
+    val e = intercept[IOException] {
+      errorHandler.read(out)
+    }
+    assert(e.getMessage().contains("is closed"))
+    errorHandler.close()
+
+    verify(decrypted, times(2)).read(any(classOf[ByteBuffer]))
+    verify(wrapped, never()).read(any(classOf[ByteBuffer]))
+    verify(decrypted, never()).close()
+    verify(wrapped, times(1)).close()
+  }
+
   private def createConf(extra: (String, String)*): SparkConf = {
     val conf = new SparkConf()
     extra.foreach { case (k, v) => conf.set(k, v) }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to