This is an automated email from the ASF dual-hosted git repository.
nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new b1f8ec835 [CELEBORN-1351] Introduce SSLFactory and enable TLS support
b1f8ec835 is described below
commit b1f8ec83575d556cd5d9da1620b7fc75b4dcee60
Author: Mridul Muralidharan <mridulatgmail.com>
AuthorDate: Mon Apr 8 10:42:29 2024 +0800
[CELEBORN-1351] Introduce SSLFactory and enable TLS support
### What changes were proposed in this pull request?
Add SSLFactory, and wire up TLS support with rest of Celeborn to enable
secure over the wire communication.
### Why are the changes needed?
Add support for TLS to secure wire communication.
This is the last PR to add basic support for TLS.
There will be a follow up for CELEBORN-1356 and documentation ofcourse !
### Does this PR introduce _any_ user-facing change?
Yes, completes basic support for TLS in Celeborn.
### How was this patch tested?
Existing tests, augmented with additional unit tests.
Closes #2438 from mridulm/add-sslfactory-and-related-changes.
Authored-by: Mridul Muralidharan <mridulatgmail.com>
Signed-off-by: SteNicholas <[email protected]>
---
LICENSE | 1 +
.../apache/celeborn/client/ShuffleClientImpl.java | 10 +-
.../celeborn/common/network/TransportContext.java | 83 ++++-
.../network/client/TransportClientFactory.java | 33 +-
.../common/network/server/TransportServer.java | 2 +-
.../celeborn/common/network/ssl/SSLFactory.java | 393 +++++++++++++++++++++
.../celeborn/common/rpc/netty/NettyRpcEnv.scala | 6 +-
.../common/network/RpcIntegrationSuiteJ.java | 11 +-
.../common/network/SSLRpcIntegrationSuiteJ.java | 49 +++
.../network/SSLTransportClientFactorySuiteJ.java | 49 +++
.../network/TransportClientFactorySuiteJ.java | 25 +-
.../celeborn/common/network/sasl/SaslTestBase.java | 3 +
.../common/network/ssl/SslConnectivitySuiteJ.java | 341 ++++++++++++++++++
.../common/network/ssl/SslSampleConfigs.java | 1 -
.../common/rpc/netty/SSLNettyRpcEnvSuite.scala | 44 +++
.../celeborn/service/deploy/worker/Worker.scala | 18 +-
.../network/RequestTimeoutIntegrationSuiteJ.java | 19 +-
.../SSLRequestTimeoutIntegrationSuiteJ.java | 49 +++
.../storage/ChunkFetchIntegrationSuiteJ.java | 20 +-
.../storage/ReducePartitionDataWriterSuiteJ.java | 24 +-
.../storage/SSLChunkFetchIntegrationSuiteJ.java | 49 +++
.../SSLReducePartitionDataWriterSuiteJ.java | 45 +++
22 files changed, 1233 insertions(+), 42 deletions(-)
diff --git a/LICENSE b/LICENSE
index d5f68e4d7..30b865ffc 100644
--- a/LICENSE
+++ b/LICENSE
@@ -215,6 +215,7 @@ Apache Spark
./common/src/main/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeader.java
./common/src/main/java/org/apache/celeborn/common/network/protocol/SslMessageEncoder.java
./common/src/main/java/org/apache/celeborn/common/network/ssl/ReloadingX509TrustManager.java
+./common/src/main/java/org/apache/celeborn/common/network/ssl/SSLFactory.java
./common/src/main/java/org/apache/celeborn/common/network/util/NettyLogger.java
./common/src/main/java/org/apache/celeborn/common/unsafe/Platform.java
./common/src/main/java/org/apache/celeborn/common/util/JavaUtils.java
diff --git
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index 0494d73e9..6aeca765c 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -88,6 +88,7 @@ public class ShuffleClientImpl extends ShuffleClient {
protected RpcEndpointRef lifecycleManagerRef;
+ private TransportContext transportContext;
protected TransportClientFactory dataClientFactory;
protected final int BATCH_HEADER_SIZE = 4 * 4;
@@ -211,12 +212,12 @@ public class ShuffleClientImpl extends ShuffleClient {
if (dataClientFactory != null) {
return;
}
- TransportContext context =
+ this.transportContext =
new TransportContext(
dataTransportConf, new BaseMessageHandler(),
conf.clientCloseIdleConnections());
if (!authEnabled) {
logger.info("Initializing data client factory for {}.", appUniqueId);
- dataClientFactory = context.createClientFactory();
+ dataClientFactory = transportContext.createClientFactory();
} else if (lifecycleManagerRef != null) {
PbApplicationMetaRequest pbApplicationMetaRequest =
PbApplicationMetaRequest.newBuilder().setAppId(appUniqueId).build();
@@ -232,7 +233,7 @@ public class ShuffleClientImpl extends ShuffleClient {
dataTransportConf,
appUniqueId,
new SaslCredentials(appUniqueId,
pbApplicationMeta.getSecret())));
- dataClientFactory = context.createClientFactory(bootstraps);
+ dataClientFactory = transportContext.createClientFactory(bootstraps);
}
}
@@ -1742,6 +1743,9 @@ public class ShuffleClientImpl extends ShuffleClient {
if (null != dataClientFactory) {
dataClientFactory.close();
}
+ if (null != transportContext) {
+ transportContext.close();
+ }
if (null != pushDataRetryPool) {
pushDataRetryPool.shutdown();
}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
b/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
index ec0d1fd87..58366a087 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java
@@ -17,14 +17,19 @@
package org.apache.celeborn.common.network;
+import java.io.Closeable;
import java.util.Collections;
import java.util.List;
+import javax.annotation.Nullable;
+
import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.ssl.SslHandler;
+import io.netty.handler.stream.ChunkedWriteHandler;
import io.netty.handler.timeout.IdleStateHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -35,7 +40,9 @@ import
org.apache.celeborn.common.network.client.TransportClientBootstrap;
import org.apache.celeborn.common.network.client.TransportClientFactory;
import org.apache.celeborn.common.network.client.TransportResponseHandler;
import org.apache.celeborn.common.network.protocol.MessageEncoder;
+import org.apache.celeborn.common.network.protocol.SslMessageEncoder;
import org.apache.celeborn.common.network.server.*;
+import org.apache.celeborn.common.network.ssl.SSLFactory;
import org.apache.celeborn.common.network.util.FrameDecoder;
import org.apache.celeborn.common.network.util.NettyLogger;
import org.apache.celeborn.common.network.util.TransportConf;
@@ -54,7 +61,7 @@ import
org.apache.celeborn.common.network.util.TransportFrameDecoder;
* channel. As each TransportChannelHandler contains a TransportClient, this
enables server
* processes to send messages back to the client on an existing channel.
*/
-public class TransportContext {
+public class TransportContext implements Closeable {
private static final Logger logger =
LoggerFactory.getLogger(TransportContext.class);
private static final NettyLogger nettyLogger = new NettyLogger();
@@ -62,10 +69,13 @@ public class TransportContext {
private final BaseMessageHandler msgHandler;
private final ChannelDuplexHandler channelsLimiter;
private final boolean closeIdleConnections;
+ // Non-null if SSL is enabled, null otherwise.
+ @Nullable private final SSLFactory sslFactory;
private final boolean enableHeartbeat;
private final AbstractSource source;
private static final MessageEncoder ENCODER = MessageEncoder.INSTANCE;
+ private static final SslMessageEncoder SSL_ENCODER =
SslMessageEncoder.INSTANCE;
public TransportContext(
TransportConf conf,
@@ -77,6 +87,7 @@ public class TransportContext {
this.conf = conf;
this.msgHandler = msgHandler;
this.closeIdleConnections = closeIdleConnections;
+ this.sslFactory = createSslFactory();
this.channelsLimiter = channelsLimiter;
this.enableHeartbeat = enableHeartbeat;
this.source = source;
@@ -135,31 +146,53 @@ public class TransportContext {
return createServer(null, 0, Collections.emptyList());
}
+ public boolean sslEncryptionEnabled() {
+ return this.sslFactory != null;
+ }
+
public TransportChannelHandler initializePipeline(
- SocketChannel channel, ChannelInboundHandlerAdapter decoder) {
- return initializePipeline(channel, decoder, msgHandler);
+ SocketChannel channel, ChannelInboundHandlerAdapter decoder, boolean
isClient) {
+ return initializePipeline(channel, decoder, msgHandler, isClient);
}
public TransportChannelHandler initializePipeline(
- SocketChannel channel, BaseMessageHandler resolvedMsgHandler) {
- return initializePipeline(channel, new TransportFrameDecoder(),
resolvedMsgHandler);
+ SocketChannel channel, BaseMessageHandler resolvedMsgHandler, boolean
isClient) {
+ return initializePipeline(channel, new TransportFrameDecoder(),
resolvedMsgHandler, isClient);
}
public TransportChannelHandler initializePipeline(
SocketChannel channel,
ChannelInboundHandlerAdapter decoder,
- BaseMessageHandler resolvedMsgHandler) {
+ BaseMessageHandler resolvedMsgHandler,
+ boolean isClient) {
try {
ChannelPipeline pipeline = channel.pipeline();
if (nettyLogger.getLoggingHandler() != null) {
pipeline.addLast("loggingHandler", nettyLogger.getLoggingHandler());
}
+ if (sslEncryptionEnabled()) {
+ if (!isClient && !sslFactory.hasKeyManagers()) {
+ throw new IllegalStateException("Not a client connection and no keys
configured");
+ }
+
+ SslHandler sslHandler;
+ try {
+ sslHandler = new SslHandler(sslFactory.createSSLEngine(isClient,
channel.alloc()));
+ } catch (Exception e) {
+ throw new IllegalStateException("Error creating Netty SslHandler",
e);
+ }
+ pipeline.addFirst("NettySslEncryptionHandler", sslHandler);
+ // Cannot use zero-copy with HTTPS, so we add in our
ChunkedWriteHandler just before the
+ // MessageEncoder
+ pipeline.addLast("chunkedWriter", new ChunkedWriteHandler());
+ }
+
if (channelsLimiter != null) {
pipeline.addLast("limiter", channelsLimiter);
}
TransportChannelHandler channelHandler = createChannelHandler(channel,
resolvedMsgHandler);
pipeline
- .addLast("encoder", ENCODER)
+ .addLast("encoder", sslEncryptionEnabled() ? SSL_ENCODER : ENCODER)
.addLast(FrameDecoder.HANDLER_NAME, decoder)
.addLast(
"idleStateHandler",
@@ -174,6 +207,35 @@ public class TransportContext {
}
}
+ private SSLFactory createSslFactory() {
+ if (conf.sslEnabled()) {
+ if (conf.sslEnabledAndKeysAreValid()) {
+ return new SSLFactory.Builder()
+ .requestedProtocol(conf.sslProtocol())
+ .requestedCiphers(conf.sslRequestedCiphers())
+ .keyStore(conf.sslKeyStore(), conf.sslKeyStorePassword())
+ .trustStore(
+ conf.sslTrustStore(),
+ conf.sslTrustStorePassword(),
+ conf.sslTrustStoreReloadingEnabled(),
+ conf.sslTrustStoreReloadIntervalMs())
+ .build();
+ } else {
+ logger.error(
+ "SSL encryption enabled but keys not found for "
+ + conf.getModuleName()
+ + "! Please ensure the configured keys are present.");
+ throw new IllegalArgumentException(
+ conf.getModuleName()
+ + " SSL encryption enabled for "
+ + conf.getModuleName()
+ + " but keys not found!");
+ }
+ } else {
+ return null;
+ }
+ }
+
private TransportChannelHandler createChannelHandler(
Channel channel, BaseMessageHandler msgHandler) {
TransportResponseHandler responseHandler = new
TransportResponseHandler(conf, channel);
@@ -198,4 +260,11 @@ public class TransportContext {
public BaseMessageHandler getMsgHandler() {
return msgHandler;
}
+
+ @Override
+ public void close() {
+ if (sslFactory != null) {
+ sslFactory.destroy();
+ }
+ }
}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java
index f51d04a57..26191be26 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java
@@ -34,6 +34,9 @@ import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.*;
import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.ssl.SslHandler;
+import io.netty.util.concurrent.Future;
+import io.netty.util.concurrent.GenericFutureListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -81,6 +84,7 @@ public class TransportClientFactory implements Closeable {
private final int numConnectionsPerPeer;
private final int connectTimeoutMs;
+ private final int connectionTimeoutMs;
private final int receiveBuf;
@@ -97,6 +101,7 @@ public class TransportClientFactory implements Closeable {
this.connectionPool = JavaUtils.newConcurrentHashMap();
this.numConnectionsPerPeer = conf.numConnectionsPerPeer();
this.connectTimeoutMs = conf.connectTimeoutMs();
+ this.connectionTimeoutMs = conf.connectionTimeoutMs();
this.receiveBuf = conf.receiveBuf();
this.sendBuf = conf.sendBuf();
this.rand = new Random();
@@ -237,7 +242,7 @@ public class TransportClientFactory implements Closeable {
new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) {
- TransportChannelHandler clientHandler =
context.initializePipeline(ch, decoder);
+ TransportChannelHandler clientHandler =
context.initializePipeline(ch, decoder, true);
clientRef.set(clientHandler.getClient());
channelRef.set(ch);
}
@@ -252,6 +257,32 @@ public class TransportClientFactory implements Closeable {
} else if (cf.cause() != null) {
throw new CelebornIOException(String.format("Failed to connect to %s",
address), cf.cause());
}
+ if (context.sslEncryptionEnabled()) {
+ final SslHandler sslHandler =
cf.channel().pipeline().get(SslHandler.class);
+ Future<Channel> future =
+ sslHandler
+ .handshakeFuture()
+ .addListener(
+ new GenericFutureListener<Future<Channel>>() {
+ @Override
+ public void operationComplete(final Future<Channel>
handshakeFuture) {
+ if (handshakeFuture.isSuccess()) {
+ logger.debug("successfully completed TLS handshake to
{}", address);
+ } else {
+ logger.info(
+ "failed to complete TLS handshake to {}",
+ address,
+ handshakeFuture.cause());
+ cf.channel().close();
+ }
+ }
+ });
+ if (!future.await(connectionTimeoutMs)) {
+ cf.channel().close();
+ throw new IOException(
+ String.format("Failed to connect to %s within connection timeout",
address));
+ }
+ }
TransportClient client = clientRef.get();
assert client != null : "Channel future completed successfully with null
client";
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/server/TransportServer.java
b/common/src/main/java/org/apache/celeborn/common/network/server/TransportServer.java
index 53f7b7169..d849327df 100644
---
a/common/src/main/java/org/apache/celeborn/common/network/server/TransportServer.java
+++
b/common/src/main/java/org/apache/celeborn/common/network/server/TransportServer.java
@@ -142,7 +142,7 @@ public class TransportServer implements Closeable {
"Adding bootstrap to TransportServer {}.",
bootstrap.getClass().getName());
baseHandler = bootstrap.doBootstrap(ch, baseHandler);
}
- context.initializePipeline(ch, baseHandler);
+ context.initializePipeline(ch, baseHandler, false);
}
});
}
diff --git
a/common/src/main/java/org/apache/celeborn/common/network/ssl/SSLFactory.java
b/common/src/main/java/org/apache/celeborn/common/network/ssl/SSLFactory.java
new file mode 100644
index 000000000..443db6a9c
--- /dev/null
+++
b/common/src/main/java/org/apache/celeborn/common/network/ssl/SSLFactory.java
@@ -0,0 +1,393 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.ssl;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.security.GeneralSecurityException;
+import java.security.KeyStore;
+import java.security.KeyStoreException;
+import java.security.NoSuchAlgorithmException;
+import java.security.UnrecoverableKeyException;
+import java.security.cert.CertificateException;
+import java.security.cert.X509Certificate;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import javax.net.ssl.KeyManager;
+import javax.net.ssl.KeyManagerFactory;
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLEngine;
+import javax.net.ssl.TrustManager;
+import javax.net.ssl.TrustManagerFactory;
+import javax.net.ssl.X509TrustManager;
+
+import com.google.common.io.Files;
+import io.netty.buffer.ByteBufAllocator;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.common.util.JavaUtils;
+
+/**
+ * SSLFactory to initialize and configure use of JSSE for SSL in Celeborn.
+ *
+ * <p>Note: code was initially copied from Apache Spark.
+ */
+public class SSLFactory {
+ private static final Logger logger =
LoggerFactory.getLogger(SSLFactory.class);
+
+ /** For a configuration specifying keystore/truststore files */
+ private SSLContext jdkSslContext;
+
+ private KeyManager[] keyManagers;
+ private TrustManager[] trustManagers;
+ private String requestedProtocol;
+ private String[] requestedCiphers;
+
+ private SSLFactory(final Builder b) {
+ this.requestedProtocol = b.requestedProtocol;
+ this.requestedCiphers = b.requestedCiphers;
+ try {
+ initJdkSslContext(b);
+ } catch (Exception e) {
+ throw new RuntimeException("SSLFactory creation failed", e);
+ }
+ }
+
+ private void initJdkSslContext(final Builder b) throws IOException,
GeneralSecurityException {
+ this.keyManagers =
+ null != b.keyStore ? keyManagers(b.keyStore, b.keyPassword,
b.keyStorePassword) : null;
+ this.trustManagers =
+ trustStoreManagers(
+ b.trustStore, b.trustStorePassword,
+ b.trustStoreReloadingEnabled, b.trustStoreReloadIntervalMs);
+ this.jdkSslContext = createSSLContext(requestedProtocol, keyManagers,
trustManagers);
+ }
+
+ public boolean hasKeyManagers() {
+ return null != keyManagers;
+ }
+
+ public void destroy() {
+ if (trustManagers != null) {
+ for (int i = 0; i < trustManagers.length; i++) {
+ if (trustManagers[i] instanceof ReloadingX509TrustManager) {
+ try {
+ ((ReloadingX509TrustManager) trustManagers[i]).destroy();
+ } catch (InterruptedException ex) {
+ logger.info("Interrupted while destroying trust manager: {}", ex,
ex);
+ }
+ }
+ }
+ trustManagers = null;
+ }
+
+ keyManagers = null;
+ jdkSslContext = null;
+ requestedProtocol = null;
+ requestedCiphers = null;
+ }
+
+ /** Builder class to construct instances of {@link SSLFactory} with specific
options */
+ public static class Builder {
+ private String requestedProtocol;
+ private String[] requestedCiphers;
+ private File keyStore;
+ private String keyStorePassword;
+ private String keyPassword;
+ private File trustStore;
+ private String trustStorePassword;
+ private boolean trustStoreReloadingEnabled;
+ private int trustStoreReloadIntervalMs;
+
+ /**
+ * Sets the requested protocol, i.e., "TLSv1.2", "TLSv1.1", etc
+ *
+ * @param requestedProtocol The requested protocol
+ * @return The builder object
+ */
+ public Builder requestedProtocol(String requestedProtocol) {
+ this.requestedProtocol = requestedProtocol == null ? "TLSv1.3" :
requestedProtocol;
+ return this;
+ }
+
+ /**
+ * Sets the requested cipher suites, i.e.,
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", etc
+ *
+ * @param requestedCiphers The requested ciphers
+ * @return The builder object
+ */
+ public Builder requestedCiphers(String[] requestedCiphers) {
+ this.requestedCiphers = requestedCiphers;
+ return this;
+ }
+
+ /**
+ * Sets the Keystore and Keystore password
+ *
+ * @param keyStore The key store file to use
+ * @param keyStorePassword The password for the key store
+ * @return The builder object
+ */
+ public Builder keyStore(File keyStore, String keyStorePassword) {
+ this.keyStore = keyStore;
+ this.keyStorePassword = keyStorePassword;
+ return this;
+ }
+
+ /**
+ * Sets the key password
+ *
+ * @param keyPassword The password for the private key in the key store
+ * @return The builder object
+ */
+ public Builder keyPassword(String keyPassword) {
+ this.keyPassword = keyPassword;
+ return this;
+ }
+
+ /**
+ * Sets the trust-store, trust-store password, whether to use a Reloading
TrustStore, and the
+ * trust-store reload interval, if enabled
+ *
+ * @param trustStore The trust store file to use
+ * @param trustStorePassword The password for the trust store
+ * @param trustStoreReloadingEnabled Whether trust store reloading is
enabled
+ * @param trustStoreReloadIntervalMs The interval at which to reload the
trust store file
+ * @return The builder object
+ */
+ public Builder trustStore(
+ File trustStore,
+ String trustStorePassword,
+ boolean trustStoreReloadingEnabled,
+ int trustStoreReloadIntervalMs) {
+ this.trustStore = trustStore;
+ this.trustStorePassword = trustStorePassword;
+ this.trustStoreReloadingEnabled = trustStoreReloadingEnabled;
+ this.trustStoreReloadIntervalMs = trustStoreReloadIntervalMs;
+ return this;
+ }
+
+ /**
+ * Builds our {@link SSLFactory}
+ *
+ * @return The built {@link SSLFactory}
+ */
+ public SSLFactory build() {
+ return new SSLFactory(this);
+ }
+ }
+
+ /**
+ * Returns an initialized {@link SSLContext}
+ *
+ * @param requestedProtocol The requested protocol to use
+ * @param keyManagers The list of key managers to use
+ * @param trustManagers The list of trust managers to use
+ * @return The built {@link SSLContext}
+ * @throws GeneralSecurityException
+ */
+ private static SSLContext createSSLContext(
+ String requestedProtocol, KeyManager[] keyManagers, TrustManager[]
trustManagers)
+ throws GeneralSecurityException {
+ SSLContext sslContext = SSLContext.getInstance(requestedProtocol);
+ sslContext.init(keyManagers, trustManagers, null);
+ return sslContext;
+ }
+
+ /**
+ * Creates a new {@link SSLEngine}. Note that currently client auth is not
supported
+ *
+ * @param isClient Whether the engine is used in a client context
+ * @param allocator The {@link ByteBufAllocator to use}
+ * @return A valid {@link SSLEngine}.
+ */
+ public SSLEngine createSSLEngine(boolean isClient, ByteBufAllocator
allocator) {
+ SSLEngine engine = createEngine(isClient, allocator);
+ engine.setUseClientMode(isClient);
+ engine.setWantClientAuth(true);
+ engine.setEnabledProtocols(enabledProtocols(engine, requestedProtocol));
+ engine.setEnabledCipherSuites(enabledCipherSuites(engine,
requestedCiphers));
+ return engine;
+ }
+
+ private SSLEngine createEngine(boolean isClient, ByteBufAllocator allocator)
{
+ return jdkSslContext.createSSLEngine();
+ }
+
+ private static final X509Certificate[] EMPTY_CERT_ARRAY = new
X509Certificate[0];
+
+ private static TrustManager[] credulousTrustStoreManagers() {
+ return new TrustManager[] {
+ new X509TrustManager() {
+ @Override
+ public void checkClientTrusted(X509Certificate[] x509Certificates,
String s)
+ throws CertificateException {}
+
+ @Override
+ public void checkServerTrusted(X509Certificate[] x509Certificates,
String s)
+ throws CertificateException {}
+
+ @Override
+ public X509Certificate[] getAcceptedIssuers() {
+ return EMPTY_CERT_ARRAY;
+ }
+ }
+ };
+ }
+
+ private static TrustManager[] trustStoreManagers(
+ File trustStore,
+ String trustStorePassword,
+ boolean trustStoreReloadingEnabled,
+ int trustStoreReloadIntervalMs)
+ throws IOException, GeneralSecurityException {
+ if (trustStore == null || !trustStore.exists()) {
+ return credulousTrustStoreManagers();
+ } else {
+ if (trustStoreReloadingEnabled) {
+ ReloadingX509TrustManager reloading =
+ new ReloadingX509TrustManager(
+ KeyStore.getDefaultType(),
+ trustStore,
+ trustStorePassword,
+ trustStoreReloadIntervalMs);
+ reloading.init();
+ return new TrustManager[] {reloading};
+ } else {
+ return defaultTrustManagers(trustStore, trustStorePassword);
+ }
+ }
+ }
+
+ private static TrustManager[] defaultTrustManagers(File trustStore, String
trustStorePassword)
+ throws IOException, KeyStoreException, CertificateException,
NoSuchAlgorithmException {
+ try (InputStream input = Files.asByteSource(trustStore).openStream()) {
+ KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
+ char[] passwordCharacters =
+ trustStorePassword != null ? trustStorePassword.toCharArray() : null;
+ ks.load(input, passwordCharacters);
+ TrustManagerFactory tmf =
+
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
+ tmf.init(ks);
+ return tmf.getTrustManagers();
+ }
+ }
+
+ private static KeyManager[] keyManagers(
+ File keyStore, String keyPassword, String keyStorePassword)
+ throws NoSuchAlgorithmException, CertificateException,
KeyStoreException, IOException,
+ UnrecoverableKeyException {
+ KeyManagerFactory factory =
+ KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
+ char[] keyStorePasswordChars = keyStorePassword != null ?
keyStorePassword.toCharArray() : null;
+ char[] keyPasswordChars =
+ keyPassword != null ? keyPassword.toCharArray() :
keyStorePasswordChars;
+ factory.init(loadKeyStore(keyStore, keyStorePasswordChars),
keyPasswordChars);
+ return factory.getKeyManagers();
+ }
+
+ private static KeyStore loadKeyStore(File keyStore, char[] keyStorePassword)
+ throws KeyStoreException, IOException, CertificateException,
NoSuchAlgorithmException {
+ if (keyStore == null) {
+ throw new KeyStoreException(
+ "keyStore cannot be null. Please configure
celeborn.ssl.<module>.keyStore");
+ }
+
+ KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
+ FileInputStream fin = new FileInputStream(keyStore);
+ try {
+ ks.load(fin, keyStorePassword);
+ return ks;
+ } finally {
+ JavaUtils.closeQuietly(fin);
+ }
+ }
+
+ private static String[] enabledProtocols(SSLEngine engine, String
requestedProtocol) {
+ String[] supportedProtocols = engine.getSupportedProtocols();
+ String[] defaultProtocols = {"TLSv1.3", "TLSv1.2"};
+ String[] enabledProtocols =
+ ((requestedProtocol == null || requestedProtocol.isEmpty())
+ ? defaultProtocols
+ : new String[] {requestedProtocol});
+
+ List<String> protocols = addIfSupported(supportedProtocols,
enabledProtocols);
+ if (!protocols.isEmpty()) {
+ return protocols.toArray(new String[protocols.size()]);
+ } else {
+ return supportedProtocols;
+ }
+ }
+
+ private static String[] enabledCipherSuites(
+ String[] supportedCiphers, String[] defaultCiphers, String[]
requestedCiphers) {
+ String[] baseCiphers =
+ new String[] {
+ // We take ciphers from the mozilla modern list first (for TLS 1.3):
+ // https://wiki.mozilla.org/Security/Server_Side_TLS
+ "TLS_CHACHA20_POLY1305_SHA256",
+ "TLS_AES_128_GCM_SHA256",
+ "TLS_AES_256_GCM_SHA384",
+ // Next we have the TLS1.2 ciphers for intermediate compatibility
(since JDK8 does not
+ // support TLS1.3)
+ "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
+ "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
+ "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
+ "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
+ "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
+ "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
+ "TLS_DHE_RSA_WITH_AES_128_GCM_SHA256",
+ "TLS_DHE_RSA_WITH_AES_256_GCM_SHA384"
+ };
+ String[] enabledCiphers =
+ ((requestedCiphers == null || requestedCiphers.length == 0)
+ ? baseCiphers
+ : requestedCiphers);
+
+ List<String> ciphers = addIfSupported(supportedCiphers, enabledCiphers);
+ if (!ciphers.isEmpty()) {
+ return ciphers.toArray(new String[ciphers.size()]);
+ } else {
+ // Use the default from JDK as fallback.
+ return defaultCiphers;
+ }
+ }
+
+ private static String[] enabledCipherSuites(SSLEngine engine, String[]
requestedCiphers) {
+ return enabledCipherSuites(
+ engine.getSupportedCipherSuites(), engine.getEnabledCipherSuites(),
requestedCiphers);
+ }
+
+ private static List<String> addIfSupported(String[] supported, String...
names) {
+ List<String> enabled = new ArrayList<>();
+ Set<String> supportedSet = new HashSet<>(Arrays.asList(supported));
+ for (String n : names) {
+ if (supportedSet.contains(n)) {
+ enabled.add(n);
+ }
+ }
+ return enabled;
+ }
+}
diff --git
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala
index b0f470efe..151a79104 100644
---
a/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala
+++
b/common/src/main/scala/org/apache/celeborn/common/rpc/netty/NettyRpcEnv.scala
@@ -59,7 +59,8 @@ class NettyRpcEnv(
private var worker: RpcEndpoint = null
- private val transportContext =
+ // Visible for tests
+ private[netty] val transportContext =
new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this))
private def createClientBootstraps():
java.util.List[TransportClientBootstrap] = {
@@ -342,6 +343,9 @@ class NettyRpcEnv(
if (clientFactory != null) {
clientFactory.close()
}
+ if (null != transportContext) {
+ transportContext.close();
+ }
if (clientConnectionExecutor != null) {
clientConnectionExecutor.shutdownNow()
}
diff --git
a/common/src/test/java/org/apache/celeborn/common/network/RpcIntegrationSuiteJ.java
b/common/src/test/java/org/apache/celeborn/common/network/RpcIntegrationSuiteJ.java
index ec1834b4d..76cc0dc2c 100644
---
a/common/src/test/java/org/apache/celeborn/common/network/RpcIntegrationSuiteJ.java
+++
b/common/src/test/java/org/apache/celeborn/common/network/RpcIntegrationSuiteJ.java
@@ -43,7 +43,9 @@ import org.apache.celeborn.common.network.util.TransportConf;
import org.apache.celeborn.common.util.JavaUtils;
public class RpcIntegrationSuiteJ {
+ static final String TEST_MODULE = "shuffle";
static TransportConf conf;
+ static TransportContext context;
static TransportServer server;
static TransportClientFactory clientFactory;
static BaseMessageHandler handler;
@@ -52,7 +54,11 @@ public class RpcIntegrationSuiteJ {
@BeforeClass
public static void setUp() throws Exception {
- conf = new TransportConf("shuffle", new CelebornConf());
+ initialize((new CelebornConf()));
+ }
+
+ static void initialize(CelebornConf celebornConf) throws Exception {
+ conf = new TransportConf(TEST_MODULE, celebornConf);
testData = new StreamTestHelper();
handler =
new BaseMessageHandler() {
@@ -91,7 +97,7 @@ public class RpcIntegrationSuiteJ {
return true;
}
};
- TransportContext context = new TransportContext(conf, handler);
+ context = new TransportContext(conf, handler);
server = context.createServer();
clientFactory = context.createClientFactory();
oneWayMsgs = new ArrayList<>();
@@ -101,6 +107,7 @@ public class RpcIntegrationSuiteJ {
public static void tearDown() {
server.close();
clientFactory.close();
+ context.close();
testData.cleanup();
}
diff --git
a/common/src/test/java/org/apache/celeborn/common/network/SSLRpcIntegrationSuiteJ.java
b/common/src/test/java/org/apache/celeborn/common/network/SSLRpcIntegrationSuiteJ.java
new file mode 100644
index 000000000..5c77f995a
--- /dev/null
+++
b/common/src/test/java/org/apache/celeborn/common/network/SSLRpcIntegrationSuiteJ.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network;
+
+import static org.junit.Assert.assertTrue;
+
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.network.ssl.SslSampleConfigs;
+
+public class SSLRpcIntegrationSuiteJ extends RpcIntegrationSuiteJ {
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ // set up SSL for TEST_MODULE
+ RpcIntegrationSuiteJ.initialize(
+ TestHelper.updateCelebornConfWithMap(
+ new CelebornConf(),
SslSampleConfigs.createDefaultConfigMapForModule(TEST_MODULE)));
+ }
+
+ @AfterClass
+ public static void tearDown() {
+ RpcIntegrationSuiteJ.tearDown();
+ }
+
+ @Test
+ public void validateSslConfig() {
+ // this is to ensure ssl config has been applied.
+ assertTrue(conf.sslEnabled());
+ }
+}
diff --git
a/common/src/test/java/org/apache/celeborn/common/network/SSLTransportClientFactorySuiteJ.java
b/common/src/test/java/org/apache/celeborn/common/network/SSLTransportClientFactorySuiteJ.java
new file mode 100644
index 000000000..8c4884373
--- /dev/null
+++
b/common/src/test/java/org/apache/celeborn/common/network/SSLTransportClientFactorySuiteJ.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network;
+
+import static org.junit.Assert.assertTrue;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.network.ssl.SslSampleConfigs;
+
+public class SSLTransportClientFactorySuiteJ extends
TransportClientFactorySuiteJ {
+
+ @Before
+ public void setUp() {
+ // set up SSL for TEST_MODULE
+ doSetup(
+ TestHelper.updateCelebornConfWithMap(
+ new CelebornConf(),
SslSampleConfigs.createDefaultConfigMapForModule(TEST_MODULE)));
+ }
+
+ @After
+ public void tearDown() {
+ super.tearDown();
+ }
+
+ @Test
+ public void validateSslConfig() {
+ // this is to ensure ssl config has been applied.
+ assertTrue(getTransportContextConf().sslEnabled());
+ }
+}
diff --git
a/common/src/test/java/org/apache/celeborn/common/network/TransportClientFactorySuiteJ.java
b/common/src/test/java/org/apache/celeborn/common/network/TransportClientFactorySuiteJ.java
index 022af8178..26a9b4885 100644
---
a/common/src/test/java/org/apache/celeborn/common/network/TransportClientFactorySuiteJ.java
+++
b/common/src/test/java/org/apache/celeborn/common/network/TransportClientFactorySuiteJ.java
@@ -37,23 +37,36 @@ import
org.apache.celeborn.common.network.util.TransportConf;
import org.apache.celeborn.common.util.JavaUtils;
public class TransportClientFactorySuiteJ {
+
+ static final String TEST_MODULE = "shuffle";
+
private TransportContext context;
private TransportServer server1;
private TransportServer server2;
- @Before
- public void setUp() {
- TransportConf conf = new TransportConf("shuffle", new CelebornConf());
+ protected void doSetup(CelebornConf celebornConf) {
+ TransportConf conf = new TransportConf(TEST_MODULE, celebornConf);
BaseMessageHandler handler = new BaseMessageHandler();
context = new TransportContext(conf, handler);
server1 = context.createServer();
server2 = context.createServer();
}
+ @Before
+ public void setUp() {
+ doSetup(new CelebornConf());
+ }
+
+ // for validation in subclasses
+ TransportConf getTransportContextConf() {
+ return context.getConf();
+ }
+
@After
public void tearDown() {
JavaUtils.closeQuietly(server1);
JavaUtils.closeQuietly(server2);
+ JavaUtils.closeQuietly(context);
}
/**
@@ -68,7 +81,7 @@ public class TransportClientFactorySuiteJ {
CelebornConf _conf = new CelebornConf();
_conf.set("celeborn.shuffle.io.numConnectionsPerPeer",
Integer.toString(maxConnections));
- TransportConf conf = new TransportConf("shuffle", _conf);
+ TransportConf conf = new TransportConf(TEST_MODULE, _conf);
BaseMessageHandler handler = new BaseMessageHandler();
TransportContext context = new TransportContext(conf, handler);
@@ -114,6 +127,7 @@ public class TransportClientFactorySuiteJ {
}
factory.close();
+ context.close();
}
@Test
@@ -177,7 +191,7 @@ public class TransportClientFactorySuiteJ {
public void closeIdleConnectionForRequestTimeOut() throws IOException,
InterruptedException {
CelebornConf _conf = new CelebornConf();
_conf.set("celeborn.shuffle.io.connectionTimeout", "1s");
- TransportConf conf = new TransportConf("shuffle", _conf);
+ TransportConf conf = new TransportConf(TEST_MODULE, _conf);
TransportContext context = new TransportContext(conf, new
BaseMessageHandler(), true);
try (TransportClientFactory factory = context.createClientFactory()) {
TransportClient c1 = factory.createClient(getLocalHost(),
server1.getPort());
@@ -188,6 +202,7 @@ public class TransportClientFactorySuiteJ {
}
assertFalse(c1.isActive());
}
+ context.close();
}
@Test(expected = IOException.class)
diff --git
a/common/src/test/java/org/apache/celeborn/common/network/sasl/SaslTestBase.java
b/common/src/test/java/org/apache/celeborn/common/network/sasl/SaslTestBase.java
index c6ca51831..6a34adb08 100644
---
a/common/src/test/java/org/apache/celeborn/common/network/sasl/SaslTestBase.java
+++
b/common/src/test/java/org/apache/celeborn/common/network/sasl/SaslTestBase.java
@@ -137,6 +137,9 @@ public class SaslTestBase {
if (server != null) {
server.close();
}
+ if (null != ctx) {
+ ctx.close();
+ }
}
}
}
diff --git
a/common/src/test/java/org/apache/celeborn/common/network/ssl/SslConnectivitySuiteJ.java
b/common/src/test/java/org/apache/celeborn/common/network/ssl/SslConnectivitySuiteJ.java
new file mode 100644
index 000000000..5f8b1902c
--- /dev/null
+++
b/common/src/test/java/org/apache/celeborn/common/network/ssl/SslConnectivitySuiteJ.java
@@ -0,0 +1,341 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.ssl;
+
+import static org.apache.celeborn.common.util.JavaUtils.getLocalHost;
+import static org.junit.Assert.*;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Map;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
+
+import org.apache.commons.lang3.tuple.Pair;
+import org.junit.Test;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.network.TestHelper;
+import org.apache.celeborn.common.network.TransportContext;
+import org.apache.celeborn.common.network.client.RpcResponseCallback;
+import org.apache.celeborn.common.network.client.TransportClient;
+import org.apache.celeborn.common.network.client.TransportClientFactory;
+import org.apache.celeborn.common.network.protocol.RequestMessage;
+import org.apache.celeborn.common.network.server.BaseMessageHandler;
+import org.apache.celeborn.common.network.server.TransportServer;
+import org.apache.celeborn.common.network.util.TransportConf;
+import org.apache.celeborn.common.util.JavaUtils;
+
+/**
+ * A few negative tests to ensure that a non SSL client cant talk to an SSL
server and vice versa.
+ * Also, a few tests to ensure non-SSL client and servers can talk to each
other, SSL client and
+ * server also can talk to each other.
+ */
+public class SslConnectivitySuiteJ {
+
+ private static final String TEST_MODULE = "rpc";
+
+ private static final String RESPONSE_PREFIX = "Test-prefix...";
+ private static final String RESPONSE_SUFFIX = "...Suffix";
+
+ private static final TestBaseMessageHandler DEFAULT_HANDLER = new
TestBaseMessageHandler();
+
+ private static TransportConf createTransportConf(
+ String module,
+ boolean enableSsl,
+ boolean useDefault,
+ Function<CelebornConf, CelebornConf> postProcessConf) {
+
+ CelebornConf celebornConf = new CelebornConf();
+ // in case the default gets flipped to true in future
+ celebornConf.set("celeborn.ssl." + module + ".enabled", "false");
+ if (enableSsl) {
+ Map<String, String> configMap;
+ if (useDefault) {
+ configMap = SslSampleConfigs.createDefaultConfigMapForModule(module);
+ } else {
+ configMap = SslSampleConfigs.createAnotherConfigMapForModule(module);
+ }
+ TestHelper.updateCelebornConfWithMap(celebornConf, configMap);
+ }
+
+ celebornConf = postProcessConf.apply(celebornConf);
+
+ TransportConf conf = new TransportConf(module, celebornConf);
+ assertEquals(enableSsl, conf.sslEnabled());
+
+ return conf;
+ }
+
+ private static class TestTransportState implements Closeable {
+ final TransportContext serverContext;
+ final TransportContext clientContext;
+ final TransportServer server;
+ final TransportClientFactory clientFactory;
+
+ TestTransportState(
+ BaseMessageHandler handler, TransportConf serverConf, TransportConf
clientConf) {
+ this.serverContext = new TransportContext(serverConf, handler);
+ this.clientContext = new TransportContext(clientConf, handler);
+
+ this.server = serverContext.createServer();
+ this.clientFactory = clientContext.createClientFactory();
+ }
+
+ TransportClient createClient() throws IOException, InterruptedException {
+ return clientFactory.createClient(getLocalHost(), server.getPort());
+ }
+
+ @Override
+ public void close() throws IOException {
+ JavaUtils.closeQuietly(server);
+ JavaUtils.closeQuietly(clientFactory);
+ JavaUtils.closeQuietly(serverContext);
+ JavaUtils.closeQuietly(clientContext);
+ }
+ }
+
+ private static class TestBaseMessageHandler extends BaseMessageHandler {
+
+ @Override
+ public void receive(
+ TransportClient client, RequestMessage requestMessage,
RpcResponseCallback callback) {
+ try {
+ String msg =
JavaUtils.bytesToString(requestMessage.body().nioByteBuffer());
+ String response = RESPONSE_PREFIX + msg + RESPONSE_SUFFIX;
+ callback.onSuccess(JavaUtils.stringToBytes(response));
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public boolean checkRegistered() {
+ return true;
+ }
+ }
+
+ // Pair<Success, Failure> ... should be an Either actually.
+
+ private Pair<String, String> sendRPC(TransportClient client, String message,
boolean canTimeout)
+ throws Exception {
+ final Semaphore sem = new Semaphore(0);
+
+ final AtomicReference<String> response = new AtomicReference<>(null);
+ final AtomicReference<String> errorResponse = new AtomicReference<>(null);
+
+ RpcResponseCallback callback =
+ new RpcResponseCallback() {
+ @Override
+ public void onSuccess(ByteBuffer message) {
+ String res = JavaUtils.bytesToString(message);
+ response.set(res);
+ sem.release();
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ errorResponse.set(e.getMessage());
+ sem.release();
+ }
+ };
+
+ client.sendRpc(JavaUtils.stringToBytes(message), callback);
+
+ if (!sem.tryAcquire(1, 5, TimeUnit.SECONDS)) {
+ if (canTimeout) {
+ throw new IOException("Timed out sending rpc message");
+ } else {
+ fail("Timeout getting response from the server");
+ }
+ }
+
+ return Pair.of(response.get(), errorResponse.get());
+ }
+
+ // A basic validation test to check if non-ssl client and non-ssl server can
talk to each other.
+ // This not only validates non-SSL flow, but also is a check to verify that
the negative
+ // tests are not indicating any other error.
+ @Test
+ public void testNormalConnectivityWorks() throws Exception {
+ testSuccessfulConnectivity(
+ false,
+ // does not matter what these are set to
+ true,
+ true,
+ Function.identity(),
+ Function.identity());
+ }
+
+ // Both server and client are on SSL, and should be able to successfully
communicate
+ // This is the SSL version of testNormalConnectivityWorks above.
+ @Test
+ public void testBasicSslClientConnectivityWorks() throws Exception {
+
+ final Function<CelebornConf, CelebornConf> updateClientConf =
+ conf -> {
+ // ignore incoming conf, and return a new which only has ssl enabled
= true
+ // This is essentially testing two things:
+ // a) client can talk SSL to a server, and does not need anything
else to be specified
+ // b) When no truststore is configured at client, it is trusts all
server certs
+ CelebornConf newConf = new CelebornConf();
+ newConf.set("celeborn.ssl." + TEST_MODULE + ".enabled", "true");
+ return newConf;
+ };
+
+ // primaryConfigForClient param does not matter - we are completely
overriding it above
+ // in updateClientConf. Adding both just for completeness sake.
+ testSuccessfulConnectivity(true, true, true, Function.identity(),
updateClientConf);
+ testSuccessfulConnectivity(true, true, false, Function.identity(),
updateClientConf);
+
+ // just for validation, this should fail (ssl for client, plain for
server) ...
+ testConnectivityFailure(false, true, false, false, Function.identity(),
updateClientConf);
+ }
+
+ // Only SSL client can talk to a SSL server.
+ @Test
+ public void testSslServerNormalClientFails() throws Exception {
+ // Will fail for both primary and seconday jks (primaryConfigForServer) -
adding
+ // both just for completeness
+ testConnectivityFailure(true, false, true, true, Function.identity(),
Function.identity());
+ testConnectivityFailure(true, false, false, true, Function.identity(),
Function.identity());
+ }
+
+ // Only non-SSL client can talk to SSL server
+ @Test
+ public void testSslClientNormalServerFails() throws Exception {
+ // Will fail for both primaryConfigForClient - adding both just for
completeness
+ testConnectivityFailure(false, true, false, false, Function.identity(),
Function.identity());
+ testConnectivityFailure(false, true, false, true, Function.identity(),
Function.identity());
+ }
+
+ @Test
+ public void testUntrustedServerCertFails() throws Exception {
+ final String trustStoreKey = "celeborn.ssl." + TEST_MODULE + ".trustStore";
+
+ final Function<CelebornConf, CelebornConf> updateConf =
+ conf -> {
+ assertTrue(conf.getOption(trustStoreKey).isDefined());
+ conf.set(trustStoreKey, SslSampleConfigs.TRUST_STORE_WITHOUT_CA);
+ return conf;
+ };
+
+ // will fail for all combinations - since we dont have the CA's in the
truststore
+ testConnectivityFailure(true, true, true, false, updateConf, updateConf);
+ testConnectivityFailure(true, true, true, true, updateConf, updateConf);
+ testConnectivityFailure(true, true, false, true, updateConf, updateConf);
+ testConnectivityFailure(true, true, false, false, updateConf, updateConf);
+ }
+
+ // This is a variant of testUntrustedServerCertFails - where the server does
not trust the
+ // client cert, and the client does provide a cert.
+ @Test
+ public void testUntrustedClientCertFails() throws Exception {
+ final String trustStoreKey = "celeborn.ssl." + TEST_MODULE + ".trustStore";
+
+ final Function<CelebornConf, CelebornConf> updateConf =
+ conf -> {
+ assertTrue(conf.getOption(trustStoreKey).isDefined());
+ conf.set(trustStoreKey, SslSampleConfigs.TRUST_STORE_WITHOUT_CA);
+ return conf;
+ };
+
+ // will fail for all combinations - since server does not have the client
cert's CA
+ // in its truststore
+ testConnectivityFailure(true, true, true, false, updateConf,
Function.identity());
+ testConnectivityFailure(true, true, true, true, updateConf,
Function.identity());
+ testConnectivityFailure(true, true, false, true, updateConf,
Function.identity());
+ testConnectivityFailure(true, true, false, false, updateConf,
Function.identity());
+ }
+
+ @Test
+ public void testUntrustedServerCertWorksIfTrustStoreDisabled() throws
Exception {
+ // Same as testUntrustedServerCertFails, but remove the truststore - which
should result in
+ // accepting all certs. Note, for jks at client side
+
+ final String trustStoreKey = "celeborn.ssl." + TEST_MODULE + ".trustStore";
+
+ final Function<CelebornConf, CelebornConf> updateConf =
+ conf -> {
+ assertTrue(conf.getOption(trustStoreKey).isDefined());
+ conf.unset(trustStoreKey);
+ assertNull(new TransportConf(TEST_MODULE, conf).sslTrustStore());
+ return conf;
+ };
+
+ // checking nettyssl == false at both client and server does not make
sense in this context
+ // it is the same cert for both :)
+ testSuccessfulConnectivity(true, true, true, updateConf, updateConf);
+ testSuccessfulConnectivity(true, true, false, updateConf, updateConf);
+ testSuccessfulConnectivity(true, false, true, updateConf, updateConf);
+ testSuccessfulConnectivity(true, false, false, updateConf, updateConf);
+ }
+
+ private void testSuccessfulConnectivity(
+ boolean enableSsl,
+ boolean primaryConfigForServer,
+ boolean primaryConfigForClient,
+ Function<CelebornConf, CelebornConf> postProcessServerConf,
+ Function<CelebornConf, CelebornConf> postProcessClientConf)
+ throws Exception {
+ try (TestTransportState state =
+ new TestTransportState(
+ DEFAULT_HANDLER,
+ createTransportConf(
+ TEST_MODULE, enableSsl, primaryConfigForServer,
postProcessServerConf),
+ createTransportConf(
+ TEST_MODULE, enableSsl, primaryConfigForClient,
postProcessClientConf));
+ TransportClient client = state.createClient()) {
+
+ String msg = " hi ";
+ Pair<String, String> response = sendRPC(client, msg, false);
+ assertNotNull("Failed ? " + response.getRight(), response.getLeft());
+ assertNull(response.getRight());
+ assertEquals(RESPONSE_PREFIX + msg + RESPONSE_SUFFIX,
response.getLeft());
+ }
+ }
+
+ private void testConnectivityFailure(
+ boolean serverSsl,
+ boolean clientSsl,
+ boolean primaryConfigForServer,
+ boolean primaryConfigForClient,
+ Function<CelebornConf, CelebornConf> postProcessServerConf,
+ Function<CelebornConf, CelebornConf> postProcessClientConf)
+ throws Exception {
+ try (TestTransportState state =
+ new TestTransportState(
+ DEFAULT_HANDLER,
+ createTransportConf(
+ TEST_MODULE, serverSsl, primaryConfigForServer,
postProcessServerConf),
+ createTransportConf(
+ TEST_MODULE, clientSsl, primaryConfigForClient,
postProcessClientConf));
+ TransportClient client = state.createClient()) {
+
+ String msg = " hi ";
+ Pair<String, String> response = sendRPC(client, msg, true);
+ assertNull(response.getLeft());
+ assertNotNull(response.getRight());
+ } catch (IOException ioEx) {
+ // this is fine - expected to fail
+ }
+ }
+}
diff --git
a/common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java
b/common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java
index 653709a31..c33baf92a 100644
---
a/common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java
+++
b/common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java
@@ -62,7 +62,6 @@ public class SslSampleConfigs {
Map<String, String> confMap = new HashMap<>();
confMap.put("celeborn.ssl." + module + ".enabled", "true");
confMap.put("celeborn.ssl." + module + ".trustStoreReloadingEnabled",
"false");
- confMap.put("celeborn.ssl." + module + ".openSslEnabled", "false");
confMap.put("celeborn.ssl." + module + ".trustStoreReloadIntervalMs",
"10000");
if (forDefault) {
confMap.put("celeborn.ssl." + module + ".keyStore",
DEFAULT_KEY_STORE_PATH);
diff --git
a/common/src/test/scala/org/apache/celeborn/common/rpc/netty/SSLNettyRpcEnvSuite.scala
b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/SSLNettyRpcEnvSuite.scala
new file mode 100644
index 000000000..79de7f313
--- /dev/null
+++
b/common/src/test/scala/org/apache/celeborn/common/rpc/netty/SSLNettyRpcEnvSuite.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.rpc.netty
+
+import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.network.TestHelper
+import org.apache.celeborn.common.network.ssl.SslSampleConfigs
+import org.apache.celeborn.common.protocol.TransportModuleConstants
+
+class SSLNettyRpcEnvSuite extends NettyRpcEnvSuite {
+
+ override def createCelebornConf(): CelebornConf = {
+ val conf = super.createCelebornConf()
+ TestHelper.updateCelebornConfWithMap(
+ conf,
+
SslSampleConfigs.createDefaultConfigMapForModule(TransportModuleConstants.RPC_MODULE))
+ conf
+ }
+
+ test("verify rpc env is using SSL") {
+ env match {
+ case nettyRpcEnv: NettyRpcEnv =>
+ assert(nettyRpcEnv.transportContext.getConf.sslEnabled())
+ assert(nettyRpcEnv.transportContext.sslEncryptionEnabled())
+ case _ =>
+ throw new IllegalArgumentException("Expected NettyRpcEnv, found = " +
env)
+ }
+ }
+}
diff --git
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala
index 5ec378060..88ad05c4e 100644
---
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala
+++
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala
@@ -196,7 +196,7 @@ private[celeborn] class Worker(
}
val pushDataHandler = new PushDataHandler(workerSource)
- private val pushServer = {
+ private val (pushServerTransportContext, pushServer) = {
val closeIdleConnections = conf.workerCloseIdleConnections
val numThreads =
conf.workerPushIoThreads.getOrElse(storageManager.totalFlusherThread)
val transportConf =
@@ -210,11 +210,13 @@ private[celeborn] class Worker(
pushServerLimiter,
conf.workerPushHeartbeatEnabled,
workerSource)
- transportContext.createServer(conf.workerPushPort,
getServerBootstraps(transportConf))
+ (
+ transportContext,
+ transportContext.createServer(conf.workerPushPort,
getServerBootstraps(transportConf)))
}
val replicateHandler = new PushDataHandler(workerSource)
- val (replicateServer, replicateClientFactory) = {
+ val (replicateTransportContext, replicateServer, replicateClientFactory) = {
val closeIdleConnections = conf.workerCloseIdleConnections
val numThreads =
conf.workerReplicateIoThreads.getOrElse(storageManager.totalFlusherThread)
@@ -230,12 +232,13 @@ private[celeborn] class Worker(
false,
workerSource)
(
+ transportContext,
transportContext.createServer(conf.workerReplicatePort),
transportContext.createClientFactory())
}
var fetchHandler: FetchHandler = _
- private val fetchServer = {
+ private val (fetchServerTransportContext, fetchServer) = {
val closeIdleConnections = conf.workerCloseIdleConnections
val numThreads =
conf.workerFetchIoThreads.getOrElse(storageManager.totalFlusherThread)
val transportConf =
@@ -248,7 +251,9 @@ private[celeborn] class Worker(
closeIdleConnections,
conf.workerFetchHeartbeatEnabled,
workerSource)
- transportContext.createServer(conf.workerFetchPort,
getServerBootstraps(transportConf))
+ (
+ transportContext,
+ transportContext.createServer(conf.workerFetchPort,
getServerBootstraps(transportConf)))
}
private val pushPort = pushServer.getPort
@@ -551,6 +556,9 @@ private[celeborn] class Worker(
replicateServer.shutdown(exitKind)
fetchServer.shutdown(exitKind)
pushServer.shutdown(exitKind)
+ replicateTransportContext.close()
+ fetchServerTransportContext.close()
+ pushServerTransportContext.close()
metricsSystem.stop()
if (conf.internalPortEnabled) {
internalRpcEnvInUse.stop(internalRpcEndpointRef)
diff --git
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/RequestTimeoutIntegrationSuiteJ.java
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/RequestTimeoutIntegrationSuiteJ.java
index fb083d392..d9494c429 100644
---
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/RequestTimeoutIntegrationSuiteJ.java
+++
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/RequestTimeoutIntegrationSuiteJ.java
@@ -60,6 +60,8 @@ import
org.apache.celeborn.service.deploy.worker.storage.ChunkStreamManager;
*/
public class RequestTimeoutIntegrationSuiteJ {
+ static final String TEST_MODULE = "shuffle";
+
private TransportServer server;
private TransportClientFactory clientFactory;
@@ -68,11 +70,14 @@ public class RequestTimeoutIntegrationSuiteJ {
// A large timeout that "shouldn't happen", for the sake of faulty tests not
hanging forever.
private static final int FOREVER = 60 * 1000;
+ protected void doSetup(CelebornConf celebornConf) {
+ celebornConf.set("celeborn.shuffle.io.connectionTimeout", "2s");
+ conf = new TransportConf(TEST_MODULE, celebornConf);
+ }
+
@Before
public void setUp() throws Exception {
- CelebornConf _conf = new CelebornConf();
- _conf.set("celeborn.shuffle.io.connectionTimeout", "2s");
- conf = new TransportConf("shuffle", _conf);
+ doSetup(new CelebornConf());
}
@After
@@ -85,6 +90,11 @@ public class RequestTimeoutIntegrationSuiteJ {
}
}
+ // for validation in subclasses
+ protected TransportConf getConf() {
+ return this.conf;
+ }
+
// Basic suite: First request completes quickly, and second waits for longer
than network timeout.
@Test
public void timeoutInactiveRequests() throws Exception {
@@ -136,6 +146,7 @@ public class RequestTimeoutIntegrationSuiteJ {
callback1.latch.await(60, TimeUnit.SECONDS);
assertNotNull(callback1.failure);
assertTrue(callback1.failure instanceof IOException);
+ context.close();
semaphore.release();
}
@@ -195,6 +206,7 @@ public class RequestTimeoutIntegrationSuiteJ {
callback1.latch.await();
assertEquals(responseSize, callback1.successLength);
assertNull(callback1.failure);
+ context.close();
}
// The timeout is relative to the LAST request sent, which is kinda weird,
but still.
@@ -265,6 +277,7 @@ public class RequestTimeoutIntegrationSuiteJ {
callback1.latch.await(60, TimeUnit.SECONDS);
// failed at same time as previous
assertTrue(callback1.failure instanceof IOException);
+ context.close();
}
/**
diff --git
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/SSLRequestTimeoutIntegrationSuiteJ.java
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/SSLRequestTimeoutIntegrationSuiteJ.java
new file mode 100644
index 000000000..4037c6f1b
--- /dev/null
+++
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/SSLRequestTimeoutIntegrationSuiteJ.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.service.deploy.worker.network;
+
+import static org.junit.Assert.assertTrue;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.network.TestHelper;
+import org.apache.celeborn.common.network.ssl.SslSampleConfigs;
+
+public class SSLRequestTimeoutIntegrationSuiteJ extends
RequestTimeoutIntegrationSuiteJ {
+ @Before
+ public void setUp() {
+ // set up SSL for TEST_MODULE
+ doSetup(
+ TestHelper.updateCelebornConfWithMap(
+ new CelebornConf(),
SslSampleConfigs.createDefaultConfigMapForModule(TEST_MODULE)));
+ }
+
+ @After
+ public void tearDown() {
+ super.tearDown();
+ }
+
+ @Test
+ public void validateSslConfig() {
+ // this is to ensure ssl config has been applied.
+ assertTrue(super.getConf().sslEnabled());
+ }
+}
diff --git
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkFetchIntegrationSuiteJ.java
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkFetchIntegrationSuiteJ.java
index 36e12ca4c..f4eaf618e 100644
---
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkFetchIntegrationSuiteJ.java
+++
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkFetchIntegrationSuiteJ.java
@@ -55,10 +55,12 @@ import
org.apache.celeborn.common.network.util.TransportConf;
import org.apache.celeborn.common.protocol.PbChunkFetchRequest;
public class ChunkFetchIntegrationSuiteJ {
+ static final String TEST_MODULE = "shuffle";
static final long STREAM_ID = 1;
static final int BUFFER_CHUNK_INDEX = 0;
static final int FILE_CHUNK_INDEX = 1;
+ static TransportContext transportContext;
static TransportServer server;
static TransportClientFactory clientFactory;
static ChunkStreamManager chunkStreamManager;
@@ -69,6 +71,10 @@ public class ChunkFetchIntegrationSuiteJ {
@BeforeClass
public static void setUp() throws Exception {
+ initialize((new CelebornConf()));
+ }
+
+ static void initialize(CelebornConf celebornConf) throws Exception {
int bufSize = 100_000;
final ByteBuffer buf = ByteBuffer.allocate(bufSize);
for (int i = 0; i < bufSize; i++) {
@@ -90,7 +96,7 @@ public class ChunkFetchIntegrationSuiteJ {
Closeables.close(fp, shouldSuppressIOException);
}
- final TransportConf conf = new TransportConf("shuffle", new
CelebornConf());
+ final TransportConf conf = new TransportConf(TEST_MODULE, celebornConf);
fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10,
testFile.length() - 25);
chunkStreamManager =
@@ -144,9 +150,9 @@ public class ChunkFetchIntegrationSuiteJ {
return true;
}
};
- TransportContext context = new TransportContext(conf, handler);
- server = context.createServer();
- clientFactory = context.createClientFactory();
+ transportContext = new TransportContext(conf, handler);
+ server = transportContext.createServer();
+ clientFactory = transportContext.createClientFactory();
}
@AfterClass
@@ -154,6 +160,7 @@ public class ChunkFetchIntegrationSuiteJ {
bufferChunk.release();
server.close();
clientFactory.close();
+ transportContext.close();
testFile.delete();
}
@@ -205,6 +212,11 @@ public class ChunkFetchIntegrationSuiteJ {
return res;
}
+ // for subclasses to validate
+ TransportConf fetchTransportConf() {
+ return transportContext.getConf();
+ }
+
@Test
public void fetchBufferChunk() throws Exception {
FetchResult res = fetchChunks(Arrays.asList(BUFFER_CHUNK_INDEX));
diff --git
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ReducePartitionDataWriterSuiteJ.java
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ReducePartitionDataWriterSuiteJ.java
index 1271d2c45..6417e2aa0 100644
---
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ReducePartitionDataWriterSuiteJ.java
+++
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ReducePartitionDataWriterSuiteJ.java
@@ -90,14 +90,13 @@ public class ReducePartitionDataWriterSuiteJ {
private static LocalFlusher localFlusher = null;
private static WorkerSource source = null;
- private static TransportServer server;
- private static TransportClientFactory clientFactory;
+ private TransportContext transportContext;
+ private TransportServer server;
+ private TransportClientFactory clientFactory;
private static long streamId;
private static int numChunks;
private final UserIdentifier userIdentifier = new
UserIdentifier("mock-tenantId", "mock-name");
- private static final TransportConf transConf = new TransportConf("shuffle",
new CelebornConf());
-
@BeforeClass
public static void beforeAll() {
tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"),
"celeborn");
@@ -138,7 +137,13 @@ public class ReducePartitionDataWriterSuiteJ {
MemoryManager.initialize(conf);
}
- public static void setupChunkServer(DiskFileInfo info) throws IOException {
+ protected TransportConf createModuleTransportConf(String module) {
+ return new TransportConf(module, new CelebornConf());
+ }
+
+ public void setupChunkServer(DiskFileInfo info) throws IOException {
+ TransportConf transConf = createModuleTransportConf("shuffle");
+
FetchHandler handler =
new FetchHandler(transConf.getCelebornConf(), transConf,
mock(WorkerSource.class)) {
@Override
@@ -166,10 +171,10 @@ public class ReducePartitionDataWriterSuiteJ {
.when(sorter)
.getSortedFileInfo(anyString(), anyString(), eq(info), anyInt(),
anyInt());
handler.setPartitionsSorter(sorter);
- TransportContext context = new TransportContext(transConf, handler);
- server = context.createServer();
+ transportContext = new TransportContext(transConf, handler);
+ server = transportContext.createServer();
- clientFactory = context.createClientFactory();
+ clientFactory = transportContext.createClientFactory();
}
@AfterClass
@@ -184,9 +189,10 @@ public class ReducePartitionDataWriterSuiteJ {
}
}
- public static void closeChunkServer() {
+ public void closeChunkServer() {
server.close();
clientFactory.close();
+ transportContext.close();
}
static class FetchResult {
diff --git
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/SSLChunkFetchIntegrationSuiteJ.java
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/SSLChunkFetchIntegrationSuiteJ.java
new file mode 100644
index 000000000..e3cc85494
--- /dev/null
+++
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/SSLChunkFetchIntegrationSuiteJ.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.service.deploy.worker.storage;
+
+import static org.junit.Assert.assertTrue;
+
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.network.TestHelper;
+import org.apache.celeborn.common.network.ssl.SslSampleConfigs;
+
+public class SSLChunkFetchIntegrationSuiteJ extends
ChunkFetchIntegrationSuiteJ {
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ SSLChunkFetchIntegrationSuiteJ.initialize(
+ TestHelper.updateCelebornConfWithMap(
+ new CelebornConf(),
SslSampleConfigs.createDefaultConfigMapForModule(TEST_MODULE)));
+ }
+
+ @AfterClass
+ public static void tearDown() {
+ ChunkFetchIntegrationSuiteJ.tearDown();
+ }
+
+ @Test
+ public void validateSslConfig() {
+ // this is to ensure ssl config has been applied.
+ assertTrue(fetchTransportConf().sslEnabled());
+ }
+}
diff --git
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/SSLReducePartitionDataWriterSuiteJ.java
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/SSLReducePartitionDataWriterSuiteJ.java
new file mode 100644
index 000000000..e6e37ed51
--- /dev/null
+++
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/SSLReducePartitionDataWriterSuiteJ.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.service.deploy.worker.storage;
+
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.network.TestHelper;
+import org.apache.celeborn.common.network.ssl.SslSampleConfigs;
+import org.apache.celeborn.common.network.util.TransportConf;
+
+public class SSLReducePartitionDataWriterSuiteJ extends
ReducePartitionDataWriterSuiteJ {
+ protected TransportConf createModuleTransportConf(String module) {
+ CelebornConf conf =
+ TestHelper.updateCelebornConfWithMap(
+ new CelebornConf(),
SslSampleConfigs.createDefaultConfigMapForModule(module));
+ return new TransportConf(module, conf);
+ }
+
+ @BeforeClass
+ public static void beforeAll() {
+ ReducePartitionDataWriterSuiteJ.beforeAll();
+ }
+
+ @AfterClass
+ public static void afterAll() {
+ ReducePartitionDataWriterSuiteJ.afterAll();
+ }
+}