This is an automated email from the ASF dual-hosted git repository.

wuyi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f0563ef64c7f [SPARK-47172][CORE] Add support for AES-GCM for RPC 
encryption
f0563ef64c7f is described below

commit f0563ef64c7f42df21b16ccaeb4cf1324ea720f9
Author: Steve Weis <[email protected]>
AuthorDate: Fri Jun 21 21:07:37 2024 +0800

    [SPARK-47172][CORE] Add support for AES-GCM for RPC encryption
    
    ### What changes were proposed in this pull request?
    
    This change adds AES-GCM as an optional AES cipher mode for RPC encryption. 
The current default is using AES-CTR without any authentication. That would 
allow someone on the network to easily modify RPC contents on the wire and 
impact Spark behavior. See 
[SPARK-47172](https://issues.apache.org/jira/browse/SPARK-47172) for more 
details.
    
    ### Why are the changes needed?
    
    The current default is using AES-CTR without any authentication. That would 
allow someone on the network to easily modify RPC contents on the wire and 
impact Spark behavior.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, it adds an additional configuration flag is reflected in the 
documentation.
    
    ### How was this patch tested?
    Existing unit tests are all ensured to pass. New unit tests are written to 
explicitly test GCM support and to verify that modifying ciphertext content 
will cause an exception and fail.
    
    `build/sbt "network-common/test:testOnly"`
    `build/sbt "network-common/test:testOnly 
org.apache.spark.network.crypto.AuthIntegrationSuite"`
    `build/sbt "network-common/test:testOnly 
org.apache.spark.network.crypto.AuthEngineSuite"`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    Nope.
    
    Closes #46515 from sweisdb/SPARK-47172.
    
    Authored-by: Steve Weis <[email protected]>
    Signed-off-by: Yi Wu <[email protected]>
---
 .../apache/spark/network/crypto/AuthEngine.java    |  21 +-
 ...ransportCipher.java => CtrTransportCipher.java} |  29 +-
 .../spark/network/crypto/GcmTransportCipher.java   | 410 +++++++++++++++++++++
 .../spark/network/crypto/TransportCipher.java      | 374 ++-----------------
 .../network/util/ByteBufferWriteableChannel.java   |  59 +++
 .../spark/network/crypto/AuthEngineSuite.java      | 203 +++-------
 .../spark/network/crypto/AuthIntegrationSuite.java |  79 ++--
 .../spark/network/crypto/CtrAuthEngineSuite.java   | 177 +++++++++
 .../spark/network/crypto/GcmAuthEngineSuite.java   | 339 +++++++++++++++++
 .../spark/network/crypto/TransportCipherSuite.java |   4 +-
 docs/security.md                                   |   9 +
 11 files changed, 1150 insertions(+), 554 deletions(-)

diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java
 
b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java
index cb68cfb5a0e8..8449a774a404 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java
@@ -45,6 +45,8 @@ class AuthEngine implements Closeable {
   public static final byte[] INPUT_IV_INFO = "inputIv".getBytes(UTF_8);
   public static final byte[] OUTPUT_IV_INFO = "outputIv".getBytes(UTF_8);
   private static final String MAC_ALGORITHM = "HMACSHA256";
+  private static final String LEGACY_CIPHER_ALGORITHM = "AES/CTR/NoPadding";
+  private static final String CIPHER_ALGORITHM = "AES/GCM/NoPadding";
   private static final int AES_GCM_KEY_SIZE_BYTES = 16;
   private static final byte[] EMPTY_TRANSCRIPT = new byte[0];
   private static final int UNSAFE_SKIP_HKDF_VERSION = 1;
@@ -227,12 +229,19 @@ class AuthEngine implements Closeable {
         OUTPUT_IV_INFO,  // This is the HKDF info field used to differentiate 
IV values
         AES_GCM_KEY_SIZE_BYTES);
     SecretKeySpec sessionKey = new SecretKeySpec(derivedKey, "AES");
-    return new TransportCipher(
-        cryptoConf,
-        conf.cipherTransformation(),
-        sessionKey,
-        isClient ? clientIv : serverIv,  // If it's the client, use the client 
IV first
-        isClient ? serverIv : clientIv);
+    if (LEGACY_CIPHER_ALGORITHM.equalsIgnoreCase(conf.cipherTransformation())) 
{
+      return new CtrTransportCipher(
+          cryptoConf,
+          sessionKey,
+          isClient ? clientIv : serverIv,  // If it's the client, use the 
client IV first
+          isClient ? serverIv : clientIv);
+    } else if (CIPHER_ALGORITHM.equalsIgnoreCase(conf.cipherTransformation())) 
{
+      return new GcmTransportCipher(sessionKey);
+    } else {
+      throw new IllegalArgumentException(
+              String.format("Unsupported cipher mode: %s. %s and %s are 
supported.",
+                      conf.cipherTransformation(), CIPHER_ALGORITHM, 
LEGACY_CIPHER_ALGORITHM));
+    }
   }
 
   private byte[] getTranscript(AuthMessage... encryptedPublicKeys) {
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
 
b/common/network-common/src/main/java/org/apache/spark/network/crypto/CtrTransportCipher.java
similarity index 92%
copy from 
common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
copy to 
common/network-common/src/main/java/org/apache/spark/network/crypto/CtrTransportCipher.java
index b507f911fe11..85b893751b39 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/crypto/CtrTransportCipher.java
@@ -21,6 +21,7 @@ import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.nio.channels.ReadableByteChannel;
 import java.nio.channels.WritableByteChannel;
+import java.security.GeneralSecurityException;
 import java.util.Properties;
 import javax.crypto.spec.SecretKeySpec;
 import javax.crypto.spec.IvParameterSpec;
@@ -40,34 +41,36 @@ import 
org.apache.spark.network.util.ByteArrayWritableChannel;
 /**
  * Cipher for encryption and decryption.
  */
-public class TransportCipher {
+public class CtrTransportCipher implements TransportCipher {
   @VisibleForTesting
-  static final String ENCRYPTION_HANDLER_NAME = "TransportEncryption";
-  private static final String DECRYPTION_HANDLER_NAME = "TransportDecryption";
+  static final String ENCRYPTION_HANDLER_NAME = "CtrTransportEncryption";
+  private static final String DECRYPTION_HANDLER_NAME = 
"CtrTransportDecryption";
   @VisibleForTesting
   static final int STREAM_BUFFER_SIZE = 1024 * 32;
 
   private final Properties conf;
-  private final String cipher;
+  private static final String CIPHER_ALGORITHM = "AES/CTR/NoPadding";
   private final SecretKeySpec key;
   private final byte[] inIv;
   private final byte[] outIv;
 
-  public TransportCipher(
+  public CtrTransportCipher(
       Properties conf,
-      String cipher,
       SecretKeySpec key,
       byte[] inIv,
       byte[] outIv) {
     this.conf = conf;
-    this.cipher = cipher;
     this.key = key;
     this.inIv = inIv;
     this.outIv = outIv;
   }
 
-  public String getCipherTransformation() {
-    return cipher;
+  /*
+   * This method is for testing purposes only.
+   */
+  @VisibleForTesting
+  public String getKeyId() throws GeneralSecurityException {
+    return TransportCipherUtil.getKeyId(key);
   }
 
   @VisibleForTesting
@@ -87,12 +90,12 @@ public class TransportCipher {
 
   @VisibleForTesting
   CryptoOutputStream createOutputStream(WritableByteChannel ch) throws 
IOException {
-    return new CryptoOutputStream(cipher, conf, ch, key, new 
IvParameterSpec(outIv));
+    return new CryptoOutputStream(CIPHER_ALGORITHM, conf, ch, key, new 
IvParameterSpec(outIv));
   }
 
   @VisibleForTesting
   CryptoInputStream createInputStream(ReadableByteChannel ch) throws 
IOException {
-    return new CryptoInputStream(cipher, conf, ch, key, new 
IvParameterSpec(inIv));
+    return new CryptoInputStream(CIPHER_ALGORITHM, conf, ch, key, new 
IvParameterSpec(inIv));
   }
 
   /**
@@ -114,7 +117,7 @@ public class TransportCipher {
     private final ByteArrayWritableChannel byteRawChannel;
     private boolean isCipherValid;
 
-    EncryptionHandler(TransportCipher cipher) throws IOException {
+    EncryptionHandler(CtrTransportCipher cipher) throws IOException {
       byteEncChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE);
       cos = cipher.createOutputStream(byteEncChannel);
       byteRawChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE);
@@ -161,7 +164,7 @@ public class TransportCipher {
     private final ByteArrayReadableChannel byteChannel;
     private boolean isCipherValid;
 
-    DecryptionHandler(TransportCipher cipher) throws IOException {
+    DecryptionHandler(CtrTransportCipher cipher) throws IOException {
       byteChannel = new ByteArrayReadableChannel();
       cis = cipher.createInputStream(byteChannel);
       isCipherValid = true;
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java
 
b/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java
new file mode 100644
index 000000000000..c3540838bef0
--- /dev/null
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java
@@ -0,0 +1,410 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import com.google.common.primitives.Longs;
+import com.google.crypto.tink.subtle.*;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.*;
+import io.netty.util.ReferenceCounted;
+import org.apache.spark.network.util.AbstractFileRegion;
+import org.apache.spark.network.util.ByteBufferWriteableChannel;
+
+import javax.crypto.spec.SecretKeySpec;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.WritableByteChannel;
+import java.security.GeneralSecurityException;
+import java.security.InvalidAlgorithmParameterException;
+
+public class GcmTransportCipher implements TransportCipher {
+    private static final String HKDF_ALG = "HmacSha256";
+    private static final int LENGTH_HEADER_BYTES = 8;
+    @VisibleForTesting
+    static final int CIPHERTEXT_BUFFER_SIZE = 32 * 1024; // 32KB
+    private final SecretKeySpec aesKey;
+
+    public GcmTransportCipher(SecretKeySpec aesKey)  {
+        this.aesKey = aesKey;
+    }
+
+    AesGcmHkdfStreaming getAesGcmHkdfStreaming() throws 
InvalidAlgorithmParameterException {
+        return new AesGcmHkdfStreaming(
+            aesKey.getEncoded(),
+            HKDF_ALG,
+            aesKey.getEncoded().length,
+            CIPHERTEXT_BUFFER_SIZE,
+            0);
+    }
+
+    /*
+     * This method is for testing purposes only.
+     */
+    @VisibleForTesting
+    public String getKeyId() throws GeneralSecurityException {
+        return TransportCipherUtil.getKeyId(aesKey);
+    }
+
+    @VisibleForTesting
+    EncryptionHandler getEncryptionHandler() throws GeneralSecurityException {
+        return new EncryptionHandler();
+    }
+
+    @VisibleForTesting
+    DecryptionHandler getDecryptionHandler() throws GeneralSecurityException {
+        return new DecryptionHandler();
+    }
+
+    public void addToChannel(Channel ch) throws GeneralSecurityException {
+        ch.pipeline()
+            .addFirst("GcmTransportEncryption", getEncryptionHandler())
+            .addFirst("GcmTransportDecryption", getDecryptionHandler());
+    }
+
+    @VisibleForTesting
+    class EncryptionHandler extends ChannelOutboundHandlerAdapter {
+        private final ByteBuffer plaintextBuffer;
+        private final ByteBuffer ciphertextBuffer;
+        private final AesGcmHkdfStreaming aesGcmHkdfStreaming;
+
+        EncryptionHandler() throws InvalidAlgorithmParameterException {
+            aesGcmHkdfStreaming = getAesGcmHkdfStreaming();
+            plaintextBuffer = 
ByteBuffer.allocate(aesGcmHkdfStreaming.getPlaintextSegmentSize());
+            ciphertextBuffer = 
ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize());
+        }
+
+        @Override
+        public void write(ChannelHandlerContext ctx, Object msg, 
ChannelPromise promise)
+                throws Exception {
+            GcmEncryptedMessage encryptedMessage = new GcmEncryptedMessage(
+                    aesGcmHkdfStreaming,
+                    msg,
+                    plaintextBuffer,
+                    ciphertextBuffer);
+            ctx.write(encryptedMessage, promise);
+        }
+    }
+
+    static class GcmEncryptedMessage extends AbstractFileRegion {
+        private final Object plaintextMessage;
+        private final ByteBuffer plaintextBuffer;
+        private final ByteBuffer ciphertextBuffer;
+        private final ByteBuffer headerByteBuffer;
+        private final long bytesToRead;
+        private long bytesRead = 0;
+        private final StreamSegmentEncrypter encrypter;
+        private long transferred = 0;
+        private final long encryptedCount;
+
+        GcmEncryptedMessage(AesGcmHkdfStreaming aesGcmHkdfStreaming,
+                            Object plaintextMessage,
+                            ByteBuffer plaintextBuffer,
+                            ByteBuffer ciphertextBuffer) throws 
GeneralSecurityException {
+            Preconditions.checkArgument(
+                    plaintextMessage instanceof ByteBuf || plaintextMessage 
instanceof FileRegion,
+                    "Unrecognized message type: %s", 
plaintextMessage.getClass().getName());
+            this.plaintextMessage = plaintextMessage;
+            this.plaintextBuffer = plaintextBuffer;
+            this.ciphertextBuffer = ciphertextBuffer;
+            // If the ciphertext buffer cannot be fully written the target, 
transferTo may
+            // return with it containing some unwritten data. The initial call 
we'll explicitly
+            // set its limit to 0 to indicate the first call to transferTo.
+            this.ciphertextBuffer.limit(0);
+
+            this.bytesToRead = getReadableBytes();
+            this.encryptedCount =
+                    LENGTH_HEADER_BYTES + 
aesGcmHkdfStreaming.expectedCiphertextSize(bytesToRead);
+            byte[] lengthAad = Longs.toByteArray(encryptedCount);
+            this.encrypter = 
aesGcmHkdfStreaming.newStreamSegmentEncrypter(lengthAad);
+            this.headerByteBuffer = createHeaderByteBuffer();
+        }
+
+        // The format of the output is:
+        // [8 byte length][Internal IV and header][Ciphertext][Auth Tag]
+        private ByteBuffer createHeaderByteBuffer() {
+            ByteBuffer encrypterHeader = encrypter.getHeader();
+            return ByteBuffer
+                    .allocate(encrypterHeader.remaining() + 
LENGTH_HEADER_BYTES)
+                    .putLong(encryptedCount)
+                    .put(encrypterHeader)
+                    .flip();
+        }
+
+        @Override
+        public long position() {
+            return 0;
+        }
+
+        @Override
+        public long transferred() {
+            return transferred;
+        }
+
+        @Override
+        public long count() {
+            return encryptedCount;
+        }
+
+        @Override
+        public GcmEncryptedMessage touch(Object o) {
+            super.touch(o);
+            if (plaintextMessage instanceof ByteBuf byteBuf) {
+                byteBuf.touch(o);
+            } else if (plaintextMessage instanceof FileRegion fileRegion) {
+                fileRegion.touch(o);
+            }
+            return this;
+        }
+
+        @Override
+        public GcmEncryptedMessage retain(int increment) {
+            super.retain(increment);
+            if (plaintextMessage instanceof ByteBuf byteBuf) {
+                byteBuf.retain(increment);
+            } else if (plaintextMessage instanceof FileRegion fileRegion) {
+                fileRegion.retain(increment);
+            }
+            return this;
+        }
+
+        @Override
+        public boolean release(int decrement) {
+            if (plaintextMessage instanceof ByteBuf byteBuf) {
+                byteBuf.release(decrement);
+            } else if (plaintextMessage instanceof FileRegion fileRegion) {
+                fileRegion.release(decrement);
+            }
+            return super.release(decrement);
+        }
+
+        @Override
+        public long transferTo(WritableByteChannel target, long position) 
throws IOException {
+            int transferredThisCall = 0;
+            // If the header has is not empty, try to write it out to the 
target.
+            if (headerByteBuffer.hasRemaining()) {
+                int written = target.write(headerByteBuffer);
+                transferredThisCall += written;
+                this.transferred += written;
+                if (headerByteBuffer.hasRemaining()) {
+                    return written;
+                }
+            }
+            // If the ciphertext buffer is not empty, try to write it to the 
target.
+            if (ciphertextBuffer.hasRemaining()) {
+                int written = target.write(ciphertextBuffer);
+                transferredThisCall += written;
+                this.transferred += written;
+                if (ciphertextBuffer.hasRemaining()) {
+                    return transferredThisCall;
+               }
+            }
+            while (bytesRead < bytesToRead) {
+                long readableBytes = getReadableBytes();
+                int readLimit =
+                        (int) Math.min(readableBytes, 
plaintextBuffer.remaining());
+                if (plaintextMessage instanceof ByteBuf byteBuf) {
+                    Preconditions.checkState(0 == plaintextBuffer.position());
+                    plaintextBuffer.limit(readLimit);
+                    byteBuf.readBytes(plaintextBuffer);
+                    Preconditions.checkState(readLimit == 
plaintextBuffer.position());
+                } else if (plaintextMessage instanceof FileRegion fileRegion) {
+                    ByteBufferWriteableChannel plaintextChannel =
+                            new ByteBufferWriteableChannel(plaintextBuffer);
+                    long plaintextRead =
+                            fileRegion.transferTo(plaintextChannel, 
fileRegion.transferred());
+                    if (plaintextRead < readLimit) {
+                        // If we do not read a full plaintext buffer or all 
the available
+                        // readable bytes, return what was transferred this 
call.
+                        return transferredThisCall;
+                    }
+                }
+                boolean lastSegment = getReadableBytes() == 0;
+                plaintextBuffer.flip();
+                bytesRead += plaintextBuffer.remaining();
+                ciphertextBuffer.clear();
+                try {
+                    encrypter.encryptSegment(plaintextBuffer, lastSegment, 
ciphertextBuffer);
+                } catch (GeneralSecurityException e) {
+                    throw new IllegalStateException("GeneralSecurityException 
from encrypter", e);
+                }
+                plaintextBuffer.clear();
+                ciphertextBuffer.flip();
+                int written = target.write(ciphertextBuffer);
+                transferredThisCall += written;
+                this.transferred += written;
+                if (ciphertextBuffer.hasRemaining()) {
+                    // In this case, upon calling transferTo again, it will 
try to write the
+                    // remaining ciphertext buffer in the conditional before 
this loop.
+                    return transferredThisCall;
+                }
+            }
+            return transferredThisCall;
+        }
+
+        private long getReadableBytes() {
+            if (plaintextMessage instanceof ByteBuf byteBuf) {
+                return byteBuf.readableBytes();
+            } else if (plaintextMessage instanceof FileRegion fileRegion) {
+                return fileRegion.count() - fileRegion.transferred();
+            } else {
+                throw new IllegalArgumentException("Unsupported message type: 
" +
+                        plaintextMessage.getClass().getName());
+            }
+        }
+
+        @Override
+        protected void deallocate() {
+            if (plaintextMessage instanceof ReferenceCounted referenceCounted) 
{
+                referenceCounted.release();
+            }
+            plaintextBuffer.clear();
+            ciphertextBuffer.clear();
+        }
+    }
+
+    @VisibleForTesting
+    class DecryptionHandler extends ChannelInboundHandlerAdapter {
+        private final ByteBuffer expectedLengthBuffer;
+        private final ByteBuffer headerBuffer;
+        private final ByteBuffer ciphertextBuffer;
+        private final AesGcmHkdfStreaming aesGcmHkdfStreaming;
+        private final StreamSegmentDecrypter decrypter;
+        private final int plaintextSegmentSize;
+        private boolean decrypterInit = false;
+        private boolean completed = false;
+        private int segmentNumber = 0;
+        private long expectedLength = -1;
+        private long ciphertextRead = 0;
+
+        DecryptionHandler() throws GeneralSecurityException {
+            aesGcmHkdfStreaming = getAesGcmHkdfStreaming();
+            expectedLengthBuffer = ByteBuffer.allocate(LENGTH_HEADER_BYTES);
+            headerBuffer = 
ByteBuffer.allocate(aesGcmHkdfStreaming.getHeaderLength());
+            ciphertextBuffer =
+                    
ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize());
+            decrypter = aesGcmHkdfStreaming.newStreamSegmentDecrypter();
+            plaintextSegmentSize = 
aesGcmHkdfStreaming.getPlaintextSegmentSize();
+        }
+
+        private boolean initalizeExpectedLength(ByteBuf ciphertextNettyBuf) {
+            if (expectedLength < 0) {
+                ciphertextNettyBuf.readBytes(expectedLengthBuffer);
+                if (expectedLengthBuffer.hasRemaining()) {
+                    // We did not read enough bytes to initialize the expected 
length.
+                    return false;
+                }
+                expectedLengthBuffer.flip();
+                expectedLength = expectedLengthBuffer.getLong();
+                if (expectedLength < 0) {
+                    throw new IllegalStateException("Invalid expected 
ciphertext length.");
+                }
+                ciphertextRead += LENGTH_HEADER_BYTES;
+            }
+            return true;
+        }
+
+        private boolean initalizeDecrypter(ByteBuf ciphertextNettyBuf)
+                throws GeneralSecurityException {
+            // Check if the ciphertext header has been read. This contains
+            // the IV and other internal metadata.
+            if (!decrypterInit) {
+                ciphertextNettyBuf.readBytes(headerBuffer);
+                if (headerBuffer.hasRemaining()) {
+                    // We did not read enough bytes to initialize the header.
+                    return false;
+                }
+                headerBuffer.flip();
+                byte[] lengthAad = Longs.toByteArray(expectedLength);
+                decrypter.init(headerBuffer, lengthAad);
+                decrypterInit = true;
+                ciphertextRead += aesGcmHkdfStreaming.getHeaderLength();
+                if (expectedLength == ciphertextRead) {
+                    // If the expected length is just the header, the 
ciphertext is 0 length.
+                    completed = true;
+                }
+            }
+            return true;
+        }
+
+        @Override
+        public void channelRead(ChannelHandlerContext ctx, Object 
ciphertextMessage)
+                throws GeneralSecurityException {
+            Preconditions.checkArgument(ciphertextMessage instanceof ByteBuf,
+                    "Unrecognized message type: %s",
+                    ciphertextMessage.getClass().getName());
+            ByteBuf ciphertextNettyBuf = (ByteBuf) ciphertextMessage;
+            // The format of the output is:
+            // [8 byte length][Internal IV and header][Ciphertext][Auth Tag]
+            try {
+                if (!initalizeExpectedLength(ciphertextNettyBuf)) {
+                    // We have not read enough bytes to initialize the 
expected length.
+                    return;
+                }
+                if (!initalizeDecrypter(ciphertextNettyBuf)) {
+                    // We have not read enough bytes to initialize a header, 
needed to
+                    // initialize a decrypter.
+                    return;
+                }
+                int nettyBufReadableBytes = ciphertextNettyBuf.readableBytes();
+                while (nettyBufReadableBytes > 0 && !completed) {
+                    // Read the ciphertext into the local buffer
+                    int readableBytes = Integer.min(
+                            nettyBufReadableBytes,
+                            ciphertextBuffer.remaining());
+                    int expectedRemaining = (int) (expectedLength - 
ciphertextRead);
+                    int bytesToRead = Integer.min(readableBytes, 
expectedRemaining);
+                    // The smallest ciphertext size is 16 bytes for the auth 
tag
+                    ciphertextBuffer.limit(ciphertextBuffer.position() + 
bytesToRead);
+                    ciphertextNettyBuf.readBytes(ciphertextBuffer);
+                    ciphertextRead += bytesToRead;
+                    // Check if this is the last segment
+                    if (ciphertextRead == expectedLength) {
+                        completed = true;
+                    } else if (ciphertextRead > expectedLength) {
+                        throw new IllegalStateException("Read more ciphertext 
than expected.");
+                    }
+                    // If the ciphertext buffer is full, or this is the last 
segment,
+                    // then decrypt it and fire a read.
+                    if (ciphertextBuffer.limit() == 
ciphertextBuffer.capacity() || completed) {
+                        ByteBuffer plaintextBuffer = 
ByteBuffer.allocate(plaintextSegmentSize);
+                        ciphertextBuffer.flip();
+                        decrypter.decryptSegment(
+                                ciphertextBuffer,
+                                segmentNumber,
+                                completed,
+                                plaintextBuffer);
+                        segmentNumber++;
+                        // Clear the ciphertext buffer because it's been read
+                        ciphertextBuffer.clear();
+                        plaintextBuffer.flip();
+                        
ctx.fireChannelRead(Unpooled.wrappedBuffer(plaintextBuffer));
+                    } else {
+                        // Set the ciphertext buffer up to read the next chunk
+                        ciphertextBuffer.limit(ciphertextBuffer.capacity());
+                    }
+                    nettyBufReadableBytes = ciphertextNettyBuf.readableBytes();
+                }
+            } finally {
+                ciphertextNettyBuf.release();
+            }
+        }
+    }
+}
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
 
b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
index b507f911fe11..355c55272018 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java
@@ -17,362 +17,32 @@
 
 package org.apache.spark.network.crypto;
 
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.nio.channels.ReadableByteChannel;
-import java.nio.channels.WritableByteChannel;
-import java.util.Properties;
-import javax.crypto.spec.SecretKeySpec;
-import javax.crypto.spec.IvParameterSpec;
-
 import com.google.common.annotations.VisibleForTesting;
-import com.google.common.base.Preconditions;
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.Unpooled;
-import io.netty.channel.*;
-import org.apache.commons.crypto.stream.CryptoInputStream;
-import org.apache.commons.crypto.stream.CryptoOutputStream;
-
-import org.apache.spark.network.util.AbstractFileRegion;
-import org.apache.spark.network.util.ByteArrayReadableChannel;
-import org.apache.spark.network.util.ByteArrayWritableChannel;
-
-/**
- * Cipher for encryption and decryption.
- */
-public class TransportCipher {
-  @VisibleForTesting
-  static final String ENCRYPTION_HANDLER_NAME = "TransportEncryption";
-  private static final String DECRYPTION_HANDLER_NAME = "TransportDecryption";
-  @VisibleForTesting
-  static final int STREAM_BUFFER_SIZE = 1024 * 32;
-
-  private final Properties conf;
-  private final String cipher;
-  private final SecretKeySpec key;
-  private final byte[] inIv;
-  private final byte[] outIv;
-
-  public TransportCipher(
-      Properties conf,
-      String cipher,
-      SecretKeySpec key,
-      byte[] inIv,
-      byte[] outIv) {
-    this.conf = conf;
-    this.cipher = cipher;
-    this.key = key;
-    this.inIv = inIv;
-    this.outIv = outIv;
-  }
-
-  public String getCipherTransformation() {
-    return cipher;
-  }
-
-  @VisibleForTesting
-  SecretKeySpec getKey() {
-    return key;
-  }
-
-  /** The IV for the input channel (i.e. output channel of the remote side). */
-  public byte[] getInputIv() {
-    return inIv;
-  }
-
-  /** The IV for the output channel (i.e. input channel of the remote side). */
-  public byte[] getOutputIv() {
-    return outIv;
-  }
-
-  @VisibleForTesting
-  CryptoOutputStream createOutputStream(WritableByteChannel ch) throws 
IOException {
-    return new CryptoOutputStream(cipher, conf, ch, key, new 
IvParameterSpec(outIv));
-  }
-
-  @VisibleForTesting
-  CryptoInputStream createInputStream(ReadableByteChannel ch) throws 
IOException {
-    return new CryptoInputStream(cipher, conf, ch, key, new 
IvParameterSpec(inIv));
-  }
-
-  /**
-   * Add handlers to channel.
-   *
-   * @param ch the channel for adding handlers
-   * @throws IOException
-   */
-  public void addToChannel(Channel ch) throws IOException {
-    ch.pipeline()
-      .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(this))
-      .addFirst(DECRYPTION_HANDLER_NAME, new DecryptionHandler(this));
-  }
-
-  @VisibleForTesting
-  static class EncryptionHandler extends ChannelOutboundHandlerAdapter {
-    private final ByteArrayWritableChannel byteEncChannel;
-    private final CryptoOutputStream cos;
-    private final ByteArrayWritableChannel byteRawChannel;
-    private boolean isCipherValid;
-
-    EncryptionHandler(TransportCipher cipher) throws IOException {
-      byteEncChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE);
-      cos = cipher.createOutputStream(byteEncChannel);
-      byteRawChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE);
-      isCipherValid = true;
-    }
+import com.google.crypto.tink.subtle.Hex;
+import com.google.crypto.tink.subtle.Hkdf;
+import io.netty.channel.Channel;
 
-    @Override
-    public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise 
promise)
-      throws Exception {
-      ctx.write(createEncryptedMessage(msg), promise);
-    }
-
-    @VisibleForTesting
-    EncryptedMessage createEncryptedMessage(Object msg) {
-      return new EncryptedMessage(this, cos, msg, byteEncChannel, 
byteRawChannel);
-    }
+import javax.crypto.spec.SecretKeySpec;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.security.GeneralSecurityException;
 
-    @Override
-    public void close(ChannelHandlerContext ctx, ChannelPromise promise) 
throws Exception {
-      try {
-        if (isCipherValid) {
-          cos.close();
-        }
-      } finally {
-        super.close(ctx, promise);
-      }
-    }
+interface TransportCipher {
+    String getKeyId() throws GeneralSecurityException;
+    void addToChannel(Channel channel) throws IOException, 
GeneralSecurityException;
+}
 
-    /**
-     * SPARK-25535. Workaround for CRYPTO-141. Avoid further interaction with 
the underlying cipher
-     * after an error occurs.
+class TransportCipherUtil {
+    /*
+     * This method is used for testing to verify key derivation.
      */
-    void reportError() {
-      this.isCipherValid = false;
-    }
-
-    boolean isCipherValid() {
-      return isCipherValid;
-    }
-  }
-
-  private static class DecryptionHandler extends ChannelInboundHandlerAdapter {
-    private final CryptoInputStream cis;
-    private final ByteArrayReadableChannel byteChannel;
-    private boolean isCipherValid;
-
-    DecryptionHandler(TransportCipher cipher) throws IOException {
-      byteChannel = new ByteArrayReadableChannel();
-      cis = cipher.createInputStream(byteChannel);
-      isCipherValid = true;
-    }
-
-    @Override
-    public void channelRead(ChannelHandlerContext ctx, Object data) throws 
Exception {
-      ByteBuf buffer = (ByteBuf) data;
-
-      try {
-        if (!isCipherValid) {
-          throw new IOException("Cipher is in invalid state.");
-        }
-        byte[] decryptedData = new byte[buffer.readableBytes()];
-        byteChannel.feedData(buffer);
-
-        int offset = 0;
-        while (offset < decryptedData.length) {
-          // SPARK-25535: workaround for CRYPTO-141.
-          try {
-            offset += cis.read(decryptedData, offset, decryptedData.length - 
offset);
-          } catch (InternalError ie) {
-            isCipherValid = false;
-            throw ie;
-          }
-        }
-
-        ctx.fireChannelRead(Unpooled.wrappedBuffer(decryptedData, 0, 
decryptedData.length));
-      } finally {
-        buffer.release();
-      }
-    }
-
-    @Override
-    public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
-      // We do the closing of the stream / channel in handlerRemoved(...) as
-      // this method will be called in all cases:
-      //
-      //     - when the Channel becomes inactive
-      //     - when the handler is removed from the ChannelPipeline
-      try {
-        if (isCipherValid) {
-          cis.close();
-        }
-      } finally {
-        super.handlerRemoved(ctx);
-      }
-    }
-  }
-
-  @VisibleForTesting
-  static class EncryptedMessage extends AbstractFileRegion {
-    private final boolean isByteBuf;
-    private final ByteBuf buf;
-    private final FileRegion region;
-    private final CryptoOutputStream cos;
-    private final EncryptionHandler handler;
-    private final long count;
-    private long transferred;
-
-    // Due to streaming issue CRYPTO-125: 
https://issues.apache.org/jira/browse/CRYPTO-125, it has
-    // to utilize two helper ByteArrayWritableChannel for streaming. One is 
used to receive raw data
-    // from upper handler, another is used to store encrypted data.
-    private final ByteArrayWritableChannel byteEncChannel;
-    private final ByteArrayWritableChannel byteRawChannel;
-
-    private ByteBuffer currentEncrypted;
-
-    EncryptedMessage(
-        EncryptionHandler handler,
-        CryptoOutputStream cos,
-        Object msg,
-        ByteArrayWritableChannel byteEncChannel,
-        ByteArrayWritableChannel byteRawChannel) {
-      Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof 
FileRegion,
-        "Unrecognized message type: %s", msg.getClass().getName());
-      this.handler = handler;
-      this.isByteBuf = msg instanceof ByteBuf;
-      this.buf = isByteBuf ? (ByteBuf) msg : null;
-      this.region = isByteBuf ? null : (FileRegion) msg;
-      this.transferred = 0;
-      this.cos = cos;
-      this.byteEncChannel = byteEncChannel;
-      this.byteRawChannel = byteRawChannel;
-      this.count = isByteBuf ? buf.readableBytes() : region.count();
-    }
-
-    @Override
-    public long count() {
-      return count;
-    }
-
-    @Override
-    public long position() {
-      return 0;
-    }
-
-    @Override
-    public long transferred() {
-      return transferred;
-    }
-
-    @Override
-    public EncryptedMessage touch(Object o) {
-      super.touch(o);
-      if (region != null) {
-        region.touch(o);
-      }
-      if (buf != null) {
-        buf.touch(o);
-      }
-      return this;
-    }
-
-    @Override
-    public EncryptedMessage retain(int increment) {
-      super.retain(increment);
-      if (region != null) {
-        region.retain(increment);
-      }
-      if (buf != null) {
-        buf.retain(increment);
-      }
-      return this;
-    }
-
-    @Override
-    public boolean release(int decrement) {
-      if (region != null) {
-        region.release(decrement);
-      }
-      if (buf != null) {
-        buf.release(decrement);
-      }
-      return super.release(decrement);
-    }
-
-    @Override
-    public long transferTo(WritableByteChannel target, long position) throws 
IOException {
-      Preconditions.checkArgument(position == transferred(), "Invalid 
position.");
-
-      if (transferred == count) {
-        return 0;
-      }
-
-      long totalBytesWritten = 0L;
-      do {
-        if (currentEncrypted == null) {
-          encryptMore();
-        }
-
-        long remaining = currentEncrypted.remaining();
-        if (remaining == 0)  {
-          // Just for safety to avoid endless loop. It usually won't happen, 
but since the
-          // underlying `region.transferTo` is allowed to transfer 0 bytes, we 
should handle it for
-          // safety.
-          currentEncrypted = null;
-          byteEncChannel.reset();
-          return totalBytesWritten;
-        }
-
-        long bytesWritten = target.write(currentEncrypted);
-        totalBytesWritten += bytesWritten;
-        transferred += bytesWritten;
-        if (bytesWritten < remaining) {
-          // break as the underlying buffer in "target" is full
-          break;
-        }
-        currentEncrypted = null;
-        byteEncChannel.reset();
-      } while (transferred < count);
-
-      return totalBytesWritten;
-    }
-
-    private void encryptMore() throws IOException {
-      if (!handler.isCipherValid()) {
-        throw new IOException("Cipher is in invalid state.");
-      }
-      byteRawChannel.reset();
-
-      if (isByteBuf) {
-        int copied = byteRawChannel.write(buf.nioBuffer());
-        buf.skipBytes(copied);
-      } else {
-        region.transferTo(byteRawChannel, region.transferred());
-      }
-
-      try {
-        cos.write(byteRawChannel.getData(), 0, byteRawChannel.length());
-        cos.flush();
-      } catch (InternalError ie) {
-        handler.reportError();
-        throw ie;
-      }
-
-      currentEncrypted = ByteBuffer.wrap(byteEncChannel.getData(),
-        0, byteEncChannel.length());
-    }
-
-    @Override
-    protected void deallocate() {
-      byteRawChannel.reset();
-      byteEncChannel.reset();
-      if (region != null) {
-        region.release();
-      }
-      if (buf != null) {
-        buf.release();
-      }
+    @VisibleForTesting
+    static String getKeyId(SecretKeySpec key) throws GeneralSecurityException {
+        byte[] keyIdBytes = Hkdf.computeHkdf("HmacSha256",
+                key.getEncoded(),
+                null,
+                "keyID".getBytes(StandardCharsets.UTF_8),
+                32);
+        return Hex.encode(keyIdBytes);
     }
-  }
-
 }
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/util/ByteBufferWriteableChannel.java
 
b/common/network-common/src/main/java/org/apache/spark/network/util/ByteBufferWriteableChannel.java
new file mode 100644
index 000000000000..b20240cfcaa6
--- /dev/null
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/util/ByteBufferWriteableChannel.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.ClosedChannelException;
+import java.nio.channels.WritableByteChannel;
+
+public class ByteBufferWriteableChannel implements WritableByteChannel {
+    private final ByteBuffer destination;
+    private boolean open;
+
+    public ByteBufferWriteableChannel(ByteBuffer destination) {
+        this.destination = destination;
+        this.open = true;
+    }
+
+    @Override
+    public int write(ByteBuffer src) throws IOException {
+        if (!isOpen()) {
+            throw new ClosedChannelException();
+        }
+        int bytesToWrite = Math.min(src.remaining(), destination.remaining());
+        // Destination buffer is full
+        if (bytesToWrite == 0) {
+            return 0;
+        }
+        ByteBuffer temp = src.slice().limit(bytesToWrite);
+        destination.put(temp);
+        src.position(src.position() + bytesToWrite);
+        return bytesToWrite;
+    }
+
+    @Override
+    public boolean isOpen() {
+        return open;
+    }
+
+    @Override
+    public void close() {
+        open = false;
+    }
+}
diff --git 
a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java
 
b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java
index e9846be20c9b..628de9e78033 100644
--- 
a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java
+++ 
b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java
@@ -18,75 +18,76 @@
 package org.apache.spark.network.crypto;
 
 import java.nio.ByteBuffer;
-import java.nio.channels.WritableByteChannel;
 import java.security.GeneralSecurityException;
-import java.util.Collections;
-import java.util.Random;
+import java.util.Map;
 
+import com.google.common.collect.ImmutableMap;
 import com.google.crypto.tink.subtle.Hex;
-import io.netty.buffer.ByteBuf;
-import io.netty.buffer.Unpooled;
-import io.netty.channel.FileRegion;
-import org.apache.spark.network.util.ByteArrayWritableChannel;
-import org.apache.spark.network.util.ConfigProvider;
-import org.apache.spark.network.util.MapConfigProvider;
-import org.apache.spark.network.util.TransportConf;
+import org.apache.spark.network.util.*;
+
 import static org.junit.jupiter.api.Assertions.*;
-import org.junit.jupiter.api.BeforeAll;
 import org.junit.jupiter.api.Test;
-import static org.mockito.Mockito.*;
-import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
-
-public class AuthEngineSuite {
 
-  private static final String clientPrivate =
-      "efe6b68b3fce92158e3637f6ef9d937e75558928dd4b401de04b43d300a73186";
-  private static final String clientChallengeHex =
-      
"fb00000005617070496400000010890b6e960f48e998777267a7e4e623220000003c48ad7dc7ec9466da9"
 +
-      
"3bda9f11488dc9404050e02c661d87d67c782444944c6e369b27e0a416c30845a2d9e64271511ca98b41d"
 +
-      "65f8c426e18ff380f6";
-  private static final String serverResponseHex =
-      
"fb00000005617070496400000010708451c9dd2792c97c1ca66e6df449ef0000003c64fe899ecdaf458d4"
 +
-      
"e25e9d5c5a380b8e6d1a184692fac065ed84f8592c18e9629f9c636809dca2ffc041f20346eb53db78738"
 +
-      "08ecad08b46b5ee3ff";
-  private static final String derivedKey = "2d6e7a9048c8265c33a8f3747bfcc84c";
+abstract class AuthEngineSuite {
+  static final String clientPrivate =
+          "efe6b68b3fce92158e3637f6ef9d937e75558928dd4b401de04b43d300a73186";
+  static final String clientChallengeHex =
+          
"fb00000005617070496400000010890b6e960f48e998777267a7e4e623220000003c48ad7dc7ec9466da9"
 +
+          
"3bda9f11488dc9404050e02c661d87d67c782444944c6e369b27e0a416c30845a2d9e64271511ca98b41d"
 +
+          "65f8c426e18ff380f6";
+  static final String serverResponseHex =
+          
"fb00000005617070496400000010708451c9dd2792c97c1ca66e6df449ef0000003c64fe899ecdaf458d4"
 +
+          
"e25e9d5c5a380b8e6d1a184692fac065ed84f8592c18e9629f9c636809dca2ffc041f20346eb53db78738"
 +
+          "08ecad08b46b5ee3ff";
+  static final String derivedKeyId =
+          "de04fd52d71040ed9d260579dacfdf4f5695f991ce8ddb1dde05a7335880906e";
   // This key would have been derived for version 1.0 protocol that did not 
run a final HKDF round.
-  private static final String unsafeDerivedKey =
-      "31963f15a320d5c90333f7ecf5cf3a31c7eaf151de07fef8494663a9f47cfd31";
-
-  private static final String inputIv = "fc6a5dc8b90a9dad8f54f08b51a59ed2";
-  private static final String outputIv = "a72709baf00785cad6329ce09f631f71";
-  private static TransportConf conf;
-
-  @BeforeAll
-  public static void setUp() {
-    ConfigProvider v2Provider = new MapConfigProvider(Collections.singletonMap(
-            "spark.network.crypto.authEngineVersion", "2"));
-    conf = new TransportConf("rpc", v2Provider);
+  static final String unsafeDerivedKey =
+          "31963f15a320d5c90333f7ecf5cf3a31c7eaf151de07fef8494663a9f47cfd31";
+  static TransportConf conf;
+
+  static TransportConf getConf(int authEngineVerison, boolean useCtr) {
+    String authEngineVersion = (authEngineVerison == 1) ? "1" : "2";
+    String mode = useCtr ? "AES/CTR/NoPadding" : "AES/GCM/NoPadding";
+    Map<String, String> confMap = ImmutableMap.of(
+            "spark.network.crypto.enabled", "true",
+            "spark.network.crypto.authEngineVersion", authEngineVersion,
+            "spark.network.crypto.cipher", mode
+    );
+    ConfigProvider v2Provider = new MapConfigProvider(confMap);
+    return new TransportConf("rpc", v2Provider);
   }
 
   @Test
   public void testAuthEngine() throws Exception {
-
     try (AuthEngine client = new AuthEngine("appId", "secret", conf);
          AuthEngine server = new AuthEngine("appId", "secret", conf)) {
       AuthMessage clientChallenge = client.challenge();
       AuthMessage serverResponse = server.response(clientChallenge);
       client.deriveSessionCipher(clientChallenge, serverResponse);
-
       TransportCipher serverCipher = server.sessionCipher();
       TransportCipher clientCipher = client.sessionCipher();
+      assertEquals(clientCipher.getKeyId(), serverCipher.getKeyId());
+    }
+  }
 
-      assertArrayEquals(serverCipher.getInputIv(), clientCipher.getOutputIv());
-      assertArrayEquals(serverCipher.getOutputIv(), clientCipher.getInputIv());
-      assertEquals(serverCipher.getKey(), clientCipher.getKey());
+  @Test
+  public void testFixedChallengeResponse() throws Exception {
+    try (AuthEngine client = new AuthEngine("appId", "secret", conf)) {
+      byte[] clientPrivateKey = Hex.decode(clientPrivate);
+      client.setClientPrivateKey(clientPrivateKey);
+      AuthMessage clientChallenge =
+              
AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex)));
+      AuthMessage serverResponse =
+              
AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(serverResponseHex)));
+      // Verify that the client will accept an old transcript.
+      client.deriveSessionCipher(clientChallenge, serverResponse);
+      assertEquals(client.sessionCipher().getKeyId(), derivedKeyId);
     }
   }
 
   @Test
   public void testCorruptChallengeAppId() throws Exception {
-
     try (AuthEngine client = new AuthEngine("appId", "secret", conf);
          AuthEngine server = new AuthEngine("appId", "secret", conf)) {
       AuthMessage clientChallenge = client.challenge();
@@ -98,7 +99,6 @@ public class AuthEngineSuite {
 
   @Test
   public void testCorruptChallengeSalt() throws Exception {
-
     try (AuthEngine client = new AuthEngine("appId", "secret", conf);
          AuthEngine server = new AuthEngine("appId", "secret", conf)) {
       AuthMessage clientChallenge = client.challenge();
@@ -109,7 +109,6 @@ public class AuthEngineSuite {
 
   @Test
   public void testCorruptChallengeCiphertext() throws Exception {
-
     try (AuthEngine client = new AuthEngine("appId", "secret", conf);
          AuthEngine server = new AuthEngine("appId", "secret", conf)) {
       AuthMessage clientChallenge = client.challenge();
@@ -120,7 +119,6 @@ public class AuthEngineSuite {
 
   @Test
   public void testCorruptResponseAppId() throws Exception {
-
     try (AuthEngine client = new AuthEngine("appId", "secret", conf);
          AuthEngine server = new AuthEngine("appId", "secret", conf)) {
       AuthMessage clientChallenge = client.challenge();
@@ -134,20 +132,18 @@ public class AuthEngineSuite {
 
   @Test
   public void testCorruptResponseSalt() throws Exception {
-
     try (AuthEngine client = new AuthEngine("appId", "secret", conf);
          AuthEngine server = new AuthEngine("appId", "secret", conf)) {
       AuthMessage clientChallenge = client.challenge();
       AuthMessage serverResponse = server.response(clientChallenge);
       serverResponse.salt()[0] ^= 1;
       assertThrows(GeneralSecurityException.class,
-        () -> client.deriveSessionCipher(clientChallenge, serverResponse));
+              () -> client.deriveSessionCipher(clientChallenge, 
serverResponse));
     }
   }
 
   @Test
   public void testCorruptServerCiphertext() throws Exception {
-
     try (AuthEngine client = new AuthEngine("appId", "secret", conf);
          AuthEngine server = new AuthEngine("appId", "secret", conf)) {
       AuthMessage clientChallenge = client.challenge();
@@ -169,45 +165,6 @@ public class AuthEngineSuite {
     }
   }
 
-  @Test
-  public void testFixedChallengeResponse() throws Exception {
-    try (AuthEngine client = new AuthEngine("appId", "secret", conf)) {
-      byte[] clientPrivateKey = Hex.decode(clientPrivate);
-      client.setClientPrivateKey(clientPrivateKey);
-      AuthMessage clientChallenge =
-              
AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex)));
-      AuthMessage serverResponse =
-              
AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(serverResponseHex)));
-      // Verify that the client will accept an old transcript.
-      client.deriveSessionCipher(clientChallenge, serverResponse);
-      TransportCipher clientCipher = client.sessionCipher();
-      assertEquals(Hex.encode(clientCipher.getKey().getEncoded()), derivedKey);
-      assertEquals(Hex.encode(clientCipher.getInputIv()), inputIv);
-      assertEquals(Hex.encode(clientCipher.getOutputIv()), outputIv);
-    }
-  }
-
-  @Test
-  public void testFixedChallengeResponseUnsafeVersion() throws Exception {
-    ConfigProvider v1Provider = new MapConfigProvider(Collections.singletonMap(
-            "spark.network.crypto.authEngineVersion", "1"));
-    TransportConf v1Conf = new TransportConf("rpc", v1Provider);
-    try (AuthEngine client = new AuthEngine("appId", "secret", v1Conf)) {
-      byte[] clientPrivateKey = Hex.decode(clientPrivate);
-      client.setClientPrivateKey(clientPrivateKey);
-      AuthMessage clientChallenge =
-              
AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex)));
-      AuthMessage serverResponse =
-              
AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(serverResponseHex)));
-      // Verify that the client will accept an old transcript.
-      client.deriveSessionCipher(clientChallenge, serverResponse);
-      TransportCipher clientCipher = client.sessionCipher();
-      assertEquals(Hex.encode(clientCipher.getKey().getEncoded()), 
unsafeDerivedKey);
-      assertEquals(Hex.encode(clientCipher.getInputIv()), inputIv);
-      assertEquals(Hex.encode(clientCipher.getOutputIv()), outputIv);
-    }
-  }
-
   @Test
   public void testMismatchedSecret() throws Exception {
     try (AuthEngine client = new AuthEngine("appId", "secret", conf);
@@ -216,70 +173,4 @@ public class AuthEngineSuite {
       assertThrows(GeneralSecurityException.class, () -> 
server.response(clientChallenge));
     }
   }
-
-  @Test
-  public void testEncryptedMessage() throws Exception {
-    try (AuthEngine client = new AuthEngine("appId", "secret", conf);
-         AuthEngine server = new AuthEngine("appId", "secret", conf)) {
-      AuthMessage clientChallenge = client.challenge();
-      AuthMessage serverResponse = server.response(clientChallenge);
-      client.deriveSessionCipher(clientChallenge, serverResponse);
-
-      TransportCipher cipher = server.sessionCipher();
-      TransportCipher.EncryptionHandler handler = new 
TransportCipher.EncryptionHandler(cipher);
-
-      byte[] data = new byte[TransportCipher.STREAM_BUFFER_SIZE + 1];
-      new Random().nextBytes(data);
-      ByteBuf buf = Unpooled.wrappedBuffer(data);
-
-      ByteArrayWritableChannel channel = new 
ByteArrayWritableChannel(data.length);
-      TransportCipher.EncryptedMessage emsg = 
handler.createEncryptedMessage(buf);
-      while (emsg.transferred() < emsg.count()) {
-        emsg.transferTo(channel, emsg.transferred());
-      }
-      assertEquals(data.length, channel.length());
-    }
-  }
-
-  @Test
-  public void testEncryptedMessageWhenTransferringZeroBytes() throws Exception 
{
-    try (AuthEngine client = new AuthEngine("appId", "secret", conf);
-         AuthEngine server = new AuthEngine("appId", "secret", conf)) {
-      AuthMessage clientChallenge = client.challenge();
-      AuthMessage serverResponse = server.response(clientChallenge);
-      client.deriveSessionCipher(clientChallenge, serverResponse);
-
-      TransportCipher cipher = server.sessionCipher();
-      TransportCipher.EncryptionHandler handler = new 
TransportCipher.EncryptionHandler(cipher);
-
-      int testDataLength = 4;
-      FileRegion region = mock(FileRegion.class);
-      when(region.count()).thenReturn((long) testDataLength);
-      // Make `region.transferTo` do nothing in first call and transfer 4 
bytes in the second one.
-      when(region.transferTo(any(), anyLong())).thenAnswer(new Answer<Long>() {
-
-        private boolean firstTime = true;
-
-        @Override
-        public Long answer(InvocationOnMock invocationOnMock) throws Throwable 
{
-          if (firstTime) {
-            firstTime = false;
-            return 0L;
-          } else {
-            WritableByteChannel channel = invocationOnMock.getArgument(0);
-            channel.write(ByteBuffer.wrap(new byte[testDataLength]));
-            return (long) testDataLength;
-          }
-        }
-      });
-
-      TransportCipher.EncryptedMessage emsg = 
handler.createEncryptedMessage(region);
-      ByteArrayWritableChannel channel = new 
ByteArrayWritableChannel(testDataLength);
-      // "transferTo" should act correctly when the underlying FileRegion 
transfers 0 bytes.
-      assertEquals(0L, emsg.transferTo(channel, emsg.transferred()));
-      assertEquals(testDataLength, emsg.transferTo(channel, 
emsg.transferred()));
-      assertEquals(emsg.transferred(), emsg.count());
-      assertEquals(4, channel.length());
-    }
-  }
 }
diff --git 
a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java
 
b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java
index 90f6c874a6c8..cb5929f7c65b 100644
--- 
a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java
+++ 
b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java
@@ -49,7 +49,7 @@ public class AuthIntegrationSuite {
   private AuthTestCtx ctx;
 
   @AfterEach
-  public void cleanUp() throws Exception {
+  public void cleanUp() {
     if (ctx != null) {
       ctx.close();
     }
@@ -57,8 +57,8 @@ public class AuthIntegrationSuite {
   }
 
   @Test
-  public void testNewAuth() throws Exception {
-    ctx = new AuthTestCtx();
+  public void testNewCtrAuth() throws Exception {
+    ctx = new AuthTestCtx(new DummyRpcHandler(), "AES/CTR/NoPadding");
     ctx.createServer("secret");
     ctx.createClient("secret");
 
@@ -68,8 +68,28 @@ public class AuthIntegrationSuite {
   }
 
   @Test
-  public void testAuthFailure() throws Exception {
-    ctx = new AuthTestCtx();
+  public void testNewGcmAuth() throws Exception {
+    ctx = new AuthTestCtx(new DummyRpcHandler(), "AES/GCM/NoPadding");
+    ctx.createServer("secret");
+    ctx.createClient("secret");
+    ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 
5000);
+    assertEquals("Pong", JavaUtils.bytesToString(reply));
+    assertNull(ctx.authRpcHandler.saslHandler);
+  }
+
+  @Test
+  public void testCtrAuthFailure() throws Exception {
+    ctx = new AuthTestCtx(new DummyRpcHandler(), "AES/CTR/NoPadding");
+    ctx.createServer("server");
+
+    assertThrows(Exception.class, () -> ctx.createClient("client"));
+    assertFalse(ctx.authRpcHandler.isAuthenticated());
+    assertFalse(ctx.serverChannel.isActive());
+  }
+
+  @Test
+  public void testGcmAuthFailure() throws Exception {
+    ctx = new AuthTestCtx(new DummyRpcHandler(), "AES/GCM/NoPadding");
     ctx.createServer("server");
 
     assertThrows(Exception.class, () -> ctx.createClient("client"));
@@ -100,7 +120,7 @@ public class AuthIntegrationSuite {
   }
 
   @Test
-  public void testAuthReplay() throws Exception {
+  public void testCtrAuthReplay() throws Exception {
     // This test covers the case where an attacker replays a challenge message 
sniffed from the
     // network, but doesn't know the actual secret. The server should close 
the connection as
     // soon as a message is sent after authentication is performed. This is 
emulated by removing
@@ -110,16 +130,16 @@ public class AuthIntegrationSuite {
     ctx.createClient("secret");
 
     assertNotNull(ctx.client.getChannel().pipeline()
-      .remove(TransportCipher.ENCRYPTION_HANDLER_NAME));
+      .remove(CtrTransportCipher.ENCRYPTION_HANDLER_NAME));
     assertThrows(Exception.class,
       () -> ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000));
     assertTrue(ctx.authRpcHandler.isAuthenticated());
   }
 
   @Test
-  public void testLargeMessageEncryption() throws Exception {
+  public void testLargeCtrMessageEncryption() throws Exception {
     // Use a big length to create a message that cannot be put into the 
encryption buffer completely
-    final int testErrorMessageLength = TransportCipher.STREAM_BUFFER_SIZE;
+    final int testErrorMessageLength = CtrTransportCipher.STREAM_BUFFER_SIZE;
     ctx = new AuthTestCtx(new RpcHandler() {
       @Override
       public void receive(
@@ -157,6 +177,23 @@ public class AuthIntegrationSuite {
     assertNotNull(ctx.authRpcHandler.getMergedBlockMetaReqHandler());
   }
 
+  private static class DummyRpcHandler extends RpcHandler {
+    @Override
+    public void receive(
+            TransportClient client,
+            ByteBuffer message,
+            RpcResponseCallback callback) {
+      String messageString = JavaUtils.bytesToString(message);
+      assertEquals("Ping", messageString);
+      callback.onSuccess(JavaUtils.stringToBytes("Pong"));
+    }
+
+    @Override
+    public StreamManager getStreamManager() {
+      return null;
+    }
+  }
+
   private static class AuthTestCtx {
 
     private final String appId = "testAppId";
@@ -169,25 +206,17 @@ public class AuthIntegrationSuite {
     volatile AuthRpcHandler authRpcHandler;
 
     AuthTestCtx() throws Exception {
-      this(new RpcHandler() {
-        @Override
-        public void receive(
-            TransportClient client,
-            ByteBuffer message,
-            RpcResponseCallback callback) {
-          assertEquals("Ping", JavaUtils.bytesToString(message));
-          callback.onSuccess(JavaUtils.stringToBytes("Pong"));
-        }
-
-        @Override
-        public StreamManager getStreamManager() {
-          return null;
-        }
-      });
+      this(new DummyRpcHandler());
     }
 
     AuthTestCtx(RpcHandler rpcHandler) throws Exception {
-      Map<String, String> testConf = 
ImmutableMap.of("spark.network.crypto.enabled", "true");
+        this(rpcHandler, "AES/CTR/NoPadding");
+    }
+
+    AuthTestCtx(RpcHandler rpcHandler, String mode) throws Exception {
+      Map<String, String> testConf = ImmutableMap.of(
+              "spark.network.crypto.enabled", "true",
+              "spark.network.crypto.cipher", mode);
       this.conf = new TransportConf("rpc", new MapConfigProvider(testConf));
       this.ctx = new TransportContext(conf, rpcHandler);
     }
diff --git 
a/common/network-common/src/test/java/org/apache/spark/network/crypto/CtrAuthEngineSuite.java
 
b/common/network-common/src/test/java/org/apache/spark/network/crypto/CtrAuthEngineSuite.java
new file mode 100644
index 000000000000..c353ee392ff4
--- /dev/null
+++ 
b/common/network-common/src/test/java/org/apache/spark/network/crypto/CtrAuthEngineSuite.java
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import com.google.crypto.tink.subtle.Hex;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.FileRegion;
+import org.apache.spark.network.util.ByteArrayWritableChannel;
+import org.apache.spark.network.util.TransportConf;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+import java.nio.ByteBuffer;
+import java.nio.channels.WritableByteChannel;
+import java.util.Random;
+
+import static org.junit.jupiter.api.Assertions.*;
+import static org.mockito.Mockito.*;
+
+public class CtrAuthEngineSuite extends AuthEngineSuite {
+  private static final String inputIv = "fc6a5dc8b90a9dad8f54f08b51a59ed2";
+  private static final String outputIv = "a72709baf00785cad6329ce09f631f71";
+
+  @BeforeAll
+  public static void setUp() {
+    conf = getConf(2, true);
+  }
+
+  @Test
+  public void testAuthEngine() throws Exception {
+    try (AuthEngine client = new AuthEngine("appId", "secret", conf);
+         AuthEngine server = new AuthEngine("appId", "secret", conf)) {
+      AuthMessage clientChallenge = client.challenge();
+      AuthMessage serverResponse = server.response(clientChallenge);
+      client.deriveSessionCipher(clientChallenge, serverResponse);
+
+      TransportCipher serverCipher = server.sessionCipher();
+      TransportCipher clientCipher = client.sessionCipher();
+      assert(clientCipher instanceof CtrTransportCipher);
+      assert(serverCipher instanceof CtrTransportCipher);
+      CtrTransportCipher ctrClient = (CtrTransportCipher) clientCipher;
+      CtrTransportCipher ctrServer = (CtrTransportCipher) serverCipher;
+      assertArrayEquals(ctrServer.getInputIv(), ctrClient.getOutputIv());
+      assertArrayEquals(ctrServer.getOutputIv(), ctrClient.getInputIv());
+      assertEquals(ctrServer.getKey(), ctrClient.getKey());
+    }
+  }
+
+  @Test
+  public void testCtrFixedChallengeIvResponse() throws Exception {
+    try (AuthEngine client = new AuthEngine("appId", "secret", conf)) {
+      byte[] clientPrivateKey = Hex.decode(clientPrivate);
+      client.setClientPrivateKey(clientPrivateKey);
+      AuthMessage clientChallenge =
+              
AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex)));
+      AuthMessage serverResponse =
+              
AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(serverResponseHex)));
+      // Verify that the client will accept an old transcript.
+      client.deriveSessionCipher(clientChallenge, serverResponse);
+      TransportCipher clientCipher = client.sessionCipher();
+      assertEquals(clientCipher.getKeyId(), derivedKeyId);
+      assert(clientCipher instanceof CtrTransportCipher);
+      CtrTransportCipher ctrTransportCipher = (CtrTransportCipher) 
clientCipher;
+      assertEquals(Hex.encode(ctrTransportCipher.getInputIv()), inputIv);
+      assertEquals(Hex.encode(ctrTransportCipher.getOutputIv()), outputIv);
+    }
+  }
+
+  @Test
+  public void testFixedChallengeResponseUnsafeVersion() throws Exception {
+    TransportConf v1Conf = getConf(1, true);
+    try (AuthEngine client = new AuthEngine("appId", "secret", v1Conf)) {
+      byte[] clientPrivateKey = Hex.decode(clientPrivate);
+      client.setClientPrivateKey(clientPrivateKey);
+      AuthMessage clientChallenge =
+              
AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex)));
+      AuthMessage serverResponse =
+              
AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(serverResponseHex)));
+      // Verify that the client will accept an old transcript.
+      client.deriveSessionCipher(clientChallenge, serverResponse);
+      TransportCipher clientCipher = client.sessionCipher();
+      assert(clientCipher instanceof CtrTransportCipher);
+      CtrTransportCipher ctrTransportCipher = (CtrTransportCipher) 
clientCipher;
+      assertEquals(Hex.encode(ctrTransportCipher.getKey().getEncoded()), 
unsafeDerivedKey);
+      assertEquals(Hex.encode(ctrTransportCipher.getInputIv()), inputIv);
+      assertEquals(Hex.encode(ctrTransportCipher.getOutputIv()), outputIv);
+    }
+  }
+
+  @Test
+  public void testCtrEncryptedMessage() throws Exception {
+    try (AuthEngine client = new AuthEngine("appId", "secret", conf);
+         AuthEngine server = new AuthEngine("appId", "secret", conf)) {
+      AuthMessage clientChallenge = client.challenge();
+      AuthMessage serverResponse = server.response(clientChallenge);
+      client.deriveSessionCipher(clientChallenge, serverResponse);
+
+      TransportCipher clientCipher = server.sessionCipher();
+      assert(clientCipher instanceof CtrTransportCipher);
+      CtrTransportCipher ctrTransportCipher = (CtrTransportCipher) 
clientCipher;
+      CtrTransportCipher.EncryptionHandler handler =
+              new CtrTransportCipher.EncryptionHandler(ctrTransportCipher);
+
+      byte[] data = new byte[CtrTransportCipher.STREAM_BUFFER_SIZE + 1];
+      new Random().nextBytes(data);
+      ByteBuf buf = Unpooled.wrappedBuffer(data);
+
+      ByteArrayWritableChannel channel = new 
ByteArrayWritableChannel(data.length);
+      CtrTransportCipher.EncryptedMessage emsg = 
handler.createEncryptedMessage(buf);
+      while (emsg.transferred() < emsg.count()) {
+        emsg.transferTo(channel, emsg.transferred());
+      }
+      assertEquals(data.length, channel.length());
+    }
+  }
+
+  @Test
+  public void testCtrEncryptedMessageWhenTransferringZeroBytes() throws 
Exception {
+    try (AuthEngine client = new AuthEngine("appId", "secret", conf);
+         AuthEngine server = new AuthEngine("appId", "secret", conf)) {
+      AuthMessage clientChallenge = client.challenge();
+      AuthMessage serverResponse = server.response(clientChallenge);
+      client.deriveSessionCipher(clientChallenge, serverResponse);
+      TransportCipher clientCipher = server.sessionCipher();
+      assert(clientCipher instanceof CtrTransportCipher);
+      CtrTransportCipher ctrTransportCipher = (CtrTransportCipher) 
clientCipher;
+      CtrTransportCipher.EncryptionHandler handler =
+              new CtrTransportCipher.EncryptionHandler(ctrTransportCipher);
+      int testDataLength = 4;
+      FileRegion region = mock(FileRegion.class);
+      when(region.count()).thenReturn((long) testDataLength);
+      // Make `region.transferTo` do nothing in first call and transfer 4 
bytes in the second one.
+      when(region.transferTo(any(), anyLong())).thenAnswer(new Answer<Long>() {
+
+        private boolean firstTime = true;
+
+        @Override
+        public Long answer(InvocationOnMock invocationOnMock) throws Throwable 
{
+          if (firstTime) {
+            firstTime = false;
+            return 0L;
+          } else {
+            WritableByteChannel channel = invocationOnMock.getArgument(0);
+            channel.write(ByteBuffer.wrap(new byte[testDataLength]));
+            return (long) testDataLength;
+          }
+        }
+      });
+
+      CtrTransportCipher.EncryptedMessage emsg = 
handler.createEncryptedMessage(region);
+      ByteArrayWritableChannel channel = new 
ByteArrayWritableChannel(testDataLength);
+      // "transferTo" should act correctly when the underlying FileRegion 
transfers 0 bytes.
+      assertEquals(0L, emsg.transferTo(channel, emsg.transferred()));
+      assertEquals(testDataLength, emsg.transferTo(channel, 
emsg.transferred()));
+      assertEquals(emsg.transferred(), emsg.count());
+      assertEquals(4, channel.length());
+    }
+  }
+}
diff --git 
a/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java
 
b/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java
new file mode 100644
index 000000000000..20efb8d57dcb
--- /dev/null
+++ 
b/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java
@@ -0,0 +1,339 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.crypto;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelPromise;
+import org.apache.spark.network.util.*;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
+
+import javax.crypto.AEADBadTagException;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.WritableByteChannel;
+import java.util.Arrays;
+
+import static org.junit.jupiter.api.Assertions.*;
+import static org.mockito.Mockito.*;
+
+public class GcmAuthEngineSuite extends AuthEngineSuite {
+
+  @BeforeAll
+  public static void setUp() {
+    // Uses GCM mode
+    conf = getConf(2, false);
+  }
+
+  @Test
+  public void testGcmEncryptedMessage() throws Exception {
+    TransportConf gcmConf = getConf(2, false);
+    try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf);
+         AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) {
+      AuthMessage clientChallenge = client.challenge();
+      AuthMessage serverResponse = server.response(clientChallenge);
+      client.deriveSessionCipher(clientChallenge, serverResponse);
+      TransportCipher clientCipher = server.sessionCipher();
+      // Verify that it derives a GcmTransportCipher
+      assert (clientCipher instanceof GcmTransportCipher);
+      GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) 
clientCipher;
+      GcmTransportCipher.EncryptionHandler encryptionHandler =
+              gcmTransportCipher.getEncryptionHandler();
+      GcmTransportCipher.DecryptionHandler decryptionHandler =
+              gcmTransportCipher.getDecryptionHandler();
+      // Allocating 1.5x the buffer size to test multiple segments and a 
fractional segment.
+      int plaintextSegmentSize = GcmTransportCipher.CIPHERTEXT_BUFFER_SIZE - 
16;
+      byte[] data = new byte[plaintextSegmentSize + (plaintextSegmentSize / 
2)];
+      // Just writing some bytes.
+      data[0] = 'a';
+      data[data.length / 2] = 'b';
+      data[data.length - 10] = 'c';
+      ByteBuf buf = Unpooled.wrappedBuffer(data);
+
+      // Mock the context and capture the arguments passed to it
+      ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
+      ChannelPromise promise = mock(ChannelPromise.class);
+      ArgumentCaptor<GcmTransportCipher.GcmEncryptedMessage> 
captorWrappedEncrypted =
+              
ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class);
+      encryptionHandler.write(ctx, buf, promise);
+      verify(ctx).write(captorWrappedEncrypted.capture(), eq(promise));
+
+      // Get the encrypted value and pass it to the decryption handler
+      GcmTransportCipher.GcmEncryptedMessage encrypted =
+              captorWrappedEncrypted.getValue();
+      ByteBuffer ciphertextBuffer =
+              ByteBuffer.allocate((int) encrypted.count());
+      ByteBufferWriteableChannel channel =
+              new ByteBufferWriteableChannel(ciphertextBuffer);
+      encrypted.transferTo(channel, 0);
+      ciphertextBuffer.flip();
+      ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer);
+
+      // Capture the decrypted values and verify them
+      ArgumentCaptor<ByteBuf> captorPlaintext = 
ArgumentCaptor.forClass(ByteBuf.class);
+      decryptionHandler.channelRead(ctx, ciphertext);
+      verify(ctx, times(2))
+              .fireChannelRead(captorPlaintext.capture());
+      ByteBuf lastPlaintextSegment = captorPlaintext.getValue();
+      assertEquals(plaintextSegmentSize/2,
+              lastPlaintextSegment.readableBytes());
+      assertEquals('c',
+              lastPlaintextSegment.getByte((plaintextSegmentSize/2) - 10));
+    }
+  }
+
+  static class FakeRegion extends AbstractFileRegion {
+    private final ByteBuffer[] source;
+    private int sourcePosition;
+    private final long count;
+
+    FakeRegion(ByteBuffer... source) {
+      this.source = source;
+      sourcePosition = 0;
+      count = remaining();
+    }
+
+    private long remaining() {
+      long remaining = 0;
+      for (ByteBuffer buffer : source) {
+        remaining += buffer.remaining();
+      }
+      return remaining;
+    }
+
+    @Override
+    public long position() {
+      return 0;
+    }
+
+    @Override
+    public long transferred() {
+      return count - remaining();
+    }
+
+    @Override
+    public long count() {
+      return count;
+    }
+
+    @Override
+    public long transferTo(WritableByteChannel target, long position) throws 
IOException {
+      if (sourcePosition < source.length) {
+        ByteBuffer currentBuffer = source[sourcePosition];
+        long written = target.write(currentBuffer);
+        if (!currentBuffer.hasRemaining()) {
+          sourcePosition++;
+        }
+        return written;
+      } else {
+        return 0;
+      }
+    }
+
+    @Override
+    protected void deallocate() {
+    }
+  }
+
+  private static ByteBuffer getTestByteBuf(int size, byte fill) {
+    byte[] data = new byte[size];
+    Arrays.fill(data, fill);
+    return ByteBuffer.wrap(data);
+  }
+
+  @Test
+  public void testGcmEncryptedMessageFileRegion() throws Exception {
+    TransportConf gcmConf = getConf(2, false);
+    try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf);
+         AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) {
+      AuthMessage clientChallenge = client.challenge();
+      AuthMessage serverResponse = server.response(clientChallenge);
+      client.deriveSessionCipher(clientChallenge, serverResponse);
+      TransportCipher clientCipher = server.sessionCipher();
+      // Verify that it derives a GcmTransportCipher
+      assert (clientCipher instanceof GcmTransportCipher);
+      GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) 
clientCipher;
+      GcmTransportCipher.EncryptionHandler encryptionHandler =
+              gcmTransportCipher.getEncryptionHandler();
+      GcmTransportCipher.DecryptionHandler decryptionHandler =
+              gcmTransportCipher.getDecryptionHandler();
+      // Allocating 1.5x the buffer size to test multiple segments and a 
fractional segment.
+      int plaintextSegmentSize = GcmTransportCipher.CIPHERTEXT_BUFFER_SIZE - 
16;
+      int halfSegmentSize = plaintextSegmentSize / 2;
+      int totalSize = plaintextSegmentSize + halfSegmentSize;
+
+      // Set up some fragmented segments to test
+      ByteBuffer halfSegment = getTestByteBuf(halfSegmentSize, (byte) 'a');
+      int smallFragmentSize = 128;
+      ByteBuffer smallFragment = getTestByteBuf(smallFragmentSize, (byte) 'b');
+      int remainderSize = totalSize - halfSegmentSize - smallFragmentSize;
+      ByteBuffer remainder = getTestByteBuf(remainderSize, (byte) 'c');
+      FakeRegion fakeRegion = new FakeRegion(halfSegment, smallFragment, 
remainder);
+      assertEquals(totalSize, fakeRegion.count());
+
+      ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
+      ChannelPromise promise = mock(ChannelPromise.class);
+      ArgumentCaptor<GcmTransportCipher.GcmEncryptedMessage> 
captorWrappedEncrypted =
+              
ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class);
+      encryptionHandler.write(ctx, fakeRegion, promise);
+      verify(ctx).write(captorWrappedEncrypted.capture(), eq(promise));
+
+      // Get the encrypted value and pass it to the decryption handler
+      GcmTransportCipher.GcmEncryptedMessage encrypted =
+              captorWrappedEncrypted.getValue();
+      ByteBuffer ciphertextBuffer =
+              ByteBuffer.allocate((int) encrypted.count());
+      ByteBufferWriteableChannel channel =
+              new ByteBufferWriteableChannel(ciphertextBuffer);
+
+      // We'll simulate the FileRegion only transferring half a segment.
+      // The encrypted message should buffer the partial segment plaintext.
+      long ciphertextTransferred = 0;
+      while (ciphertextTransferred < encrypted.count()) {
+        long chunkTransferred = encrypted.transferTo(channel, 0);
+        ciphertextTransferred += chunkTransferred;
+      }
+      assertEquals(encrypted.count(), ciphertextTransferred);
+
+      ciphertextBuffer.flip();
+      ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer);
+
+      // Capture the decrypted values and verify them
+      ArgumentCaptor<ByteBuf> captorPlaintext = 
ArgumentCaptor.forClass(ByteBuf.class);
+      decryptionHandler.channelRead(ctx, ciphertext);
+      verify(ctx, times(2)).fireChannelRead(captorPlaintext.capture());
+      ByteBuf plaintext = captorPlaintext.getValue();
+      // We expect this to be the last partial plaintext segment
+      int expectedLength = totalSize % plaintextSegmentSize;
+      assertEquals(expectedLength, plaintext.readableBytes());
+      // This will be the "remainder" segment that is filled to 'c'
+      assertEquals('c', plaintext.getByte(0));
+    }
+  }
+
+
+  @Test
+  public void testGcmUnalignedDecryption() throws Exception {
+    TransportConf gcmConf = getConf(2, false);
+    try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf);
+         AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) {
+      AuthMessage clientChallenge = client.challenge();
+      AuthMessage serverResponse = server.response(clientChallenge);
+      client.deriveSessionCipher(clientChallenge, serverResponse);
+      TransportCipher clientCipher = server.sessionCipher();
+      // Verify that it derives a GcmTransportCipher
+      assert (clientCipher instanceof GcmTransportCipher);
+      GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) 
clientCipher;
+      GcmTransportCipher.EncryptionHandler encryptionHandler =
+              gcmTransportCipher.getEncryptionHandler();
+      GcmTransportCipher.DecryptionHandler decryptionHandler =
+              gcmTransportCipher.getDecryptionHandler();
+      // Allocating 1.5x the buffer size to test multiple segments and a 
fractional segment.
+      int plaintextSegmentSize = GcmTransportCipher.CIPHERTEXT_BUFFER_SIZE - 
16;
+      int plaintextSize = plaintextSegmentSize + (plaintextSegmentSize / 2);
+      byte[] data = new byte[plaintextSize];
+      Arrays.fill(data, (byte) 'x');
+      ByteBuf buf = Unpooled.wrappedBuffer(data);
+
+      // Mock the context and capture the arguments passed to it
+      ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
+      ChannelPromise promise = mock(ChannelPromise.class);
+      ArgumentCaptor<GcmTransportCipher.GcmEncryptedMessage> 
captorWrappedEncrypted =
+              
ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class);
+      encryptionHandler.write(ctx, buf, promise);
+      verify(ctx).write(captorWrappedEncrypted.capture(), eq(promise));
+
+      // Get the encrypted value and pass it to the decryption handler
+      GcmTransportCipher.GcmEncryptedMessage encrypted =
+              captorWrappedEncrypted.getValue();
+      ByteBuffer ciphertextBuffer =
+              ByteBuffer.allocate((int) encrypted.count());
+      ByteBufferWriteableChannel channel =
+              new ByteBufferWriteableChannel(ciphertextBuffer);
+      encrypted.transferTo(channel, 0);
+      ciphertextBuffer.flip();
+      ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer);
+
+      // Split up the ciphertext into some different sized chunks
+      int firstChunkSize = plaintextSize / 2;
+      ByteBuf mockCiphertext = spy(ciphertext);
+      when(mockCiphertext.readableBytes())
+              .thenReturn(firstChunkSize, firstChunkSize).thenCallRealMethod();
+
+      // Capture the decrypted values and verify them
+      ArgumentCaptor<ByteBuf> captorPlaintext = 
ArgumentCaptor.forClass(ByteBuf.class);
+      decryptionHandler.channelRead(ctx, mockCiphertext);
+      verify(ctx, times(2)).fireChannelRead(captorPlaintext.capture());
+      ByteBuf lastPlaintextSegment = captorPlaintext.getValue();
+      assertEquals(plaintextSegmentSize/2,
+              lastPlaintextSegment.readableBytes());
+      assertEquals('x',
+              lastPlaintextSegment.getByte((plaintextSegmentSize/2) - 10));
+    }
+  }
+
+  @Test
+  public void testCorruptGcmEncryptedMessage() throws Exception {
+    TransportConf gcmConf = getConf(2, false);
+
+    try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf);
+         AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) {
+      AuthMessage clientChallenge = client.challenge();
+      AuthMessage serverResponse = server.response(clientChallenge);
+      client.deriveSessionCipher(clientChallenge, serverResponse);
+
+      TransportCipher clientCipher = server.sessionCipher();
+      assert (clientCipher instanceof GcmTransportCipher);
+
+      GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) 
clientCipher;
+      GcmTransportCipher.EncryptionHandler encryptionHandler =
+              gcmTransportCipher.getEncryptionHandler();
+      GcmTransportCipher.DecryptionHandler decryptionHandler =
+              gcmTransportCipher.getDecryptionHandler();
+      byte[] zeroData = new byte[1024 * 32];
+      // Just writing some bytes.
+      ByteBuf buf = Unpooled.wrappedBuffer(zeroData);
+
+      // Mock the context and capture the arguments passed to it
+      ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
+      ChannelPromise promise = mock(ChannelPromise.class);
+      ArgumentCaptor<GcmTransportCipher.GcmEncryptedMessage> 
captorWrappedEncrypted =
+              
ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class);
+      encryptionHandler.write(ctx, buf, promise);
+      verify(ctx).write(captorWrappedEncrypted.capture(), eq(promise));
+
+      GcmTransportCipher.GcmEncryptedMessage encrypted =
+              captorWrappedEncrypted.getValue();
+      ByteBuffer ciphertextBuffer =
+              ByteBuffer.allocate((int) encrypted.count());
+      ByteBufferWriteableChannel channel =
+              new ByteBufferWriteableChannel(ciphertextBuffer);
+      encrypted.transferTo(channel, 0);
+      ciphertextBuffer.flip();
+      ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer);
+
+      byte b = ciphertext.getByte(100);
+      // Inverting the bits of the 100th bit
+      ciphertext.setByte(100, ~b & 0xFF);
+      assertThrows(AEADBadTagException.class, () -> 
decryptionHandler.channelRead(ctx, ciphertext));
+    }
+  }
+}
diff --git 
a/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java
 
b/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java
index da62d3b2de31..8977f29034fe 100644
--- 
a/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java
+++ 
b/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java
@@ -41,10 +41,10 @@ import static org.mockito.Mockito.when;
 public class TransportCipherSuite {
 
   @Test
-  public void testBufferNotLeaksOnInternalError() throws IOException {
+  public void testCtrBufferNotLeaksOnInternalError() throws IOException {
     String algorithm = "TestAlgorithm";
     TransportConf conf = new TransportConf("Test", MapConfigProvider.EMPTY);
-    TransportCipher cipher = new TransportCipher(conf.cryptoConf(), 
conf.cipherTransformation(),
+    CtrTransportCipher cipher = new CtrTransportCipher(conf.cryptoConf(),
       new SecretKeySpec(new byte[256], algorithm), new byte[0], new byte[0]) {
 
       @Override
diff --git a/docs/security.md b/docs/security.md
index 4c73c749788f..1b5dcb383645 100644
--- a/docs/security.md
+++ b/docs/security.md
@@ -207,6 +207,15 @@ The following table describes the different options 
available for configuring th
   </td>
   <td>2.2.0</td>
 </tr>
+<tr>
+  <td><code>spark.network.crypto.cipher</code></td>
+  <td>AES/CTR/NoPadding</td>
+  <td>
+    Cipher mode to use. Defaults "AES/CTR/NoPadding" for backward 
compatibility, which is not authenticated. 
+    Recommended to use "AES/GCM/NoPadding", which is an authenticated 
encryption mode.
+  </td>
+  <td>4.0.0</td>
+</tr>
 <tr>
   <td><code>spark.network.crypto.authEngineVersion</code></td>
   <td>1</td>


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to