mridulm commented on code in PR #43541:
URL: https://github.com/apache/spark/pull/43541#discussion_r1374076018
##########
common/network-common/src/main/java/org/apache/spark/network/TransportContext.java:
##########
@@ -189,15 +204,32 @@ public TransportChannelHandler
initializePipeline(SocketChannel channel) {
*/
public TransportChannelHandler initializePipeline(
SocketChannel channel,
- RpcHandler channelRpcHandler) {
+ RpcHandler channelRpcHandler,
+ boolean isClient) {
try {
TransportChannelHandler channelHandler = createChannelHandler(channel,
channelRpcHandler);
ChannelPipeline pipeline = channel.pipeline();
+
if (nettyLogger.getLoggingHandler() != null) {
pipeline.addLast("loggingHandler", nettyLogger.getLoggingHandler());
}
+
+ if (sslEncryptionEnabled()) {
+ SslHandler sslHandler;
+ try {
+ sslHandler = new SslHandler(
+ sslFactory.createSSLEngine(isClient, pipeline.channel().alloc()));
Review Comment:
nit:
```suggestion
sslFactory.createSSLEngine(isClient, channel.alloc()));
```
##########
common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java:
##########
@@ -293,6 +296,26 @@ public void initChannel(SocketChannel ch) {
} else if (cf.cause() != null) {
throw new IOException(String.format("Failed to connect to %s", address),
cf.cause());
}
+ if (context.sslEncryptionEnabled()) {
+ final SslHandler sslHandler =
cf.channel().pipeline().get(SslHandler.class);
+ Future<Channel> future = sslHandler.handshakeFuture().addListener(
+ new GenericFutureListener<Future<Channel>>() {
+ @Override
+ public void operationComplete(final Future<Channel> handshakeFuture)
{
+ if (handshakeFuture.isSuccess()) {
+ logger.debug("{} successfully completed TLS handshake to ",
address);
+ } else {
+ if (logger.isDebugEnabled()) {
+ logger.debug(
+ "failed to complete TLS handshake to " + address,
+ handshakeFuture.cause());
+ }
+ cf.channel().close();
+ }
+ }
+ });
+ future.await(conf.connectionTimeoutMs());
Review Comment:
Throw exception when await fails ? (after closing connection)
##########
common/network-common/src/main/java/org/apache/spark/network/TransportContext.java:
##########
@@ -189,15 +204,32 @@ public TransportChannelHandler
initializePipeline(SocketChannel channel) {
*/
public TransportChannelHandler initializePipeline(
SocketChannel channel,
- RpcHandler channelRpcHandler) {
+ RpcHandler channelRpcHandler,
+ boolean isClient) {
try {
TransportChannelHandler channelHandler = createChannelHandler(channel,
channelRpcHandler);
ChannelPipeline pipeline = channel.pipeline();
+
if (nettyLogger.getLoggingHandler() != null) {
pipeline.addLast("loggingHandler", nettyLogger.getLoggingHandler());
}
+
+ if (sslEncryptionEnabled()) {
+ SslHandler sslHandler;
+ try {
+ sslHandler = new SslHandler(
+ sslFactory.createSSLEngine(isClient, pipeline.channel().alloc()));
+ } catch (Exception e) {
+ throw new RuntimeException("Error creating Netty SslHandler", e);
+ }
+ pipeline.addFirst("NettySslEncryptionHandler", sslHandler);
+ // Cannot use zero-copy with HTTPS, so we add in our
ChunkedWriteHandler just before the
+ // MessageEncoder
+ pipeline.addLast("chunkedWriter", new ChunkedWriteHandler());
Review Comment:
`addFirst` and `addLast` for `sslHandler` should be the same at this point.
But, if we want to do `addFirst`, then perhaps ensure `ChunkedWriteHandler`
is added with `addAfter` `sslHandler` ?
##########
common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java:
##########
@@ -293,6 +296,26 @@ public void initChannel(SocketChannel ch) {
} else if (cf.cause() != null) {
throw new IOException(String.format("Failed to connect to %s", address),
cf.cause());
}
+ if (context.sslEncryptionEnabled()) {
+ final SslHandler sslHandler =
cf.channel().pipeline().get(SslHandler.class);
+ Future<Channel> future = sslHandler.handshakeFuture().addListener(
+ new GenericFutureListener<Future<Channel>>() {
+ @Override
+ public void operationComplete(final Future<Channel> handshakeFuture)
{
+ if (handshakeFuture.isSuccess()) {
+ logger.debug("{} successfully completed TLS handshake to ",
address);
+ } else {
+ if (logger.isDebugEnabled()) {
+ logger.debug(
Review Comment:
Do we want to make this `info` instead ? I am assuming it wont be noisy, and
when it does fail, it is something we want to know about ?
##########
common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java:
##########
@@ -86,6 +85,17 @@ public class ExternalShuffleIntegrationSuite {
new byte[54321],
};
+ private static TransportConf createTransportConf(String maxRetries, String
rddEnabled) {
Review Comment:
nit: specify using the actual types and convert it to `String` in this
method.
```suggestion
private static TransportConf createTransportConf(int maxRetries, boolean
rddEnabled) {
```
##########
common/network-common/src/main/java/org/apache/spark/network/TransportContext.java:
##########
@@ -189,15 +204,32 @@ public TransportChannelHandler
initializePipeline(SocketChannel channel) {
*/
public TransportChannelHandler initializePipeline(
SocketChannel channel,
- RpcHandler channelRpcHandler) {
+ RpcHandler channelRpcHandler,
+ boolean isClient) {
try {
TransportChannelHandler channelHandler = createChannelHandler(channel,
channelRpcHandler);
ChannelPipeline pipeline = channel.pipeline();
+
Review Comment:
super nit:
```suggestion
```
##########
common/network-common/src/main/java/org/apache/spark/network/TransportContext.java:
##########
@@ -223,6 +255,33 @@ protected MessageToMessageDecoder<ByteBuf> getDecoder() {
return DECODER;
}
+ private SSLFactory createSslFactory() {
+ if (conf.sslRpcEnabled()) {
+ if (conf.sslRpcEnabledAndKeysAreValid()) {
+ return new SSLFactory.Builder()
+ .openSslEnabled(conf.sslRpcOpenSslEnabled())
+ .requestedProtocol(conf.sslRpcProtocol())
+ .requestedCiphers(conf.sslRpcRequestedCiphers())
+ .keyStore(conf.sslRpcKeyStore(), conf.sslRpcKeyStorePassword())
+ .privateKey(conf.sslRpcPrivateKey())
+ .keyPassword(conf.sslRpcKeyPassword())
+ .certChain(conf.sslRpcCertChain())
+ .trustStore(
+ conf.sslRpcTrustStore(),
+ conf.sslRpcTrustStorePassword(),
+ conf.sslRpcTrustStoreReloadingEnabled(),
+ conf.sslRpctrustStoreReloadIntervalMs())
+ .build();
+ } else {
+ logger.error("RPC SSL encryption enabled but keys not found!" +
+ "Please ensure the configured keys are present.");
+ throw new RuntimeException("RPC SSL encryption enabled but keys not
found!");
Review Comment:
```suggestion
throw new IllegalArgumentException("RPC SSL encryption enabled but
keys not found!");
```
##########
common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleTransportContextSuite.java:
##########
@@ -90,15 +94,22 @@ private ByteBuf getDecodableMessageBuf(Message req) throws
Exception {
public void testInitializePipeline() throws IOException {
// SPARK-43987: test that the FinalizedHandler is added to the pipeline
only when configured
for (boolean enabled : new boolean[]{true, false}) {
- ShuffleTransportContext ctx = createShuffleTransportContext(enabled);
- SocketChannel channel = new NioSocketChannel();
- RpcHandler rpcHandler = mock(RpcHandler.class);
- ctx.initializePipeline(channel, rpcHandler);
- String handlerName =
ShuffleTransportContext.FinalizedHandler.HANDLER_NAME;
- if (enabled) {
- Assertions.assertNotNull(channel.pipeline().get(handlerName));
- } else {
- Assertions.assertNull(channel.pipeline().get(handlerName));
+ for (boolean isClient: new boolean[]{true, false}) {
Review Comment:
super nit: `isClient` -> `client`
##########
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleTransportContext.java:
##########
@@ -52,7 +53,8 @@
* */
public class ShuffleTransportContext extends TransportContext {
private static final Logger logger =
LoggerFactory.getLogger(ShuffleTransportContext.class);
- private static final ShuffleMessageDecoder SHUFFLE_DECODER =
+ @VisibleForTesting
+ protected static ShuffleMessageDecoder SHUFFLE_DECODER =
new ShuffleMessageDecoder(MessageDecoder.INSTANCE);
Review Comment:
Instead of exposing the variable, add a method to reinitialize it - and
annotate as for use by tests.
##########
common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java:
##########
@@ -86,6 +85,17 @@ public class ExternalShuffleIntegrationSuite {
new byte[54321],
};
+ private static TransportConf createTransportConf(String maxRetries, String
rddEnabled) {
+ HashMap<String, String> config = new HashMap<>();
+ config.put("spark.shuffle.io.maxRetries", maxRetries);
+ config.put(Constants.SHUFFLE_SERVICE_FETCH_RDD_ENABLED, rddEnabled);
+ return new TransportConf("shuffle", new MapConfigProvider(config));
+ }
+
+ protected TransportConf createTransportConfForFetchNoServerTest() {
Review Comment:
It is unclear to me why this method is named this way ...
##########
common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleTransportContextSuite.java:
##########
@@ -90,15 +94,22 @@ private ByteBuf getDecodableMessageBuf(Message req) throws
Exception {
public void testInitializePipeline() throws IOException {
// SPARK-43987: test that the FinalizedHandler is added to the pipeline
only when configured
for (boolean enabled : new boolean[]{true, false}) {
- ShuffleTransportContext ctx = createShuffleTransportContext(enabled);
- SocketChannel channel = new NioSocketChannel();
- RpcHandler rpcHandler = mock(RpcHandler.class);
- ctx.initializePipeline(channel, rpcHandler);
- String handlerName =
ShuffleTransportContext.FinalizedHandler.HANDLER_NAME;
- if (enabled) {
- Assertions.assertNotNull(channel.pipeline().get(handlerName));
- } else {
- Assertions.assertNull(channel.pipeline().get(handlerName));
+ for (boolean isClient: new boolean[]{true, false}) {
+ // Since the decoder is not Shareable, reset it between test runs to
avoid errors since it's
+ // used both across ShuffleTransportContextSuite and
SslShuffleTransportContextSuite
+ // and server/clients
Review Comment:
The decoder is not being used here (other than configuring the pipeline) -
why do we need to reset it ?
##########
common/network-common/src/test/java/org/apache/spark/network/SslChunkFetchIntegrationSuite.java:
##########
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.network;
+
+import java.io.File;
+import java.io.RandomAccessFile;
+import java.nio.ByteBuffer;
+import java.util.Random;
+
+import com.google.common.io.Closeables;
+import org.junit.jupiter.api.BeforeAll;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.TransportConf;
+import org.apache.spark.network.ssl.SslSampleConfigs;
+
+
+public class SslChunkFetchIntegrationSuite extends ChunkFetchIntegrationSuite {
+
+ @BeforeAll
+ public static void setUp() throws Exception {
+ int bufSize = 100000;
+ final ByteBuffer buf = ByteBuffer.allocate(bufSize);
+ for (int i = 0; i < bufSize; i ++) {
+ buf.put((byte) i);
+ }
+ buf.flip();
+ bufferChunk = new NioManagedBuffer(buf);
+
+ testFile = File.createTempFile("shuffle-test-file", "txt");
+ testFile.deleteOnExit();
+ RandomAccessFile fp = new RandomAccessFile(testFile, "rw");
+ boolean shouldSuppressIOException = true;
+ try {
+ byte[] fileContent = new byte[1024];
+ new Random().nextBytes(fileContent);
+ fp.write(fileContent);
+ shouldSuppressIOException = false;
+ } finally {
+ Closeables.close(fp, shouldSuppressIOException);
+ }
+
+ final TransportConf conf = new TransportConf(
+ "shuffle",
SslSampleConfigs.createDefaultConfigProviderForRpcNamespace());
+ fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10,
testFile.length() - 25);
+
+ streamManager = new StreamManager() {
+ @Override
+ public ManagedBuffer getChunk(long streamId, int chunkIndex) {
+ assertEquals(STREAM_ID, streamId);
+ if (chunkIndex == BUFFER_CHUNK_INDEX) {
+ return new NioManagedBuffer(buf);
+ } else if (chunkIndex == FILE_CHUNK_INDEX) {
+ return new FileSegmentManagedBuffer(conf, testFile, 10,
testFile.length() - 25);
+ } else {
+ throw new IllegalArgumentException("Invalid chunk index: " +
chunkIndex);
+ }
+ }
+ };
+ RpcHandler handler = new RpcHandler() {
+ @Override
+ public void receive(
+ TransportClient client,
+ ByteBuffer message,
+ RpcResponseCallback callback) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return streamManager;
+ }
+ };
+ context = new TransportContext(conf, handler);
+ server = context.createServer();
+ clientFactory = context.createClientFactory();
+ }
Review Comment:
If I am not wrong, the only change between this and
`ChunkFetchIntegrationSuite.setUp` is `conf` right ?
If yes, instead of duplicating the method - pass the `conf` to a common
static method to initialize for both Suites instead instead ?
(Same comment for the other Suites too)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]