This is an automated email from the ASF dual-hosted git repository.
mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 41420ca1033 [SPARK-45429][CORE] Add helper classes for SSL RPC
communication
41420ca1033 is described below
commit 41420ca10336d6dc876337a5fb562c5e68ffe8f9
Author: Hasnain Lakhani <[email protected]>
AuthorDate: Sun Oct 15 00:45:44 2023 -0500
[SPARK-45429][CORE] Add helper classes for SSL RPC communication
### What changes were proposed in this pull request?
This PR adds helper classes for SSL RPC communication that are needed to
work around the fact that `netty` does not support zero-copy transfers.
These mirror the existing `MessageWithHeader` and `MessageEncoder` classes
with very minor differences. But the differences were just enough that it
didn't seem easy to refactor/consolidate, and since we don't expect these
classes to change much I hope it's ok.
### Why are the changes needed?
These are needed to support transferring `ManagedBuffer`s into a form that
can be transferred by `netty` over the network, since netty's encryption
support does not support zero-copy transfers.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Added unit tests
```
build/sbt
> project network-common
> testOnly org.apache.spark.network.protocol.EncryptedMessageWithHeaderSuite
```
The rest of the changes and integration were tested as part of
https://github.com/apache/spark/pull/42685
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #43244 from hasnain-db/spark-tls-helpers.
Authored-by: Hasnain Lakhani <[email protected]>
Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
.../protocol/EncryptedMessageWithHeader.java | 148 ++++++++++++++++++++
.../spark/network/protocol/SslMessageEncoder.java | 108 +++++++++++++++
.../protocol/EncryptedMessageWithHeaderSuite.java | 154 +++++++++++++++++++++
3 files changed, 410 insertions(+)
diff --git
a/common/network-common/src/main/java/org/apache/spark/network/protocol/EncryptedMessageWithHeader.java
b/common/network-common/src/main/java/org/apache/spark/network/protocol/EncryptedMessageWithHeader.java
new file mode 100644
index 00000000000..7e7ba85ebf6
--- /dev/null
+++
b/common/network-common/src/main/java/org/apache/spark/network/protocol/EncryptedMessageWithHeader.java
@@ -0,0 +1,148 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.network.protocol;
+
+import java.io.EOFException;
+import java.io.InputStream;
+import javax.annotation.Nullable;
+
+import com.google.common.base.Preconditions;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.ByteBufAllocator;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.stream.ChunkedStream;
+import io.netty.handler.stream.ChunkedInput;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/**
+ * A wrapper message that holds two separate pieces (a header and a body).
+ *
+ * The header must be a ByteBuf, while the body can be any InputStream or
ChunkedStream
+ */
+public class EncryptedMessageWithHeader implements ChunkedInput<ByteBuf> {
+
+ @Nullable private final ManagedBuffer managedBuffer;
+ private final ByteBuf header;
+ private final int headerLength;
+ private final Object body;
+ private final long bodyLength;
+ private long totalBytesTransferred;
+
+ /**
+ * Construct a new EncryptedMessageWithHeader.
+ *
+ * @param managedBuffer the {@link ManagedBuffer} that the message body came
from. This needs to
+ * be passed in so that the buffer can be freed when
this message is
+ * deallocated. Ownership of the caller's reference to
this buffer is
+ * transferred to this class, so if the caller wants to
continue to use the
+ * ManagedBuffer in other messages then they will need
to call retain() on
+ * it before passing it to this constructor.
+ * @param header the message header.
+ * @param body the message body.
+ * @param bodyLength the length of the message body, in bytes.
+ */
+
+ public EncryptedMessageWithHeader(
+ @Nullable ManagedBuffer managedBuffer, ByteBuf header, Object body, long
bodyLength) {
+ Preconditions.checkArgument(body instanceof InputStream || body instanceof
ChunkedStream,
+ "Body must be an InputStream or a ChunkedStream.");
+ this.managedBuffer = managedBuffer;
+ this.header = header;
+ this.headerLength = header.readableBytes();
+ this.body = body;
+ this.bodyLength = bodyLength;
+ this.totalBytesTransferred = 0;
+ }
+
+ @Override
+ public ByteBuf readChunk(ChannelHandlerContext ctx) throws Exception {
+ return readChunk(ctx.alloc());
+ }
+
+ @Override
+ public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception {
+ if (isEndOfInput()) {
+ return null;
+ }
+
+ if (totalBytesTransferred < headerLength) {
+ totalBytesTransferred += headerLength;
+ return header.retain();
+ } else if (body instanceof InputStream) {
+ InputStream stream = (InputStream) body;
+ int available = stream.available();
+ if (available <= 0) {
+ available = (int) (length() - totalBytesTransferred);
+ } else {
+ available = (int) Math.min(available, length() -
totalBytesTransferred);
+ }
+ ByteBuf buffer = allocator.buffer(available);
+ int toRead = Math.min(available, buffer.writableBytes());
+ int read = buffer.writeBytes(stream, toRead);
+ if (read >= 0) {
+ totalBytesTransferred += read;
+ return buffer;
+ } else {
+ throw new EOFException("Unable to read bytes from InputStream");
+ }
+ } else if (body instanceof ChunkedStream) {
+ ChunkedStream stream = (ChunkedStream) body;
+ long old = stream.transferredBytes();
+ ByteBuf buffer = stream.readChunk(allocator);
+ long read = stream.transferredBytes() - old;
+ if (read >= 0) {
+ totalBytesTransferred += read;
+ assert(totalBytesTransferred <= length());
+ return buffer;
+ } else {
+ throw new EOFException("Unable to read bytes from ChunkedStream");
+ }
+ } else {
+ return null;
+ }
+ }
+
+ @Override
+ public long length() {
+ return headerLength + bodyLength;
+ }
+
+ @Override
+ public long progress() {
+ return totalBytesTransferred;
+ }
+
+ @Override
+ public boolean isEndOfInput() throws Exception {
+ return (headerLength + bodyLength) == totalBytesTransferred;
+ }
+
+ @Override
+ public void close() throws Exception {
+ header.release();
+ if (managedBuffer != null) {
+ managedBuffer.release();
+ }
+ if (body instanceof InputStream) {
+ ((InputStream) body).close();
+ } else if (body instanceof ChunkedStream) {
+ ((ChunkedStream) body).close();
+ }
+ }
+}
diff --git
a/common/network-common/src/main/java/org/apache/spark/network/protocol/SslMessageEncoder.java
b/common/network-common/src/main/java/org/apache/spark/network/protocol/SslMessageEncoder.java
new file mode 100644
index 00000000000..f43d0789ee6
--- /dev/null
+++
b/common/network-common/src/main/java/org/apache/spark/network/protocol/SslMessageEncoder.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.network.protocol;
+
+import java.io.InputStream;
+import java.util.List;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.MessageToMessageEncoder;
+import io.netty.handler.stream.ChunkedStream;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Encoder used by the server side to encode secure (SSL) server-to-client
responses.
+ * This encoder is stateless so it is safe to be shared by multiple threads.
+ */
[email protected]
+public final class SslMessageEncoder extends MessageToMessageEncoder<Message> {
+
+ private final Logger logger =
LoggerFactory.getLogger(SslMessageEncoder.class);
+
+ private SslMessageEncoder() {}
+
+ public static final SslMessageEncoder INSTANCE = new SslMessageEncoder();
+
+ /**
+ * Encodes a Message by invoking its encode() method. For non-data messages,
we will add one
+ * ByteBuf to 'out' containing the total frame length, the message type, and
the message itself.
+ * In the case of a ChunkFetchSuccess, we will also add the ManagedBuffer
corresponding to the
+ * data to 'out'.
+ */
+ @Override
+ public void encode(ChannelHandlerContext ctx, Message in, List<Object> out)
throws Exception {
+ Object body = null;
+ long bodyLength = 0;
+ boolean isBodyInFrame = false;
+
+ // If the message has a body, take it out...
+ // For SSL, zero-copy transfer will not work, so we will check if
+ // the body is an InputStream, and if so, use an EncryptedMessageWithHeader
+ // to wrap the header+body appropriately (for thread safety).
+ if (in.body() != null) {
+ try {
+ bodyLength = in.body().size();
+ body = in.body().convertToNettyForSsl();
+ isBodyInFrame = in.isBodyInFrame();
+ } catch (Exception e) {
+ in.body().release();
+ if (in instanceof AbstractResponseMessage) {
+ AbstractResponseMessage resp = (AbstractResponseMessage) in;
+ // Re-encode this message as a failure response.
+ String error = e.getMessage() != null ? e.getMessage() : "null";
+ logger.error(String.format("Error processing %s for client %s",
+ in, ctx.channel().remoteAddress()), e);
+ encode(ctx, resp.createFailureResponse(error), out);
+ } else {
+ throw e;
+ }
+ return;
+ }
+ }
+
+ Message.Type msgType = in.type();
+ // All messages have the frame length, message type, and message itself.
The frame length
+ // may optionally include the length of the body data, depending on what
message is being
+ // sent.
+ int headerLength = 8 + msgType.encodedLength() + in.encodedLength();
+ long frameLength = headerLength + (isBodyInFrame ? bodyLength : 0);
+ ByteBuf header = ctx.alloc().buffer(headerLength);
+ header.writeLong(frameLength);
+ msgType.encode(header);
+ in.encode(header);
+ assert header.writableBytes() == 0;
+
+ if (body != null && bodyLength > 0) {
+ if (body instanceof ByteBuf) {
+ out.add(Unpooled.wrappedBuffer(header, (ByteBuf) body));
+ } else if (body instanceof InputStream || body instanceof ChunkedStream)
{
+ // For now, assume the InputStream is doing proper chunking.
+ out.add(new EncryptedMessageWithHeader(in.body(), header, body,
bodyLength));
+ } else {
+ throw new IllegalArgumentException(
+ "Body must be a ByteBuf, ChunkedStream or an InputStream");
+ }
+ } else {
+ out.add(header);
+ }
+ }
+}
diff --git
a/common/network-common/src/test/java/org/apache/spark/network/protocol/EncryptedMessageWithHeaderSuite.java
b/common/network-common/src/test/java/org/apache/spark/network/protocol/EncryptedMessageWithHeaderSuite.java
new file mode 100644
index 00000000000..7478fa1db71
--- /dev/null
+++
b/common/network-common/src/test/java/org/apache/spark/network/protocol/EncryptedMessageWithHeaderSuite.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.io.ByteArrayInputStream;
+import java.io.InputStream;
+import java.util.Random;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.buffer.ByteBufAllocator;
+import io.netty.handler.stream.ChunkedStream;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+public class EncryptedMessageWithHeaderSuite {
+
+ // Tests the case where the body is an input stream and that we manage the
refcounts of the
+ // buffer properly
+ @Test
+ public void testInputStreamBodyFromManagedBuffer() throws Exception {
+ byte[] randomData = new byte[128];
+ new Random().nextBytes(randomData);
+ ByteBuf sourceBuffer = Unpooled.copiedBuffer(randomData);
+ InputStream body = new ByteArrayInputStream(sourceBuffer.array());
+ ByteBuf header = Unpooled.copyLong(42);
+
+ long expectedHeaderValue = header.getLong(header.readerIndex());
+ assertEquals(1, header.refCnt());
+ assertEquals(1, sourceBuffer.refCnt());
+ ManagedBuffer managedBuf = new NettyManagedBuffer(sourceBuffer);
+
+ EncryptedMessageWithHeader msg = new EncryptedMessageWithHeader(
+ managedBuf, header, body, managedBuf.size());
+ ByteBufAllocator allocator = ByteBufAllocator.DEFAULT;
+
+ // First read should just read the header
+ ByteBuf headerResult = msg.readChunk(allocator);
+ assertEquals(header.capacity(), headerResult.readableBytes());
+ assertEquals(expectedHeaderValue, headerResult.readLong());
+ assertEquals(header.capacity(), msg.progress());
+ assertFalse(msg.isEndOfInput());
+
+ // Second read should read the body
+ ByteBuf bodyResult = msg.readChunk(allocator);
+ assertEquals(randomData.length + header.capacity(), msg.progress());
+ assertTrue(msg.isEndOfInput());
+
+ // Validate we read it all
+ assertEquals(bodyResult.readableBytes(), randomData.length);
+ for (int i = 0; i < randomData.length; i++) {
+ assertEquals(bodyResult.readByte(), randomData[i]);
+ }
+
+ // Closing the message should release the source buffer
+ msg.close();
+ assertEquals(0, sourceBuffer.refCnt());
+
+ // The header still has a reference we got
+ assertEquals(1, header.refCnt());
+ headerResult.release();
+ assertEquals(0, header.refCnt());
+ }
+
+ // Tests the case where the body is a chunked stream and that we are fine
when there is no
+ // input managed buffer
+ @Test
+ public void testChunkedStream() throws Exception {
+ int bodyLength = 129;
+ int chunkSize = 64;
+ byte[] randomData = new byte[bodyLength];
+ new Random().nextBytes(randomData);
+ InputStream inputStream = new ByteArrayInputStream(randomData);
+ ChunkedStream body = new ChunkedStream(inputStream, chunkSize);
+ ByteBuf header = Unpooled.copyLong(42);
+
+ long expectedHeaderValue = header.getLong(header.readerIndex());
+ assertEquals(1, header.refCnt());
+
+ EncryptedMessageWithHeader msg = new EncryptedMessageWithHeader(null,
header, body, bodyLength);
+ ByteBufAllocator allocator = ByteBufAllocator.DEFAULT;
+
+ // First read should just read the header
+ ByteBuf headerResult = msg.readChunk(allocator);
+ assertEquals(header.capacity(), headerResult.readableBytes());
+ assertEquals(expectedHeaderValue, headerResult.readLong());
+ assertEquals(header.capacity(), msg.progress());
+ assertFalse(msg.isEndOfInput());
+
+ // Next 2 reads should read full buffers
+ int readIndex = 0;
+ for (int i = 1; i <= 2; i++) {
+ ByteBuf bodyResult = msg.readChunk(allocator);
+ assertEquals(header.capacity() + (i*chunkSize), msg.progress());
+ assertFalse(msg.isEndOfInput());
+
+ // Validate we read data correctly
+ assertEquals(bodyResult.readableBytes(), chunkSize);
+ assert(bodyResult.readableBytes() < (randomData.length - readIndex));
+ while (bodyResult.readableBytes() > 0) {
+ assertEquals(bodyResult.readByte(), randomData[readIndex++]);
+ }
+ }
+
+ // Last read should be partial
+ ByteBuf bodyResult = msg.readChunk(allocator);
+ assertEquals(header.capacity() + bodyLength, msg.progress());
+ assertTrue(msg.isEndOfInput());
+
+ // Validate we read the byte properly
+ assertEquals(bodyResult.readableBytes(), 1);
+ assertEquals(bodyResult.readByte(), randomData[readIndex]);
+
+ // Closing the message should close the input stream
+ msg.close();
+ assertTrue(body.isEndOfInput());
+
+ // The header still has a reference we got
+ assertEquals(1, header.refCnt());
+ headerResult.release();
+ assertEquals(0, header.refCnt());
+ }
+
+ @Test
+ public void testByteBufIsNotSupported() throws Exception {
+ // Validate that ByteBufs are not supported. This test can be updated
+ // when we add support for them
+ ByteBuf header = Unpooled.copyLong(42);
+ assertThrows(IllegalArgumentException.class, () -> {
+ EncryptedMessageWithHeader msg = new EncryptedMessageWithHeader(
+ null, header, header, 4);
+ });
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]