This is an automated email from the ASF dual-hosted git repository. mridulm80 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 884f6f71172 [SPARK-45544][CORE] Integrate SSL support into TransportContext 884f6f71172 is described below commit 884f6f71172156ccc7d95ed022c8fb8baadc3c0a Author: Hasnain Lakhani <hasnain.lakh...@databricks.com> AuthorDate: Sun Oct 29 20:58:18 2023 -0500 [SPARK-45544][CORE] Integrate SSL support into TransportContext ### What changes were proposed in this pull request? This integrates SSL support into TransportContext and related modules so that the RPC SSL functionality can work when properly configured. ### Why are the changes needed? This is needed in order to support SSL for RPC connections. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI Ran the following tests: ``` build/sbt -P yarn > project network-common > testOnly > project network-shuffle > testOnly > project core > testOnly *Ssl* > project yarn > testOnly org.apache.spark.network.yarn.SslYarnShuffleServiceWithRocksDBBackendSuite ``` I verified traffic was encrypted using TLS using two mechanisms: * Enabled trace level logging for Netty and JDK SSL and saw logs confirming TLS handshakes were happening * I ran wireshark on my machine and snooped on traffic while sending queries shuffling a fixed string. Without any encryption, I could find that string in the network traffic. With this encryption enabled, that string did not show up, and wireshark logs confirmed a TLS handshake was happening. ### Was this patch authored or co-authored using generative AI tooling? No Closes #43541 from hasnain-db/spark-tls-final. Authored-by: Hasnain Lakhani <hasnain.lakh...@databricks.com> Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com> --- .../org/apache/spark/network/TransportContext.java | 70 ++++++++++++++++++++-- .../network/client/TransportClientFactory.java | 26 +++++++- .../spark/network/server/TransportServer.java | 2 +- .../apache/spark/network/util/TransportConf.java | 8 --- .../spark/network/ChunkFetchIntegrationSuite.java | 6 +- .../network/SslChunkFetchIntegrationSuite.java | 22 ++++--- .../client/SslTransportClientFactorySuite.java | 29 +++++---- .../client/TransportClientFactorySuite.java | 8 +-- .../network/shuffle/ShuffleTransportContext.java | 10 ++-- .../shuffle/ExternalShuffleIntegrationSuite.java | 29 +++++---- .../shuffle/ExternalShuffleSecuritySuite.java | 14 ++++- .../shuffle/ShuffleTransportContextSuite.java | 33 +++++----- .../SslExternalShuffleIntegrationSuite.java | 44 ++++++++++++++ .../shuffle/SslExternalShuffleSecuritySuite.java | 35 +++++++---- .../shuffle/SslShuffleTransportContextSuite.java | 28 +++++---- .../network/yarn/SslYarnShuffleServiceSuite.scala | 2 +- 16 files changed, 265 insertions(+), 101 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java index 51d074a4ddb..90ca4f4c46a 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java @@ -23,13 +23,17 @@ import io.netty.handler.codec.MessageToMessageDecoder; import java.io.Closeable; import java.util.ArrayList; import java.util.List; +import javax.annotation.Nullable; import com.codahale.metrics.Counter; import io.netty.channel.Channel; import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoopGroup; import io.netty.channel.socket.SocketChannel; +import io.netty.handler.ssl.SslHandler; +import io.netty.handler.stream.ChunkedWriteHandler; import io.netty.handler.timeout.IdleStateHandler; +import io.netty.handler.codec.MessageToMessageEncoder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -37,6 +41,8 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.client.TransportResponseHandler; +import org.apache.spark.network.protocol.Message; +import org.apache.spark.network.protocol.SslMessageEncoder; import org.apache.spark.network.protocol.MessageDecoder; import org.apache.spark.network.protocol.MessageEncoder; import org.apache.spark.network.server.ChunkFetchRequestHandler; @@ -45,6 +51,7 @@ import org.apache.spark.network.server.TransportChannelHandler; import org.apache.spark.network.server.TransportRequestHandler; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.TransportServerBootstrap; +import org.apache.spark.network.ssl.SSLFactory; import org.apache.spark.network.util.IOMode; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.NettyLogger; @@ -72,6 +79,8 @@ public class TransportContext implements Closeable { private final TransportConf conf; private final RpcHandler rpcHandler; private final boolean closeIdleConnections; + // Non-null if SSL is enabled, null otherwise. + @Nullable private final SSLFactory sslFactory; // Number of registered connections to the shuffle service private Counter registeredConnections = new Counter(); @@ -87,7 +96,8 @@ public class TransportContext implements Closeable { * RPC to load it and cause to load the non-exist matcher class again. JVM will report * `ClassCircularityError` to prevent such infinite recursion. (See SPARK-17714) */ - private static final MessageEncoder ENCODER = MessageEncoder.INSTANCE; + private static final MessageToMessageEncoder<Message> ENCODER = MessageEncoder.INSTANCE; + private static final MessageToMessageEncoder<Message> SSL_ENCODER = SslMessageEncoder.INSTANCE; private static final MessageDecoder DECODER = MessageDecoder.INSTANCE; // Separate thread pool for handling ChunkFetchRequest. This helps to enable throttling @@ -125,6 +135,7 @@ public class TransportContext implements Closeable { this.conf = conf; this.rpcHandler = rpcHandler; this.closeIdleConnections = closeIdleConnections; + this.sslFactory = createSslFactory(); if (conf.getModuleName() != null && conf.getModuleName().equalsIgnoreCase("shuffle") && @@ -171,8 +182,12 @@ public class TransportContext implements Closeable { return createServer(0, new ArrayList<>()); } - public TransportChannelHandler initializePipeline(SocketChannel channel) { - return initializePipeline(channel, rpcHandler); + public TransportChannelHandler initializePipeline(SocketChannel channel, boolean isClient) { + return initializePipeline(channel, rpcHandler, isClient); + } + + public boolean sslEncryptionEnabled() { + return this.sslFactory != null; } /** @@ -189,15 +204,30 @@ public class TransportContext implements Closeable { */ public TransportChannelHandler initializePipeline( SocketChannel channel, - RpcHandler channelRpcHandler) { + RpcHandler channelRpcHandler, + boolean isClient) { try { TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler); ChannelPipeline pipeline = channel.pipeline(); if (nettyLogger.getLoggingHandler() != null) { pipeline.addLast("loggingHandler", nettyLogger.getLoggingHandler()); } + + if (sslEncryptionEnabled()) { + SslHandler sslHandler; + try { + sslHandler = new SslHandler(sslFactory.createSSLEngine(isClient, channel.alloc())); + } catch (Exception e) { + throw new IllegalStateException("Error creating Netty SslHandler", e); + } + pipeline.addFirst("NettySslEncryptionHandler", sslHandler); + // Cannot use zero-copy with HTTPS, so we add in our ChunkedWriteHandler just before the + // MessageEncoder + pipeline.addLast("chunkedWriter", new ChunkedWriteHandler()); + } + pipeline - .addLast("encoder", ENCODER) + .addLast("encoder", sslEncryptionEnabled()? SSL_ENCODER : ENCODER) .addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder()) .addLast("decoder", getDecoder()) .addLast("idleStateHandler", @@ -223,6 +253,33 @@ public class TransportContext implements Closeable { return DECODER; } + private SSLFactory createSslFactory() { + if (conf.sslRpcEnabled()) { + if (conf.sslRpcEnabledAndKeysAreValid()) { + return new SSLFactory.Builder() + .openSslEnabled(conf.sslRpcOpenSslEnabled()) + .requestedProtocol(conf.sslRpcProtocol()) + .requestedCiphers(conf.sslRpcRequestedCiphers()) + .keyStore(conf.sslRpcKeyStore(), conf.sslRpcKeyStorePassword()) + .privateKey(conf.sslRpcPrivateKey()) + .keyPassword(conf.sslRpcKeyPassword()) + .certChain(conf.sslRpcCertChain()) + .trustStore( + conf.sslRpcTrustStore(), + conf.sslRpcTrustStorePassword(), + conf.sslRpcTrustStoreReloadingEnabled(), + conf.sslRpctrustStoreReloadIntervalMs()) + .build(); + } else { + logger.error("RPC SSL encryption enabled but keys not found!" + + "Please ensure the configured keys are present."); + throw new IllegalArgumentException("RPC SSL encryption enabled but keys not found!"); + } + } else { + return null; + } + } + /** * Creates the server- and client-side handler which is used to handle both RequestMessages and * ResponseMessages. The channel is expected to have been successfully created, though certain @@ -255,5 +312,8 @@ public class TransportContext implements Closeable { if (chunkFetchWorkers != null) { chunkFetchWorkers.shutdownGracefully(); } + if (sslFactory != null) { + sslFactory.destroy(); + } } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 4c1efd69206..fd48020caac 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -39,6 +39,9 @@ import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.socket.SocketChannel; +import io.netty.handler.ssl.SslHandler; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -268,7 +271,7 @@ public class TransportClientFactory implements Closeable { bootstrap.handler(new ChannelInitializer<SocketChannel>() { @Override public void initChannel(SocketChannel ch) { - TransportChannelHandler clientHandler = context.initializePipeline(ch); + TransportChannelHandler clientHandler = context.initializePipeline(ch, true); clientRef.set(clientHandler.getClient()); channelRef.set(ch); } @@ -293,6 +296,27 @@ public class TransportClientFactory implements Closeable { } else if (cf.cause() != null) { throw new IOException(String.format("Failed to connect to %s", address), cf.cause()); } + if (context.sslEncryptionEnabled()) { + final SslHandler sslHandler = cf.channel().pipeline().get(SslHandler.class); + Future<Channel> future = sslHandler.handshakeFuture().addListener( + new GenericFutureListener<Future<Channel>>() { + @Override + public void operationComplete(final Future<Channel> handshakeFuture) { + if (handshakeFuture.isSuccess()) { + logger.debug("{} successfully completed TLS handshake to ", address); + } else { + logger.info( + "failed to complete TLS handshake to " + address, handshakeFuture.cause()); + cf.channel().close(); + } + } + }); + if (!future.await(conf.connectionTimeoutMs())) { + cf.channel().close(); + throw new IOException( + String.format("Failed to connect to %s within connection timeout", address)); + } + } TransportClient client = clientRef.get(); Channel channel = channelRef.get(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java index 5b5b3f9d901..6f2e4b8a502 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -140,7 +140,7 @@ public class TransportServer implements Closeable { for (TransportServerBootstrap bootstrap : bootstraps) { rpcHandler = bootstrap.doBootstrap(ch, rpcHandler); } - context.initializePipeline(ch, rpcHandler); + context.initializePipeline(ch, rpcHandler, false); } }); diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 3ebb38e310f..eb85d2bb561 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -401,14 +401,6 @@ public class TransportConf { } } - /** - * If we can dangerously fallback to unencrypted connections if RPC over SSL is enabled - * but the key files are not present - */ - public boolean sslRpcDangerouslyFallbackIfKeysNotPresent() { - return conf.getBoolean("spark.ssl.rpc.dangerouslyFallbackIfKeysNotPresent", false); - } - /** * Flag indicating whether to share the pooled ByteBuf allocators between the different Netty * channels. If enabled then only two pooled ByteBuf allocators are created: one where caching diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 2026d3b9524..576a106934f 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -65,8 +65,13 @@ public class ChunkFetchIntegrationSuite { static ManagedBuffer bufferChunk; static ManagedBuffer fileChunk; + // This is split out so it can be invoked in a subclass with a different config @BeforeAll public static void setUp() throws Exception { + doSetUpWithConfig(new TransportConf("shuffle", MapConfigProvider.EMPTY)); + } + + public static void doSetUpWithConfig(final TransportConf conf) throws Exception { int bufSize = 100000; final ByteBuffer buf = ByteBuffer.allocate(bufSize); for (int i = 0; i < bufSize; i ++) { @@ -88,7 +93,6 @@ public class ChunkFetchIntegrationSuite { Closeables.close(fp, shouldSuppressIOException); } - final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); streamManager = new StreamManager() { diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala b/common/network-common/src/test/java/org/apache/spark/network/SslChunkFetchIntegrationSuite.java similarity index 59% copy from resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala copy to common/network-common/src/test/java/org/apache/spark/network/SslChunkFetchIntegrationSuite.java index 322d6bfdb7c..783ffd4b8c1 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala +++ b/common/network-common/src/test/java/org/apache/spark/network/SslChunkFetchIntegrationSuite.java @@ -14,21 +14,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package org.apache.spark.network; -package org.apache.spark.network.yarn +import org.junit.jupiter.api.BeforeAll; -import org.apache.spark.network.ssl.SslSampleConfigs +import org.apache.spark.network.util.TransportConf; +import org.apache.spark.network.ssl.SslSampleConfigs; -class SslYarnShuffleServiceWithRocksDBBackendSuite - extends YarnShuffleServiceWithRocksDBBackendSuite { - /** - * Override to add "spark.ssl.rpc.*" configuration parameters... - */ - override def beforeEach(): Unit = { - super.beforeEach() - // Same as SSLTestUtils.updateWithSSLConfig(), which is not available to import here. - SslSampleConfigs.createDefaultConfigMap().entrySet(). - forEach(entry => yarnConfig.set(entry.getKey, entry.getValue)) +public class SslChunkFetchIntegrationSuite extends ChunkFetchIntegrationSuite { + + @BeforeAll + public static void setUp() throws Exception { + doSetUpWithConfig(new TransportConf( + "shuffle", SslSampleConfigs.createDefaultConfigProviderForRpcNamespace())); } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala b/common/network-common/src/test/java/org/apache/spark/network/client/SslTransportClientFactorySuite.java similarity index 51% copy from resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala copy to common/network-common/src/test/java/org/apache/spark/network/client/SslTransportClientFactorySuite.java index 322d6bfdb7c..79b76b633f9 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala +++ b/common/network-common/src/test/java/org/apache/spark/network/client/SslTransportClientFactorySuite.java @@ -15,20 +15,25 @@ * limitations under the License. */ -package org.apache.spark.network.yarn +package org.apache.spark.network.client; -import org.apache.spark.network.ssl.SslSampleConfigs +import org.junit.jupiter.api.BeforeEach; -class SslYarnShuffleServiceWithRocksDBBackendSuite - extends YarnShuffleServiceWithRocksDBBackendSuite { +import org.apache.spark.network.ssl.SslSampleConfigs; +import org.apache.spark.network.server.NoOpRpcHandler; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.util.TransportConf; +import org.apache.spark.network.TransportContext; - /** - * Override to add "spark.ssl.rpc.*" configuration parameters... - */ - override def beforeEach(): Unit = { - super.beforeEach() - // Same as SSLTestUtils.updateWithSSLConfig(), which is not available to import here. - SslSampleConfigs.createDefaultConfigMap().entrySet(). - forEach(entry => yarnConfig.set(entry.getKey, entry.getValue)) +public class SslTransportClientFactorySuite extends TransportClientFactorySuite { + + @BeforeEach + public void setUp() { + conf = new TransportConf( + "shuffle", SslSampleConfigs.createDefaultConfigProviderForRpcNamespace()); + RpcHandler rpcHandler = new NoOpRpcHandler(); + context = new TransportContext(conf, rpcHandler); + server1 = context.createServer(); + server2 = context.createServer(); } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/client/TransportClientFactorySuite.java b/common/network-common/src/test/java/org/apache/spark/network/client/TransportClientFactorySuite.java index 49a2d570d96..b57f0be920c 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/client/TransportClientFactorySuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/client/TransportClientFactorySuite.java @@ -44,10 +44,10 @@ import org.apache.spark.network.util.TransportConf; import static org.junit.jupiter.api.Assertions.*; public class TransportClientFactorySuite { - private TransportConf conf; - private TransportContext context; - private TransportServer server1; - private TransportServer server2; + protected TransportConf conf; + protected TransportContext context; + protected TransportServer server1; + protected TransportServer server2; @BeforeEach public void setUp() { diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleTransportContext.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleTransportContext.java index e0971d49510..feaaa570b73 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleTransportContext.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleTransportContext.java @@ -22,6 +22,7 @@ import java.util.List; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.EventLoopGroup; import io.netty.channel.SimpleChannelInboundHandler; @@ -81,16 +82,16 @@ public class ShuffleTransportContext extends TransportContext { } @Override - public TransportChannelHandler initializePipeline(SocketChannel channel) { - TransportChannelHandler ch = super.initializePipeline(channel); + public TransportChannelHandler initializePipeline(SocketChannel channel, boolean isClient) { + TransportChannelHandler ch = super.initializePipeline(channel, isClient); addHandlerToPipeline(channel, ch); return ch; } @Override public TransportChannelHandler initializePipeline(SocketChannel channel, - RpcHandler channelRpcHandler) { - TransportChannelHandler ch = super.initializePipeline(channel, channelRpcHandler); + RpcHandler channelRpcHandler, boolean isClient) { + TransportChannelHandler ch = super.initializePipeline(channel, channelRpcHandler, isClient); addHandlerToPipeline(channel, ch); return ch; } @@ -112,6 +113,7 @@ public class ShuffleTransportContext extends TransportContext { return finalizeWorkers == null ? super.getDecoder() : SHUFFLE_DECODER; } + @ChannelHandler.Sharable static class ShuffleMessageDecoder extends MessageToMessageDecoder<ByteBuf> { private final MessageDecoder delegate; diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index b5ffa30f62d..73cb133f17e 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -32,7 +32,6 @@ import java.util.concurrent.Future; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.server.OneForOneStreamManager; @@ -57,11 +56,11 @@ public class ExternalShuffleIntegrationSuite { private static final String APP_ID = "app-id"; private static final String SORT_MANAGER = "org.apache.spark.shuffle.sort.SortShuffleManager"; - private static final int RDD_ID = 1; - private static final int SPLIT_INDEX_VALID_BLOCK = 0; + protected static final int RDD_ID = 1; + protected static final int SPLIT_INDEX_VALID_BLOCK = 0; private static final int SPLIT_INDEX_MISSING_FILE = 1; - private static final int SPLIT_INDEX_CORRUPT_LENGTH = 2; - private static final int SPLIT_INDEX_VALID_BLOCK_TO_RM = 3; + protected static final int SPLIT_INDEX_CORRUPT_LENGTH = 2; + protected static final int SPLIT_INDEX_VALID_BLOCK_TO_RM = 3; private static final int SPLIT_INDEX_MISSING_BLOCK_TO_RM = 4; // Executor 0 is sort-based @@ -86,8 +85,20 @@ public class ExternalShuffleIntegrationSuite { new byte[54321], }; + private static TransportConf createTransportConf(int maxRetries, boolean rddEnabled) { + HashMap<String, String> config = new HashMap<>(); + config.put("spark.shuffle.io.maxRetries", String.valueOf(maxRetries)); + config.put(Constants.SHUFFLE_SERVICE_FETCH_RDD_ENABLED, String.valueOf(rddEnabled)); + return new TransportConf("shuffle", new MapConfigProvider(config)); + } + + // This is split out so it can be invoked in a subclass with a different config @BeforeAll public static void beforeAll() throws IOException { + doBeforeAllWithConfig(createTransportConf(0, true)); + } + + public static void doBeforeAllWithConfig(TransportConf transportConf) throws IOException { Random rand = new Random(); for (byte[] block : exec0Blocks) { @@ -105,10 +116,7 @@ public class ExternalShuffleIntegrationSuite { dataContext0.insertCachedRddData(RDD_ID, SPLIT_INDEX_VALID_BLOCK, exec0RddBlockValid); dataContext0.insertCachedRddData(RDD_ID, SPLIT_INDEX_VALID_BLOCK_TO_RM, exec0RddBlockToRemove); - HashMap<String, String> config = new HashMap<>(); - config.put("spark.shuffle.io.maxRetries", "0"); - config.put(Constants.SHUFFLE_SERVICE_FETCH_RDD_ENABLED, "true"); - conf = new TransportConf("shuffle", new MapConfigProvider(config)); + conf = transportConf; handler = new ExternalBlockHandler( new OneForOneStreamManager(), new ExternalShuffleBlockResolver(conf, null) { @@ -319,8 +327,7 @@ public class ExternalShuffleIntegrationSuite { @Test public void testFetchNoServer() throws Exception { - TransportConf clientConf = new TransportConf("shuffle", - new MapConfigProvider(ImmutableMap.of("spark.shuffle.io.maxRetries", "0"))); + TransportConf clientConf = createTransportConf(0, false); registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); FetchResult execFetch = fetchBlocks("exec-0", new String[]{"shuffle_1_0_0", "shuffle_1_0_1"}, clientConf, 1 /* port */); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index b8beec303ae..76f82800c50 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -39,10 +39,19 @@ import org.apache.spark.network.util.TransportConf; public class ExternalShuffleSecuritySuite { - TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); + TransportConf conf = createTransportConf(false); TransportServer server; TransportContext transportContext; + protected TransportConf createTransportConf(boolean encrypt) { + if (encrypt) { + return new TransportConf("shuffle", new MapConfigProvider( + ImmutableMap.of("spark.authenticate.enableSaslEncryption", "true"))); + } else { + return new TransportConf("shuffle", MapConfigProvider.EMPTY); + } + } + @BeforeEach public void beforeEach() throws IOException { transportContext = new TransportContext(conf, new ExternalBlockHandler(conf, null)); @@ -92,8 +101,7 @@ public class ExternalShuffleSecuritySuite { throws IOException, InterruptedException { TransportConf testConf = conf; if (encrypt) { - testConf = new TransportConf("shuffle", new MapConfigProvider( - ImmutableMap.of("spark.authenticate.enableSaslEncryption", "true"))); + testConf = createTransportConf(encrypt); } try (ExternalBlockStoreClient client = diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleTransportContextSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleTransportContextSuite.java index 5484e8131a8..de164474766 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleTransportContextSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleTransportContextSuite.java @@ -60,13 +60,16 @@ public class ShuffleTransportContextSuite { blockHandler = mock(ExternalBlockHandler.class); } - ShuffleTransportContext createShuffleTransportContext(boolean separateFinalizeThread) - throws IOException { + protected TransportConf createTransportConf(boolean separateFinalizeThread) { Map<String, String> configs = new HashMap<>(); configs.put("spark.shuffle.server.finalizeShuffleMergeThreadsPercent", - separateFinalizeThread ? "1" : "0"); - TransportConf transportConf = new TransportConf("shuffle", - new MapConfigProvider(configs)); + separateFinalizeThread ? "1" : "0"); + return new TransportConf("shuffle", new MapConfigProvider(configs)); + } + + ShuffleTransportContext createShuffleTransportContext(boolean separateFinalizeThread) + throws IOException { + TransportConf transportConf = createTransportConf(separateFinalizeThread); return new ShuffleTransportContext(transportConf, blockHandler, true); } @@ -90,15 +93,17 @@ public class ShuffleTransportContextSuite { public void testInitializePipeline() throws IOException { // SPARK-43987: test that the FinalizedHandler is added to the pipeline only when configured for (boolean enabled : new boolean[]{true, false}) { - ShuffleTransportContext ctx = createShuffleTransportContext(enabled); - SocketChannel channel = new NioSocketChannel(); - RpcHandler rpcHandler = mock(RpcHandler.class); - ctx.initializePipeline(channel, rpcHandler); - String handlerName = ShuffleTransportContext.FinalizedHandler.HANDLER_NAME; - if (enabled) { - Assertions.assertNotNull(channel.pipeline().get(handlerName)); - } else { - Assertions.assertNull(channel.pipeline().get(handlerName)); + for (boolean client: new boolean[]{true, false}) { + ShuffleTransportContext ctx = createShuffleTransportContext(enabled); + SocketChannel channel = new NioSocketChannel(); + RpcHandler rpcHandler = mock(RpcHandler.class); + ctx.initializePipeline(channel, rpcHandler, client); + String handlerName = ShuffleTransportContext.FinalizedHandler.HANDLER_NAME; + if (enabled) { + Assertions.assertNotNull(channel.pipeline().get(handlerName)); + } else { + Assertions.assertNull(channel.pipeline().get(handlerName)); + } } } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslExternalShuffleIntegrationSuite.java new file mode 100644 index 00000000000..3591ccad150 --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslExternalShuffleIntegrationSuite.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.IOException; +import java.util.HashMap; + +import org.junit.jupiter.api.BeforeAll; + +import org.apache.spark.network.util.TransportConf; +import org.apache.spark.network.ssl.SslSampleConfigs; + +public class SslExternalShuffleIntegrationSuite extends ExternalShuffleIntegrationSuite { + + private static TransportConf createTransportConf(int maxRetries, boolean rddEnabled) { + HashMap<String, String> config = new HashMap<>(); + config.put("spark.shuffle.io.maxRetries", String.valueOf(maxRetries)); + config.put(Constants.SHUFFLE_SERVICE_FETCH_RDD_ENABLED, String.valueOf(rddEnabled)); + return new TransportConf( + "shuffle", + SslSampleConfigs.createDefaultConfigProviderForRpcNamespaceWithAdditionalEntries(config) + ); + } + + @BeforeAll + public static void beforeAll() throws IOException { + doBeforeAllWithConfig(createTransportConf(0, true)); + } +} diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslExternalShuffleSecuritySuite.java similarity index 50% copy from resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala copy to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslExternalShuffleSecuritySuite.java index 322d6bfdb7c..061d63dbcd7 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslExternalShuffleSecuritySuite.java @@ -15,20 +15,31 @@ * limitations under the License. */ -package org.apache.spark.network.yarn +package org.apache.spark.network.shuffle; -import org.apache.spark.network.ssl.SslSampleConfigs +import com.google.common.collect.ImmutableMap; -class SslYarnShuffleServiceWithRocksDBBackendSuite - extends YarnShuffleServiceWithRocksDBBackendSuite { +import org.apache.spark.network.ssl.SslSampleConfigs; +import org.apache.spark.network.util.TransportConf; - /** - * Override to add "spark.ssl.rpc.*" configuration parameters... - */ - override def beforeEach(): Unit = { - super.beforeEach() - // Same as SSLTestUtils.updateWithSSLConfig(), which is not available to import here. - SslSampleConfigs.createDefaultConfigMap().entrySet(). - forEach(entry => yarnConfig.set(entry.getKey, entry.getValue)) +public class SslExternalShuffleSecuritySuite extends ExternalShuffleSecuritySuite { + + @Override + protected TransportConf createTransportConf(boolean encrypt) { + if (encrypt) { + return new TransportConf( + "shuffle", + SslSampleConfigs.createDefaultConfigProviderForRpcNamespaceWithAdditionalEntries( + ImmutableMap.of( + "spark.authenticate.enableSaslEncryption", + "true") + ) + ); + } else { + return new TransportConf( + "shuffle", + SslSampleConfigs.createDefaultConfigProviderForRpcNamespace() + ); + } } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslShuffleTransportContextSuite.java similarity index 55% copy from resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala copy to common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslShuffleTransportContextSuite.java index 322d6bfdb7c..51463bbad55 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/SslShuffleTransportContextSuite.java @@ -15,20 +15,24 @@ * limitations under the License. */ -package org.apache.spark.network.yarn +package org.apache.spark.network.shuffle; -import org.apache.spark.network.ssl.SslSampleConfigs +import com.google.common.collect.ImmutableMap; -class SslYarnShuffleServiceWithRocksDBBackendSuite - extends YarnShuffleServiceWithRocksDBBackendSuite { +import org.apache.spark.network.ssl.SslSampleConfigs; +import org.apache.spark.network.util.TransportConf; - /** - * Override to add "spark.ssl.rpc.*" configuration parameters... - */ - override def beforeEach(): Unit = { - super.beforeEach() - // Same as SSLTestUtils.updateWithSSLConfig(), which is not available to import here. - SslSampleConfigs.createDefaultConfigMap().entrySet(). - forEach(entry => yarnConfig.set(entry.getKey, entry.getValue)) +public class SslShuffleTransportContextSuite extends ShuffleTransportContextSuite { + + @Override + protected TransportConf createTransportConf(boolean separateFinalizeThread) { + return new TransportConf( + "shuffle", + SslSampleConfigs.createDefaultConfigProviderForRpcNamespaceWithAdditionalEntries( + ImmutableMap.of( + "spark.shuffle.server.finalizeShuffleMergeThreadsPercent", + separateFinalizeThread ? "1" : "0") + ) + ); } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala index 322d6bfdb7c..06b91faf44a 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/SslYarnShuffleServiceSuite.scala @@ -28,7 +28,7 @@ class SslYarnShuffleServiceWithRocksDBBackendSuite override def beforeEach(): Unit = { super.beforeEach() // Same as SSLTestUtils.updateWithSSLConfig(), which is not available to import here. - SslSampleConfigs.createDefaultConfigMap().entrySet(). + SslSampleConfigs.createDefaultConfigMapForRpcNamespace().entrySet(). forEach(entry => yarnConfig.set(entry.getKey, entry.getValue)) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org