This is an automated email from the ASF dual-hosted git repository.

nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 3ff8812cd [CELEBORN-1348] Update infrastructure for SSL communication
3ff8812cd is described below

commit 3ff8812cdd43d7c1295ab99de40e9cbbd2ffd8d2
Author: Mridul Muralidharan <mridulatgmail.com>
AuthorDate: Mon Apr 1 19:59:44 2024 +0800

    [CELEBORN-1348] Update infrastructure for SSL communication
    
    ### What changes were proposed in this pull request?
    
    Update infrastructure for SSL support.
    Please see #2416 for the consolidated PR with all the changes for reference.
    
    ### Why are the changes needed?
    
    At a high level, the changes are:
    * `ManagedBuffer.convertToNettyForSsl`, to support SSL encryption.
    * Add `EncryptedMessageWithHeader`, which is used to wrap the message and 
body, for use with SSL.
    * `SslMessageEncoder`  is an encoder for SSL
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    
    The overall PR #2416 (and this PR as well) passes all tests, and this PR 
includes relevant subset of tests.
    
    Closes #2427 from mridulm/update-infra-for-ssl.
    
    Authored-by: Mridul Muralidharan <mridulatgmail.com>
    Signed-off-by: SteNicholas <[email protected]>
---
 LICENSE                                            |   3 +
 .../flink/buffer/FlinkNettyManagedBuffer.java      |   6 +
 .../network/buffer/FileSegmentManagedBuffer.java   |   7 +
 .../common/network/buffer/ManagedBuffer.java       |  11 ++
 .../common/network/buffer/NettyManagedBuffer.java  |   5 +
 .../common/network/buffer/NioManagedBuffer.java    |   5 +
 .../protocol/EncryptedMessageWithHeader.java       | 149 +++++++++++++++++++
 .../common/network/protocol/SslMessageEncoder.java | 105 ++++++++++++++
 .../celeborn/common/network/TestManagedBuffer.java |   5 +
 .../protocol/EncryptedMessageWithHeaderSuiteJ.java | 160 +++++++++++++++++++++
 10 files changed, 456 insertions(+)

diff --git a/LICENSE b/LICENSE
index 76555a026..d5f68e4d7 100644
--- a/LICENSE
+++ b/LICENSE
@@ -212,11 +212,14 @@ Apache License 2.0
 Apache Spark
 
./client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
 
./client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
+./common/src/main/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeader.java
+./common/src/main/java/org/apache/celeborn/common/network/protocol/SslMessageEncoder.java
 
./common/src/main/java/org/apache/celeborn/common/network/ssl/ReloadingX509TrustManager.java
 ./common/src/main/java/org/apache/celeborn/common/network/util/NettyLogger.java
 ./common/src/main/java/org/apache/celeborn/common/unsafe/Platform.java
 ./common/src/main/java/org/apache/celeborn/common/util/JavaUtils.java
 ./common/src/main/scala/org/apache/celeborn/common/util/SignalUtils.scala
+./common/src/test/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeaderSuiteJ.java
 
./common/src/test/java/org/apache/celeborn/common/network/ssl/ReloadingX509TrustManagerSuiteJ.java
 
./common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java
 
./worker/src/main/java/org/apache/celeborn/service/deploy/worker/shuffledb/DB.java
diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/FlinkNettyManagedBuffer.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/FlinkNettyManagedBuffer.java
index aeb736a39..e3add9117 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/FlinkNettyManagedBuffer.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/FlinkNettyManagedBuffer.java
@@ -17,6 +17,7 @@
 
 package org.apache.celeborn.plugin.flink.buffer;
 
+import java.io.IOException;
 import java.io.InputStream;
 import java.nio.ByteBuffer;
 
@@ -64,4 +65,9 @@ public class FlinkNettyManagedBuffer extends ManagedBuffer {
   public Object convertToNetty() {
     return buf.duplicate().retain();
   }
+
+  @Override
+  public Object convertToNettyForSsl() throws IOException {
+    return buf.duplicate().retain();
+  }
 }
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/buffer/FileSegmentManagedBuffer.java
 
b/common/src/main/java/org/apache/celeborn/common/network/buffer/FileSegmentManagedBuffer.java
index 6af9e4305..5d11e8780 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/buffer/FileSegmentManagedBuffer.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/buffer/FileSegmentManagedBuffer.java
@@ -24,6 +24,7 @@ import java.nio.file.StandardOpenOption;
 
 import com.google.common.io.ByteStreams;
 import io.netty.channel.DefaultFileRegion;
+import io.netty.handler.stream.ChunkedStream;
 import org.apache.commons.lang3.builder.ToStringBuilder;
 import org.apache.commons.lang3.builder.ToStringStyle;
 
@@ -132,6 +133,12 @@ public final class FileSegmentManagedBuffer extends 
ManagedBuffer {
     }
   }
 
+  @Override
+  public Object convertToNettyForSsl() throws IOException {
+    // Cannot use zero-copy with SSL
+    return new ChunkedStream(createInputStream(), 
conf.maxSslEncryptedBlockSize());
+  }
+
   public File getFile() {
     return file;
   }
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/buffer/ManagedBuffer.java
 
b/common/src/main/java/org/apache/celeborn/common/network/buffer/ManagedBuffer.java
index ce320d9d7..9ab05781b 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/buffer/ManagedBuffer.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/buffer/ManagedBuffer.java
@@ -71,4 +71,15 @@ public abstract class ManagedBuffer {
    * the caller will be responsible for releasing this new reference.
    */
   public abstract Object convertToNetty() throws IOException;
+
+  /**
+   * Convert the buffer into a Netty object, used to write the data out with 
SSL encryption, which
+   * cannot use {@link io.netty.channel.FileRegion}. The return value is 
either a {@link
+   * io.netty.buffer.ByteBuf}, a {@link 
io.netty.handler.stream.ChunkedStream}, or a {@link
+   * java.io.InputStream}.
+   *
+   * <p>If this method returns a ByteBuf, then that buffer's reference count 
will be incremented and
+   * the caller will be responsible for releasing this new reference.
+   */
+  public abstract Object convertToNettyForSsl() throws IOException;
 }
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/buffer/NettyManagedBuffer.java
 
b/common/src/main/java/org/apache/celeborn/common/network/buffer/NettyManagedBuffer.java
index 60cf8625b..0528c8c74 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/buffer/NettyManagedBuffer.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/buffer/NettyManagedBuffer.java
@@ -76,6 +76,11 @@ public class NettyManagedBuffer extends ManagedBuffer {
     return buf.duplicate().retain();
   }
 
+  @Override
+  public Object convertToNettyForSsl() throws IOException {
+    return buf.duplicate().retain();
+  }
+
   @Override
   public String toString() {
     return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/buffer/NioManagedBuffer.java
 
b/common/src/main/java/org/apache/celeborn/common/network/buffer/NioManagedBuffer.java
index b14cb1f8e..97c31ef2c 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/buffer/NioManagedBuffer.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/buffer/NioManagedBuffer.java
@@ -64,6 +64,11 @@ public class NioManagedBuffer extends ManagedBuffer {
     return Unpooled.wrappedBuffer(buf);
   }
 
+  @Override
+  public Object convertToNettyForSsl() throws IOException {
+    return Unpooled.wrappedBuffer(buf);
+  }
+
   @Override
   public String toString() {
     return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeader.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeader.java
new file mode 100644
index 000000000..df2ab1a92
--- /dev/null
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeader.java
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.protocol;
+
+import java.io.EOFException;
+import java.io.InputStream;
+
+import javax.annotation.Nullable;
+
+import com.google.common.base.Preconditions;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.ByteBufAllocator;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.stream.ChunkedInput;
+import io.netty.handler.stream.ChunkedStream;
+
+import org.apache.celeborn.common.network.buffer.ManagedBuffer;
+
+/**
+ * A wrapper message that holds two separate pieces (a header and a body).
+ *
+ * <p>The header must be a ByteBuf, while the body can be any InputStream or 
ChunkedStream Based on
+ * 
common/network-common/org.apache.spark.network.protocol.EncryptedMessageWithHeader
+ */
+public class EncryptedMessageWithHeader implements ChunkedInput<ByteBuf> {
+
+  @Nullable private final ManagedBuffer managedBuffer;
+  private final ByteBuf header;
+  private final int headerLength;
+  private final Object body;
+  private final long bodyLength;
+  private long totalBytesTransferred;
+
+  /**
+   * Construct a new EncryptedMessageWithHeader.
+   *
+   * @param managedBuffer the {@link ManagedBuffer} that the message body came 
from. This needs to
+   *     be passed in so that the buffer can be freed when this message is 
deallocated. Ownership of
+   *     the caller's reference to this buffer is transferred to this class, 
so if the caller wants
+   *     to continue to use the ManagedBuffer in other messages then they will 
need to call retain()
+   *     on it before passing it to this constructor.
+   * @param header the message header.
+   * @param body the message body.
+   * @param bodyLength the length of the message body, in bytes.
+   */
+  public EncryptedMessageWithHeader(
+      @Nullable ManagedBuffer managedBuffer, ByteBuf header, Object body, long 
bodyLength) {
+    Preconditions.checkArgument(
+        body instanceof InputStream || body instanceof ChunkedStream,
+        "Body must be an InputStream or a ChunkedStream.");
+    this.managedBuffer = managedBuffer;
+    this.header = header;
+    this.headerLength = header.readableBytes();
+    this.body = body;
+    this.bodyLength = bodyLength;
+    this.totalBytesTransferred = 0;
+  }
+
+  @Override
+  public ByteBuf readChunk(ChannelHandlerContext ctx) throws Exception {
+    return readChunk(ctx.alloc());
+  }
+
+  @Override
+  public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception {
+    if (isEndOfInput()) {
+      return null;
+    }
+
+    if (totalBytesTransferred < headerLength) {
+      totalBytesTransferred += headerLength;
+      return header.retain();
+    } else if (body instanceof InputStream) {
+      InputStream stream = (InputStream) body;
+      int available = stream.available();
+      if (available <= 0) {
+        available = (int) (length() - totalBytesTransferred);
+      } else {
+        available = (int) Math.min(available, length() - 
totalBytesTransferred);
+      }
+      ByteBuf buffer = allocator.buffer(available);
+      int toRead = Math.min(available, buffer.writableBytes());
+      int read = buffer.writeBytes(stream, toRead);
+      if (read >= 0) {
+        totalBytesTransferred += read;
+        return buffer;
+      } else {
+        throw new EOFException("Unable to read bytes from InputStream");
+      }
+    } else if (body instanceof ChunkedStream) {
+      ChunkedStream stream = (ChunkedStream) body;
+      long old = stream.transferredBytes();
+      ByteBuf buffer = stream.readChunk(allocator);
+      long read = stream.transferredBytes() - old;
+      if (read >= 0) {
+        totalBytesTransferred += read;
+        assert (totalBytesTransferred <= length());
+        return buffer;
+      } else {
+        throw new EOFException("Unable to read bytes from ChunkedStream");
+      }
+    } else {
+      return null;
+    }
+  }
+
+  @Override
+  public long length() {
+    return headerLength + bodyLength;
+  }
+
+  @Override
+  public long progress() {
+    return totalBytesTransferred;
+  }
+
+  @Override
+  public boolean isEndOfInput() throws Exception {
+    return (headerLength + bodyLength) == totalBytesTransferred;
+  }
+
+  @Override
+  public void close() throws Exception {
+    header.release();
+    if (managedBuffer != null) {
+      managedBuffer.release();
+    }
+    if (body instanceof InputStream) {
+      ((InputStream) body).close();
+    } else if (body instanceof ChunkedStream) {
+      ((ChunkedStream) body).close();
+    }
+  }
+}
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/SslMessageEncoder.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/SslMessageEncoder.java
new file mode 100644
index 000000000..508b6a13d
--- /dev/null
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/SslMessageEncoder.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.protocol;
+
+import java.io.InputStream;
+import java.util.List;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.MessageToMessageEncoder;
+import io.netty.handler.stream.ChunkedStream;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Encoder used by the server side to encode secure (SSL) server-to-client 
responses. This encoder
+ * is stateless so it is safe to be shared by multiple threads. Based on
+ * common/network-common/org.apache.spark.network.protocol.SslMessageEncoder
+ */
[email protected]
+public final class SslMessageEncoder extends MessageToMessageEncoder<Message> {
+
+  private static final Logger logger = 
LoggerFactory.getLogger(SslMessageEncoder.class);
+  public static final SslMessageEncoder INSTANCE = new SslMessageEncoder();
+
+  private SslMessageEncoder() {}
+
+  /**
+   * Encodes a Message by invoking its encode() method. For non-data messages, 
we will add one
+   * ByteBuf to 'out' containing the total frame length, the message type, and 
the message itself.
+   * In the case of a ChunkFetchSuccess, we will also add the ManagedBuffer 
corresponding to the
+   * data to 'out'.
+   */
+  @Override
+  public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) 
throws Exception {
+    Object body = null;
+    int bodyLength = 0;
+
+    // If the message has a body, take it out...
+    // For SSL, zero-copy transfer will not work, so we will check if
+    // the body is an InputStream, and if so, use an EncryptedMessageWithHeader
+    // to wrap the header+body appropriately (for thread safety).
+    if (in.body() != null) {
+      try {
+        bodyLength = (int) in.body().size();
+        body = in.body().convertToNettyForSsl();
+      } catch (Exception e) {
+        in.body().release();
+        if (in instanceof ResponseMessage) {
+          ResponseMessage resp = (ResponseMessage) in;
+          // Re-encode this message as a failure response.
+          String error = e.getMessage() != null ? e.getMessage() : "null";
+          logger.error(
+              String.format("Error processing %s for client %s", in, 
ctx.channel().remoteAddress()),
+              e);
+          encode(ctx, resp.createFailureResponse(error), out);
+        } else {
+          throw e;
+        }
+        return;
+      }
+    }
+
+    Message.Type msgType = in.type();
+    // message size, message type size, body size, message encoded length
+    int headerLength = 4 + msgType.encodedLength() + 4 + in.encodedLength();
+    ByteBuf header = ctx.alloc().heapBuffer(headerLength);
+    header.writeInt(in.encodedLength());
+    msgType.encode(header);
+    header.writeInt(bodyLength);
+    in.encode(header);
+    assert header.writableBytes() == 0;
+
+    if (body != null && bodyLength > 0) {
+      if (body instanceof ByteBuf) {
+        out.add(Unpooled.wrappedBuffer(header, (ByteBuf) body));
+      } else if (body instanceof InputStream || body instanceof ChunkedStream) 
{
+        // For now, assume the InputStream is doing proper chunking.
+        out.add(new EncryptedMessageWithHeader(in.body(), header, body, 
bodyLength));
+      } else {
+        throw new IllegalArgumentException(
+            "Body must be a ByteBuf, ChunkedStream or an InputStream");
+      }
+    } else {
+      out.add(header);
+    }
+  }
+}
diff --git 
a/common/src/test/java/org/apache/celeborn/common/network/TestManagedBuffer.java
 
b/common/src/test/java/org/apache/celeborn/common/network/TestManagedBuffer.java
index b5f196fe2..ad3cc4521 100644
--- 
a/common/src/test/java/org/apache/celeborn/common/network/TestManagedBuffer.java
+++ 
b/common/src/test/java/org/apache/celeborn/common/network/TestManagedBuffer.java
@@ -79,6 +79,11 @@ public class TestManagedBuffer extends ManagedBuffer {
     return underlying.convertToNetty();
   }
 
+  @Override
+  public Object convertToNettyForSsl() throws IOException {
+    return underlying.convertToNettyForSsl();
+  }
+
   @Override
   public int hashCode() {
     return underlying.hashCode();
diff --git 
a/common/src/test/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeaderSuiteJ.java
 
b/common/src/test/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeaderSuiteJ.java
new file mode 100644
index 000000000..0fbf7e9c9
--- /dev/null
+++ 
b/common/src/test/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeaderSuiteJ.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.protocol;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+
+import java.io.ByteArrayInputStream;
+import java.io.InputStream;
+import java.util.Random;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.ByteBufAllocator;
+import io.netty.buffer.Unpooled;
+import io.netty.handler.stream.ChunkedStream;
+import org.junit.Test;
+
+import org.apache.celeborn.common.network.buffer.ManagedBuffer;
+import org.apache.celeborn.common.network.buffer.NettyManagedBuffer;
+
+/*
+ * Based on 
common/network-common/org.apache.spark.network.protocol.EncryptedMessageWithHeaderSuite
+ */
+public class EncryptedMessageWithHeaderSuiteJ {
+
+  // Tests the case where the body is an input stream and that we manage the 
refcounts of the
+  // buffer properly
+  @Test
+  public void testInputStreamBodyFromManagedBuffer() throws Exception {
+    byte[] randomData = new byte[128];
+    new Random().nextBytes(randomData);
+    ByteBuf sourceBuffer = Unpooled.copiedBuffer(randomData);
+    InputStream body = new ByteArrayInputStream(sourceBuffer.array());
+    ByteBuf header = Unpooled.copyLong(42);
+
+    long expectedHeaderValue = header.getLong(header.readerIndex());
+    assertEquals(1, header.refCnt());
+    assertEquals(1, sourceBuffer.refCnt());
+    ManagedBuffer managedBuf = new NettyManagedBuffer(sourceBuffer);
+
+    EncryptedMessageWithHeader msg =
+        new EncryptedMessageWithHeader(managedBuf, header, body, 
managedBuf.size());
+    ByteBufAllocator allocator = ByteBufAllocator.DEFAULT;
+
+    // First read should just read the header
+    ByteBuf headerResult = msg.readChunk(allocator);
+    assertEquals(header.capacity(), headerResult.readableBytes());
+    assertEquals(expectedHeaderValue, headerResult.readLong());
+    assertEquals(header.capacity(), msg.progress());
+    assertFalse(msg.isEndOfInput());
+
+    // Second read should read the body
+    ByteBuf bodyResult = msg.readChunk(allocator);
+    assertEquals(randomData.length + header.capacity(), msg.progress());
+    assertTrue(msg.isEndOfInput());
+
+    // Validate we read it all
+    assertEquals(bodyResult.readableBytes(), randomData.length);
+    for (int i = 0; i < randomData.length; i++) {
+      assertEquals(bodyResult.readByte(), randomData[i]);
+    }
+
+    // Closing the message should release the source buffer
+    msg.close();
+    assertEquals(0, sourceBuffer.refCnt());
+
+    // The header still has a reference we got
+    assertEquals(1, header.refCnt());
+    headerResult.release();
+    assertEquals(0, header.refCnt());
+  }
+
+  // Tests the case where the body is a chunked stream and that we are fine 
when there is no
+  // input managed buffer
+  @Test
+  public void testChunkedStream() throws Exception {
+    int bodyLength = 129;
+    int chunkSize = 64;
+    byte[] randomData = new byte[bodyLength];
+    new Random().nextBytes(randomData);
+    InputStream inputStream = new ByteArrayInputStream(randomData);
+    ChunkedStream body = new ChunkedStream(inputStream, chunkSize);
+    ByteBuf header = Unpooled.copyLong(42);
+
+    long expectedHeaderValue = header.getLong(header.readerIndex());
+    assertEquals(1, header.refCnt());
+
+    EncryptedMessageWithHeader msg = new EncryptedMessageWithHeader(null, 
header, body, bodyLength);
+    ByteBufAllocator allocator = ByteBufAllocator.DEFAULT;
+
+    // First read should just read the header
+    ByteBuf headerResult = msg.readChunk(allocator);
+    assertEquals(header.capacity(), headerResult.readableBytes());
+    assertEquals(expectedHeaderValue, headerResult.readLong());
+    assertEquals(header.capacity(), msg.progress());
+    assertFalse(msg.isEndOfInput());
+
+    // Next 2 reads should read full buffers
+    int readIndex = 0;
+    for (int i = 1; i <= 2; i++) {
+      ByteBuf bodyResult = msg.readChunk(allocator);
+      assertEquals(header.capacity() + (i * chunkSize), msg.progress());
+      assertFalse(msg.isEndOfInput());
+
+      // Validate we read data correctly
+      assertEquals(bodyResult.readableBytes(), chunkSize);
+      assert (bodyResult.readableBytes() < (randomData.length - readIndex));
+      while (bodyResult.readableBytes() > 0) {
+        assertEquals(bodyResult.readByte(), randomData[readIndex++]);
+      }
+    }
+
+    // Last read should be partial
+    ByteBuf bodyResult = msg.readChunk(allocator);
+    assertEquals(header.capacity() + bodyLength, msg.progress());
+    assertTrue(msg.isEndOfInput());
+
+    // Validate we read the byte properly
+    assertEquals(bodyResult.readableBytes(), 1);
+    assertEquals(bodyResult.readByte(), randomData[readIndex]);
+
+    // Closing the message should close the input stream
+    msg.close();
+    assertTrue(body.isEndOfInput());
+
+    // The header still has a reference we got
+    assertEquals(1, header.refCnt());
+    headerResult.release();
+    assertEquals(0, header.refCnt());
+  }
+
+  @Test
+  public void testByteBufIsNotSupported() throws Exception {
+    // Validate that ByteBufs are not supported. This test can be updated
+    // when we add support for them
+    ByteBuf header = Unpooled.copyLong(42);
+    assertThrows(
+        IllegalArgumentException.class,
+        () -> {
+          EncryptedMessageWithHeader msg = new 
EncryptedMessageWithHeader(null, header, header, 4);
+        });
+  }
+}

Reply via email to